@@ -675,7 +675,9 @@ def push(self, node: Apply, node_bitflag: int, is_ancestor: bool):
675675
676676 def pop (self ) -> tuple [Apply , int , bool ]:
677677 node_bitflag , node = heappop (self .queue )
678- return node , node_bitflag < 0
678+ if node_bitflag < 0 :
679+ return node , - node_bitflag , True
680+ return node , node_bitflag , False
679681
680682 def __bool__ (self ):
681683 return bool (self .queue )
@@ -685,7 +687,6 @@ class NonConvexError(Exception):
685687
686688 class ConvexSubgraph :
687689 __slots__ = (
688- "nodes_bitflags" ,
689690 "ancestors_bitset" ,
690691 "nodes" ,
691692 "nodes_bitset" ,
@@ -694,8 +695,7 @@ class ConvexSubgraph:
694695 "inputs_and_outputs" ,
695696 )
696697
697- def __init__ (self , nodes_bitflags , ancestors_bitset ):
698- self .nodes_bitflags = nodes_bitflags
698+ def __init__ (self , ancestors_bitset ):
699699 self .ancestors_bitset = ancestors_bitset
700700 self .nodes = {}
701701 self .nodes_bitset = 0
@@ -706,11 +706,10 @@ def __init__(self, nodes_bitflags, ancestors_bitset):
706706 def __len__ (self ):
707707 return len (self .nodes )
708708
709- def __contains__ (self , node : int ):
710- return bool (self .nodes_bitset & self . nodes_bitflags [ node ] )
709+ def __contains__ (self , node_bitflag : int ):
710+ return bool (self .nodes_bitset & node_bitflag )
711711
712- def add_node (self , node : Apply , is_ancestor : bool ):
713- node_bitflag = self .nodes_bitflags [node ]
712+ def add_node (self , node : Apply , node_bitflag , is_ancestor : bool ):
714713 if is_ancestor :
715714 if node_bitflag & self .unfuseable_ancestors_bitset :
716715 raise NonConvexError
@@ -720,14 +719,14 @@ def add_node(self, node: Apply, is_ancestor: bool):
720719 self .nodes [node ] = None
721720 self .inputs_and_outputs = None # clear cache
722721
723- def add_unfuseable_ancestor (self , ancestor : Apply ):
724- # If an ancestor is unfuseable, so are all its ancestors
725- self .unfuseable_ancestors_bitset |= self .ancestors_bitset [ancestor ]
726-
727- def add_unfuseable_client (self , client : Apply ):
728- # If a client is unfuseable, so are all its clients, but we don't need to keep track of those
729- # Any downstream client will also depend on this unfuseable client and will be rejected when visited
730- self .unfuseable_clients_bitset |= self .nodes_bitflags [client ]
722+ # def add_unfuseable_ancestor(self, ancestor: Apply):
723+ # # If an ancestor is unfuseable, so are all its ancestors
724+ # self.unfuseable_ancestors_bitset |= self.ancestors_bitset[ancestor]
725+ #
726+ # def add_unfuseable_client(self, client: Apply):
727+ # # If a client is unfuseable, so are all its clients, but we don't need to keep track of those
728+ # # Any downstream client will also depend on this unfuseable client and will be rejected when visited
729+ # self.unfuseable_clients_bitset |= self.nodes_bitflags[client]
731730
732731 def get_inputs_and_outputs (self ):
733732 if self .inputs_and_outputs is not None :
@@ -842,20 +841,20 @@ def __iter__(self):
842841 ):
843842 continue
844843
845- subgraph = ConvexSubgraph (nodes_bitflags , ancestors_bitset )
844+ subgraph = ConvexSubgraph (ancestors_bitset )
846845
847846 fuseable_nodes_queue = SortedFuseableNodesQueue ()
848847 fuseable_nodes_queue .push (
849848 starting_node , starting_bitflag , is_ancestor = True
850849 )
851850 while fuseable_nodes_queue :
852- node , is_ancestor = fuseable_nodes_queue .pop ()
851+ node , node_bitflag , is_ancestor = fuseable_nodes_queue .pop ()
853852
854- if node in subgraph :
853+ if node_bitflag in subgraph :
855854 continue
856855
857856 try :
858- subgraph .add_node (node , is_ancestor = is_ancestor )
857+ subgraph .add_node (node , node_bitflag , is_ancestor = is_ancestor )
859858 except NonConvexError :
860859 continue
861860
@@ -866,29 +865,33 @@ def __iter__(self):
866865 # - unfuseable (add to respective unfuseable bitset)
867866 for ancestor in node .inputs :
868867 ancestor_node = ancestor .owner
869- if ancestor_node in subgraph :
868+ ancestor_bitset = nodes_bitflags [ancestor_node ]
869+ if ancestor_bitset in subgraph :
870870 continue
871871 if node in fuseable_clients [ancestor_node ]:
872872 fuseable_nodes_queue .push (
873873 ancestor_node ,
874- nodes_bitflags [ ancestor_node ] ,
874+ ancestor_bitset ,
875875 is_ancestor = True ,
876876 )
877877 else :
878- subgraph .add_unfuseable_ancestor (ancestor_node )
878+ subgraph .unfuseable_ancestors_bitset |= ancestors_bitset [
879+ ancestor_node
880+ ]
879881
880882 next_fuseable_clients = fuseable_clients [node ]
881883 for client_node , _ in fg_clients [node .outputs [0 ]]:
882- if client_node in subgraph :
884+ client_bitflag = nodes_bitflags [client_node ]
885+ if client_bitflag in subgraph :
883886 continue
884887 if client_node in next_fuseable_clients :
885888 fuseable_nodes_queue .push (
886889 client_node ,
887- nodes_bitflags [ client_node ] ,
890+ client_bitflag ,
888891 is_ancestor = False ,
889892 )
890893 else :
891- subgraph .add_unfuseable_client ( client_node )
894+ subgraph .unfuseable_clients_bitset |= client_bitflag
892895
893896 # Finished exploring this subgraph
894897 if len (subgraph ) == 1 :
0 commit comments