diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 7611a380bd..54ae1d90c0 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -827,71 +827,6 @@ def validate(self, fgraph): raise InconsistencyError("Trying to reintroduce a removed node") -class NodeFinder(Bookkeeper): - def __init__(self): - self.fgraph = None - self.d = {} - - def on_attach(self, fgraph): - if hasattr(fgraph, "get_nodes"): - raise AlreadyThere("NodeFinder is already present") - - if self.fgraph is not None and self.fgraph != fgraph: - raise Exception("A NodeFinder instance can only serve one FunctionGraph.") - - self.fgraph = fgraph - fgraph.get_nodes = partial(self.query, fgraph) - Bookkeeper.on_attach(self, fgraph) - - def clone(self): - return type(self)() - - def on_detach(self, fgraph): - """ - Should remove any dynamically added functionality - that it installed into the function_graph - """ - if self.fgraph is not fgraph: - raise Exception( - "This NodeFinder instance was not attached to the provided fgraph." - ) - self.fgraph = None - del fgraph.get_nodes - Bookkeeper.on_detach(self, fgraph) - - def on_import(self, fgraph, node, reason): - try: - self.d.setdefault(node.op, []).append(node) - except TypeError: # node.op is unhashable - return - except Exception as e: - print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201 - try: - print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201 - except Exception: - print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201 - raise e - - def on_prune(self, fgraph, node, reason): - try: - nodes = self.d[node.op] - except TypeError: # node.op is unhashable - return - nodes.remove(node) - if not nodes: - del self.d[node.op] - - def query(self, fgraph, op): - try: - all = self.d.get(op, []) - except TypeError: - raise TypeError( - f"{op} in unhashable and cannot be queried by the optimizer" - ) - all = list(all) - return all - - class PrintListener(Feature): def __init__(self, active=True): self.active = active diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 5b45fa40f4..12bfb672c3 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -11,12 +11,10 @@ import warnings from collections import Counter, UserList, defaultdict, deque from collections.abc import Callable, Iterable, Sequence -from collections.abc import Iterable as IterableType -from functools import _compose_mro, partial, reduce # type: ignore +from functools import _compose_mro, partial # type: ignore from itertools import chain -from typing import TYPE_CHECKING, Literal +from typing import Literal -import pytensor from pytensor.configdefaults import config from pytensor.graph import destroyhandler as dh from pytensor.graph.basic import ( @@ -28,18 +26,15 @@ io_toposort, vars_between, ) -from pytensor.graph.features import AlreadyThere, Feature, NodeFinder +from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op +from pytensor.graph.rewriting.unify import OpInstance, Var, convert_strs_to_vars from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet from pytensor.utils import flatten -if TYPE_CHECKING: - from pytensor.graph.rewriting.unify import Var - - _logger = logging.getLogger("pytensor.graph.rewriting.basic") RemoveKeyType = Literal["remove"] @@ -942,129 +937,6 @@ def recursive_merge(var): return [recursive_merge(v) for v in variables] -class MetaNodeRewriter(NodeRewriter): - r""" - Base class for meta-rewriters that try a set of `NodeRewriter`\s - to replace a node and choose the one that executes the fastest. - - If the error `MetaNodeRewriterSkip` is raised during - compilation, we will skip that function compilation and not print - the error. - - """ - - def __init__(self): - self.verbose = config.metaopt__verbose - self.track_dict = defaultdict(list) - self.tag_dict = defaultdict(list) - self._tracks = [] - self.rewriters = [] - - def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]): - self.rewriters.append(rewriter) - - tracks = rewriter.tracks() - if tracks: - self._tracks.extend(tracks) - for c in tracks: - self.track_dict[c].append(rewriter) - - for tag in tag_list: - self.tag_dict[tag].append(rewriter) - - def tracks(self): - return self._tracks - - def transform(self, fgraph, node, *args, **kwargs): - # safety check: depending on registration, tracks may have been ignored - if self._tracks is not None: - if not isinstance(node.op, tuple(self._tracks)): - return - # first, we need to provide dummy values for all inputs - # to the node that are not shared variables anyway - givens = {} - missing = set() - for input in node.inputs: - if isinstance(input, pytensor.compile.SharedVariable): - pass - elif hasattr(input.tag, "test_value"): - givens[input] = pytensor.shared( - input.type.filter(input.tag.test_value), - input.name, - shape=input.broadcastable, - borrow=True, - ) - else: - missing.add(input) - if missing: - givens.update(self.provide_inputs(node, missing)) - missing.difference_update(givens.keys()) - # ensure we have data for all input variables that need it - if missing: - if self.verbose > 0: - print( # noqa: T201 - f"{self.__class__.__name__} cannot meta-rewrite {node}, " - f"{len(missing)} of {int(node.nin)} input shapes unknown" - ) - return - # now we can apply the different rewrites in turn, - # compile the resulting subgraphs and time their execution - if self.verbose > 1: - print( # noqa: T201 - f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):" - ) - timings = [] - for node_rewriter in self.get_rewrites(node): - outputs = node_rewriter.transform(fgraph, node, *args, **kwargs) - if outputs: - try: - fn = pytensor.function( - [], outputs, givens=givens, on_unused_input="ignore" - ) - fn.trust_input = True - timing = min(self.time_call(fn) for _ in range(2)) - except MetaNodeRewriterSkip: - continue - except Exception as e: - if self.verbose > 0: - print(f"* {node_rewriter}: exception", e) # noqa: T201 - continue - else: - if self.verbose > 1: - print(f"* {node_rewriter}: {timing:.5g} sec") # noqa: T201 - timings.append((timing, outputs, node_rewriter)) - else: - if self.verbose > 0: - print(f"* {node_rewriter}: not applicable") # noqa: T201 - # finally, we choose the fastest one - if timings: - timings.sort() - if self.verbose > 1: - print(f"= {timings[0][2]}") # noqa: T201 - return timings[0][1] - return - - def provide_inputs(self, node, inputs): - """Return a dictionary mapping some `inputs` to `SharedVariable` instances of with dummy values. - - The `node` argument can be inspected to infer required input shapes. - - """ - raise NotImplementedError() - - def get_rewrites(self, node): - """Return the rewrites that apply to `node`. - - This uses ``self.track_dict[type(node.op)]`` by default. - """ - return self.track_dict[type(node.op)] - - def time_call(self, fn): - start = time.perf_counter() - fn() - return time.perf_counter() - start - - class FromFunctionNodeRewriter(NodeRewriter): """A `NodeRewriter` constructed from a function.""" @@ -1214,9 +1086,6 @@ class SequentialNodeRewriter(NodeRewriter): reentrant : bool Some global rewriters, like `NodeProcessingGraphRewriter`, use this value to determine if they should ignore new nodes. - retains_inputs : bool - States whether or not the inputs of a transformed node are transferred - to the outputs. """ def __init__( @@ -1247,9 +1116,6 @@ def __init__( self.reentrant = any( getattr(rewrite, "reentrant", True) for rewrite in rewriters ) - self.retains_inputs = all( - getattr(rewrite, "retains_inputs", False) for rewrite in rewriters - ) self.apply_all_rewrites = apply_all_rewrites @@ -1425,17 +1291,12 @@ class SubstitutionNodeRewriter(NodeRewriter): # an SubstitutionNodeRewriter does not apply to the nodes it produces reentrant = False - # all the inputs of the original node are transferred to the outputs - retains_inputs = True def __init__(self, op1, op2, transfer_tags=True): self.op1 = op1 self.op2 = op2 self.transfer_tags = transfer_tags - def op_key(self): - return self.op1 - def tracks(self): return [self.op1] @@ -1453,45 +1314,13 @@ def __str__(self): return f"{self.op1} -> {self.op2}" -class RemovalNodeRewriter(NodeRewriter): - """ - Removes all applications of an `Op` by transferring each of its - outputs to the corresponding input. - - """ - - reentrant = False # no nodes are added at all - - def __init__(self, op): - self.op = op - - def op_key(self): - return self.op - - def tracks(self): - return [self.op] - - def transform(self, fgraph, node): - if node.op != self.op: - return False - return node.inputs - - def __str__(self): - return f"{self.op}(x) -> x" - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print( - f"{' ' * level}{self.__class__.__name__}(self.op) id={id(self)}", - file=stream, - ) - - class PatternNodeRewriter(NodeRewriter): """Replace all occurrences of an input pattern with an output pattern. The input and output patterns have the following syntax: input_pattern ::= (op, , , ...) + input_pattern ::= (OpInstance(type(op), {: , ...}), , , ...) input_pattern ::= dict(pattern = , constraint = ) sub_pattern ::= input_pattern @@ -1505,6 +1334,7 @@ class PatternNodeRewriter(NodeRewriter): output_pattern ::= string output_pattern ::= int output_pattern ::= float + output_pattern ::= callable Each string in the input pattern is a variable that will be set to whatever expression is found in its place. If the same string is @@ -1530,22 +1360,74 @@ class PatternNodeRewriter(NodeRewriter): Examples -------- - PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x')) - PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x')) - PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x') - PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x')) - PatternNodeRewriter((boggle, {'pattern': 'x', - 'constraint': lambda expr: expr.type == scrabble}), - (scrabble, 'x')) + .. code-block:: python + + from pytensor.graph.rewriting.basic import PatternNodeRewriter + from pytensor.tensor import add, mul, sub, pow, square + + PatternNodeRewriter((add, "x", "y"), (add, "y", "x")) + PatternNodeRewriter((mul, "x", "x"), (square, "x")) + PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x") + PatternNodeRewriter((pow, "x", 2.0), (square, "x")) + PatternNodeRewriter( + (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"), + (mul, "y", "x"), + ) + + You can use OpInstance to match a subtype of an Op, with some parameter constraints + You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments. + + + .. code-block:: python + + from pytensor.graph.rewriting.basic import PatternNodeRewriter + from pytensor.graph.rewriting.unify import OpInstance + from pytensor.tensor.basic import Join + from pytensor.tensor.elemwise import CAReduce, Elemwise + + + def output_fn(fgraph, node, s): + reduce_op = node.op + reduced_a = reduce_op(s["a"]) + reduced_b = reduce_op(s["b"]) + return Elemwise(s["scalar_op"])(reduced_a, reduced_b) + + + PatternNodeRewriter( + ( + OpInstance(CAReduce, scalar_op="scalar_op", axis=None), + (Join(), "join_axis", "a", "b"), + ), + output_fn, + ) + + + If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable. + + .. code-block:: python + + from pytensor.graph.rewriting.basic import PatternNodeRewriter + from pytensor.graph.rewriting.unify import OpInstance, LiteralString + from pytensor.tensor.blockwise import Blockwise + from pytensor.tensor.slinalg import Solve + + PatternNodeRewriter( + ( + OpInstance( + Blockwise, core_op=OpInstance(Solve, assume_a=LiteralString("gen")) + ), + "A", + "b", + ) + ) """ def __init__( self, - in_pattern, - out_pattern, + in_pattern: tuple, + out_pattern: tuple | Callable, allow_multiple_clients: bool = False, - skip_identities_fn=None, name: str | None = None, tracks=(), get_nodes=None, @@ -1559,12 +1441,10 @@ def __init__( in_pattern The input pattern that we want to replace. out_pattern - The replacement pattern. + The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs allow_multiple_clients If ``False``, the pattern matching will fail if one of the subpatterns has more than one client. - skip_identities_fn - TODO name Set the name of this rewriter. tracks @@ -1574,49 +1454,51 @@ def __init__( function that takes the tracked node and returns a list of nodes on which we will try this rewrite. values_eq_approx - TODO + If specified, this value will be assigned to the ``values_eq_approx`` + tag of the output variable. This is used by DebugMode to determine if rewrites are correct. allow_cast Automatically cast the output of the rewrite whenever new and old types differ Notes ----- `tracks` and `get_nodes` can be used to make this rewrite track a less - frequent `Op`, which will prevent the rewrite from being tried as - often. + frequent `Op`, which will prevent the rewrite from being tried as often. """ - from pytensor.graph.rewriting.unify import convert_strs_to_vars - var_map: dict[str, Var] = {} self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) self.values_eq_approx = values_eq_approx self.allow_cast = allow_cast - if isinstance(in_pattern, list | tuple): - self.op = self.in_pattern[0] - elif isinstance(in_pattern, dict): - self.op = self.in_pattern["pattern"][0] - else: - raise TypeError( - "The pattern to search for must start with a specific Op instance." - ) - self.__doc__ = f"{self.__class__.__doc__}\n\nThis instance does: {self}\n" self.allow_multiple_clients = allow_multiple_clients - self.skip_identities_fn = skip_identities_fn if name: self.__name__ = name - self._tracks = tracks self.get_nodes = get_nodes if tracks != (): - assert get_nodes - - def op_key(self): - return self.op + if not get_nodes: + raise ValueError("Custom `tracks` requires `get_nodes` to be provided.") + self._tracks = tracks + else: + if isinstance(in_pattern, list | tuple): + op = self.in_pattern[0] + elif isinstance(in_pattern, dict): + op = self.in_pattern["pattern"][0] + else: + raise TypeError( + "The pattern to search for must start with a specific Op instance." + ) + if isinstance(op, Op): + self._tracks = [op] + elif isinstance(op, OpInstance): + self._tracks = [op.op_type] + else: + raise ValueError( + f"The pattern to search for must start with a specific Op instance or an OpInstance class. " + f"Got {op}, with type {type(op)}." + ) def tracks(self): - if self._tracks != (): - return self._tracks - return [self.op] + return self._tracks def transform(self, fgraph, node, get_nodes=True): """Check if the graph from node corresponds to ``in_pattern``. @@ -1633,42 +1515,52 @@ def transform(self, fgraph, node, get_nodes=True): if ret is not False and ret is not None: return dict(zip(real_node.outputs, ret, strict=True)) - if node.op != self.op: - return False - if len(node.outputs) != 1: # PatternNodeRewriter doesn't support replacing multi-output nodes return False - s = unify(self.in_pattern, node.out) + s = unify(self.in_pattern, node.out, {}) if s is False: return False - ret = reify(self.out_pattern, s) - - if isinstance(ret, ExpressionTuple): - ret = ret.evaled_obj - - if self.values_eq_approx: - ret.tag.values_eq_approx = self.values_eq_approx - if not self.allow_multiple_clients: - input_vars = list(s.values()) + input_vars = set(s.values()) + clients = fgraph.clients if any( - len(fgraph.clients[v]) > 1 + len(clients[v]) > 1 for v in vars_between(input_vars, node.inputs) if v not in input_vars ): return False + if callable(self.out_pattern): + # token is the variable name used in the original pattern + ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()}) + if ret is None or ret is False: + # The output function is still allowed to reject the rewrite + return False + if not isinstance(ret, Variable): + raise ValueError( + f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}." + ) + else: + ret = reify(self.out_pattern, s) + if isinstance(ret, ExpressionTuple): + ret = ret.evaled_obj + + if self.values_eq_approx: + ret.tag.values_eq_approx = self.values_eq_approx + [old_out] = node.outputs if not old_out.type.is_super(ret.type): + from pytensor.tensor.type import TensorType + # Type doesn't match if not ( self.allow_cast - and isinstance(old_out.type, pytensor.tensor.TensorType) - and isinstance(ret.type, pytensor.tensor.TensorType) + and isinstance(old_out.type, TensorType) + and isinstance(ret.type, TensorType) ): return False @@ -2136,7 +2028,7 @@ def walking_rewriter( else: (node_rewriters,) = node_rewriters if not name: - name = node_rewriters.__name__ + name = getattr(node_rewriters, "__name__", None) ret = WalkingGraphRewriter( node_rewriters, order=order, @@ -2152,52 +2044,6 @@ def walking_rewriter( out2in = partial(walking_rewriter, "out_to_in") -class OpKeyGraphRewriter(NodeProcessingGraphRewriter): - r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s. - - The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either - as a list of `Op`\s or a single `Op`), and discovered within a - `FunctionGraph` using the `NodeFinder` `Feature`. - - This is similar to the `Op`-based tracking feature used by other rewriters. - - """ - - def __init__(self, node_rewriter, ignore_newtrees=False, failure_callback=None): - if not hasattr(node_rewriter, "op_key"): - raise TypeError(f"{node_rewriter} must have an `op_key` method.") - super().__init__(node_rewriter, ignore_newtrees, failure_callback) - - def apply(self, fgraph): - op = self.node_rewriter.op_key() - if isinstance(op, list | tuple): - q = reduce(list.__iadd__, map(fgraph.get_nodes, op)) - else: - q = list(fgraph.get_nodes(op)) - - def importer(node): - if node is not current_node: - if node.op == op: - q.append(node) - - u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) - ) - try: - while q: - node = q.pop() - if node not in fgraph.apply_nodes: - continue - current_node = node - self.process_node(fgraph, node) - finally: - self.detach_updater(fgraph, u) - - def add_requirements(self, fgraph): - super().add_requirements(fgraph) - fgraph.attach_feature(NodeFinder()) - - class ChangeTracker(Feature): def __init__(self): self.changed = False @@ -2785,38 +2631,6 @@ def merge(rewriters, attr, idx): ) -def _check_chain(r, chain): - """ - WRITEME - - """ - chain = list(reversed(chain)) - while chain: - elem = chain.pop() - if elem is None: - if r.owner is not None: - return False - elif r.owner is None: - return False - elif isinstance(elem, Op): - if r.owner.op != elem: - return False - else: - try: - if issubclass(elem, Op) and not isinstance(r.owner.op, elem): - return False - except TypeError: - return False - if chain: - r = r.owner.inputs[chain.pop()] - # print 'check_chain', _check_chain.n_calls - # _check_chain.n_calls += 1 - - # The return value will be used as a Boolean, but some Variables cannot - # be used as Booleans (the results of comparisons, for instance) - return r is not None - - def pre_greedy_node_rewriter( fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable ) -> Variable: @@ -2998,10 +2812,10 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): otherwise. """ - if isinstance(f_or_fgraph, pytensor.compile.function.types.Function): - fgraph = f_or_fgraph.maker.fgraph - elif isinstance(f_or_fgraph, pytensor.graph.fg.FunctionGraph): + if isinstance(f_or_fgraph, FunctionGraph): fgraph = f_or_fgraph + elif hasattr(f_or_fgraph, "fgraph"): + fgraph = f_or_fgraph.fgraph else: raise ValueError("The type of f_or_fgraph is not supported") diff --git a/pytensor/graph/rewriting/trie_unification.py b/pytensor/graph/rewriting/trie_unification.py new file mode 100644 index 0000000000..358b226eea --- /dev/null +++ b/pytensor/graph/rewriting/trie_unification.py @@ -0,0 +1,403 @@ +from dataclasses import dataclass, field +from typing import Any, Union + +from pytensor.graph import Op +from pytensor.graph.basic import Variable +from pytensor.graph.rewriting.unify import OpInstance + + +@dataclass(frozen=True, eq=False) +class MatchPattern: + name: str | None + pattern: tuple + _var_to_standard: dict[str, int] = field(default_factory=dict) + _standard_to_var: dict[int, str] = field(default_factory=dict) + + def __repr__(self): + if self.name is not None: + return self.name + return str(self.pattern) + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self is other + + def get_standard_var_name(self, var: str) -> str: + """Get a canonicalized variable name for a pattern variable. + This increases sharing of paths in the trie as only uniqueness of names matters for the purpose of unification. + """ + standard_var = self._var_to_standard.get(var, None) + if standard_var is None: + standard_var = f"{len(self._var_to_standard)}" + self._var_to_standard[var] = standard_var + self._standard_to_var[standard_var] = var + return standard_var + + +@dataclass(frozen=True) +class Literal: + # Wrapper class to signal that a pattern is a literal value, not a pattern variable + pattern: Any + + +@dataclass(frozen=True) +class Asterisk: + # Wrapper class to signal that a pattern is a wildcard matching zero or more variables + variable: str + + +@dataclass(frozen=True) +class TrieNode: + # Class for Op level trie nodes + # Each node has edges for exact Op matches, Op type matches, variable matches, and + # edges for starting parametrized Op matches (which lead to ParameterTrieNode) + # Terminal patterns are stored at the nodes where patterns end + op_edges: dict[Op, "TrieNode"] = field(default_factory=dict) + op_type_edges: dict[type[Op], "TrieNode"] = field(default_factory=dict) + start_parameter_edges: dict[type[Op], "ParameterTrieNode"] = field( + default_factory=dict + ) + variable_edges: dict[str, "TrieNode"] = field(default_factory=dict) + asterisk_edges: dict[str, "TrieNode"] = field( + default_factory=dict + ) # New: asterisk variable edges + terminal_patterns: list[MatchPattern] = field(default_factory=list) + + +@dataclass(frozen=False) +class ParameterTrieNode: + # Class for Op parameter level trie nodes + # Each node has edges for matching Op parameters (key, pattern) pairs + # (where pattern can be a variable name, an Op type, a literal value, or a nested parametrized Op (OpType, {param: value, ...})) + + # A ParameterTrieNode may have multiple parameter edges to move to the next ParameterTrieNode + # A ParameterTrieNode may have an end_parameter_edge, to move back to the outer TrieNode/ ParameterTrieNode + # This allows different patterns to match a different number of parameters. + # Parameters are arranged in alphabetical order to help sharing of common paths. + + # A ParameterTrieNode may also have a sub_op_parameter_edge, to start matching parameters of a nested parametrized Op + # A sub_op_parameter_edge always follows a parameter_edge for the same parameter key and op type. + + parameter_edges: dict[tuple[str, Any], "ParameterTrieNode"] = field( + default_factory=dict + ) + sub_op_parameter_edge: tuple[str, "ParameterTrieNode"] | None = field(default=None) + + # A ParameterTrieNode may end up followed by a ParameterTrieNode, if it was a nested parametrized op + # Or with a regular TrieNode, if it was the end of a parametrized op pattern + end_parameter_edge: Union["TrieNode", "ParameterTrieNode"] | None = field( + default=None + ) + + +@dataclass(frozen=False) +class Trie: + root_node: TrieNode = field(default_factory=TrieNode) + op_type_match_cache: dict = field(default_factory=dict) + + def add_pattern(self, pattern: MatchPattern | tuple): + """Expand Trie with new pattern""" + self.op_type_match_cache.clear() + + if not isinstance(pattern, MatchPattern): + pattern = MatchPattern(None, pattern) + + def get_keyed_edge(edges_dict, key, trie_class=TrieNode): + next_trie_node = edges_dict.get(key, None) + if next_trie_node is None: + edges_dict[key] = next_trie_node = trie_class() + return next_trie_node + + def recursive_insert_params(trie_node, parameters, nested=False): + assert isinstance(trie_node, ParameterTrieNode) + if not parameters: + # Base case: We consumed all the parameters. Add an end_parameter edge to signal we're done + if trie_node.end_parameter_edge is None: + trie_node.end_parameter_edge = ( + ParameterTrieNode() if nested else TrieNode() + ) + return trie_node.end_parameter_edge + + (item_key, item_pattern), *rest_key_pattern_pairs = parameters + + if isinstance(item_pattern, OpInstance): + # Nested parametrized op + sub_op_type, sub_parameters = ( + item_pattern.op_type, + item_pattern.parameters, + ) + # Start with a parameter edge for the op parameter + start_trie_node = get_keyed_edge( + trie_node.parameter_edges, + (item_key, sub_op_type), + trie_class=ParameterTrieNode, + ) + if item_pattern.parameters: + # Add a sub_op_parameter edge to start matching the nested Op parameters + # A trie node can only have one sub_op_parameter edge, since it's always preceded by a parameter edge + if start_trie_node.sub_op_parameter_edge is None: + start_trie_node.sub_op_parameter_edge = ( + item_key, + ParameterTrieNode(), + ) + (sub_op_key, sub_op_trie_node) = ( + start_trie_node.sub_op_parameter_edge + ) + assert sub_op_key == item_key + next_trie_node = recursive_insert_params( + sub_op_trie_node, sub_parameters, nested=True + ) + else: + # No parameters, so we can directly move to the next trie node + next_trie_node = start_trie_node + else: + # Simple parameter pattern: add a parameter edge + if isinstance(item_pattern, str): + # Pattern variable, replace with a unique variable name + item_pattern = pattern.get_standard_var_name(item_pattern) + # All edges (including variables) go through parameter_edges + # TODO: Consider splitting variable edges into a separate dict for faster matching + next_trie_node = get_keyed_edge( + trie_node.parameter_edges, + (item_key, item_pattern), + trie_class=ParameterTrieNode, + ) + # Recurse with the rest of the parameters + return recursive_insert_params( + next_trie_node, rest_key_pattern_pairs, nested=nested + ) + + def recursinve_insert(trie_node, sub_pattern): + if not sub_pattern: + # Base case: we've consumed the entire pattern + trie_node.terminal_patterns.append(pattern) + return + + head, *tail = sub_pattern + if isinstance(head, tuple): + # ((op, input1, input2, ...), ...) + head_head, *head_tail = head + return recursinve_insert(trie_node, (head_head, *head_tail, *tail)) + + if isinstance(head, OpInstance) and head.parameters: + op_type, parameters = head.op_type, head.parameters + # Start with an edge for the op type + next_trie_node = get_keyed_edge( + trie_node.start_parameter_edges, + op_type, + trie_class=ParameterTrieNode, + ) + # Recurse into the parameters, with parameter edges + next_trie_node = recursive_insert_params(next_trie_node, parameters) + else: + key = head + if isinstance(head, Op): + edge_type = trie_node.op_edges + elif isinstance(head, type) and issubclass(head, Op): + edge_type = trie_node.op_type_edges + elif isinstance(head, OpInstance): + # Empty ParametrizedOp, handle with a simple op_type edge + assert not head.parameters + key = head.op_type + edge_type = trie_node.op_type_edges + elif isinstance(head, str): + key = pattern.get_standard_var_name(head) + edge_type = trie_node.variable_edges + elif isinstance(head, Asterisk): + key = pattern.get_standard_var_name(head.variable) + edge_type = trie_node.asterisk_edges + else: + raise TypeError(f"Invalid head type {type(head)}: {head}") + next_trie_node = get_keyed_edge(edge_type, key) + + # Recurse with the tail of the pattern + recursinve_insert(next_trie_node, tail) + + recursinve_insert(self.root_node, pattern.pattern) + + def match(self, variable): + if not isinstance(variable, Variable): + return False + + def find_op_type_edge_matches(edges_dict, op: Op): + type_op = type(op) + cache_key = (id(edges_dict), type_op) + if cache_key in self.op_type_match_cache: + yield from self.op_type_match_cache[cache_key] + return + + self.op_type_match_cache[cache_key] = matches = [ + match + for base_cls in type_op.mro() + if (match := edges_dict.get(base_cls)) is not None + ] + yield from matches + + def find_op_matches(trie_node: TrieNode, op: Op): + if (next_trie_node := trie_node.op_edges.get(op)) is not None: + yield next_trie_node + + yield from find_op_type_edge_matches(trie_node.op_type_edges, op) + + def recursive_match( + trie_node: TrieNode | ParameterTrieNode, + subject_pattern: tuple[Variable, tuple[Variable, ...]], + subs: dict[str, Any], + num_op_inputs: tuple, + ): + if isinstance(trie_node, TrieNode): + # Base case, terminal patterns are successfully matched + # whenever trie node is reached with no subject pattern left to unify + if not subject_pattern: + for terminal_pattern in trie_node.terminal_patterns: + # Convert the canonicalized variable names back to the original pattern variable names + d = terminal_pattern._standard_to_var + yield terminal_pattern, {d[k]: v for k, v in subs.items()} + + # Unify asterisk variables + # This must be the last pattern for the current op's inputs + for asterisk_var, next_trie_node in trie_node.asterisk_edges.items(): + remaining_n_inputs, tail_n_inputs = num_op_inputs + consumed_vars = subject_pattern[:remaining_n_inputs] + remaining_subject = subject_pattern[remaining_n_inputs:] + subs_copy = subs + + if asterisk_var in subs: + if subs[asterisk_var] != consumed_vars: + continue # mismatch + else: + subs_copy = subs.copy() + subs_copy[asterisk_var] = consumed_vars + yield from recursive_match( + next_trie_node, + remaining_subject, + subs_copy, + num_op_inputs=tail_n_inputs, + ) + + if not subject_pattern: + # Nothing left to match + return None + + head, *tail = subject_pattern + assert isinstance(head, Variable), (type(head), head) + + # Unify variable patterns + for ( + variable_pattern, + next_trie_node, + ) in trie_node.variable_edges.items(): + subs_copy = subs + if variable_pattern in subs: + if subs[variable_pattern] != head: + continue # mismatch + else: + subs_copy = subs.copy() + subs_copy[variable_pattern] = head + + remaining_n_inputs, tail_n_inputs = num_op_inputs + if remaining_n_inputs == 0: + # We've exhausted the inputs for the current op, this next variable belongs to the next input of the outer Op + remaining_n_inputs, tail_n_inputs = tail_n_inputs + assert ( + remaining_n_inputs > 0 + ), "Number of inputs to consume is smaller than expected. Perhaps missing an Asterisk pattern?" + yield from recursive_match( + next_trie_node, + tail, + subs_copy, + (remaining_n_inputs - 1, tail_n_inputs), + ) + + if head.owner is None: + # head is a root variable, can only be matched to wildcard patterns above + return False + head_op = head.owner.op + + # Match exact op or type op (including subclasses) + # We consume the head variable and extend the tail pattern with its inputs + for next_trie_node in find_op_matches(trie_node, head_op): + yield from recursive_match( + next_trie_node, + (*head.owner.inputs, *tail), + subs, + (len(head.owner.inputs), num_op_inputs), + ) + + # Match start of parametrized op pattern + for next_trie_node in find_op_type_edge_matches( + trie_node.start_parameter_edges, head_op + ): + # We place the Op variable at the head of the subject pattern + # And extend the tail pattern with the inputs of the head variable, just like a regular op match + yield from recursive_match( + next_trie_node, + (head_op, *head.owner.inputs, *tail), + subs, + (len(head.owner.inputs), num_op_inputs), + ) + + else: # ParameterTrieNode + head_op, *tail = subject_pattern + assert isinstance(head_op, Op), (type(head_op), head_op) + + # Exit parametrized op pattern matching + if (next_trie_node := trie_node.end_parameter_edge) is not None: + # We discard the head variable and keep working on the tail pattern + yield from recursive_match( + next_trie_node, tail, subs, num_op_inputs + ) + + # Match op parameters + for ( + op_param_key, + op_param_pattern, + ), next_trie_node in trie_node.parameter_edges.items(): + op_param_value = getattr(head_op, op_param_key) + subs_copy = subs + + # Match variable pattern + if isinstance(op_param_pattern, str): + if op_param_pattern in subs: + if subs[op_param_pattern] != op_param_value: + continue # mismatch + else: + subs_copy = subs.copy() + subs_copy[op_param_pattern] = op_param_value + # Match op type + elif isinstance(op_param_pattern, type) and issubclass( + op_param_pattern, Op + ): + if not isinstance(op_param_value, op_param_pattern): + continue # mismatch + # Match literal value + elif isinstance(op_param_pattern, Literal): + if op_param_value != op_param_pattern.pattern: + continue # mismatch + # Match exact value + elif op_param_value != op_param_pattern: + continue # mismatch + + # We arrive here if there was no mismatch + # For parameter edges, we continue to the next trie_node with the same pattern + # as we may still need to check other parameters from the same Op + # We'll eventually move to the tail pattern via an end_parameter edge + yield from recursive_match( + next_trie_node, subject_pattern, subs_copy, num_op_inputs + ) + + # Match nested op parametrizations + # This always follows an op parameter edge + if trie_node.sub_op_parameter_edge is not None: + (sub_op_param_key, next_trie_node) = trie_node.sub_op_parameter_edge + sub_op = getattr(head_op, sub_op_param_key) + # For sub_op parameter edges, we continue to the next trie_node with the sub_op as the head + yield from recursive_match( + next_trie_node, (sub_op, *subject_pattern), subs, num_op_inputs + ) + return None + + yield from recursive_match(self.root_node, (variable,), {}, ()) + return None diff --git a/pytensor/graph/rewriting/unify.py b/pytensor/graph/rewriting/unify.py index e9361d62c2..177cc1925d 100644 --- a/pytensor/graph/rewriting/unify.py +++ b/pytensor/graph/rewriting/unify.py @@ -10,8 +10,10 @@ """ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence +from dataclasses import dataclass from numbers import Number +from typing import Any import numpy as np from cons.core import ConsError, _car, _cdr @@ -254,6 +256,103 @@ def _unify_ConstrainedVar_object(u, v, s): _unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) +@dataclass(frozen=True) +class LiteralString: + value: str + + +class OpInstance: + """Class that can be unified with Op instances of a given type and parameters. + + An op instance is unified as long as the parameters specified in the OpInstance can be unified as well. + Parameters that are not specified in the OpInstance are ignored during unification. + + This is needed because some Ops can be complex to parametrize fully, + and not all parameters are relevant for a given pattern. + + Examples + -------- + + .. testcode:: + + from unification import var, unify + from etuples import etuple + + import pytensor.tensor as pt + from pytensor.graph.rewriting.unify import OpInstance + from pytensor.tensor.blockwise import Blockwise + from pytensor.tensor.slinalg import Solve + + A = var("A") + b = var("b") + pattern = etuple( + OpInstance(Blockwise, core_op=OpInstance(Solve, assume_a="gen")), A, b + ) + + A_pt = pt.tensor3("A") + b_pt = pt.tensor3("b") + out1 = pt.linalg.solve(A_pt, b_pt) + out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos") + + assert unify(pattern, out1) == {A: A_pt, b: b_pt} + assert unify(pattern, out2) is False + + assume_a = var("assume_a") + pattern = etuple( + OpInstance(Blockwise, core_op=OpInstance(Solve, assume_a=assume_a)), + A, + b, + ) + assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"} + assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"} + + + """ + + def __init__( + self, + op_type: type[Op], + parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None, + **kwargs, + ): + if not (isinstance(op_type, type) and issubclass(op_type, Op)): + raise TypeError(f"Invalid op_type {op_type}. Expected type(Op)") + + if kwargs: + if parameters is not None: + raise ValueError( + "Cannot provide both parameters dict and keyword arguments" + ) + parameters = kwargs + if isinstance(parameters, dict): + parameters = tuple(sorted(parameters.items())) + elif isinstance(parameters, list | tuple): + parameters = tuple(sorted(parameters)) + elif parameters is None: + parameters = () + self.op_type = op_type + self.parameters = parameters + + def __str__(self): + return f"{self.op_type.__name__}({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})" + + +def _unify_parametrized_op(v: Op, u: OpInstance, s: Mapping): + if not isinstance(v, u.op_type): + yield False + return + for parameter_key, parameter_pattern in u.parameters: + parameter_value = getattr(v, parameter_key) + s = yield _unify(parameter_value, parameter_pattern, s) + if s is False: + yield False + return + yield s + + +_unify.add((Op, OpInstance, Mapping), _unify_parametrized_op) + + def convert_strs_to_vars( x: tuple | str | dict, var_map: dict[str, Var] | None = None ) -> ExpressionTuple | Var: @@ -266,11 +365,13 @@ def convert_strs_to_vars( if var_map is None: var_map = {} - def _convert(y): + def _convert(y, op_prop=False): if isinstance(y, str): v = var_map.get(y, var(y)) var_map[y] = v return v + if isinstance(y, LiteralString): + return y.value elif isinstance(y, dict): pattern = y["pattern"] if not isinstance(pattern, str): @@ -282,8 +383,14 @@ def _convert(y): var_map[pattern] = v return v elif isinstance(y, tuple): - return etuple(*(_convert(e) for e in y)) - elif isinstance(y, Number | np.ndarray): + return etuple(*(_convert(e, op_prop=op_prop) for e in y)) + elif isinstance(y, OpInstance): + return OpInstance( + y.op_type, + {k: _convert(v, op_prop=True) for k, v in y.parameters}, + ) + elif (not op_prop) and isinstance(y, Number | np.ndarray): + # If we are converting an Op property, we don't want to convert numbers to PyTensor constants from pytensor.tensor import as_tensor_variable return as_tensor_variable(y) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 77723917b0..49b92e4ce3 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -159,8 +159,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): # This is the list of the original dimensions that we keep self.shuffle = [x for x in new_order if x != "x"] self.transposition = self.shuffle + drop - # List of dimensions of the output that are broadcastable and were not - # in the original input + # List of dimensions of the output that are broadcastable and were not in the original input self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.drop = drop @@ -175,6 +174,12 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.is_right_expand_dims = self.is_expand_dims and new_order[ :input_ndim ] == list(range(input_ndim)) + self.is_matrix_transpose = False + if dims_are_shuffled and (not drop) and input_ndim >= 2: + # We consider a matrix transpose if we only flip the last two dims + # Regardless of whethre there's an expand_dims or not + mt_pattern = [*range(input_ndim - 2), input_ndim - 1, input_ndim - 2] + self.is_matrix_transpose = new_order[len(augment) :] == mt_pattern def __setstate__(self, state): self.__dict__.update(state) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 83ee8c2c3b..e9c2c8e47e 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -34,7 +34,6 @@ from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, - RemovalNodeRewriter, Rewriter, copy_stack_trace, in2out, @@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node): return [alloc(inputs_inner[0], *dims_outer)] -register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") +@register_canonicalize +@node_rewriter(tracks=[tensor_copy]) +def remove_tensor_copy(fgraph, node): + return node.inputs @register_specialize diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 8367642c4c..44f5da0399 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -66,25 +66,11 @@ def is_matrix_transpose(x: TensorVariable) -> bool: """Check if a variable corresponds to a transpose of the last two axes""" node = x.owner - if ( - node + return ( + node is not None and isinstance(node.op, DimShuffle) - and not (node.op.drop or node.op.augment) - ): - [inp] = node.inputs - ndims = inp.type.ndim - if ndims < 2: - return False - transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2) - - # Allow expand_dims on the left of the transpose - if (diff := len(transpose_order) - len(node.op.new_order)) > 0: - transpose_order = ( - *(["x"] * diff), - *transpose_order, - ) - return node.op.new_order == transpose_order - return False + and node.op.is_matrix_transpose + ) @register_canonicalize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index f8156067f9..df2355ca12 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10): return np.allclose(x, ref, rtol=rtol, atol=atol) -def _skip_mul_1(r): - if r.owner and r.owner.op == mul: - not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] - if len(not_is_1) == 1: - return not_is_1[0] - - def _is_1(expr): """ @@ -3190,7 +3183,6 @@ def _is_1(expr): (neg, (softplus, (neg, "x"))), allow_multiple_clients=True, values_eq_approx=values_eq_approx_remove_inf, - skip_identities_fn=_skip_mul_1, tracks=[sigmoid], get_nodes=get_clients_at_depth1, ) @@ -3199,7 +3191,6 @@ def _is_1(expr): (neg, (softplus, "x")), allow_multiple_clients=True, values_eq_approx=values_eq_approx_remove_inf, - skip_identities_fn=_skip_mul_1, tracks=[sigmoid], get_nodes=get_clients_at_depth2, ) diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index d0cb94f9fb..1522d6d8fe 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -8,18 +8,16 @@ from pytensor.graph.rewriting.basic import ( EquilibriumGraphRewriter, MergeOptimizer, - OpKeyGraphRewriter, OpToRewriterTracker, PatternNodeRewriter, SequentialNodeRewriter, - SubstitutionNodeRewriter, - WalkingGraphRewriter, in2out, logging, node_rewriter, pre_constant_merge, pre_greedy_node_rewriter, ) +from pytensor.graph.rewriting.unify import LiteralString, OpInstance from pytensor.raise_op import assert_op from pytensor.tensor.math import Dot, add, dot, exp from pytensor.tensor.rewriting.basic import constant_folding @@ -52,16 +50,12 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def OpKeyPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False): - return OpKeyGraphRewriter( + return in2out( PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients), ignore_newtrees=ign, ) -def WalkingPatternNodeRewriter(p1, p2, ign=True): - return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) - - class TestPatternNodeRewriter: def test_replace_output(self): # replacing the whole graph @@ -160,7 +154,7 @@ def test_ambiguous(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) - WalkingPatternNodeRewriter((op1, (op1, "1")), (op1, "1"), ign=False).rewrite(g) + OpKeyPatternNodeRewriter((op1, (op1, "1")), (op1, "1"), ign=False).rewrite(g) assert str(g) == "FunctionGraph(Op1(x))" def test_constant(self): @@ -202,7 +196,7 @@ def test_match_same_illegal(self): g = FunctionGraph([x, y, z], [e]) def constraint(r): - # Only replacing if the input is an instance of Op2 + # Only replacing if the inputs are not identical return r.owner.inputs[0] is not r.owner.inputs[1] OpKeyPatternNodeRewriter( @@ -287,25 +281,41 @@ def test_eq(self): str_g = str(g) assert str_g == "FunctionGraph(Op4(z, y))" + def test_op_instance(self): + a = MyVariable("a") + e1 = MyOp(name="MyOp(x=1)", x=1)(a) + e2 = MyOp(name="MyOp(x=2)", x=2)(a) + e_hello = MyOp(name="MyOp(x='hello')", x="hello")(a) + op_x3 = MyOp(name="MyOp(x=3)", x=3) + assert not equal_computations([e1], [op_x3(a)]) + assert not equal_computations([e2], [op_x3(a)]) + + rewriter = OpKeyPatternNodeRewriter( + (OpInstance(MyOp, x=1), "a"), + "a", + ) + g = FunctionGraph([a], [e1, e2, e1], copy_inputs=False) + rewriter.rewrite(g) + assert equal_computations(g.outputs, [a, e2, a]) + + rewriter = OpKeyPatternNodeRewriter( + (OpInstance(MyOp, x="x"), "a"), + lambda fgraph, node, subs: ( + MyOp(name="MyOp(x+=10)", x=subs["x"] + 10)(subs["a"]) + if subs["x"] < 10 + else False + ), + ) + g = FunctionGraph([a], [e1], copy_inputs=False) + rewriter.rewrite(g) + assert equal_computations(g.outputs, [MyOp(name="x=11", x=11)(a)]) -def KeyedSubstitutionNodeRewriter(op1, op2): - return OpKeyGraphRewriter(SubstitutionNodeRewriter(op1, op2)) - - -class TestSubstitutionNodeRewriter: - def test_straightforward(self): - x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") - e = op1(op1(op1(op1(op1(x))))) - g = FunctionGraph([x, y, z], [e]) - KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g) - assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))" - - def test_straightforward_2(self): - x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") - e = op1(op2(x), op3(y), op4(z)) - g = FunctionGraph([x, y, z], [e]) - KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))" + rewriter = OpKeyPatternNodeRewriter( + (OpInstance(MyOp, x=LiteralString("hello")), "a"), "a" + ) + g = FunctionGraph([a], [e1, e_hello], copy_inputs=False) + rewriter.rewrite(g) + assert equal_computations(g.outputs, [e1, a]) class NoInputOp(Op): diff --git a/tests/graph/rewriting/test_trie_unification.py b/tests/graph/rewriting/test_trie_unification.py new file mode 100644 index 0000000000..b43aef9047 --- /dev/null +++ b/tests/graph/rewriting/test_trie_unification.py @@ -0,0 +1,182 @@ +# TODO: Use simpler test ops +import pytensor.tensor as pt +from pytensor.graph.rewriting.trie_unification import ( + Asterisk, + Literal, + MatchPattern, + Trie, +) +from pytensor.graph.rewriting.unify import OpInstance +from pytensor.tensor.basic import Join +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.slinalg import Cholesky, Solve + + +def test_solve_of_cholesky(): + def blockwise_of(core_op): + return OpInstance(Blockwise, {"core_op": core_op}) + + MatrixTransposePattern = OpInstance(DimShuffle, {"is_matrix_transpose": True}) + GenSolvePattern = blockwise_of(OpInstance(Solve, {"assume_a": Literal("gen")})) + CholeskyPattern = blockwise_of(OpInstance(Cholesky, {"lower": "lower"})) + + A = pt.matrix("A") + b = pt.vector("b") + out1 = pt.linalg.solve(pt.linalg.cholesky(A), b) + out2 = pt.linalg.solve(pt.linalg.cholesky(A).mT, b) + + P1 = MatchPattern( + "GenSolve(Cholesky(A), b)", + (GenSolvePattern, (CholeskyPattern, "A"), "b"), + ) + P2 = MatchPattern( + "GenSolve(Cholesky(A).mT, b)", + (GenSolvePattern, (MatrixTransposePattern, (CholeskyPattern, "A")), "b"), + ) + + trie = Trie() + trie.add_pattern(P1) + trie.add_pattern(P2) + + r1 = dict(trie.match(out1)) + assert list(r1) == [P1] + assert r1[P1] == {"A": A, "b": b, "lower": True} + + r2 = dict(trie.match(out2)) + assert list(r2) == [P2] + assert r2[P2] == {"A": A, "b": b, "lower": True} + + +def test_mixed_blockwise_types(): + blockwise_unary = MatchPattern("Blockwise(x)", (Blockwise, "x")) + blockwise_lower_cholesky = MatchPattern( + "Blockwise(Cholesky(lower=True))(x)", (Blockwise(Cholesky(lower=True)), "x") + ) + blockwise_cholesky = MatchPattern( + "Blockwise(Cholesky)(x)", (OpInstance(Blockwise, core_op=Cholesky), "x") + ) + alt_blockwise_lower_cholesky = MatchPattern( + "[Alt]Blockwise(Cholesky)(lower=True)(x)", + ( + OpInstance(Blockwise, {"core_op": OpInstance(Cholesky, [("lower", True)])}), + "x", + ), + ) + solve_gen_var = MatchPattern( + "Blockwise(Solve(assume_a=?gen))(A, b)", + ( + OpInstance(Blockwise, core_op=OpInstance(Solve, assume_a="?gen")), + "A", + "b", + ), + ) + solve_gen_literal = MatchPattern( + "Blockwise(Solve(assume_a=gen))(A, b)", + ( + OpInstance(Blockwise, core_op=OpInstance(Solve, assume_a=Literal("gen"))), + "A", + "b", + ), + ) + + trie = Trie() + for pattern in ( + blockwise_unary, + blockwise_lower_cholesky, + blockwise_cholesky, + alt_blockwise_lower_cholesky, + solve_gen_var, + solve_gen_literal, + ): + trie.add_pattern(pattern) + + X = pt.matrix("X") + out = pt.linalg.cholesky(X) + res = dict(trie.match(out)) + assert set(res) == { + blockwise_unary, + blockwise_lower_cholesky, + blockwise_cholesky, + alt_blockwise_lower_cholesky, + } + for subs in res.values(): + assert subs == {"x": X} + + out = pt.linalg.cholesky(X, lower=False) + res = dict(trie.match(out)) + assert set(res) == { + blockwise_unary, + blockwise_cholesky, + } + for subs in res.values(): + assert subs == {"x": X} + + A, b = pt.matrix("A"), pt.vector("b") + out = pt.linalg.solve(A, b) + res = dict(trie.match(out)) + assert set(res) == {solve_gen_var, solve_gen_literal} + assert res[solve_gen_literal] == {"A": A, "b": b} + assert res[solve_gen_var] == {"?gen": "gen", "A": A, "b": b} + + +def test_asterisk(): + P1 = MatchPattern( + "Reduce(Join(*entries))", + (CAReduce, (Join, "axis", Asterisk("entries"))), + ) + P2 = MatchPattern( + "Pow(Reduce(Join(*entries)), y)", + (pt.pow, (CAReduce, (Join, "axis", Asterisk("entries"))), "y"), + ) + + trie = Trie() + trie.add_pattern(P1) + trie.add_pattern(P2) + + x = pt.vector("x") + y = pt.vector("y") + z = pt.vector("z") + zeroth_axis = pt.constant(0, dtype="int64") + sum_of_join = pt.sum(pt.join(zeroth_axis, x, y, z)) + res = dict(trie.match(sum_of_join)) + assert set(res) == {P1} + assert res[P1] == {"axis": zeroth_axis, "entries": [x, y, z]} + + exponent = pt.scalar("exponent", dtype="int64") + pow_of_sum = pt.pow(sum_of_join, exponent) + res = dict(trie.match(pow_of_sum)) + assert set(res) == {P2} + assert res[P2] == {"axis": zeroth_axis, "entries": [x, y, z], "y": exponent} + + +def test_repeated_vars(): + P = MatchPattern( + "Join(x, x)", + (Join, "axis", "x", "x", Asterisk("xs")), + ) + + trie = Trie() + trie.add_pattern(P) + + x, y = pt.vectors("xy") + zeroth_axis = pt.constant(0, dtype="int64") + + join_xx = pt.join(zeroth_axis, x, x) + res = dict(trie.match(join_xx)) + assert set(res) == {P} + assert res[P] == {"axis": zeroth_axis, "x": x, "xs": []} + + join_xy = pt.join(zeroth_axis, x, y) + res = dict(trie.match(join_xy)) + assert set(res) == set() + + join_xxy = pt.join(zeroth_axis, x, x, y) + res = dict(trie.match(join_xxy)) + assert set(res) == {P} + assert res[P] == {"axis": zeroth_axis, "x": x, "xs": [y]} + + join_xxyx = pt.join(zeroth_axis, x, x, y, x) + res = dict(trie.match(join_xxyx)) + assert set(res) == {P} + assert res[P] == {"axis": zeroth_axis, "x": x, "xs": [y, x]} diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index da430a1587..792ced9c56 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -11,7 +11,11 @@ import pytensor.tensor as pt from pytensor.graph.basic import Apply, Constant, equal_computations from pytensor.graph.op import Op -from pytensor.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars +from pytensor.graph.rewriting.unify import ( + ConstrainedVar, + OpInstance, + convert_strs_to_vars, +) from pytensor.tensor.type import TensorType from tests.graph.utils import MyType @@ -348,3 +352,25 @@ def constraint(x): res = convert_strs_to_vars((val,)) assert isinstance(res[0], Constant) assert np.array_equal(res[0].data, val) + + +def test_unify_OpInstance(): + x_pt = MyType()("x_pt") + y_pt = MyType()("y_pt") + out1 = CustomOp(a=1)(x_pt, y_pt) + out2 = CustomOp(a=2)(x_pt, y_pt) + + x = var("x") + y = var("y") + pattern = etuple(OpInstance(CustomOp), x, y) + assert unify(pattern, out1) == {x: x_pt, y: y_pt} + assert unify(pattern, out2) == {x: x_pt, y: y_pt} + + pattern = etuple(OpInstance(CustomOp, a=1), x, y) + assert unify(pattern, out1) == {x: x_pt, y: y_pt} + assert unify(pattern, out2) is False + + a = var("a") + pattern = etuple(OpInstance(CustomOp, a=a), x, y) + assert unify(pattern, out1) == {x: x_pt, y: y_pt, a: 1} + assert unify(pattern, out2) == {x: x_pt, y: y_pt, a: 2} diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 16a654da26..70333f369b 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -10,7 +10,6 @@ from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, - OpKeyGraphRewriter, PatternNodeRewriter, SubstitutionNodeRewriter, WalkingGraphRewriter, @@ -21,7 +20,7 @@ def OpKeyPatternNodeRewriter(p1, p2, ign=True): - return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) + return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) def TopoSubstitutionNodeRewriter( diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index d23caf52ee..474104269d 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -2,92 +2,12 @@ import pytensor.tensor as pt from pytensor.graph import rewrite_graph -from pytensor.graph.basic import Apply, Variable, equal_computations -from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate +from pytensor.graph.basic import equal_computations +from pytensor.graph.features import Feature, FullHistory, ReplaceValidate from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import Op -from pytensor.graph.type import Type from tests.graph.utils import MyVariable, op1 -class TestNodeFinder: - def test_straightforward(self): - class MyType(Type): - def __init__(self, name): - self.name = name - - def filter(self, *args, **kwargs): - raise NotImplementedError() - - def __str__(self): - return self.name - - def __repr__(self): - return self.name - - def __eq__(self, other): - return isinstance(other, MyType) - - class MyOp(Op): - __props__ = ("nin", "name") - - def __init__(self, nin, name): - self.nin = nin - self.name = name - - def make_node(self, *inputs): - def as_variable(x): - assert isinstance(x, Variable) - return x - - assert len(inputs) == self.nin - inputs = list(map(as_variable, inputs)) - for input in inputs: - if not isinstance(input.type, MyType): - raise Exception("Error 1") - outputs = [MyType(self.name + "_R")()] - return Apply(self, inputs, outputs) - - def __str__(self): - return self.name - - def perform(self, *args, **kwargs): - raise NotImplementedError() - - sigmoid = MyOp(1, "Sigmoid") - add = MyOp(2, "Add") - dot = MyOp(2, "Dot") - - def MyVariable(name): - return Variable(MyType(name), None, None) - - def inputs(): - x = MyVariable("x") - y = MyVariable("y") - z = MyVariable("z") - return x, y, z - - x, y, z = inputs() - e0 = dot(y, z) - e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) - g = FunctionGraph([x, y, z], [e], clone=False) - g.attach_feature(NodeFinder()) - - assert hasattr(g, "get_nodes") - for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): - if len(list(g.get_nodes(type))) != num: - raise Exception(f"Expected: {num} times {type}") - new_e0 = add(y, z) - assert e0.owner in g.get_nodes(dot) - assert new_e0.owner not in g.get_nodes(add) - g.replace(e0, new_e0) - assert e0.owner not in g.get_nodes(dot) - assert new_e0.owner in g.get_nodes(add) - for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): - if len(list(g.get_nodes(type))) != num: - raise Exception(f"Expected: {num} times {type}") - - class TestReplaceValidate: def test_verbose(self, capsys): var1 = MyVariable("var1")