@@ -285,7 +285,7 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
285285normal = NormalRV ()
286286
287287
288- def standard_normal (* , size = None , rng = None , dtype = None ):
288+ def standard_normal (* , size = None , rng = None , dtype = None , ** kwargs ):
289289 """Draw samples from a standard normal distribution.
290290
291291 Signature
@@ -302,7 +302,7 @@ def standard_normal(*, size=None, rng=None, dtype=None):
302302 is returned.
303303
304304 """
305- return normal (0.0 , 1.0 , size = size , rng = rng , dtype = dtype )
305+ return normal (0.0 , 1.0 , size = size , rng = rng , dtype = dtype , ** kwargs )
306306
307307
308308class HalfNormalRV (ScipyRandomVariable ):
@@ -516,7 +516,7 @@ def chisquare(df, size=None, **kwargs):
516516 return gamma (shape = df / 2.0 , scale = 2.0 , size = size , ** kwargs )
517517
518518
519- def rayleigh (scale = 1.0 , * , size = None , ** kwargs ):
519+ def rayleigh (scale = 1.0 , * , size = None , return_next_rng = False , ** kwargs ):
520520 r"""Draw samples from a Rayleigh distribution.
521521
522522 The probability density function for `rayleigh` with parameter `scale` is given by:
@@ -550,7 +550,13 @@ def rayleigh(scale=1.0, *, size=None, **kwargs):
550550 scale = as_tensor_variable (scale )
551551 if size is None :
552552 size = scale .shape
553- return sqrt (chisquare (df = 2 , size = size , ** kwargs )) * scale
553+ next_rng , chisquare_draws = chisquare (
554+ df = 2 , size = size , return_next_rng = True , ** kwargs
555+ )
556+ rayleigh_draws = sqrt (chisquare_draws ) * scale
557+ if return_next_rng :
558+ return next_rng , rayleigh_draws
559+ return rayleigh_draws
554560
555561
556562class ParetoRV (ScipyRandomVariable ):
@@ -1986,7 +1992,7 @@ def rng_fn(self, *params):
19861992 return out
19871993
19881994
1989- def choice (a , size = None , replace = True , p = None , rng = None ):
1995+ def choice (a , size = None , replace = True , p = None , rng = None , return_next_rng = False ):
19901996 r"""Generate a random sample from an array.
19911997
19921998
@@ -2016,17 +2022,23 @@ def choice(a, size=None, replace=True, p=None, rng=None):
20162022 # This is equivalent to the numpy implementation:
20172023 # https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914
20182024 if p is None :
2019- idxs = integers (0 , a_size , size = size , rng = rng )
2025+ next_rng , idxs = integers (
2026+ 0 , a_size , size = size , rng = rng , return_next_rng = True
2027+ )
20202028 else :
2021- idxs = categorical (p , size = size , rng = rng )
2029+ next_rng , idxs = categorical (p , size = size , rng = rng , return_next_rng = True )
20222030
20232031 if a .type .ndim == 0 :
20242032 # A was an implicit arange, we don't need to do any indexing
20252033 # TODO: Add rewrite for this optimization if users passed arange
2026- return idxs
2027-
2028- # TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice
2029- return a [idxs ]
2034+ out = idxs
2035+ else :
2036+ # TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice
2037+ out = a [idxs ]
2038+ if return_next_rng :
2039+ return next_rng , out
2040+ else :
2041+ return out
20302042
20312043 # Sampling with p is not as trivial
20322044 # It involves some form of rejection sampling or iterative shuffling under the hood.
@@ -2063,7 +2075,7 @@ def choice(a, size=None, replace=True, p=None, rng=None):
20632075 op = ChoiceWithoutReplacement (signature = signature , dtype = dtype )
20642076
20652077 params = (a , core_shape ) if p is None else (a , p , core_shape )
2066- return op (* params , size = None , rng = rng )
2078+ return op (* params , size = None , rng = rng , return_next_rng = return_next_rng )
20672079
20682080
20692081class PermutationRV (RandomVariable ):
0 commit comments