Skip to content

Commit 10e03fd

Browse files
committed
Allow non-shared untraced SIT-SOT
1 parent bc72ef3 commit 10e03fd

File tree

8 files changed

+362
-221
lines changed

8 files changed

+362
-221
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,23 @@ def scan(*outer_inputs):
6060
mit_mot_init,
6161
mit_sot_init,
6262
sit_sot_init,
63-
op.outer_shared(outer_inputs),
63+
op.outer_untraced_sit_sot(outer_inputs),
6464
op.outer_non_seqs(outer_inputs),
6565
) # JAX `init`
6666

6767
def jax_args_to_inner_func_args(carry, x):
6868
"""Convert JAX scan arguments into format expected by scan_inner_func.
6969
70-
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs)
70+
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, untraced SIT-SOT, non_seqs)
7171
"""
7272

73-
# `carry` contains all inner taps, shared terms, and non_seqs
73+
# `carry` contains all inner taps and non_seqs
7474
(
7575
i,
7676
inner_mit_mot,
7777
inner_mit_sot,
7878
inner_sit_sot,
79-
inner_shared,
79+
inner_untraced_sit_sot,
8080
inner_non_seqs,
8181
) = carry
8282

@@ -108,7 +108,7 @@ def jax_args_to_inner_func_args(carry, x):
108108
*mit_mot_flatten,
109109
*mit_sot_flatten,
110110
*inner_sit_sot,
111-
*inner_shared,
111+
*inner_untraced_sit_sot,
112112
*inner_non_seqs,
113113
)
114114

@@ -118,22 +118,22 @@ def inner_func_outs_to_jax_outs(
118118
):
119119
"""Convert inner_scan_func outputs into format expected by JAX scan.
120120
121-
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
121+
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
122122
"""
123123
(
124124
i,
125125
old_mit_mot,
126126
old_mit_sot,
127127
_old_sit_sot,
128-
_old_shared,
128+
_old_untraced_sit_sot,
129129
inner_non_seqs,
130130
) = old_carry
131131

132132
new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_scan_outs)
133133
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
134134
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
135135
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
136-
new_shared = op.inner_shared_outs(inner_scan_outs)
136+
new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_scan_outs)
137137

138138
# New carry for next step
139139
# Update MIT-MOT buffer at positions indicated by output taps
@@ -150,14 +150,14 @@ def inner_func_outs_to_jax_outs(
150150
old_mit_sot, new_mit_sot_vals, strict=True
151151
)
152152
]
153-
# For SIT-SOT, and shared just pass along the new value
153+
# For SIT-SOT just pass along the new value
154154
# Non-sequences remain unchanged
155155
new_carry = (
156156
i + 1,
157157
new_mit_mot,
158158
new_mit_sot,
159159
new_sit_sot,
160-
new_shared,
160+
new_untraced_sit_sot,
161161
inner_non_seqs,
162162
)
163163

@@ -183,7 +183,7 @@ def jax_inner_func(carry, x):
183183
final_mit_mot,
184184
_final_mit_sot,
185185
_final_sit_sot,
186-
final_shared,
186+
final_untraced_sit_sot,
187187
_final_non_seqs,
188188
),
189189
traces,
@@ -238,7 +238,7 @@ def get_partial_traces(traces):
238238
scan_outs_final = [
239239
*final_mit_mot,
240240
*get_partial_traces(traces),
241-
*final_shared,
241+
*final_untraced_sit_sot,
242242
]
243243

244244
if len(scan_outs_final) == 1:

pytensor/link/numba/dispatch/scan.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
108108
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
109109
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
110110
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
111-
outer_in_shared_names = op.outer_shared(outer_in_names)
111+
outer_in_untraced_sit_sot_names = op.outer_untraced_sit_sot(outer_in_names)
112112
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
113113

114114
# These are all the outer-input names that have produce outputs/have output
115115
# taps (i.e. they have inner-outputs and corresponding outer-outputs).
116116
# Outer-outputs are ordered as follows:
117-
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs
117+
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs
118118
outer_in_outtap_names = (
119119
outer_in_mit_mot_names
120120
+ outer_in_mit_sot_names
121121
+ outer_in_sit_sot_names
122122
+ outer_in_nit_sot_names
123-
+ outer_in_shared_names
123+
+ outer_in_untraced_sit_sot_names
124124
)
125125

126126
# We create distinct variables for/references to the storage arrays for
@@ -138,16 +138,18 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
138138
for outer_in_name in outer_in_nit_sot_names:
139139
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage"
140140

141-
for outer_in_name in outer_in_shared_names:
142-
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage"
141+
for outer_in_name in outer_in_untraced_sit_sot_names:
142+
outer_in_to_storage_name[outer_in_name] = (
143+
f"{outer_in_name}_untraced_sit_sot_storage"
144+
)
143145

144146
outer_output_names = list(outer_in_to_storage_name.values())
145147
assert len(outer_output_names) == len(node.outputs)
146148

147149
# Construct the inner-input expressions (e.g. indexed storage expressions)
148150
# Inner-inputs are ordered as follows:
149151
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
150-
# shared-inputs + non-sequences.
152+
# untraced-sit-sot-inputs + non-sequences.
151153
temp_scalar_storage_alloc_stmts: list[str] = []
152154
inner_in_exprs_scalar: list[str] = []
153155
inner_in_exprs: list[str] = []
@@ -204,11 +206,9 @@ def add_inner_in_expr(
204206

205207
# Inner-outputs consist of:
206208
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
207-
# shared-outputs [+ while-condition]
209+
# untraced-sit-sot-outputs [+ while-condition]
208210
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
209211

210-
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
211-
212212
# The assignment statements that copy inner-outputs into the outer-outputs
213213
# storage
214214
inner_out_to_outer_in_stmts: list[str] = []

0 commit comments

Comments
 (0)