From 61089a9a91b3f36ad5a18d36a1e8dbf0e1ae8968 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 4 Sep 2025 19:25:51 +0200 Subject: [PATCH 01/13] Benchmark another FusionOptimizer graph --- pytensor/tensor/rewriting/elemwise.py | 12 ++++--- tests/tensor/rewriting/test_elemwise.py | 42 +++++++++++++++++++++---- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e2d420f361..0eb2900729 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -569,8 +569,6 @@ def elemwise_to_scalar(inputs, outputs): return scalar_inputs, scalar_outputs def apply(self, fgraph): - nb_replacement = 0 - if fgraph.profile: validate_before = fgraph.profile.validate_time callbacks_before = fgraph.execute_callbacks_times.copy() @@ -925,6 +923,8 @@ def update_fuseable_mappings_after_fg_replace( starting_nodes=starting_nodes, ) + nb_fused = 0 + nb_replacement = 0 for inputs, outputs in find_next_fuseable_subgraph(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( @@ -943,11 +943,13 @@ def update_fuseable_mappings_after_fg_replace( if old_out.name: composite_out.name = old_out.name + starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( list(zip(outputs, composite_outputs, strict=True)), reason=self.__class__.__name__, ) - nb_replacement += 1 + nb_fused += 1 + nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -965,7 +967,7 @@ def update_fuseable_mappings_after_fg_replace( return ( self, - 1, # nb_iter + nb_fused, nb_replacement, 0, # nb_inconsintency_replace validate_time, @@ -978,7 +980,7 @@ def update_fuseable_mappings_after_fg_replace( def print_profile(stream, prof, level=0): blanc = " " * level print(blanc, "FusionOptimizer", file=stream) - print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_fused", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) print(blanc, " validate_time", prof[4], file=stream) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index c23d0ac23a..3c549788e1 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -273,7 +273,8 @@ def my_init(dtype="float64", num=0): fwx = fw + fx ftanx = tan(fx) - def large_fuseable_graph(self, n): + @staticmethod + def large_fuseable_graph(n): factors = [] sd = dscalar() means = dvector() @@ -296,6 +297,24 @@ def large_fuseable_graph(self, n): dlogp = [pytensor.grad(logp, v) for v in vars] return vars, dlogp + @staticmethod + def deep_small_kernels(n): + x = pt.matrix("x") + out = x + for _ in range(n): + out = pt.sin(out.T) + pt.cos(out) + + return [x], [out] + + @staticmethod + def diamond_graph(n): + a = pt.matrix("a") + b = pt.exp(a) + c = pt.log(b) + d = pt.sin(c) + e = c + d + return [a], [e] + @pytest.mark.parametrize( "case", [ @@ -1347,16 +1366,27 @@ def test_eval_benchmark(self, benchmark): benchmark(func) @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") - def test_rewrite_benchmark(self, benchmark): - inps, outs = self.large_fuseable_graph(n=25) + @pytest.mark.parametrize( + "graph_fn, n, expected_n_repl", + [ + # ("diamond_graph", None, (1, 4)), + ("deep_small_kernels", 20, (20, 60)), + ("large_fuseable_graph", 25, (103, 876)), + ], + ) + def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): + inps, outs = getattr(self, graph_fn)(n) fg = FunctionGraph(inps, outs) opt = FusionOptimizer() def rewrite_func(): - nb_replacement = opt.apply(fg.clone())[2] - return nb_replacement + fg_clone = fg.clone() + _, nb_fused, nb_replacement, *_ = opt.apply(fg_clone) + # fg_clone.dprint() + return nb_fused, nb_replacement - assert benchmark(rewrite_func) == 103 + assert rewrite_func() == expected_n_repl + benchmark.pedantic(rewrite_func, rounds=7, iterations=5) def test_no_warning_from_old_client(self): # There used to be a warning issued when creating fuseable mapping From d3c5133fa40bf4b4b9ce625f8d37692ca377afff Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 20 Sep 2025 10:19:40 +0200 Subject: [PATCH 02/13] Short-circuit `as_scalar` common cases faster --- pytensor/scalar/basic.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 339da84cd1..26d242d3f0 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: def as_scalar(x: Any, name: str | None = None) -> ScalarVariable: - from pytensor.tensor.basic import scalar_from_tensor - from pytensor.tensor.type import TensorType + if isinstance(x, ScalarVariable): + return x + + if isinstance(x, Variable): + from pytensor.tensor.basic import scalar_from_tensor + from pytensor.tensor.type import TensorType + + if isinstance(x.type, TensorType) and x.type.ndim == 0: + return scalar_from_tensor(x) + else: + raise TypeError(f"Cannot convert {x} to a scalar type") if isinstance(x, Apply): + # FIXME: Why do we support calling this with Apply? + # Also, if we do, why can't we support multiple outputs? if len(x.outputs) != 1: raise ValueError( "It is ambiguous which output of a multi-output" " Op has to be fetched.", x, ) - else: - x = x.outputs[0] - if isinstance(x, Variable): - if isinstance(x, ScalarVariable): - return x - elif isinstance(x.type, TensorType) and x.type.ndim == 0: - return scalar_from_tensor(x) - else: - raise TypeError(f"Cannot convert {x} to a scalar type") + return as_scalar(x.outputs[0]) return constant(x) From 820d99d3882f02941f81b0ca9b183599d2c08f04 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 19 Sep 2025 01:01:55 +0200 Subject: [PATCH 03/13] Speedup supports c_code Not using `__call__` avoids the test_value computation --- pytensor/scalar/basic.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 26d242d3f0..f12449cfc4 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1332,32 +1332,26 @@ def supports_c_code(self, inputs, outputs): the given Elemwise inputs, outputs. """ - try: - tmp_s_input = [] - # To keep the same aliasing between inputs - mapping = dict() - for ii in inputs: - if ii in mapping: - tmp_s_input.append(mapping[ii]) - else: - tmp = get_scalar_type(ii.dtype).make_variable() - tmp_s_input.append(tmp) - mapping[ii] = tmp_s_input[-1] - - with config.change_flags(compute_test_value="ignore"): - s_op = self(*tmp_s_input, return_list=True) + tmp_s_input = [] + # To keep the same aliasing between inputs + mapping = {} + for ii in inputs: + if ii in mapping: + tmp_s_input.append(mapping[ii]) + else: + tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable() + tmp_s_input.append(tmp) - # if the scalar_op don't have a c implementation, - # we skip its fusion to allow the fusion of the - # other ops. + try: self.c_code( - s_op[0].owner, + self.make_node(*tmp_s_input), "test_presence_of_c_code", + # FIXME: Shouldn't this be a unique name per unique variable? ["x" for x in inputs], ["z" for z in outputs], {"fail": "%(fail)s"}, ) - except (MethodNotDefined, NotImplementedError): + except (NotImplementedError, MethodNotDefined): return False return True From 69237f3ad72cbb1686240f02f2ffe215e5e97260 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 13:33:36 +0200 Subject: [PATCH 04/13] Speedup FusionOptimizer.elemwise_to_scalar --- pytensor/scalar/basic.py | 8 ++-- pytensor/tensor/rewriting/elemwise.py | 55 +++++++++------------------ 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index f12449cfc4..cbf7b73542 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -779,9 +779,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType: This caches objects to save allocation and run time. """ - if dtype not in cache: - cache[dtype] = ScalarType(dtype=dtype) - return cache[dtype] + try: + return cache[dtype] + except KeyError: + cache[dtype] = res = ScalarType(dtype=dtype) + return res # Register C code for ViewOp on Scalars. diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 0eb2900729..1eb3d7c037 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -28,7 +28,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors +from pytensor.graph.traversal import ancestors, toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -530,43 +530,24 @@ def add_requirements(self, fgraph): @staticmethod def elemwise_to_scalar(inputs, outputs): - replace_inputs = [(inp, inp.clone()) for inp in inputs] - outputs = clone_replace(outputs, replace=replace_inputs) - - inputs = [inp for _, inp in replace_inputs] - fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) - middle_inputs = [] - - scalar_inputs = [ - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs - ] - middle_scalar_inputs = [] - - for node in fg.toposort(): - node_scalar_inputs = [] - for inp in node.inputs: - if inp in inputs: - node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) - elif inp in middle_inputs: - node_scalar_inputs.append( - middle_scalar_inputs[middle_inputs.index(inp)] + replacement = { + inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + } + for node in toposort(outputs, blockers=inputs): + scalar_inputs = [replacement[inp] for inp in node.inputs] + replacement.update( + dict( + zip( + node.outputs, + node.op.scalar_op.make_node(*scalar_inputs).outputs, ) - else: - new_scalar_input = ps.get_scalar_type( - inp.type.dtype - ).make_variable() - node_scalar_inputs.append(new_scalar_input) - middle_scalar_inputs.append(new_scalar_input) - middle_inputs.append(inp) - - new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) - middle_scalar_inputs.append(new_scalar_node.outputs[0]) - middle_inputs.append(node.outputs[0]) - - scalar_outputs = [ - middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs - ] - return scalar_inputs, scalar_outputs + ) + ) + + return ( + [replacement[inp] for inp in inputs], + [replacement[out] for out in outputs], + ) def apply(self, fgraph): if fgraph.profile: From a39dc8b4921d851a121189b8ad0bb91cdd020af0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 20 Sep 2025 10:05:06 +0200 Subject: [PATCH 05/13] Avoid double cloning of Composite Ops created by FusionOptimizer --- pytensor/scalar/basic.py | 19 ++++++++++++------- pytensor/tensor/rewriting/elemwise.py | 13 +++++++------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index cbf7b73542..769a5dfeeb 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -13,7 +13,6 @@ import builtins import math from collections.abc import Callable -from copy import copy from itertools import chain from textwrap import dedent from typing import Any, TypeAlias @@ -4093,12 +4092,12 @@ def __init__(self, *args, **kwargs): self.prepare_node_called = set() super().__init__(*args, **kwargs) - def _cleanup_graph(self, inputs, outputs): + def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True): # TODO: We could convert to TensorVariable, optimize graph, # and then convert back to ScalarVariable. # This would introduce rewrites like `log(1 + x) -> log1p`. - fgraph = FunctionGraph(copy(inputs), copy(outputs)) + fgraph = FunctionGraph(inputs, outputs, clone=clone) # Validate node types for node in fgraph.apply_nodes: @@ -4281,7 +4280,9 @@ class Composite(ScalarInnerGraphOp): init_param: tuple[str, ...] = ("inputs", "outputs") - def __init__(self, inputs, outputs, name="Composite"): + def __init__( + self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True + ): self.name = name self._name = None # We need to clone the graph as sometimes its nodes already @@ -4299,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"): if len(outputs) > 1 or not any( isinstance(var.owner.op, Composite) for var in outputs ): - # No inner Composite - inputs, outputs = clone(inputs, outputs) + if clone_graph: + inputs, outputs = clone(inputs, outputs) + else: # Inner Composite that we need to flatten + # FIXME: There could be a composite in the middle of the graph, why is this here? + # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway. assert len(outputs) == 1 # 1. Create a new graph from inputs up to the # Composite @@ -4321,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"): assert res[0] != inputs inputs, outputs = res[0], res2[1] - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) + # We already cloned the graph, or the user told us there was no need for it + self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) self.inputs_type = tuple(input.type for input in self.inputs) self.outputs_type = tuple(output.type for output in self.outputs) self.nin = len(inputs) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 1eb3d7c037..42f4b6fc67 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -915,12 +915,13 @@ def update_fuseable_mappings_after_fg_replace( break scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) - composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))( - *inputs - ) - if not isinstance(composite_outputs, list): - composite_outputs = [composite_outputs] - for old_out, composite_out in zip(outputs, composite_outputs, strict=True): + composite_outputs = Elemwise( + # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables + ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False) + )(*inputs, return_list=True) + assert len(outputs) == len(composite_outputs) + for old_out, composite_out in zip(outputs, composite_outputs): + # Preserve any names on the original outputs if old_out.name: composite_out.name = old_out.name From 9baa8a48676923337fa9c89bdaecdc7bb42651a7 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 18:06:06 +0200 Subject: [PATCH 06/13] Do not recompute toposort in every iteration of FusionOptimizer It's not really needed as we never expand on the new nodes --- pytensor/tensor/rewriting/elemwise.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 42f4b6fc67..689b47c28d 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -625,10 +625,10 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: def find_fuseable_subgraph( *, - fg: FunctionGraph, visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: KT = TypeVar("KT") VT = TypeVar("VT", list, set) @@ -648,8 +648,7 @@ def variables_depend_on( for a in ancestors(variables, blockers=stop_search_at) ) - toposort = fg.toposort() - for starting_node in toposort: + for starting_node in toposort_index: if starting_node in visited_nodes: continue @@ -791,7 +790,7 @@ def variables_depend_on( and inp.owner not in visited_nodes ) ), - key=lambda inp: toposort.index(inp.owner), + key=lambda inp: toposort_index[inp.owner], reverse=True, ): fuseable_nodes_to_visit.appendleft(inp.owner) @@ -803,7 +802,7 @@ def variables_depend_on( for node in fuseable_clients_temp.get(next_out, ()) if node not in visited_nodes ), - key=lambda node: toposort.index(node), + key=lambda node: toposort_index[node], ): fuseable_nodes_to_visit.append(next_node) @@ -877,20 +876,22 @@ def update_fuseable_mappings_after_fg_replace( # client (those that don't fit into 1)) fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) visited_nodes: set[Apply] = set() + toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} while True: - starting_nodes = fg.apply_nodes.copy() try: subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - fg=fg, visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + toposort_index=toposort_index, ) except ValueError: return else: # The caller is now expected to update fg in place, # by replacing the subgraph with a Composite Op + starting_nodes = fg.apply_nodes.copy() + yield subgraph_inputs, subgraph_outputs # This is where we avoid repeated work by using a stateful From 824af00f867706fc93dd44ef1db19076146dc36d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 23:49:50 +0200 Subject: [PATCH 07/13] Cleanup FusionOptimizer code --- pytensor/tensor/rewriting/elemwise.py | 164 ++++++++++++-------------- 1 file changed, 76 insertions(+), 88 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 689b47c28d..d37f04feb5 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -5,7 +5,7 @@ from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce -from typing import TypeVar +from typing import Literal from warnings import warn import pytensor.scalar.basic as ps @@ -568,8 +568,7 @@ def find_next_fuseable_subgraph( This generator assumes that such subgraph is replaced by a single Elemwise Composite before being accessed again in the next iteration. """ - - FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]] + FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] def initialize_fuseable_mappings( @@ -591,35 +590,33 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: # to ensure the rewrite remains deterministic. # This is not a problem from unfuseable ones, as they can never # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) + fuseable_clients: FUSEABLE_MAPPING = defaultdict(set) unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) for out, clients in fg.clients.items(): - # Old FunctionGraph nodes remain in the clients dictionary - # even after they are removed by rewrites - if not clients: - continue - out_maybe_fuseable = ( - out.owner + out.owner is not None and isinstance(out.owner.op, Elemwise) # and not isinstance(out.owner.op.scalar_op, ps.Composite) and len(out.owner.outputs) == 1 and elemwise_scalar_op_has_c_code(out.owner) ) - for client, _ in clients: - if ( - out_maybe_fuseable - and isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out.type.broadcastable - == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) - ): - if client not in fuseable_clients[out]: - fuseable_clients[out].append(client) - else: - unfuseable_clients[out].add(client) + if out_maybe_fuseable: + out_bcast = ( + out.type.broadcastable if out_maybe_fuseable else None + ) + for client, _ in clients: + if ( + isinstance(client.op, Elemwise) + # and not isinstance(client.op.scalar_op, ps.Composite) + and len(client.outputs) == 1 + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ): + fuseable_clients[out].add(client) + else: + unfuseable_clients[out].add(client) + else: + unfuseable_clients[out] = {client for client, _ in clients} return fuseable_clients, unfuseable_clients @@ -630,16 +627,6 @@ def find_fuseable_subgraph( unfuseable_clients: UNFUSEABLE_MAPPING, toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: - KT = TypeVar("KT") - VT = TypeVar("VT", list, set) - - def shallow_clone_defaultdict( - d: defaultdict[KT, VT], - ) -> defaultdict[KT, VT]: - new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory) - new_dict.update({k: v.copy() for k, v in d.items()}) - return new_dict - def variables_depend_on( variables, depend_on, stop_search_at=None ) -> bool: @@ -657,17 +644,19 @@ def variables_depend_on( visited_nodes.add(starting_node) continue - subgraph_inputs: list[Variable] = [] - subgraph_outputs: list[Variable] = [] + subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set + subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set unfuseable_clients_subgraph: set[Variable] = set() # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) - unfuseable_clients_clone = shallow_clone_defaultdict( - unfuseable_clients + fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set) + fuseable_clients_clone.update( + {k: v.copy() for k, v in fuseable_clients.items()} + ) + unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set) + unfuseable_clients_clone.update( + {k: v.copy() for k, v in unfuseable_clients.items()} ) - - fuseable_nodes_to_visit = deque([starting_node]) # We now try to expand as much as possible towards the potentially # fuseable clients and ancestors to detect the largest possible @@ -676,6 +665,7 @@ def variables_depend_on( # some inputs or clients may depend on other nodes of the same # subgraph via a path that cannot be included in the Composite # (unfuseable) + fuseable_nodes_to_visit = deque([starting_node]) while fuseable_nodes_to_visit: next_node = fuseable_nodes_to_visit.popleft() visited_nodes.add(next_node) @@ -684,15 +674,14 @@ def variables_depend_on( # If the output variable of next_node has no fuseable clients # or has unfuseable clients, then next_node must become an output # if it is to be fused. - must_become_output = ( - next_out not in fuseable_clients_temp - or next_out in unfuseable_clients_clone - ) + must_become_output = not fuseable_clients_clone.get( + next_out + ) or unfuseable_clients_clone.get(next_out) # We have backtracked to this node, and it may no longer be a viable output, # so we remove it and check again as if we had never seen this node - if must_become_output and next_out in subgraph_outputs: - subgraph_outputs.remove(next_out) + if must_become_output: + subgraph_outputs.pop(next_out, None) required_unfuseable_inputs = [ inp @@ -744,18 +733,19 @@ def variables_depend_on( if ( inp.owner in visited_nodes # next_node could have the same input repeated - and next_node in fuseable_clients_temp[inp] + and next_node in fuseable_clients_clone[inp] ): - fuseable_clients_temp[inp].remove(next_node) + fuseable_clients_clone[inp].remove(next_node) unfuseable_clients_clone[inp].add(next_node) # This input must become an output of the subgraph, # because it can't be merged with next_node. # We will revisit it to make sure this is safe. fuseable_nodes_to_visit.appendleft(inp.owner) - for client in fuseable_clients_temp[next_out]: + # need to convert to tuple not to change set size during iteration + for client in tuple(fuseable_clients_clone[next_out]): if client in visited_nodes: - fuseable_clients_temp[next_out].remove(client) + fuseable_clients_clone[next_out].remove(client) unfuseable_clients_clone[next_out].add(client) # next_out must become an input of the subgraph. # We will revisit any of its clients currently @@ -771,74 +761,72 @@ def variables_depend_on( # mappings as if it next_node was part of it. # Useless inputs will be removed by the useless Composite rewrite for inp in new_required_unfuseable_inputs: - if inp not in subgraph_inputs: - subgraph_inputs.append(inp) + subgraph_inputs[inp] = None if must_become_output: - subgraph_outputs.append(next_out) + subgraph_outputs[next_out] = None unfuseable_clients_subgraph.update( new_implied_unfuseable_clients ) # Expand through unvisited fuseable ancestors - for inp in sorted( - ( - inp - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=lambda inp: toposort_index[inp.owner], - reverse=True, - ): - fuseable_nodes_to_visit.appendleft(inp.owner) + fuseable_nodes_to_visit.extendleft( + sorted( + ( + inp.owner + for inp in next_node.inputs + if ( + inp not in required_unfuseable_inputs + and inp.owner not in visited_nodes + ) + ), + key=toposort_index.get, # type: ignore[arg-type] + ) + ) # Expand through unvisited fuseable clients - for next_node in sorted( - ( - node - for node in fuseable_clients_temp.get(next_out, ()) - if node not in visited_nodes - ), - key=lambda node: toposort_index[node], - ): - fuseable_nodes_to_visit.append(next_node) + fuseable_nodes_to_visit.extend( + sorted( + ( + node + for node in fuseable_clients_clone.get(next_out, ()) + if node not in visited_nodes + ), + key=toposort_index.get, # type: ignore[arg-type] + ) + ) # Don't return if final subgraph is just the original Elemwise if len(subgraph_outputs) == 1 and set( - subgraph_outputs[0].owner.inputs + next(iter(subgraph_outputs)).owner.inputs ) == set(subgraph_inputs): # Update global fuseable mappings # No input was actually fuseable for inp in starting_node.inputs: - if starting_node in fuseable_clients.get(inp, ()): - fuseable_clients[inp].remove(starting_node) - unfuseable_clients[inp].add(starting_node) + fuseable_clients[inp].discard(starting_node) + unfuseable_clients[inp].add(starting_node) # No client was actually fuseable unfuseable_clients[starting_out].update( fuseable_clients.pop(starting_out, ()) ) continue - return subgraph_inputs, subgraph_outputs + return list(subgraph_inputs), list(subgraph_outputs) raise ValueError def update_fuseable_mappings_after_fg_replace( *, - fg: FunctionGraph, visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, starting_nodes: set[Apply], + updated_nodes: set[Apply], ) -> None: # Find new composite node and dropped intermediate nodes # by comparing the current fg.apply nodes with the cached # original nodes - next_nodes = fg.apply_nodes - (new_composite_node,) = next_nodes - starting_nodes - dropped_nodes = starting_nodes - next_nodes + (new_composite_node,) = updated_nodes - starting_nodes + dropped_nodes = starting_nodes - updated_nodes # Remove intermediate Composite nodes from mappings for dropped_node in dropped_nodes: @@ -850,11 +838,11 @@ def update_fuseable_mappings_after_fg_replace( # Update fuseable information for subgraph inputs for inp in subgraph_inputs: if inp in fuseable_clients: - new_fuseable_clients = [ + new_fuseable_clients = { client for client in fuseable_clients[inp] if client not in dropped_nodes - ] + } if new_fuseable_clients: fuseable_clients[inp] = new_fuseable_clients else: @@ -898,11 +886,11 @@ def update_fuseable_mappings_after_fg_replace( # generator. For large models (as in `TestFusion.test_big_fusion`) # this can provide huge speedups update_fuseable_mappings_after_fg_replace( - fg=fg, visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, starting_nodes=starting_nodes, + updated_nodes=fg.apply_nodes, ) nb_fused = 0 From d73debfb45b55d280bcd326deec541dd2cfed648 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 23:49:50 +0200 Subject: [PATCH 08/13] Copy on write in FusionOptimizer --- pytensor/tensor/rewriting/elemwise.py | 82 ++++++++++++++++++++------- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index d37f04feb5..7d65ce5f95 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,6 +2,7 @@ import itertools import operator import sys +import typing from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce @@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int: return 1024 +class CopyOnWriteDictOfSets: + __slots__ = ("d", "d_copy") + + def __init__(self, d: dict[typing.Any, set]): + self.d = d + self.d_copy: dict[typing.Any, set] = {} + + def __getitem__(self, key): + try: + return self.d_copy[key] + except KeyError: + return self.d[key] + + def get(self, key, default=frozenset()): + try: + return self.d_copy[key] + except KeyError: + try: + return self.d[key] + except KeyError: + return default + + def remove_from_key(self, key, value): + try: + self.d_copy[key].remove(value) + except KeyError: + self.d_copy[key] = copied_value = self.d[key].copy() + copied_value.remove(value) + + def add_to_key(self, key, value): + try: + self.d_copy[key].add(value) + except KeyError: + self.d_copy[key] = copied_value = self.d[key].copy() + copied_value.add(value) + + class FusionOptimizer(GraphRewriter): """Graph optimizer that fuses consecutive Elemwise operations.""" @@ -648,15 +686,10 @@ def variables_depend_on( subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set unfuseable_clients_subgraph: set[Variable] = set() - # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set) - fuseable_clients_clone.update( - {k: v.copy() for k, v in fuseable_clients.items()} - ) - unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set) - unfuseable_clients_clone.update( - {k: v.copy() for k, v in unfuseable_clients.items()} - ) + # If we need to manipulate the maps in place, we'll do a shallow copy later + # For now we query on the original ones + fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients) + unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients) # We now try to expand as much as possible towards the potentially # fuseable clients and ancestors to detect the largest possible @@ -686,7 +719,7 @@ def variables_depend_on( required_unfuseable_inputs = [ inp for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp, ()) + if next_node in unfuseable_clients_clone.get(inp) ] new_required_unfuseable_inputs = [ inp @@ -709,7 +742,7 @@ def variables_depend_on( if not must_backtrack: implied_unfuseable_clients = { c - for client in unfuseable_clients_clone.get(next_out, ()) + for client in unfuseable_clients_clone.get(next_out) if not isinstance(client.op, Output) for c in client.outputs } @@ -730,13 +763,15 @@ def variables_depend_on( if must_backtrack: for inp in next_node.inputs: - if ( - inp.owner in visited_nodes - # next_node could have the same input repeated - and next_node in fuseable_clients_clone[inp] - ): - fuseable_clients_clone[inp].remove(next_node) - unfuseable_clients_clone[inp].add(next_node) + if inp.owner in visited_nodes: + if next_node not in fuseable_clients_clone[inp]: + # This can happen when next node has repeated inputs + continue + fuseable_clients_clone.remove_from_key( + inp, next_node + ) + unfuseable_clients_clone.add_to_key(inp, next_node) + # This input must become an output of the subgraph, # because it can't be merged with next_node. # We will revisit it to make sure this is safe. @@ -745,8 +780,13 @@ def variables_depend_on( # need to convert to tuple not to change set size during iteration for client in tuple(fuseable_clients_clone[next_out]): if client in visited_nodes: - fuseable_clients_clone[next_out].remove(client) - unfuseable_clients_clone[next_out].add(client) + fuseable_clients_clone.remove_from_key( + next_out, client + ) + unfuseable_clients_clone.add_to_key( + next_out, client + ) + # next_out must become an input of the subgraph. # We will revisit any of its clients currently # in the subgraph to make sure this is safe. @@ -789,7 +829,7 @@ def variables_depend_on( sorted( ( node - for node in fuseable_clients_clone.get(next_out, ()) + for node in fuseable_clients_clone.get(next_out) if node not in visited_nodes ), key=toposort_index.get, # type: ignore[arg-type] From 514832e62a18124ec539838ff820cc40539ec109 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 18:57:02 +0200 Subject: [PATCH 09/13] Use bitset to check ancestors more efficiently --- pytensor/tensor/rewriting/elemwise.py | 139 +++++++++++++------------- tests/test_printing.py | 14 +-- 2 files changed, 77 insertions(+), 76 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 7d65ce5f95..77cf934705 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -6,6 +6,7 @@ from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce +from operator import or_ from typing import Literal from warnings import warn @@ -29,7 +30,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors, toposort +from pytensor.graph.traversal import toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -663,16 +664,9 @@ def find_fuseable_subgraph( visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + ancestors_bitset: dict[Apply, int], toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: - def variables_depend_on( - variables, depend_on, stop_search_at=None - ) -> bool: - return any( - a in depend_on - for a in ancestors(variables, blockers=stop_search_at) - ) - for starting_node in toposort_index: if starting_node in visited_nodes: continue @@ -684,7 +678,8 @@ def variables_depend_on( subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set - unfuseable_clients_subgraph: set[Variable] = set() + subgraph_inputs_ancestors_bitset = 0 + unfuseable_clients_subgraph_bitset = 0 # If we need to manipulate the maps in place, we'll do a shallow copy later # For now we query on the original ones @@ -716,50 +711,32 @@ def variables_depend_on( if must_become_output: subgraph_outputs.pop(next_out, None) - required_unfuseable_inputs = [ - inp - for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp) - ] - new_required_unfuseable_inputs = [ - inp - for inp in required_unfuseable_inputs - if inp not in subgraph_inputs - ] - - must_backtrack = False - if new_required_unfuseable_inputs and subgraph_outputs: - # We need to check that any new inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - if variables_depend_on( - [next_out], - depend_on=unfuseable_clients_subgraph, - stop_search_at=subgraph_outputs, - ): - must_backtrack = True + # We need to check that any inputs required by this node + # do not depend on other outputs of the current subgraph, + # via an unfuseable path. + must_backtrack = ( + ancestors_bitset[next_node] + & unfuseable_clients_subgraph_bitset + ) if not must_backtrack: - implied_unfuseable_clients = { - c - for client in unfuseable_clients_clone.get(next_out) - if not isinstance(client.op, Output) - for c in client.outputs - } - - new_implied_unfuseable_clients = ( - implied_unfuseable_clients - unfuseable_clients_subgraph + implied_unfuseable_clients_bitset = reduce( + or_, + ( + 1 << toposort_index[client] + for client in unfuseable_clients_clone.get(next_out) + if not isinstance(client.op, Output) + ), + 0, ) - if new_implied_unfuseable_clients and subgraph_inputs: - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - if variables_depend_on( - subgraph_inputs, - depend_on=new_implied_unfuseable_clients, - ): - must_backtrack = True + # We need to check that any inputs of the current subgraph + # do not depend on other clients of this node, + # via an unfuseable path. + must_backtrack = ( + subgraph_inputs_ancestors_bitset + & implied_unfuseable_clients_bitset + ) if must_backtrack: for inp in next_node.inputs: @@ -800,29 +777,24 @@ def variables_depend_on( # immediate dependency problems. Update subgraph # mappings as if it next_node was part of it. # Useless inputs will be removed by the useless Composite rewrite - for inp in new_required_unfuseable_inputs: - subgraph_inputs[inp] = None - if must_become_output: subgraph_outputs[next_out] = None - unfuseable_clients_subgraph.update( - new_implied_unfuseable_clients + unfuseable_clients_subgraph_bitset |= ( + implied_unfuseable_clients_bitset ) - # Expand through unvisited fuseable ancestors - fuseable_nodes_to_visit.extendleft( - sorted( - ( - inp.owner - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=toposort_index.get, # type: ignore[arg-type] - ) - ) + for inp in sorted( + next_node.inputs, + key=lambda x: toposort_index.get(x.owner, -1), + ): + if next_node in unfuseable_clients_clone.get(inp, ()): + # input must become an input of the subgraph since it's unfuseable with new node + subgraph_inputs_ancestors_bitset |= ( + ancestors_bitset.get(inp.owner, 0) + ) + subgraph_inputs[inp] = None + elif inp.owner not in visited_nodes: + fuseable_nodes_to_visit.appendleft(inp.owner) # Expand through unvisited fuseable clients fuseable_nodes_to_visit.extend( @@ -859,6 +831,8 @@ def update_fuseable_mappings_after_fg_replace( visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + toposort_index: dict[Apply, int], + ancestors_bitset: dict[Apply, int], starting_nodes: set[Apply], updated_nodes: set[Apply], ) -> None: @@ -869,11 +843,25 @@ def update_fuseable_mappings_after_fg_replace( dropped_nodes = starting_nodes - updated_nodes # Remove intermediate Composite nodes from mappings + # And compute the ancestors bitset of the new composite node + # As well as the new toposort index for the new node + new_node_ancestor_bitset = 0 + new_node_toposort_index = len(toposort_index) for dropped_node in dropped_nodes: (dropped_out,) = dropped_node.outputs fuseable_clients.pop(dropped_out, None) unfuseable_clients.pop(dropped_out, None) visited_nodes.remove(dropped_node) + # The new composite ancestor bitset is the union + # of the ancestors of all the dropped nodes + new_node_ancestor_bitset |= ancestors_bitset[dropped_node] + # The new composite node can have the same order as the latest node that was absorbed into it + new_node_toposort_index = max( + new_node_toposort_index, toposort_index[dropped_node] + ) + + ancestors_bitset[new_composite_node] = new_node_ancestor_bitset + toposort_index[new_composite_node] = new_node_toposort_index # Update fuseable information for subgraph inputs for inp in subgraph_inputs: @@ -905,12 +893,23 @@ def update_fuseable_mappings_after_fg_replace( fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) visited_nodes: set[Apply] = set() toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} + # Create a bitset for each node of all its ancestors + # This allows to quickly check if a variable depends on a set + ancestors_bitset: dict[Apply, int] = {} + for node, index in toposort_index.items(): + node_ancestor_bitset = 1 << index + for inp in node.inputs: + if (inp_node := inp.owner) is not None: + node_ancestor_bitset |= ancestors_bitset[inp_node] + ancestors_bitset[node] = node_ancestor_bitset + while True: try: subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + ancestors_bitset=ancestors_bitset, toposort_index=toposort_index, ) except ValueError: @@ -929,6 +928,8 @@ def update_fuseable_mappings_after_fg_replace( visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + toposort_index=toposort_index, + ancestors_bitset=ancestors_bitset, starting_nodes=starting_nodes, updated_nodes=fg.apply_nodes, ) diff --git a/tests/test_printing.py b/tests/test_printing.py index 95c3c938cf..dbad8c063b 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -301,7 +301,8 @@ def test_debugprint(): Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv" exp_res = dedent( r""" - Composite{(i2 + (i0 - i1))} 4 + Composite{(i0 + (i1 - i2))} 4 + ├─ A ├─ ExpandDims{axis=0} v={0: [0]} 3 """ f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2" @@ -313,17 +314,16 @@ def test_debugprint(): │ ├─ B │ ├─ │ └─ 0.0 - ├─ D - └─ A + └─ D Inner graphs: - Composite{(i2 + (i0 - i1))} + Composite{(i0 + (i1 - i2))} ← add 'o0' - ├─ i2 - └─ sub ├─ i0 - └─ i1 + └─ sub + ├─ i1 + └─ i2 """ ).lstrip() From eb010b718c245ec03c5c843b99c9ce3fb837e565 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 18 Sep 2025 09:36:13 +0200 Subject: [PATCH 10/13] Avoid backtracking in FusionOptimizer The change in number of fused kernels has to do with the order of iteration, and could be replicated in the old approach by iterating in topological order. It was an accident that it happen to visit in an order where it connected two branches, instead of keeping them separate. The underlying limitation already existed and is described in https://github.com/pymc-devs/pytensor/issues/249 --- pytensor/tensor/rewriting/elemwise.py | 616 ++++++++++-------------- tests/tensor/rewriting/test_elemwise.py | 2 +- 2 files changed, 254 insertions(+), 364 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 77cf934705..f2f3957eb3 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,12 +2,10 @@ import itertools import operator import sys -import typing -from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce +from heapq import heapify, heappop, heappush from operator import or_ -from typing import Literal from warnings import warn import pytensor.scalar.basic as ps @@ -524,43 +522,6 @@ def elemwise_max_operands_fct(node) -> int: return 1024 -class CopyOnWriteDictOfSets: - __slots__ = ("d", "d_copy") - - def __init__(self, d: dict[typing.Any, set]): - self.d = d - self.d_copy: dict[typing.Any, set] = {} - - def __getitem__(self, key): - try: - return self.d_copy[key] - except KeyError: - return self.d[key] - - def get(self, key, default=frozenset()): - try: - return self.d_copy[key] - except KeyError: - try: - return self.d[key] - except KeyError: - return default - - def remove_from_key(self, key, value): - try: - self.d_copy[key].remove(value) - except KeyError: - self.d_copy[key] = copied_value = self.d[key].copy() - copied_value.remove(value) - - def add_to_key(self, key, value): - try: - self.d_copy[key].add(value) - except KeyError: - self.d_copy[key] = copied_value = self.d[key].copy() - copied_value.add(value) - - class FusionOptimizer(GraphRewriter): """Graph optimizer that fuses consecutive Elemwise operations.""" @@ -596,353 +557,282 @@ def apply(self, fgraph): max_operands = elemwise_max_operands_fct(None) - def find_next_fuseable_subgraph( + def find_fuseable_subgraphs( fg: FunctionGraph, - ) -> Generator[tuple[list[Variable], list[Variable]], None, None]: - """Find all subgraphs in a FunctionGraph that can be fused together - - Yields - ------- - List of inputs and outputs that determine subgraphs which can be fused. - This generator assumes that such subgraph is replaced by a single - Elemwise Composite before being accessed again in the next iteration. + ) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]: + """Find subgraphs of Elemwise nodes that can be fused together. + + In general there is no single solution, we try to find large subgraphs eagerly + + Any two consecutive Elemwise nodes that have the same broadcasting pattern, + and a C-implementation (historical accident that should be revisited), are potentially fuseable. + + However, we need to be careful about keeping the fused subgraph "convex", meaning that no two + nodes in the same subgraph are connected via a path that goes outside the subgraph, either because they + are connected via unfuseable nodes, or nodes that have been claimed by another subgraph. + + For example the graph add(sin(exp(x)), sum(exp(x)) cannot be fused into a single Elemwise, because the sum + node breaks the convexity of the subgraph {exp, sin, add}. However, we can fuse {exp, sin}, + and perhaps fuse add with somethnig else. """ - FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - - def initialize_fuseable_mappings( - *, fg: FunctionGraph - ) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: - @cache - def elemwise_scalar_op_has_c_code(node: Apply) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): - return True - else: - if config.optimizer_verbose: - warn( - f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." - ) - return False - - # Fuseable nodes have to be accessed in a deterministic manner - # to ensure the rewrite remains deterministic. - # This is not a problem from unfuseable ones, as they can never - # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(set) - unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) - for out, clients in fg.clients.items(): - out_maybe_fuseable = ( - out.owner is not None - and isinstance(out.owner.op, Elemwise) - # and not isinstance(out.owner.op.scalar_op, ps.Composite) - and len(out.owner.outputs) == 1 - and elemwise_scalar_op_has_c_code(out.owner) + + @cache + def elemwise_scalar_op_has_c_code( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: + warn( + f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." ) - if out_maybe_fuseable: - out_bcast = ( - out.type.broadcastable if out_maybe_fuseable else None - ) - for client, _ in clients: - if ( - isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out_bcast == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) - ): - fuseable_clients[out].add(client) - else: - unfuseable_clients[out].add(client) - else: - unfuseable_clients[out] = {client for client, _ in clients} - - return fuseable_clients, unfuseable_clients - - def find_fuseable_subgraph( - *, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - ancestors_bitset: dict[Apply, int], - toposort_index: dict[Apply, int], - ) -> tuple[list[Variable], list[Variable]]: - for starting_node in toposort_index: - if starting_node in visited_nodes: - continue + return False + + fuseable_clients: dict[Apply, set[Apply]] = {} + candidate_nodes = set() + fg_clients = fg.clients + for out, clients_and_indices in fg_clients.items(): + out_node = out.owner + + if not ( + out_node is not None + and len(out_node.outputs) == 1 + and isinstance(out_node.op, Elemwise) + and elemwise_scalar_op_has_c_code(out_node) + ): + continue - starting_out = starting_node.outputs[0] - if not fuseable_clients.get(starting_out): - visited_nodes.add(starting_node) - continue + candidate_nodes.add(out_node) + out_bcast = out.type.broadcastable + out_fuseable_clients = { + client + for client, _ in clients_and_indices + if ( + len(client.outputs) == 1 + and isinstance(client.op, Elemwise) + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ) + } + if out_fuseable_clients: + fuseable_clients[out_node] = out_fuseable_clients + + if not fuseable_clients: + return None + + # Create a bitset of ancestors for each node. + # Each node is represented by a bit flag of it's position in the toposort + # With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c, + # the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100} + # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} + # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND + # For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0` + # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do + # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` + nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} + ancestors_bitset = { + None: 0 + } # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None` + for node, node_bitflag in nodes_bitflags.items(): + # The bitset of each node is the union of the bitsets of its inputs, plus its own bit + ancestors_bitset[node] = reduce( + or_, + (ancestors_bitset[inp.owner] for inp in node.inputs), + node_bitflag, + ) + # handle root and leaf nodes gracefully + nodes_bitflags[None] = ( + 0 # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None` + ) + out_bitflag = 1 << len( + nodes_bitflags + ) # Nothing ever depends on output nodes, so just use a new bit for all + for out in fg.outputs: + for client, _ in fg_clients[out]: + if isinstance(client.op, Output): + nodes_bitflags[client] = out_bitflag + + sorted_subgraphs: list[ + tuple[int, tuple[tuple[Variable], tuple[Variable]]] + ] = [] + all_subgraphs_bitset = 0 + # Start exploring from candidate sink nodes (backwards) + # These are Elemwise nodes with a C-implementation, that are not part of another subgraph + # And have no other fuseable clients (i.e., are sinks) + for starting_node, starting_bitflag in reversed(nodes_bitflags.items()): + if ( + starting_bitflag & all_subgraphs_bitset + or starting_node not in candidate_nodes + ): + continue - subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set - subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set - subgraph_inputs_ancestors_bitset = 0 - unfuseable_clients_subgraph_bitset = 0 - - # If we need to manipulate the maps in place, we'll do a shallow copy later - # For now we query on the original ones - fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients) - unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients) - - # We now try to expand as much as possible towards the potentially - # fuseable clients and ancestors to detect the largest possible - # subgraph that can be Composed together into a single `Op`. The - # largest issue to watch out is for cyclical dependencies, where - # some inputs or clients may depend on other nodes of the same - # subgraph via a path that cannot be included in the Composite - # (unfuseable) - fuseable_nodes_to_visit = deque([starting_node]) - while fuseable_nodes_to_visit: - next_node = fuseable_nodes_to_visit.popleft() - visited_nodes.add(next_node) - next_out = next_node.outputs[0] - - # If the output variable of next_node has no fuseable clients - # or has unfuseable clients, then next_node must become an output - # if it is to be fused. - must_become_output = not fuseable_clients_clone.get( - next_out - ) or unfuseable_clients_clone.get(next_out) - - # We have backtracked to this node, and it may no longer be a viable output, - # so we remove it and check again as if we had never seen this node - if must_become_output: - subgraph_outputs.pop(next_out, None) - - # We need to check that any inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - must_backtrack = ( - ancestors_bitset[next_node] - & unfuseable_clients_subgraph_bitset - ) - - if not must_backtrack: - implied_unfuseable_clients_bitset = reduce( - or_, - ( - 1 << toposort_index[client] - for client in unfuseable_clients_clone.get(next_out) - if not isinstance(client.op, Output) - ), - 0, - ) + if starting_node in fuseable_clients: + # Not a sink, + continue - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - must_backtrack = ( - subgraph_inputs_ancestors_bitset - & implied_unfuseable_clients_bitset - ) + # We keep an ordered queue for expanding the subgraph + # We always want to visit ancestors before clients + # For ancestors, we want to visit the later nodes first (those that have more dependencies) + # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) + # We negate the bitflag for ancestors to achieve this ordering + fuseables_nodes_queue = [(-starting_bitflag, starting_node)] + heapify(fuseables_nodes_queue) + + # We keep 3 bitsets during the exploration: + # - the nodes that are part of the subgraph + # - the unfuseable ancestors of the subgraph (i.e., ancestors that are not fuseable with any node in the subgraph) + # - the unfuseable clients of the subgraph (i.e., clients that are not fuseable with any node in the subgraph) + # Whenever we visit a node, we check if unfuseable ancestors depend on it, or if it depends on an unfuseable client, + # in which case we can't fuse it. If we can fuse it, we then add its unfuseable ancestors/clients to the respective bitsets + # and add its fuseable ancestors/clients to the queue to explore later. This approach requires a visit in the order described above. + # Otherwise, we need to recompute target bitsets in every iteration and/or backtrack. + subgraph_nodes = [] + subgraph_bitset = 0 + unfuseable_ancestors_bitset = 0 + unfuseable_clients_bitset = 0 + + # print(f"\nStarting new subgraph exploration from {starting_node}") + while fuseables_nodes_queue: + node_bitflag, node = heappop(fuseables_nodes_queue) + is_ancestor = node_bitflag < 0 + if is_ancestor: + node_bitflag = -node_bitflag + # print(f"\t > Visiting {'ancestor' if is_ancestor else 'client'} {next_node}") + + if node_bitflag & subgraph_bitset: + # Already part of the subgraph + # print("\t - already in subgraph") + continue - if must_backtrack: - for inp in next_node.inputs: - if inp.owner in visited_nodes: - if next_node not in fuseable_clients_clone[inp]: - # This can happen when next node has repeated inputs - continue - fuseable_clients_clone.remove_from_key( - inp, next_node - ) - unfuseable_clients_clone.add_to_key(inp, next_node) - - # This input must become an output of the subgraph, - # because it can't be merged with next_node. - # We will revisit it to make sure this is safe. - fuseable_nodes_to_visit.appendleft(inp.owner) - - # need to convert to tuple not to change set size during iteration - for client in tuple(fuseable_clients_clone[next_out]): - if client in visited_nodes: - fuseable_clients_clone.remove_from_key( - next_out, client - ) - unfuseable_clients_clone.add_to_key( - next_out, client - ) - - # next_out must become an input of the subgraph. - # We will revisit any of its clients currently - # in the subgraph to make sure this is safe. - fuseable_nodes_to_visit.appendleft(client) - - # Revisit node at a later time - visited_nodes.remove(next_node) + if is_ancestor: + if node_bitflag & unfuseable_ancestors_bitset: + # An unfuseable ancestor depends on this node, can't fuse + # print("\t failed - unfuseable ancestor depends on it") continue + elif ancestors_bitset[node] & unfuseable_clients_bitset: + # This node depends on an unfuseable client, can't fuse + # print("\t failed - depends on unfuseable client") + continue - # Adding next_node to subgraph does not result in any - # immediate dependency problems. Update subgraph - # mappings as if it next_node was part of it. - # Useless inputs will be removed by the useless Composite rewrite - if must_become_output: - subgraph_outputs[next_out] = None - unfuseable_clients_subgraph_bitset |= ( - implied_unfuseable_clients_bitset + # print("\t succeeded - adding to subgraph") + subgraph_nodes.append(node) + subgraph_bitset |= node_bitflag + + # Expand through ancestors and client nodes + # A node can either be: + # - already part of the subgraph (skip) + # - fuseable (add to queue) + # - unfuseable (add to respective unfuseable bitset) + for ancestor in node.inputs: + ancestor_node = ancestor.owner + ancestor_bitflag = nodes_bitflags[ancestor_node] + if ancestor_bitflag & subgraph_bitset: + continue + if node in fuseable_clients.get(ancestor_node, ()): + heappush( + fuseables_nodes_queue, + (-ancestor_bitflag, ancestor_node), ) + else: + # If an ancestor is unfuseable, so are all its ancestors + unfuseable_ancestors_bitset |= ancestors_bitset[ + ancestor_node + ] + + next_fuseable_clients = fuseable_clients.get(node, ()) + for client, _ in fg_clients[node.outputs[0]]: + client_bitflag = nodes_bitflags[client] + if client_bitflag & subgraph_bitset: + continue + if client in next_fuseable_clients: + heappush(fuseables_nodes_queue, (client_bitflag, client)) + else: + # If a client is unfuseable, so are all its clients, but we don't need to keep track of those + # Any downstream client will also depend on this unfuseable client and will be rejected when visited + unfuseable_clients_bitset |= client_bitflag - for inp in sorted( - next_node.inputs, - key=lambda x: toposort_index.get(x.owner, -1), - ): - if next_node in unfuseable_clients_clone.get(inp, ()): - # input must become an input of the subgraph since it's unfuseable with new node - subgraph_inputs_ancestors_bitset |= ( - ancestors_bitset.get(inp.owner, 0) - ) - subgraph_inputs[inp] = None - elif inp.owner not in visited_nodes: - fuseable_nodes_to_visit.appendleft(inp.owner) - - # Expand through unvisited fuseable clients - fuseable_nodes_to_visit.extend( - sorted( - ( - node - for node in fuseable_clients_clone.get(next_out) - if node not in visited_nodes - ), - key=toposort_index.get, # type: ignore[arg-type] - ) - ) - - # Don't return if final subgraph is just the original Elemwise - if len(subgraph_outputs) == 1 and set( - next(iter(subgraph_outputs)).owner.inputs - ) == set(subgraph_inputs): - # Update global fuseable mappings - # No input was actually fuseable - for inp in starting_node.inputs: - fuseable_clients[inp].discard(starting_node) - unfuseable_clients[inp].add(starting_node) - # No client was actually fuseable - unfuseable_clients[starting_out].update( - fuseable_clients.pop(starting_out, ()) - ) - continue + # Finished exploring this subgraph + all_subgraphs_bitset |= subgraph_bitset + + if subgraph_bitset == starting_bitflag: + # No fusion possible, single node subgraph + continue - return list(subgraph_inputs), list(subgraph_outputs) - raise ValueError - - def update_fuseable_mappings_after_fg_replace( - *, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - toposort_index: dict[Apply, int], - ancestors_bitset: dict[Apply, int], - starting_nodes: set[Apply], - updated_nodes: set[Apply], - ) -> None: - # Find new composite node and dropped intermediate nodes - # by comparing the current fg.apply nodes with the cached - # original nodes - (new_composite_node,) = updated_nodes - starting_nodes - dropped_nodes = starting_nodes - updated_nodes - - # Remove intermediate Composite nodes from mappings - # And compute the ancestors bitset of the new composite node - # As well as the new toposort index for the new node - new_node_ancestor_bitset = 0 - new_node_toposort_index = len(toposort_index) - for dropped_node in dropped_nodes: - (dropped_out,) = dropped_node.outputs - fuseable_clients.pop(dropped_out, None) - unfuseable_clients.pop(dropped_out, None) - visited_nodes.remove(dropped_node) - # The new composite ancestor bitset is the union - # of the ancestors of all the dropped nodes - new_node_ancestor_bitset |= ancestors_bitset[dropped_node] - # The new composite node can have the same order as the latest node that was absorbed into it - new_node_toposort_index = max( - new_node_toposort_index, toposort_index[dropped_node] + # Find out inputs/outputs of subgraph_nodes + not_subgraph_bitset = ~subgraph_bitset + # Use a dict to deduplicate while preserving order + subgraph_inputs = tuple( + dict.fromkeys( + inp + for node in subgraph_nodes + for inp in node.inputs + if (ancestor_node := inp.owner) is None + or nodes_bitflags[ancestor_node] & not_subgraph_bitset ) + ) - ancestors_bitset[new_composite_node] = new_node_ancestor_bitset - toposort_index[new_composite_node] = new_node_toposort_index - - # Update fuseable information for subgraph inputs - for inp in subgraph_inputs: - if inp in fuseable_clients: - new_fuseable_clients = { - client - for client in fuseable_clients[inp] - if client not in dropped_nodes - } - if new_fuseable_clients: - fuseable_clients[inp] = new_fuseable_clients - else: - fuseable_clients.pop(inp) - unfuseable_clients[inp] = ( - unfuseable_clients[inp] - dropped_nodes - ) | {new_composite_node} - - # Update fuseable information for subgraph outputs - for out in new_composite_node.outputs: - unfuseable_clients[out] = {client for client, _ in fg.clients[out]} - - visited_nodes.add(new_composite_node) - return - - # We start by creating two maps, 1) from each node to each potentially - # fuseable client (both nodes must be single output Elemwise with same - # broadcast type) and 2) from each node to each certainly unfuseable - # client (those that don't fit into 1)) - fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) - visited_nodes: set[Apply] = set() - toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} - # Create a bitset for each node of all its ancestors - # This allows to quickly check if a variable depends on a set - ancestors_bitset: dict[Apply, int] = {} - for node, index in toposort_index.items(): - node_ancestor_bitset = 1 << index - for inp in node.inputs: - if (inp_node := inp.owner) is not None: - node_ancestor_bitset |= ancestors_bitset[inp_node] - ancestors_bitset[node] = node_ancestor_bitset - - while True: - try: - subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - ancestors_bitset=ancestors_bitset, - toposort_index=toposort_index, + subgraph_outputs = tuple( + node.outputs[0] + for node in subgraph_nodes + if any( + nodes_bitflags[client] & not_subgraph_bitset + for client, _ in fg_clients[node.outputs[0]] + ) + ) + + # print(f"Found subgraph with {len(subgraph_inputs)} inputs, {len(subgraph_outputs)} outputs, and {len(subgraph_nodes)} nodes (subgraph_bitset={bin(subgraph_bitset)})") + # FunctionGraph(list(subgraph_inputs), subgraph_outputs).dprint() + + # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end + # But in some cases they can, so we need to insert at the right position. + if not (unfuseable_ancestors_bitset & all_subgraphs_bitset): + sorted_subgraphs.append( + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)) ) - except ValueError: - return else: - # The caller is now expected to update fg in place, - # by replacing the subgraph with a Composite Op - starting_nodes = fg.apply_nodes.copy() - - yield subgraph_inputs, subgraph_outputs - - # This is where we avoid repeated work by using a stateful - # generator. For large models (as in `TestFusion.test_big_fusion`) - # this can provide huge speedups - update_fuseable_mappings_after_fg_replace( - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - toposort_index=toposort_index, - ancestors_bitset=ancestors_bitset, - starting_nodes=starting_nodes, - updated_nodes=fg.apply_nodes, + # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph + # no longer depends on what's left. This tells us where to insert the current subgraph. + remaining_subgraphs_bitset = all_subgraphs_bitset + for index, (other_subgraph_bitset, _) in enumerate( + reversed(sorted_subgraphs) + ): + remaining_subgraphs_bitset &= ~other_subgraph_bitset + if not ( + unfuseable_ancestors_bitset & remaining_subgraphs_bitset + ): + break + + sorted_subgraphs.insert( + -(index + 1), + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), ) + # Update fuseable clients, inputs can no longer be fused with graph variables + # and outputs can't be fused with anything else + for ancestor in subgraph_inputs: + if (ancestor_node := ancestor.owner) is not None: + if ancestor_fuseable_clients := fuseable_clients.get( + ancestor_node + ): + ancestor_fuseable_clients.difference_update(subgraph_nodes) + if not ancestor_fuseable_clients: + del fuseable_clients[ancestor_node] + + for out in subgraph_outputs: + fuseable_clients.pop(out.owner, None) + + yield from (io for _, io in sorted_subgraphs) + nb_fused = 0 nb_replacement = 0 - for inputs, outputs in find_next_fuseable_subgraph(fgraph): + for inputs, outputs in find_fuseable_subgraphs(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( "Loop fusion failed because the resulting node would exceed " "the kernel argument limit." ) - break + continue scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) composite_outputs = Elemwise( @@ -957,7 +847,7 @@ def update_fuseable_mappings_after_fg_replace( starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( - list(zip(outputs, composite_outputs, strict=True)), + tuple(zip(outputs, composite_outputs)), reason=self.__class__.__name__, ) nb_fused += 1 diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 3c549788e1..7e625043ec 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1371,7 +1371,7 @@ def test_eval_benchmark(self, benchmark): [ # ("diamond_graph", None, (1, 4)), ("deep_small_kernels", 20, (20, 60)), - ("large_fuseable_graph", 25, (103, 876)), + ("large_fuseable_graph", 25, (128, 876)), ], ) def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): From 5b9baec69fdc768c4b94588a9a40de10063af3a9 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 27 Sep 2025 13:43:22 +0200 Subject: [PATCH 11/13] Regular sets instead of bitsets --- pytensor/tensor/rewriting/elemwise.py | 118 ++++++++++---------------- 1 file changed, 46 insertions(+), 72 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index f2f3957eb3..c4d67dd2ea 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -5,7 +5,6 @@ from collections.abc import Generator, Sequence from functools import cache, reduce from heapq import heapify, heappop, heappush -from operator import or_ from warnings import warn import pytensor.scalar.basic as ps @@ -16,7 +15,7 @@ from pytensor.graph.basic import Apply, Variable from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.features import ReplaceValidate -from pytensor.graph.fg import FunctionGraph, Output +from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import ( GraphRewriter, @@ -621,48 +620,28 @@ def elemwise_scalar_op_has_c_code( if not fuseable_clients: return None - # Create a bitset of ancestors for each node. - # Each node is represented by a bit flag of it's position in the toposort - # With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c, - # the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100} - # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} - # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND - # For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0` - # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do - # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` - nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} - ancestors_bitset = { - None: 0 - } # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None` - for node, node_bitflag in nodes_bitflags.items(): - # The bitset of each node is the union of the bitsets of its inputs, plus its own bit - ancestors_bitset[node] = reduce( - or_, - (ancestors_bitset[inp.owner] for inp in node.inputs), - node_bitflag, + toposort_idx = { + node: idx for idx, node in enumerate(fg.toposort(), start=1) + } + node_ancestors = {None: frozenset()} + for node in toposort_idx: + node_ancestors[node] = frozenset.union( + *(node_ancestors[inp.owner] for inp in node.inputs), {node} ) - # handle root and leaf nodes gracefully - nodes_bitflags[None] = ( - 0 # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None` - ) - out_bitflag = 1 << len( - nodes_bitflags - ) # Nothing ever depends on output nodes, so just use a new bit for all - for out in fg.outputs: - for client, _ in fg_clients[out]: - if isinstance(client.op, Output): - nodes_bitflags[client] = out_bitflag sorted_subgraphs: list[ tuple[int, tuple[tuple[Variable], tuple[Variable]]] ] = [] - all_subgraphs_bitset = 0 + subgraph_set = set() + unfuseable_ancestors_set = set() + unfuseable_clients_set = set() + all_subgraphs_set = set() # Start exploring from candidate sink nodes (backwards) # These are Elemwise nodes with a C-implementation, that are not part of another subgraph # And have no other fuseable clients (i.e., are sinks) - for starting_node, starting_bitflag in reversed(nodes_bitflags.items()): + for starting_node, starting_index in reversed(toposort_idx.items()): if ( - starting_bitflag & all_subgraphs_bitset + starting_node in all_subgraphs_set or starting_node not in candidate_nodes ): continue @@ -676,7 +655,7 @@ def elemwise_scalar_op_has_c_code( # For ancestors, we want to visit the later nodes first (those that have more dependencies) # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) # We negate the bitflag for ancestors to achieve this ordering - fuseables_nodes_queue = [(-starting_bitflag, starting_node)] + fuseables_nodes_queue = [(-starting_index, starting_node)] heapify(fuseables_nodes_queue) # We keep 3 bitsets during the exploration: @@ -687,37 +666,35 @@ def elemwise_scalar_op_has_c_code( # in which case we can't fuse it. If we can fuse it, we then add its unfuseable ancestors/clients to the respective bitsets # and add its fuseable ancestors/clients to the queue to explore later. This approach requires a visit in the order described above. # Otherwise, we need to recompute target bitsets in every iteration and/or backtrack. - subgraph_nodes = [] - subgraph_bitset = 0 - unfuseable_ancestors_bitset = 0 - unfuseable_clients_bitset = 0 + subgraph_set.clear() + unfuseable_ancestors_set.clear() + unfuseable_clients_set.clear() # print(f"\nStarting new subgraph exploration from {starting_node}") while fuseables_nodes_queue: - node_bitflag, node = heappop(fuseables_nodes_queue) - is_ancestor = node_bitflag < 0 + node_idx, node = heappop(fuseables_nodes_queue) + is_ancestor = node_idx < 0 if is_ancestor: - node_bitflag = -node_bitflag + node_idx = -node_idx # print(f"\t > Visiting {'ancestor' if is_ancestor else 'client'} {next_node}") - if node_bitflag & subgraph_bitset: + if node in subgraph_set: # Already part of the subgraph # print("\t - already in subgraph") continue if is_ancestor: - if node_bitflag & unfuseable_ancestors_bitset: + if node in unfuseable_ancestors_set: # An unfuseable ancestor depends on this node, can't fuse # print("\t failed - unfuseable ancestor depends on it") continue - elif ancestors_bitset[node] & unfuseable_clients_bitset: + elif not node_ancestors[node].isdisjoint(unfuseable_clients_set): # This node depends on an unfuseable client, can't fuse # print("\t failed - depends on unfuseable client") continue # print("\t succeeded - adding to subgraph") - subgraph_nodes.append(node) - subgraph_bitset |= node_bitflag + subgraph_set.add(node) # Expand through ancestors and client nodes # A node can either be: @@ -726,49 +703,48 @@ def elemwise_scalar_op_has_c_code( # - unfuseable (add to respective unfuseable bitset) for ancestor in node.inputs: ancestor_node = ancestor.owner - ancestor_bitflag = nodes_bitflags[ancestor_node] - if ancestor_bitflag & subgraph_bitset: + if ancestor_node in subgraph_set: continue if node in fuseable_clients.get(ancestor_node, ()): heappush( fuseables_nodes_queue, - (-ancestor_bitflag, ancestor_node), + (-toposort_idx[ancestor_node], ancestor_node), ) else: # If an ancestor is unfuseable, so are all its ancestors - unfuseable_ancestors_bitset |= ancestors_bitset[ - ancestor_node - ] + unfuseable_ancestors_set |= node_ancestors[ancestor_node] next_fuseable_clients = fuseable_clients.get(node, ()) for client, _ in fg_clients[node.outputs[0]]: - client_bitflag = nodes_bitflags[client] - if client_bitflag & subgraph_bitset: + if client in subgraph_set: continue if client in next_fuseable_clients: - heappush(fuseables_nodes_queue, (client_bitflag, client)) + heappush( + fuseables_nodes_queue, (toposort_idx[client], client) + ) else: # If a client is unfuseable, so are all its clients, but we don't need to keep track of those # Any downstream client will also depend on this unfuseable client and will be rejected when visited - unfuseable_clients_bitset |= client_bitflag + unfuseable_clients_set.add(client) # Finished exploring this subgraph - all_subgraphs_bitset |= subgraph_bitset + all_subgraphs_set |= subgraph_set - if subgraph_bitset == starting_bitflag: + if len(subgraph_set) == 1: # No fusion possible, single node subgraph continue # Find out inputs/outputs of subgraph_nodes - not_subgraph_bitset = ~subgraph_bitset + # not_subgraph_bitset = ~subgraph_set # Use a dict to deduplicate while preserving order + subgraph_nodes = sorted(subgraph_set, key=toposort_idx.get) subgraph_inputs = tuple( dict.fromkeys( inp for node in subgraph_nodes for inp in node.inputs if (ancestor_node := inp.owner) is None - or nodes_bitflags[ancestor_node] & not_subgraph_bitset + or ancestor_node not in subgraph_set ) ) @@ -776,7 +752,7 @@ def elemwise_scalar_op_has_c_code( node.outputs[0] for node in subgraph_nodes if any( - nodes_bitflags[client] & not_subgraph_bitset + client not in subgraph_set for client, _ in fg_clients[node.outputs[0]] ) ) @@ -786,26 +762,24 @@ def elemwise_scalar_op_has_c_code( # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end # But in some cases they can, so we need to insert at the right position. - if not (unfuseable_ancestors_bitset & all_subgraphs_bitset): + if not (unfuseable_ancestors_set & all_subgraphs_set): sorted_subgraphs.append( - (subgraph_bitset, (subgraph_inputs, subgraph_outputs)) + (subgraph_set, (subgraph_inputs, subgraph_outputs)) ) else: # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph # no longer depends on what's left. This tells us where to insert the current subgraph. - remaining_subgraphs_bitset = all_subgraphs_bitset - for index, (other_subgraph_bitset, _) in enumerate( + remaining_subgraphs_bitset = all_subgraphs_set.copy() + for index, (other_subgraph_set, _) in enumerate( reversed(sorted_subgraphs) ): - remaining_subgraphs_bitset &= ~other_subgraph_bitset - if not ( - unfuseable_ancestors_bitset & remaining_subgraphs_bitset - ): + remaining_subgraphs_bitset.difference_update(other_subgraph_set) + if not (unfuseable_ancestors_set & remaining_subgraphs_bitset): break sorted_subgraphs.insert( -(index + 1), - (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), + (subgraph_set, (subgraph_inputs, subgraph_outputs)), ) # Update fuseable clients, inputs can no longer be fused with graph variables From 98204161e026f48b63f64e5d5fc7d46ffa7bd2f9 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 27 Sep 2025 21:49:37 +0200 Subject: [PATCH 12/13] Use helper classes for readability --- pytensor/tensor/rewriting/elemwise.py | 421 +++++++++++++++----------- 1 file changed, 248 insertions(+), 173 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index c4d67dd2ea..2d676d2c4a 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -575,47 +575,219 @@ def find_fuseable_subgraphs( and perhaps fuse add with somethnig else. """ - @cache - def elemwise_scalar_op_has_c_code( - node: Apply, optimizer_verbose=config.optimizer_verbose - ) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): - return True - elif optimizer_verbose: - warn( - f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." + class FuseableClients: + __slots__ = ("fuseable_clients", "candidate_nodes") + + def __init__(self, fgraph): + @cache + def elemwise_scalar_op_has_c_code( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: + warn( + f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." + ) + return False + + self.fuseable_clients = fuseable_clients = {} + self.candidate_nodes = candidate_nodes = set() + fg_clients = fg.clients + for out, clients_and_indices in fg_clients.items(): + out_node = out.owner + + if not ( + out_node is not None + and len(out_node.outputs) == 1 + and isinstance(out_node.op, Elemwise) + and elemwise_scalar_op_has_c_code(out_node) + ): + continue + + candidate_nodes.add(out_node) + out_bcast = out.type.broadcastable + out_fuseable_clients = { + client + for client, _ in clients_and_indices + if ( + len(client.outputs) == 1 + and isinstance(client.op, Elemwise) + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ) + } + if out_fuseable_clients: + fuseable_clients[out_node] = out_fuseable_clients + + def __bool__(self): + return bool(self.fuseable_clients) + + def __getitem__(self, node: Apply): + return self.fuseable_clients.get(node, ()) + + def is_sink_node(self, node: Apply) -> bool: + # A sink node is a candidate node that has no fuseable clients + return ( + node in self.candidate_nodes + and node not in self.fuseable_clients ) - return False - - fuseable_clients: dict[Apply, set[Apply]] = {} - candidate_nodes = set() - fg_clients = fg.clients - for out, clients_and_indices in fg_clients.items(): - out_node = out.owner - - if not ( - out_node is not None - and len(out_node.outputs) == 1 - and isinstance(out_node.op, Elemwise) - and elemwise_scalar_op_has_c_code(out_node) - ): - continue - candidate_nodes.add(out_node) - out_bcast = out.type.broadcastable - out_fuseable_clients = { - client - for client, _ in clients_and_indices - if ( - len(client.outputs) == 1 - and isinstance(client.op, Elemwise) - and out_bcast == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) + def remove_subgraph_connections(self, subgraph: "ConvexSubgraph"): + # Update fuseable clients, inputs can no longer be fused with graph variables + # and outputs can't be fused with anything else + subgraph_inputs, subgraph_outputs = ( + subgraph.get_inputs_and_outputs() ) - } - if out_fuseable_clients: - fuseable_clients[out_node] = out_fuseable_clients + fuseable_clients = self.fuseable_clients + for ancestor in subgraph_inputs: + if (ancestor_node := ancestor.owner) is not None: + if ancestor_fuseable_clients := fuseable_clients.get( + ancestor_node + ): + ancestor_fuseable_clients.difference_update( + subgraph.nodes + ) + if not ancestor_fuseable_clients: + del fuseable_clients[ancestor_node] + + for out in subgraph_outputs: + fuseable_clients.pop(out.owner, None) + + class SortedFuseableNodesQueue: + __slots__ = ("queue",) + + def __init__(self): + # We keep an ordered queue for expanding the subgraph + # We always want to visit ancestors before clients + # For ancestors, we want to visit the later nodes first (those that have more dependencies) + # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) + # We negate the bitflag for ancestors to achieve this ordering + self.queue = queue = [] + heapify(queue) + + def push(self, node: Apply, toposort_index: int, is_ancestor: bool): + if is_ancestor: + toposort_index = -toposort_index + heappush(self.queue, (toposort_index, node)) + + def pop(self) -> tuple[Apply, bool]: + toposort_index, node = heappop(self.queue) + return node, toposort_index < 0 + + def __bool__(self): + return bool(self.queue) + + class NonConvexError(Exception): + pass + + class ConvexSubgraph: + __slots__ = ( + "node_ancestors", + "nodes", + "bitset", + "unfuseable_ancestors", + "unfuseable_clients", + "inputs_and_outputs", + ) + + def __init__(self, node_ancestors): + self.node_ancestors = node_ancestors + self.nodes = {} + self.bitset = 0 + self.unfuseable_ancestors = set() + self.unfuseable_clients = set() + self.inputs_and_outputs = None + + def __len__(self): + return len(self.nodes) + + def __contains__(self, node: Apply): + return node in self.nodes + + def add_node(self, node: Apply, is_ancestor: bool): + if is_ancestor: + if node in self.unfuseable_ancestors: + raise NonConvexError + elif self.node_ancestors[node] & self.unfuseable_clients: + raise NonConvexError + self.nodes[node] = None + self.inputs_and_outputs = None # clear cache + + def add_unfuseable_ancestor(self, ancestor: Apply): + # If an ancestor is unfuseable, so are all its ancestors + self.unfuseable_ancestors |= self.node_ancestors[ancestor] + + def add_unfuseable_client(self, client: Apply): + # If a client is unfuseable, so are all its clients, but we don't need to keep track of those + # Any downstream client will also depend on this unfuseable client and will be rejected when visited + self.unfuseable_clients.add(client) + + def get_inputs_and_outputs(self): + if self.inputs_and_outputs is not None: + return self.inputs_and_outputs + + nodes = self.nodes + # Use a dict to deduplicate while preserving order + subgraph_inputs = tuple( + dict.fromkeys( + inp + for node in nodes + for inp in node.inputs + if (ancestor_node := inp.owner) is None + or ancestor_node not in nodes + ) + ) + + subgraph_outputs = tuple( + node.outputs[0] + for node in nodes + if any( + client not in nodes + for client, _ in fg_clients[node.outputs[0]] + ) + ) + self.inputs_and_outputs = subgraph_inputs, subgraph_outputs + return subgraph_inputs, subgraph_outputs + + class SortedSubgraphCollection: + __slots__ = ("subgraphs", "nodes") + + def __init__(self): + self.subgraphs: list[ + tuple[int, tuple[tuple[Variable], tuple[Variable]]] + ] = [] + self.nodes = {} + + def __contains__(self, node: Apply): + return node in self.nodes + + def insert_subgraph(self, subgraph: ConvexSubgraph): + # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end + # But in some cases they can, so we need to insert at the right position. + subgraph_unfuseable_ancestors = subgraph.unfuseable_ancestors + if subgraph_unfuseable_ancestors.isdisjoint(self.nodes): + self.subgraphs.append(subgraph) + else: + # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph + # no longer depends on what's left. This tells us where to insert the current subgraph. + remaining_nodes = set(self.nodes) + for index, other_subgraph in enumerate( + reversed(self.subgraphs) + ): + remaining_nodes.difference_update(other_subgraph.nodes) + if subgraph_unfuseable_ancestors.isdisjoint( + remaining_nodes + ): + break + self.subgraphs.insert(-(index + 1), subgraph) + self.nodes |= subgraph.nodes + + def __iter__(self): + yield from self.subgraphs + + fuseable_clients = FuseableClients(fgraph) if not fuseable_clients: return None @@ -629,73 +801,34 @@ def elemwise_scalar_op_has_c_code( *(node_ancestors[inp.owner] for inp in node.inputs), {node} ) - sorted_subgraphs: list[ - tuple[int, tuple[tuple[Variable], tuple[Variable]]] - ] = [] - subgraph_set = set() - unfuseable_ancestors_set = set() - unfuseable_clients_set = set() - all_subgraphs_set = set() + fg_clients = fgraph.clients + sorted_subgraphs = SortedSubgraphCollection() + # Start exploring from candidate sink nodes (backwards) # These are Elemwise nodes with a C-implementation, that are not part of another subgraph # And have no other fuseable clients (i.e., are sinks) - for starting_node, starting_index in reversed(toposort_idx.items()): + for starting_node, starting_idx in reversed(toposort_idx.items()): if ( - starting_node in all_subgraphs_set - or starting_node not in candidate_nodes + starting_node in sorted_subgraphs + or not fuseable_clients.is_sink_node(starting_node) ): continue - if starting_node in fuseable_clients: - # Not a sink, - continue + subgraph = ConvexSubgraph(node_ancestors) - # We keep an ordered queue for expanding the subgraph - # We always want to visit ancestors before clients - # For ancestors, we want to visit the later nodes first (those that have more dependencies) - # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) - # We negate the bitflag for ancestors to achieve this ordering - fuseables_nodes_queue = [(-starting_index, starting_node)] - heapify(fuseables_nodes_queue) - - # We keep 3 bitsets during the exploration: - # - the nodes that are part of the subgraph - # - the unfuseable ancestors of the subgraph (i.e., ancestors that are not fuseable with any node in the subgraph) - # - the unfuseable clients of the subgraph (i.e., clients that are not fuseable with any node in the subgraph) - # Whenever we visit a node, we check if unfuseable ancestors depend on it, or if it depends on an unfuseable client, - # in which case we can't fuse it. If we can fuse it, we then add its unfuseable ancestors/clients to the respective bitsets - # and add its fuseable ancestors/clients to the queue to explore later. This approach requires a visit in the order described above. - # Otherwise, we need to recompute target bitsets in every iteration and/or backtrack. - subgraph_set.clear() - unfuseable_ancestors_set.clear() - unfuseable_clients_set.clear() - - # print(f"\nStarting new subgraph exploration from {starting_node}") - while fuseables_nodes_queue: - node_idx, node = heappop(fuseables_nodes_queue) - is_ancestor = node_idx < 0 - if is_ancestor: - node_idx = -node_idx - # print(f"\t > Visiting {'ancestor' if is_ancestor else 'client'} {next_node}") + fuseable_nodes_queue = SortedFuseableNodesQueue() + fuseable_nodes_queue.push(starting_node, starting_idx, is_ancestor=True) + while fuseable_nodes_queue: + node, is_ancestor = fuseable_nodes_queue.pop() - if node in subgraph_set: - # Already part of the subgraph - # print("\t - already in subgraph") + if node in subgraph: continue - if is_ancestor: - if node in unfuseable_ancestors_set: - # An unfuseable ancestor depends on this node, can't fuse - # print("\t failed - unfuseable ancestor depends on it") - continue - elif not node_ancestors[node].isdisjoint(unfuseable_clients_set): - # This node depends on an unfuseable client, can't fuse - # print("\t failed - depends on unfuseable client") + try: + subgraph.add_node(node, is_ancestor=is_ancestor) + except NonConvexError: continue - # print("\t succeeded - adding to subgraph") - subgraph_set.add(node) - # Expand through ancestors and client nodes # A node can either be: # - already part of the subgraph (skip) @@ -703,100 +836,42 @@ def elemwise_scalar_op_has_c_code( # - unfuseable (add to respective unfuseable bitset) for ancestor in node.inputs: ancestor_node = ancestor.owner - if ancestor_node in subgraph_set: + if ancestor_node in subgraph: continue - if node in fuseable_clients.get(ancestor_node, ()): - heappush( - fuseables_nodes_queue, - (-toposort_idx[ancestor_node], ancestor_node), + if node in fuseable_clients[ancestor_node]: + fuseable_nodes_queue.push( + ancestor_node, + toposort_idx[ancestor_node], + is_ancestor=True, ) else: - # If an ancestor is unfuseable, so are all its ancestors - unfuseable_ancestors_set |= node_ancestors[ancestor_node] + subgraph.add_unfuseable_ancestor(ancestor_node) - next_fuseable_clients = fuseable_clients.get(node, ()) - for client, _ in fg_clients[node.outputs[0]]: - if client in subgraph_set: + next_fuseable_clients = fuseable_clients[node] + for client_node, _ in fg_clients[node.outputs[0]]: + if client_node in subgraph: continue - if client in next_fuseable_clients: - heappush( - fuseables_nodes_queue, (toposort_idx[client], client) + if client_node in next_fuseable_clients: + fuseable_nodes_queue.push( + client_node, + toposort_idx[client_node], + is_ancestor=False, ) else: - # If a client is unfuseable, so are all its clients, but we don't need to keep track of those - # Any downstream client will also depend on this unfuseable client and will be rejected when visited - unfuseable_clients_set.add(client) + subgraph.add_unfuseable_client(client_node) # Finished exploring this subgraph - all_subgraphs_set |= subgraph_set - - if len(subgraph_set) == 1: + if len(subgraph) == 1: # No fusion possible, single node subgraph continue - # Find out inputs/outputs of subgraph_nodes - # not_subgraph_bitset = ~subgraph_set - # Use a dict to deduplicate while preserving order - subgraph_nodes = sorted(subgraph_set, key=toposort_idx.get) - subgraph_inputs = tuple( - dict.fromkeys( - inp - for node in subgraph_nodes - for inp in node.inputs - if (ancestor_node := inp.owner) is None - or ancestor_node not in subgraph_set - ) - ) - - subgraph_outputs = tuple( - node.outputs[0] - for node in subgraph_nodes - if any( - client not in subgraph_set - for client, _ in fg_clients[node.outputs[0]] - ) - ) - - # print(f"Found subgraph with {len(subgraph_inputs)} inputs, {len(subgraph_outputs)} outputs, and {len(subgraph_nodes)} nodes (subgraph_bitset={bin(subgraph_bitset)})") - # FunctionGraph(list(subgraph_inputs), subgraph_outputs).dprint() + sorted_subgraphs.insert_subgraph(subgraph) + # Mark the nodes of this subgraph as no longer fuseable + fuseable_clients.remove_subgraph_connections(subgraph) - # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end - # But in some cases they can, so we need to insert at the right position. - if not (unfuseable_ancestors_set & all_subgraphs_set): - sorted_subgraphs.append( - (subgraph_set, (subgraph_inputs, subgraph_outputs)) - ) - else: - # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph - # no longer depends on what's left. This tells us where to insert the current subgraph. - remaining_subgraphs_bitset = all_subgraphs_set.copy() - for index, (other_subgraph_set, _) in enumerate( - reversed(sorted_subgraphs) - ): - remaining_subgraphs_bitset.difference_update(other_subgraph_set) - if not (unfuseable_ancestors_set & remaining_subgraphs_bitset): - break - - sorted_subgraphs.insert( - -(index + 1), - (subgraph_set, (subgraph_inputs, subgraph_outputs)), - ) - - # Update fuseable clients, inputs can no longer be fused with graph variables - # and outputs can't be fused with anything else - for ancestor in subgraph_inputs: - if (ancestor_node := ancestor.owner) is not None: - if ancestor_fuseable_clients := fuseable_clients.get( - ancestor_node - ): - ancestor_fuseable_clients.difference_update(subgraph_nodes) - if not ancestor_fuseable_clients: - del fuseable_clients[ancestor_node] - - for out in subgraph_outputs: - fuseable_clients.pop(out.owner, None) - - yield from (io for _, io in sorted_subgraphs) + yield from ( + subgraph.get_inputs_and_outputs() for subgraph in sorted_subgraphs + ) nb_fused = 0 nb_replacement = 0 From f2683c91e758da9d727e0a561c56801270186f9e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 28 Sep 2025 00:16:49 +0200 Subject: [PATCH 13/13] Try helper classes with bitset --- pytensor/tensor/rewriting/elemwise.py | 124 ++++++++++++++++---------- 1 file changed, 77 insertions(+), 47 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 2d676d2c4a..192d289af5 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -5,6 +5,7 @@ from collections.abc import Generator, Sequence from functools import cache, reduce from heapq import heapify, heappop, heappush +from operator import or_ from warnings import warn import pytensor.scalar.basic as ps @@ -15,7 +16,7 @@ from pytensor.graph.basic import Apply, Variable from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.features import ReplaceValidate -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import ( GraphRewriter, @@ -667,14 +668,14 @@ def __init__(self): self.queue = queue = [] heapify(queue) - def push(self, node: Apply, toposort_index: int, is_ancestor: bool): + def push(self, node: Apply, node_bitflag: int, is_ancestor: bool): if is_ancestor: - toposort_index = -toposort_index - heappush(self.queue, (toposort_index, node)) + node_bitflag = -node_bitflag + heappush(self.queue, (node_bitflag, node)) - def pop(self) -> tuple[Apply, bool]: - toposort_index, node = heappop(self.queue) - return node, toposort_index < 0 + def pop(self) -> tuple[Apply, int, bool]: + node_bitflag, node = heappop(self.queue) + return node, node_bitflag < 0 def __bool__(self): return bool(self.queue) @@ -684,45 +685,49 @@ class NonConvexError(Exception): class ConvexSubgraph: __slots__ = ( - "node_ancestors", + "nodes_bitflags", + "ancestors_bitset", "nodes", - "bitset", - "unfuseable_ancestors", - "unfuseable_clients", + "nodes_bitset", + "unfuseable_ancestors_bitset", + "unfuseable_clients_bitset", "inputs_and_outputs", ) - def __init__(self, node_ancestors): - self.node_ancestors = node_ancestors + def __init__(self, nodes_bitflags, ancestors_bitset): + self.nodes_bitflags = nodes_bitflags + self.ancestors_bitset = ancestors_bitset self.nodes = {} - self.bitset = 0 - self.unfuseable_ancestors = set() - self.unfuseable_clients = set() + self.nodes_bitset = 0 + self.unfuseable_ancestors_bitset = 0 + self.unfuseable_clients_bitset = 0 self.inputs_and_outputs = None def __len__(self): return len(self.nodes) - def __contains__(self, node: Apply): - return node in self.nodes + def __contains__(self, node: int): + return bool(self.nodes_bitset & self.nodes_bitflags[node]) def add_node(self, node: Apply, is_ancestor: bool): + node_bitflag = self.nodes_bitflags[node] if is_ancestor: - if node in self.unfuseable_ancestors: + if node_bitflag & self.unfuseable_ancestors_bitset: raise NonConvexError - elif self.node_ancestors[node] & self.unfuseable_clients: + elif self.ancestors_bitset[node] & self.unfuseable_clients_bitset: raise NonConvexError + self.nodes_bitset |= node_bitflag self.nodes[node] = None self.inputs_and_outputs = None # clear cache def add_unfuseable_ancestor(self, ancestor: Apply): # If an ancestor is unfuseable, so are all its ancestors - self.unfuseable_ancestors |= self.node_ancestors[ancestor] + self.unfuseable_ancestors_bitset |= self.ancestors_bitset[ancestor] def add_unfuseable_client(self, client: Apply): # If a client is unfuseable, so are all its clients, but we don't need to keep track of those # Any downstream client will also depend on this unfuseable client and will be rejected when visited - self.unfuseable_clients.add(client) + self.unfuseable_clients_bitset |= self.nodes_bitflags[client] def get_inputs_and_outputs(self): if self.inputs_and_outputs is not None: @@ -752,37 +757,40 @@ def get_inputs_and_outputs(self): return subgraph_inputs, subgraph_outputs class SortedSubgraphCollection: - __slots__ = ("subgraphs", "nodes") + __slots__ = ("subgraphs", "nodes_bitset") def __init__(self): self.subgraphs: list[ tuple[int, tuple[tuple[Variable], tuple[Variable]]] ] = [] - self.nodes = {} + self.nodes_bitset = 0 - def __contains__(self, node: Apply): - return node in self.nodes + def __contains__(self, node_bitflag: int): + return bool(node_bitflag & self.nodes_bitset) def insert_subgraph(self, subgraph: ConvexSubgraph): # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end # But in some cases they can, so we need to insert at the right position. - subgraph_unfuseable_ancestors = subgraph.unfuseable_ancestors - if subgraph_unfuseable_ancestors.isdisjoint(self.nodes): + subgraph_unfuseable_ancestors_bitset = ( + subgraph.unfuseable_ancestors_bitset + ) + if not (subgraph_unfuseable_ancestors_bitset & self.nodes_bitset): self.subgraphs.append(subgraph) else: # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph # no longer depends on what's left. This tells us where to insert the current subgraph. - remaining_nodes = set(self.nodes) + remaining_nodes_bitset = self.nodes_bitset for index, other_subgraph in enumerate( reversed(self.subgraphs) ): - remaining_nodes.difference_update(other_subgraph.nodes) - if subgraph_unfuseable_ancestors.isdisjoint( - remaining_nodes + remaining_nodes_bitset &= ~other_subgraph.nodes_bitset + if not ( + subgraph_unfuseable_ancestors_bitset + & remaining_nodes_bitset ): break self.subgraphs.insert(-(index + 1), subgraph) - self.nodes |= subgraph.nodes + self.nodes_bitset |= subgraph.nodes_bitset def __iter__(self): yield from self.subgraphs @@ -792,32 +800,54 @@ def __iter__(self): if not fuseable_clients: return None - toposort_idx = { - node: idx for idx, node in enumerate(fg.toposort(), start=1) - } - node_ancestors = {None: frozenset()} - for node in toposort_idx: - node_ancestors[node] = frozenset.union( - *(node_ancestors[inp.owner] for inp in node.inputs), {node} + # Create a bitset of ancestors for each node. + # Each node is represented by a bit flag of it's position in the toposort + # With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c, + # the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100} + # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} + # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND + # For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0` + # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do + # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` + fg_clients = fgraph.clients + nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} + # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None` + ancestors_bitset = {None: 0} + for node, node_bitflag in nodes_bitflags.items(): + # The bitset of each node is the union of the bitsets of its inputs, plus its own bit + ancestors_bitset[node] = reduce( + or_, + (ancestors_bitset[inp.owner] for inp in node.inputs), + node_bitflag, ) + # handle root and leaf nodes gracefully + # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None` + nodes_bitflags[None] = 0 + # Nothing ever depends on output nodes, so just use a new bit for all + out_bitflag = 1 << len(nodes_bitflags) + for out in fg.outputs: + for client, _ in fg_clients[out]: + if isinstance(client.op, Output): + nodes_bitflags[client] = out_bitflag - fg_clients = fgraph.clients sorted_subgraphs = SortedSubgraphCollection() # Start exploring from candidate sink nodes (backwards) # These are Elemwise nodes with a C-implementation, that are not part of another subgraph # And have no other fuseable clients (i.e., are sinks) - for starting_node, starting_idx in reversed(toposort_idx.items()): + for starting_node, starting_bitflag in reversed(nodes_bitflags.items()): if ( - starting_node in sorted_subgraphs + starting_bitflag in sorted_subgraphs or not fuseable_clients.is_sink_node(starting_node) ): continue - subgraph = ConvexSubgraph(node_ancestors) + subgraph = ConvexSubgraph(nodes_bitflags, ancestors_bitset) fuseable_nodes_queue = SortedFuseableNodesQueue() - fuseable_nodes_queue.push(starting_node, starting_idx, is_ancestor=True) + fuseable_nodes_queue.push( + starting_node, starting_bitflag, is_ancestor=True + ) while fuseable_nodes_queue: node, is_ancestor = fuseable_nodes_queue.pop() @@ -841,7 +871,7 @@ def __iter__(self): if node in fuseable_clients[ancestor_node]: fuseable_nodes_queue.push( ancestor_node, - toposort_idx[ancestor_node], + nodes_bitflags[ancestor_node], is_ancestor=True, ) else: @@ -854,7 +884,7 @@ def __iter__(self): if client_node in next_fuseable_clients: fuseable_nodes_queue.push( client_node, - toposort_idx[client_node], + nodes_bitflags[client_node], is_ancestor=False, ) else: