Skip to content

Commit 207735f

Browse files
committed
Start deprecating shared updates API in Scan
Using DeprecationWarning to keep it visible only for devs for now
1 parent 0216080 commit 207735f

File tree

8 files changed

+186
-44
lines changed

8 files changed

+186
-44
lines changed

pytensor/gradient.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
21882188
# It is possible that the inputs are disconnected from expr,
21892189
# even if they are connected to cost.
21902190
# This should not be an error.
2191-
hess, updates = pytensor.scan(
2191+
hess = pytensor.scan(
21922192
lambda i, y, x: grad(
21932193
y[i],
21942194
x,
@@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
21972197
),
21982198
sequences=pytensor.tensor.arange(expr.shape[0]),
21992199
non_sequences=[expr, input],
2200-
)
2201-
assert not updates, (
2202-
"Scan has returned a list of updates; this should not happen."
2200+
return_updates=False,
22032201
)
22042202
hessians.append(hess)
22052203
return as_list_or_tuple(using_list, using_tuple, hessians)

pytensor/scan/basic.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,26 @@ def isNaN_or_Inf_or_None(x):
168168
return isNone or isNaN or isInf or isStr
169169

170170

171+
def _manage_output_api_change(outputs, updates, return_updates):
172+
if return_updates:
173+
warnings.warn(
174+
"Scan return signature will change. Updates dict will not be returned, only the first argument. "
175+
"Pass `return_updates=False` to conform to the new API and avoid this warning",
176+
DeprecationWarning,
177+
# Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
178+
stacklevel=2,
179+
)
180+
else:
181+
if updates:
182+
raise ValueError(
183+
f"return_updates=False but Scan produced updates {updates}."
184+
"Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
185+
)
186+
return outputs
187+
188+
return outputs, updates
189+
190+
171191
def scan(
172192
fn,
173193
sequences=None,
@@ -182,6 +202,7 @@ def scan(
182202
allow_gc=None,
183203
strict=False,
184204
return_list=False,
205+
return_updates: bool = True,
185206
):
186207
r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
187208
@@ -900,7 +921,7 @@ def wrap_into_list(x):
900921
if not return_list and len(outputs) == 1:
901922
outputs = outputs[0]
902923

903-
return (outputs, updates)
924+
return _manage_output_api_change(outputs, updates, return_updates)
904925

905926
##
906927
# Step 4. Compile the dummy function
@@ -919,6 +940,8 @@ def wrap_into_list(x):
919940
fake_outputs = clone_replace(
920941
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
921942
)
943+
# TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
944+
# to find implicit inputs in a way that reduces the size of the inner function
922945
known_inputs = [*args, *fake_nonseqs]
923946
extra_inputs = [
924947
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
@@ -1074,7 +1097,7 @@ def wrap_into_list(x):
10741097
if not isinstance(arg, SharedVariable | Constant)
10751098
]
10761099

1077-
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True)))
1100+
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) # type: ignore[arg-type]
10781101

10791102
if strict:
10801103
non_seqs_set = set(non_sequences if non_sequences is not None else [])
@@ -1123,7 +1146,7 @@ def wrap_into_list(x):
11231146
if condition is not None:
11241147
inner_outs.append(condition)
11251148

1126-
new_outs = clone_replace(inner_outs, replace=inner_replacements)
1149+
new_outs = clone_replace(inner_outs, replace=inner_replacements) # type: ignore[arg-type]
11271150

11281151
##
11291152
# Step 7. Create the Scan Op
@@ -1211,12 +1234,14 @@ def remove_dimensions(outs, offsets=None):
12111234

12121235
offset += n_nit_sot
12131236

1214-
# Support for explicit untraced sit_sot
1237+
# Legacy support for explicit untraced sit_sot and those built with update dictionary
1238+
# Switch to n_untraced_sit_sot_outs after deprecation period
12151239
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
12161240
untraced_sit_sot_outs = scan_outs[
12171241
offset : offset + n_explicit_untraced_sit_sot_outs
12181242
]
12191243

1244+
# Legacy support: map shared outputs to their updates
12201245
offset += n_explicit_untraced_sit_sot_outs
12211246
for idx, update_rule in enumerate(scan_outs[offset:]):
12221247
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
@@ -1245,8 +1270,8 @@ def remove_dimensions(outs, offsets=None):
12451270
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
12461271
scan_out_list = [x for x in scan_out_list if x is not None]
12471272
if not return_list and len(scan_out_list) == 1:
1248-
scan_out_list = scan_out_list[0]
1273+
scan_out_list = scan_out_list[0] # type: ignore[assignment]
12491274
elif len(scan_out_list) == 0:
1250-
scan_out_list = None
1275+
scan_out_list = None # type: ignore[assignment]
12511276

1252-
return scan_out_list, update_map
1277+
return _manage_output_api_change(scan_out_list, update_map, return_updates)

