1414
1515
1616from collections import namedtuple
17- from typing import Sequence , Tuple , Union
17+ from collections . abc import Sequence
1818
1919import numpy as np
2020import pymc as pm
2626def _psivar2musigma (
2727 psi : pt .TensorVariable ,
2828 explained_var : pt .TensorVariable ,
29- psi_mask : Union [ pt .TensorLike , None ] ,
30- ) -> Tuple [pt .TensorVariable , pt .TensorVariable ]:
29+ psi_mask : pt .TensorLike | None ,
30+ ) -> tuple [pt .TensorVariable , pt .TensorVariable ]:
3131 sign = pt .sign (psi - 0.5 )
3232 if psi_mask is not None :
3333 # any computation might be ignored for ~psi_mask
@@ -55,7 +55,7 @@ def _R2D2M2CP_beta(
5555 psi : pt .TensorVariable ,
5656 * ,
5757 psi_mask ,
58- dims : Union [ str , Sequence [str ] ],
58+ dims : str | Sequence [str ],
5959 centered = False ,
6060) -> pt .TensorVariable :
6161 """R2D2M2CP beta prior.
@@ -120,7 +120,7 @@ def _R2D2M2CP_beta(
120120def _broadcast_as_dims (
121121 * values : np .ndarray ,
122122 dims : Sequence [str ],
123- ) -> Union [ Tuple [ np .ndarray , ...], np .ndarray ] :
123+ ) -> tuple [ np .ndarray , ...] | np .ndarray :
124124 model = pm .modelcontext (None )
125125 shape = [len (model .coords [d ]) for d in dims ]
126126 ret = tuple (np .broadcast_to (v , shape ) for v in values )
@@ -135,7 +135,7 @@ def _psi_masked(
135135 positive_probs_std : pt .TensorLike ,
136136 * ,
137137 dims : Sequence [str ],
138- ) -> Tuple [ Union [ pt .TensorLike , None ] , pt .TensorVariable ]:
138+ ) -> tuple [ pt .TensorLike | None , pt .TensorVariable ]:
139139 if not (
140140 isinstance (positive_probs , pt .Constant ) and isinstance (positive_probs_std , pt .Constant )
141141 ):
@@ -172,10 +172,10 @@ def _psi_masked(
172172
173173def _psi (
174174 positive_probs : pt .TensorLike ,
175- positive_probs_std : Union [ pt .TensorLike , None ] ,
175+ positive_probs_std : pt .TensorLike | None ,
176176 * ,
177177 dims : Sequence [str ],
178- ) -> Tuple [ Union [ pt .TensorLike , None ] , pt .TensorVariable ]:
178+ ) -> tuple [ pt .TensorLike | None , pt .TensorVariable ]:
179179 if positive_probs_std is not None :
180180 mask , psi = _psi_masked (
181181 positive_probs = pt .as_tensor (positive_probs ),
@@ -194,9 +194,9 @@ def _psi(
194194
195195
196196def _phi (
197- variables_importance : Union [ pt .TensorLike , None ] ,
198- variance_explained : Union [ pt .TensorLike , None ] ,
199- importance_concentration : Union [ pt .TensorLike , None ] ,
197+ variables_importance : pt .TensorLike | None ,
198+ variance_explained : pt .TensorLike | None ,
199+ importance_concentration : pt .TensorLike | None ,
200200 * ,
201201 dims : Sequence [str ],
202202) -> pt .TensorVariable :
@@ -210,15 +210,15 @@ def _phi(
210210 variables_importance = pt .as_tensor (variables_importance )
211211 if importance_concentration is not None :
212212 variables_importance *= importance_concentration
213- return pm .Dirichlet ("phi" , variables_importance , dims = broadcast_dims + [ dim ])
213+ return pm .Dirichlet ("phi" , variables_importance , dims = [ * broadcast_dims , dim ])
214214 elif variance_explained is not None :
215215 if len (model .coords [dim ]) <= 1 :
216216 raise TypeError ("Can't use variance explained with less than two variables" )
217217 phi = pt .as_tensor (variance_explained )
218218 else :
219219 phi = _broadcast_as_dims (1.0 , dims = dims )
220220 if importance_concentration is not None :
221- return pm .Dirichlet ("phi" , importance_concentration * phi , dims = broadcast_dims + [ dim ])
221+ return pm .Dirichlet ("phi" , importance_concentration * phi , dims = [ * broadcast_dims , dim ])
222222 else :
223223 return phi
224224
@@ -233,12 +233,12 @@ def R2D2M2CP(
233233 * ,
234234 dims : Sequence [str ],
235235 r2 : pt .TensorLike ,
236- variables_importance : Union [ pt .TensorLike , None ] = None ,
237- variance_explained : Union [ pt .TensorLike , None ] = None ,
238- importance_concentration : Union [ pt .TensorLike , None ] = None ,
239- r2_std : Union [ pt .TensorLike , None ] = None ,
240- positive_probs : Union [ pt .TensorLike , None ] = 0.5 ,
241- positive_probs_std : Union [ pt .TensorLike , None ] = None ,
236+ variables_importance : pt .TensorLike | None = None ,
237+ variance_explained : pt .TensorLike | None = None ,
238+ importance_concentration : pt .TensorLike | None = None ,
239+ r2_std : pt .TensorLike | None = None ,
240+ positive_probs : pt .TensorLike | None = 0.5 ,
241+ positive_probs_std : pt .TensorLike | None = None ,
242242 centered : bool = False ,
243243) -> R2D2M2CPOut :
244244 """R2D2M2CP Prior.
@@ -413,7 +413,7 @@ def R2D2M2CP(
413413 year = {2023}
414414 }
415415 """
416- if not isinstance (dims , ( list , tuple ) ):
416+ if not isinstance (dims , list | tuple ):
417417 dims = (dims ,)
418418 * broadcast_dims , dim = dims
419419 input_sigma = pt .as_tensor (input_sigma )
@@ -438,7 +438,7 @@ def R2D2M2CP(
438438 r2 ,
439439 phi ,
440440 psi ,
441- dims = broadcast_dims + [ dim ],
441+ dims = [ * broadcast_dims , dim ],
442442 centered = centered ,
443443 psi_mask = mask ,
444444 )
0 commit comments