Skip to content

Commit 065ce0e

Browse files
committed
.WIP new explicit RNG API
1 parent 619fe66 commit 065ce0e

File tree

12 files changed

+177
-87
lines changed

12 files changed

+177
-87
lines changed

pytensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def get_underlying_scalar_constant(v):
159159

160160

161161
# isort: off
162-
import pytensor.tensor.random.var
162+
import pytensor.tensor.random.variable
163163
import pytensor.sparse
164164
from pytensor.ifelse import ifelse
165165
from pytensor.scan import checkpoints

pytensor/compile/function/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import re
33
import traceback as tb
4+
import warnings
45
from collections.abc import Iterable
56
from pathlib import Path
67

@@ -102,7 +103,7 @@ def function(
102103
givens: Iterable[tuple[Variable, Variable]]
103104
| dict[Variable, Variable]
104105
| None = None,
105-
no_default_updates: bool = False,
106+
no_default_updates: bool | None = None,
106107
accept_inplace: bool = False,
107108
name: str | None = None,
108109
rebuild_strict: bool = True,
@@ -266,6 +267,15 @@ def opt_log1p(node):
266267
of just writing it in C from scratch.
267268
268269
"""
270+
if no_default_updates is not None:
271+
warnings.warn(
272+
"The no_default_updates parameter is deprecated and will be "
273+
"removed in a future version of PyTensor. Please set updates manually ",
274+
DeprecationWarning,
275+
stacklevel=2,
276+
)
277+
else:
278+
no_default_updates = False
269279
if isinstance(outputs, dict):
270280
assert all(isinstance(k, str) for k in outputs)
271281

pytensor/compile/sharedvalue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Provide a simple user friendly API to PyTensor-managed memory."""
22

33
import copy
4+
import warnings
45
from contextlib import contextmanager
56
from functools import singledispatch
67
from typing import TYPE_CHECKING
@@ -161,6 +162,7 @@ def default_update(self) -> Variable | None:
161162

162163
@default_update.setter
163164
def default_update(self, value):
165+
warnings.warn("Setting default_update is deprecated.", DeprecationWarning)
164166
if value is not None:
165167
self._default_update = self.type.filter_variable(value, allow_convert=True)
166168
else:

pytensor/tensor/random/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from pytensor.tensor.random.basic import *
55
from pytensor.tensor.random.op import default_rng
66
from pytensor.tensor.random.utils import RandomStream
7+
from pytensor.tensor.random.variable import rng

pytensor/tensor/random/basic.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
285285
normal = NormalRV()
286286

287287

288-
def standard_normal(*, size=None, rng=None, dtype=None):
288+
def standard_normal(*, size=None, rng=None, dtype=None, **kwargs):
289289
"""Draw samples from a standard normal distribution.
290290
291291
Signature
@@ -302,7 +302,7 @@ def standard_normal(*, size=None, rng=None, dtype=None):
302302
is returned.
303303
304304
"""
305-
return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype)
305+
return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype, **kwargs)
306306

307307

308308
class HalfNormalRV(ScipyRandomVariable):
@@ -516,7 +516,7 @@ def chisquare(df, size=None, **kwargs):
516516
return gamma(shape=df / 2.0, scale=2.0, size=size, **kwargs)
517517

518518