pytensor/scan/checkpoints.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def scan_checkpoints(
1313
n_steps=None,
1414
save_every_N=10,
1515
padding=True,
16+
return_updates=True,
1617
):
1718
"""Scan function that uses less memory, but is more restrictive.
1819
@@ -157,31 +158,34 @@ def outer_step(*args):
157158
] * len(new_nitsots)
158159

159160
# Call the user-provided function with the proper arguments
160-
results, updates = scan(
161+
results_and_updates = scan(
161162
fn=fn,
162163
sequences=i_sequences[:-1],
163164
outputs_info=i_outputs_infos,
164165
non_sequences=i_non_sequences,
165166
name=name + "_inner",
166167
n_steps=i_sequences[-1],
168+
return_updates=return_updates,
167169
)
170+
if return_updates:
171+
results, updates = results_and_updates
172+
else:
173+
results = results_and_updates
174+
updates = {}
175+
168176
if not isinstance(results, list):
169177
results = [results]
170178

171179
# Keep only the last timestep of every output but keep all the updates
172-
if not isinstance(results, list):
173-
return results[-1], updates
174-
else:
175-
return [r[-1] for r in results], updates
180+
return [r[-1] for r in results], updates
176181

177-
results, updates = scan(
182+
return scan(
178183
fn=outer_step,
179184
sequences=o_sequences,
180185
outputs_info=outputs_info,
181186
non_sequences=o_nonsequences,
182187
name=name + "_outer",
183188
n_steps=o_n_steps,
184189
allow_gc=True,
190+
return_updates=return_updates,
185191
)
186-
187-
return results, updates

pytensor/scan/views.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def map(
1616
go_backwards=False,
1717
mode=None,
1818
name=None,
19+
return_updates=True,
1920
):
2021
"""Construct a `Scan` `Op` that functions like `map`.
2122
@@ -50,6 +51,7 @@ def map(
5051
go_backwards=go_backwards,
5152
mode=mode,
5253
name=name,
54+
return_updates=return_updates,
5355
)
5456

5557

@@ -61,6 +63,7 @@ def reduce(
6163
go_backwards=False,
6264
mode=None,
6365
name=None,
66+
return_updates=True,
6467
):
6568
"""Construct a `Scan` `Op` that functions like `reduce`.
6669
@@ -97,14 +100,29 @@ def reduce(
97100
truncate_gradient=-1,
98101
mode=mode,
99102
name=name,
103+
return_updates=return_updates,
100104
)
101-
if isinstance(rval[0], list | tuple):
102-
return [x[-1] for x in rval[0]], rval[1]
105+
if return_updates:
106+
if isinstance(rval[0], list | tuple):
107+
return [x[-1] for x in rval[0]], rval[1]
108+
else:
109+
return rval[0][-1], rval[1]
103110
else:
104-
return rval[0][-1], rval[1]
111+
if isinstance(rval, list | tuple):
112+
return [x[-1] for x in rval]
113+
else:
114+
return rval[-1]
105115

106116

107-
def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None):
117+
def foldl(
118+
fn,
119+
sequences,
120+
outputs_info,
121+
non_sequences=None,
122+
mode=None,
123+
name=None,
124+
return_updates=True,
125+
):
108126
"""Construct a `Scan` `Op` that functions like Haskell's `foldl`.
109127
110128
Parameters
@@ -135,10 +153,19 @@ def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
135153
go_backwards=False,
136154
mode=mode,
137155
name=name,
156+
return_updates=return_updates,
138157
)
139158

140159

