11import warnings
2+ from itertools import chain
23
34import numpy as np
45
910from pytensor .graph .basic import Constant , Variable
1011from pytensor .graph .op import get_test_value
1112from pytensor .graph .replace import clone_replace
12- from pytensor .graph .traversal import graph_inputs
13+ from pytensor .graph .traversal import explicit_graph_inputs
1314from pytensor .graph .utils import MissingInputError , TestValueError
1415from pytensor .scan .op import Scan , ScanInfo
1516from pytensor .scan .utils import expand_empty , safe_new , until
@@ -475,19 +476,15 @@ def wrap_into_list(x):
475476 else :
476477 non_seqs .append (elem )
477478
478- # If we provided a known number of steps ( before compilation)
479- # and if that number is 1 or -1, then we can skip the Scan Op,
480- # and just apply the inner function once
481- # To do that we check here to see the nature of n_steps
482- n_fixed_steps = None
483-
479+ # This helper eagerly skips the Scan if n_steps is known to be 1
480+ single_step_requested = False
484481 if isinstance (n_steps , float | int ):
485- n_fixed_steps = int ( n_steps )
482+ single_step_requested = n_steps == 1
486483 else :
487484 try :
488- n_fixed_steps = pt .get_scalar_constant_value (n_steps )
485+ single_step_requested = pt .get_scalar_constant_value (n_steps ) == 1
489486 except NotScalarConstantError :
490- n_fixed_steps = None
487+ pass
491488
492489 # Check n_steps is an int
493490 if hasattr (n_steps , "dtype" ) and str (n_steps .dtype ) not in integer_dtypes :
@@ -497,7 +494,6 @@ def wrap_into_list(x):
497494 n_seqs = len (seqs )
498495 n_outs = len (outs_info )
499496
500- return_steps = {}
501497 # wrap sequences in a dictionary if they are not already dictionaries
502498 for i in range (n_seqs ):
503499 if not isinstance (seqs [i ], dict ):
@@ -700,7 +696,6 @@ def wrap_into_list(x):
700696 mit_sot_inner_inputs = []
701697 mit_sot_inner_slices = []
702698 mit_sot_inner_outputs = []
703- mit_sot_return_steps = {}
704699 mit_sot_tap_array = []
705700 mit_sot_rightOrder = []
706701
@@ -709,7 +704,6 @@ def wrap_into_list(x):
709704 sit_sot_inner_inputs = []
710705 sit_sot_inner_slices = []
711706 sit_sot_inner_outputs = []
712- sit_sot_return_steps = {}
713707 sit_sot_rightOrder = []
714708
715709 # go through outputs picking up time slices as needed
@@ -755,8 +749,6 @@ def wrap_into_list(x):
755749 )
756750
757751 sit_sot_inner_slices .append (actual_arg )
758- if i in return_steps :
759- sit_sot_return_steps [n_sit_sot ] = return_steps [i ]
760752 sit_sot_inner_inputs .append (arg )
761753 sit_sot_rightOrder .append (i )
762754 n_sit_sot += 1
@@ -774,8 +766,6 @@ def wrap_into_list(x):
774766 expand_empty (init_out ["initial" ][:mintap ], actual_n_steps )
775767 )
776768
777- if i in return_steps :
778- mit_sot_return_steps [n_mit_sot ] = return_steps [i ]
779769 mit_sot_rightOrder .append (i )
780770 n_mit_sot += 1
781771 for k in init_out ["taps" ]:
@@ -819,7 +809,7 @@ def wrap_into_list(x):
819809 offset = 0
820810 for idx in range (n_mit_sot ):
821811 n_inputs = len (mit_sot_tap_array [idx ])
822- if n_fixed_steps in ( 1 , - 1 ) :
812+ if single_step_requested :
823813 _ordered_args [mit_sot_rightOrder [idx ]] = mit_sot_inner_slices [
824814 offset : offset + n_inputs
825815 ]
@@ -830,17 +820,14 @@ def wrap_into_list(x):
830820 offset += n_inputs
831821
832822 for idx in range (n_sit_sot ):
833- if n_fixed_steps in ( 1 , - 1 ) :
823+ if single_step_requested :
834824 _ordered_args [sit_sot_rightOrder [idx ]] = [sit_sot_inner_slices [idx ]]
835825 else :
836826 _ordered_args [sit_sot_rightOrder [idx ]] = [sit_sot_inner_inputs [idx ]]
837827
838- ordered_args = []
839- for ls in _ordered_args :
840- ordered_args += ls
841- if n_fixed_steps in (1 , - 1 ):
828+ ordered_args = list (chain .from_iterable (_ordered_args ))
829+ if single_step_requested :
842830 args = inner_slices + ordered_args + non_seqs
843-
844831 else :
845832 args = inner_seqs + ordered_args + non_seqs
846833
@@ -863,15 +850,15 @@ def wrap_into_list(x):
863850 # Step 3. Check if we actually need scan and remove it if we don't
864851 ##
865852
866- if n_fixed_steps in ( 1 , - 1 ) :
853+ if single_step_requested :
867854 for pos , inner_out in enumerate (outputs ):
868855 # we need to see if we need to pad our sequences with an
869856 # extra dimension; case example : we return an
870857 # output for which we want all intermediate. If n_steps is 1
871858 # then, if we return the output as given by the innner function
872859 # this will represent only a slice and it will have one
873860 # dimension less.
874- if isinstance (inner_out .type , TensorType ) and return_steps . get ( pos , 0 ) != 1 :
861+ if isinstance (inner_out .type , TensorType ):
875862 outputs [pos ] = shape_padleft (inner_out )
876863
877864 if not return_list and len (outputs ) == 1 :
@@ -896,15 +883,10 @@ def wrap_into_list(x):
896883 fake_outputs = clone_replace (
897884 outputs , replace = dict (zip (non_seqs , fake_nonseqs , strict = True ))
898885 )
899- all_inputs = filter (
900- lambda x : (
901- isinstance (x , Variable )
902- and not isinstance (x , SharedVariable )
903- and not isinstance (x , Constant )
904- ),
905- graph_inputs (fake_outputs ),
906- )
907- extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs ]
886+ known_inputs = [* args , * fake_nonseqs ]
887+ extra_inputs = [
888+ x for x in explicit_graph_inputs (fake_outputs ) if x not in known_inputs
889+ ]
908890 non_seqs += extra_inputs
909891 # Note we do not use all_inputs directly since the order of variables
910892 # in args is quite important
@@ -1033,13 +1015,10 @@ def wrap_into_list(x):
10331015 # Step 5.4 Outputs with no taps used in the input
10341016 n_nit_sot = 0
10351017 nit_sot_inner_outputs = []
1036- nit_sot_return_steps = {}
10371018 nit_sot_rightOrder = []
10381019 for i , out in enumerate (outs_info ):
10391020 if "taps" not in out :
10401021 nit_sot_inner_outputs .append (outputs [i ])
1041- if i in return_steps :
1042- nit_sot_return_steps [n_nit_sot ] = return_steps [i ]
10431022 nit_sot_rightOrder .append (i )
10441023 n_nit_sot += 1
10451024
@@ -1173,37 +1152,25 @@ def wrap_into_list(x):
11731152
11741153 update_map = OrderedUpdates ()
11751154
1176- def remove_dimensions (outs , steps_return , offsets = None ):
1155+ def remove_dimensions (outs , offsets = None ):
11771156 out_ls = []
11781157 for idx , out in enumerate (outs ):
1179- if idx in steps_return :
1180- if steps_return [idx ] > 1 :
1181- out_ls .append (out [- steps_return [idx ] :])
1182- else :
1183- out_ls .append (out [- 1 ])
1158+ if offsets is None :
1159+ out_ls .append (out )
11841160 else :
1185- if offsets is None :
1186- out_ls .append (out )
1187- else :
1188- out_ls .append (out [offsets [idx ] :])
1161+ out_ls .append (out [offsets [idx ] :])
11891162 return out_ls
11901163
11911164 offset = n_mit_mot
11921165 offsets = [abs (np .min (x )) for x in mit_sot_tap_array ]
1193- mit_sot_outs = remove_dimensions (
1194- scan_outs [offset : offset + n_mit_sot ], mit_sot_return_steps , offsets
1195- )
1166+ mit_sot_outs = remove_dimensions (scan_outs [offset : offset + n_mit_sot ], offsets )
11961167
11971168 offset += n_mit_sot
11981169 offsets = [1 for x in range (n_sit_sot )]
1199- sit_sot_outs = remove_dimensions (
1200- scan_outs [offset : offset + n_sit_sot ], sit_sot_return_steps , offsets
1201- )
1170+ sit_sot_outs = remove_dimensions (scan_outs [offset : offset + n_sit_sot ], offsets )
12021171
12031172 offset += n_sit_sot
1204- nit_sot_outs = remove_dimensions (
1205- scan_outs [offset : offset + n_nit_sot ], nit_sot_return_steps
1206- )
1173+ nit_sot_outs = remove_dimensions (scan_outs [offset : offset + n_nit_sot ])
12071174
12081175 offset += n_nit_sot
12091176 for idx , update_rule in enumerate (scan_outs [offset : offset + n_shared_outs ]):
@@ -1232,4 +1199,4 @@ def remove_dimensions(outs, steps_return, offsets=None):
12321199 elif len (scan_out_list ) == 0 :
12331200 scan_out_list = None
12341201
1235- return ( scan_out_list , update_map )
1202+ return scan_out_list , update_map
0 commit comments