Skip to content

Commit 11f18df

Browse files
committed
Try helper classes with bitset (reduce dict access)
1 parent f2683c9 commit 11f18df

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,9 @@ def push(self, node: Apply, node_bitflag: int, is_ancestor: bool):
675675

676676
def pop(self) -> tuple[Apply, int, bool]:
677677
node_bitflag, node = heappop(self.queue)
678-
return node, node_bitflag < 0
678+
if node_bitflag < 0:
679+
return node, -node_bitflag, True
680+
return node, node_bitflag, False
679681

680682
def __bool__(self):
681683
return bool(self.queue)
@@ -685,7 +687,6 @@ class NonConvexError(Exception):
685687

686688
class ConvexSubgraph:
687689
__slots__ = (
688-
"nodes_bitflags",
689690
"ancestors_bitset",
690691
"nodes",
691692
"nodes_bitset",
@@ -694,8 +695,7 @@ class ConvexSubgraph:
694695
"inputs_and_outputs",
695696
)
696697

697-
def __init__(self, nodes_bitflags, ancestors_bitset):
698-
self.nodes_bitflags = nodes_bitflags
698+
def __init__(self, ancestors_bitset):
699699
self.ancestors_bitset = ancestors_bitset
700700
self.nodes = {}
701701
self.nodes_bitset = 0
@@ -706,11 +706,10 @@ def __init__(self, nodes_bitflags, ancestors_bitset):
706706
def __len__(self):
707707
return len(self.nodes)
708708

709-
def __contains__(self, node: int):
710-
return bool(self.nodes_bitset & self.nodes_bitflags[node])
709+
def __contains__(self, node_bitflag: int):
710+
return bool(self.nodes_bitset & node_bitflag)
711711

712-
def add_node(self, node: Apply, is_ancestor: bool):
713-
node_bitflag = self.nodes_bitflags[node]
712+
def add_node(self, node: Apply, node_bitflag, is_ancestor: bool):
714713
if is_ancestor:
715714
if node_bitflag & self.unfuseable_ancestors_bitset:
716715
raise NonConvexError
@@ -720,14 +719,14 @@ def add_node(self, node: Apply, is_ancestor: bool):
720719
self.nodes[node] = None
721720
self.inputs_and_outputs = None # clear cache
722721

723-
def add_unfuseable_ancestor(self, ancestor: Apply):
724-
# If an ancestor is unfuseable, so are all its ancestors
725-
self.unfuseable_ancestors_bitset |= self.ancestors_bitset[ancestor]
726-
727-
def add_unfuseable_client(self, client: Apply):
728-
# If a client is unfuseable, so are all its clients, but we don't need to keep track of those
729-
# Any downstream client will also depend on this unfuseable client and will be rejected when visited
730-
self.unfuseable_clients_bitset |= self.nodes_bitflags[client]
722+
# def add_unfuseable_ancestor(self, ancestor: Apply):
723+
# # If an ancestor is unfuseable, so are all its ancestors
724+
# self.unfuseable_ancestors_bitset |= self.ancestors_bitset[ancestor]
725+
#
726+
# def add_unfuseable_client(self, client: Apply):
727+
# # If a client is unfuseable, so are all its clients, but we don't need to keep track of those
728+
# # Any downstream client will also depend on this unfuseable client and will be rejected when visited
729+
# self.unfuseable_clients_bitset |= self.nodes_bitflags[client]
731730

732731
def get_inputs_and_outputs(self):
733732
if self.inputs_and_outputs is not None:
@@ -842,20 +841,20 @@ def __iter__(self):
842841
):
843842
continue
844843

845-
subgraph = ConvexSubgraph(nodes_bitflags, ancestors_bitset)
844+
subgraph = ConvexSubgraph(ancestors_bitset)
846845

847846
fuseable_nodes_queue = SortedFuseableNodesQueue()
848847
fuseable_nodes_queue.push(
849848
starting_node, starting_bitflag, is_ancestor=True
850849
)
851850
while fuseable_nodes_queue:
852-
node, is_ancestor = fuseable_nodes_queue.pop()
851+
node, node_bitflag, is_ancestor = fuseable_nodes_queue.pop()
853852

854-
if node in subgraph:
853+
if node_bitflag in subgraph:
855854
continue
856855

857856
try:
858-
subgraph.add_node(node, is_ancestor=is_ancestor)
857+
subgraph.add_node(node, node_bitflag, is_ancestor=is_ancestor)
859858
except NonConvexError:
860859
continue
861860

@@ -866,29 +865,33 @@ def __iter__(self):
866865
# - unfuseable (add to respective unfuseable bitset)
867866
for ancestor in node.inputs:
868867
ancestor_node = ancestor.owner
869-
if ancestor_node in subgraph:
868+
ancestor_bitset = nodes_bitflags[ancestor_node]
869+
if ancestor_bitset in subgraph:
870870
continue
871871
if node in fuseable_clients[ancestor_node]:
872872
fuseable_nodes_queue.push(
873873
ancestor_node,
874-
nodes_bitflags[ancestor_node],
874+
ancestor_bitset,
875875
is_ancestor=True,
876876
)
877877
else:
878-
subgraph.add_unfuseable_ancestor(ancestor_node)
878+
subgraph.unfuseable_ancestors_bitset |= ancestors_bitset[
879+
ancestor_node
880+
]
879881

880882
next_fuseable_clients = fuseable_clients[node]
881883
for client_node, _ in fg_clients[node.outputs[0]]:
882-
if client_node in subgraph:
884+
client_bitflag = nodes_bitflags[client_node]
885+
if client_bitflag in subgraph:
883886
continue
884887
if client_node in next_fuseable_clients:
885888
fuseable_nodes_queue.push(
886889
client_node,
887-
nodes_bitflags[client_node],
890+
client_bitflag,
888891
is_ancestor=False,
889892
)
890893
else:
891-
subgraph.add_unfuseable_client(client_node)
894+
subgraph.unfuseable_clients_bitset |= client_bitflag
892895

893896
# Finished exploring this subgraph
894897
if len(subgraph) == 1:

0 commit comments

Comments
 (0)