Skip to content

Commit 5b13c51

Browse files
committed
Start deprecating shared updates API in Scan
Using DeprecationWarning to keep it visible only for devs for now
1 parent 52296ff commit 5b13c51

File tree

8 files changed

+191
-44
lines changed

8 files changed

+191
-44
lines changed

pytensor/gradient.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
21882188
# It is possible that the inputs are disconnected from expr,
21892189
# even if they are connected to cost.
21902190
# This should not be an error.
2191-
hess, updates = pytensor.scan(
2191+
hess = pytensor.scan(
21922192
lambda i, y, x: grad(
21932193
y[i],
21942194
x,
@@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
21972197
),
21982198
sequences=pytensor.tensor.arange(expr.shape[0]),
21992199
non_sequences=[expr, input],
2200-
)
2201-
assert not updates, (
2202-
"Scan has returned a list of updates; this should not happen."
2200+
return_updates=False,
22032201
)
22042202
hessians.append(hess)
22052203
return as_list_or_tuple(using_list, using_tuple, hessians)

pytensor/scan/basic.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,26 @@ def isNaN_or_Inf_or_None(x):
168168
return isNone or isNaN or isInf or isStr
169169

170170

171+
def _manage_output_api_change(outputs, updates, return_updates):
172+
if return_updates:
173+
warnings.warn(
174+
"Scan return signature well change. Updates dict will not be returned, only the first argument. "
175+
"Pass `return_updates=False` to conform to the new API and avoid this warning",
176+
DeprecationWarning,
177+
# Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
178+
stacklevel=2,
179+
)
180+
else:
181+
if updates:
182+
raise ValueError(
183+
f"return_updates=False but Scan produced updates {updates}."
184+
"Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
185+
)
186+
return outputs
187+
188+
return outputs, updates
189+
190+
171191
def scan(
172192
fn,
173193
sequences=None,
@@ -182,6 +202,7 @@ def scan(
182202
allow_gc=None,
183203
strict=False,
184204
return_list=False,
205+
return_updates: bool = True,
185206
):
186207
r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
187208
@@ -878,6 +899,11 @@ def wrap_into_list(x):
878899
raw_inner_outputs = fn(*args)
879900

880901
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
902+
if updates:
903+
warnings.warn(
904+
"Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs.",
905+
DeprecationWarning, # Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
906+
)
881907
if condition is not None:
882908
as_while = True
883909
else:
@@ -900,7 +926,7 @@ def wrap_into_list(x):
900926
if not return_list and len(outputs) == 1:
901927
outputs = outputs[0]
902928

903-
return (outputs, updates)
929+
return _manage_output_api_change(outputs, updates, return_updates)
904930

905931
##
906932
# Step 4. Compile the dummy function
@@ -919,6 +945,8 @@ def wrap_into_list(x):
919945
fake_outputs = clone_replace(
920946
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
921947
)
948+
# TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
949+
# to find implicit inputs in a way that reduces the size of the inner function
922950
known_inputs = [*args, *fake_nonseqs]
923951
extra_inputs = [
924952
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
@@ -1074,7 +1102,7 @@ def wrap_into_list(x):
10741102
if not isinstance(arg, SharedVariable | Constant)
10751103
]
10761104

1077-
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True)))
1105+
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) # type: ignore[arg-type]
10781106

10791107
if strict:
10801108
non_seqs_set = set(non_sequences if non_sequences is not None else [])
@@ -1123,7 +1151,7 @@ def wrap_into_list(x):
11231151
if condition is not None:
11241152
inner_outs.append(condition)
11251153

1126-
new_outs = clone_replace(inner_outs, replace=inner_replacements)
1154+
new_outs = clone_replace(inner_outs, replace=inner_replacements) # type: ignore[arg-type]
11271155

11281156
##
11291157
# Step 7. Create the Scan Op
@@ -1211,12 +1239,14 @@ def remove_dimensions(outs, offsets=None):
12111239

12121240
offset += n_nit_sot
12131241

