@@ -60,23 +60,23 @@ def scan(*outer_inputs):
6060 mit_mot_init ,
6161 mit_sot_init ,
6262 sit_sot_init ,
63- op .outer_shared (outer_inputs ),
63+ op .outer_untraced_sit_sot (outer_inputs ),
6464 op .outer_non_seqs (outer_inputs ),
6565 ) # JAX `init`
6666
6767 def jax_args_to_inner_func_args (carry , x ):
6868 """Convert JAX scan arguments into format expected by scan_inner_func.
6969
70- scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared , non_seqs)
70+ scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, untraced SIT-SOT , non_seqs)
7171 """
7272
73- # `carry` contains all inner taps, shared terms, and non_seqs
73+ # `carry` contains all inner taps and non_seqs
7474 (
7575 i ,
7676 inner_mit_mot ,
7777 inner_mit_sot ,
7878 inner_sit_sot ,
79- inner_shared ,
79+ inner_untraced_sit_sot ,
8080 inner_non_seqs ,
8181 ) = carry
8282
@@ -108,7 +108,7 @@ def jax_args_to_inner_func_args(carry, x):
108108 * mit_mot_flatten ,
109109 * mit_sot_flatten ,
110110 * inner_sit_sot ,
111- * inner_shared ,
111+ * inner_untraced_sit_sot ,
112112 * inner_non_seqs ,
113113 )
114114
@@ -118,22 +118,22 @@ def inner_func_outs_to_jax_outs(
118118 ):
119119 """Convert inner_scan_func outputs into format expected by JAX scan.
120120
121- old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs ) -> (new_carry, ys)
121+ old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs ) -> (new_carry, ys)
122122 """
123123 (
124124 i ,
125125 old_mit_mot ,
126126 old_mit_sot ,
127127 _old_sit_sot ,
128- _old_shared ,
128+ _old_untraced_sit_sot ,
129129 inner_non_seqs ,
130130 ) = old_carry
131131
132132 new_mit_mot_vals = op .inner_mitmot_outs_grouped (inner_scan_outs )
133133 new_mit_sot_vals = op .inner_mitsot_outs (inner_scan_outs )
134134 new_sit_sot = op .inner_sitsot_outs (inner_scan_outs )
135135 new_nit_sot = op .inner_nitsot_outs (inner_scan_outs )
136- new_shared = op .inner_shared_outs (inner_scan_outs )
136+ new_untraced_sit_sot = op .inner_untraced_sit_sot_outs (inner_scan_outs )
137137
138138 # New carry for next step
139139 # Update MIT-MOT buffer at positions indicated by output taps
@@ -150,14 +150,14 @@ def inner_func_outs_to_jax_outs(
150150 old_mit_sot , new_mit_sot_vals , strict = True
151151 )
152152 ]
153- # For SIT-SOT, and shared just pass along the new value
153+ # For SIT-SOT just pass along the new value
154154 # Non-sequences remain unchanged
155155 new_carry = (
156156 i + 1 ,
157157 new_mit_mot ,
158158 new_mit_sot ,
159159 new_sit_sot ,
160- new_shared ,
160+ new_untraced_sit_sot ,
161161 inner_non_seqs ,
162162 )
163163
@@ -183,7 +183,7 @@ def jax_inner_func(carry, x):
183183 final_mit_mot ,
184184 _final_mit_sot ,
185185 _final_sit_sot ,
186- final_shared ,
186+ final_untraced_sit_sot ,
187187 _final_non_seqs ,
188188 ),
189189 traces ,
@@ -238,7 +238,7 @@ def get_partial_traces(traces):
238238 scan_outs_final = [
239239 * final_mit_mot ,
240240 * get_partial_traces (traces ),
241- * final_shared ,
241+ * final_untraced_sit_sot ,
242242 ]
243243
244244 if len (scan_outs_final ) == 1 :
0 commit comments