6262
6363
6464class MarginalModel (Model ):
65- """Subclass of PyMC Model that implements functionality for automatic
66- marginalization of variables in the logp transformation
67-
68- After defining the full Model, the `marginalize` method can be used to indicate a
69- subset of variables that should be marginalized
70-
71- Notes
72- -----
73- Marginalization functionality is still very restricted. Only finite discrete
74- variables can be marginalized. Deterministics and Potentials cannot be conditionally
75- dependent on the marginalized variables.
76-
77- Furthermore, not all instances of such variables can be marginalized. If a variable
78- has batched dimensions, it is required that any conditionally dependent variables
79- use information from an individual batched dimension. In other words, the graph
80- connecting the marginalized variable(s) to the dependent variable(s) must be
81- composed strictly of Elemwise Operations. This is necessary to ensure an efficient
82- logprob graph can be generated. If you want to bypass this restriction you can
83- separate each dimension of the marginalized variable into the scalar components
84- and then stack them together. Note that such graphs will grow exponentially in the
85- number of marginalized variables.
86-
87- For the same reason, it's not possible to marginalize RVs with multivariate
88- dependent RVs.
89-
90- Examples
91- --------
92- Marginalize over a single variable
93-
94- .. code-block:: python
95-
96- import pymc as pm
97- from pymc_extras import MarginalModel
98-
99- with MarginalModel() as m:
100- p = pm.Beta("p", 1, 1)
101- x = pm.Bernoulli("x", p=p, shape=(3,))
102- y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
103-
104- m.marginalize([x])
105-
106- idata = pm.sample()
107-
108- """
10965
11066 def __init__ (self , * args , ** kwargs ):
11167 raise TypeError (
@@ -147,10 +103,29 @@ def _unique(seq: Sequence) -> list:
147103def marginalize (model : Model , rvs_to_marginalize : ModelRVs ) -> MarginalModel :
148104 """Marginalize a subset of variables in a PyMC model.
149105
150- This creates a class of `MarginalModel` from an existing `Model`, with the specified
151- variables marginalized.
106+ Notes
107+ -----
108+ Marginalization functionality is still very restricted. Only finite discrete
109+ variables and some closed from graphs can be marginalized.
110+ Deterministics and Potentials cannot be conditionally dependent on the marginalized variables.
152111
153- See documentation for `MarginalModel` for more information.
112+
113+ Examples
114+ --------
115+ Marginalize over a single variable
116+
117+ .. code-block:: python
118+
119+ import pymc as pm
120+ from pymc_extras import marginalize
121+
122+ with pm.Model() as m:
123+ p = pm.Beta("p", 1, 1)
124+ x = pm.Bernoulli("x", p=p, shape=(3,))
125+ y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
126+
127+ with marginalize(m, [x]) as marginal_m:
128+ idata = pm.sample()
154129
155130 Parameters
156131 ----------
@@ -161,8 +136,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
161136
162137 Returns
163138 -------
164- marginal_model: MarginalModel
165- Marginal model with the specified variables marginalized.
139+ marginal_model: Model
140+ PyMC model with the specified variables marginalized.
166141 """
167142 if isinstance (rvs_to_marginalize , str | Variable ):
168143 rvs_to_marginalize = (rvs_to_marginalize ,)
@@ -176,20 +151,20 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
176151 if rv_to_marginalize not in model .free_RVs :
177152 raise ValueError (f"Marginalized RV { rv_to_marginalize } is not a free RV in the model" )
178153
179- rv_op = rv_to_marginalize .owner .op
180- if isinstance (rv_op , DiscreteMarkovChain ):
181- if rv_op .n_lags > 1 :
182- raise NotImplementedError (
183- "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
184- )
185- if rv_to_marginalize .owner .inputs [0 ].type .ndim > 2 :
186- raise NotImplementedError (
187- "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
188- )
189- elif not isinstance (rv_op , Bernoulli | Categorical | DiscreteUniform ):
190- raise NotImplementedError (
191- f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
192- )
154+ # rv_op = rv_to_marginalize.owner.op
155+ # if isinstance(rv_op, DiscreteMarkovChain):
156+ # if rv_op.n_lags > 1:
157+ # raise NotImplementedError(
158+ # "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
159+ # )
160+ # if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
161+ # raise NotImplementedError(
162+ # "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
163+ # )
164+ # elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
165+ # raise NotImplementedError(
166+ # f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
167+ # )
193168
194169 fg , memo = fgraph_from_model (model )
195170 rvs_to_marginalize = [memo [rv ] for rv in rvs_to_marginalize ]
@@ -241,11 +216,52 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
241216 ]
242217 input_rvs = _unique ((* marginalized_rv_input_rvs , * other_direct_rv_ancestors ))
243218
244- replace_finite_discrete_marginal_subgraph (fg , rv_to_marginalize , dependent_rvs , input_rvs )
219+ marginalize_subgraph (fg , rv_to_marginalize , dependent_rvs , input_rvs )
245220
246221 return model_from_fgraph (fg , mutate_fgraph = True )
247222
248223
224+ def marginalize_subgraph (
225+ fgraph , rv_to_marginalize , dependent_rvs , input_rvs
226+ ) -> None :
227+
228+ output_rvs = [rv_to_marginalize , * dependent_rvs ]
229+ rng_updates = collect_default_updates (output_rvs , inputs = input_rvs , must_be_shared = False )
230+ outputs = output_rvs + list (rng_updates .values ())
231+ inputs = input_rvs + list (rng_updates .keys ())
232+ # Add any other shared variable inputs
233+ inputs += collect_shared_vars (output_rvs , blockers = inputs )
234+
235+ inner_inputs = [inp .clone () for inp in inputs ]
236+ inner_outputs = clone_replace (outputs , replace = dict (zip (inputs , inner_inputs )))
237+ inner_outputs = remove_model_vars (inner_outputs )
238+
239+ _ , _ , * dims = rv_to_marginalize .owner .inputs
240+ marginalization_op = MarginalRV (
241+ inputs = inner_inputs ,
242+ outputs = inner_outputs ,
243+ dims = dims ,
244+ n_dependent_rvs = len (dependent_rvs )
245+ )
246+
247+ new_outputs = marginalization_op (* inputs )
248+ assert len (new_outputs ) == len (outputs )
249+ for old_output , new_output in zip (outputs , new_outputs ):
250+ new_output .name = old_output .name
251+
252+ model_replacements = []
253+ for old_output , new_output in zip (outputs , new_outputs ):
254+ if old_output is rv_to_marginalize or not isinstance (old_output .owner .op , ModelValuedVar ):
255+ # Replace the marginalized ModelFreeRV (or non model-variables) themselves
256+ var_to_replace = old_output
257+ else :
258+ # Replace the underlying RV, keeping the same value, transform and dims
259+ var_to_replace = old_output .owner .inputs [0 ]
260+ model_replacements .append ((var_to_replace , new_output ))
261+
262+ fgraph .replace_all (model_replacements )
263+
264+
249265@node_rewriter (tracks = [MarginalRV ])
250266def local_unmarginalize (fgraph , node ):
251267 unmarginalized_rv , * dependent_rvs_and_rngs = inline_ofg_outputs (node .op , node .inputs )
0 commit comments