Skip to content

Commit ceb0c8c

Browse files
committed
Simplify scan helper logic
return_steps has not been a thing for 14 years
1 parent 5ab9bf4 commit ceb0c8c

File tree

2 files changed

+25
-119
lines changed

2 files changed

+25
-119
lines changed

pytensor/scan/basic.py

Lines changed: 25 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from itertools import chain
23

34
import numpy as np
45

@@ -9,7 +10,7 @@
910
from pytensor.graph.basic import Constant, Variable
1011
from pytensor.graph.op import get_test_value
1112
from pytensor.graph.replace import clone_replace
12-
from pytensor.graph.traversal import graph_inputs
13+
from pytensor.graph.traversal import explicit_graph_inputs
1314
from pytensor.graph.utils import MissingInputError, TestValueError
1415
from pytensor.scan.op import Scan, ScanInfo
1516
from pytensor.scan.utils import expand_empty, safe_new, until
@@ -475,19 +476,15 @@ def wrap_into_list(x):
475476
else:
476477
non_seqs.append(elem)
477478

478-
# If we provided a known number of steps ( before compilation)
479-
# and if that number is 1 or -1, then we can skip the Scan Op,
480-
# and just apply the inner function once
481-
# To do that we check here to see the nature of n_steps
482-
n_fixed_steps = None
483-
479+
# This helper eagerly skips the Scan if n_steps is known to be 1
480+
single_step_requested = False
484481
if isinstance(n_steps, float | int):
485-
n_fixed_steps = int(n_steps)
482+
single_step_requested = n_steps == 1
486483
else:
487484
try:
488-
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
485+
single_step_requested = pt.get_scalar_constant_value(n_steps) == 1
489486
except NotScalarConstantError:
490-
n_fixed_steps = None
487+
pass
491488

492489
# Check n_steps is an int
493490
if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in integer_dtypes:
@@ -497,7 +494,6 @@ def wrap_into_list(x):
497494
n_seqs = len(seqs)
498495
n_outs = len(outs_info)
499496

500-
return_steps = {}
501497
# wrap sequences in a dictionary if they are not already dictionaries
502498
for i in range(n_seqs):
503499
if not isinstance(seqs[i], dict):
@@ -700,7 +696,6 @@ def wrap_into_list(x):
700696
mit_sot_inner_inputs = []
701697
mit_sot_inner_slices = []
702698
mit_sot_inner_outputs = []
703-
mit_sot_return_steps = {}
704699
mit_sot_tap_array = []
705700
mit_sot_rightOrder = []
706701

@@ -709,7 +704,6 @@ def wrap_into_list(x):
709704
sit_sot_inner_inputs = []
710705
sit_sot_inner_slices = []
711706
sit_sot_inner_outputs = []
712-
sit_sot_return_steps = {}
713707
sit_sot_rightOrder = []
714708

715709
# go through outputs picking up time slices as needed
@@ -755,8 +749,6 @@ def wrap_into_list(x):
755749
)
756750

757751
sit_sot_inner_slices.append(actual_arg)
758-
if i in return_steps:
759-
sit_sot_return_steps[n_sit_sot] = return_steps[i]
760752
sit_sot_inner_inputs.append(arg)
761753
sit_sot_rightOrder.append(i)
762754
n_sit_sot += 1
@@ -774,8 +766,6 @@ def wrap_into_list(x):
774766
expand_empty(init_out["initial"][:mintap], actual_n_steps)
775767
)
776768

777-
if i in return_steps:
778-
mit_sot_return_steps[n_mit_sot] = return_steps[i]
779769
mit_sot_rightOrder.append(i)
780770
n_mit_sot += 1
781771
for k in init_out["taps"]:
@@ -819,7 +809,7 @@ def wrap_into_list(x):
819809
offset = 0
820810
for idx in range(n_mit_sot):
821811
n_inputs = len(mit_sot_tap_array[idx])
822-
if n_fixed_steps in (1, -1):
812+
if single_step_requested:
823813
_ordered_args[mit_sot_rightOrder[idx]] = mit_sot_inner_slices[
824814
offset : offset + n_inputs
825815
]
@@ -830,17 +820,14 @@ def wrap_into_list(x):
830820
offset += n_inputs
831821