141-
def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None):
160+
def foldr(
161+
fn,
162+
sequences,
163+
outputs_info,
164+
non_sequences=None,
165+
mode=None,
166+
name=None,
167+
return_updates=True,
168+
):
142169
"""Construct a `Scan` `Op` that functions like Haskell's `foldr`.
143170
144171
Parameters
@@ -169,4 +196,5 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
169196
go_backwards=True,
170197
mode=mode,
171198
name=name,
199+
return_updates=return_updates,
172200
)

pytensor/tensor/pad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,12 @@ def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
314314

315315

316316
def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis):
317-
[_, parts], _ = scan(
317+
[_, parts] = scan(
318318
inner_func,
319319
non_sequences=[array, array_flipped],
320320
outputs_info=[0, None],
321321
n_steps=repeats,
322+
return_updates=False,
322323
)
323324

324325
parts = moveaxis(parts, 0, axis)

tests/scan/test_basic.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.compile.sharedvalue import shared
2828
from pytensor.configdefaults import config
2929
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
30-
from pytensor.graph.basic import Apply, equal_computations
30+
from pytensor.graph.basic import Apply, Variable, equal_computations
3131
from pytensor.graph.fg import FunctionGraph
3232
from pytensor.graph.op import Op
3333
from pytensor.graph.replace import vectorize_graph
@@ -67,6 +67,7 @@
6767
vector,
6868
)
6969
from tests import unittest_tools as utt
70+
from tests.unittest_tools import assert_equal_computations
7071

7172

7273
if config.mode == "FAST_COMPILE":
@@ -4139,3 +4140,46 @@ def step(prev_x, prev_rng):
41394140
xs_ref.append(rng_ref.normal(xs_ref[-1]))
41404141
assert random_generator_type.values_eq(rng_ref, rng_final_eval)
41414142
np.testing.assert_allclose(xs_eval, xs_ref[1:])
4143+
4144+
4145+
@pytest.mark.filterwarnings("error")
4146+
def test_return_updates_api_change():
4147+
err_msg = "return_updates=False but Scan produced updates"
4148+
warn_msg1 = "Updates functionality in Scan are deprecated"
4149+
warn_msg2 = "Pass `return_updates=False` to conform to the new API"
4150+
4151+
x = shared(np.array(0, dtype="float64"))
4152+
4153+
with pytest.warns(DeprecationWarning, match=warn_msg2):
4154+
with pytest.warns(DeprecationWarning, match=warn_msg1):
4155+
traced1, updates1 = scan(
4156+
lambda: {x: x + 1},
4157+
outputs_info=[],
4158+
n_steps=5,
4159+
)
4160+
assert traced1 is None
4161+
assert len(updates1) == 1 and x in updates1
4162+
4163+
with pytest.warns(DeprecationWarning, match=warn_msg2):
4164+
traced2, updates2 = scan(
4165+
lambda x: x + 1,
4166+
outputs_info=[x],
4167+
n_steps=5,
4168+
)
4169+
assert isinstance(traced2, Variable)
4170+
assert isinstance(updates2, dict) and not updates2
4171+
4172+
traced3 = scan(
4173+
lambda x: x + 1,
4174+
outputs_info=[x],
4175+
n_steps=5,
4176+
return_updates=False,
4177+
)
4178+
assert isinstance(traced3, Variable)
4179+
4180+
assert_equal_computations(list(updates1.values()), [traced2[-1]])
4181+
assert_equal_computations([traced2], [traced3])
4182+
4183+
with pytest.raises(ValueError, match=err_msg):
4184+
with pytest.warns(DeprecationWarning, match=warn_msg1):
4185+
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)

tests/scan/test_checkpoints.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,44 +9,53 @@
99
from pytensor.tensor.type import iscalar, vector
1010

1111

12+
@pytest.mark.parametrize("return_updates", [True, False])
1213
class TestScanCheckpoint:
13-
def setup_method(self):
14+
def setup_method(self, return_updates):
1415
self.k = iscalar("k")
1516
self.A = vector("A")
1617
seq = arange(self.k, dtype="float32") + 1
17-
result, _ = scan(
18+
result_raw = scan(
1819
fn=lambda s, prior_result, A: prior_result * A / s,
1920
outputs_info=ones_like(self.A),
2021
sequences=[seq],
2122
non_sequences=self.A,
2223
n_steps=self.k,
24+
return_updates=return_updates,
2325
)
24-
result_check, _ = scan_checkpoints(
26+
result_check_raw = scan_checkpoints(
2527
fn=lambda s, prior_result, A: prior_result * A / s,
2628
outputs_info=ones_like(self.A),
2729
sequences=[seq],
2830
non_sequences=self.A,
2931
n_steps=self.k,
3032
save_every_N=100,
33+
return_updates=return_updates,
3134
)
35+
if return_updates:
36+
result, _ = result_raw
37+
result_check, _ = result_check_raw
38+
else:
39+
result = result_raw
40+
result_check = result_check_raw
3241
self.result = result[-1]
3342
self.result_check = result_check[-1]
3443
self.grad_A = grad(self.result.sum(), self.A)
3544
self.grad_A_check = grad(self.result_check.sum(), self.A)
3645

37-
def test_forward_pass(self):
46+
def test_forward_pass(self, return_updates):
3847
# Test forward computation of A**k.
3948
f = function(inputs=[self.A, self.k], outputs=[self.result, self.result_check])
4049
out, out_check = f(range(10), 101)
4150
assert np.allclose(out, out_check)
4251

43-
def test_backward_pass(self):
52+
def test_backward_pass(self, return_updates):
4453
# Test gradient computation of A**k.
4554
f = function(inputs=[self.A, self.k], outputs=[self.grad_A, self.grad_A_check])
4655
out, out_check = f(range(10), 101)
4756
assert np.allclose(out, out_check)
4857

49-
def test_taps_error(self):
58+
def test_taps_error(self, return_updates):
5059
# Test that an error rises if we use taps in outputs_info.
5160
with pytest.raises(RuntimeError):
5261
scan_checkpoints(lambda: None, [], {"initial": self.A, "taps": [-2]})

0 commit comments

Comments
 (0)