From 619fe66b1dee5cb4c6bcb7bfc7cfa93f9a2494b6 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 29 Oct 2025 23:23:09 +0100 Subject: [PATCH 1/3] Better guess for IterationError --- pytensor/tensor/variable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 31e08fd39b..1e564488bc 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -614,8 +614,7 @@ def __iter__(self): # This prevents accidental iteration via sum(self) raise TypeError( "TensorType does not support iteration.\n" - "\tDid you pass a PyTensor variable to a function that expects a list?\n" - "\tMaybe you are using builtins.sum instead of pytensor.tensor.sum?" + "\tDid you try to unpack a Variable or used a function that expects a list?\n" ) @property From 065ce0ef251e336218a31271fd0200bbfabc269c Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 29 Oct 2025 23:24:01 +0100 Subject: [PATCH 2/3] .WIP new explicit RNG API --- pytensor/__init__.py | 2 +- pytensor/compile/function/__init__.py | 12 ++- pytensor/compile/sharedvalue.py | 2 + pytensor/tensor/random/__init__.py | 1 + pytensor/tensor/random/basic.py | 36 +++++--- pytensor/tensor/random/op.py | 39 +++++++-- pytensor/tensor/random/utils.py | 6 ++ pytensor/tensor/random/var.py | 37 --------- pytensor/tensor/random/variable.py | 91 +++++++++++++++++++++ tests/link/numba/test_random.py | 2 +- tests/scan/test_basic.py | 34 ++------ tests/tensor/random/rewriting/test_basic.py | 2 +- 12 files changed, 177 insertions(+), 87 deletions(-) delete mode 100644 pytensor/tensor/random/var.py create mode 100644 pytensor/tensor/random/variable.py diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 12f67c9a37..d16fdcc69f 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -159,7 +159,7 @@ def get_underlying_scalar_constant(v): # isort: off -import pytensor.tensor.random.var +import pytensor.tensor.random.variable import pytensor.sparse from pytensor.ifelse import ifelse from pytensor.scan import checkpoints diff --git a/pytensor/compile/function/__init__.py b/pytensor/compile/function/__init__.py index ffce6db4fb..cd2892c9ff 100644 --- a/pytensor/compile/function/__init__.py +++ b/pytensor/compile/function/__init__.py @@ -1,6 +1,7 @@ import logging import re import traceback as tb +import warnings from collections.abc import Iterable from pathlib import Path @@ -102,7 +103,7 @@ def function( givens: Iterable[tuple[Variable, Variable]] | dict[Variable, Variable] | None = None, - no_default_updates: bool = False, + no_default_updates: bool | None = None, accept_inplace: bool = False, name: str | None = None, rebuild_strict: bool = True, @@ -266,6 +267,15 @@ def opt_log1p(node): of just writing it in C from scratch. """ + if no_default_updates is not None: + warnings.warn( + "The no_default_updates parameter is deprecated and will be " + "removed in a future version of PyTensor. Please set updates manually ", + DeprecationWarning, + stacklevel=2, + ) + else: + no_default_updates = False if isinstance(outputs, dict): assert all(isinstance(k, str) for k in outputs) diff --git a/pytensor/compile/sharedvalue.py b/pytensor/compile/sharedvalue.py index 8c6f0726a4..dc27e070e0 100644 --- a/pytensor/compile/sharedvalue.py +++ b/pytensor/compile/sharedvalue.py @@ -1,6 +1,7 @@ """Provide a simple user friendly API to PyTensor-managed memory.""" import copy +import warnings from contextlib import contextmanager from functools import singledispatch from typing import TYPE_CHECKING @@ -161,6 +162,7 @@ def default_update(self) -> Variable | None: @default_update.setter def default_update(self, value): + warnings.warn("Setting default_update is deprecated.", DeprecationWarning) if value is not None: self._default_update = self.type.filter_variable(value, allow_convert=True) else: diff --git a/pytensor/tensor/random/__init__.py b/pytensor/tensor/random/__init__.py index 78994fd40c..8e62971775 100644 --- a/pytensor/tensor/random/__init__.py +++ b/pytensor/tensor/random/__init__.py @@ -4,3 +4,4 @@ from pytensor.tensor.random.basic import * from pytensor.tensor.random.op import default_rng from pytensor.tensor.random.utils import RandomStream +from pytensor.tensor.random.variable import rng diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 59e02ee3b4..bf43248399 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -285,7 +285,7 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): normal = NormalRV() -def standard_normal(*, size=None, rng=None, dtype=None): +def standard_normal(*, size=None, rng=None, dtype=None, **kwargs): """Draw samples from a standard normal distribution. Signature @@ -302,7 +302,7 @@ def standard_normal(*, size=None, rng=None, dtype=None): is returned. """ - return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype) + return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype, **kwargs) class HalfNormalRV(ScipyRandomVariable): @@ -516,7 +516,7 @@ def chisquare(df, size=None, **kwargs): return gamma(shape=df / 2.0, scale=2.0, size=size, **kwargs) -def rayleigh(scale=1.0, *, size=None, **kwargs): +def rayleigh(scale=1.0, *, size=None, return_next_rng=False, **kwargs): r"""Draw samples from a Rayleigh distribution. The probability density function for `rayleigh` with parameter `scale` is given by: @@ -550,7 +550,13 @@ def rayleigh(scale=1.0, *, size=None, **kwargs): scale = as_tensor_variable(scale) if size is None: size = scale.shape - return sqrt(chisquare(df=2, size=size, **kwargs)) * scale + next_rng, chisquare_draws = chisquare( + df=2, size=size, return_next_rng=True, **kwargs + ) + rayleigh_draws = sqrt(chisquare_draws) * scale + if return_next_rng: + return next_rng, rayleigh_draws + return rayleigh_draws class ParetoRV(ScipyRandomVariable): @@ -1986,7 +1992,7 @@ def rng_fn(self, *params): return out -def choice(a, size=None, replace=True, p=None, rng=None): +def choice(a, size=None, replace=True, p=None, rng=None, return_next_rng=False): r"""Generate a random sample from an array. @@ -2016,17 +2022,23 @@ def choice(a, size=None, replace=True, p=None, rng=None): # This is equivalent to the numpy implementation: # https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914 if p is None: - idxs = integers(0, a_size, size=size, rng=rng) + next_rng, idxs = integers( + 0, a_size, size=size, rng=rng, return_next_rng=True + ) else: - idxs = categorical(p, size=size, rng=rng) + next_rng, idxs = categorical(p, size=size, rng=rng, return_next_rng=True) if a.type.ndim == 0: # A was an implicit arange, we don't need to do any indexing # TODO: Add rewrite for this optimization if users passed arange - return idxs - - # TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice - return a[idxs] + out = idxs + else: + # TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice + out = a[idxs] + if return_next_rng: + return next_rng, out + else: + return out # Sampling with p is not as trivial # It involves some form of rejection sampling or iterative shuffling under the hood. @@ -2063,7 +2075,7 @@ def choice(a, size=None, replace=True, p=None, rng=None): op = ChoiceWithoutReplacement(signature=signature, dtype=dtype) params = (a, core_shape) if p is None else (a, p, core_shape) - return op(*params, size=None, rng=rng) + return op(*params, size=None, rng=rng, return_next_rng=return_next_rng) class PermutationRV(RandomVariable): diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 6891823576..5813b3f7a5 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -314,7 +314,16 @@ def infer_shape(self, fgraph, node, input_shapes): return [None, list(shape)] - def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs): + def __call__( + self, + *args, + size=None, + name=None, + rng=None, + dtype=None, + return_next_rng: bool | None = None, + **kwargs, + ): if dtype is None: dtype = self.dtype if dtype == "floatX": @@ -332,15 +341,31 @@ def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs): props["dtype"] = dtype new_op = type(self)(**props) return new_op.__call__( - *args, size=size, name=name, rng=rng, dtype=dtype, **kwargs + *args, + size=size, + name=name, + rng=rng, + dtype=dtype, + return_next_rng=return_next_rng, + **kwargs, ) - res = super().__call__(rng, size, *args, **kwargs) - + node = self.make_node(rng, size, *args) + outputs = node.outputs if name is not None: - res.name = name - - return res + outputs[self.default_output].name = name + if return_next_rng: + return outputs + else: + if return_next_rng is None: + warnings.warn( + "The default behavior of RandomVariable.__call__ is changing to return both the next RNG and the draws. " + "Please set return_next_rng explicitly to avoid this warning.", + ) + out = outputs[self.default_output] + if kwargs.get("return_list", False): + return [out] + return out def make_node(self, rng, size, *dist_params): """Create a random variable node. diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 86628a81cb..a3db8295bf 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Callable, Sequence from functools import wraps from itertools import zip_longest @@ -231,6 +232,11 @@ def __init__( [np.random.SeedSequence], np.random.Generator ] = np.random.default_rng, ): + warnings.warn( + "pytensor.tensor.random.utils.RandomStream is deprecated and will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) if namespace is None: from pytensor.tensor.random import basic # pylint: disable=import-self diff --git a/pytensor/tensor/random/var.py b/pytensor/tensor/random/var.py deleted file mode 100644 index 09fef393e6..0000000000 --- a/pytensor/tensor/random/var.py +++ /dev/null @@ -1,37 +0,0 @@ -import copy - -import numpy as np - -from pytensor.compile.sharedvalue import SharedVariable, shared_constructor -from pytensor.tensor.random.type import random_generator_type - - -class RandomGeneratorSharedVariable(SharedVariable): - def __str__(self): - return self.name or f"RNG({self.container!r})" - - -@shared_constructor.register(np.random.RandomState) -@shared_constructor.register(np.random.Generator) -def randomgen_constructor( - value, name=None, strict=False, allow_downcast=None, borrow=False -): - r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`.""" - if isinstance(value, np.random.RandomState): - raise TypeError( - "`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead." - ) - - rng_sv_type = RandomGeneratorSharedVariable - rng_type = random_generator_type - - if not borrow: - value = copy.deepcopy(value) - - return rng_sv_type( - type=rng_type, - value=value, - strict=strict, - allow_downcast=allow_downcast, - name=name, - ) diff --git a/pytensor/tensor/random/variable.py b/pytensor/tensor/random/variable.py new file mode 100644 index 0000000000..b3b10e8105 --- /dev/null +++ b/pytensor/tensor/random/variable.py @@ -0,0 +1,91 @@ +import copy +import warnings +from functools import wraps +from typing import TypeAlias + +import numpy as np + +from pytensor import config +from pytensor.compile.sharedvalue import SharedVariable, shared_constructor +from pytensor.graph.basic import OptionalApplyType, Variable +from pytensor.tensor.random.basic import normal +from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type +from pytensor.tensor.variable import TensorVariable + + +RNG_AND_DRAW: TypeAlias = tuple["RandomGeneratorVariable", TensorVariable] + + +def warn_reuse(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if getattr(self.tag, "used", False) and config.warn_rng.reuse: + warnings.warn( + f"RandomGeneratorVariable {self} has already been used. " + "You probably want to use the new RandomGeneratorVariable that was returned when you used it.", + UserWarning, + ) + self.tag.used = True + return func(self, *args, **kwargs) + + return wrapper + + +class _random_generator_py_operators: + @warn_reuse + def normal(self, loc=0, scale=1, size=None) -> RNG_AND_DRAW: + return normal(loc, scale, size=size, rng=self, return_next_rng=True) + + +class RandomGeneratorVariable( + _random_generator_py_operators, + Variable[RandomGeneratorType, OptionalApplyType], +): + """The Variable type used for random number generator states.""" + + +RandomGeneratorType.variable_type = RandomGeneratorVariable + + +def rng(name=None) -> RandomGeneratorVariable: + """Create a new default random number generator variable. + + Returns + ------- + RandomGeneratorVariable + A new random number generator variable initialized with the default + numpy random generator. + """ + + return random_generator_type(name=name) + + +class RandomGeneratorSharedVariable(SharedVariable, RandomGeneratorVariable): + def __str__(self): + return self.name or f"RNG({self.container!r})" + + +@shared_constructor.register(np.random.RandomState) +@shared_constructor.register(np.random.Generator) +def randomgen_constructor( + value, name=None, strict=False, allow_downcast=None, borrow=False +): + r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`.""" + if isinstance(value, np.random.RandomState): + raise TypeError( + "`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead." + ) + + rng_sv_type = RandomGeneratorSharedVariable + rng_type = random_generator_type + + if not borrow: + value = copy.deepcopy(value) + + return rng_sv_type( + type=rng_type, + value=value, + strict=strict, + allow_downcast=allow_downcast, + name=name, + ) diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index c7da82b2db..12ab6640db 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -82,7 +82,7 @@ def test_rng_copy(): rng.type.values_eq(rng.get_value(), np.random.default_rng(123)) -def test_rng_non_default_update(): +def test_rng_custom_update(): rng = shared(np.random.default_rng(1)) rng_new = shared(np.random.default_rng(2)) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 98a249c154..c6c1e03b33 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -193,19 +193,13 @@ def max_err(self, _g_pt): # verify_grad method so that other ops with multiple outputs can be tested. # DONE - rp def scan_project_sum(*args, **kwargs): - rng = RandomStream(123) - scan_outputs, updates = scan(*args, **kwargs) - if not isinstance(scan_outputs, list | tuple): - scan_outputs = [scan_outputs] - # we should ignore the random-state updates so that - # the uniform numbers are the same every evaluation and on every call - rng.add_default_updates = False + kwargs["return_list"] = True + scan_outputs = scan(*args, **kwargs, return_updates=False) + + # We don't recur on the rng so uniform numbers are the same every evaluation and on every call + rng = shared(np.random.default_rng(123)) factors = [rng.uniform(0.1, 0.9, size=s.shape) for s in scan_outputs] - # Random values (?) - return ( - sum((s * f).sum() for s, f in zip(scan_outputs, factors, strict=True)), - updates, - ) + return sum((s * f).sum() for s, f in zip(scan_outputs, factors, strict=True)) def asarrayX(value): @@ -1335,14 +1329,12 @@ def f_rnn(u_t, x_tm1, W_in, W): [u, x0, W_in, W], [gu, gx0, gW_in, gW], updates=updates, - no_default_updates=True, allow_input_downcast=True, ) cost_fn = function( [u, x0, W_in, W], cost, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) @@ -1402,14 +1394,12 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1): [u1, u2, x0, y0, W_in1], gparams, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) cost_fn = function( [u1, u2, x0, y0, W_in1], cost, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) @@ -1477,14 +1467,12 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1): [u1, u2, x0, y0, W_in1], cost, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) grad_fn = function( [u1, u2, x0, y0, W_in1], gparams, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) @@ -1547,14 +1535,12 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1): [u1, u2, x0, y0, W_in1], gparams, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) cost_fn = function( [u1, u2, x0, y0, W_in1], cost, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) @@ -1603,14 +1589,12 @@ def f_rnn_cmpl(u_t, u2_t, x_tm1, W_in): [u, u2, x0, W_in], gparams, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) cost_fn = function( [u, u2, x0, W_in], cost, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) @@ -1686,14 +1670,12 @@ def f_rnn_cmpl(u_t, x_tm1, W_in): [u, x0, W_in], gparams, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) cost_fn = function( [u, x0, W_in], cost, updates=updates, - no_default_updates=True, allow_input_downcast=True, ) @@ -1900,9 +1882,7 @@ def onestep(xdl, xprev, w): non_sequences=w, ) loss = (xseq[-1] ** 2).sum() - cost_fn = function( - [xinit, w], loss, no_default_updates=True, allow_input_downcast=True - ) + cost_fn = function([xinit, w], loss, allow_input_downcast=True) gw, gx = grad(loss, [w, xinit]) grad_fn = function([xinit, w], [gx, gw], allow_input_downcast=True) diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index f17f20ddbd..29fb7f057f 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -124,13 +124,13 @@ def test_inplace_rewrites(rv_op): out = rv_op(np.e) node = out.owner op = node.op - node.inputs[0].default_update = node.outputs[0] assert op.inplace is False f = function( [], out, mode="FAST_RUN", + updates={node.inputs[0]: node.outputs[0]}, ) (new_out, _new_rng) = f.maker.fgraph.outputs From 6b8c5feac173574a1e56dc685fd21419f4d72e96 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 1 Nov 2025 15:38:20 +0100 Subject: [PATCH 3/3] .broken wip --- doc/tutorial/examples.rst | 41 ++++++++++++++------------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/doc/tutorial/examples.rst b/doc/tutorial/examples.rst index 859d57a3ae..af965ef45e 100644 --- a/doc/tutorial/examples.rst +++ b/doc/tutorial/examples.rst @@ -362,45 +362,34 @@ Here's a brief example. The setup code is: .. testcode:: - from pytensor.tensor.random.utils import RandomStream + from pytensor.tensor.random import rng from pytensor import function - srng = RandomStream(seed=234) - rv_u = srng.uniform(0, 1, size=(2,2)) - rv_n = srng.normal(0, 1, size=(2,2)) - f = function([], rv_u) - g = function([], rv_n, no_default_updates=True) - nearly_zeros = function([], rv_u + rv_u - 2 * rv_u) + srng = rng("rng") + next_srng, rv_u = srng.uniform(0, 1, size=(2,2)) + final_srng = next_srng.normal(0, 1, size=(2,2)) + f = function([srng], [rv_u]) Here, ``rv_u`` represents a random stream of 2x2 matrices of draws from a uniform distribution. Likewise, ``rv_n`` represents a random stream of 2x2 matrices of draws from a normal distribution. The distributions that are implemented are defined as :class:`RandomVariable`\s -in :ref:`basic`. They only work on CPU. +in :ref:`basic`. -Now let's use these objects. If we call ``f()``, we get random uniform numbers. -The internal state of the random number generator is automatically updated, -so we get different random numbers every time. +Now let's use these objects. If we call ``f()``, with a numpy generator we get random uniform numbers. +Unlike numpy, PyTensor does not mutate the random generator so we get the same results when called sequentially. ->>> f_val0 = f() ->>> f_val1 = f() #different numbers from f_val0 +>>> rng_np = np.random.default_rng(123) +>>> f_val0 = f(rng_np) +>>> f_val1 = f(rng_np) #different numbers from f_val0 -When we add the extra argument ``no_default_updates=True`` to -``function`` (as in ``g``), then the random number generator state is -not affected by calling the returned function. So, for example, calling -``g`` multiple times will return the same numbers. +We can tell PyTensor it's safe to mutate the rng by compiling a function like this ->>> g_val0 = g() # different numbers from f_val0 and f_val1 ->>> g_val1 = g() # same numbers as g_val0! - -An important remark is that a random variable is drawn at most once during any -single function execution. So the `nearly_zeros` function is guaranteed to -return approximately 0 (except for rounding error) even though the ``rv_u`` -random variable appears three times in the output expression. - ->>> nearly_zeros = function([], rv_u + rv_u - 2 * rv_u) +>>> g = function([pytensor.In(srng, mutable=True)], [rv_u]) +>>> g_val0 = g(rng_np) # different numbers in subsequent calls with the same rng +>>> g_val1 = g(rng_np) Seeding Streams ---------------