66from collections import defaultdict , deque
77from collections .abc import Generator , Sequence
88from functools import cache , reduce
9+ from operator import or_
910from typing import Literal
1011from warnings import warn
1112
2930)
3031from pytensor .graph .rewriting .db import SequenceDB
3132from pytensor .graph .rewriting .unify import OpPattern
32- from pytensor .graph .traversal import ancestors , toposort
33+ from pytensor .graph .traversal import toposort
3334from pytensor .graph .utils import InconsistencyError , MethodNotDefined
3435from pytensor .scalar .math import Grad2F1Loop , _grad_2f1_loop
3536from pytensor .tensor .basic import (
@@ -659,16 +660,9 @@ def find_fuseable_subgraph(
659660 visited_nodes : set [Apply ],
660661 fuseable_clients : FUSEABLE_MAPPING ,
661662 unfuseable_clients : UNFUSEABLE_MAPPING ,
663+ ancestors_bitset : dict [Apply , int ],
662664 toposort_index : dict [Apply , int ],
663665 ) -> tuple [list [Variable ], list [Variable ]]:
664- def variables_depend_on (
665- variables , depend_on , stop_search_at = None
666- ) -> bool :
667- return any (
668- a in depend_on
669- for a in ancestors (variables , blockers = stop_search_at )
670- )
671-
672666 for starting_node in toposort_index :
673667 if starting_node in visited_nodes :
674668 continue
@@ -680,7 +674,8 @@ def variables_depend_on(
680674
681675 subgraph_inputs : dict [Variable , Literal [None ]] = {} # ordered set
682676 subgraph_outputs : dict [Variable , Literal [None ]] = {} # ordered set
683- unfuseable_clients_subgraph : set [Variable ] = set ()
677+ subgraph_inputs_ancestors_bitset = 0
678+ unfuseable_clients_subgraph_bitset = 0
684679
685680 # If we need to manipulate the maps in place, we'll do a shallow copy later
686681 # For now we query on the original ones
@@ -712,50 +707,32 @@ def variables_depend_on(
712707 if must_become_output :
713708 subgraph_outputs .pop (next_out , None )
714709
715- required_unfuseable_inputs = [
716- inp
717- for inp in next_node .inputs
718- if next_node in unfuseable_clients_clone .get (inp )
719- ]
720- new_required_unfuseable_inputs = [
721- inp
722- for inp in required_unfuseable_inputs
723- if inp not in subgraph_inputs
724- ]
725-
726- must_backtrack = False
727- if new_required_unfuseable_inputs and subgraph_outputs :
728- # We need to check that any new inputs required by this node
729- # do not depend on other outputs of the current subgraph,
730- # via an unfuseable path.
731- if variables_depend_on (
732- [next_out ],
733- depend_on = unfuseable_clients_subgraph ,
734- stop_search_at = subgraph_outputs ,
735- ):
736- must_backtrack = True
710+ # We need to check that any inputs required by this node
711+ # do not depend on other outputs of the current subgraph,
712+ # via an unfuseable path.
713+ must_backtrack = (
714+ ancestors_bitset [next_node ]
715+ & unfuseable_clients_subgraph_bitset
716+ )
737717
738718 if not must_backtrack :
739- implied_unfuseable_clients = {
740- c
741- for client in unfuseable_clients_clone .get (next_out )
742- if not isinstance (client .op , Output )
743- for c in client .outputs
744- }
745-
746- new_implied_unfuseable_clients = (
747- implied_unfuseable_clients - unfuseable_clients_subgraph
719+ implied_unfuseable_clients_bitset = reduce (
720+ or_ ,
721+ (
722+ 1 << toposort_index [client ]
723+ for client in unfuseable_clients_clone .get (next_out )
724+ if not isinstance (client .op , Output )
725+ ),
726+ 0 ,
748727 )
749728
750- if new_implied_unfuseable_clients and subgraph_inputs :
751- # We need to check that any inputs of the current subgraph
752- # do not depend on other clients of this node,
753- # via an unfuseable path.
754- if variables_depend_on (
755- subgraph_inputs ,
756- depend_on = new_implied_unfuseable_clients ,
757- ):
758- must_backtrack = True
729+ # We need to check that any inputs of the current subgraph
730+ # do not depend on other clients of this node,
731+ # via an unfuseable path.
732+ must_backtrack = (
733+ subgraph_inputs_ancestors_bitset
734+ & implied_unfuseable_clients_bitset
735+ )
759736
760737 if must_backtrack :
761738 for inp in next_node .inputs :
@@ -796,29 +773,24 @@ def variables_depend_on(
796773 # immediate dependency problems. Update subgraph
797774 # mappings as if it next_node was part of it.
798775 # Useless inputs will be removed by the useless Composite rewrite
799- for inp in new_required_unfuseable_inputs :
800- subgraph_inputs [inp ] = None
801-
802776 if must_become_output :
803777 subgraph_outputs [next_out ] = None
804- unfuseable_clients_subgraph . update (
805- new_implied_unfuseable_clients
778+ unfuseable_clients_subgraph_bitset |= (
779+ implied_unfuseable_clients_bitset
806780 )
807781
808- # Expand through unvisited fuseable ancestors
809- fuseable_nodes_to_visit .extendleft (
810- sorted (
811- (
812- inp .owner
813- for inp in next_node .inputs
814- if (
815- inp not in required_unfuseable_inputs
816- and inp .owner not in visited_nodes
817- )
818- ),
819- key = toposort_index .get , # type: ignore[arg-type]
820- )
821- )
782+ for inp in sorted (
783+ next_node .inputs ,
784+ key = lambda x : toposort_index .get (x .owner , - 1 ),
785+ ):
786+ if next_node in unfuseable_clients_clone .get (inp , ()):
787+ # input must become an input of the subgraph since it's unfuseable with new node
788+ subgraph_inputs_ancestors_bitset |= (
789+ ancestors_bitset .get (inp .owner , 0 )
790+ )
791+ subgraph_inputs [inp ] = None
792+ elif inp .owner not in visited_nodes :
793+ fuseable_nodes_to_visit .appendleft (inp .owner )
822794
823795 # Expand through unvisited fuseable clients
824796 fuseable_nodes_to_visit .extend (
@@ -855,6 +827,8 @@ def update_fuseable_mappings_after_fg_replace(
855827 visited_nodes : set [Apply ],
856828 fuseable_clients : FUSEABLE_MAPPING ,
857829 unfuseable_clients : UNFUSEABLE_MAPPING ,
830+ toposort_index : dict [Apply , int ],
831+ ancestors_bitset : dict [Apply , int ],
858832 starting_nodes : set [Apply ],
859833 updated_nodes : set [Apply ],
860834 ) -> None :
@@ -865,11 +839,25 @@ def update_fuseable_mappings_after_fg_replace(
865839 dropped_nodes = starting_nodes - updated_nodes
866840
867841 # Remove intermediate Composite nodes from mappings
842+ # And compute the ancestors bitset of the new composite node
843+ # As well as the new toposort index for the new node
844+ new_node_ancestor_bitset = 0
845+ new_node_toposort_index = len (toposort_index )
868846 for dropped_node in dropped_nodes :
869847 (dropped_out ,) = dropped_node .outputs
870848 fuseable_clients .pop (dropped_out , None )
871849 unfuseable_clients .pop (dropped_out , None )
872850 visited_nodes .remove (dropped_node )
851+ # The new composite ancestor bitset is the union
852+ # of the ancestors of all the dropped nodes
853+ new_node_ancestor_bitset |= ancestors_bitset [dropped_node ]
854+ # The new composite node can have the same order as the latest node that was absorbed into it
855+ new_node_toposort_index = max (
856+ new_node_toposort_index , toposort_index [dropped_node ]
857+ )
858+
859+ ancestors_bitset [new_composite_node ] = new_node_ancestor_bitset
860+ toposort_index [new_composite_node ] = new_node_toposort_index
873861
874862 # Update fuseable information for subgraph inputs
875863 for inp in subgraph_inputs :
@@ -901,12 +889,23 @@ def update_fuseable_mappings_after_fg_replace(
901889 fuseable_clients , unfuseable_clients = initialize_fuseable_mappings (fg = fg )
902890 visited_nodes : set [Apply ] = set ()
903891 toposort_index = {node : i for i , node in enumerate (fgraph .toposort ())}
892+ # Create a bitset for each node of all its ancestors
893+ # This allows to quickly check if a variable depends on a set
894+ ancestors_bitset : dict [Apply , int ] = {}
895+ for node , index in toposort_index .items ():
896+ node_ancestor_bitset = 1 << index
897+ for inp in node .inputs :
898+ if (inp_node := inp .owner ) is not None :
899+ node_ancestor_bitset |= ancestors_bitset [inp_node ]
900+ ancestors_bitset [node ] = node_ancestor_bitset
901+
904902 while True :
905903 try :
906904 subgraph_inputs , subgraph_outputs = find_fuseable_subgraph (
907905 visited_nodes = visited_nodes ,
908906 fuseable_clients = fuseable_clients ,
909907 unfuseable_clients = unfuseable_clients ,
908+ ancestors_bitset = ancestors_bitset ,
910909 toposort_index = toposort_index ,
911910 )
912911 except ValueError :
@@ -925,6 +924,8 @@ def update_fuseable_mappings_after_fg_replace(
925924 visited_nodes = visited_nodes ,
926925 fuseable_clients = fuseable_clients ,
927926 unfuseable_clients = unfuseable_clients ,
927+ toposort_index = toposort_index ,
928+ ancestors_bitset = ancestors_bitset ,
928929 starting_nodes = starting_nodes ,
929930 updated_nodes = fg .apply_nodes ,
930931 )
0 commit comments