11import warnings
2- from typing import Sequence , Tuple , Union
2+ from typing import Sequence
33
44import numpy as np
55import pymc
66import pytensor .tensor as pt
7- from arviz import dict_to_dataset
7+ from arviz import InferenceData , dict_to_dataset
88from pymc import SymbolicRandomVariable
99from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
1010from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
1414from pymc .logprob .transforms import IntervalTransform
1515from pymc .model import Model
1616from pymc .pytensorf import compile_pymc , constant_fold
17- from pymc .util import _get_seeds_per_chain , treedict
17+ from pymc .util import RandomState , _get_seeds_per_chain , treedict
1818from pytensor import Mode , scan
1919from pytensor .compile import SharedVariable
2020from pytensor .graph import Constant , FunctionGraph , ancestors , clone_replace
@@ -235,7 +235,7 @@ def clone(self):
235235
236236 def marginalize (
237237 self ,
238- rvs_to_marginalize : Union [ TensorVariable , str , Sequence [TensorVariable ], Sequence [str ] ],
238+ rvs_to_marginalize : TensorVariable | Sequence [TensorVariable ] | str | Sequence [str ],
239239 ):
240240 if not isinstance (rvs_to_marginalize , Sequence ):
241241 rvs_to_marginalize = (rvs_to_marginalize ,)
@@ -292,7 +292,7 @@ def _to_transformed(self):
292292 fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
293293 return fn , transformed_names
294294
295- def unmarginalize (self , rvs_to_unmarginalize ):
295+ def unmarginalize (self , rvs_to_unmarginalize : Sequence [ TensorVariable ] ):
296296 for rv in rvs_to_unmarginalize :
297297 self .marginalized_rvs .remove (rv )
298298 if rv .name in self ._marginalized_named_vars_to_dims :
@@ -303,11 +303,11 @@ def unmarginalize(self, rvs_to_unmarginalize):
303303
304304 def recover_marginals (
305305 self ,
306- idata ,
307- var_names = None ,
308- return_samples = True ,
309- extend_inferencedata = True ,
310- random_seed = None ,
306+ idata : InferenceData ,
307+ var_names : Sequence [ str ] | None = None ,
308+ return_samples : bool = True ,
309+ extend_inferencedata : bool = True ,
310+ random_seed : RandomState = None ,
311311 ):
312312 """Computes posterior log-probabilities and samples of marginalized variables
313313 conditioned on parameters of the model given InferenceData with posterior group
@@ -648,7 +648,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
648648 return rvs_to_marginalize , marginalized_rvs
649649
650650
651- def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> Tuple [int , ...]:
651+ def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
652652 op = rv .owner .op
653653 if isinstance (op , Bernoulli ):
654654 return (0 , 1 )
0 commit comments