519-
def rayleigh(scale=1.0, *, size=None, **kwargs):
519+
def rayleigh(scale=1.0, *, size=None, return_next_rng=False, **kwargs):
520520
r"""Draw samples from a Rayleigh distribution.
521521
522522
The probability density function for `rayleigh` with parameter `scale` is given by:
@@ -550,7 +550,13 @@ def rayleigh(scale=1.0, *, size=None, **kwargs):
550550
scale = as_tensor_variable(scale)
551551
if size is None:
552552
size = scale.shape
553-
return sqrt(chisquare(df=2, size=size, **kwargs)) * scale
553+
next_rng, chisquare_draws = chisquare(
554+
df=2, size=size, return_next_rng=True, **kwargs
555+
)
556+
rayleigh_draws = sqrt(chisquare_draws) * scale
557+
if return_next_rng:
558+
return next_rng, rayleigh_draws
559+
return rayleigh_draws
554560

555561

556562
class ParetoRV(ScipyRandomVariable):
@@ -1986,7 +1992,7 @@ def rng_fn(self, *params):
19861992
return out
19871993

19881994

1989-
def choice(a, size=None, replace=True, p=None, rng=None):
1995+
def choice(a, size=None, replace=True, p=None, rng=None, return_next_rng=False):
19901996
r"""Generate a random sample from an array.
19911997
19921998
@@ -2016,17 +2022,23 @@ def choice(a, size=None, replace=True, p=None, rng=None):
20162022
# This is equivalent to the numpy implementation:
20172023
# https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914
20182024
if p is None:
2019-
idxs = integers(0, a_size, size=size, rng=rng)
2025+
next_rng, idxs = integers(
2026+
0, a_size, size=size, rng=rng, return_next_rng=True
2027+
)
20202028
else:
2021-
idxs = categorical(p, size=size, rng=rng)
2029+
next_rng, idxs = categorical(p, size=size, rng=rng, return_next_rng=True)
20222030

20232031
if a.type.ndim == 0:
20242032
# A was an implicit arange, we don't need to do any indexing
20252033
# TODO: Add rewrite for this optimization if users passed arange
2026-
return idxs
2027-
2028-
# TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice
2029-
return a[idxs]
2034+
out = idxs
2035+
else:
2036+
# TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice
2037+
out = a[idxs]
2038+
if return_next_rng:
2039+
return next_rng, out
2040+
else:
2041+
return out
20302042

20312043
# Sampling with p is not as trivial
20322044
# It involves some form of rejection sampling or iterative shuffling under the hood.
@@ -2063,7 +2075,7 @@ def choice(a, size=None, replace=True, p=None, rng=None):
20632075
op = ChoiceWithoutReplacement(signature=signature, dtype=dtype)
20642076

20652077
params = (a, core_shape) if p is None else (a, p, core_shape)
2066-
return op(*params, size=None, rng=rng)
2078+
return op(*params, size=None, rng=rng, return_next_rng=return_next_rng)
20672079

20682080

20692081
class PermutationRV(RandomVariable):

pytensor/tensor/random/op.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,16 @@ def infer_shape(self, fgraph, node, input_shapes):
314314

315315
return [None, list(shape)]
316316

317-
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
317+
def __call__(
318+
self,
319+
*args,
320+
size=None,
321+
name=None,
322+
rng=None,
323+
dtype=None,
324+
return_next_rng: bool | None = None,
325+
**kwargs,
326+
):
318327
if dtype is None:
319328
dtype = self.dtype
320329
if dtype == "floatX":
@@ -332,15 +341,31 @@ def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
332341
props["dtype"] = dtype
333342
new_op = type(self)(**props)
334343
return new_op.__call__(
335-
*args, size=size, name=name, rng=rng, dtype=dtype, **kwargs
344+
*args,
345+
size=size,
346+
name=name,
347+
rng=rng,
348+
dtype=dtype,
349+
return_next_rng=return_next_rng,
350+
**kwargs,
336351
)
337352

338-
res = super().__call__(rng, size, *args, **kwargs)
339-
353+
node = self.make_node(rng, size, *args)
354+
outputs = node.outputs
340355
if name is not None:
341-
res.name = name
342-
343-
return res
356+
outputs[self.default_output].name = name
357+
if return_next_rng:
358+
return outputs
359+
else:
360+
if return_next_rng is None:
361+
warnings.warn(
362+
"The default behavior of RandomVariable.__call__ is changing to return both the next RNG and the draws. "
363+
"Please set return_next_rng explicitly to avoid this warning.",
364+
)
365+
out = outputs[self.default_output]
366+
if kwargs.get("return_list", False):
367+
return [out]
368+
return out
344369

