@@ -168,6 +168,26 @@ def isNaN_or_Inf_or_None(x):
168168 return isNone or isNaN or isInf or isStr
169169
170170
171+ def _manage_output_api_change (outputs , updates , return_updates ):
172+ if return_updates :
173+ warnings .warn (
174+ "Scan return signature well change. Updates dict will not be returned, only the first argument. "
175+ "Pass `return_updates=False` to conform to the new API and avoid this warning" ,
176+ DeprecationWarning ,
177+ # Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
178+ stacklevel = 2 ,
179+ )
180+ else :
181+ if updates :
182+ raise ValueError (
183+ f"return_updates=False but Scan produced updates { updates } ."
184+ "Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
185+ )
186+ return outputs
187+
188+ return outputs , updates
189+
190+
171191def scan (
172192 fn ,
173193 sequences = None ,
@@ -182,6 +202,7 @@ def scan(
182202 allow_gc = None ,
183203 strict = False ,
184204 return_list = False ,
205+ return_updates : bool = True ,
185206):
186207 r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
187208
@@ -878,6 +899,11 @@ def wrap_into_list(x):
878899 raw_inner_outputs = fn (* args )
879900
880901 condition , outputs , updates = get_updates_and_outputs (raw_inner_outputs )
902+ if updates :
903+ warnings .warn (
904+ "Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs." ,
905+ DeprecationWarning , # Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
906+ )
881907 if condition is not None :
882908 as_while = True
883909 else :
@@ -900,7 +926,7 @@ def wrap_into_list(x):
900926 if not return_list and len (outputs ) == 1 :
901927 outputs = outputs [0 ]
902928
903- return (outputs , updates )
929+ return _manage_output_api_change (outputs , updates , return_updates )
904930
905931 ##
906932 # Step 4. Compile the dummy function
@@ -919,6 +945,8 @@ def wrap_into_list(x):
919945 fake_outputs = clone_replace (
920946 outputs , replace = dict (zip (non_seqs , fake_nonseqs , strict = True ))
921947 )
948+ # TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
949+ # to find implicit inputs in a way that reduces the size of the inner function
922950 known_inputs = [* args , * fake_nonseqs ]
923951 extra_inputs = [
924952 x for x in explicit_graph_inputs (fake_outputs ) if x not in known_inputs
@@ -1074,7 +1102,7 @@ def wrap_into_list(x):
10741102 if not isinstance (arg , SharedVariable | Constant )
10751103 ]
10761104
1077- inner_replacements .update (dict (zip (other_scan_args , other_inner_args , strict = True )))
1105+ inner_replacements .update (dict (zip (other_scan_args , other_inner_args , strict = True ))) # type: ignore[arg-type]
10781106
10791107 if strict :
10801108 non_seqs_set = set (non_sequences if non_sequences is not None else [])
@@ -1123,7 +1151,7 @@ def wrap_into_list(x):
11231151 if condition is not None :
11241152 inner_outs .append (condition )
11251153
1126- new_outs = clone_replace (inner_outs , replace = inner_replacements )
1154+ new_outs = clone_replace (inner_outs , replace = inner_replacements ) # type: ignore[arg-type]
11271155
11281156 ##
11291157 # Step 7. Create the Scan Op
@@ -1211,12 +1239,14 @@ def remove_dimensions(outs, offsets=None):
12111239
12121240 offset += n_nit_sot
12131241
1214- # Support for explicit untraced sit_sot
1242+ # Legacy support for explicit untraced sit_sot and those built with update dictionary
1243+ # Switch to n_untraced_sit_sot_outs after deprecation period
12151244 n_explicit_untraced_sit_sot_outs = len (untraced_sit_sot_rightOrder )
12161245 untraced_sit_sot_outs = scan_outs [
12171246 offset : offset + n_explicit_untraced_sit_sot_outs
12181247 ]
12191248
1249+ # Legacy support: map shared outputs to their updates
12201250 offset += n_explicit_untraced_sit_sot_outs
12211251 for idx , update_rule in enumerate (scan_outs [offset :]):
12221252 update_map [untraced_sit_sot_scan_inputs [idx ]] = update_rule
@@ -1245,8 +1275,8 @@ def remove_dimensions(outs, offsets=None):
12451275 update_map [sit_sot_shared [abs (pos ) - 1 ]] = _scan_out_list [idx ][- 1 ]
12461276 scan_out_list = [x for x in scan_out_list if x is not None ]
12471277 if not return_list and len (scan_out_list ) == 1 :
1248- scan_out_list = scan_out_list [0 ]
1278+ scan_out_list = scan_out_list [0 ] # type: ignore[assignment]
12491279 elif len (scan_out_list ) == 0 :
1250- scan_out_list = None
1280+ scan_out_list = None # type: ignore[assignment]
12511281
1252- return scan_out_list , update_map
1282+ return _manage_output_api_change ( scan_out_list , update_map , return_updates )
0 commit comments