Skip to content

Commit 64a08e7

Browse files
committed
Allow non-shared untraced SIT-SOT
1 parent ceb0c8c commit 64a08e7

File tree

8 files changed

+335
-199
lines changed

8 files changed

+335
-199
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def scan(*outer_inputs):
6060
mit_mot_init,
6161
mit_sot_init,
6262
sit_sot_init,
63-
op.outer_shared(outer_inputs),
63+
op.outer_untraced_sit_sot(outer_inputs),
6464
op.outer_non_seqs(outer_inputs),
6565
) # JAX `init`
6666

@@ -118,7 +118,7 @@ def inner_func_outs_to_jax_outs(
118118
):
119119
"""Convert inner_scan_func outputs into format expected by JAX scan.
120120
121-
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
121+
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
122122
"""
123123
(
124124
i,
@@ -133,7 +133,7 @@ def inner_func_outs_to_jax_outs(
133133
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
134134
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
135135
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
136-
new_shared = op.inner_shared_outs(inner_scan_outs)
136+
new_shared = op.inner_untraced_sit_sot_outs(inner_scan_outs)
137137

138138
# New carry for next step
139139
# Update MIT-MOT buffer at positions indicated by output taps

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
108108
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
109109
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
110110
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
111-
outer_in_shared_names = op.outer_shared(outer_in_names)
111+
outer_in_shared_names = op.outer_untraced_sit_sot(outer_in_names)
112112
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
113113

114114
# These are all the outer-input names that have produce outputs/have output
@@ -204,11 +204,9 @@ def add_inner_in_expr(
204204

205205
# Inner-outputs consist of:
206206
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
207-
# shared-outputs [+ while-condition]
207+
# untraced-sit-sot-outputs [+ while-condition]
208208
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
209209

210-
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
211-
212210
# The assignment statements that copy inner-outputs into the outer-outputs
213211
# storage
214212
inner_out_to_outer_in_stmts: list[str] = []

pytensor/scan/basic.py

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.graph.op import get_test_value
1212
from pytensor.graph.replace import clone_replace
1313
from pytensor.graph.traversal import explicit_graph_inputs
14+
from pytensor.graph.type import HasShape
1415
from pytensor.graph.utils import MissingInputError, TestValueError
1516
from pytensor.scan.op import Scan, ScanInfo
1617
from pytensor.scan.utils import expand_empty, safe_new, until
@@ -706,6 +707,12 @@ def wrap_into_list(x):
706707
sit_sot_inner_outputs = []
707708
sit_sot_rightOrder = []
708709

710+
n_untraced_sit_sot_outs = 0
711+
untraced_sit_sot_scan_inputs = []
712+
untraced_sit_sot_inner_inputs = []
713+
untraced_sit_sot_inner_outputs = []
714+
untraced_sit_sot_rightOrder = []
715+
709716
# go through outputs picking up time slices as needed
710717
for i, init_out in enumerate(outs_info):
711718
# Note that our convention dictates that if an output uses
@@ -741,17 +748,35 @@ def wrap_into_list(x):
741748
# We need now to allocate space for storing the output and copy
742749
# the initial state over. We do this using the expand function
743750
# defined in scan utils
744-
sit_sot_scan_inputs.append(
745-
expand_empty(
746-
shape_padleft(actual_arg),
747-
actual_n_steps,
751+
if isinstance(actual_arg.type, HasShape):
752+
sit_sot_scan_inputs.append(
753+
expand_empty(
754+
shape_padleft(actual_arg),
755+
actual_n_steps,
756+
)
748757
)
749-
)
758+
sit_sot_inner_slices.append(actual_arg)
759+
760+
sit_sot_inner_inputs.append(arg)
761+
sit_sot_rightOrder.append(i)
762+
n_sit_sot += 1
763+
else:
764+
# Assume variables without shape cannot be stacked (e.g., RNG variables)
765+
# Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
766+
from pytensor.tensor.random.type import RandomType
750767

751-
sit_sot_inner_slices.append(actual_arg)
752-
sit_sot_inner_inputs.append(arg)
753-
sit_sot_rightOrder.append(i)
754-
n_sit_sot += 1
768+
if not isinstance(arg.type, RandomType):
769+
warnings.warn(
770+
(
771+
f"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. "
772+
"Only the last value will be returned, not the entire sequence."
773+
),
774+
UserWarning,
775+
)
776+
untraced_sit_sot_scan_inputs.append(actual_arg)
777+
untraced_sit_sot_inner_inputs.append(arg)
778+
n_untraced_sit_sot_outs += 1
779+
untraced_sit_sot_rightOrder.append(i)
755780

756781
elif init_out.get("taps", None):
757782
if np.any(np.array(init_out.get("taps", [])) > 0):
@@ -802,9 +827,10 @@ def wrap_into_list(x):
802827
# a map); in that case we do not have to do anything ..
803828

804829
# Re-order args
805-
max_mit_sot = np.max([-1, *mit_sot_rightOrder]) + 1
806-
max_sit_sot = np.max([-1, *sit_sot_rightOrder]) + 1
807-
n_elems = np.max([max_mit_sot, max_sit_sot])
830+
max_mit_sot = max(mit_sot_rightOrder, default=-1) + 1
831+
max_sit_sot = max(sit_sot_rightOrder, default=-1) + 1
832+
max_untraced_sit_sot_outs = max(untraced_sit_sot_rightOrder, default=-1) + 1
833+
n_elems = np.max((max_mit_sot, max_sit_sot, max_untraced_sit_sot_outs))
808834
_ordered_args = [[] for x in range(n_elems)]
809835
offset = 0
810836
for idx in range(n_mit_sot):
@@ -825,6 +851,11 @@ def wrap_into_list(x):
825851
else:
826852
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
827853

854+
for idx in range(n_untraced_sit_sot_outs):
855+
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
856+
untraced_sit_sot_inner_inputs[idx]
857+
]
858+
828859
ordered_args = list(chain.from_iterable(_ordered_args))
829860
if single_step_requested:
830861
args = inner_slices + ordered_args + non_seqs
@@ -842,6 +873,11 @@ def wrap_into_list(x):
842873
raw_inner_outputs = fn(*args)
843874

844875
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
876+
if updates:
877+
warnings.warn(
878+
"Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs.",
879+
DeprecationWarning, # Only meant for developers for now, not users. Switch to FutureWarning later, before removing.
880+
)
845881
if condition is not None:
846882
as_while = True
847883
else:
@@ -883,6 +919,8 @@ def wrap_into_list(x):
883919
fake_outputs = clone_replace(
884920
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
885921
)
922+
# TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
923+
# to find implicit inputs in a way that reduces the size of the inner function
886924
known_inputs = [*args, *fake_nonseqs]
887925
extra_inputs = [
888926
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
@@ -939,18 +977,19 @@ def wrap_into_list(x):
939977
if "taps" in out and out["taps"] != [-1]:
940978
mit_sot_inner_outputs.append(outputs[i])
941979

942-
# Step 5.2 Outputs with tap equal to -1
980+
# Step 5.2 Outputs with tap equal to -1 (traced and untraced)
943981
for i, out in enumerate(outs_info):
944982
if "taps" in out and out["taps"] == [-1]:
945-
sit_sot_inner_outputs.append(outputs[i])
983+
output = outputs[i]
984+
if isinstance(output.type, HasShape):
985+
sit_sot_inner_outputs.append(output)
986+
else:
987+
untraced_sit_sot_inner_outputs.append(output)
946988

947989
# Step 5.3 Outputs that correspond to update rules of shared variables
948-
inner_replacements = {}
949-
n_shared_outs = 0
950-
shared_scan_inputs = []
951-
shared_inner_inputs = []
952-
shared_inner_outputs = []
990+
# This whole special logic for shared variables is deprecated
953991
sit_sot_shared = []
992+
inner_replacements = {}
954993
no_update_shared_inputs = []
955994
for input in dummy_inputs:
956995
if not isinstance(input.variable, SharedVariable):
@@ -1003,10 +1042,10 @@ def wrap_into_list(x):
10031042
sit_sot_shared.append(input.variable)
10041043

10051044
else:
1006-
shared_inner_inputs.append(new_var)
1007-
shared_scan_inputs.append(input.variable)
1008-
shared_inner_outputs.append(input.update)
1009-
n_shared_outs += 1
1045+
untraced_sit_sot_inner_inputs.append(new_var)
1046+
untraced_sit_sot_scan_inputs.append(input.variable)
1047+
untraced_sit_sot_inner_outputs.append(input.update)
1048+
n_untraced_sit_sot_outs += 1
10101049
else:
10111050
no_update_shared_inputs.append(input)
10121051

@@ -1071,7 +1110,7 @@ def wrap_into_list(x):
10711110
+ mit_mot_inner_inputs
10721111
+ mit_sot_inner_inputs
10731112
+ sit_sot_inner_inputs
1074-
+ shared_inner_inputs
1113+
+ untraced_sit_sot_inner_inputs
10751114
+ other_shared_inner_args
10761115
+ other_inner_args
10771116
)
@@ -1081,7 +1120,7 @@ def wrap_into_list(x):
10811120
+ mit_sot_inner_outputs
10821121
+ sit_sot_inner_outputs
10831122
+ nit_sot_inner_outputs
1084-
+ shared_inner_outputs
1123+
+ untraced_sit_sot_inner_outputs
10851124
)
10861125
if condition is not None:
10871126
inner_outs.append(condition)
@@ -1101,7 +1140,7 @@ def wrap_into_list(x):
11011140
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
11021141
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
11031142
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
1104-
n_shared_outs=n_shared_outs,
1143+
n_untraced_sit_sot_outs=n_untraced_sit_sot_outs,
11051144
n_nit_sot=n_nit_sot,
11061145
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
11071146
as_while=as_while,
@@ -1127,7 +1166,7 @@ def wrap_into_list(x):
11271166
+ mit_mot_scan_inputs
11281167
+ mit_sot_scan_inputs
11291168
+ sit_sot_scan_inputs
1130-
+ shared_scan_inputs
1169+
+ untraced_sit_sot_scan_inputs
11311170
+ [actual_n_steps for x in range(n_nit_sot)]
11321171
+ other_shared_scan_args
11331172
+ other_scan_args
@@ -1173,13 +1212,28 @@ def remove_dimensions(outs, offsets=None):
11731212
nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
11741213

11751214
offset += n_nit_sot
1176-
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
1177-
update_map[shared_scan_inputs[idx]] = update_rule
11781215

1179-
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1216+
# Legacy support for explicit untraced sit_sot and those built with update dictionary
1217+
# Switch to n_untraced_sit_sot_outs after deprecation period
1218+
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
1219+
untraced_sit_sot_outs = scan_outs[
1220+
offset : offset + n_explicit_untraced_sit_sot_outs
1221+
]
1222+
1223+
# Legacy support: map shared outputs to their updates
1224+
offset += n_explicit_untraced_sit_sot_outs
1225+
for idx, update_rule in enumerate(scan_outs[offset:]):
1226+
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
1227+
1228+
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
11801229
# Step 10. I need to reorder the outputs to be in the order expected by
11811230
# the user
1182-
rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1231+
rightOrder = (
1232+
mit_sot_rightOrder
1233+
+ sit_sot_rightOrder
1234+
+ untraced_sit_sot_rightOrder
1235+
+ nit_sot_rightOrder
1236+
)
11831237
scan_out_list = [None] * len(rightOrder)
11841238
for idx, pos in enumerate(rightOrder):
11851239
if pos >= 0:

0 commit comments

Comments
 (0)