@@ -168,6 +168,26 @@ def isNaN_or_Inf_or_None(x):
168168 return isNone or isNaN or isInf or isStr
169169
170170
171+ def _manage_output_api_change (outputs , updates , return_updates ):
172+ if return_updates :
173+ warnings .warn (
174+ "Scan return signature will change. Updates dict will not be returned, only the first argument. "
175+ "Pass `return_updates=False` to conform to the new API and avoid this warning" ,
176+ DeprecationWarning ,
177+ # Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
178+ stacklevel = 2 ,
179+ )
180+ else :
181+ if updates :
182+ raise ValueError (
183+ f"return_updates=False but Scan produced updates { updates } ."
184+ "Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
185+ )
186+ return outputs
187+
188+ return outputs , updates
189+
190+
171191def scan (
172192 fn ,
173193 sequences = None ,
@@ -182,6 +202,7 @@ def scan(
182202 allow_gc = None ,
183203 strict = False ,
184204 return_list = False ,
205+ return_updates : bool = True ,
185206):
186207 r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
187208
@@ -900,7 +921,7 @@ def wrap_into_list(x):
900921 if not return_list and len (outputs ) == 1 :
901922 outputs = outputs [0 ]
902923
903- return (outputs , updates )
924+ return _manage_output_api_change (outputs , updates , return_updates )
904925
905926 ##
906927 # Step 4. Compile the dummy function
@@ -919,6 +940,8 @@ def wrap_into_list(x):
919940 fake_outputs = clone_replace (
920941 outputs , replace = dict (zip (non_seqs , fake_nonseqs , strict = True ))
921942 )
943+ # TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
944+ # to find implicit inputs in a way that reduces the size of the inner function
922945 known_inputs = [* args , * fake_nonseqs ]
923946 extra_inputs = [
924947 x for x in explicit_graph_inputs (fake_outputs ) if x not in known_inputs
@@ -1074,7 +1097,7 @@ def wrap_into_list(x):
10741097 if not isinstance (arg , SharedVariable | Constant )
10751098 ]
10761099
1077- inner_replacements .update (dict (zip (other_scan_args , other_inner_args , strict = True )))
1100+ inner_replacements .update (dict (zip (other_scan_args , other_inner_args , strict = True ))) # type: ignore[arg-type]
10781101
10791102 if strict :
10801103 non_seqs_set = set (non_sequences if non_sequences is not None else [])
@@ -1123,7 +1146,7 @@ def wrap_into_list(x):
11231146 if condition is not None :
11241147 inner_outs .append (condition )
11251148
1126- new_outs = clone_replace (inner_outs , replace = inner_replacements )
1149+ new_outs = clone_replace (inner_outs , replace = inner_replacements ) # type: ignore[arg-type]
11271150
11281151 ##
11291152 # Step 7. Create the Scan Op
@@ -1211,12 +1234,14 @@ def remove_dimensions(outs, offsets=None):
12111234
12121235 offset += n_nit_sot
12131236
1214- # Support for explicit untraced sit_sot
1237+ # Legacy support for explicit untraced sit_sot and those built with update dictionary
1238+ # Switch to n_untraced_sit_sot_outs after deprecation period
12151239 n_explicit_untraced_sit_sot_outs = len (untraced_sit_sot_rightOrder )
12161240 untraced_sit_sot_outs = scan_outs [
12171241 offset : offset + n_explicit_untraced_sit_sot_outs
12181242 ]
12191243
1244+ # Legacy support: map shared outputs to their updates
12201245 offset += n_explicit_untraced_sit_sot_outs
12211246 for idx , update_rule in enumerate (scan_outs [offset :]):
12221247 update_map [untraced_sit_sot_scan_inputs [idx ]] = update_rule
@@ -1245,8 +1270,8 @@ def remove_dimensions(outs, offsets=None):
12451270 update_map [sit_sot_shared [abs (pos ) - 1 ]] = _scan_out_list [idx ][- 1 ]
12461271 scan_out_list = [x for x in scan_out_list if x is not None ]
12471272 if not return_list and len (scan_out_list ) == 1 :
1248- scan_out_list = scan_out_list [0 ]
1273+ scan_out_list = scan_out_list [0 ] # type: ignore[assignment]
12491274 elif len (scan_out_list ) == 0 :
1250- scan_out_list = None
1275+ scan_out_list = None # type: ignore[assignment]
12511276
1252- return scan_out_list , update_map
1277+ return _manage_output_api_change ( scan_out_list , update_map , return_updates )
0 commit comments