22from typing import Sequence , Union
33
44import numpy as np
5- import pymc
65import pytensor .tensor as pt
6+ import scipy
77from arviz import InferenceData , dict_to_dataset
8- from pymc import SymbolicRandomVariable
8+ from pymc import SymbolicRandomVariable , icdf
99from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
10+ from pymc .distributions .continuous import Continuous , Normal
1011from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
1112from pymc .distributions .transforms import Chain
1213from pymc .logprob .abstract import _logprob
@@ -159,7 +160,11 @@ def _marginalize(self, user_warnings=False):
159160 f"Cannot marginalize { rv_to_marginalize } due to dependent Potential { pot } "
160161 )
161162
162- old_rvs , new_rvs = replace_finite_discrete_marginal_subgraph (
163+ if isinstance (rv_to_marginalize .owner .op , Continuous ):
164+ subgraph_builder_fn = replace_continuous_marginal_subgraph
165+ else :
166+ subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
167+ old_rvs , new_rvs = subgraph_builder_fn (
163168 fg , rv_to_marginalize , self .basic_RVs + rvs_left_to_marginalize
164169 )
165170
@@ -267,7 +272,11 @@ def marginalize(
267272 )
268273
269274 rv_op = rv_to_marginalize .owner .op
270- if isinstance (rv_op , DiscreteMarkovChain ):
275+
276+ if isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
277+ pass
278+
279+ elif isinstance (rv_op , DiscreteMarkovChain ):
271280 if rv_op .n_lags > 1 :
272281 raise NotImplementedError (
273282 "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
@@ -276,7 +285,11 @@ def marginalize(
276285 raise NotImplementedError (
277286 "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
278287 )
279- elif not isinstance (rv_op , (Bernoulli , Categorical , DiscreteUniform )):
288+
289+ elif isinstance (rv_op , Normal ):
290+ pass
291+
292+ else :
280293 raise NotImplementedError (
281294 f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
282295 )
@@ -549,6 +562,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
549562 """Base class for Discrete Marginal Markov Chain RVs"""
550563
551564
565+ class QMCMarginalNormalRV (MarginalRV ):
566+ """Basec class for QMC Marginalized RVs"""
567+
568+ __props__ = ("qmc_order" ,)
569+
570+ def __init__ (self , * args , qmc_order : int , ** kwargs ):
571+ self .qmc_order = qmc_order
572+ super ().__init__ (* args , ** kwargs )
573+
574+
552575def static_shape_ancestors (vars ):
553576 """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
554577 return [
@@ -707,6 +730,36 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
707730 return rvs_to_marginalize , marginalized_rvs
708731
709732
733+ def replace_continuous_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
734+ dependent_rvs = find_conditional_dependent_rvs (rv_to_marginalize , all_rvs )
735+ if not dependent_rvs :
736+ raise ValueError (f"No RVs depend on marginalized RV { rv_to_marginalize } " )
737+
738+ marginalized_rv_input_rvs = find_conditional_input_rvs ([rv_to_marginalize ], all_rvs )
739+ dependent_rvs_input_rvs = [
740+ rv
741+ for rv in find_conditional_input_rvs (dependent_rvs , all_rvs )
742+ if rv is not rv_to_marginalize
743+ ]
744+
745+ input_rvs = [* marginalized_rv_input_rvs , * dependent_rvs_input_rvs ]
746+ rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
747+
748+ outputs = rvs_to_marginalize
749+ # We are strict about shared variables in SymbolicRandomVariables
750+ inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
751+
752+ marginalized_rvs = QMCMarginalNormalRV (
753+ inputs = inputs ,
754+ outputs = outputs ,
755+ ndim_supp = max ([rv .owner .op .ndim_supp for rv in dependent_rvs ]),
756+ qmc_order = 13 ,
757+ )(* inputs )
758+
759+ fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
760+ return rvs_to_marginalize , marginalized_rvs
761+
762+
710763def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
711764 op = rv .owner .op
712765 dist_params = rv .owner .op .dist_params (rv .owner )
@@ -870,3 +923,47 @@ def step_alpha(logp_emission, log_alpha, log_P):
870923 # return is the joint probability of everything together, but PyMC still expects one logp for each one.
871924 dummy_logps = (pt .constant (0 ),) * (len (values ) - 1 )
872925 return joint_logp , * dummy_logps
926+
927+
928+ @_logprob .register (QMCMarginalNormalRV )
929+ def qmc_marginal_rv_logp (op , values , * inputs , ** kwargs ):
930+ # Clone the inner RV graph of the Marginalized RV
931+ marginalized_rvs_node = op .make_node (* inputs )
932+ marginalized_rv , * inner_rvs = clone_replace (
933+ op .inner_outputs ,
934+ replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
935+ )
936+
937+ marginalized_rv_node = marginalized_rv .owner
938+ marginalized_rv_op = marginalized_rv_node .op
939+
940+ # GET QMC draws from the marginalized RV
941+ # TODO: Make this an Op
942+ rng = marginalized_rv_op .rng_param (marginalized_rv_node )
943+ shape = constant_fold (tuple (marginalized_rv .shape ))
944+ size = np .prod (shape ).astype (int )
945+ n_draws = 2 ** op .qmc_order
946+ qmc_engine = scipy .stats .qmc .Sobol (d = size , seed = rng .get_value (borrow = False ))
947+ uniform_draws = qmc_engine .random (n_draws ).reshape ((n_draws , * shape ))
948+ qmc_draws = icdf (marginalized_rv , uniform_draws )
949+ qmc_draws .name = f"QMC_{ op .name } _draws"
950+
951+ # Obtain the logp of the dependent variables
952+ # We need to include the marginalized RV for correctness, we remove it later.
953+ inner_rv_values = dict (zip (inner_rvs , values ))
954+ marginalized_vv = marginalized_rv .clone ()
955+ rv_values = inner_rv_values | {marginalized_rv : marginalized_vv }
956+ logps_dict = conditional_logp (rv_values = rv_values , ** kwargs )
957+ # Pop the logp term corresponding to the marginalized RV
958+ # (it already got accounted for in the bias of the QMC draws)
959+ logps_dict .pop (marginalized_vv )
960+
961+ # Vectorize across QMC draws and take the mean on log scale
962+ core_marginalized_logps = list (logps_dict .values ())
963+ batched_marginalized_logps = vectorize_graph (
964+ core_marginalized_logps , replace = {marginalized_vv : qmc_draws }
965+ )
966+ return tuple (
967+ pt .logsumexp (batched_marginalized_logp , axis = 0 ) - pt .log (size )
968+ for batched_marginalized_logp in batched_marginalized_logps
969+ )
0 commit comments