@@ -163,6 +163,26 @@ def isNaN_or_Inf_or_None(x):
163163 return isNone or isNaN or isInf or isStr
164164
165165
166+ def _manage_output_api_change (outputs , updates , return_updates ):
167+ if return_updates :
168+ warnings .warn (
169+ "Scan return signature well change. Updates dict will not be returned, only the first argument. "
170+ "Pass `return_updates=False` to conform to the new API and avoid this warning" ,
171+ DeprecationWarning ,
172+ # Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
173+ stacklevel = 2 ,
174+ )
175+ else :
176+ if updates :
177+ raise ValueError (
178+ f"return_updates=False but Scan produced updates { updates } ."
179+ "Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
180+ )
181+ return outputs
182+
183+ return outputs , updates
184+
185+
166186def scan (
167187 fn ,
168188 sequences = None ,
@@ -177,6 +197,7 @@ def scan(
177197 allow_gc = None ,
178198 strict = False ,
179199 return_list = False ,
200+ return_updates : bool = True ,
180201):
181202 r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
182203
@@ -873,6 +894,11 @@ def wrap_into_list(x):
873894 raw_inner_outputs = fn (* args )
874895
875896 condition , outputs , updates = get_updates_and_outputs (raw_inner_outputs )
897+ if updates :
898+ warnings .warn (
899+ "Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs." ,
900+ DeprecationWarning , # Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
901+ )
876902 if condition is not None :
877903 as_while = True
878904 else :
@@ -895,7 +921,7 @@ def wrap_into_list(x):
895921 if not return_list and len (outputs ) == 1 :
896922 outputs = outputs [0 ]
897923
898- return (outputs , updates )
924+ return _manage_output_api_change (outputs , updates , return_updates )
899925
900926 ##
901927 # Step 4. Compile the dummy function
@@ -914,6 +940,8 @@ def wrap_into_list(x):
914940 fake_outputs = clone_replace (
915941 outputs , replace = dict (zip (non_seqs , fake_nonseqs , strict = True ))
916942 )
943+ # TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
944+ # to find implicit inputs in a way that reduces the size of the inner function
917945 known_inputs = [* args , * fake_nonseqs ]
918946 extra_inputs = [
919947 x for x in explicit_graph_inputs (fake_outputs ) if x not in known_inputs
@@ -1206,12 +1234,14 @@ def remove_dimensions(outs, offsets=None):
12061234
12071235 offset += n_nit_sot
12081236
1209- # Support for explicit untraced sit_sot
1237+ # Legacy support for explicit untraced sit_sot and those built with update dictionary
1238+ # Switch to n_untraced_sit_sot_outs after deprecation period
12101239 n_explicit_untraced_sit_sot_outs = len (untraced_sit_sot_rightOrder )
12111240 untraced_sit_sot_outs = scan_outs [
12121241 offset : offset + n_explicit_untraced_sit_sot_outs
12131242 ]
12141243
1244+ # Legacy support: map shared outputs to their updates
12151245 offset += n_explicit_untraced_sit_sot_outs
12161246 for idx , update_rule in enumerate (scan_outs [offset :]):
12171247 update_map [untraced_sit_sot_scan_inputs [idx ]] = update_rule
@@ -1244,4 +1274,4 @@ def remove_dimensions(outs, offsets=None):
12441274 elif len (scan_out_list ) == 0 :
12451275 scan_out_list = None
12461276
1247- return scan_out_list , update_map
1277+ return _manage_output_api_change ( scan_out_list , update_map , return_updates )
0 commit comments