diff --git a/environment-osx-arm64.yml b/environment-osx-arm64.yml index 0064fe0330..5097f54309 100644 --- a/environment-osx-arm64.yml +++ b/environment-osx-arm64.yml @@ -26,7 +26,6 @@ dependencies: - diff-cover - mypy - types-setuptools - - scipy-stubs - pytest - pytest-cov - pytest-xdist diff --git a/environment.yml b/environment.yml index 5a883752ce..dade3c8d9d 100644 --- a/environment.yml +++ b/environment.yml @@ -28,7 +28,6 @@ dependencies: - diff-cover - mypy - types-setuptools - - scipy-stubs - pytest - pytest-cov - pytest-xdist diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 1c76e2e3a3..625e20c991 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -546,26 +546,27 @@ def construct_pfunc_ins_and_outs( "variables in the inputs list." ) - # Check that we are not using `givens` to replace input variables, because - # this typically does nothing, contrary to what one may expect. - in_var_set = set(in_variables) - try: - givens_pairs = list(givens.items()) - except AttributeError: - givens_pairs = givens - for x, y in givens_pairs: - if x in in_var_set: - raise RuntimeError( - f"You are trying to replace variable '{x}' through the " - "`givens` parameter, but this variable is an input to your " - "function. Replacing inputs is currently forbidden because it " - "has no effect. One way to modify an input `x` to a function " - "evaluating f(x) is to define a new input `y` and use " - "`pytensor.function([y], f(x), givens={x: g(y)})`. Another " - "solution consists in using `pytensor.clone_replace`, e.g. like this: " - "`pytensor.function([x], " - "pytensor.clone_replace(f(x), replace={x: g(x)}))`." - ) + if givens: + # Check that we are not using `givens` to replace input variables, because + # this typically does nothing, contrary to what one may expect. + in_var_set = set(in_variables) + try: + givens_pairs = list(givens.items()) + except AttributeError: + givens_pairs = givens + for x, y in givens_pairs: + if x in in_var_set: + raise RuntimeError( + f"You are trying to replace variable '{x}' through the " + "`givens` parameter, but this variable is an input to your " + "function. Replacing inputs is currently forbidden because it " + "has no effect. One way to modify an input `x` to a function " + "evaluating f(x) is to define a new input `y` and use " + "`pytensor.function([y], f(x), givens={x: g(y)})`. Another " + "solution consists in using `pytensor.clone_replace`, e.g. like this: " + "`pytensor.function([x], " + "pytensor.clone_replace(f(x), replace={x: g(x)}))`." + ) if not fgraph: # Extend the outputs with the updates on input variables so they are diff --git a/pytensor/gradient.py b/pytensor/gradient.py index ecdf4fbd4c..022eba2454 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): # It is possible that the inputs are disconnected from expr, # even if they are connected to cost. # This should not be an error. - hess, updates = pytensor.scan( + hess = pytensor.scan( lambda i, y, x: grad( y[i], x, @@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): ), sequences=pytensor.tensor.arange(expr.shape[0]), non_sequences=[expr, input], - ) - assert not updates, ( - "Scan has returned a list of updates; this should not happen." + return_updates=False, ) hessians.append(hess) return as_list_or_tuple(using_list, using_tuple, hessians) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index 3082c6481a..23b790ecbc 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -60,23 +60,23 @@ def scan(*outer_inputs): mit_mot_init, mit_sot_init, sit_sot_init, - op.outer_shared(outer_inputs), + op.outer_untraced_sit_sot(outer_inputs), op.outer_non_seqs(outer_inputs), ) # JAX `init` def jax_args_to_inner_func_args(carry, x): """Convert JAX scan arguments into format expected by scan_inner_func. - scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs) + scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, untraced SIT-SOT, non_seqs) """ - # `carry` contains all inner taps, shared terms, and non_seqs + # `carry` contains all inner taps and non_seqs ( i, inner_mit_mot, inner_mit_sot, inner_sit_sot, - inner_shared, + inner_untraced_sit_sot, inner_non_seqs, ) = carry @@ -108,7 +108,7 @@ def jax_args_to_inner_func_args(carry, x): *mit_mot_flatten, *mit_sot_flatten, *inner_sit_sot, - *inner_shared, + *inner_untraced_sit_sot, *inner_non_seqs, ) @@ -118,14 +118,14 @@ def inner_func_outs_to_jax_outs( ): """Convert inner_scan_func outputs into format expected by JAX scan. - old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys) + old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys) """ ( i, old_mit_mot, old_mit_sot, _old_sit_sot, - _old_shared, + _old_untraced_sit_sot, inner_non_seqs, ) = old_carry @@ -133,7 +133,7 @@ def inner_func_outs_to_jax_outs( new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs) new_sit_sot = op.inner_sitsot_outs(inner_scan_outs) new_nit_sot = op.inner_nitsot_outs(inner_scan_outs) - new_shared = op.inner_shared_outs(inner_scan_outs) + new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_scan_outs) # New carry for next step # Update MIT-MOT buffer at positions indicated by output taps @@ -150,14 +150,14 @@ def inner_func_outs_to_jax_outs( old_mit_sot, new_mit_sot_vals, strict=True ) ] - # For SIT-SOT, and shared just pass along the new value + # For SIT-SOT just pass along the new value # Non-sequences remain unchanged new_carry = ( i + 1, new_mit_mot, new_mit_sot, new_sit_sot, - new_shared, + new_untraced_sit_sot, inner_non_seqs, ) @@ -183,7 +183,7 @@ def jax_inner_func(carry, x): final_mit_mot, _final_mit_sot, _final_sit_sot, - final_shared, + final_untraced_sit_sot, _final_non_seqs, ), traces, @@ -238,7 +238,7 @@ def get_partial_traces(traces): scan_outs_final = [ *final_mit_mot, *get_partial_traces(traces), - *final_shared, + *final_untraced_sit_sot, ] if len(scan_outs_final) == 1: diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py index 26b308d4ef..df24335121 100644 --- a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py +++ b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py @@ -48,7 +48,7 @@ def _lu_1( Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A. """ - return linalg.lu( + return linalg.lu( # type: ignore[no-any-return] a, permute_l=permute_l, check_finite=check_finite, @@ -70,7 +70,7 @@ def _lu_2( Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L. """ - return linalg.lu( + return linalg.lu( # type: ignore[no-any-return] a, permute_l=permute_l, check_finite=check_finite, @@ -92,7 +92,7 @@ def _lu_3( Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation matrix, P @ L @ U = A. """ - return linalg.lu( + return linalg.lu( # type: ignore[no-any-return] a, permute_l=permute_l, check_finite=check_finite, diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 694f341ed4..7c431fb707 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) outer_in_sit_sot_names = op.outer_sitsot(outer_in_names) outer_in_nit_sot_names = op.outer_nitsot(outer_in_names) - outer_in_shared_names = op.outer_shared(outer_in_names) + outer_in_untraced_sit_sot_names = op.outer_untraced_sit_sot(outer_in_names) outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) # These are all the outer-input names that have produce outputs/have output # taps (i.e. they have inner-outputs and corresponding outer-outputs). # Outer-outputs are ordered as follows: - # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs + # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs outer_in_outtap_names = ( outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names + outer_in_nit_sot_names - + outer_in_shared_names + + outer_in_untraced_sit_sot_names ) # We create distinct variables for/references to the storage arrays for @@ -138,8 +138,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): for outer_in_name in outer_in_nit_sot_names: outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage" - for outer_in_name in outer_in_shared_names: - outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage" + for outer_in_name in outer_in_untraced_sit_sot_names: + outer_in_to_storage_name[outer_in_name] = ( + f"{outer_in_name}_untraced_sit_sot_storage" + ) outer_output_names = list(outer_in_to_storage_name.values()) assert len(outer_output_names) == len(node.outputs) @@ -147,7 +149,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): # Construct the inner-input expressions (e.g. indexed storage expressions) # Inner-inputs are ordered as follows: # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + - # shared-inputs + non-sequences. + # untraced-sit-sot-inputs + non-sequences. temp_scalar_storage_alloc_stmts: list[str] = [] inner_in_exprs_scalar: list[str] = [] inner_in_exprs: list[str] = [] @@ -204,11 +206,9 @@ def add_inner_in_expr( # Inner-outputs consist of: # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + - # shared-outputs [+ while-condition] + # untraced-sit-sot-outputs [+ while-condition] inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))] - # inner_out_shared_names = op.inner_shared_outs(inner_output_names) - # The assignment statements that copy inner-outputs into the outer-outputs # storage inner_out_to_outer_in_stmts: list[str] = [] diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 6b03917a2b..b0ecc6e6fb 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -1,4 +1,6 @@ +import typing import warnings +from itertools import chain import numpy as np @@ -9,7 +11,8 @@ from pytensor.graph.basic import Constant, Variable from pytensor.graph.op import get_test_value from pytensor.graph.replace import clone_replace -from pytensor.graph.traversal import graph_inputs +from pytensor.graph.traversal import explicit_graph_inputs +from pytensor.graph.type import HasShape from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import expand_empty, safe_new, until @@ -21,6 +24,10 @@ from pytensor.updates import OrderedUpdates +if typing.TYPE_CHECKING: + from pytensor.tensor.type import TensorVariable + + def get_updates_and_outputs(ls): """Recognize and order the updates, outputs, and stopping condition for a `Scan`. @@ -161,6 +168,26 @@ def isNaN_or_Inf_or_None(x): return isNone or isNaN or isInf or isStr +def _manage_output_api_change(outputs, updates, return_updates): + if return_updates: + warnings.warn( + "Scan return signature will change. Updates dict will not be returned, only the first argument. " + "Pass `return_updates=False` to conform to the new API and avoid this warning", + DeprecationWarning, + # Only meant for developers for now. Switch to FutureWarning to warn users, before removing. + stacklevel=3, + ) + else: + if updates: + raise ValueError( + f"return_updates=False but Scan produced updates {updates}. " + "Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates." + ) + return outputs + + return outputs, updates + + def scan( fn, sequences=None, @@ -175,6 +202,7 @@ def scan( allow_gc=None, strict=False, return_list=False, + return_updates: bool = True, ): r"""This function constructs and applies a `Scan` `Op` to the provided arguments. @@ -468,26 +496,22 @@ def wrap_into_list(x): # Make sure we get rid of numpy arrays or ints or anything like that # passed as inputs to scan - non_seqs = [] + non_seqs: list[Variable] = [] for elem in wrap_into_list(non_sequences): if not isinstance(elem, Variable): non_seqs.append(pt.as_tensor_variable(elem)) else: non_seqs.append(elem) - # If we provided a known number of steps ( before compilation) - # and if that number is 1 or -1, then we can skip the Scan Op, - # and just apply the inner function once - # To do that we check here to see the nature of n_steps - n_fixed_steps = None - + # This helper eagerly skips the Scan if n_steps is known to be 1 + single_step_requested = False if isinstance(n_steps, float | int): - n_fixed_steps = int(n_steps) + single_step_requested = n_steps == 1 else: try: - n_fixed_steps = pt.get_scalar_constant_value(n_steps) + single_step_requested = pt.get_scalar_constant_value(n_steps) == 1 except NotScalarConstantError: - n_fixed_steps = None + pass # Check n_steps is an int if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in integer_dtypes: @@ -497,7 +521,6 @@ def wrap_into_list(x): n_seqs = len(seqs) n_outs = len(outs_info) - return_steps = {} # wrap sequences in a dictionary if they are not already dictionaries for i in range(n_seqs): if not isinstance(seqs[i], dict): @@ -689,10 +712,10 @@ def wrap_into_list(x): # MIT_MOT -- not provided by the user only by the grad function n_mit_mot = 0 - mit_mot_scan_inputs = [] - mit_mot_inner_inputs = [] - mit_mot_inner_outputs = [] - mit_mot_out_slices = [] + mit_mot_scan_inputs: list[TensorVariable] = [] + mit_mot_inner_inputs: list[TensorVariable] = [] + mit_mot_inner_outputs: list[TensorVariable] = [] + mit_mot_out_slices: list[TensorVariable] = [] # SIT_SOT -- provided by the user n_mit_sot = 0 @@ -700,7 +723,6 @@ def wrap_into_list(x): mit_sot_inner_inputs = [] mit_sot_inner_slices = [] mit_sot_inner_outputs = [] - mit_sot_return_steps = {} mit_sot_tap_array = [] mit_sot_rightOrder = [] @@ -709,9 +731,14 @@ def wrap_into_list(x): sit_sot_inner_inputs = [] sit_sot_inner_slices = [] sit_sot_inner_outputs = [] - sit_sot_return_steps = {} sit_sot_rightOrder = [] + n_untraced_sit_sot_outs = 0 + untraced_sit_sot_scan_inputs = [] + untraced_sit_sot_inner_inputs = [] + untraced_sit_sot_inner_outputs = [] + untraced_sit_sot_rightOrder = [] + # go through outputs picking up time slices as needed for i, init_out in enumerate(outs_info): # Note that our convention dictates that if an output uses @@ -747,19 +774,35 @@ def wrap_into_list(x): # We need now to allocate space for storing the output and copy # the initial state over. We do this using the expand function # defined in scan utils - sit_sot_scan_inputs.append( - expand_empty( - shape_padleft(actual_arg), - actual_n_steps, + if isinstance(actual_arg.type, HasShape): + sit_sot_scan_inputs.append( + expand_empty( + shape_padleft(actual_arg), + actual_n_steps, + ) ) - ) + sit_sot_inner_slices.append(actual_arg) - sit_sot_inner_slices.append(actual_arg) - if i in return_steps: - sit_sot_return_steps[n_sit_sot] = return_steps[i] - sit_sot_inner_inputs.append(arg) - sit_sot_rightOrder.append(i) - n_sit_sot += 1 + sit_sot_inner_inputs.append(arg) + sit_sot_rightOrder.append(i) + n_sit_sot += 1 + else: + # Assume variables without shape cannot be stacked (e.g., RNG variables) + # Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature + from pytensor.tensor.random.type import RandomType + + if not isinstance(arg.type, RandomType): + warnings.warn( + ( + f"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. " + "Only the last value will be returned, not the entire sequence." + ), + UserWarning, + ) + untraced_sit_sot_scan_inputs.append(actual_arg) + untraced_sit_sot_inner_inputs.append(arg) + n_untraced_sit_sot_outs += 1 + untraced_sit_sot_rightOrder.append(i) elif init_out.get("taps", None): if np.any(np.array(init_out.get("taps", [])) > 0): @@ -774,8 +817,6 @@ def wrap_into_list(x): expand_empty(init_out["initial"][:mintap], actual_n_steps) ) - if i in return_steps: - mit_sot_return_steps[n_mit_sot] = return_steps[i] mit_sot_rightOrder.append(i) n_mit_sot += 1 for k in init_out["taps"]: @@ -812,14 +853,15 @@ def wrap_into_list(x): # a map); in that case we do not have to do anything .. # Re-order args - max_mit_sot = np.max([-1, *mit_sot_rightOrder]) + 1 - max_sit_sot = np.max([-1, *sit_sot_rightOrder]) + 1 - n_elems = np.max([max_mit_sot, max_sit_sot]) - _ordered_args = [[] for x in range(n_elems)] + max_mit_sot = max(mit_sot_rightOrder, default=-1) + 1 + max_sit_sot = max(sit_sot_rightOrder, default=-1) + 1 + max_untraced_sit_sot_outs = max(untraced_sit_sot_rightOrder, default=-1) + 1 + n_elems = np.max((max_mit_sot, max_sit_sot, max_untraced_sit_sot_outs)) + _ordered_args: list[list[Variable]] = [[] for x in range(n_elems)] offset = 0 for idx in range(n_mit_sot): n_inputs = len(mit_sot_tap_array[idx]) - if n_fixed_steps in (1, -1): + if single_step_requested: _ordered_args[mit_sot_rightOrder[idx]] = mit_sot_inner_slices[ offset : offset + n_inputs ] @@ -830,17 +872,19 @@ def wrap_into_list(x): offset += n_inputs for idx in range(n_sit_sot): - if n_fixed_steps in (1, -1): + if single_step_requested: _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_slices[idx]] else: _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]] - ordered_args = [] - for ls in _ordered_args: - ordered_args += ls - if n_fixed_steps in (1, -1): - args = inner_slices + ordered_args + non_seqs + for idx in range(n_untraced_sit_sot_outs): + _ordered_args[untraced_sit_sot_rightOrder[idx]] = [ + untraced_sit_sot_inner_inputs[idx] + ] + ordered_args = list(chain.from_iterable(_ordered_args)) + if single_step_requested: + args = inner_slices + ordered_args + non_seqs else: args = inner_seqs + ordered_args + non_seqs @@ -863,7 +907,7 @@ def wrap_into_list(x): # Step 3. Check if we actually need scan and remove it if we don't ## - if n_fixed_steps in (1, -1): + if single_step_requested: for pos, inner_out in enumerate(outputs): # we need to see if we need to pad our sequences with an # extra dimension; case example : we return an @@ -871,13 +915,13 @@ def wrap_into_list(x): # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. - if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: + if isinstance(inner_out.type, TensorType): outputs[pos] = shape_padleft(inner_out) if not return_list and len(outputs) == 1: outputs = outputs[0] - return (outputs, updates) + return _manage_output_api_change(outputs, updates, return_updates) ## # Step 4. Compile the dummy function @@ -896,15 +940,12 @@ def wrap_into_list(x): fake_outputs = clone_replace( outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True)) ) - all_inputs = filter( - lambda x: ( - isinstance(x, Variable) - and not isinstance(x, SharedVariable) - and not isinstance(x, Constant) - ), - graph_inputs(fake_outputs), - ) - extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs] + # TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs` + # to find implicit inputs in a way that reduces the size of the inner function + known_inputs = [*args, *fake_nonseqs] + extra_inputs = [ + x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs + ] non_seqs += extra_inputs # Note we do not use all_inputs directly since the order of variables # in args is quite important @@ -957,18 +998,19 @@ def wrap_into_list(x): if "taps" in out and out["taps"] != [-1]: mit_sot_inner_outputs.append(outputs[i]) - # Step 5.2 Outputs with tap equal to -1 + # Step 5.2 Outputs with tap equal to -1 (traced and untraced) for i, out in enumerate(outs_info): if "taps" in out and out["taps"] == [-1]: - sit_sot_inner_outputs.append(outputs[i]) + output = outputs[i] + if isinstance(output.type, HasShape): + sit_sot_inner_outputs.append(output) + else: + untraced_sit_sot_inner_outputs.append(output) # Step 5.3 Outputs that correspond to update rules of shared variables + # This whole special logic for shared variables is deprecated + sit_sot_shared: list[Variable] = [] inner_replacements = {} - n_shared_outs = 0 - shared_scan_inputs = [] - shared_inner_inputs = [] - shared_inner_outputs = [] - sit_sot_shared = [] no_update_shared_inputs = [] for input in dummy_inputs: if not isinstance(input.variable, SharedVariable): @@ -994,8 +1036,8 @@ def wrap_into_list(x): new_var = safe_new(input.variable) - if getattr(input.variable, "name", None) is not None: - new_var.name = input.variable.name + "_copy" + if input.variable.name is not None: + new_var.name = f"{input.variable.name}_copy" inner_replacements[input.variable] = new_var @@ -1021,10 +1063,10 @@ def wrap_into_list(x): sit_sot_shared.append(input.variable) else: - shared_inner_inputs.append(new_var) - shared_scan_inputs.append(input.variable) - shared_inner_outputs.append(input.update) - n_shared_outs += 1 + untraced_sit_sot_inner_inputs.append(new_var) + untraced_sit_sot_scan_inputs.append(input.variable) + untraced_sit_sot_inner_outputs.append(input.update) + n_untraced_sit_sot_outs += 1 else: no_update_shared_inputs.append(input) @@ -1033,13 +1075,10 @@ def wrap_into_list(x): # Step 5.4 Outputs with no taps used in the input n_nit_sot = 0 nit_sot_inner_outputs = [] - nit_sot_return_steps = {} nit_sot_rightOrder = [] for i, out in enumerate(outs_info): if "taps" not in out: nit_sot_inner_outputs.append(outputs[i]) - if i in return_steps: - nit_sot_return_steps[n_nit_sot] = return_steps[i] nit_sot_rightOrder.append(i) n_nit_sot += 1 @@ -1058,7 +1097,7 @@ def wrap_into_list(x): if not isinstance(arg, SharedVariable | Constant) ] - inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) + inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) # type: ignore[arg-type] if strict: non_seqs_set = set(non_sequences if non_sequences is not None else []) @@ -1092,7 +1131,7 @@ def wrap_into_list(x): + mit_mot_inner_inputs + mit_sot_inner_inputs + sit_sot_inner_inputs - + shared_inner_inputs + + untraced_sit_sot_inner_inputs + other_shared_inner_args + other_inner_args ) @@ -1102,12 +1141,12 @@ def wrap_into_list(x): + mit_sot_inner_outputs + sit_sot_inner_outputs + nit_sot_inner_outputs - + shared_inner_outputs + + untraced_sit_sot_inner_outputs ) if condition is not None: inner_outs.append(condition) - new_outs = clone_replace(inner_outs, replace=inner_replacements) + new_outs = clone_replace(inner_outs, replace=inner_replacements) # type: ignore[arg-type] ## # Step 7. Create the Scan Op @@ -1122,7 +1161,7 @@ def wrap_into_list(x): mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices), mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array), sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)), - n_shared_outs=n_shared_outs, + n_untraced_sit_sot_outs=n_untraced_sit_sot_outs, n_nit_sot=n_nit_sot, n_non_seqs=len(other_shared_inner_args) + len(other_inner_args), as_while=as_while, @@ -1148,7 +1187,7 @@ def wrap_into_list(x): + mit_mot_scan_inputs + mit_sot_scan_inputs + sit_sot_scan_inputs - + shared_scan_inputs + + untraced_sit_sot_scan_inputs + [actual_n_steps for x in range(n_nit_sot)] + other_shared_scan_args + other_scan_args @@ -1173,46 +1212,49 @@ def wrap_into_list(x): update_map = OrderedUpdates() - def remove_dimensions(outs, steps_return, offsets=None): + def remove_dimensions(outs, offsets=None): out_ls = [] for idx, out in enumerate(outs): - if idx in steps_return: - if steps_return[idx] > 1: - out_ls.append(out[-steps_return[idx] :]) - else: - out_ls.append(out[-1]) + if offsets is None: + out_ls.append(out) else: - if offsets is None: - out_ls.append(out) - else: - out_ls.append(out[offsets[idx] :]) + out_ls.append(out[offsets[idx] :]) return out_ls offset = n_mit_mot offsets = [abs(np.min(x)) for x in mit_sot_tap_array] - mit_sot_outs = remove_dimensions( - scan_outs[offset : offset + n_mit_sot], mit_sot_return_steps, offsets - ) + mit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_mit_sot], offsets) offset += n_mit_sot offsets = [1 for x in range(n_sit_sot)] - sit_sot_outs = remove_dimensions( - scan_outs[offset : offset + n_sit_sot], sit_sot_return_steps, offsets - ) + sit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_sit_sot], offsets) offset += n_sit_sot - nit_sot_outs = remove_dimensions( - scan_outs[offset : offset + n_nit_sot], nit_sot_return_steps - ) + nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot]) offset += n_nit_sot - for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]): - update_map[shared_scan_inputs[idx]] = update_rule - _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + # Legacy support for explicit untraced sit_sot and those built with update dictionary + # Switch to n_untraced_sit_sot_outs after deprecation period + n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder) + untraced_sit_sot_outs = scan_outs[ + offset : offset + n_explicit_untraced_sit_sot_outs + ] + + # Legacy support: map shared outputs to their updates + offset += n_explicit_untraced_sit_sot_outs + for idx, update_rule in enumerate(scan_outs[offset:]): + update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule + + _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs # Step 10. I need to reorder the outputs to be in the order expected by # the user - rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder + rightOrder = ( + mit_sot_rightOrder + + sit_sot_rightOrder + + untraced_sit_sot_rightOrder + + nit_sot_rightOrder + ) scan_out_list = [None] * len(rightOrder) for idx, pos in enumerate(rightOrder): if pos >= 0: @@ -1228,8 +1270,8 @@ def remove_dimensions(outs, steps_return, offsets=None): update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] scan_out_list = [x for x in scan_out_list if x is not None] if not return_list and len(scan_out_list) == 1: - scan_out_list = scan_out_list[0] + scan_out_list = scan_out_list[0] # type: ignore[assignment] elif len(scan_out_list) == 0: - scan_out_list = None + scan_out_list = None # type: ignore[assignment] - return (scan_out_list, update_map) + return _manage_output_api_change(scan_out_list, update_map, return_updates) diff --git a/pytensor/scan/checkpoints.py b/pytensor/scan/checkpoints.py index d974e8257e..12a63a1a3e 100644 --- a/pytensor/scan/checkpoints.py +++ b/pytensor/scan/checkpoints.py @@ -13,6 +13,7 @@ def scan_checkpoints( n_steps=None, save_every_N=10, padding=True, + return_updates=True, ): """Scan function that uses less memory, but is more restrictive. @@ -157,24 +158,28 @@ def outer_step(*args): ] * len(new_nitsots) # Call the user-provided function with the proper arguments - results, updates = scan( + results_and_updates = scan( fn=fn, sequences=i_sequences[:-1], outputs_info=i_outputs_infos, non_sequences=i_non_sequences, name=name + "_inner", n_steps=i_sequences[-1], + return_updates=return_updates, ) + if return_updates: + results, updates = results_and_updates + else: + results = results_and_updates + updates = {} + if not isinstance(results, list): results = [results] # Keep only the last timestep of every output but keep all the updates - if not isinstance(results, list): - return results[-1], updates - else: - return [r[-1] for r in results], updates + return [r[-1] for r in results], updates - results, updates = scan( + return scan( fn=outer_step, sequences=o_sequences, outputs_info=outputs_info, @@ -182,6 +187,5 @@ def outer_step(*args): name=name + "_outer", n_steps=o_n_steps, allow_gc=True, + return_updates=return_updates, ) - - return results, updates diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index eda97560b3..7e2d8186fd 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -46,6 +46,7 @@ import dataclasses import logging import time +import warnings from collections.abc import Callable, Iterable from copy import copy from itertools import chain, product @@ -208,10 +209,19 @@ class ScanInfo: mit_sot_in_slices: tuple sit_sot_in_slices: tuple n_nit_sot: int - n_shared_outs: int + n_untraced_sit_sot_outs: int n_non_seqs: int as_while: bool + @property + def n_shared_outs(self): + warnings.warn( + "The 'n_shared_outs' property is deprecated. Use 'n_untraced_sit_sot_outs' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.n_untraced_sit_sot_outs + @property def n_mit_mot(self): return len(self.mit_mot_in_slices) @@ -239,7 +249,7 @@ def n_inner_inputs(self): + sum(len(x) for x in self.mit_mot_in_slices) + sum(len(x) for x in self.mit_sot_in_slices) + self.n_sit_sot - + self.n_shared_outs + + self.n_untraced_sit_sot_outs + self.n_non_seqs ) @@ -250,7 +260,7 @@ def n_inner_outputs(self): + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot - + self.n_shared_outs + + self.n_untraced_sit_sot_outs + int(self.as_while) ) @@ -263,7 +273,7 @@ def n_outer_inputs(self): + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot - + self.n_shared_outs + + self.n_untraced_sit_sot_outs + self.n_non_seqs ) @@ -274,7 +284,7 @@ def n_outer_outputs(self): + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot - + self.n_shared_outs + + self.n_untraced_sit_sot_outs ) @@ -381,7 +391,7 @@ def outer_nitsot(self, list_inputs): + self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot - + self.info.n_shared_outs + + self.info.n_untraced_sit_sot_outs ) return list_inputs[offset : offset + self.info.n_nit_sot] @@ -394,15 +404,23 @@ def outer_nitsot_outs(self, list_outputs): offset = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot return list_outputs[offset : offset + self.info.n_nit_sot] - def inner_shared(self, list_inputs): + def inner_untraced_sit_sot(self, list_inputs): n_taps_upto_sit_sot = sum( len(x) for x in chain(self.info.mit_mot_in_slices, self.info.mit_sot_in_slices) ) offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot - return list_inputs[offset : offset + self.info.n_shared_outs] + return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs] - def outer_shared(self, list_inputs): + def inner_shared(self, list_inputs): + warnings.warn( + "The 'inner_shared' method is deprecated. Use 'inner_untraced_sit_sot' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.inner_untraced_sit_sot(list_inputs) + + def outer_untraced_sit_sot(self, list_inputs): offset = ( 1 + self.info.n_seqs @@ -410,23 +428,47 @@ def outer_shared(self, list_inputs): + self.info.n_mit_sot + self.info.n_sit_sot ) - return list_inputs[offset : offset + self.info.n_shared_outs] + return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs] - def inner_shared_outs(self, list_outputs): + def outer_shared(self, list_inputs): + warnings.warn( + "The 'outer_shared' method is deprecated. Use 'outer_untraced_sit_sot' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.outer_untraced_sit_sot(list_inputs) + + def inner_untraced_sit_sot_outs(self, list_outputs): n_taps = sum(len(x) for x in self.info.mit_mot_out_slices) offset = ( self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot ) - return list_outputs[offset : offset + self.info.n_shared_outs] + return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs] - def outer_shared_outs(self, list_outputs): + def inner_shared_outs(self, list_outputs): + warnings.warn( + "The 'inner_shared_outs' method is deprecated. Use 'inner_untraced_sit_sot_outs' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.inner_untraced_sit_sot_outs(list_outputs) + + def outer_untraced_sit_sot_outs(self, list_outputs): offset = ( self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot + self.info.n_nit_sot ) - return list_outputs[offset : offset + self.info.n_shared_outs] + return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs] + + def outer_shared_outs(self, list_outputs): + warnings.warn( + "The 'outer_shared_outs' method is deprecated. Use 'outer_untraced_sit_sot_outs' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.outer_untraced_sit_sot_outs(list_outputs) def inner_non_seqs(self, list_inputs): n_taps_upto_sit_sot = sum( @@ -437,7 +479,7 @@ def inner_non_seqs(self, list_inputs): self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot - + self.info.n_shared_outs + + self.info.n_untraced_sit_sot_outs ) return list_inputs[offset:] @@ -449,7 +491,7 @@ def outer_non_seqs(self, list_inputs): + self.info.n_mit_sot + self.info.n_sit_sot + self.info.n_nit_sot - + self.info.n_shared_outs + + self.info.n_untraced_sit_sot_outs ) return list_inputs[offset:] @@ -525,8 +567,8 @@ def get_oinp_iinp_iout_oout_mappings(self): outer_oidx += 1 # This is needed because, for outer inputs (and for outer inputs only) - # nitsots come *after* shared variables. - outer_iidx += self.info.n_shared_outs + # nitsots come *after* untraced_sitsot variables. + outer_iidx += self.info.n_untraced_sit_sot_outs # Handle nitsots variables for i in range(self.info.n_nit_sot): @@ -541,11 +583,11 @@ def get_oinp_iinp_iout_oout_mappings(self): outer_oidx += 1 # This is needed because, for outer inputs (and for outer inputs only) - # nitsots come *after* shared variables. - outer_iidx -= self.info.n_shared_outs + self.info.n_nit_sot + # nitsots come *after* untraced_sit_sot variables. + outer_iidx -= self.info.n_untraced_sit_sot_outs + self.info.n_nit_sot - # Handle shared states - for i in range(self.info.n_shared_outs): + # Handle untraced_sitsot states + for i in range(self.info.n_untraced_sit_sot_outs): outer_input_indices.append(outer_iidx) inner_input_indices.append([inner_iidx]) inner_output_indices.append([inner_oidx]) @@ -557,7 +599,7 @@ def get_oinp_iinp_iout_oout_mappings(self): outer_oidx += 1 # This is needed because, for outer inputs (and for outer inputs only) - # nitsots come *after* shared variables. + # nitsots come *after* untraced_sitsot variables. outer_iidx += self.info.n_nit_sot # Handle non-sequence inputs @@ -708,7 +750,7 @@ def __init__( Inputs of the inner function of `Scan`. These take the following general form: - sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + shared-inputs + non-sequences + sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + untraced-sit-sot-inputs + shared-inputs + non-sequences where each term is a list of `Variable`\s. @@ -716,7 +758,7 @@ def __init__( Outputs of the inner function of `Scan`. These take the following general form: - mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs [+ while-condition] + mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs [+ while-condition] where each term is a list of `Variable`\s. @@ -817,7 +859,7 @@ def tensorConstructor(shape, dtype): typeConstructor((None, *o.type.shape), o.type.dtype) ) - # shared outputs + possibly the ending condition + # untraced_sit_sot outputs + possibly the ending condition for o in self.fgraph.outputs[end:]: self.output_types.append(o.type) @@ -836,10 +878,12 @@ def tensorConstructor(shape, dtype): ] self.mintaps += [0 for x in range(info.n_nit_sot)] self.seqs_arg_offset = 1 + info.n_seqs - self.shared_arg_offset = ( + self.untraced_sit_sot_arg_offset = ( self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot ) - self.nit_sot_arg_offset = self.shared_arg_offset + info.n_shared_outs + self.nit_sot_arg_offset = ( + self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs + ) # XXX: This doesn't include `info.n_nit_sot`s, so it's really a count # of the number of outputs generated by taps with inputs self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot @@ -908,7 +952,7 @@ def make_node(self, *inputs): sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + - shared-inputs + + untraced-sit-sot-inputs + shared-inputs nit-sots + non-sequences @@ -923,7 +967,7 @@ def make_node(self, *inputs): [n_steps] + sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + - shared-inputs + + untraced-sit-sot-inputs + shared-inputs nit-sots + non-sequences @@ -931,7 +975,7 @@ def make_node(self, *inputs): mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + - shared-outputs + untraced-sit-sot-outputs These outer-outputs essentially follow the same form as their corresponding inner-outputs, excluding the final "while" condition @@ -949,7 +993,7 @@ def make_node(self, *inputs): + len(self.info.mit_mot_in_slices) + len(self.info.mit_sot_in_slices) + len(self.inner_sitsot(self.inner_inputs)) - + len(self.inner_shared(self.inner_inputs)) + + len(self.inner_untraced_sit_sot(self.inner_inputs)) + len(self.inner_non_seqs(self.inner_inputs)) ) @@ -1134,60 +1178,60 @@ def make_node(self, *inputs): ) argoffset += len(self.outer_sitsot(inputs)) - # Check that the shared variable and their update rule have the same + # Check that the untraced (u) sit-sot variable and their update rule have the same # dtype. Maybe even same type ?! - for idx, (inner_shared, inner_shared_out, _outer_shared) in enumerate( + for idx, (inner_u_sitsot, inner_u_sitsot_out, _outer_u_sitsot) in enumerate( zip( - self.inner_shared(self.inner_inputs), - self.inner_shared_outs(self.inner_outputs), - self.outer_shared(inputs), + self.inner_untraced_sit_sot(self.inner_inputs), + self.inner_untraced_sit_sot_outs(self.inner_outputs), + self.outer_untraced_sit_sot(inputs), strict=True, ) ): - outer_shared = copy_var_format(_outer_shared, as_var=inner_shared) - new_inputs.append(outer_shared) + outer_u_sitsot = copy_var_format(_outer_u_sitsot, as_var=inner_u_sitsot) + new_inputs.append(outer_u_sitsot) if ( - hasattr(outer_shared, "dtype") - and outer_shared.dtype != inner_shared_out.dtype + hasattr(outer_u_sitsot, "dtype") + and outer_u_sitsot.dtype != inner_u_sitsot_out.dtype ): raise ValueError( err_msg2 % ( - str(outer_shared), + str(outer_u_sitsot), idx + argoffset, - outer_shared.dtype, - inner_shared_out.dtype, + outer_u_sitsot.dtype, + inner_u_sitsot_out.dtype, ) ) if ( - hasattr(outer_shared, "dtype") - and outer_shared.ndim != inner_shared_out.ndim + hasattr(outer_u_sitsot, "dtype") + and outer_u_sitsot.ndim != inner_u_sitsot_out.ndim ): raise ValueError( err_msg3 % ( - str(outer_shared), + str(outer_u_sitsot), idx + argoffset, - outer_shared.ndim, - inner_shared_out.ndim, + outer_u_sitsot.ndim, + inner_u_sitsot_out.ndim, ) ) - if hasattr(outer_shared, "dtype") and ( - outer_shared.dtype != inner_shared.dtype - or outer_shared.ndim != inner_shared.ndim + if hasattr(outer_u_sitsot, "dtype") and ( + outer_u_sitsot.dtype != inner_u_sitsot.dtype + or outer_u_sitsot.ndim != inner_u_sitsot.ndim ): raise ValueError( err_msg1 % ( "initial state (outputs_info in scan nomenclature) ", - str(outer_shared), + str(outer_u_sitsot), argoffset + idx, - outer_shared.dtype, - outer_shared.ndim, - str(inner_shared), - inner_shared.dtype, - inner_shared.ndim, + outer_u_sitsot.dtype, + outer_u_sitsot.ndim, + str(inner_u_sitsot), + inner_u_sitsot.dtype, + inner_u_sitsot.ndim, ) ) # We do not need to call `copy_var_format` on outer_nisot arguments. @@ -1585,7 +1629,7 @@ def p(node, inputs, outputs): try: t_fn, n_steps = scan_perform_ext.perform( - self.info.n_shared_outs, + self.info.n_untraced_sit_sot_outs, self.info.n_mit_mot_outs, self.info.n_seqs, self.info.n_mit_mot, @@ -1719,7 +1763,7 @@ def perform(self, node, inputs, output_storage): # The length of each output store_steps = [ arg.shape[0] - for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset] + for arg in inputs[self.seqs_arg_offset : self.untraced_sit_sot_arg_offset] ] store_steps += list( inputs[self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot] @@ -1784,7 +1828,7 @@ def perform(self, node, inputs, output_storage): info.sit_sot_in_slices, ) ) - + info.n_shared_outs + + info.n_untraced_sit_sot_outs ) for idx in range(len(other_args)): inner_input_storage[idx + offset].storage[0] = other_args[idx] @@ -1827,14 +1871,14 @@ def perform(self, node, inputs, output_storage): ] offset += 1 - a_offset = self.shared_arg_offset + a_offset = self.untraced_sit_sot_arg_offset o_offset = self.n_outs + info.n_nit_sot if i == 0: - for j in range(info.n_shared_outs): + for j in range(info.n_untraced_sit_sot_outs): inner_input_storage[offset].storage[0] = inputs[a_offset + j] offset += 1 else: - for j in range(info.n_shared_outs): + for j in range(info.n_untraced_sit_sot_outs): inner_input_storage[offset].storage[0] = output_storage[ o_offset + j ][0] @@ -1866,14 +1910,14 @@ def perform(self, node, inputs, output_storage): for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot): inner_output_storage[idx + offset].storage[0] = None - # 4.3. Collect slices for shared outputs + # 4.3. Collect slices for untraced sitsot outputs offset += self.n_outs + info.n_nit_sot - info.n_mit_mot - for idx in range(info.n_shared_outs): + for idx in range(info.n_untraced_sit_sot_outs): inner_output_storage[idx + offset].storage[0] = None # 4.4. If there is a condition add it to the mix if info.as_while: - pdx = offset + info.n_shared_outs + pdx = offset + info.n_untraced_sit_sot_outs inner_output_storage[pdx].storage[0] = None # 4.5. Keep a reference to the variables (ndarrays, @@ -1942,7 +1986,7 @@ def perform(self, node, inputs, output_storage): dt_fn = time.perf_counter() - t0_fn if info.as_while: - pdx = offset + info.n_shared_outs + pdx = offset + info.n_untraced_sit_sot_outs cond = inner_output_storage[pdx].storage[0] == 0 t_fn += dt_fn @@ -2089,10 +2133,10 @@ def perform(self, node, inputs, output_storage): j + offset_out ].storage[0] - # 5.6 Copy over the values for outputs corresponding to shared + # 5.6 Copy over the values for outputs corresponding to untraced sitsot # variables begin = end - end += info.n_shared_outs + end += info.n_untraced_sit_sot_outs for j in range(begin, end): jout = j + offset_out output_storage[j][0] = inner_output_storage[jout].storage[0] @@ -2240,13 +2284,13 @@ def infer_shape(self, fgraph, node, input_shapes): # out_equivalent[self.inner_inputs[inner_inp_idx]] = corresponding_tap outer_inp_idx += 1 - # shared_outs + # untraced sit_sot outputs offset = 1 + info.n_seqs + n_outs - for idx in range(info.n_shared_outs): + for idx in range(info.n_untraced_sit_sot_outs): outs_shape += [input_shapes[idx + offset]] # non_sequences - offset += info.n_nit_sot + info.n_shared_outs + offset += info.n_nit_sot + info.n_untraced_sit_sot_outs inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] assert len(inner_ins_shapes) == len(self.inner_inputs) @@ -2288,7 +2332,7 @@ def infer_shape(self, fgraph, node, input_shapes): # in the inner function. r = node.outputs[n_outs + x] assert r.ndim == 1 + len(out_shape_x) - shp = [node.inputs[offset + info.n_shared_outs + x]] + shp = [node.inputs[offset + info.n_untraced_sit_sot_outs + x]] for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True): # Validate shp_i. v_shape_i is either None (if invalid), # or a (variable, Boolean) tuple. The Boolean indicates @@ -2305,7 +2349,7 @@ def infer_shape(self, fgraph, node, input_shapes): shp.append(v_shp_i[0]) scan_outs.append(tuple(shp)) - scan_outs += list(input_shapes[offset : offset + info.n_shared_outs]) + scan_outs += list(input_shapes[offset : offset + info.n_untraced_sit_sot_outs]) # if we are dealing with a repeat-until, then we do not know the # leading dimension so we replace it for every entry with Shape_i if info.as_while: @@ -2735,7 +2779,7 @@ def compute_all_gradients(known_grads): mitmot_inp_taps.append([]) mitmot_out_taps.append([]) undefined_msg = None - through_shared = False + through_untraced = False disconnected = True for mit_mot_out_slice in info.mit_mot_out_slices[idx]: @@ -2779,9 +2823,9 @@ def compute_all_gradients(known_grads): disconnected &= disconnected_dC_dinps_t[ins_pos] - through_shared = any( + through_untraced = any( _sh in graph_inputs([dC_dinps_t[ins_pos]]) - for _sh in self.inner_shared(self_inputs) + for _sh in self.inner_untraced_sit_sot(self_inputs) ) ins_pos += 1 @@ -2795,8 +2839,8 @@ def compute_all_gradients(known_grads): if undefined_msg: type_outs.append(undefined_msg) - elif through_shared: - type_outs.append("through_shared") + elif through_untraced: + type_outs.append("through_untraced") elif disconnected: type_outs.append("disconnected") else: @@ -2814,7 +2858,7 @@ def compute_all_gradients(known_grads): out_pos += 1 n_mitmot_inps += 1 undefined_msg = None - through_shared = False + through_untraced = False disconnected = True mitmot_inp_taps[idx + offset].append(0) for tap in taps: @@ -2836,9 +2880,9 @@ def compute_all_gradients(known_grads): disconnected &= disconnected_dC_dinps_t[ins_pos] - through_shared = any( + through_untraced = any( _sh in graph_inputs([dC_dinps_t[ins_pos]]) - for _sh in self.inner_shared(self_inputs) + for _sh in self.inner_untraced_sit_sot(self_inputs) ) n_mitmot_inps += 1 @@ -2847,8 +2891,8 @@ def compute_all_gradients(known_grads): if undefined_msg: type_outs.append(undefined_msg) - elif through_shared: - type_outs.append("through_shared") + elif through_untraced: + type_outs.append("through_untraced") elif disconnected: type_outs.append("disconnected") else: @@ -2884,15 +2928,15 @@ def compute_all_gradients(known_grads): else: inner_out_mitmot.append(dC_dinps_t[ins_pos]) - through_shared = any( + through_untraced = any( _sh in graph_inputs([dC_dinps_t[ins_pos]]) - for _sh in self.inner_shared(self_inputs) + for _sh in self.inner_untraced_sit_sot(self_inputs) ) if isinstance(dC_dinps_t[ins_pos].type, NullType): type_outs.append(dC_dinps_t[ins_pos].type.why_null) - elif through_shared: - type_outs.append("through_shared") + elif through_untraced: + type_outs.append("through_untraced") elif disconnected_dC_dinps_t[ins_pos]: type_outs.append("disconnected") else: @@ -2911,10 +2955,10 @@ def compute_all_gradients(known_grads): inner_out_nitsot = dC_dinps_t[: info.n_seqs] inner_out_sitsot = dC_dinps_t[ins_pos:] for _p, vl in enumerate(inner_out_sitsot): - through_shared = False - for _sh in self.inner_shared(self_inputs): + through_untraced = False + for _sh in self.inner_untraced_sit_sot(self_inputs): if _sh in graph_inputs([vl]): - through_shared = True + through_untraced = True if isinstance(vl.type, NullType): type_outs.append(vl.type.why_null) # Replace the inner output with a zero tensor of @@ -2922,18 +2966,18 @@ def compute_all_gradients(known_grads): inner_out_sitsot[_p] = pt.zeros( diff_inputs[ins_pos + _p].shape, dtype=config.floatX ) - elif through_shared: - type_outs.append("through_shared") + elif through_untraced: + type_outs.append("through_untraced") elif disconnected_dC_dinps_t[_p + ins_pos]: type_outs.append("disconnected") else: type_outs.append("connected") for _p, vl in enumerate(inner_out_nitsot): - through_shared = False - for _sh in self.inner_shared(self_inputs): + through_untraced = False + for _sh in self.inner_untraced_sit_sot(self_inputs): if _sh in graph_inputs([vl]): - through_shared = True + through_untraced = True if isinstance(vl.type, NullType): type_outs.append(vl.type.why_null) # Replace the inner output with a zero tensor of @@ -2942,8 +2986,8 @@ def compute_all_gradients(known_grads): diff_inputs[_p].shape, dtype=config.floatX ) - if through_shared: - type_outs.append("through_shared") + if through_untraced: + type_outs.append("through_untraced") elif disconnected_dC_dinps_t[_p]: type_outs.append("disconnected") else: @@ -2983,7 +3027,7 @@ def compute_all_gradients(known_grads): + outer_inp_mitmot + outer_inp_sitsot + [n_steps if info.as_while else inputs[0] for _ in range(n_nit_sot)] - + self.outer_shared(inputs) + + self.outer_untraced_sit_sot(inputs) + self.outer_non_seqs(inputs) ) @@ -2991,7 +3035,7 @@ def compute_all_gradients(known_grads): inner_inp_seqs + inner_inp_mitmot + inner_inp_sitsot - + self.inner_shared(self_inputs) + + self.inner_untraced_sit_sot(self_inputs) + self.inner_non_seqs(self_inputs) ) inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot @@ -3003,8 +3047,8 @@ def compute_all_gradients(known_grads): mit_sot_in_slices=(), sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)), n_nit_sot=n_nit_sot, - n_shared_outs=0, - n_non_seqs=len(self.outer_shared(inputs)) + n_untraced_sit_sot_outs=0, + n_non_seqs=len(self.outer_untraced_sit_sot(inputs)) + len(self.outer_non_seqs(inputs)), as_while=False, ) @@ -3047,10 +3091,10 @@ def compute_all_gradients(known_grads): gradients.append(x[::-1]) elif t == "disconnected": gradients.append(DisconnectedType()()) - elif t == "through_shared": + elif t == "through_untraced": gradients.append( grad_undefined( - self, p + 1, inputs[p + 1], "Depends on a shared variable" + self, p + 1, inputs[p + 1], "Depends on a untraced variable" ) ) else: @@ -3075,13 +3119,13 @@ def compute_all_gradients(known_grads): gradients.append(x[::-1]) elif t == "disconnected": gradients.append(DisconnectedType()()) - elif t == "through_shared": + elif t == "through_untraced": gradients.append( grad_undefined( self, p + 1 + info.n_seqs, inputs[p + 1 + info.n_seqs], - "Depends on a shared variable", + "Depends on an untraced variable", ) ) else: @@ -3090,7 +3134,7 @@ def compute_all_gradients(known_grads): start = len(gradients) node = outs[0].owner - for idx in range(info.n_shared_outs): + for idx in range(info.n_untraced_sit_sot_outs): disconnected = True connected_flags = self.connection_pattern(node)[idx + start] for dC_dout, connected in zip(dC_douts, connected_flags, strict=True): @@ -3116,13 +3160,13 @@ def compute_all_gradients(known_grads): gradients.append(x[-1]) elif t == "disconnected": gradients.append(DisconnectedType()()) - elif t == "through_shared": + elif t == "through_untraced": gradients.append( grad_undefined( self, p + begin + 1, inputs[p + begin + 1], - "Depends on a shared variable", + "Depends on a untraced variable", ) ) else: @@ -3152,7 +3196,7 @@ def R_op(self, inputs, eval_points): self_inputs = self.inner_inputs rop_of_inputs = ( self_inputs[: info.n_seqs + self.n_outs] - + self_inputs[info.n_seqs + self.n_outs + info.n_shared_outs :] + + self_inputs[info.n_seqs + self.n_outs + info.n_untraced_sit_sot_outs :] ) self_outputs = self.inner_outputs @@ -3162,8 +3206,8 @@ def R_op(self, inputs, eval_points): rop_self_outputs = self_outputs[:-1] else: rop_self_outputs = self_outputs - if info.n_shared_outs > 0: - rop_self_outputs = rop_self_outputs[: -info.n_shared_outs] + if info.n_untraced_sit_sot_outs > 0: + rop_self_outputs = rop_self_outputs[: -info.n_untraced_sit_sot_outs] rop_outs = Rop( rop_self_outputs, rop_of_inputs, @@ -3247,13 +3291,13 @@ def R_op(self, inputs, eval_points): scan_sit_sot = inputs[b:e] + clean_eval_points inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie] - # Shared outs ... + # Untraced outs ... b = e - e = e + info.n_shared_outs + e = e + info.n_untraced_sit_sot_outs ib = ie - ie = ie + info.n_shared_outs - scan_shared = inputs[b:e] - inner_shared = self_inputs[ib:ie] + ie = ie + info.n_untraced_sit_sot_outs + scan_untraced = inputs[b:e] + inner_untraced = self_inputs[ib:ie] # NIT_SOT sequences b = e @@ -3268,7 +3312,7 @@ def R_op(self, inputs, eval_points): else: clean_eval_points.append(inp.zeros_like()) scan_other = inputs[e:] + clean_eval_points - # inner_eval_points do not have entries for shared variables + # inner_eval_points do not have entries for untraced variables inner_other = self_inputs[ie:] + inner_eval_points[ib:] # Outputs @@ -3287,15 +3331,15 @@ def R_op(self, inputs, eval_points): e = e + info.n_nit_sot inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e] b = e - e = e + info.n_shared_outs - inner_out_shared = self_outputs[b:e] + e = e + info.n_untraced_sit_sot_outs + inner_out_untraced = self_outputs[b:e] inner_ins = ( inner_seqs + inner_mit_mot + inner_mit_sot + inner_sit_sot - + inner_shared + + inner_untraced + inner_other ) inner_outs = ( @@ -3303,7 +3347,7 @@ def R_op(self, inputs, eval_points): + inner_out_mit_sot + inner_out_sit_sot + inner_out_nit_sot - + inner_out_shared + + inner_out_untraced ) if info.as_while: @@ -3314,7 +3358,7 @@ def R_op(self, inputs, eval_points): *scan_mit_mot, *scan_mit_sot, *scan_sit_sot, - *scan_shared, + *scan_untraced, *scan_nit_sot, *scan_other, ] @@ -3326,7 +3370,7 @@ def R_op(self, inputs, eval_points): mit_sot_in_slices=new_mit_sot_in_slices, sit_sot_in_slices=new_sit_sot_in_slices, n_nit_sot=info.n_nit_sot * 2, - n_shared_outs=info.n_shared_outs, + n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs, n_non_seqs=len(inner_other), as_while=info.as_while, ) @@ -3358,7 +3402,7 @@ def R_op(self, inputs, eval_points): b = e + info.n_nit_sot e = e + info.n_nit_sot * 2 final_outs += outputs[b:e] - final_outs += [None] * info.n_shared_outs + final_outs += [None] * info.n_untraced_sit_sot_outs return final_outs diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 0bc4f12143..8cd69d2d6a 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices)) ) st += op_info.n_sit_sot - st += op_info.n_shared_outs + st += op_info.n_untraced_sit_sot_outs op_ins = op.inner_inputs op_outs = op.inner_outputs @@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): + op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot - + op_info.n_shared_outs + + op_info.n_untraced_sit_sot_outs + 1 ) outer_non_seqs = node.inputs[st:] @@ -983,7 +983,7 @@ def attempt_scan_inplace( ls = op.outer_mitmot(node.inputs) ls += op.outer_mitsot(node.inputs) ls += op.outer_sitsot(node.inputs) - ls_end = op.outer_shared(node.inputs) + ls_end = op.outer_untraced_sit_sot(node.inputs) ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_non_seqs(node.inputs) @@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: + idx + op_info.n_seqs + 1 - + op_info.n_shared_outs + + op_info.n_untraced_sit_sot_outs ) if nw_inputs[pos] == node.inputs[0]: nw_inputs[pos] = 1 if required_orphan else val @@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: elif ( idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot ): - in_idx = offset + idx + op_info.n_shared_outs + in_idx = offset + idx + op_info.n_untraced_sit_sot_outs if nw_inputs[in_idx] == node.inputs[0]: nw_inputs[in_idx] = nw_steps @@ -1886,8 +1886,8 @@ def merge(self, nodes): for idx, nd in enumerate(nodes): # Shared - inner_ins[idx].append(nd.op.inner_shared(nd.op.inner_inputs)) - outer_ins += nd.op.outer_shared(nd.inputs) + inner_ins[idx].append(nd.op.inner_untraced_sit_sot(nd.op.inner_inputs)) + outer_ins += nd.op.outer_untraced_sit_sot(nd.inputs) for idx, nd in enumerate(nodes): # NitSot @@ -1897,8 +1897,10 @@ def merge(self, nodes): for idx, nd in enumerate(nodes): # Shared - outer_outs += nd.op.outer_shared_outs(nd.outputs) - inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs)) + outer_outs += nd.op.outer_untraced_sit_sot_outs(nd.outputs) + inner_outs[idx].append( + nd.op.inner_untraced_sit_sot_outs(nd.op.inner_outputs) + ) n_non_seqs = 0 for idx, nd in enumerate(nodes): @@ -1978,7 +1980,9 @@ def merge(self, nodes): mit_sot_in_slices=mit_sot_in_slices, sit_sot_in_slices=sit_sot_in_slices, n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes), - n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes), + n_untraced_sit_sot_outs=sum( + nd.op.info.n_untraced_sit_sot_outs for nd in nodes + ), n_non_seqs=n_non_seqs, as_while=as_while, ) @@ -2360,7 +2364,7 @@ def scan_push_out_dot1(fgraph, node): # When seq[t] is a vector/matrix and `value` is a matrix # Note that this works when only you need X[-1] in the end # and assumes dimshuffle are applied to vectors before calling dot - op = node.op + op: Scan = node.op sitsot_ins = op.inner_sitsot(op.inner_inputs) sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) outer_sitsot = op.outer_sitsot_outs(node.outputs) @@ -2416,9 +2420,13 @@ def scan_push_out_dot1(fgraph, node): inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) outer_nitsot = op.outer_nitsot(node.inputs) inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs) - inner_shared = op.inner_shared(op.inner_inputs) - outer_shared = op.outer_shared(node.inputs) - inner_shared_outs = op.inner_shared_outs(op.inner_outputs) + inner_untraced_sitsot = op.inner_untraced_sitsot(op.inner_inputs) + outer_untraced_sitsot_outs = op.outer_untraced_sitsot_outs( + node.inputs + ) + inner_untraced_sitsot_outs = op.inner_untraced_sitsot_outs( + op.inner_outputs + ) inner_non_seqs = op.inner_non_seqs(op.inner_inputs) outer_non_seqs = op.outer_non_seqs(node.inputs) @@ -2441,7 +2449,7 @@ def scan_push_out_dot1(fgraph, node): + inner_mitmot + inner_mitsot + inner_sitsot - + inner_shared + + inner_untraced_sitsot + inner_non_seqs ) _new_inner_outs = ( @@ -2449,7 +2457,7 @@ def scan_push_out_dot1(fgraph, node): + inner_mitsot_outs + inner_sitsot_outs + inner_nitsot_outs - + inner_shared_outs + + inner_untraced_sitsot_outs ) new_inner_inps, new_inner_outs = reconstruct_graph( _new_inner_inps, _new_inner_outs @@ -2471,7 +2479,7 @@ def scan_push_out_dot1(fgraph, node): *outer_mitmot, *outer_mitsot, *outer_sitsot, - *outer_shared, + *outer_untraced_sitsot_outs, *outer_nitsot, node.inputs[0], *outer_non_seqs, diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index 4f39cfe0ad..d1afef639d 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -370,7 +370,9 @@ def scan_can_remove_outs(op, out_idxs): out_ins += [op.inner_inputs[offset : offset + n_ins]] offset += n_ins out_ins += [[] for k in range(op.info.n_nit_sot)] - out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)] + out_ins += [ + [op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot_outs) + ] added = True out_idxs_mask = [1 for idx in out_idxs] @@ -409,7 +411,7 @@ def compress_outs(op, not_required, inputs): mit_sot_in_slices=(), sit_sot_in_slices=(), n_nit_sot=0, - n_shared_outs=0, + n_untraced_sit_sot_outs=0, n_non_seqs=0, as_while=op_info.as_while, ) @@ -515,17 +517,19 @@ def compress_outs(op, not_required, inputs): info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1) op_outputs += [op.inner_outputs[o_offset]] o_offset += 1 - nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]] + nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot_outs]] else: o_offset += 1 offset += op_info.n_nit_sot shared_ins = [] - for idx in range(op_info.n_shared_outs): + for idx in range(op_info.n_untraced_sit_sot_outs): if offset + idx not in not_required: map_old_new[offset + idx] = curr_pos curr_pos += 1 - info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1) + info = dataclasses.replace( + info, n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs + 1 + ) op_outputs += [op.inner_outputs[o_offset]] o_offset += 1 op_inputs += [op.inner_inputs[i_offset]] @@ -539,7 +543,9 @@ def compress_outs(op, not_required, inputs): # other stuff op_inputs += op.inner_inputs[i_offset:] info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:])) - node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :] + node_inputs += inputs[ + ni_offset + op_info.n_untraced_sit_sot_outs + op_info.n_nit_sot : + ] if op_info.as_while: op_outputs += [op.inner_outputs[o_offset]] map_old_new[o_offset] = len(op_outputs) - 1 @@ -658,11 +664,11 @@ def __init__( p += n_sit_sot q += n_sit_sot - n_shared_outs = info.n_shared_outs - self.outer_in_shared = list(outer_inputs[p : p + n_shared_outs]) - self.inner_in_shared = list(inner_inputs[q : q + n_shared_outs]) - p += n_shared_outs - q += n_shared_outs + n_untraced_sit_sot_outs = info.n_untraced_sit_sot_outs + self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot_outs]) + self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot_outs]) + p += n_untraced_sit_sot_outs + q += n_untraced_sit_sot_outs n_nit_sot = info.n_nit_sot self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot]) @@ -702,10 +708,10 @@ def __init__( p += n_nit_sot q += n_nit_sot - self.outer_out_shared = list(outer_outputs[p : p + n_shared_outs]) - self.inner_out_shared = list(inner_outputs[q : q + n_shared_outs]) - p += n_shared_outs - q += n_shared_outs + self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot_outs]) + self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot_outs]) + p += n_untraced_sit_sot_outs + q += n_untraced_sit_sot_outs assert p == len(outer_outputs) assert q == len(inner_outputs) @@ -816,7 +822,7 @@ def info(self) -> "ScanInfo": mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices), sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot), n_nit_sot=len(self.outer_in_nit_sot), - n_shared_outs=len(self.outer_in_shared), + n_untraced_sit_sot_outs=len(self.outer_in_shared), n_non_seqs=len(self.inner_in_non_seqs), as_while=self.as_while, ) diff --git a/pytensor/scan/views.py b/pytensor/scan/views.py index b86476b330..68d09ea11e 100644 --- a/pytensor/scan/views.py +++ b/pytensor/scan/views.py @@ -16,6 +16,7 @@ def map( go_backwards=False, mode=None, name=None, + return_updates=True, ): """Construct a `Scan` `Op` that functions like `map`. @@ -50,6 +51,7 @@ def map( go_backwards=go_backwards, mode=mode, name=name, + return_updates=return_updates, ) @@ -61,6 +63,7 @@ def reduce( go_backwards=False, mode=None, name=None, + return_updates=True, ): """Construct a `Scan` `Op` that functions like `reduce`. @@ -97,14 +100,29 @@ def reduce( truncate_gradient=-1, mode=mode, name=name, + return_updates=return_updates, ) - if isinstance(rval[0], list | tuple): - return [x[-1] for x in rval[0]], rval[1] + if return_updates: + if isinstance(rval[0], list | tuple): + return [x[-1] for x in rval[0]], rval[1] + else: + return rval[0][-1], rval[1] else: - return rval[0][-1], rval[1] + if isinstance(rval, list | tuple): + return [x[-1] for x in rval] + else: + return rval[-1] -def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None): +def foldl( + fn, + sequences, + outputs_info, + non_sequences=None, + mode=None, + name=None, + return_updates=True, +): """Construct a `Scan` `Op` that functions like Haskell's `foldl`. Parameters @@ -135,10 +153,19 @@ def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None) go_backwards=False, mode=mode, name=name, + return_updates=return_updates, ) -def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None): +def foldr( + fn, + sequences, + outputs_info, + non_sequences=None, + mode=None, + name=None, + return_updates=True, +): """Construct a `Scan` `Op` that functions like Haskell's `foldr`. Parameters @@ -169,4 +196,5 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None) go_backwards=True, mode=mode, name=name, + return_updates=return_updates, ) diff --git a/pytensor/tensor/pad.py b/pytensor/tensor/pad.py index cce7bee619..efe7da88dc 100644 --- a/pytensor/tensor/pad.py +++ b/pytensor/tensor/pad.py @@ -314,11 +314,12 @@ def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable: def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis): - [_, parts], _ = scan( + [_, parts] = scan( inner_func, non_sequences=[array, array_flipped], outputs_info=[0, None], n_steps=repeats, + return_updates=False, ) parts = moveaxis(parts, 0, axis) diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index ff9f4893af..e180d34fd7 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -23,10 +23,11 @@ @pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)]) def test_scan_sit_sot(view): x0 = pt.scalar("x0", dtype="float64") - xs, _ = scan( + xs = scan( lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=10, + return_updates=False, ) if view: xs = xs[view] @@ -37,10 +38,11 @@ def test_scan_sit_sot(view): @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) def test_scan_mit_sot(view): x0 = pt.vector("x0", dtype="float64", shape=(3,)) - xs, _ = scan( + xs = scan( lambda xtm3, xtm1: xtm3 + xtm1 + 1, outputs_info=[{"initial": x0, "taps": [-3, -1]}], n_steps=10, + return_updates=False, ) if view: xs = xs[view] @@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y): def step(xtm3, xtm1, ytm4, ytm2): return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2 - [xs, ys], _ = scan( + [xs, ys] = scan( fn=step, outputs_info=[ {"initial": x0, "taps": [-3, -1]}, {"initial": y0, "taps": [-4, -2]}, ], n_steps=10, + return_updates=False, ) if view_x: xs = xs[view_x] @@ -80,10 +83,8 @@ def test_scan_nit_sot(view): xs = pt.vector("x0", dtype="float64", shape=(10,)) - ys, _ = scan( - lambda x: pt.exp(x), - outputs_info=[None], - sequences=[xs], + ys = scan( + lambda x: pt.exp(x), outputs_info=[None], sequences=[xs], return_updates=False ) if view: ys = ys[view] @@ -106,11 +107,12 @@ def step(xtm1, ytm3, ytm1, rho): rho = pt.scalar("rho", dtype="float64") x0 = pt.vector("xs", shape=(2,)) y0 = pt.vector("ys", shape=(3,)) - [outs, _], _ = scan( + [outs, _] = scan( step, outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}], non_sequences=[rho], n_steps=10, + return_updates=False, ) grads = pt.grad(outs.sum(), wrt=[x0, y0, rho]) compare_jax_and_py( @@ -191,10 +193,11 @@ def update_fn(rng): @pytest.mark.xfail(raises=NotImplementedError) def test_scan_while(): - xs, _ = scan( + xs = scan( lambda x: (x + 1, until(x < 10)), outputs_info=[pt.zeros(())], n_steps=100, + return_updates=False, ) compare_jax_and_py([], [xs], []) @@ -210,7 +213,7 @@ def input_step_fn(y_tm1, y_tm3, a): res.name = "y_t" return res - y_scan_pt, _ = scan( + y_scan_pt = scan( fn=input_step_fn, outputs_info=[ { @@ -223,6 +226,7 @@ def input_step_fn(y_tm1, y_tm3, a): non_sequences=[a_pt], n_steps=10, name="y_scan", + return_updates=False, ) y_scan_pt.name = "y" y_scan_pt.owner.inputs[0].name = "y_all" @@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func): k = 3 # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph - xs, _ = scan( + xs = scan( lambda X, A: A @ X, non_sequences=[A], outputs_info=[x0], n_steps=n_steps, + return_updates=False, ) x0_val = ( @@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq(): A = pt.matrix("A", shape=(k, k)) # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph - xs, _ = scan( + xs = scan( lambda X, A: A @ X, non_sequences=[A], sequences=[x], n_steps=n_steps, + return_updates=False, ) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) @@ -287,11 +293,12 @@ def test_nd_scan_mit_sot(): B = pt.matrix("B", shape=(3, 3)) # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph - xs, _ = scan( + xs = scan( lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1, outputs_info=[{"initial": x0, "taps": [-3, -1]}], non_sequences=[A, B], n_steps=10, + return_updates=False, ) x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) @@ -310,12 +317,13 @@ def step(x, A): return A @ x, x.sum() # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph - xs, _ = scan( + xs = scan( step, outputs_info=[x0, None], non_sequences=[A], n_steps=10, mode=get_mode("JAX"), + return_updates=False, ) x0_val = np.arange(3, dtype=config.floatX) @@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites(): # See issue #426 A = matrix("A") B = matrix("B") - out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) + out = scan( + lambda a, b: a @ b, + outputs_info=[A], + non_sequences=[B], + n_steps=2, + return_updates=False, + ) compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX") @@ -353,8 +367,11 @@ def _(op, **kwargs): x = pt.tensor("x", shape=(None, 3)) - out, _ = scan( - lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x] + out = scan( + lambda x: inc_without_static_shape(x), + outputs_info=[None], + sequences=[x], + return_updates=False, ) f = function([x], out, mode=get_mode("JAX").excluding("scan")) assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 @@ -364,10 +381,11 @@ def _(op, **kwargs): np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3))) # With known static shape we should always manage, regardless of the internal implementation - out2, _ = scan( + out2 = scan( lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape), outputs_info=[None], sequences=[x], + return_updates=False, ) f2 = function([x], out2, mode=get_mode("JAX").excluding("scan")) np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]])) @@ -418,11 +436,12 @@ def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta): it1 = it0 + ct0 - dt0 return st1, et1, it1, logp_c1, logp_d1 - (st, et, it, logp_c_all, logp_d_all), _ = scan( + (st, et, it, logp_c_all, logp_d_all) = scan( fn=seir_one_step, sequences=[C_t, D_t], outputs_info=[st0, et0, it0, None, None], non_sequences=[beta, gamma, delta], + return_updates=False, ) st.name = "S_t" et.name = "E_t" @@ -511,11 +530,12 @@ def cycle_step(A0, A1, A2, A1_hat, _norm, step_num): max_iter = 100 tol = 1e-7 - (*_, A1_hat, norm, _n_steps), _ = scan( + (*_, A1_hat, norm, _n_steps) = scan( step, outputs_info=[A, B, C, B, norm, step_num], non_sequences=[tol], n_steps=max_iter, + return_updates=False, ) A1_hat = A1_hat[-1] diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 2edeff934c..77ceebbcf7 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -85,7 +85,7 @@ 3, [], [np.array([0.50100236, 2.16822932, 1.36326596])], - lambda op: op.info.n_shared_outs > 0, + lambda op: op.info.n_untraced_sit_sot_outs > 0, ), # mit-sot (that's also a type of sit-sot) ( @@ -206,11 +206,12 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): it1 = it0 + ct0 - dt0 return st1, et1, it1, logp_c1, logp_d1 - (st, et, it, logp_c_all, logp_d_all), _ = scan( + (st, et, it, logp_c_all, logp_d_all) = scan( fn=seir_one_step, sequences=[pt_C, pt_D], outputs_info=[st0, et0, it0, logp_c, logp_d], non_sequences=[beta, gamma, delta], + return_updates=False, ) st.name = "S_t" et.name = "E_t" @@ -268,7 +269,7 @@ def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a): y_t.name = "y_t" return x_t, y_t, pt.fill((10,), z_t) - scan_res, _ = scan( + scan_res = scan( fn=input_step_fn, sequences=[ { @@ -297,6 +298,7 @@ def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a): n_steps=5, name="yz_scan", strict=True, + return_updates=False, ) test_input_vals = [ @@ -312,11 +314,12 @@ def power_of_2(previous_power, max_value): return previous_power * 2, until(previous_power * 2 > max_value) max_value = pt.scalar() - values, _ = scan( + values = scan( power_of_2, outputs_info=pt.constant(1.0), non_sequences=max_value, n_steps=1024, + return_updates=False, ) test_input_vals = [ @@ -331,11 +334,12 @@ def test_scan_multiple_none_output(): def power_step(prior_result, x): return prior_result * x, prior_result * x * x, prior_result * x * x * x - result, _ = scan( + result = scan( power_step, non_sequences=[A], outputs_info=[pt.ones_like(A), None, None], n_steps=3, + return_updates=False, ) test_input_vals = (np.array([1.0, 2.0]),) compare_numba_and_py([A], result, test_input_vals) @@ -343,8 +347,12 @@ def power_step(prior_result, x): def test_grad_sitsot(): def get_sum_of_grad(inp): - scan_outputs, _updates = scan( - fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA" + scan_outputs = scan( + fn=lambda x: x * 2, + outputs_info=[inp], + n_steps=5, + mode="NUMBA", + return_updates=False, ) return grad(scan_outputs.sum(), inp).sum() @@ -362,8 +370,11 @@ def test_mitmots_basic(): def inner_fct(seq, state_old, state_current): return state_old * 2 + state_current + seq - out, _ = scan( - inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]} + out = scan( + inner_fct, + sequences=seq, + outputs_info={"initial": init_x, "taps": [-2, -1]}, + return_updates=False, ) g_outs = grad(out.sum(), [seq, init_x]) @@ -383,10 +394,11 @@ def inner_fct(seq, state_old, state_current): def test_inner_graph_optimized(): """Test that inner graph of Scan is optimized""" xs = vector("xs") - seq, _ = scan( + seq = scan( fn=lambda x: log(1 + x), sequences=[xs], mode=get_mode("NUMBA"), + return_updates=False, ) # Disable scan pushout, in which case the whole scan is replaced by an Elemwise @@ -421,13 +433,14 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1): sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2) return mitsot3, sitsot2 - outs, _ = scan( + outs = scan( fn=step, sequences=[seq1, seq2], outputs_info=[ dict(initial=mitsot_init, taps=[-2, -1]), dict(initial=sitsot_init, taps=[-1]), ], + return_updates=False, ) rng = np.random.default_rng(474) @@ -468,7 +481,7 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): y = ytm1 + 1 + ytm2 + a return z, x, z + x + y, y - [zs, xs, ws, ys], _ = scan( + [zs, xs, ws, ys] = scan( fn=step, outputs_info=[ dict(initial=z0, taps=[-3, -1]), @@ -478,6 +491,7 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): ], non_sequences=[a], n_steps=n_steps, + return_updates=False, ) numba_fn, _ = compare_numba_and_py( [n_steps] * (not n_steps_constant) + [a, x0, y0, z0], @@ -529,10 +543,11 @@ def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): class TestScanSITSOTBuffer: def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None): x0 = pt.vector(shape=(op_size,), dtype="float64") - xs, _ = pytensor.scan( + xs = pytensor.scan( fn=lambda xtm1: (xtm1 + 1), outputs_info=[x0], n_steps=n_steps - 1, # 1- makes it easier to align/misalign + return_updates=False, ) if buffer_size == "unit": xs_kept = xs[-1] # Only last state is used @@ -588,12 +603,13 @@ def f_pow2(x_tm2, x_tm1): init_x = pt.vector("init_x", shape=(2,)) n_steps = pt.iscalar("n_steps") - output, _ = scan( + output = scan( f_pow2, sequences=[], outputs_info=[{"initial": init_x, "taps": [-2, -1]}], non_sequences=[], n_steps=n_steps_val if constant_n_steps else n_steps, + return_updates=False, ) init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 98a249c154..b34e6ced28 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -27,7 +27,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian -from pytensor.graph.basic import Apply, equal_computations +from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.replace import vectorize_graph @@ -42,11 +42,13 @@ from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.random import normal +from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.shape import Shape_i, reshape, specify_shape from pytensor.tensor.sharedvar import SharedVariable from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( + TensorType, dcol, dmatrix, dscalar, @@ -65,6 +67,7 @@ vector, ) from tests import unittest_tools as utt +from tests.unittest_tools import assert_equal_computations if config.mode == "FAST_COMPILE": @@ -291,7 +294,7 @@ def inner_fn(x): def test_clone(self): a = vector() - output, _ = scan(fn=lambda x: x**2, sequences=[a]) + output = scan(fn=lambda x: x**2, sequences=[a], return_updates=False) scan_op = output.owner.op assert isinstance(scan_op, Scan) @@ -317,7 +320,7 @@ def f_pow2(x_tm1): state = scalar("state") n_steps = iscalar("nsteps") - output, updates = scan( + output = scan( f_pow2, [], state, @@ -325,10 +328,9 @@ def f_pow2(x_tm1): n_steps=n_steps, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - _my_f = function( - [state, n_steps], output, updates=updates, allow_input_downcast=True - ) + _my_f = function([state, n_steps], output, allow_input_downcast=True) origdir = Path.cwd() tmpdir = None @@ -365,11 +367,9 @@ def f_pow2(x_tm1): state = scalar("state") n_steps = iscalar("nsteps") - output, updates = scan(f_pow2, [], state, [], n_steps=n_steps) + output = scan(f_pow2, [], state, [], n_steps=n_steps, return_updates=False) - f = function( - [state, n_steps], output, updates=updates, allow_input_downcast=True - ) + f = function([state, n_steps], output, allow_input_downcast=True) scan_node = [ node for node in f.maker.fgraph.toposort() if isinstance(node.op, Scan) @@ -407,7 +407,9 @@ def f_pow(x_tm1): return 2 * x_tm1 n_steps = iscalar("n_steps") - values, _ = scan(f_pow, outputs_info=(x_init,), n_steps=n_steps) + values = scan( + f_pow, outputs_info=(x_init,), n_steps=n_steps, return_updates=False + ) update_fn = function((x_init, n_steps), values, mode=mode) @@ -440,7 +442,9 @@ def inner_fn(x_seq, x_i): return 2 * x_i with config.change_flags(mode=mode): - values, _ = scan(inner_fn, outputs_info=(x_init,), sequences=x) + values = scan( + inner_fn, outputs_info=(x_init,), sequences=x, return_updates=False + ) values_fn = function((x_init, x), values) assert isinstance(values.owner.inputs[0].owner.op, Scan) @@ -471,7 +475,7 @@ def inner_fn(x_i): return 2 * x_i with config.change_flags(mode=mode): - values, _ = scan(inner_fn, sequences=x) + values = scan(inner_fn, sequences=x, return_updates=False) values_fn = function((x,), values) assert isinstance(values.owner.op, Scan) @@ -488,7 +492,9 @@ def test_only_nonseq_inputs(self): # Compile the PyTensor function n_steps = 2 inp = matrix() - broadcasted_inp, _ = scan(lambda x: x, non_sequences=[inp], n_steps=n_steps) + broadcasted_inp = scan( + lambda x: x, non_sequences=[inp], n_steps=n_steps, return_updates=False + ) out = broadcasted_inp.sum() gr = grad(out, inp) fun = function([inp], [broadcasted_inp, gr]) @@ -516,7 +522,7 @@ def f_rnn(u_t, x_tm1, W_in, W): W_in = scalar("win") W = scalar("w") - output, updates = scan( + output = scan( f_rnn, u, x0, @@ -524,11 +530,10 @@ def f_rnn(u_t, x_tm1, W_in, W): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - f2 = function( - [u, x0, W_in, W], output, updates=updates, allow_input_downcast=True - ) + f2 = function([u, x0, W_in, W], output, allow_input_downcast=True) # get random initial values rng = np.random.default_rng(utt.fetch_seed()) v_u = rng.uniform(-5.0, 5.0, size=(4,)) @@ -558,7 +563,7 @@ def test_one_sequence_one_output_weights_shared(self): def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W): return u_t * tmp_W_in + x_tm1 * tmp_W - output, updates = scan( + output = scan( f_rnn_shared, u, x0, @@ -566,8 +571,9 @@ def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - f3 = function([u, x0], output, updates=updates, allow_input_downcast=True) + f3 = function([u, x0], output, allow_input_downcast=True) # get random initial values v_u = rng.uniform(-5.0, 5.0, size=(4,)) @@ -685,11 +691,14 @@ def test_using_taps_sequence(self): # this test refers to a bug reported by Nicolas # Boulanger-Lewandowski June 6th x = dvector() - y, updates = scan( - lambda x: [x], sequences=dict(input=x, taps=[-1]), outputs_info=[None] + y = scan( + lambda x: [x], + sequences=dict(input=x, taps=[-1]), + outputs_info=[None], + return_updates=False, ) inp = np.arange(5).astype("float64") - rval = function([x], y, updates=updates)(inp) + rval = function([x], y)(inp) assert np.all(rval == inp[:-1]) def test_output_only(self): @@ -698,11 +707,18 @@ def f_rnn(u_t): u = vector("u") - outputs, updates = scan( - f_rnn, u, [], [], n_steps=None, truncate_gradient=-1, go_backwards=False + outputs = scan( + f_rnn, + u, + [], + [], + n_steps=None, + truncate_gradient=-1, + go_backwards=False, + return_updates=False, ) - f2 = function([u], outputs, updates=updates, allow_input_downcast=True) + f2 = function([u], outputs, allow_input_downcast=True) rng = np.random.default_rng(utt.fetch_seed()) v_u = rng.uniform(-5.0, 5.0, size=(5,)) @@ -719,7 +735,7 @@ def f_rnn(u_t, x_tm1, W_in, W): W_in = scalar("win") W = scalar("w") - output, updates = scan( + output = scan( f_rnn, u, x0, @@ -727,11 +743,10 @@ def f_rnn(u_t, x_tm1, W_in, W): n_steps=None, truncate_gradient=-1, go_backwards=True, + return_updates=False, ) - f2 = function( - [u, x0, W_in, W], output, updates=updates, allow_input_downcast=True - ) + f2 = function([u, x0, W_in, W], output, allow_input_downcast=True) # get random initial values rng = np.random.default_rng(utt.fetch_seed()) v_u = rng.uniform(-5.0, 5.0, size=(4,)) @@ -794,8 +809,8 @@ def incr(s): def test_hash(self): x = vector() y = vector() - scan1, _updates = scan(lambda _x: _x + 1, x) - scan2, _updates = scan(lambda _x: _x + 1, y) + scan1 = scan(lambda _x: _x + 1, x, return_updates=False) + scan2 = scan(lambda _x: _x + 1, y, return_updates=False) assert scan1.owner.op == scan2.owner.op assert hash(scan1.owner.op) == hash(scan2.owner.op) @@ -806,9 +821,24 @@ def test_can_merge(self): y = vector("y") c = scalar("c") - scan_a, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c]) - scan_b, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c]) - scan_c, _ = scan(lambda x, y, c: x + y + c, sequences=[y, x], non_sequences=[c]) + scan_a = scan( + lambda x, y, c: x + y + c, + sequences=[x, y], + non_sequences=[c], + return_updates=False, + ) + scan_b = scan( + lambda x, y, c: x + y + c, + sequences=[x, y], + non_sequences=[c], + return_updates=False, + ) + scan_c = scan( + lambda x, y, c: x + y + c, + sequences=[y, x], + non_sequences=[c], + return_updates=False, + ) assert scan_b is not scan_a assert scan_c is not scan_a @@ -1003,7 +1033,7 @@ def test_while(self): def lambda_fn(x_t): return x_t + 1, until(x_t > 3) - o, _ = scan(lambda_fn, x) + o = scan(lambda_fn, x, return_updates=False) f = function([x], o) vx = np.zeros((50,), dtype=config.floatX) vx[23] = 4 @@ -1016,7 +1046,7 @@ def test_while_infer_shape(self): def lambda_fn(x_t): return x_t + 1, until(x_t > 3) - o, _ = scan(lambda_fn, x) + o = scan(lambda_fn, x, return_updates=False) f = function([x], o.shape[0], mode=mode_with_opt) vx = np.zeros((50,), dtype=config.floatX) @@ -1026,11 +1056,12 @@ def lambda_fn(x_t): def test_infer_shape_nsteps_smaller_seq_length(self): x = vector("x") - [o1, o2], _ = scan( + [o1, o2] = scan( lambda x, y: (x + 1, y + x), sequences=x, outputs_info=[None, x[0]], n_steps=20, + return_updates=False, ) f = function([x], [o1.shape[0], o2.shape[0]], mode=mode_with_opt) @@ -1068,17 +1099,18 @@ def detect_large_outputs(fgraph, i, node, fn): mode = MonitorMode(post_func=detect_large_outputs) # Symbolic description of the result - result, updates = scan( + result = scan( fn=lambda prior_result, A: prior_result * A, outputs_info=pt.ones_like(A), non_sequences=A, n_steps=k, mode=mode, + return_updates=False, ) final_result = result[-1] - f = function(inputs=[A, k], outputs=final_result, updates=updates) + f = function(inputs=[A, k], outputs=final_result) f(np.asarray([2, 3, 0.1, 0, 1], dtype=config.floatX), 4) # There should be 3 outputs greater than 10: prior_result[0] at step 3, @@ -1100,10 +1132,11 @@ def test_inner_grad(self): y.name = "y" gy = grad(y, x) gy.name = "gy" - hy, _updates = scan( + hy = scan( lambda i, gy, x: grad(gy[i] * fc2, x), sequences=pt.arange(gy.shape[0]), non_sequences=[gy, x], + return_updates=False, ) f = function([x, A], hy, allow_input_downcast=True) @@ -1120,8 +1153,13 @@ def test_inner_grad(self): def test_sequence_is_scan(self, mode): """Make sure that a `Scan` can be used as a sequence input to another `Scan`.""" x0 = scalar("x0") - scan_1, _ = scan(lambda x: x + 1, outputs_info={"initial": x0}, n_steps=10) - scan_2, _ = scan(lambda x: x + 1, sequences=[scan_1]) + scan_1 = scan( + lambda x: x + 1, + outputs_info={"initial": x0}, + n_steps=10, + return_updates=False, + ) + scan_2 = scan(lambda x: x + 1, sequences=[scan_1], return_updates=False) with config.change_flags(mode=mode): scan_2_fn = function([x0], scan_2) @@ -1182,7 +1220,7 @@ def get_sum_of_grad(input0, input1): def test_blockwise_scan(self): x = pt.tensor("x", shape=()) - out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10) + out = scan(lambda x: x + 1, outputs_info=[x], n_steps=10, return_updates=False) x_vec = pt.tensor("x_vec", shape=(None,)) out_vec = vectorize_graph(out, {x: x_vec}) @@ -1200,13 +1238,14 @@ def fn(a_m2, a_m1, b_m2, b_m1): a0 = shared(np.arange(2)) b0 = shared(np.arange(2)) - (a, _b), _ = scan( + (a, _b) = scan( fn, outputs_info=[ {"initial": a0, "taps": [-2, -1]}, {"initial": b0, "taps": [-2, -1]}, ], n_steps=2, + return_updates=False, ) grad(a[-1], a0) @@ -1238,8 +1277,11 @@ def inner_fct(seq, state_old, state_current): state_next = state_old * 2 + state_current + seq return state_next - out, _ = scan( - inner_fct, sequences=seq, outputs_info={"initial": x, "taps": [-2, -1]} + out = scan( + inner_fct, + sequences=seq, + outputs_info={"initial": x, "taps": [-2, -1]}, + return_updates=False, ) g_out = grad(out.sum(), [seq, x]) @@ -1299,12 +1341,13 @@ def inner_fn(cond, x, y): new_y = pt.switch(cond, y, sigmoid(x)) return new_cond, new_x, new_y - values, _ = scan( + values = scan( inner_fn, outputs_info=[c, x, y], n_steps=10, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) gX, gY = grad(values[1].sum(), [x, y]) f = function([c, x, y], [gX, gY], allow_input_downcast=True) @@ -1759,11 +1802,12 @@ def inner_fct(inp1, inp2, inp3): outputs_info = [None, dict(initial=out_init, taps=[-3])] - scan_outputs, _ = scan( + scan_outputs = scan( fn=inner_fct, sequences=seq, outputs_info=outputs_info, non_sequences=non_seq, + return_updates=False, ) # Attempt to take various gradients @@ -1831,7 +1875,9 @@ def inner_fct(inp1, inp2, inp3, inp4, inp5, inp6): dict(initial=out_init[3], taps=[-2, -1]), ] - scan_outputs, _ = scan(fn=inner_fct, outputs_info=outputs_info, n_steps=10) + scan_outputs = scan( + fn=inner_fct, outputs_info=outputs_info, n_steps=10, return_updates=False + ) grad(scan_outputs[0].sum(), out_init[1]) @@ -1854,11 +1900,12 @@ def test_grad_multiple_seqs_different_nsteps(self): x = scalar("x") _max_coefficients_supported = 1000 full_range = pt.arange(_max_coefficients_supported) - components, _updates = scan( + components = scan( fn=lambda coeff, power, free_var: coeff * (free_var**power), outputs_info=None, sequences=[c, full_range], non_sequences=x, + return_updates=False, ) P = components.sum() dP = grad(P, x) @@ -1874,11 +1921,12 @@ def test_grad_of_grad_of_state(self): x = scalar("x") _max_coefficients_supported = 1000 full_range = pt.arange(_max_coefficients_supported) - components, _updates = scan( + components = scan( fn=lambda coeff, power, free_var: coeff * (free_var**power), outputs_info=None, sequences=[c, full_range], non_sequences=x, + return_updates=False, ) P = components.sum() dP = grad(P, x).sum() @@ -1965,8 +2013,13 @@ def rnn_fn(_u, _y, _W): _W = specify_shape(W, v_W.shape) _W.name = "_W" - o, _ = scan( - rnn_fn, sequences=_u, outputs_info=_h0, non_sequences=_W, name="rnn_fn" + o = scan( + rnn_fn, + sequences=_u, + outputs_info=_h0, + non_sequences=_W, + name="rnn_fn", + return_updates=False, ) o = o[-1] eu = matrix("eu") @@ -1980,25 +2033,28 @@ def rnn_fn(_u, _y, _W): [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore" ) - n2o_u, _ = scan( + n2o_u = scan( lambda i, o, u, h0, W, eu: (grad(o[i], u) * eu).sum(), sequences=pt.arange(o.shape[0]), non_sequences=[o, u, h0, W, eu], name="jacobU", + return_updates=False, ) - n2o_h0, _ = scan( + n2o_h0 = scan( lambda i, o, u, h0, W, eh0: (grad(o[i], h0) * eh0).sum(), sequences=pt.arange(o.shape[0]), non_sequences=[o, u, h0, W, eh0], name="jacobh", + return_updates=False, ) - n2o_W, _ = scan( + n2o_W = scan( lambda i, o, u, h0, W, eW: (grad(o[i], W) * eW).sum(), sequences=pt.arange(o.shape[0]), non_sequences=[o, u, h0, W, eW], name="jacobW", + return_updates=False, ) fn_test = function( @@ -2129,10 +2185,11 @@ def test_R_op_mitmot(self, use_op_rop_implementation): transfer = sigmoid - hidden_rec, _ = scan( + hidden_rec = scan( lambda x, h_tm1: transfer(dot(h_tm1, W2) + x), sequences=hidden, outputs_info=[pt.zeros_like(hidden[0])], + return_updates=False, ) hidden_rec.reshape( @@ -2165,12 +2222,13 @@ def test_second_derivative_disconnected_cost_with_mit_mot(self): def step(s, xtm2, xtm1, z): return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2) - xs, _ = scan( + xs = scan( step, sequences=[seq], outputs_info=[{"initial": x0, "taps": (-2, -1)}], non_sequences=[z], n_steps=2, + return_updates=False, ) last_x = xs[-1] @@ -2251,11 +2309,12 @@ def step(s, data, rng): raise ValueError(f"Invalid case: {case}") seq = vector("seq") - xs, _ = scan( + xs = scan( step, sequences=[seq], non_sequences=non_sequences, strict=strict, + return_updates=False, ) x0 = xs[0] @@ -2295,7 +2354,7 @@ def perform(self, node, inputs, outputs): def scan_fn(): return myop(pt.as_tensor(1)) - res, _ = scan(scan_fn, n_steps=4, mode=mode) + res = scan(scan_fn, n_steps=4, mode=mode, return_updates=False) res_fn = function([], res, mode=mode) @@ -2325,14 +2384,14 @@ def f_py(): py_res = f_py() s_r = pt.as_tensor_variable(r, dtype=config.floatX) - s_y, updates = scan( + s_y = scan( fn=lambda ri, rii, M: ri + M * rii, sequences=[s_r[1:]], non_sequences=[pt.as_tensor_variable(M, dtype=config.floatX)], outputs_info=s_r[0], mode=Mode(linker="cvm", optimizer="fast_run"), + return_updates=False, ) - assert not updates f_cvm = function([], s_y, mode="FAST_RUN") f_cvm.trust_input = True @@ -2354,9 +2413,7 @@ def test_compute_test_values(): y = shared(np.arange(3, dtype=config.floatX), name="y") - z, updates = scan(fn=lambda u, v: u + v, sequences=[x, y]) - - assert not updates + z = scan(fn=lambda u, v: u + v, sequences=[x, y], return_updates=False) z_grad = grad(z.sum(), x) @@ -2365,9 +2422,9 @@ def test_compute_test_values(): # Use `non_sequences` this time y = shared(np.arange(9, dtype=config.floatX).reshape(3, 3), name="y") - z, updates = scan(fn=lambda u, v: u + v, sequences=[x], non_sequences=[y]) - - assert not updates + z = scan( + fn=lambda u, v: u + v, sequences=[x], non_sequences=[y], return_updates=False + ) z_grad = grad(z.sum(), x) @@ -2396,20 +2453,22 @@ def loss_mi(mi, sum_mi, W): def loss_ti(ti, sum_ti, mi, W): return W.sum().sum().sum() + sum_ti - result_ti, _ = scan( + result_ti = scan( fn=loss_ti, outputs_info=outputs_ti, sequences=pt.arange(W.shape[1], dtype="int32"), non_sequences=[mi, W], + return_updates=False, ) lossmi = result_ti[-1] return sum_mi + lossmi - result_mi, _ = scan( + result_mi = scan( fn=loss_mi, outputs_info=outputs_mi, sequences=pt.arange(W.shape[0], dtype="int32"), non_sequences=[W], + return_updates=False, ) loss = result_mi[-1] @@ -2433,11 +2492,12 @@ def test_compute_test_value_grad_cast(): name="w", ) - outputs, _ = scan( + outputs = scan( lambda i, h, w: (dot(h[i], w), i), outputs_info=[None, 0], non_sequences=[h, w], n_steps=3, + return_updates=False, ) grad(outputs[0].sum(), w) @@ -2446,11 +2506,12 @@ def test_compute_test_value_grad_cast(): def test_constant_folding_n_steps(): # The following code used to crash at revision 2060b8f, in the constant # folding optimization step. - res, _ = scan( + res = scan( lambda x: x * 2, outputs_info=pt.ones(()), # The constant `n_steps` was causing the crash. n_steps=10, + return_updates=False, ) with config.change_flags(on_opt_error="raise"): function([], res)() @@ -2475,10 +2536,11 @@ def f(x, y): def test_inconsistent_broadcast_error(): x = tensor3() initial_x = pt.constant(np.zeros((1, 10))) - y, _updates = scan( + y = scan( fn=lambda x, prev_x: x + prev_x, sequences=x, outputs_info=[dict(initial=initial_x)], + return_updates=False, ) # Error, because the broadcast patterns are inconsistent. with pytest.raises(TypeError): @@ -2506,10 +2568,11 @@ def setup_method(self): self.numpy_gradient = 2 * np.concatenate([self.seq[:7], z], axis=0) def test_grad_until(self): - r, _ = scan( + r = scan( lambda x, u: (x * x, until(x > u)), sequences=self.x, non_sequences=[self.threshold], + return_updates=False, ) g = grad(r.sum(), self.x) f = function([self.x, self.threshold], [r, g]) @@ -2525,10 +2588,11 @@ def tile_array(inp): X = matrix(name="x") arr = tile_array(self.seq) - r, _ = scan( + r = scan( lambda x, u: (x * x, until(pt_all(x > u))), sequences=X, non_sequences=[self.threshold], + return_updates=False, ) g = grad(r.sum(), X) f = function([X, self.threshold], [r, g]) @@ -2539,11 +2603,12 @@ def tile_array(inp): def test_grad_until_and_truncate(self): n = 3 - r, _ = scan( + r = scan( lambda x, u: (x * x, until(x > u)), sequences=self.x, non_sequences=[self.threshold], truncate_gradient=n, + return_updates=False, ) g = grad(r.sum(), self.x) f = function([self.x, self.threshold], [r, g]) @@ -2555,11 +2620,12 @@ def test_grad_until_and_truncate(self): def test_grad_until_and_truncate_sequence_taps(self): n = 3 - r, _ = scan( + r = scan( lambda x, y, u: (x * y, until(y > u)), sequences=dict(input=self.x, taps=[-2, 0]), non_sequences=[self.threshold], truncate_gradient=n, + return_updates=False, ) g = grad(r.sum(), self.x) f = function([self.x, self.threshold], [r, g]) @@ -2578,8 +2644,12 @@ def accum(seq_t, prev_sum): new_sum = prev_sum + seq_t return new_sum - rs, _updates = scan( - fn=accum, sequences={"input": seq, "taps": [2]}, outputs_info=0, n_steps=1 + rs = scan( + fn=accum, + sequences={"input": seq, "taps": [2]}, + outputs_info=0, + n_steps=1, + return_updates=False, ) f = function(inputs=[seq], outputs=rs) @@ -2664,7 +2734,12 @@ def scan_body(size): def test_profile_info(): from pytensor.scan.utils import ScanProfileStats - z, _updates = scan(fn=lambda u: u + 1, sequences=[pt.arange(10)], profile=True) + z = scan( + fn=lambda u: u + 1, + sequences=[pt.arange(10)], + profile=True, + return_updates=False, + ) assert isinstance(z.owner.op, Scan) fn = z.owner.op.fn @@ -2673,8 +2748,11 @@ def test_profile_info(): assert fn.profile.name == "scan_fn" # Set the `ScanProfileStats` name - z, _updates = scan( - fn=lambda u: u + 1, sequences=[pt.arange(10)], profile="profile_name" + z = scan( + fn=lambda u: u + 1, + sequences=[pt.arange(10)], + profile="profile_name", + return_updates=False, ) assert isinstance(z.owner.op, Scan) @@ -2685,7 +2763,12 @@ def test_profile_info(): # Use an existing profile object profile = fn.profile - z, _updates = scan(fn=lambda u: u + 1, sequences=[pt.arange(10)], profile=profile) + z = scan( + fn=lambda u: u + 1, + sequences=[pt.arange(10)], + profile=profile, + return_updates=False, + ) assert isinstance(z.owner.op, Scan) fn = z.owner.op.fn @@ -2816,7 +2899,7 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1): y_tm1 + dot(x_tm1, W_out), ] - outputs, updates = scan( + outputs = scan( f_rnn_cmpl, [u1, u2], [None, None, x0, dict(initial=y0, taps=[-1, -3])], @@ -2824,11 +2907,10 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - f4 = function( - [u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True - ) + f4 = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True) # compute the values in numpy v_x = np.zeros((3, 2), dtype=config.floatX) @@ -2854,8 +2936,12 @@ def test_scan_as_tensor_on_gradients(self, benchmark): def scanStep(prev, seq, f1): return prev + f1 * seq - scanned, _ = scan( - fn=scanStep, sequences=[seq], outputs_info=[to_scan], non_sequences=[f1] + scanned = scan( + fn=scanStep, + sequences=[seq], + outputs_info=[to_scan], + non_sequences=[f1], + return_updates=False, ) function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True) @@ -2876,8 +2962,12 @@ def one_step(x_t, h_tm1, W): expr = dot(h_tm1, W) + x_t return expr - expr, _ = scan( - fn=one_step, sequences=[inpt], outputs_info=[initial], non_sequences=[W] + expr = scan( + fn=one_step, + sequences=[inpt], + outputs_info=[initial], + non_sequences=[W], + return_updates=False, ) v1 = shared(np.ones(5, dtype=config.floatX)) @@ -2914,11 +3004,12 @@ def test_use_scan_direct_output(self): x = scalar() seq = vector() outputs_info = [x, pt.zeros_like(x)] - (out1, out2), _updates = scan( + (out1, out2) = scan( lambda a, b, c: (a + b, b + c), sequences=seq, outputs_info=outputs_info, mode=mode, + return_updates=False, ) # Obtain a reference to the scan outputs before the subtensor and @@ -2953,8 +3044,11 @@ def test_use_scan_direct_output2(self): x = dcol() seq = dcol() outputs_info = [x, pt.zeros_like(x)] - (out1, out2), _updates = scan( - lambda a, b, c: (a + b, a + c), sequences=seq, outputs_info=outputs_info + (out1, out2) = scan( + lambda a, b, c: (a + b, a + c), + sequences=seq, + outputs_info=outputs_info, + return_updates=False, ) # Obtain a reference to the scan outputs before the subtensor and @@ -3093,7 +3187,9 @@ def onestep(x, x_tm4): seq = matrix() initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) outputs_info = [{"initial": initial_value, "taps": [-4]}, None] - results, _updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) + results = scan( + fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False + ) f = function([seq], results[1]) assert np.all(exp_out == f(inp)) @@ -3116,7 +3212,9 @@ def onestep(x, x_tm4): seq = matrix() initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) outputs_info = [{"initial": initial_value, "taps": [-4]}, None] - results, _ = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) + results = scan( + fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False + ) sharedvar = shared(np.zeros((1, 1), dtype=config.floatX)) updates = {sharedvar: results[0][-1:]} @@ -3161,7 +3259,7 @@ def inner_fn(tap_m3, tap_m2, tap_m1): init = matrix() outputs_info = [None, None, None, None, dict(initial=init, taps=[-3, -2, -1])] - out, _ = scan(inner_fn, outputs_info=outputs_info, n_steps=3) + out = scan(inner_fn, outputs_info=outputs_info, n_steps=3, return_updates=False) fct = function([init], out) # Compare obtained outputs with expected outputs @@ -3194,21 +3292,23 @@ def loss_outer(sum_outer, W): def loss_inner(sum_inner, W): return sum_inner + (W**2).sum() - result_inner, _ = scan( + result_inner = scan( fn=loss_inner, outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)), non_sequences=[W], n_steps=1, + return_updates=False, ) return sum_outer + result_inner[-1] # Also test return_list for that case. - result_outer, _ = scan( + result_outer = scan( fn=loss_outer, outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)), non_sequences=[W], n_steps=n_steps, return_list=True, + return_updates=False, ) cost = result_outer[0][-1] @@ -3227,7 +3327,9 @@ def inner_fn(x_tm1, y_tm1, z_tm1): x0 = vector("X") y0 = vector("y0") z0 = vector("Z") - [x, y, z], _ = scan(inner_fn, outputs_info=[x0, y0, z0], n_steps=10) + [x, y, z] = scan( + inner_fn, outputs_info=[x0, y0, z0], n_steps=10, return_updates=False + ) cost = (x + y + z).sum() grad(cost, x0) # defined @@ -3244,7 +3346,12 @@ def test_disconnected_gradient(self): m = matrix("m") u0 = pt.zeros((7,)) - [_u, m2], _ = scan(lambda _, u: [u, v], sequences=m, outputs_info=[u0, None]) + [_u, m2] = scan( + lambda _, u: [u, v], + sequences=m, + outputs_info=[u0, None], + return_updates=False, + ) # This used to raise an exception with older versions because for a # disconnected gradient a non disconnected type was returned grad((m * m2).sum(), v) @@ -3254,8 +3361,11 @@ def test_disconnected_gradient2(self): m = matrix("m") u0 = pt.zeros((7,)) - [_u, m2], _ = scan( - lambda x, u: [x + u, u + v], sequences=m, outputs_info=[u0, None] + [_u, m2] = scan( + lambda x, u: [x + u, u + v], + sequences=m, + outputs_info=[u0, None], + return_updates=False, ) # This used to raise an exception with older versions because # scan could not detect the connection between `m2` and `x` @@ -3275,7 +3385,7 @@ def step(seq): out2 = out1 + 1 return out1, out2 - [_out1, out2], _ = scan(step, sequences=v) + [_out1, out2] = scan(step, sequences=v, return_updates=False) gv = grad(out2.sum(), [v]) f = function([v], gv) @@ -3286,7 +3396,13 @@ def step(seq): def test_grad_bug_disconnected_input(self): W = shared(np.zeros((3, 3)), name="W") v = ivector(name="v") - y, _ = scan(lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=W) + y = scan( + lambda i, W: W[i], + sequences=v, + outputs_info=None, + non_sequences=W, + return_updates=False, + ) # This used to raise an exception f = function([v], grad(y.sum(), W)) @@ -3296,10 +3412,8 @@ def test_grad_find_input(self): w = shared(np.array(0, dtype="float32"), name="w") init = fscalar("init") - out, _ = scan( - fn=lambda prev: w, - outputs_info=init, - n_steps=2, + out = scan( + fn=lambda prev: w, outputs_info=init, n_steps=2, return_updates=False ) grad(out[-1], w) @@ -3323,7 +3437,7 @@ def test_using_taps_input_output(self): def f_rnn_shared(u_tm2, x_tm1, x_tm2): return u_tm2 * W_in + x_tm1 * W + x_tm2 - outputs, updates = scan( + outputs = scan( f_rnn_shared, dict(input=u, taps=-2), dict(initial=x0, taps=[-1, -2]), @@ -3331,9 +3445,10 @@ def f_rnn_shared(u_tm2, x_tm1, x_tm2): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - f7 = function([u, x0], outputs, updates=updates, allow_input_downcast=True) + f7 = function([u, x0], outputs, allow_input_downcast=True) pytensor_out = f7(vu, vx0) # compute output in numpy @@ -3369,7 +3484,7 @@ def test_past_future_taps_shared(self): def f_rnn_shared(u_tm2, u_tp2, x_tm1, x_tm2): return (u_tm2 + u_tp2) * W_in + x_tm1 * W + x_tm2 - output, updates = scan( + output = scan( f_rnn_shared, dict(input=u, taps=[-2, 2]), dict(initial=x0, taps=[-1, -2]), @@ -3377,9 +3492,10 @@ def f_rnn_shared(u_tm2, u_tp2, x_tm1, x_tm2): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - f8 = function([u, x0], output, updates=updates, allow_input_downcast=True) + f8 = function([u, x0], output, allow_input_downcast=True) pytensor_out = f8(vu, vx0) # compute output in numpy numpy_out = np.zeros(2) @@ -3401,7 +3517,7 @@ def f_pow2(x_tm1): state = scalar("state") n_steps = iscalar("nsteps") # Test return_list at the same time. - output, updates = scan( + output = scan( f_pow2, [], state, @@ -3410,10 +3526,9 @@ def f_pow2(x_tm1): truncate_gradient=-1, return_list=True, go_backwards=False, + return_updates=False, ) - my_f = function( - [state, n_steps], output, updates=updates, allow_input_downcast=True - ) + my_f = function([state, n_steps], output, allow_input_downcast=True) rng = np.random.default_rng(utt.fetch_seed()) state = rng.uniform() @@ -3443,10 +3558,11 @@ def _active(x, pre_h): pre_h = dot(x, W_x) return pre_h - value, _scan_updates = scan( + value = scan( _active, sequences=X, outputs_info=[pt.alloc(floatx(0.0), 1, out_size)], + return_updates=False, ) cost = mean(value) gW_x = grad(cost, W_x) @@ -3464,7 +3580,7 @@ def accum(prev_value, step): condition = until(new_value > max_value) return [new_value, new_step], condition - rs, _updates = scan(fn=accum, outputs_info=[0, 0], n_steps=n_steps) + rs = scan(fn=accum, outputs_info=[0, 0], n_steps=n_steps, return_updates=False) f = function(inputs=[max_value, n_steps], outputs=rs) @@ -3484,33 +3600,37 @@ def test_outputs_info_not_typed(self): # Generate the components of the polynomial full_range = pt.arange(max_coefficients_supported) - components, _updates = scan( + components = scan( fn=lambda coeff, power, free_var: coeff * (free_var**power), sequences=[coefficients, full_range], non_sequences=x, + return_updates=False, ) polynomial1 = components.sum() - polynomial2, _updates = scan( + polynomial2 = scan( fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power), outputs_info=pt.constant(0, dtype="floatX"), sequences=[coefficients, full_range], non_sequences=x, + return_updates=False, ) # python int - polynomial3, _updates = scan( + polynomial3 = scan( fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power), outputs_info=0, sequences=[coefficients, full_range], non_sequences=x, + return_updates=False, ) # python float - polynomial4, _updates = scan( + polynomial4 = scan( fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power), outputs_info=0.0, sequences=[coefficients, full_range], non_sequences=x, + return_updates=False, ) calculate_polynomial = function( @@ -3573,8 +3693,12 @@ def one_step(v, W): # o = v + 1 # <-- this line works return o - OS, _updates = scan( - fn=one_step, sequences=V, outputs_info=[None], non_sequences=[W] + OS = scan( + fn=one_step, + sequences=V, + outputs_info=[None], + non_sequences=[W], + return_updates=False, ) O = OS.sum() + W.sum() @@ -3588,11 +3712,12 @@ def one_step(v, W): ) def test_infershape_seq_shorter_nsteps(self): x = vector("x") - [o1, o2], _ = scan( + [o1, o2] = scan( lambda x, y: (x + 1, y + x), sequences=x, outputs_info=[None, x[0]], n_steps=20, + return_updates=False, ) f = function([x], [o1, o2], mode=mode_with_opt) @@ -3650,67 +3775,6 @@ def lm(m): if config.mode != "FAST_COMPILE": assert nb_shape_i == 1 - def test_return_steps(self): - rng = np.random.default_rng(utt.fetch_seed()) - - vW_in2 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,))) - vW = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2))) - vWout = asarrayX(rng.uniform(-0.5, 0.5, size=(2,))) - vW_in1 = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2))) - v_u1 = asarrayX(rng.uniform(-0.5, 0.5, size=(8, 2))) - v_u2 = asarrayX(rng.uniform(-0.5, 0.5, size=(8,))) - v_x0 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,))) - v_y0 = asarrayX(rng.uniform(size=(3,))) - - W_in2 = shared(vW_in2, name="win2") - W = shared(vW, name="w") - W_out = shared(vWout, name="wout") - W_in1 = matrix("win") - u1 = matrix("u1") - u2 = vector("u2") - x0 = vector("x0") - y0 = vector("y0") - - def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1): - return [ - y_tm3 + 1, - dot(u1_t, W_in1) + u2_t * W_in2 + dot(x_tm1, W), - y_tm1 + dot(x_tm1, W_out), - ] - - rval, updates = scan( - f_rnn_cmpl, - [u1, u2], - [None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])], - W_in1, - n_steps=None, - truncate_gradient=-1, - go_backwards=False, - ) - - outputs = [] - outputs += [rval[0][-3:]] - outputs += [rval[1][-2:]] - outputs += [rval[2][-4:]] - f4 = function( - [u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True - ) - - # compute the values in numpy - v_x = np.zeros((8, 2), dtype=config.floatX) - v_y = np.zeros((8,), dtype=config.floatX) - v_x[0] = np.dot(v_u1[0], vW_in1) + v_u2[0] * vW_in2 + np.dot(v_x0, vW) - v_y[0] = np.dot(v_x0, vWout) + v_y0[2] - - for i in range(1, 8): - v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW) - v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1] - - (_pytensor_dump, pytensor_x, pytensor_y) = f4(v_u1, v_u2, v_x0, v_y0, vW_in1) - - utt.assert_allclose(pytensor_x, v_x[-2:]) - utt.assert_allclose(pytensor_y, v_y[-4:]) - def test_until_random_infer_shape(self): """ Test for a crash in scan.infer_shape when using both @@ -3725,10 +3789,14 @@ def inner_fct(previous_val): condition = until(previous_val > 5) return new_val, condition - out, _updates = scan(inner_fct, outputs_info=x, n_steps=10) + out, updates = scan(inner_fct, outputs_info=x, n_steps=10) g_out = grad(out.sum(), x) - fct = function([x], [out, g_out]) + fct = function( + [x], + [out, g_out], + updates=updates, + ) for i in range(-5, 5): output, g_output = fct(i) @@ -3760,7 +3828,7 @@ def step(seq1, sitsot_m1, mitsot_m2, mitsot_m1): ) return next_sitsot_val, next_mitsot_val, nitsot_out - out, _updates = scan( + out = scan( fn=step, sequences=seq, outputs_info=[ @@ -3769,6 +3837,7 @@ def step(seq1, sitsot_m1, mitsot_m2, mitsot_m1): None, ], n_steps=5, + return_updates=False, ) f = function([seq, sitsot_init, mitsot_init], out[2].shape) @@ -3804,7 +3873,7 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1): dot(x_tm1, W_out), ] - outputs, updates = scan( + outputs = scan( f_rnn_cmpl, [u1, u2], [x0, y0], @@ -3812,11 +3881,10 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - f4 = function( - [u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True - ) + f4 = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True) # compute the values in numpy v_x = np.zeros((3, 2), dtype=config.floatX) @@ -3860,7 +3928,7 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1): dot(u1_t, W_in1), ] - outputs, updates = scan( + outputs = scan( f_rnn_cmpl, [u1, dict(input=u2, taps=[-1, 0, 1])], [x0, dict(initial=y0, taps=[-1, -3]), None], @@ -3868,11 +3936,10 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - f = function( - [u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True - ) + f = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True) ny0 = np.zeros((5, 2)) ny1 = np.zeros((5,)) @@ -3962,13 +4029,14 @@ def one_step(x_t, h_tm2, h_tm1, W_ih, W_hh, b_h, W_ho, b_o): return [h_t, y_t] # hidden and outputs of the entire sequence - [_h, y], _ = scan( + [_h, y] = scan( fn=one_step, sequences=dict(input=x), # corresponds to the return type of one_step outputs_info=[dict(initial=h0, taps=[-2, -1]), None], non_sequences=[W_ih, W_hh, b_h, W_ho, b_o], mode=mode, + return_updates=False, ) # target values @@ -4068,7 +4136,7 @@ def test_grad_multiple_outs_some_disconnected_2(self): [{}], [], 3, - lambda op: op.info.n_shared_outs > 0, + lambda op: op.info.n_untraced_sit_sot_outs > 0, ), # mit-sot (that's also a type of sit-sot) ( @@ -4142,7 +4210,7 @@ def fn(n): outer-output arrays are initialized using the outer-input arrays, the shape difference needs to be handled correctly. """ - s_in_y, _ = scan( + s_in_y = scan( fn=lambda z: (z + 1, until(z > 2)), outputs_info=[ {"taps": [-1], "initial": pt.as_tensor(0.0, dtype=np.float64)} @@ -4150,16 +4218,18 @@ def fn(n): mode=mode, n_steps=n - 1, allow_gc=False, + return_updates=False, ) return s_in_y.sum() - s_y, _updates = scan( + s_y = scan( fn=fn, outputs_info=[None], sequences=[pt.as_tensor([3, 2, 1], dtype=np.int64)], mode=mode, allow_gc=False, + return_updates=False, ) f_cvm = function([], s_y, mode=mode) @@ -4167,3 +4237,74 @@ def fn(n): res = f_cvm() assert np.array_equal(res, np.array([3, 1, 0])) + + +def test_rng_outputs_info(): + rng_init = random_generator_type("rng") + rng_x0, x0 = pt.random.normal(0, rng=rng_init, dtype="float64").owner.outputs + + def step(prev_x, prev_rng): + next_rng, next_x = pt.random.normal( + prev_x, rng=prev_rng, dtype="float64" + ).owner.outputs + return next_x, next_rng + + [xs, rng_final] = scan( + fn=step, + outputs_info=[x0, rng_x0], + n_steps=10, + return_updates=False, + ) + assert isinstance(xs.type, TensorType) + assert isinstance(rng_final.type, RandomGeneratorType) + + fn = function([rng_init], [xs, rng_final]) + xs_eval, rng_final_eval = fn(np.random.default_rng(0)) + + rng_ref = np.random.default_rng(0) + assert not random_generator_type.values_eq(rng_ref, rng_final_eval) + xs_ref = [rng_ref.normal(0)] + for i in range(10): + xs_ref.append(rng_ref.normal(xs_ref[-1])) + assert random_generator_type.values_eq(rng_ref, rng_final_eval) + np.testing.assert_allclose(xs_eval, xs_ref[1:]) + + +@pytest.mark.filterwarnings("error") +def test_return_updates_api_change(): + err_msg = "return_updates=False but Scan produced updates" + warn_msg = "Scan return signature will change. Updates dict will not be returned" + + x = shared(np.array(0, dtype="float64")) + + with pytest.warns(DeprecationWarning, match=warn_msg): + traced1, updates1 = scan( + lambda: {x: x + 1}, + outputs_info=[], + n_steps=5, + ) + assert traced1 is None + assert len(updates1) == 1 and x in updates1 + + with pytest.warns(DeprecationWarning, match=warn_msg): + traced2, updates2 = scan( + lambda x: x + 1, + outputs_info=[x], + n_steps=5, + ) + assert isinstance(traced2, Variable) + assert isinstance(updates2, dict) and not updates2 + + traced3 = scan( + lambda x: x + 1, + outputs_info=[x], + n_steps=5, + return_updates=False, + ) + assert isinstance(traced3, Variable) + + assert_equal_computations(list(updates1.values()), [traced2[-1]]) + assert_equal_computations([traced2], [traced3]) + + with pytest.raises(ValueError, match=err_msg): + scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False) diff --git a/tests/scan/test_checkpoints.py b/tests/scan/test_checkpoints.py index b30c1582fe..345dc2c0e2 100644 --- a/tests/scan/test_checkpoints.py +++ b/tests/scan/test_checkpoints.py @@ -9,44 +9,53 @@ from pytensor.tensor.type import iscalar, vector +@pytest.mark.parametrize("return_updates", [True, False]) class TestScanCheckpoint: - def setup_method(self): + def setup_method(self, return_updates): self.k = iscalar("k") self.A = vector("A") seq = arange(self.k, dtype="float32") + 1 - result, _ = scan( + result_raw = scan( fn=lambda s, prior_result, A: prior_result * A / s, outputs_info=ones_like(self.A), sequences=[seq], non_sequences=self.A, n_steps=self.k, + return_updates=return_updates, ) - result_check, _ = scan_checkpoints( + result_check_raw = scan_checkpoints( fn=lambda s, prior_result, A: prior_result * A / s, outputs_info=ones_like(self.A), sequences=[seq], non_sequences=self.A, n_steps=self.k, save_every_N=100, + return_updates=return_updates, ) + if return_updates: + result, _ = result_raw + result_check, _ = result_check_raw + else: + result = result_raw + result_check = result_check_raw self.result = result[-1] self.result_check = result_check[-1] self.grad_A = grad(self.result.sum(), self.A) self.grad_A_check = grad(self.result_check.sum(), self.A) - def test_forward_pass(self): + def test_forward_pass(self, return_updates): # Test forward computation of A**k. f = function(inputs=[self.A, self.k], outputs=[self.result, self.result_check]) out, out_check = f(range(10), 101) assert np.allclose(out, out_check) - def test_backward_pass(self): + def test_backward_pass(self, return_updates): # Test gradient computation of A**k. f = function(inputs=[self.A, self.k], outputs=[self.grad_A, self.grad_A_check]) out, out_check = f(range(10), 101) assert np.allclose(out, out_check) - def test_taps_error(self): + def test_taps_error(self, return_updates): # Test that an error rises if we use taps in outputs_info. with pytest.raises(RuntimeError): scan_checkpoints(lambda: None, [], {"initial": self.A, "taps": [-2]}) diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index ef3aebd971..ba9de8809a 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -47,38 +47,47 @@ def test_remove_constants_and_unused_inputs_scan_non_seqs(self): """Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences.""" W = matrix(name="W") v = ivector(name="v") - y1, _ = scan( - lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W] + y1 = scan( + lambda i, W: W[i], + sequences=v, + outputs_info=None, + non_sequences=[W], + return_updates=False, ) - y2, _ = scan( + y2 = scan( lambda i, _, W: W[i], sequences=v, outputs_info=None, non_sequences=[W[0], W], + return_updates=False, ) - y3, _ = scan( + y3 = scan( lambda i, W, _: W[i], sequences=v, outputs_info=None, non_sequences=[W, W[0]], + return_updates=False, ) - y4, _ = scan( + y4 = scan( lambda i, _, _2, W: W[i], sequences=v, outputs_info=None, non_sequences=[W[0], W[0], W], + return_updates=False, ) - y5, _ = scan( + y5 = scan( lambda i, _, W, _2: W[i], sequences=v, outputs_info=None, non_sequences=[W[0], W, W[0]], + return_updates=False, ) - y6, _ = scan( + y6 = scan( lambda i, W, _, _2: W[i], sequences=v, outputs_info=None, non_sequences=[W, W[0], W[0]], + return_updates=False, ) # TODO: y7 have problem during run time. I think it should # raise an error during the scan construction. @@ -112,47 +121,61 @@ def test_remove_constants_and_unused_inputs_scan_seqs(self): W = matrix(name="W") v = ivector(name="v") vv = matrix(name="vv") - y1, _ = scan( - lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W] + y1 = scan( + lambda i, W: W[i], + sequences=v, + outputs_info=None, + non_sequences=[W], + return_updates=False, ) - y2, _ = scan( - lambda i, _, W: W[i], sequences=[v, v], outputs_info=None, non_sequences=W + y2 = scan( + lambda i, _, W: W[i], + sequences=[v, v], + outputs_info=None, + non_sequences=W, + return_updates=False, ) - y3, _ = scan( + y3 = scan( lambda i, _, W: W[i], sequences=[v, vv[0]], outputs_info=None, non_sequences=W, + return_updates=False, ) - y4, _ = scan( + y4 = scan( lambda _, i, W: W[i], sequences=[vv[0], v], outputs_info=None, non_sequences=W, + return_updates=False, ) - y5, _ = scan( + y5 = scan( lambda _, i, _2, W: W[i], sequences=[vv, v, vv[0]], outputs_info=None, non_sequences=W, + return_updates=False, ) - y6, _ = scan( + y6 = scan( lambda _, _2, i, W: W[i], sequences=[vv[0], vv, v], outputs_info=None, non_sequences=W, + return_updates=False, ) - y7, _ = scan( + y7 = scan( lambda i, _, _2, W: W[i], sequences=[v, vv[0], vv[0]], outputs_info=None, non_sequences=W, + return_updates=False, ) - y8, _ = scan( + y8 = scan( lambda _, i, W, _2, _3: W[i], sequences=[vv[0], v], outputs_info=None, non_sequences=[W, W[0], W[0]], + return_updates=False, ) W_val = np.random.normal(size=(3, 3)).astype(config.floatX) @@ -195,7 +218,7 @@ def test_pushout_all(self): def lambda_fn(h, W1, W2): return dot(h, W1 + W2) - o, _ = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5) + o = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5, return_updates=False) f = function([h0, W1, W2], o, mode=self.mode) @@ -232,19 +255,24 @@ def lambda_fn(step_idx, W1, W2): return dot(W1, W2), until_condition # Compile a function with the optimization - o, _ = scan( - lambda_fn, sequences=[step_indices, W1], non_sequences=[W2], n_steps=5 + o = scan( + lambda_fn, + sequences=[step_indices, W1], + non_sequences=[W2], + n_steps=5, + return_updates=False, ) f = function([W1, W2, step_indices], o, mode=self.mode) # Compule an pytensor function without the optimization - o, _ = scan( + o = scan( lambda_fn, sequences=[step_indices, W1], non_sequences=[W2], n_steps=5, mode="FAST_COMPILE", + return_updates=False, ) f_ref = function([W1, W2, step_indices], o, mode=self.mode) @@ -268,7 +296,13 @@ def test_pushout(self): def lambda_fn(h, W1, W2): return dot(h, W1 + W2) - o, _ = scan(lambda_fn, outputs_info=h0, non_sequences=[W1, W2], n_steps=5) + o = scan( + lambda_fn, + outputs_info=h0, + non_sequences=[W1, W2], + n_steps=5, + return_updates=False, + ) f = function([h0, W1, W2], o, mode=self.mode) @@ -290,10 +324,11 @@ def test_pushout_nomodif(self): def fn(i, i_tm1): return i + 10, i_tm1 - ([i_t, i_tm1], _) = scan( + [i_t, i_tm1] = scan( fn, sequences=[inp], outputs_info=[np.asarray([0.0, 0.0], config.floatX), None], + return_updates=False, ) f = function([inp], [i_t, i_tm1]) val = np.arange(10).reshape(5, 2).astype(config.floatX) @@ -397,17 +432,18 @@ def predict_mean_i(i, x_star, s_star, X, beta, h): @config.change_flags(on_opt_error="raise") def test_pushout_seqs2(self): x = matrix() - outputs, updates = scan( + outputs = scan( lambda x: [x * x, pt.constant(0).copy().copy()], n_steps=2, sequences=[], non_sequences=[], outputs_info=[x, None], + return_updates=False, ) # Compile an PyTensor function where any optimization error will lead to # an exception being raised - function([x], outputs, updates=updates) + function([x], outputs) @config.change_flags(on_opt_error="raise") def test_pushout_nonseq(self): @@ -418,7 +454,9 @@ def test_pushout_nonseq(self): outputs. This led the optimization to raise an exception. """ - outputs, _ = scan(lambda x: (x * x, x), non_sequences=[2], n_steps=2) + outputs = scan( + lambda x: (x * x, x), non_sequences=[2], n_steps=2, return_updates=False + ) f = function(inputs=[], outputs=outputs) outs = f() @@ -583,10 +621,12 @@ def test_nested_OpFromGraph_shared(self): test_ofg = OpFromGraph([], [y]) def inner_func(x): - out, _ = pytensor.scan(lambda: test_ofg(), n_steps=x) + out = pytensor.scan(lambda: test_ofg(), n_steps=x, return_updates=False) return out - out, _ = pytensor.scan(inner_func, sequences=[pt.arange(1, 2)]) + out = pytensor.scan( + inner_func, sequences=[pt.arange(1, 2)], return_updates=False + ) _ = pytensor.function([], test_ofg()) @@ -612,10 +652,11 @@ class TestPushOutAddScan: def test_sum_dot(self): A = matrix("A") B = matrix("B") - S, _ = scan( + S = scan( lambda x1, x2, u: u + dot(x1, x2), sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)], outputs_info=[pt.zeros_like(A)], + return_updates=False, ) # FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that. # They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]` @@ -636,13 +677,17 @@ def test_pregreedy_optimizer(self, benchmark): bv = pt.zeros((5,)) bh = pt.zeros((4,)) v = matrix("v") - (bv_t, bh_t), _ = scan( - lambda _: [bv, bh], sequences=v, outputs_info=[None, None] + (bv_t, bh_t) = scan( + lambda _: [bv, bh], + sequences=v, + outputs_info=[None, None], + return_updates=False, ) - chain, _ = scan( + chain = scan( lambda x: dot(dot(x, W) + bh_t, W.T) + bv_t, outputs_info=v, n_steps=2, + return_updates=False, ) # TODO FIXME: Make this a real test and assert something. chain_fn = function([v], chain) @@ -710,26 +755,28 @@ def rnn_step1( # Compile the function twice, once with the optimization and once # without opt_mode = mode.including("scan") - h, _ = pytensor.scan( + h = pytensor.scan( rnn_step1, sequences=[x, ri, zi], n_steps=seq_len, outputs_info=init, name="fpass1", mode=opt_mode, + return_updates=False, ) cost = h[-1].sum() grad1 = grad(cost, [U, V, W]) f_opt = pytensor.function(inputs=[x, ri, zi], outputs=grad1, mode=opt_mode) no_opt_mode = mode.excluding("scan_pushout_add") - h, _ = pytensor.scan( + h = pytensor.scan( rnn_step1, sequences=[x, ri, zi], n_steps=seq_len, outputs_info=init, name="fpass1", mode=no_opt_mode, + return_updates=False, ) cost = h[-1].sum() grad1 = grad(cost, [U, V, W]) @@ -773,21 +820,23 @@ def inner_fct(seq1, seq2, seq3, previous_output): # Compile the function twice, once with the optimization and once without opt_mode = mode.including("scan") - h, _ = pytensor.scan( + h = pytensor.scan( inner_fct, sequences=[input1, input2, input3], outputs_info=init, mode=opt_mode, + return_updates=False, ) output = h[-1] f_opt = pytensor.function([input1, input2, input3], output, mode=opt_mode) no_opt_mode = mode.excluding("scan_pushout_add") - h, _ = pytensor.scan( + h = pytensor.scan( inner_fct, sequences=[input1, input2, input3], outputs_info=init, mode=no_opt_mode, + return_updates=False, ) output = h[-1] f_no_opt = pytensor.function([input1, input2, input3], output, mode=no_opt_mode) @@ -892,13 +941,20 @@ def test_belongs_to_set(self): """ inps = vector() state = scalar() - y1, _ = scan(lambda x, y: x * y, sequences=inps, outputs_info=state, n_steps=5) + y1 = scan( + lambda x, y: x * y, + sequences=inps, + outputs_info=state, + n_steps=5, + return_updates=False, + ) - y2, _ = scan( + y2 = scan( lambda x, y: (x + y, until(x > 0)), sequences=inps, outputs_info=state, n_steps=5, + return_updates=False, ) scan_node1 = y1.owner.inputs[0].owner assert isinstance(scan_node1.op, Scan) @@ -958,8 +1014,8 @@ def add(s1, s2, const): def sub(s1, s2, const): return s1 - 1, until(s2 > const) - sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) - sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1]) + sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False) + sy = scan(sub, sequences=[y, -z], non_sequences=[c1], return_updates=False) f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode) assert self.count_scans(f) == 2 @@ -972,8 +1028,8 @@ def sub(s1, s2, const): np.testing.assert_array_equal(res_sx, [1, 1]) np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1]) - sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) - sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2]) + sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False) + sy = scan(sub, sequences=[y, z], non_sequences=[c2], return_updates=False) f = pytensor.function( inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode @@ -989,22 +1045,23 @@ def sub(s1, s2, const): np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1]) np.testing.assert_array_equal(res_sy, [-1, -1, -1]) - sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) - sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1]) + sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False) + sy = scan(sub, sequences=[y, z], non_sequences=[c1], return_updates=False) f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode) assert self.count_scans(f) == 1 def nested_scan(c, x, z): - sx, _ = scan(add, sequences=[x, z], non_sequences=[c]) - sy, _ = scan(sub, sequences=[x, z], non_sequences=[c]) + sx = scan(add, sequences=[x, z], non_sequences=[c], return_updates=False) + sy = scan(sub, sequences=[x, z], non_sequences=[c], return_updates=False) return sx.sum() + sy.sum() - sz, _ = scan( + sz = scan( nested_scan, sequences=[stack([c1, c2])], non_sequences=[x, z], mode=self.mode, + return_updates=False, ) f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode) @@ -1023,9 +1080,8 @@ def test_no_inplace(self): x = pt.vector("x") - scan_out, _ = pytensor.scan( - lambda x: (x + 1) / 2 + 1, - sequences=[x], + scan_out = pytensor.scan( + lambda x: (x + 1) / 2 + 1, sequences=[x], return_updates=False ) fgraph = FunctionGraph( @@ -1039,10 +1095,8 @@ def test_no_inplace(self): assert equal_computations([scan_out], fgraph.outputs) def test_inplace_basic(self): - scan_out, _ = pytensor.scan( - lambda x: x + 1, - outputs_info=[pt.zeros(1)], - n_steps=3, + scan_out = pytensor.scan( + lambda x: x + 1, outputs_info=[pt.zeros(1)], n_steps=3, return_updates=False ) fgraph = FunctionGraph( @@ -1089,7 +1143,7 @@ def f_rnn_shared(u0_t, u1_t, u2_t, x0_tm1, x1_tm1): u0_t * W_in + x1_tm1 * W + u1_t + u2_t, ] - outputs, updates = scan( + outputs = scan( f_rnn_shared, [u0, u1, u2], [dict(initial=x0, inplace=u2), dict(initial=x1, inplace=u1)], @@ -1098,12 +1152,12 @@ def f_rnn_shared(u0_t, u1_t, u2_t, x0_tm1, x1_tm1): truncate_gradient=-1, go_backwards=False, mode=self.mode, + return_updates=False, ) f9 = function( [mu0, mu1, mu2, x0, x1], outputs, - updates=updates, mode=self.mode, allow_input_downcast=True, ) @@ -1155,7 +1209,7 @@ def f_rnn_shared(u0_t, u1_t, u1_tp1, u2_tm1, u2_t, u2_tp1, x0_tm1, x1_tm1): u0_t * W_in + x1_tm1 * W + u2_tm1 + u2_t + u2_tp1, ] - outputs, updates = scan( + outputs = scan( f_rnn_shared, [u0, dict(input=u1, taps=[0, 1]), dict(input=u2, taps=[-1, 0, +1])], [dict(initial=x0), dict(initial=x1)], @@ -1164,11 +1218,11 @@ def f_rnn_shared(u0_t, u1_t, u1_tp1, u2_tm1, u2_t, u2_tp1, x0_tm1, x1_tm1): truncate_gradient=-1, go_backwards=False, mode=self.mode, + return_updates=False, ) f9 = function( [mu0, mu1, mu2, x0, x1], outputs, - updates=updates, mode=self.mode, allow_input_downcast=True, ) @@ -1202,8 +1256,12 @@ def test_inplace3(self): vx1 = asarrayX(rng.uniform()) x0 = shared(vx0) x1 = shared(vx1) - outputs, updates = scan( - lambda x, y: (x + asarrayX(1), y + asarrayX(1)), [], [x0, x1], n_steps=3 + outputs = scan( + lambda x, y: (x + asarrayX(1), y + asarrayX(1)), + [], + [x0, x1], + n_steps=3, + return_updates=False, ) x0 = asarrayX(np.zeros((4,))) x0[0] = vx0 @@ -1212,7 +1270,7 @@ def test_inplace3(self): to_replace = outputs[0].owner.inputs[0].owner.inputs[1] outputs = clone_replace(outputs, replace=[(to_replace, x0)]) - f9 = function([], outputs, updates=updates, mode=self.mode) + f9 = function([], outputs, mode=self.mode) scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)] assert 0 not in scan_node[0].op.destroy_map assert 1 in scan_node[0].op.destroy_map @@ -1249,7 +1307,7 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1): y_tm1 + dot(x_tm1, W_out), ] - _outputs, updates = scan( + outs = scan( f_rnn_cmpl, [u1, u2], [None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])], @@ -1257,12 +1315,12 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) - outputs = [_outputs[0][-1], _outputs[1][-1], _outputs[2][-1]] + outputs = [outs[0][-1], outs[1][-1], outs[2][-1]] f4 = function( [u1, u2, x0, y0, W_in1], outputs, - updates=updates, allow_input_downcast=True, mode=self.mode, ) @@ -1297,14 +1355,18 @@ def f_rnn(u_t): u = vector("u") idx = iscalar("idx") jdx = iscalar("jdx") - [x1, x2, x3, x4, x5, x6, x7], updates = scan( - f_rnn, u, n_steps=None, truncate_gradient=-1, go_backwards=False + [x1, x2, x3, x4, x5, x6, x7] = scan( + f_rnn, + u, + n_steps=None, + truncate_gradient=-1, + go_backwards=False, + return_updates=False, ) f2 = function( [u, idx, jdx], [x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]], - updates=updates, allow_input_downcast=True, mode=self.mode.excluding("scan_push_out_seq"), ) @@ -1341,10 +1403,8 @@ def f_rnn(u_t): def test_save_mem_reduced_number_of_steps_constant(self): x0 = pt.scalar("x0") - xs, _ = scan( - lambda xtm1: xtm1 + 1, - outputs_info=[x0], - n_steps=10, + xs = scan( + lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=10, return_updates=False ) fn = function([x0], xs[:5], mode=self.mode) @@ -1358,10 +1418,11 @@ def test_save_mem_reduced_number_of_steps_constant(self): def test_save_mem_cannot_reduce_constant_number_of_steps(self): x0 = pt.scalar("x0") - [xs, ys], _ = scan( + [xs, ys] = scan( lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1), outputs_info=[x0, x0], n_steps=10, + return_updates=False, ) # Because of ys[-1] we need all the steps! @@ -1399,7 +1460,7 @@ def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): x20 = scalar("x20") x30 = vector("x30") x40 = scalar("x40") - [x1, x2, x3, x4, x5, _x6, _x7], updates = scan( + [x1, x2, x3, x4, x5, _x6, _x7] = scan( step, u, [ @@ -1414,12 +1475,12 @@ def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): n_steps=None, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) f = function( [u, x10, x20, x30, x40], [x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]], - updates=updates, allow_input_downcast=True, mode=self.mode, ) @@ -1479,10 +1540,11 @@ def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): def test_savemem_does_not_duplicate_number_of_scan_nodes(self): var = pt.ones(()) - values, _ = scan( + values = scan( lambda x: ([x], (), until(x)), outputs_info=[var], n_steps=2, + return_updates=False, ) tmp_fn = function([var], values, mode=self.mode) @@ -1493,10 +1555,11 @@ def test_savemem_does_not_duplicate_number_of_scan_nodes(self): def test_savemem_opt(self, benchmark): y0 = shared(np.ones((2, 10))) - [_y1, y2], _updates = scan( + [_y1, y2] = scan( lambda y: [y, y], outputs_info=[dict(initial=y0, taps=[-2]), None], n_steps=5, + return_updates=False, ) # TODO FIXME: Make this a real test and assert something. fn = function([], y2.sum(), mode=self.mode) @@ -1515,23 +1578,25 @@ def inner_scan_step(x_t_t, h_tm1, w): return dot(h_tm1, w) + x_t_t def outer_scan_step(x_t, w): - h, _ = scan( + h = scan( inner_scan_step, sequences=[x_t[1:]], outputs_info=[x_t[0]], non_sequences=[w], strict=True, name="the_inner_scan", + return_updates=False, ) return h def get_outputs(x, w): - features, _ = scan( + features = scan( outer_scan_step, sequences=[x], non_sequences=[w], strict=True, name="the_outer_scan", + return_updates=False, ) return_val = grad(features.sum(), w) @@ -1571,7 +1636,7 @@ def f_pow2(x_tm1): state = vector("state") n_steps = iscalar("nsteps") - output, updates = scan( + output = scan( f_pow2, [], state, @@ -1579,13 +1644,13 @@ def f_pow2(x_tm1): n_steps=n_steps, truncate_gradient=-1, go_backwards=False, + return_updates=False, ) nw_shape = ivector("nw_shape") # Note that the output is reshaped to 3 dimensional tensor, and my_f = function( [state, n_steps, nw_shape], [reshape(output, nw_shape, ndim=3)[:-2], output[:-4]], - updates=updates, allow_input_downcast=True, ) nodes = [x for x in my_f.maker.fgraph.toposort() if isinstance(x.op, Scan)] @@ -1599,11 +1664,12 @@ def test_while_scan_taps(self): n_steps = scalar("n_steps", dtype="int64") x0 = vector("x0") - ys, _ = pytensor.scan( + ys = pytensor.scan( # Fibonacci Sequence lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)), outputs_info=[{"initial": x0, "taps": [-2, -1]}], n_steps=n_steps, + return_updates=False, ) # Save memory is triggered by choosing only last value y = ys[-1] @@ -1629,10 +1695,11 @@ def test_while_scan_taps(self): def test_while_scan_map(self): xs = vector("xs") - ys, _ = pytensor.scan( + ys = pytensor.scan( lambda x: (x + 1, {}, until(x + 1 >= 10)), outputs_info=[None], sequences=[xs], + return_updates=False, ) # Save memory is triggered by choosing only last value y = ys[-1] @@ -1656,11 +1723,12 @@ def test_while_scan_taps_and_map(self): n_steps = scalar("n_steps", dtype="int64") # while loop - [ys, zs], _ = pytensor.scan( + [ys, zs] = pytensor.scan( lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)), sequences=[seq], outputs_info=[x0, None], n_steps=n_steps, + return_updates=False, ) # Save memory is triggered by choosing only last value y = ys[-1] @@ -1696,10 +1764,11 @@ def test_broadcasted_init(self, keep_beginning, val_ndim): val_test = np.zeros(val_shape, dtype=val.dtype) init = pt.full((2,), val) - ys, _ = pytensor.scan( + ys = pytensor.scan( fn=lambda *args: pt.add(*args), outputs_info=[{"initial": init, "taps": (-2, -1)}], n_steps=100, + return_updates=False, ) out = ys[:-50] if keep_beginning else ys[-50:] @@ -1729,12 +1798,13 @@ def test_inner_replace_dot(): mode = get_default_mode().including("scan") # .excluding("BlasOpt") - o, _ = scan( + o = scan( lambda hi, him1, W: (hi, dot(hi + him1, W)), outputs_info=[pt.zeros([h.shape[1]]), None], sequences=[h], non_sequences=[W], mode=mode, + return_updates=False, ) f = function([W, h], o, mode=mode) @@ -1753,11 +1823,12 @@ def test_alloc_inputs1(): def lambda_fn(h, W1, W2): return dot(h, W1 * W2) - o, _ = scan( + o = scan( lambda_fn, outputs_info=h0, non_sequences=[W1, pt.zeros_like(W2)], n_steps=5, + return_updates=False, ) f = function([h0, W1, W2], o, mode=get_default_mode().including("scan")) @@ -1786,12 +1857,13 @@ def test_alloc_inputs2(): def lambda_fn(W1, h, W2): return W1 * dot(h, W2) - o, _ = scan( + o = scan( lambda_fn, sequences=pt.zeros_like(W1), outputs_info=h0, non_sequences=[pt.zeros_like(W2)], n_steps=5, + return_updates=False, ) f = function([h0, W1, W2], o, mode=get_default_mode().including("scan")) @@ -1821,12 +1893,13 @@ def test_alloc_inputs3(): def lambda_fn(W1, h, W2): return W1 * dot(h, W2) - o, _ = scan( + o = scan( lambda_fn, sequences=pt.zeros_like(W1), outputs_info=h0, non_sequences=[pt.zeros_like(W2)], n_steps=5, + return_updates=False, ) # TODO FIXME: This result depends on unrelated rewrites in the "fast" mode. @@ -1848,7 +1921,7 @@ def test_opt_order(): x = matrix("x") A = matrix("A") - z, _updates = scan(dot, sequences=[], non_sequences=[x, A], n_steps=2) + z = scan(dot, sequences=[], non_sequences=[x, A], n_steps=2, return_updates=False) f = function([x, A], z, mode="FAST_RUN") topo = f.maker.fgraph.toposort() diff --git a/tests/scan/test_views.py b/tests/scan/test_views.py index 38c9b9cfcd..b1e6f10957 100644 --- a/tests/scan/test_views.py +++ b/tests/scan/test_views.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import pytensor.tensor as pt from pytensor import config, function, grad, shared @@ -11,24 +12,41 @@ from tests.scan.test_basic import clone_optimized_graph, grab_scan_node -def test_reduce(): +@pytest.mark.parametrize("return_updates", [True, False]) +def test_reduce(return_updates): v = vector("v") s = scalar("s") - result, updates = pt_reduce(lambda x, y: x + y, v, s) + result_raw = pt_reduce(lambda x, y: x + y, v, s, return_updates=return_updates) + if return_updates: + result, updates = result_raw + assert not updates + else: + result = result_raw - f = function([v, s], result, updates=updates, allow_input_downcast=True) + f = function([v, s], result, allow_input_downcast=True) rng = np.random.default_rng(utt.fetch_seed()) v_v = rng.uniform(-5.0, 5.0, size=(5,)) assert abs(np.sum(v_v) - f(v_v, 0.0)) < 1e-3 -def test_map(): +@pytest.mark.parametrize("return_updates", [True, False]) +def test_map(return_updates): v = vector("v") - abs_expr, abs_updates = pt_map( - lambda x: abs(x), v, [], truncate_gradient=-1, go_backwards=False + abs_expr_raw = pt_map( + lambda x: abs(x), + v, + [], + truncate_gradient=-1, + go_backwards=False, + return_updates=return_updates, ) + if return_updates: + abs_expr, abs_updates = abs_expr_raw + assert not abs_updates + else: + abs_expr = abs_expr_raw - f = function([v], abs_expr, updates=abs_updates, allow_input_downcast=True) + f = function([v], abs_expr, allow_input_downcast=True) rng = np.random.default_rng(utt.fetch_seed()) vals = rng.uniform(-5.0, 5.0, size=(10,)) @@ -39,10 +57,11 @@ def test_map(): def test_reduce_memory_consumption(): x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX)) - o, _ = pt_reduce( + o = pt_reduce( lambda v, acc: acc + v, x, pt.constant(np.asarray(0.0, dtype=config.floatX)), + return_updates=False, ) mode = FAST_RUN mode = mode.excluding("inplace") @@ -69,13 +88,20 @@ def test_reduce_memory_consumption(): utt.assert_allclose(f2(), np.ones((10,))) -def test_foldl_memory_consumption(): +@pytest.mark.parametrize("return_updates", [True, False]) +def test_foldl_memory_consumption(return_updates): x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX)) - o, _ = foldl( + o_raw = foldl( lambda v, acc: acc + v, x, pt.constant(np.asarray(0.0, dtype=config.floatX)), + return_updates=return_updates, ) + if return_updates: + o, updates = o_raw + assert not updates + else: + o = o_raw mode = FAST_RUN mode = mode.excluding("inplace") @@ -102,13 +128,20 @@ def test_foldl_memory_consumption(): utt.assert_allclose(f2(), np.ones((10,))) -def test_foldr_memory_consumption(): +@pytest.mark.parametrize("return_updates", [True, False]) +def test_foldr_memory_consumption(return_updates): x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX)) - o, _ = foldr( + o_raw = foldr( lambda v, acc: acc + v, x, pt.constant(np.asarray(0.0, dtype=config.floatX)), + return_updates=return_updates, ) + if return_updates: + o, updates = o_raw + assert not updates + else: + o = o_raw mode = FAST_RUN mode = mode.excluding("inplace") diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index f1ea2e1af3..2e2b11257d 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): A = tensor("A", shape=(3, 3)) x0 = tensor("b", shape=(3, 4)) - xs, _ = scan( + xs = scan( lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed), outputs_info=[x0], non_sequences=[A], n_steps=10, + return_updates=False, ) fn_no_opt = function( diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index c2df7e9699..9f4acc74d6 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -694,10 +694,11 @@ def L_op(self, inputs, outputs, output_grads): def test_scan_gradient_core_type(): n_steps = 3 seq = tensor("seq", shape=(n_steps, 1), dtype="float64") - out, _ = scan( + out = scan( lambda s: s, sequences=[seq], n_steps=n_steps, + return_updates=False, ) vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")