2929from pytensor .graph .features import AlreadyThere , Feature
3030from pytensor .graph .fg import FunctionGraph , Output
3131from pytensor .graph .op import Op
32- from pytensor .graph .rewriting .unify import Var , convert_strs_to_vars
32+ from pytensor .graph .rewriting .unify import OpPattern , Var , convert_strs_to_vars
3333from pytensor .graph .utils import AssocList , InconsistencyError
3434from pytensor .misc .ordered_set import OrderedSet
3535from pytensor .utils import flatten
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
13121312 The input and output patterns have the following syntax:
13131313
13141314 input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1315+ input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
13151316 input_pattern ::= dict(pattern = <input_pattern>,
13161317 constraint = <constraint>)
13171318 sub_pattern ::= input_pattern
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
13251326 output_pattern ::= string
13261327 output_pattern ::= int
13271328 output_pattern ::= float
1329+ output_pattern ::= callable
13281330
13291331 Each string in the input pattern is a variable that will be set to
13301332 whatever expression is found in its place. If the same string is
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
13501352 Examples
13511353 --------
13521354
1353- PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1354- PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1355- PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1356- PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1357- PatternNodeRewriter((boggle, {'pattern': 'x',
1358- 'constraint': lambda expr: expr.type == scrabble}),
1359- (scrabble, 'x'))
1355+ .. code-block:: python
13601356
1357+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1358+ from pytensor.tensor import add, mul, sub, pow, square
1359+
1360+ PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1361+ PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1362+ PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1363+ PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1364+ PatternNodeRewriter(
1365+ (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1366+ (mul, "y", "x"),
1367+ )
1368+
1369+ You can use OpPattern to match a subtype of an Op, with some parameter constraints
1370+ You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1371+
1372+
1373+ .. code-block:: python
1374+
1375+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1376+ from pytensor.graph.rewriting.unify import OpPattern
1377+ from pytensor.tensor.basic import Join
1378+ from pytensor.tensor.elemwise import CAReduce, Elemwise
1379+
1380+
1381+ def output_fn(fgraph, node, s):
1382+ reduce_op = node.op
1383+ reduced_a = reduce_op(s["a"])
1384+ reduced_b = reduce_op(s["b"])
1385+ return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1386+
1387+
1388+ PatternNodeRewriter(
1389+ (
1390+ OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
1391+ (Join(), "join_axis", "a", "b"),
1392+ ),
1393+ output_fn,
1394+ )
1395+
1396+
1397+ If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1398+
1399+ .. code-block:: python
1400+
1401+
1402+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1403+ from pytensor.graph.rewriting.unify import OpPattern, LiteralString
1404+ from pytensor.tensor.blockwise import Blockwise
1405+ from pytensor.tensor.slinalg import Solve
1406+
1407+ PatternNodeRewriter(
1408+ (
1409+ OpPattern(
1410+ Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
1411+ ),
1412+ "A",
1413+ "b",
1414+ )
1415+ )
13611416 """
13621417
13631418 def __init__ (
13641419 self ,
1365- in_pattern ,
1366- out_pattern ,
1420+ in_pattern : tuple ,
1421+ out_pattern : tuple | Callable ,
13671422 allow_multiple_clients : bool = False ,
13681423 name : str | None = None ,
13691424 tracks = (),
@@ -1378,7 +1433,7 @@ def __init__(
13781433 in_pattern
13791434 The input pattern that we want to replace.
13801435 out_pattern
1381- The replacement pattern.
1436+ The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs
13821437 allow_multiple_clients
13831438 If ``False``, the pattern matching will fail if one of the subpatterns has
13841439 more than one client.
@@ -1407,26 +1462,35 @@ def __init__(
14071462 self .out_pattern = convert_strs_to_vars (out_pattern , var_map = var_map )
14081463 self .values_eq_approx = values_eq_approx
14091464 self .allow_cast = allow_cast
1410- if isinstance (in_pattern , list | tuple ):
1411- self .op = self .in_pattern [0 ]
1412- elif isinstance (in_pattern , dict ):
1413- self .op = self .in_pattern ["pattern" ][0 ]
1414- else :
1415- raise TypeError (
1416- "The pattern to search for must start with a specific Op instance."
1417- )
14181465 self .allow_multiple_clients = allow_multiple_clients
14191466 if name :
14201467 self .__name__ = name
1421- self ._tracks = tracks
14221468 self .get_nodes = get_nodes
14231469 if tracks != ():
1424- assert get_nodes
1470+ if not get_nodes :
1471+ raise ValueError ("Custom `tracks` requires `get_nodes` to be provided." )
1472+ self ._tracks = tracks
1473+ else :
1474+ if isinstance (in_pattern , list | tuple ):
1475+ op = self .in_pattern [0 ]
1476+ elif isinstance (in_pattern , dict ):
1477+ op = self .in_pattern ["pattern" ][0 ]
1478+ else :
1479+ raise TypeError (
1480+ "The pattern to search for must start with a specific Op instance."
1481+ )
1482+ if isinstance (op , Op ):
1483+ self ._tracks = [op ]
1484+ elif isinstance (op , OpPattern ):
1485+ self ._tracks = [op .op_type ]
1486+ else :
1487+ raise ValueError (
1488+ f"The pattern to search for must start with a specific Op instance or an OpPattern class. "
1489+ f"Got { op } , with type { type (op )} ."
1490+ )
14251491
14261492 def tracks (self ):
1427- if self ._tracks != ():
1428- return self ._tracks
1429- return [self .op ]
1493+ return self ._tracks
14301494
14311495 def transform (self , fgraph , node , get_nodes = True ):
14321496 """Check if the graph from node corresponds to ``in_pattern``.
@@ -1447,28 +1511,39 @@ def transform(self, fgraph, node, get_nodes=True):
14471511 # PatternNodeRewriter doesn't support replacing multi-output nodes
14481512 return False
14491513
1450- s = unify (self .in_pattern , node .out )
1514+ s = unify (self .in_pattern , node .out , {} )
14511515
14521516 if s is False :
14531517 return False
14541518
1455- ret = reify (self .out_pattern , s )
1456-
1457- if isinstance (ret , ExpressionTuple ):
1458- ret = ret .evaled_obj
1459-
1460- if self .values_eq_approx :
1461- ret .tag .values_eq_approx = self .values_eq_approx
1462-
14631519 if not self .allow_multiple_clients :
1464- input_vars = list (s .values ())
1520+ input_vars = set (s .values ())
1521+ clients = fgraph .clients
14651522 if any (
1466- len (fgraph . clients [v ]) > 1
1523+ len (clients [v ]) > 1
14671524 for v in vars_between (input_vars , node .inputs )
14681525 if v not in input_vars
14691526 ):
14701527 return False
14711528
1529+ if callable (self .out_pattern ):
1530+ # token is the variable name used in the original pattern
1531+ ret = self .out_pattern (fgraph , node , {k .token : v for k , v in s .items ()})
1532+ if ret is None or ret is False :
1533+ # The output function is still allowed to reject the rewrite
1534+ return False
1535+ if not isinstance (ret , Variable ):
1536+ raise ValueError (
1537+ f"The output of the PatternNodeRewriter callable must be a variable got { ret } of type { type (ret )} ."
1538+ )
1539+ else :
1540+ ret = reify (self .out_pattern , s )
1541+ if isinstance (ret , ExpressionTuple ):
1542+ ret = ret .evaled_obj
1543+
1544+ if self .values_eq_approx :
1545+ ret .tag .values_eq_approx = self .values_eq_approx
1546+
14721547 [old_out ] = node .outputs
14731548 if not old_out .type .is_super (ret .type ):
14741549 from pytensor .tensor .type import TensorType
0 commit comments