Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions doc/tutorial/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,45 +362,34 @@ Here's a brief example. The setup code is:

.. testcode::

from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.random import rng
from pytensor import function


srng = RandomStream(seed=234)
rv_u = srng.uniform(0, 1, size=(2,2))
rv_n = srng.normal(0, 1, size=(2,2))
f = function([], rv_u)
g = function([], rv_n, no_default_updates=True)
nearly_zeros = function([], rv_u + rv_u - 2 * rv_u)
srng = rng("rng")
next_srng, rv_u = srng.uniform(0, 1, size=(2,2))
final_srng = next_srng.normal(0, 1, size=(2,2))
f = function([srng], [rv_u])

Here, ``rv_u`` represents a random stream of 2x2 matrices of draws from a uniform
distribution. Likewise, ``rv_n`` represents a random stream of 2x2 matrices of
draws from a normal distribution. The distributions that are implemented are
defined as :class:`RandomVariable`\s
in :ref:`basic<libdoc_tensor_random_basic>`. They only work on CPU.
in :ref:`basic<libdoc_tensor_random_basic>`.


Now let's use these objects. If we call ``f()``, we get random uniform numbers.
The internal state of the random number generator is automatically updated,
so we get different random numbers every time.
Now let's use these objects. If we call ``f()``, with a numpy generator we get random uniform numbers.
Unlike numpy, PyTensor does not mutate the random generator so we get the same results when called sequentially.

>>> f_val0 = f()
>>> f_val1 = f() #different numbers from f_val0
>>> rng_np = np.random.default_rng(123)
>>> f_val0 = f(rng_np)
>>> f_val1 = f(rng_np) #different numbers from f_val0

When we add the extra argument ``no_default_updates=True`` to
``function`` (as in ``g``), then the random number generator state is
not affected by calling the returned function. So, for example, calling
``g`` multiple times will return the same numbers.
We can tell PyTensor it's safe to mutate the rng by compiling a function like this

>>> g_val0 = g() # different numbers from f_val0 and f_val1
>>> g_val1 = g() # same numbers as g_val0!

An important remark is that a random variable is drawn at most once during any
single function execution. So the `nearly_zeros` function is guaranteed to
return approximately 0 (except for rounding error) even though the ``rv_u``
random variable appears three times in the output expression.

>>> nearly_zeros = function([], rv_u + rv_u - 2 * rv_u)
>>> g = function([pytensor.In(srng, mutable=True)], [rv_u])
>>> g_val0 = g(rng_np) # different numbers in subsequent calls with the same rng
>>> g_val1 = g(rng_np)

Seeding Streams
---------------
Expand Down
2 changes: 1 addition & 1 deletion pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def get_underlying_scalar_constant(v):


# isort: off
import pytensor.tensor.random.var
import pytensor.tensor.random.variable
import pytensor.sparse
from pytensor.ifelse import ifelse
from pytensor.scan import checkpoints
Expand Down
12 changes: 11 additions & 1 deletion pytensor/compile/function/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
import traceback as tb
import warnings
from collections.abc import Iterable
from pathlib import Path

Expand Down Expand Up @@ -102,7 +103,7 @@ def function(
givens: Iterable[tuple[Variable, Variable]]
| dict[Variable, Variable]
| None = None,
no_default_updates: bool = False,
no_default_updates: bool | None = None,
accept_inplace: bool = False,
name: str | None = None,
rebuild_strict: bool = True,
Expand Down Expand Up @@ -266,6 +267,15 @@ def opt_log1p(node):
of just writing it in C from scratch.

"""
if no_default_updates is not None:
warnings.warn(
"The no_default_updates parameter is deprecated and will be "
"removed in a future version of PyTensor. Please set updates manually ",
DeprecationWarning,
stacklevel=2,
)
else:
no_default_updates = False
if isinstance(outputs, dict):
assert all(isinstance(k, str) for k in outputs)

Expand Down
2 changes: 2 additions & 0 deletions pytensor/compile/sharedvalue.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provide a simple user friendly API to PyTensor-managed memory."""

import copy
import warnings
from contextlib import contextmanager
from functools import singledispatch
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -161,6 +162,7 @@ def default_update(self) -> Variable | None:

@default_update.setter
def default_update(self, value):
warnings.warn("Setting default_update is deprecated.", DeprecationWarning)
if value is not None:
self._default_update = self.type.filter_variable(value, allow_convert=True)
else:
Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from pytensor.tensor.random.basic import *
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.random.variable import rng
36 changes: 24 additions & 12 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
normal = NormalRV()


def standard_normal(*, size=None, rng=None, dtype=None):
def standard_normal(*, size=None, rng=None, dtype=None, **kwargs):
"""Draw samples from a standard normal distribution.

Signature
Expand All @@ -302,7 +302,7 @@ def standard_normal(*, size=None, rng=None, dtype=None):
is returned.

"""
return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype)
return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype, **kwargs)