345370
def make_node(self, rng, size, *dist_params):
346371
"""Create a random variable node.

pytensor/tensor/random/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable, Sequence
23
from functools import wraps
34
from itertools import zip_longest
@@ -231,6 +232,11 @@ def __init__(
231232
[np.random.SeedSequence], np.random.Generator
232233
] = np.random.default_rng,
233234
):
235+
warnings.warn(
236+
"pytensor.tensor.random.utils.RandomStream is deprecated and will be removed in a future release.",
237+
category=DeprecationWarning,
238+
stacklevel=2,
239+
)
234240
if namespace is None:
235241
from pytensor.tensor.random import basic # pylint: disable=import-self
236242

pytensor/tensor/random/var.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

pytensor/tensor/random/variable.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import copy
2+
import warnings
3+
from functools import wraps
4+
from typing import TypeAlias
5+
6+
import numpy as np
7+
8+
from pytensor import config
9+
from pytensor.compile.sharedvalue import SharedVariable, shared_constructor
10+
from pytensor.graph.basic import OptionalApplyType, Variable
11+
from pytensor.tensor.random.basic import normal
12+
from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type
13+
from pytensor.tensor.variable import TensorVariable
14+
15+
16+
RNG_AND_DRAW: TypeAlias = tuple["RandomGeneratorVariable", TensorVariable]
17+
18+
19+
def warn_reuse(func):
20+
@wraps(func)
21+
def wrapper(self, *args, **kwargs):
22+
if getattr(self.tag, "used", False) and config.warn_rng.reuse:
23+
warnings.warn(
24+
f"RandomGeneratorVariable {self} has already been used. "
25+
"You probably want to use the new RandomGeneratorVariable that was returned when you used it.",
26+
UserWarning,
27+
)
28+
self.tag.used = True
29+
return func(self, *args, **kwargs)
30+
31+
return wrapper
32+
33+
34+
class _random_generator_py_operators:
35+
@warn_reuse
36+
def normal(self, loc=0, scale=1, size=None) -> RNG_AND_DRAW:
37+
return normal(loc, scale, size=size, rng=self, return_next_rng=True)
38+
39+
40+
class RandomGeneratorVariable(
41+
_random_generator_py_operators,
42+
Variable[RandomGeneratorType, OptionalApplyType],
43+
):
44+
"""The Variable type used for random number generator states."""
45+
46+
47+
RandomGeneratorType.variable_type = RandomGeneratorVariable
48+
49+
50+
def rng(name=None) -> RandomGeneratorVariable:
51+
"""Create a new default random number generator variable.
52+
53+
Returns
54+
-------
55+
RandomGeneratorVariable
56+
A new random number generator variable initialized with the default
57+
numpy random generator.
58+
"""
59+
60+
return random_generator_type(name=name)
61+
62+
63+
class RandomGeneratorSharedVariable(SharedVariable, RandomGeneratorVariable):
64+
def __str__(self):
65+
return self.name or f"RNG({self.container!r})"
66+
67+
68+
@shared_constructor.register(np.random.RandomState)
69+
@shared_constructor.register(np.random.Generator)
70+
def randomgen_constructor(
71+
value, name=None, strict=False, allow_downcast=None, borrow=False
72+
):
73+
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
74+
if isinstance(value, np.random.RandomState):
75+
raise TypeError(
76+
"`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead."
77+
)
78+
79+
rng_sv_type = RandomGeneratorSharedVariable
80+
rng_type = random_generator_type
81+
82+
if not borrow:
83+
value = copy.deepcopy(value)
84+
85+
return rng_sv_type(
86+
type=rng_type,
87+
value=value,
88+
strict=strict,
89+
allow_downcast=allow_downcast,
90+
name=name,
91+
)

tests/link/numba/test_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_rng_copy():
8282
rng.type.values_eq(rng.get_value(), np.random.default_rng(123))
8383

8484

85-
def test_rng_non_default_update():
85+
def test_rng_custom_update():
8686
rng = shared(np.random.default_rng(1))
8787
rng_new = shared(np.random.default_rng(2))
8888

0 commit comments

Comments
 (0)