832822
for idx in range(n_sit_sot):
833-
if n_fixed_steps in (1, -1):
823+
if single_step_requested:
834824
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_slices[idx]]
835825
else:
836826
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
837827

838-
ordered_args = []
839-
for ls in _ordered_args:
840-
ordered_args += ls
841-
if n_fixed_steps in (1, -1):
828+
ordered_args = list(chain.from_iterable(_ordered_args))
829+
if single_step_requested:
842830
args = inner_slices + ordered_args + non_seqs
843-
844831
else:
845832
args = inner_seqs + ordered_args + non_seqs
846833

@@ -863,15 +850,15 @@ def wrap_into_list(x):
863850
# Step 3. Check if we actually need scan and remove it if we don't
864851
##
865852

866-
if n_fixed_steps in (1, -1):
853+
if single_step_requested:
867854
for pos, inner_out in enumerate(outputs):
868855
# we need to see if we need to pad our sequences with an
869856
# extra dimension; case example : we return an
870857
# output for which we want all intermediate. If n_steps is 1
871858
# then, if we return the output as given by the innner function
872859
# this will represent only a slice and it will have one
873860
# dimension less.
874-
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
861+
if isinstance(inner_out.type, TensorType):
875862
outputs[pos] = shape_padleft(inner_out)
876863

877864
if not return_list and len(outputs) == 1:
@@ -896,15 +883,10 @@ def wrap_into_list(x):
896883
fake_outputs = clone_replace(
897884
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
898885
)
899-
all_inputs = filter(
900-
lambda x: (
901-
isinstance(x, Variable)
902-
and not isinstance(x, SharedVariable)
903-
and not isinstance(x, Constant)
904-
),
905-
graph_inputs(fake_outputs),
906-
)
907-
extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
886+
known_inputs = [*args, *fake_nonseqs]
887+
extra_inputs = [
888+
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
889+
]
908890
non_seqs += extra_inputs
909891
# Note we do not use all_inputs directly since the order of variables
910892
# in args is quite important
@@ -1033,13 +1015,10 @@ def wrap_into_list(x):
10331015
# Step 5.4 Outputs with no taps used in the input
10341016
n_nit_sot = 0
10351017
nit_sot_inner_outputs = []
1036-
nit_sot_return_steps = {}
10371018
nit_sot_rightOrder = []
10381019
for i, out in enumerate(outs_info):
10391020
if "taps" not in out:
10401021
nit_sot_inner_outputs.append(outputs[i])
1041-
if i in return_steps:
1042-
nit_sot_return_steps[n_nit_sot] = return_steps[i]
10431022
nit_sot_rightOrder.append(i)
10441023
n_nit_sot += 1
10451024

@@ -1173,37 +1152,25 @@ def wrap_into_list(x):
11731152

11741153
update_map = OrderedUpdates()
11751154

1176-
def remove_dimensions(outs, steps_return, offsets=None):
1155+
def remove_dimensions(outs, offsets=None):
11771156
out_ls = []
11781157
for idx, out in enumerate(outs):
1179-
if idx in steps_return:
1180-
if steps_return[idx] > 1:
1181-
out_ls.append(out[-steps_return[idx] :])
1182-
else:
1183-
out_ls.append(out[-1])
1158+
if offsets is None:
1159+
out_ls.append(out)
11841160
else:
1185-
if offsets is None:
1186-
out_ls.append(out)
1187-
else:
1188-
out_ls.append(out[offsets[idx] :])
1161+
out_ls.append(out[offsets[idx] :])
11891162
return out_ls
11901163