class HalfNormalRV(ScipyRandomVariable):
Expand Down Expand Up @@ -516,7 +516,7 @@ def chisquare(df, size=None, **kwargs):
return gamma(shape=df / 2.0, scale=2.0, size=size, **kwargs)


def rayleigh(scale=1.0, *, size=None, **kwargs):
def rayleigh(scale=1.0, *, size=None, return_next_rng=False, **kwargs):
r"""Draw samples from a Rayleigh distribution.

The probability density function for `rayleigh` with parameter `scale` is given by:
Expand Down Expand Up @@ -550,7 +550,13 @@ def rayleigh(scale=1.0, *, size=None, **kwargs):
scale = as_tensor_variable(scale)
if size is None:
size = scale.shape
return sqrt(chisquare(df=2, size=size, **kwargs)) * scale
next_rng, chisquare_draws = chisquare(
df=2, size=size, return_next_rng=True, **kwargs
)
rayleigh_draws = sqrt(chisquare_draws) * scale
if return_next_rng:
return next_rng, rayleigh_draws
return rayleigh_draws


class ParetoRV(ScipyRandomVariable):
Expand Down Expand Up @@ -1986,7 +1992,7 @@ def rng_fn(self, *params):
return out


def choice(a, size=None, replace=True, p=None, rng=None):
def choice(a, size=None, replace=True, p=None, rng=None, return_next_rng=False):
r"""Generate a random sample from an array.


Expand Down Expand Up @@ -2016,17 +2022,23 @@ def choice(a, size=None, replace=True, p=None, rng=None):
# This is equivalent to the numpy implementation:
# https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914
if p is None:
idxs = integers(0, a_size, size=size, rng=rng)
next_rng, idxs = integers(
0, a_size, size=size, rng=rng, return_next_rng=True
)
else:
idxs = categorical(p, size=size, rng=rng)
next_rng, idxs = categorical(p, size=size, rng=rng, return_next_rng=True)

if a.type.ndim == 0:
# A was an implicit arange, we don't need to do any indexing
# TODO: Add rewrite for this optimization if users passed arange
return idxs

# TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice
return a[idxs]
out = idxs
else:
# TODO: Can use take(a, idxs, axis) to support numpy axis argument to choice
out = a[idxs]
if return_next_rng:
return next_rng, out
else:
return out

# Sampling with p is not as trivial
# It involves some form of rejection sampling or iterative shuffling under the hood.
Expand Down Expand Up @@ -2063,7 +2075,7 @@ def choice(a, size=None, replace=True, p=None, rng=None):
op = ChoiceWithoutReplacement(signature=signature, dtype=dtype)

params = (a, core_shape) if p is None else (a, p, core_shape)
return op(*params, size=None, rng=rng)
return op(*params, size=None, rng=rng, return_next_rng=return_next_rng)


class PermutationRV(RandomVariable):
Expand Down
39 changes: 32 additions & 7 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,16 @@ def infer_shape(self, fgraph, node, input_shapes):

return [None, list(shape)]

def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
def __call__(
self,
*args,
size=None,
name=None,
rng=None,
dtype=None,
return_next_rng: bool | None = None,
**kwargs,
):
if dtype is None:
dtype = self.dtype
if dtype == "floatX":
Expand All @@ -332,15 +341,31 @@ def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
props["dtype"] = dtype
new_op = type(self)(**props)
return new_op.__call__(
*args, size=size, name=name, rng=rng, dtype=dtype, **kwargs
*args,
size=size,
name=name,
rng=rng,
dtype=dtype,
return_next_rng=return_next_rng,
**kwargs,
)

res = super().__call__(rng, size, *args, **kwargs)

node = self.make_node(rng, size, *args)
outputs = node.outputs
if name is not None:
res.name = name

return res
outputs[self.default_output].name = name
if return_next_rng:
return outputs
else:
if return_next_rng is None:
warnings.warn(
"The default behavior of RandomVariable.__call__ is changing to return both the next RNG and the draws. "
"Please set return_next_rng explicitly to avoid this warning.",
)
out = outputs[self.default_output]
if kwargs.get("return_list", False):
return [out]
return out

def make_node(self, rng, size, *dist_params):
"""Create a random variable node.
Expand Down
6 changes: 6 additions & 0 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Callable, Sequence
from functools import wraps
from itertools import zip_longest
Expand Down Expand Up @@ -231,6 +232,11 @@ def __init__(
[np.random.SeedSequence], np.random.Generator
] = np.random.default_rng,
):
warnings.warn(
"pytensor.tensor.random.utils.RandomStream is deprecated and will be removed in a future release.",
category=DeprecationWarning,
stacklevel=2,
)
if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self

Expand Down
37 changes: 0 additions & 37 deletions pytensor/tensor/random/var.py

This file was deleted.

Loading
Loading