Skip to content

Commit 484a9c8

Browse files
committed
Allow non-shared untraced SIT-SOT
1 parent 5ab9bf4 commit 484a9c8

File tree

8 files changed

+336
-205
lines changed

8 files changed

+336
-205
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def scan(*outer_inputs):
6060
mit_mot_init,
6161
mit_sot_init,
6262
sit_sot_init,
63-
op.outer_shared(outer_inputs),
63+
op.outer_untraced_sit_sot(outer_inputs),
6464
op.outer_non_seqs(outer_inputs),
6565
) # JAX `init`
6666

@@ -118,7 +118,7 @@ def inner_func_outs_to_jax_outs(
118118
):
119119
"""Convert inner_scan_func outputs into format expected by JAX scan.
120120
121-
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
121+
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
122122
"""
123123
(
124124
i,
@@ -133,7 +133,7 @@ def inner_func_outs_to_jax_outs(
133133
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
134134
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
135135
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
136-
new_shared = op.inner_shared_outs(inner_scan_outs)
136+
new_shared = op.inner_untraced_sit_sot_outs(inner_scan_outs)
137137

138138
# New carry for next step
139139
# Update MIT-MOT buffer at positions indicated by output taps

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
108108
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
109109
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
110110
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
111-
outer_in_shared_names = op.outer_shared(outer_in_names)
111+
outer_in_shared_names = op.outer_untraced_sit_sot(outer_in_names)
112112
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
113113

114114
# These are all the outer-input names that have produce outputs/have output
@@ -204,11 +204,9 @@ def add_inner_in_expr(
204204

205205
# Inner-outputs consist of:
206206
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
207-
# shared-outputs [+ while-condition]
207+
# untraced-sit-sot-outputs [+ while-condition]
208208
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
209209

210-
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
211-
212210
# The assignment statements that copy inner-outputs into the outer-outputs
213211
# storage
214212
inner_out_to_outer_in_stmts: list[str] = []

pytensor/scan/basic.py

Lines changed: 86 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from itertools import chain
23

34
import numpy as np
45

@@ -10,6 +11,7 @@
1011
from pytensor.graph.op import get_test_value
1112
from pytensor.graph.replace import clone_replace
1213
from pytensor.graph.traversal import graph_inputs
14+
from pytensor.graph.type import HasShape
1315
from pytensor.graph.utils import MissingInputError, TestValueError
1416
from pytensor.scan.op import Scan, ScanInfo
1517
from pytensor.scan.utils import expand_empty, safe_new, until
@@ -712,6 +714,12 @@ def wrap_into_list(x):
712714
sit_sot_return_steps = {}
713715
sit_sot_rightOrder = []
714716

717+
n_untraced_sit_sot_outs = 0
718+
untraced_sit_sot_scan_inputs = []
719+
untraced_sit_sot_inner_inputs = []
720+
untraced_sit_sot_inner_outputs = []
721+
untraced_sit_sot_rightOrder = []
722+
715723
# go through outputs picking up time slices as needed
716724
for i, init_out in enumerate(outs_info):
717725
# Note that our convention dictates that if an output uses
@@ -747,19 +755,36 @@ def wrap_into_list(x):
747755
# We need now to allocate space for storing the output and copy
748756
# the initial state over. We do this using the expand function
749757
# defined in scan utils
750-
sit_sot_scan_inputs.append(
751-
expand_empty(
752-
shape_padleft(actual_arg),
753-
actual_n_steps,
758+
if isinstance(actual_arg.type, HasShape):
759+
sit_sot_scan_inputs.append(
760+
expand_empty(
761+
shape_padleft(actual_arg),
762+
actual_n_steps,
763+
)
754764
)
755-
)
765+
sit_sot_inner_slices.append(actual_arg)
766+
if i in return_steps:
767+
sit_sot_return_steps[n_sit_sot] = return_steps[i]
768+
sit_sot_inner_inputs.append(arg)
769+
sit_sot_rightOrder.append(i)
770+
n_sit_sot += 1
771+
else:
772+
# Assume variables without shape cannot be stacked (e.g., RNG variables)
773+
# Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
774+
from pytensor.tensor.random.type import RandomType
756775

757-
sit_sot_inner_slices.append(actual_arg)
758-
if i in return_steps:
759-
sit_sot_return_steps[n_sit_sot] = return_steps[i]
760-
sit_sot_inner_inputs.append(arg)
761-
sit_sot_rightOrder.append(i)
762-
n_sit_sot += 1
776+
if not isinstance(arg.type, RandomType):
777+
warnings.warn(
778+
(
779+
f"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. "
780+
"Only the last value will be returned, not the entire sequence."
781+
),
782+
UserWarning,
783+
)
784+
untraced_sit_sot_scan_inputs.append(actual_arg)
785+
untraced_sit_sot_inner_inputs.append(arg)
786+
n_untraced_sit_sot_outs += 1
787+
untraced_sit_sot_rightOrder.append(i)
763788

764789
elif init_out.get("taps", None):
765790
if np.any(np.array(init_out.get("taps", [])) > 0):
@@ -812,9 +837,10 @@ def wrap_into_list(x):
812837
# a map); in that case we do not have to do anything ..
813838

814839
# Re-order args
815-
max_mit_sot = np.max([-1, *mit_sot_rightOrder]) + 1
816-
max_sit_sot = np.max([-1, *sit_sot_rightOrder]) + 1
817-
n_elems = np.max([max_mit_sot, max_sit_sot])
840+
max_mit_sot = max(mit_sot_rightOrder, default=-1) + 1
841+
max_sit_sot = max(sit_sot_rightOrder, default=-1) + 1
842+
max_untraced_sit_sot_outs = max(untraced_sit_sot_rightOrder, default=-1) + 1
843+
n_elems = np.max((max_mit_sot, max_sit_sot, max_untraced_sit_sot_outs))
818844
_ordered_args = [[] for x in range(n_elems)]
819845
offset = 0
820846
for idx in range(n_mit_sot):
@@ -835,9 +861,12 @@ def wrap_into_list(x):
835861
else:
836862
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
837863

838-
ordered_args = []
839-
for ls in _ordered_args:
840-
ordered_args += ls
864+
for idx in range(n_untraced_sit_sot_outs):
865+
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
866+
untraced_sit_sot_inner_inputs[idx]
867+
]
868+
869+
ordered_args = list(chain.from_iterable(_ordered_args))
841870
if n_fixed_steps in (1, -1):
842871
args = inner_slices + ordered_args + non_seqs
843872

@@ -855,6 +884,11 @@ def wrap_into_list(x):
855884
raw_inner_outputs = fn(*args)
856885

857886
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
887+
if updates:
888+
warnings.warn(
889+
"Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs.",
890+
DeprecationWarning, # Only meant for developers for now, not users. Switch to FutureWarning later, before removing.
891+
)
858892
if condition is not None:
859893
as_while = True
860894
else:
@@ -957,18 +991,18 @@ def wrap_into_list(x):
957991
if "taps" in out and out["taps"] != [-1]:
958992
mit_sot_inner_outputs.append(outputs[i])
959993

960-
# Step 5.2 Outputs with tap equal to -1
994+
# Step 5.2 Outputs with tap equal to -1 (traced and untraced)
961995
for i, out in enumerate(outs_info):
962996
if "taps" in out and out["taps"] == [-1]:
963-
sit_sot_inner_outputs.append(outputs[i])
997+
if isinstance(out["initial"].type, HasShape):
998+
sit_sot_inner_outputs.append(outputs[i])
999+
else:
1000+
untraced_sit_sot_inner_outputs.append(outputs[i])
9641001

9651002
# Step 5.3 Outputs that correspond to update rules of shared variables
966-
inner_replacements = {}
967-
n_shared_outs = 0
968-
shared_scan_inputs = []
969-
shared_inner_inputs = []
970-
shared_inner_outputs = []
1003+
# This whole special logic for shared variables is deprecated
9711004
sit_sot_shared = []
1005+
inner_replacements = {}
9721006
no_update_shared_inputs = []
9731007
for input in dummy_inputs:
9741008
if not isinstance(input.variable, SharedVariable):
@@ -1021,10 +1055,10 @@ def wrap_into_list(x):
10211055
sit_sot_shared.append(input.variable)
10221056

10231057
else:
1024-
shared_inner_inputs.append(new_var)
1025-
shared_scan_inputs.append(input.variable)
1026-
shared_inner_outputs.append(input.update)
1027-
n_shared_outs += 1
1058+
untraced_sit_sot_inner_inputs.append(new_var)
1059+
untraced_sit_sot_scan_inputs.append(input.variable)
1060+
untraced_sit_sot_inner_outputs.append(input.update)
1061+
n_untraced_sit_sot_outs += 1
10281062
else:
10291063
no_update_shared_inputs.append(input)
10301064

@@ -1092,7 +1126,7 @@ def wrap_into_list(x):
10921126
+ mit_mot_inner_inputs
10931127
+ mit_sot_inner_inputs
10941128
+ sit_sot_inner_inputs
1095-
+ shared_inner_inputs
1129+
+ untraced_sit_sot_inner_inputs
10961130
+ other_shared_inner_args
10971131
+ other_inner_args
10981132
)
@@ -1102,7 +1136,7 @@ def wrap_into_list(x):
11021136
+ mit_sot_inner_outputs
11031137
+ sit_sot_inner_outputs
11041138
+ nit_sot_inner_outputs
1105-
+ shared_inner_outputs
1139+
+ untraced_sit_sot_inner_outputs
11061140
)
11071141
if condition is not None:
11081142
inner_outs.append(condition)
@@ -1122,7 +1156,7 @@ def wrap_into_list(x):
11221156
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
11231157
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
11241158
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
1125-
n_shared_outs=n_shared_outs,
1159+
n_untraced_sit_sot_outs=n_untraced_sit_sot_outs,
11261160
n_nit_sot=n_nit_sot,
11271161
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
11281162
as_while=as_while,
@@ -1148,7 +1182,7 @@ def wrap_into_list(x):
11481182
+ mit_mot_scan_inputs
11491183
+ mit_sot_scan_inputs
11501184
+ sit_sot_scan_inputs
1151-
+ shared_scan_inputs
1185+
+ untraced_sit_sot_scan_inputs
11521186
+ [actual_n_steps for x in range(n_nit_sot)]
11531187
+ other_shared_scan_args
11541188
+ other_scan_args
@@ -1206,13 +1240,28 @@ def remove_dimensions(outs, steps_return, offsets=None):
12061240
)
12071241

