Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/extending/graph_rewriting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ computation graph.
In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`,
and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with
another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you
exercise, try to rewrite :class:`Simplify` using :class:`WalkingGraphRewriter`. (Hint: you
want to use the method it publishes instead of the call to toposort)

Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
Expand Down
4 changes: 0 additions & 4 deletions doc/library/graph/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,3 @@ Guide
.. class:: ReplaceValidate(History, Validator)

.. method:: replace_validate(fgraph, var, new_var, reason=None)

.. class:: NodeFinder(Bookkeeper)

.. class:: PrintListener(object)
94 changes: 0 additions & 94 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,100 +827,6 @@ def validate(self, fgraph):
raise InconsistencyError("Trying to reintroduce a removed node")


class NodeFinder(Bookkeeper):
def __init__(self):
self.fgraph = None
self.d = {}

def on_attach(self, fgraph):
if hasattr(fgraph, "get_nodes"):
raise AlreadyThere("NodeFinder is already present")

if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.")

self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph)

def clone(self):
return type(self)()

def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.fgraph is not fgraph:
raise Exception(
"This NodeFinder instance was not attached to the provided fgraph."
)
self.fgraph = None
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)

def on_import(self, fgraph, node, reason):
try:
self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
return
except Exception as e:
print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201
try:
print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201
except Exception:
print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201
raise e

def on_prune(self, fgraph, node, reason):
try:
nodes = self.d[node.op]
except TypeError: # node.op is unhashable
return
nodes.remove(node)
if not nodes:
del self.d[node.op]

def query(self, fgraph, op):
try:
all = self.d.get(op, [])
except TypeError:
raise TypeError(
f"{op} in unhashable and cannot be queried by the optimizer"
)
all = list(all)
return all


class PrintListener(Feature):
def __init__(self, active=True):
self.active = active

def on_attach(self, fgraph):
if self.active:
print("-- attaching to: ", fgraph) # noqa: T201

def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.active:
print("-- detaching from: ", fgraph) # noqa: T201

def on_import(self, fgraph, node, reason):
if self.active:
print(f"-- importing: {node}, reason: {reason}") # noqa: T201

def on_prune(self, fgraph, node, reason):
if self.active:
print(f"-- pruning: {node}, reason: {reason}") # noqa: T201

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") # noqa: T201


class PreserveVariableAttributes(Feature):
"""
This preserve some variables attributes and tag during optimization.
Expand Down
Loading