11911164
offset = n_mit_mot
11921165
offsets = [abs(np.min(x)) for x in mit_sot_tap_array]
1193-
mit_sot_outs = remove_dimensions(
1194-
scan_outs[offset : offset + n_mit_sot], mit_sot_return_steps, offsets
1195-
)
1166+
mit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_mit_sot], offsets)
11961167

11971168
offset += n_mit_sot
11981169
offsets = [1 for x in range(n_sit_sot)]
1199-
sit_sot_outs = remove_dimensions(
1200-
scan_outs[offset : offset + n_sit_sot], sit_sot_return_steps, offsets
1201-
)
1170+
sit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_sit_sot], offsets)
12021171

12031172
offset += n_sit_sot
1204-
nit_sot_outs = remove_dimensions(
1205-
scan_outs[offset : offset + n_nit_sot], nit_sot_return_steps
1206-
)
1173+
nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
12071174

12081175
offset += n_nit_sot
12091176
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
@@ -1232,4 +1199,4 @@ def remove_dimensions(outs, steps_return, offsets=None):
12321199
elif len(scan_out_list) == 0:
12331200
scan_out_list = None
12341201

1235-
return (scan_out_list, update_map)
1202+
return scan_out_list, update_map

tests/scan/test_basic.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3650,67 +3650,6 @@ def lm(m):
36503650
if config.mode != "FAST_COMPILE":
36513651
assert nb_shape_i == 1
36523652

3653-
def test_return_steps(self):
3654-
rng = np.random.default_rng(utt.fetch_seed())
3655-
3656-
vW_in2 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
3657-
vW = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
3658-
vWout = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
3659-
vW_in1 = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
3660-
v_u1 = asarrayX(rng.uniform(-0.5, 0.5, size=(8, 2)))
3661-
v_u2 = asarrayX(rng.uniform(-0.5, 0.5, size=(8,)))
3662-
v_x0 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
3663-
v_y0 = asarrayX(rng.uniform(size=(3,)))
3664-
3665-
W_in2 = shared(vW_in2, name="win2")
3666-
W = shared(vW, name="w")
3667-
W_out = shared(vWout, name="wout")
3668-
W_in1 = matrix("win")
3669-
u1 = matrix("u1")
3670-
u2 = vector("u2")
3671-
x0 = vector("x0")
3672-
y0 = vector("y0")
3673-
3674-
def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
3675-
return [
3676-
y_tm3 + 1,
3677-
dot(u1_t, W_in1) + u2_t * W_in2 + dot(x_tm1, W),
3678-
y_tm1 + dot(x_tm1, W_out),
3679-
]
3680-
3681-
rval, updates = scan(
3682-
f_rnn_cmpl,
3683-
[u1, u2],
3684-
[None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])],
3685-
W_in1,
3686-
n_steps=None,
3687-
truncate_gradient=-1,
3688-
go_backwards=False,
3689-
)
3690-
3691-
outputs = []
3692-
outputs += [rval[0][-3:]]
3693-
outputs += [rval[1][-2:]]
3694-
outputs += [rval[2][-4:]]
3695-
f4 = function(
3696-
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
3697-
)
3698-
3699-
# compute the values in numpy
3700-
v_x = np.zeros((8, 2), dtype=config.floatX)
3701-
v_y = np.zeros((8,), dtype=config.floatX)
3702-
v_x[0] = np.dot(v_u1[0], vW_in1) + v_u2[0] * vW_in2 + np.dot(v_x0, vW)
3703-
v_y[0] = np.dot(v_x0, vWout) + v_y0[2]
3704-
3705-
for i in range(1, 8):
3706-
v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
3707-
v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]
3708-
3709-
(_pytensor_dump, pytensor_x, pytensor_y) = f4(v_u1, v_u2, v_x0, v_y0, vW_in1)
3710-
3711-
utt.assert_allclose(pytensor_x, v_x[-2:])
3712-
utt.assert_allclose(pytensor_y, v_y[-4:])
3713-
37143653
def test_until_random_infer_shape(self):
37153654
"""
37163655
Test for a crash in scan.infer_shape when using both

0 commit comments

Comments
 (0)