1214-
# Support for explicit untraced sit_sot
1242+
# Legacy support for explicit untraced sit_sot and those built with update dictionary
1243+
# Switch to n_untraced_sit_sot_outs after deprecation period
12151244
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
12161245
untraced_sit_sot_outs = scan_outs[
12171246
offset : offset + n_explicit_untraced_sit_sot_outs
12181247
]
12191248

1249+
# Legacy support: map shared outputs to their updates
12201250
offset += n_explicit_untraced_sit_sot_outs
12211251
for idx, update_rule in enumerate(scan_outs[offset:]):
12221252
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
@@ -1245,8 +1275,8 @@ def remove_dimensions(outs, offsets=None):
12451275
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
12461276
scan_out_list = [x for x in scan_out_list if x is not None]
12471277
if not return_list and len(scan_out_list) == 1:
1248-
scan_out_list = scan_out_list[0]
1278+
scan_out_list = scan_out_list[0] # type: ignore[assignment]
12491279
elif len(scan_out_list) == 0:
1250-
scan_out_list = None
1280+
scan_out_list = None # type: ignore[assignment]
12511281

1252-
return scan_out_list, update_map
1282+
return _manage_output_api_change(scan_out_list, update_map, return_updates)

pytensor/scan/checkpoints.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def scan_checkpoints(
1313
n_steps=None,
1414
save_every_N=10,
1515
padding=True,
16+
return_updates=True,
1617
):
1718
"""Scan function that uses less memory, but is more restrictive.
1819
@@ -157,31 +158,34 @@ def outer_step(*args):
157158
] * len(new_nitsots)
158159

159160
# Call the user-provided function with the proper arguments
160-
results, updates = scan(
161+
results_and_updates = scan(
161162
fn=fn,
162163
sequences=i_sequences[:-1],
163164
outputs_info=i_outputs_infos,
164165
non_sequences=i_non_sequences,
165166
name=name + "_inner",
166167
n_steps=i_sequences[-1],
168+
return_updates=return_updates,
167169
)
170+
if return_updates:
171+
results, updates = results_and_updates
172+
else:
173+
results = results_and_updates
174+
updates = {}
175+
168176
if not isinstance(results, list):
169177
results = [results]
170178

171179
# Keep only the last timestep of every output but keep all the updates
172-
if not isinstance(results, list):
173-
return results[-1], updates
174-
else:
175-
return [r[-1] for r in results], updates
180+
return [r[-1] for r in results], updates
176181

177-
results, updates = scan(
182+
return scan(
178183
fn=outer_step,
179184
sequences=o_sequences,
180185
outputs_info=outputs_info,
181186
non_sequences=o_nonsequences,
182187
name=name + "_outer",
183188
n_steps=o_n_steps,
184189
allow_gc=True,
190+
return_updates=return_updates,
185191
)
186-
187-
return results, updates

pytensor/scan/views.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def map(
1616
go_backwards=False,
1717
mode=None,
1818
name=None,
19+
return_updates=True,
1920
):
2021
"""Construct a `Scan` `Op` that functions like `map`.
2122
@@ -50,6 +51,7 @@ def map(
5051
go_backwards=go_backwards,
5152
mode=mode,
5253
name=name,
54+
return_updates=return_updates,
5355
)
5456

5557

@@ -61,6 +63,7 @@ def reduce(
6163
go_backwards=False,
6264
mode=None,
6365
name=None,
66+
return_updates=True,
6467
):
6568
"""Construct a `Scan` `Op` that functions like `reduce`.
6669
@@ -97,14 +100,29 @@ def reduce(
97100
truncate_gradient=-1,
98101
mode=mode,
99102
name=name,
103+
return_updates=return_updates,
100104
)
101-
if isinstance(rval[0], list | tuple):
102-
return [x[-1] for x in rval[0]], rval[1]
105+
if return_updates:
106+
if isinstance(rval[0], list | tuple):
107+
return [x[-1] for x in rval[0]], rval[1]
108+
else:
109+
return rval[0][-1], rval[1]
103110
else:
104-
return rval[0][-1], rval[1]
111+
if isinstance(rval, list | tuple):
112+
return [x[-1] for x in rval]
113+
else:
114+
return rval[-1]
105115

106116

107-
def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None):
117+
def foldl(
118+
fn,
119+
sequences,
120+
outputs_info,
121+
non_sequences=None,
122+
mode=None,
123+
name=None,
124+
return_updates=True,
125+
):
108126
"""Construct a `Scan` `Op` that functions like Haskell's `foldl`.
109127
110128
Parameters
@@ -135,10 +153,19 @@ def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
135153
go_backwards=False,
136154
mode=mode,
137155
name=name,
156+
return_updates=return_updates,
138157
)
139158

