diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 0a36ceadb7..3d4ae44092 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -1073,6 +1073,7 @@ def __init__(self): defaultdict(lambda: defaultdict(list)) ) self.untracked_rewrites: list[NodeRewriter] = [] + self.get_trackers = functools.cache(self._get_trackers) self._cached_composed_mro = None def add_tracker(self, rw: NodeRewriter): @@ -1080,6 +1081,7 @@ def add_tracker(self, rw: NodeRewriter): if self._cached_composed_mro is not None: # We shouldn't actually add_trackers after the first call to get_trackers # But just to be safe we kill the cache here + self.get_trackers = functools.cache(self._get_trackers) self._cached_composed_mro = None tracks = rw.tracks() @@ -1107,8 +1109,7 @@ def add_tracker(self, rw: NodeRewriter): else: self.tracked_instances[c].append(rw) - @functools.cache - def get_trackers(self, op: Op) -> list[NodeRewriter]: + def _get_trackers(self, op: Op) -> list[NodeRewriter]: """Get all the rewrites applicable to an `Op`.""" if self._cached_composed_mro is None: diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index 07c518af93..aef4ad7a18 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -1,6 +1,10 @@ +import gc +import operator + import pytest from pytensor.configdefaults import config +from pytensor.graph import rewrite_graph from pytensor.graph.basic import Apply, Constant, equal_computations from pytensor.graph.features import Feature from pytensor.graph.fg import FunctionGraph @@ -930,3 +934,44 @@ def perform(self, *args): local_rewriter_2, local_rewriter_1, ] + + +def test_rewrite_weakref_leak(): + """Check we don't have weakref leak on our rewrites""" + + def _growth(limit=10, peak_stats={}): + """Vendoring of objgraph.growth + + Source: https://github.com/mgedmin/objgraph/blob/94b1ca61a11109547442701800292dcfc7f59fc8/objgraph.py#L253 + """ + gc.collect() + objects = gc.get_objects() + + stats = {} + for o in objects: + n = type(o).__name__ + stats[n] = stats.get(n, 0) + 1 + + deltas = {} + for name, count in stats.items(): + old_count = peak_stats.get(name, 0) + if count > old_count: + deltas[name] = count - old_count + peak_stats[name] = count + + deltas = sorted(deltas.items(), key=operator.itemgetter(1), reverse=True) + + if limit: + deltas = deltas[:limit] + + return [(name, stats[name], delta) for name, delta in deltas] + + x = vector("x") + y = exp(x) + + for i in range(20): + rewrite_graph(y, clone=False) + res = _growth() + # Only start checking after warmup + if i > 15: + assert not res, "Object counts are still growing"