@@ -179,54 +179,47 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
179179
180180 """
181181 # this co-recurses with clone_a
182- stack = [v ]
183- try :
184- while True :
185- v = stack .pop ()
186- if v in clone_d :
187- continue
188- if (apply := v .owner ) is not None :
189- if all (i in clone_d for i in apply .inputs ):
190- # all inputs have been cloned, we can clone this node
191- clone_node_and_cache (
192- apply ,
193- clone_d ,
194- strict = rebuild_strict ,
195- clone_inner_graphs = clone_inner_graphs ,
182+ assert v is not None
183+ if v in clone_d :
184+ return clone_d [v ]
185+ if v .owner :
186+ owner = v .owner
187+ if owner not in clone_d :
188+ for i in owner .inputs :
189+ clone_v_get_shared_updates (i , copy_inputs_over )
190+ clone_node_and_cache (
191+ owner ,
192+ clone_d ,
193+ strict = rebuild_strict ,
194+ clone_inner_graphs = clone_inner_graphs ,
195+ )
196+ return clone_d .setdefault (v , v )
197+ elif isinstance (v , SharedVariable ):
198+ if v not in shared_inputs :
199+ shared_inputs .append (v )
200+ if v .default_update is not None :
201+ # Check that v should not be excluded from the default
202+ # updates list
203+ if no_default_updates is False or (
204+ isinstance (no_default_updates , list ) and v not in no_default_updates
205+ ):
206+ # Do not use default_update if a "real" update was
207+ # provided
208+ if v not in update_d :
209+ v_update = v .type .filter_variable (
210+ v .default_update , allow_convert = False
196211 )
197- else :
198- # expand on the inputs
199- stack .extend (apply .inputs )
200- else :
201- clone_d [v ] = v if copy_inputs_over else v .clone ()
202-
203- # Special handling of SharedVariables
204- if isinstance (v , SharedVariable ):
205- if v not in shared_inputs :
206- shared_inputs .append (v )
207- if v .default_update is not None :
208- # Check that v should not be excluded from the default
209- # updates list
210- if no_default_updates is False or (
211- isinstance (no_default_updates , list )
212- and v not in no_default_updates
213- ):
214- # Do not use default_update if a "real" update was
215- # provided
216- if v not in update_d :
217- v_update = v .type .filter_variable (
218- v .default_update , allow_convert = False
219- )
220- if not v .type .is_super (v_update .type ):
221- raise TypeError (
222- "An update must have a type compatible with "
223- "the original shared variable"
224- )
225- update_d [v ] = v_update
226- update_expr .append ((v , v_update ))
227- except IndexError :
228- pass # stack is empty
229- return clone_d [v ]
212+ if not v .type .is_super (v_update .type ):
213+ raise TypeError (
214+ "An update must have a type compatible with "
215+ "the original shared variable"
216+ )
217+ update_d [v ] = v_update
218+ update_expr .append ((v , v_update ))
219+ if not copy_inputs_over :
220+ return clone_d .setdefault (v , v .clone ())
221+ else :
222+ return clone_d .setdefault (v , v )
230223
231224 # initialize the clone_d mapping with the replace dictionary
232225 if replace is None :
0 commit comments