22from typing import Sequence , Union
33
44import numpy as np
5- import pymc
65import pytensor .tensor as pt
6+ import scipy
77from arviz import InferenceData , dict_to_dataset
8- from pymc import SymbolicRandomVariable
8+ from pymc import SymbolicRandomVariable , icdf
99from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
10+ from pymc .distributions import MvNormal
11+ from pymc .distributions .continuous import Continuous
1012from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
1113from pymc .distributions .transforms import Chain
1214from pymc .logprob .abstract import _logprob
1315from pymc .logprob .basic import conditional_logp , logp
1416from pymc .logprob .transforms import IntervalTransform
1517from pymc .model import Model
16- from pymc .pytensorf import compile_pymc , constant_fold
18+ from pymc .pytensorf import collect_default_updates , compile_pymc , constant_fold
1719from pymc .util import RandomState , _get_seeds_per_chain , treedict
1820from pytensor import Mode , scan
1921from pytensor .compile import SharedVariable
@@ -159,17 +161,17 @@ def _marginalize(self, user_warnings=False):
159161 f"Cannot marginalize { rv_to_marginalize } due to dependent Potential { pot } "
160162 )
161163
162- old_rvs , new_rvs = replace_finite_discrete_marginal_subgraph (
163- fg , rv_to_marginalize , self .basic_RVs + rvs_left_to_marginalize
164+ if isinstance (rv_to_marginalize .owner .op , Continuous ):
165+ subgraph_builder_fn = replace_continuous_marginal_subgraph
166+ else :
167+ subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
168+ old_rvs , new_rvs = subgraph_builder_fn (
169+ fg ,
170+ rv_to_marginalize ,
171+ self .basic_RVs + rvs_left_to_marginalize ,
172+ user_warnings = user_warnings ,
164173 )
165174
166- if user_warnings and len (new_rvs ) > 2 :
167- warnings .warn (
168- "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
169- f"Their joint logp terms will be assigned to the first RV: { old_rvs [1 ]} " ,
170- UserWarning ,
171- )
172-
173175 rvs_left_to_marginalize .remove (rv_to_marginalize )
174176
175177 for old_rv , new_rv in zip (old_rvs , new_rvs ):
@@ -267,7 +269,11 @@ def marginalize(
267269 )
268270
269271 rv_op = rv_to_marginalize .owner .op
270- if isinstance (rv_op , DiscreteMarkovChain ):
272+
273+ if isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
274+ pass
275+
276+ elif isinstance (rv_op , DiscreteMarkovChain ):
271277 if rv_op .n_lags > 1 :
272278 raise NotImplementedError (
273279 "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
@@ -276,7 +282,11 @@ def marginalize(
276282 raise NotImplementedError (
277283 "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
278284 )
279- elif not isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
285+
286+ elif isinstance (rv_op , Continuous ):
287+ pass
288+
289+ else :
280290 raise NotImplementedError (
281291 f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
282292 )
@@ -549,6 +559,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
549559 """Base class for Discrete Marginal Markov Chain RVs"""
550560
551561
562+ class QMCMarginalNormalRV (MarginalRV ):
563+ """Basec class for QMC Marginalized RVs"""
564+
565+ __props__ = ("qmc_order" ,)
566+
567+ def __init__ (self , * args , qmc_order : int , ** kwargs ):
568+ self .qmc_order = qmc_order
569+ super ().__init__ (* args , ** kwargs )
570+
571+
552572def static_shape_ancestors (vars ):
553573 """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
554574 return [
@@ -646,7 +666,9 @@ def collect_shared_vars(outputs, blockers):
646666 ]
647667
648668
649- def replace_finite_discrete_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
669+ def replace_finite_discrete_marginal_subgraph (
670+ fgraph , rv_to_marginalize , all_rvs , user_warnings : bool = False
671+ ):
650672 # TODO: This should eventually be integrated in a more general routine that can
651673 # identify other types of supported marginalization, of which finite discrete
652674 # RVs is just one
@@ -655,6 +677,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
655677 if not dependent_rvs :
656678 raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
657679
680+ if user_warnings and len (dependent_rvs ) > 2 :
681+ warnings .warn (
682+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
683+ f"Their joint logp terms will be assigned to the first RV: { dependent_rvs [0 ]} " ,
684+ UserWarning ,
685+ )
686+
658687 ndim_supp = {rv .owner .op .ndim_supp for rv in dependent_rvs }
659688 if len (ndim_supp ) != 1 :
660689 raise NotImplementedError (
@@ -707,6 +736,39 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
707736 return rvs_to_marginalize , marginalized_rvs
708737
709738
739+ def replace_continuous_marginal_subgraph (
740+ fgraph , rv_to_marginalize , all_rvs , user_warnings : bool = False
741+ ):
742+ dependent_rvs = find_conditional_dependent_rvs (rv_to_marginalize , all_rvs )
743+ if not dependent_rvs :
744+ raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
745+
746+ marginalized_rv_input_rvs = find_conditional_input_rvs ([rv_to_marginalize ], all_rvs )
747+ dependent_rvs_input_rvs = [
748+ rv
749+ for rv in find_conditional_input_rvs (dependent_rvs , all_rvs )
750+ if rv is not rv_to_marginalize
751+ ]
752+
753+ input_rvs = [* marginalized_rv_input_rvs , * dependent_rvs_input_rvs ]
754+ rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
755+
756+ outputs = rvs_to_marginalize
757+ # We are strict about shared variables in SymbolicRandomVariables
758+ inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
759+
760+ # TODO: Assert no non-marginalized variables depend on the rng output of the marginalized variables!!!
761+ marginalized_rvs = QMCMarginalNormalRV (
762+ inputs = inputs ,
763+ outputs = [* outputs , * collect_default_updates (inputs = inputs , outputs = outputs ).values ()],
764+ ndim_supp = max ([rv .owner .op .ndim_supp for rv in dependent_rvs ]),
765+ qmc_order = 13 ,
766+ )(* inputs )[: len (outputs )]
767+
768+ fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
769+ return rvs_to_marginalize , marginalized_rvs
770+
771+
710772def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
711773 op = rv .owner .op
712774 dist_params = rv .owner .op .dist_params (rv .owner )
@@ -870,3 +932,65 @@ def step_alpha(logp_emission, log_alpha, log_P):
870932 # return is the joint probability of everything together, but PyMC still expects one logp for each one.
871933 dummy_logps = (pt .constant (0 ),) * (len (values ) - 1 )
872934 return joint_logp , * dummy_logps
935+
936+
937+ @_logprob .register (QMCMarginalNormalRV )
938+ def qmc_marginal_rv_logp (op , values , * inputs , ** kwargs ):
939+ # Clone the inner RV graph of the Marginalized RV
940+ marginalized_rvs_node = op .make_node (* inputs )
941+ # The MarginalizedRV contains the following outputs:
942+ # 1. The variable we marginalized
943+ # 2. The dependent variables
944+ # 3. The updates for the marginalized and dependent variables
945+ marginalized_rv , * inner_rvs_and_updates = clone_replace (
946+ op .inner_outputs ,
947+ replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
948+ )
949+ inner_rvs = inner_rvs_and_updates [: (len (inner_rvs_and_updates ) - 1 ) // 2 ]
950+
951+ marginalized_rv_node = marginalized_rv .owner
952+ marginalized_rv_op = marginalized_rv_node .op
953+
954+ # GET QMC draws from the marginalized RV
955+ # TODO: Make this an Op
956+ rng = marginalized_rv_op .rng_param (marginalized_rv_node )
957+ shape = constant_fold (tuple (marginalized_rv .shape ))
958+ size = np .prod (shape ).astype (int )
959+ n_draws = 2 ** op .qmc_order
960+
961+ # TODO: Wrap Sobol in an Op so we can control the RNG and change whenever
962+ qmc_engine = scipy .stats .qmc .Sobol (d = size , seed = rng .get_value (borrow = False ))
963+ uniform_draws = qmc_engine .random (n_draws ).reshape ((n_draws , * shape ))
964+
965+ if isinstance (marginalized_rv_op , MvNormal ):
966+ # Adapted from https://github.com/scipy/scipy/blob/87c46641a8b3b5b47b81de44c07b840468f7ebe7/scipy/stats/_qmc.py#L2211-L2298
967+ mean , cov = marginalized_rv_op .dist_params (marginalized_rv_node )
968+ corr_matrix = pt .linalg .cholesky (cov ).mT
969+ base_draws = pt .as_tensor (scipy .stats .norm .ppf (0.5 + (1 - 1e-10 ) * (uniform_draws - 0.5 )))
970+ qmc_draws = base_draws @ corr_matrix + mean
971+ else :
972+ qmc_draws = icdf (marginalized_rv , uniform_draws )
973+
974+ qmc_draws .name = f"QMC_{ marginalized_rv_op .name } _draws"
975+
976+ # Obtain the logp of the dependent variables
977+ # We need to include the marginalized RV for correctness, we remove it later.
978+ inner_rv_values = dict (zip (inner_rvs , values ))
979+ marginalized_vv = marginalized_rv .clone ()
980+ rv_values = inner_rv_values | {marginalized_rv : marginalized_vv }
981+ logps_dict = conditional_logp (rv_values = rv_values , ** kwargs )
982+ # Pop the logp term corresponding to the marginalized RV
983+ # (it already got accounted for in the bias of the QMC draws)
984+ logps_dict .pop (marginalized_vv )
985+
986+ # Vectorize across QMC draws and take the mean on log scale
987+ core_marginalized_logps = list (logps_dict .values ())
988+ batched_marginalized_logps = vectorize_graph (
989+ core_marginalized_logps , replace = {marginalized_vv : qmc_draws }
990+ )
991+
992+ # Take the mean in log scale
993+ return tuple (
994+ pt .logsumexp (batched_marginalized_logp , axis = 0 ) - pt .log (n_draws )
995+ for batched_marginalized_logp in batched_marginalized_logps
996+ )
0 commit comments