12081242
offset += n_nit_sot
1209-
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
1210-
update_map[shared_scan_inputs[idx]] = update_rule
12111243

1212-
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1244+
# Legacy support for explicit untraced sit_sot and those built with update dictionary
1245+
# Switch to n_untraced_sit_sot_outs after deprecation period
1246+
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
1247+
untraced_sit_sot_outs = scan_outs[
1248+
offset : offset + n_explicit_untraced_sit_sot_outs
1249+
]
1250+
1251+
# Legacy support: map shared outputs to their updates
1252+
offset += n_explicit_untraced_sit_sot_outs
1253+
for idx, update_rule in enumerate(scan_outs[offset:]):
1254+
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
1255+
1256+
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
12131257
# Step 10. I need to reorder the outputs to be in the order expected by
12141258
# the user
1215-
rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1259+
rightOrder = (
1260+
mit_sot_rightOrder
1261+
+ sit_sot_rightOrder
1262+
+ untraced_sit_sot_rightOrder
1263+
+ nit_sot_rightOrder
1264+
)
12161265
scan_out_list = [None] * len(rightOrder)
12171266
for idx, pos in enumerate(rightOrder):
12181267
if pos >= 0:
@@ -1232,4 +1281,4 @@ def remove_dimensions(outs, steps_return, offsets=None):
12321281
elif len(scan_out_list) == 0:
12331282
scan_out_list = None
12341283

1235-
return (scan_out_list, update_map)
1284+
return scan_out_list, update_map

0 commit comments

Comments
 (0)