11import warnings
2+ from itertools import chain
23
34import numpy as np
45
1011from pytensor .graph .op import get_test_value
1112from pytensor .graph .replace import clone_replace
1213from pytensor .graph .traversal import graph_inputs
14+ from pytensor .graph .type import HasShape
1315from pytensor .graph .utils import MissingInputError , TestValueError
1416from pytensor .scan .op import Scan , ScanInfo
1517from pytensor .scan .utils import expand_empty , safe_new , until
@@ -712,6 +714,12 @@ def wrap_into_list(x):
712714 sit_sot_return_steps = {}
713715 sit_sot_rightOrder = []
714716
717+ n_untraced_sit_sot_outs = 0
718+ untraced_sit_sot_scan_inputs = []
719+ untraced_sit_sot_inner_inputs = []
720+ untraced_sit_sot_inner_outputs = []
721+ untraced_sit_sot_rightOrder = []
722+
715723 # go through outputs picking up time slices as needed
716724 for i , init_out in enumerate (outs_info ):
717725 # Note that our convention dictates that if an output uses
@@ -747,19 +755,36 @@ def wrap_into_list(x):
747755 # We need now to allocate space for storing the output and copy
748756 # the initial state over. We do this using the expand function
749757 # defined in scan utils
750- sit_sot_scan_inputs .append (
751- expand_empty (
752- shape_padleft (actual_arg ),
753- actual_n_steps ,
758+ if isinstance (actual_arg .type , HasShape ):
759+ sit_sot_scan_inputs .append (
760+ expand_empty (
761+ shape_padleft (actual_arg ),
762+ actual_n_steps ,
763+ )
754764 )
755- )
765+ sit_sot_inner_slices .append (actual_arg )
766+ if i in return_steps :
767+ sit_sot_return_steps [n_sit_sot ] = return_steps [i ]
768+ sit_sot_inner_inputs .append (arg )
769+ sit_sot_rightOrder .append (i )
770+ n_sit_sot += 1
771+ else :
772+ # Assume variables without shape cannot be stacked (e.g., RNG variables)
773+ # Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
774+ from pytensor .tensor .random .type import RandomType
756775
757- sit_sot_inner_slices .append (actual_arg )
758- if i in return_steps :
759- sit_sot_return_steps [n_sit_sot ] = return_steps [i ]
760- sit_sot_inner_inputs .append (arg )
761- sit_sot_rightOrder .append (i )
762- n_sit_sot += 1
776+ if not isinstance (arg .type , RandomType ):
777+ warnings .warn (
778+ (
779+ f"Output { actual_arg } (index { i } ) with type { actual_arg .type } will be treated as untraced variable in scan. "
780+ "Only the last value will be returned, not the entire sequence."
781+ ),
782+ UserWarning ,
783+ )
784+ untraced_sit_sot_scan_inputs .append (actual_arg )
785+ untraced_sit_sot_inner_inputs .append (arg )
786+ n_untraced_sit_sot_outs += 1
787+ untraced_sit_sot_rightOrder .append (i )
763788
764789 elif init_out .get ("taps" , None ):
765790 if np .any (np .array (init_out .get ("taps" , [])) > 0 ):
@@ -812,9 +837,10 @@ def wrap_into_list(x):
812837 # a map); in that case we do not have to do anything ..
813838
814839 # Re-order args
815- max_mit_sot = np .max ([- 1 , * mit_sot_rightOrder ]) + 1
816- max_sit_sot = np .max ([- 1 , * sit_sot_rightOrder ]) + 1
817- n_elems = np .max ([max_mit_sot , max_sit_sot ])
840+ max_mit_sot = max (mit_sot_rightOrder , default = - 1 ) + 1
841+ max_sit_sot = max (sit_sot_rightOrder , default = - 1 ) + 1
842+ max_untraced_sit_sot_outs = max (untraced_sit_sot_rightOrder , default = - 1 ) + 1
843+ n_elems = np .max ((max_mit_sot , max_sit_sot , max_untraced_sit_sot_outs ))
818844 _ordered_args = [[] for x in range (n_elems )]
819845 offset = 0
820846 for idx in range (n_mit_sot ):
@@ -835,9 +861,12 @@ def wrap_into_list(x):
835861 else :
836862 _ordered_args [sit_sot_rightOrder [idx ]] = [sit_sot_inner_inputs [idx ]]
837863
838- ordered_args = []
839- for ls in _ordered_args :
840- ordered_args += ls
864+ for idx in range (n_untraced_sit_sot_outs ):
865+ _ordered_args [untraced_sit_sot_rightOrder [idx ]] = [
866+ untraced_sit_sot_inner_inputs [idx ]
867+ ]
868+
869+ ordered_args = list (chain .from_iterable (_ordered_args ))
841870 if n_fixed_steps in (1 , - 1 ):
842871 args = inner_slices + ordered_args + non_seqs
843872
@@ -855,6 +884,11 @@ def wrap_into_list(x):
855884 raw_inner_outputs = fn (* args )
856885
857886 condition , outputs , updates = get_updates_and_outputs (raw_inner_outputs )
887+ if updates :
888+ warnings .warn (
889+ "Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs." ,
890+ DeprecationWarning , # Only meant for developers for now, not users. Switch to FutureWarning later, before removing.
891+ )
858892 if condition is not None :
859893 as_while = True
860894 else :
@@ -957,18 +991,18 @@ def wrap_into_list(x):
957991 if "taps" in out and out ["taps" ] != [- 1 ]:
958992 mit_sot_inner_outputs .append (outputs [i ])
959993
960- # Step 5.2 Outputs with tap equal to -1
994+ # Step 5.2 Outputs with tap equal to -1 (traced and untraced)
961995 for i , out in enumerate (outs_info ):
962996 if "taps" in out and out ["taps" ] == [- 1 ]:
963- sit_sot_inner_outputs .append (outputs [i ])
997+ if isinstance (out ["initial" ].type , HasShape ):
998+ sit_sot_inner_outputs .append (outputs [i ])
999+ else :
1000+ untraced_sit_sot_inner_outputs .append (outputs [i ])
9641001
9651002 # Step 5.3 Outputs that correspond to update rules of shared variables
966- inner_replacements = {}
967- n_shared_outs = 0
968- shared_scan_inputs = []
969- shared_inner_inputs = []
970- shared_inner_outputs = []
1003+ # This whole special logic for shared variables is deprecated
9711004 sit_sot_shared = []
1005+ inner_replacements = {}
9721006 no_update_shared_inputs = []
9731007 for input in dummy_inputs :
9741008 if not isinstance (input .variable , SharedVariable ):
@@ -1021,10 +1055,10 @@ def wrap_into_list(x):
10211055 sit_sot_shared .append (input .variable )
10221056
10231057 else :
1024- shared_inner_inputs .append (new_var )
1025- shared_scan_inputs .append (input .variable )
1026- shared_inner_outputs .append (input .update )
1027- n_shared_outs += 1
1058+ untraced_sit_sot_inner_inputs .append (new_var )
1059+ untraced_sit_sot_scan_inputs .append (input .variable )
1060+ untraced_sit_sot_inner_outputs .append (input .update )
1061+ n_untraced_sit_sot_outs += 1
10281062 else :
10291063 no_update_shared_inputs .append (input )
10301064
@@ -1092,7 +1126,7 @@ def wrap_into_list(x):
10921126 + mit_mot_inner_inputs
10931127 + mit_sot_inner_inputs
10941128 + sit_sot_inner_inputs
1095- + shared_inner_inputs
1129+ + untraced_sit_sot_inner_inputs
10961130 + other_shared_inner_args
10971131 + other_inner_args
10981132 )
@@ -1102,7 +1136,7 @@ def wrap_into_list(x):
11021136 + mit_sot_inner_outputs
11031137 + sit_sot_inner_outputs
11041138 + nit_sot_inner_outputs
1105- + shared_inner_outputs
1139+ + untraced_sit_sot_inner_outputs
11061140 )
11071141 if condition is not None :
11081142 inner_outs .append (condition )
@@ -1122,7 +1156,7 @@ def wrap_into_list(x):
11221156 mit_mot_out_slices = tuple (tuple (v ) for v in mit_mot_out_slices ),
11231157 mit_sot_in_slices = tuple (tuple (v ) for v in mit_sot_tap_array ),
11241158 sit_sot_in_slices = tuple ((- 1 ,) for x in range (n_sit_sot )),
1125- n_shared_outs = n_shared_outs ,
1159+ n_untraced_sit_sot_outs = n_untraced_sit_sot_outs ,
11261160 n_nit_sot = n_nit_sot ,
11271161 n_non_seqs = len (other_shared_inner_args ) + len (other_inner_args ),
11281162 as_while = as_while ,
@@ -1148,7 +1182,7 @@ def wrap_into_list(x):
11481182 + mit_mot_scan_inputs
11491183 + mit_sot_scan_inputs
11501184 + sit_sot_scan_inputs
1151- + shared_scan_inputs
1185+ + untraced_sit_sot_scan_inputs
11521186 + [actual_n_steps for x in range (n_nit_sot )]
11531187 + other_shared_scan_args
11541188 + other_scan_args
@@ -1206,13 +1240,28 @@ def remove_dimensions(outs, steps_return, offsets=None):
12061240 )
12071241
12081242 offset += n_nit_sot
1209- for idx , update_rule in enumerate (scan_outs [offset : offset + n_shared_outs ]):
1210- update_map [shared_scan_inputs [idx ]] = update_rule
12111243
1212- _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1244+ # Legacy support for explicit untraced sit_sot and those built with update dictionary
1245+ # Switch to n_untraced_sit_sot_outs after deprecation period
1246+ n_explicit_untraced_sit_sot_outs = len (untraced_sit_sot_rightOrder )
1247+ untraced_sit_sot_outs = scan_outs [
1248+ offset : offset + n_explicit_untraced_sit_sot_outs
1249+ ]
1250+
1251+ # Legacy support: map shared outputs to their updates
1252+ offset += n_explicit_untraced_sit_sot_outs
1253+ for idx , update_rule in enumerate (scan_outs [offset :]):
1254+ update_map [untraced_sit_sot_scan_inputs [idx ]] = update_rule
1255+
1256+ _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
12131257 # Step 10. I need to reorder the outputs to be in the order expected by
12141258 # the user
1215- rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1259+ rightOrder = (
1260+ mit_sot_rightOrder
1261+ + sit_sot_rightOrder
1262+ + untraced_sit_sot_rightOrder
1263+ + nit_sot_rightOrder
1264+ )
12161265 scan_out_list = [None ] * len (rightOrder )
12171266 for idx , pos in enumerate (rightOrder ):
12181267 if pos >= 0 :
@@ -1232,4 +1281,4 @@ def remove_dimensions(outs, steps_return, offsets=None):
12321281 elif len (scan_out_list ) == 0 :
12331282 scan_out_list = None
12341283
1235- return ( scan_out_list , update_map )
1284+ return scan_out_list , update_map
0 commit comments