2929 EquilibriumGraphRewriter ,
3030 GraphRewriter ,
3131 copy_stack_trace ,
32- in2out ,
32+ dfs_rewriter ,
3333 node_rewriter ,
3434)
3535from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
@@ -2558,15 +2558,15 @@ def apply(self, fgraph, start_from=None):
25582558# ScanSaveMem should execute only once per node.
25592559optdb .register (
25602560 "scan_save_mem_prealloc" ,
2561- in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
2561+ dfs_rewriter (scan_save_mem_prealloc , ignore_newtrees = True ),
25622562 "fast_run" ,
25632563 "scan" ,
25642564 "scan_save_mem" ,
25652565 position = 1.61 ,
25662566)
25672567optdb .register (
25682568 "scan_save_mem_no_prealloc" ,
2569- in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2569+ dfs_rewriter (scan_save_mem_no_prealloc , ignore_newtrees = True ),
25702570 "numba" ,
25712571 "jax" ,
25722572 "pytorch" ,
@@ -2587,7 +2587,7 @@ def apply(self, fgraph, start_from=None):
25872587
25882588scan_seqopt1 .register (
25892589 "scan_remove_constants_and_unused_inputs0" ,
2590- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2590+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
25912591 "remove_constants_and_unused_inputs_scan" ,
25922592 "fast_run" ,
25932593 "scan" ,
@@ -2596,7 +2596,7 @@ def apply(self, fgraph, start_from=None):
25962596
25972597scan_seqopt1 .register (
25982598 "scan_push_out_non_seq" ,
2599- in2out (scan_push_out_non_seq , ignore_newtrees = True ),
2599+ dfs_rewriter (scan_push_out_non_seq , ignore_newtrees = True ),
26002600 "scan_pushout_nonseqs_ops" , # For backcompat: so it can be tagged with old name
26012601 "fast_run" ,
26022602 "scan" ,
@@ -2606,7 +2606,7 @@ def apply(self, fgraph, start_from=None):
26062606
26072607scan_seqopt1 .register (
26082608 "scan_push_out_seq" ,
2609- in2out (scan_push_out_seq , ignore_newtrees = True ),
2609+ dfs_rewriter (scan_push_out_seq , ignore_newtrees = True ),
26102610 "scan_pushout_seqs_ops" , # For backcompat: so it can be tagged with old name
26112611 "fast_run" ,
26122612 "scan" ,
@@ -2617,7 +2617,7 @@ def apply(self, fgraph, start_from=None):
26172617
26182618scan_seqopt1 .register (
26192619 "scan_push_out_dot1" ,
2620- in2out (scan_push_out_dot1 , ignore_newtrees = True ),
2620+ dfs_rewriter (scan_push_out_dot1 , ignore_newtrees = True ),
26212621 "scan_pushout_dot1" , # For backcompat: so it can be tagged with old name
26222622 "fast_run" ,
26232623 "more_mem" ,
@@ -2630,7 +2630,7 @@ def apply(self, fgraph, start_from=None):
26302630scan_seqopt1 .register (
26312631 "scan_push_out_add" ,
26322632 # TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2633- in2out (scan_push_out_add , ignore_newtrees = False ),
2633+ dfs_rewriter (scan_push_out_add , ignore_newtrees = False ),
26342634 "scan_pushout_add" , # For backcompat: so it can be tagged with old name
26352635 "fast_run" ,
26362636 "more_mem" ,
@@ -2641,22 +2641,22 @@ def apply(self, fgraph, start_from=None):
26412641
26422642scan_eqopt2 .register (
26432643 "while_scan_merge_subtensor_last_element" ,
2644- in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2644+ dfs_rewriter (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
26452645 "fast_run" ,
26462646 "scan" ,
26472647)
26482648
26492649scan_eqopt2 .register (
26502650 "constant_folding_for_scan2" ,
2651- in2out (constant_folding , ignore_newtrees = True ),
2651+ dfs_rewriter (constant_folding , ignore_newtrees = True ),
26522652 "fast_run" ,
26532653 "scan" ,
26542654)
26552655
26562656
26572657scan_eqopt2 .register (
26582658 "scan_remove_constants_and_unused_inputs1" ,
2659- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2659+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26602660 "remove_constants_and_unused_inputs_scan" ,
26612661 "fast_run" ,
26622662 "scan" ,
@@ -2671,23 +2671,23 @@ def apply(self, fgraph, start_from=None):
26712671# After Merge optimization
26722672scan_eqopt2 .register (
26732673 "scan_remove_constants_and_unused_inputs2" ,
2674- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2674+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26752675 "remove_constants_and_unused_inputs_scan" ,
26762676 "fast_run" ,
26772677 "scan" ,
26782678)
26792679
26802680scan_eqopt2 .register (
26812681 "scan_merge_inouts" ,
2682- in2out (scan_merge_inouts , ignore_newtrees = True ),
2682+ dfs_rewriter (scan_merge_inouts , ignore_newtrees = True ),
26832683 "fast_run" ,
26842684 "scan" ,
26852685)
26862686
26872687# After everything else
26882688scan_eqopt2 .register (
26892689 "scan_remove_constants_and_unused_inputs3" ,
2690- in2out (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
2690+ dfs_rewriter (remove_constants_and_unused_inputs_scan , ignore_newtrees = True ),
26912691 "remove_constants_and_unused_inputs_scan" ,
26922692 "fast_run" ,
26932693 "scan" ,
0 commit comments