140159

141-
def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None):
160+
def foldr(
161+
fn,
162+
sequences,
163+
outputs_info,
164+
non_sequences=None,
165+
mode=None,
166+
name=None,
167+
return_updates=True,
168+
):
142169
"""Construct a `Scan` `Op` that functions like Haskell's `foldr`.
143170
144171
Parameters
@@ -169,4 +196,5 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
169196
go_backwards=True,
170197
mode=mode,
171198
name=name,
199+
return_updates=return_updates,
172200
)

pytensor/tensor/pad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,12 @@ def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
314314

315315

316316
def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis):
317-
[_, parts], _ = scan(
317+
[_, parts] = scan(
318318
inner_func,
319319
non_sequences=[array, array_flipped],
320320
outputs_info=[0, None],
321321
n_steps=repeats,
322+
return_updates=False,
322323
)
323324

324325
parts = moveaxis(parts, 0, axis)

tests/scan/test_basic.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.compile.sharedvalue import shared
2828
from pytensor.configdefaults import config
2929
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
30-
from pytensor.graph.basic import Apply, equal_computations
30+
from pytensor.graph.basic import Apply, Variable, equal_computations
3131
from pytensor.graph.fg import FunctionGraph
3232
from pytensor.graph.op import Op
3333
from pytensor.graph.replace import vectorize_graph
@@ -67,6 +67,7 @@
6767
vector,
6868
)
6969
from tests import unittest_tools as utt
70+
from tests.unittest_tools import assert_equal_computations
7071

7172

7273
if config.mode == "FAST_COMPILE":
@@ -4139,3 +4140,46 @@ def step(prev_x, prev_rng):
41394140
xs_ref.append(rng_ref.normal(xs_ref[-1]))
41404141
assert random_generator_type.values_eq(rng_ref, rng_final_eval)
41414142
np.testing.assert_allclose(xs_eval, xs_ref[1:])
4143+
4144+
4145+
@pytest.mark.filterwarnings("error")
4146+
def test_return_updates_api_change():
4147+
err_msg = "return_updates=False but Scan produced updates"
4148+
warn_msg1 = "Updates functionality in Scan are deprecated"
4149+
warn_msg2 = "Pass `return_updates=False` to conform to the new API"
4150+
4151+
x = shared(np.array(0, dtype="float64"))
4152+
4153+
with pytest.warns(DeprecationWarning, match=warn_msg2):
4154+
with pytest.warns(DeprecationWarning, match=warn_msg1):
4155+
traced1, updates1 = scan(
4156+
lambda: {x: x + 1},
4157+
outputs_info=[],
4158+
n_steps=5,
4159+
)
4160+
assert traced1 is None
4161+
assert len(updates1) == 1 and x in updates1
4162+
4163+
with pytest.warns(DeprecationWarning, match=warn_msg2):
4164+
traced2, updates2 = scan(
4165+
lambda x: x + 1,
4166+
outputs_info=[x],
4167+
n_steps=5,
4168+
)
4169+
assert isinstance(traced2, Variable)
4170+
assert isinstance(updates2, dict) and not updates2
4171+
4172+
traced3 = scan(
4173+
lambda x: x + 1,
4174+
outputs_info=[x],
4175+
n_steps=5,
4176+
return_updates=False,
4177+
)
4178+
assert isinstance(traced3, Variable)
4179+
4180+
assert_equal_computations(list(updates1.values()), [traced2[-1]])
4181+
assert_equal_computations([traced2], [traced3])
4182+
4183+
with pytest.raises(ValueError, match=err_msg):
4184+
with pytest.warns(DeprecationWarning, match=warn_msg1):
4185+
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)

0 commit comments

Comments
 (0)