Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 0 additions & 65 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,71 +827,6 @@ def validate(self, fgraph):
raise InconsistencyError("Trying to reintroduce a removed node")


class NodeFinder(Bookkeeper):
def __init__(self):
self.fgraph = None
self.d = {}

def on_attach(self, fgraph):
if hasattr(fgraph, "get_nodes"):
raise AlreadyThere("NodeFinder is already present")

if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.")

self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph)

def clone(self):
return type(self)()

def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.fgraph is not fgraph:
raise Exception(
"This NodeFinder instance was not attached to the provided fgraph."
)
self.fgraph = None
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)

def on_import(self, fgraph, node, reason):
try:
self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
return
except Exception as e:
print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201
try:
print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201
except Exception:
print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201
raise e

def on_prune(self, fgraph, node, reason):
try:
nodes = self.d[node.op]
except TypeError: # node.op is unhashable
return
nodes.remove(node)
if not nodes:
del self.d[node.op]

def query(self, fgraph, op):
try:
all = self.d.get(op, [])
except TypeError:
raise TypeError(
f"{op} in unhashable and cannot be queried by the optimizer"
)
all = list(all)
return all


class PrintListener(Feature):
def __init__(self, active=True):
self.active = active
Expand Down
Loading
Loading