2929from pytensor .graph .features import AlreadyThere , Feature
3030from pytensor .graph .fg import FunctionGraph , Output
3131from pytensor .graph .op import Op
32- from pytensor .graph .rewriting .unify import Var , convert_strs_to_vars
32+ from pytensor .graph .rewriting .unify import OpInstance , Var , convert_strs_to_vars
3333from pytensor .graph .utils import AssocList , InconsistencyError
3434from pytensor .misc .ordered_set import OrderedSet
3535from pytensor .utils import flatten
@@ -1320,6 +1320,7 @@ class PatternNodeRewriter(NodeRewriter):
13201320 The input and output patterns have the following syntax:
13211321
13221322 input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1323+ input_pattern ::= (OpInstance(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
13231324 input_pattern ::= dict(pattern = <input_pattern>,
13241325 constraint = <constraint>)
13251326 sub_pattern ::= input_pattern
@@ -1333,6 +1334,7 @@ class PatternNodeRewriter(NodeRewriter):
13331334 output_pattern ::= string
13341335 output_pattern ::= int
13351336 output_pattern ::= float
1337+ output_pattern ::= callable
13361338
13371339 Each string in the input pattern is a variable that will be set to
13381340 whatever expression is found in its place. If the same string is
@@ -1358,20 +1360,73 @@ class PatternNodeRewriter(NodeRewriter):
13581360 Examples
13591361 --------
13601362
1361- PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1362- PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1363- PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1364- PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1365- PatternNodeRewriter((boggle, {'pattern': 'x',
1366- 'constraint': lambda expr: expr.type == scrabble}),
1367- (scrabble, 'x'))
1363+ .. code-block:: python
13681364
1365+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1366+ from pytensor.tensor import add, mul, sub, pow, square
1367+
1368+ PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1369+ PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1370+ PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1371+ PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1372+ PatternNodeRewriter(
1373+ (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1374+ (mul, "y", "x"),
1375+ )
1376+
1377+ You can use OpInstance to match a subtype of an Op, with some parameter constraints
1378+ You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1379+
1380+
1381+ .. code-block:: python
1382+
1383+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1384+ from pytensor.graph.rewriting.unify import OpInstance
1385+ from pytensor.tensor.basic import Join
1386+ from pytensor.tensor.elemwise import CAReduce, Elemwise
1387+
1388+
1389+ def output_fn(fgraph, node, s):
1390+ reduce_op = node.op
1391+ reduced_a = reduce_op(s["a"])
1392+ reduced_b = reduce_op(s["b"])
1393+ return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1394+
1395+
1396+ PatternNodeRewriter(
1397+ (
1398+ OpInstance(CAReduce, scalar_op="scalar_op", axis=None),
1399+ (Join(), "join_axis", "a", "b"),
1400+ ),
1401+ output_fn,
1402+ )
1403+
1404+
1405+ If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1406+
1407+ .. code-block:: python
1408+
1409+
1410+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1411+ from pytensor.graph.rewriting.unify import OpInstance, LiteralString
1412+ from pytensor.tensor.blockwise import Blockwise
1413+ from pytensor.tensor.slinalg import Solve
1414+
1415+ PatternNodeRewriter(
1416+ (
1417+ OpInstance(
1418+ Blockwise, core_op=OpInstance(Solve, assume_a=LiteralString("gen"))
1419+ ),
1420+ "A",
1421+ "b",
1422+ )
1423+ )
13691424 """
13701425
13711426 def __init__ (
13721427 self ,
1373- in_pattern ,
1374- out_pattern ,
1428+ in_pattern : tuple ,
1429+ out_pattern : tuple | Callable ,
13751430 allow_multiple_clients : bool = False ,
13761431 name : str | None = None ,
13771432 tracks = (),
@@ -1386,7 +1441,7 @@ def __init__(
13861441 in_pattern
13871442 The input pattern that we want to replace.
13881443 out_pattern
1389- The replacement pattern.
1444+ The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs
13901445 allow_multiple_clients
13911446 If ``False``, the pattern matching will fail if one of the subpatterns has
13921447 more than one client.
@@ -1415,26 +1470,35 @@ def __init__(
14151470 self .out_pattern = convert_strs_to_vars (out_pattern , var_map = var_map )
14161471 self .values_eq_approx = values_eq_approx
14171472 self .allow_cast = allow_cast
1418- if isinstance (in_pattern , list | tuple ):
1419- self .op = self .in_pattern [0 ]
1420- elif isinstance (in_pattern , dict ):
1421- self .op = self .in_pattern ["pattern" ][0 ]
1422- else :
1423- raise TypeError (
1424- "The pattern to search for must start with a specific Op instance."
1425- )
14261473 self .allow_multiple_clients = allow_multiple_clients
14271474 if name :
14281475 self .__name__ = name
1429- self ._tracks = tracks
14301476 self .get_nodes = get_nodes
14311477 if tracks != ():
1432- assert get_nodes
1478+ if not get_nodes :
1479+ raise ValueError ("Custom `tracks` requires `get_nodes` to be provided." )
1480+ self ._tracks = tracks
1481+ else :
1482+ if isinstance (in_pattern , list | tuple ):
1483+ op = self .in_pattern [0 ]
1484+ elif isinstance (in_pattern , dict ):
1485+ op = self .in_pattern ["pattern" ][0 ]
1486+ else :
1487+ raise TypeError (
1488+ "The pattern to search for must start with a specific Op instance."
1489+ )
1490+ if isinstance (op , Op ):
1491+ self ._tracks = [op ]
1492+ elif isinstance (op , OpInstance ):
1493+ self ._tracks = [op .op_type ]
1494+ else :
1495+ raise ValueError (
1496+ f"The pattern to search for must start with a specific Op instance or an OpInstance class. "
1497+ f"Got { op } , with type { type (op )} ."
1498+ )
14331499
14341500 def tracks (self ):
1435- if self ._tracks != ():
1436- return self ._tracks
1437- return [self .op ]
1501+ return self ._tracks
14381502
14391503 def transform (self , fgraph , node , get_nodes = True ):
14401504 """Check if the graph from node corresponds to ``in_pattern``.
@@ -1455,28 +1519,39 @@ def transform(self, fgraph, node, get_nodes=True):
14551519 # PatternNodeRewriter doesn't support replacing multi-output nodes
14561520 return False
14571521
1458- s = unify (self .in_pattern , node .out )
1522+ s = unify (self .in_pattern , node .out , {} )
14591523
14601524 if s is False :
14611525 return False
14621526
1463- ret = reify (self .out_pattern , s )
1464-
1465- if isinstance (ret , ExpressionTuple ):
1466- ret = ret .evaled_obj
1467-
1468- if self .values_eq_approx :
1469- ret .tag .values_eq_approx = self .values_eq_approx
1470-
14711527 if not self .allow_multiple_clients :
1472- input_vars = list (s .values ())
1528+ input_vars = set (s .values ())
1529+ clients = fgraph .clients
14731530 if any (
1474- len (fgraph . clients [v ]) > 1
1531+ len (clients [v ]) > 1
14751532 for v in vars_between (input_vars , node .inputs )
14761533 if v not in input_vars
14771534 ):
14781535 return False
14791536
1537+ if callable (self .out_pattern ):
1538+ # token is the variable name used in the original pattern
1539+ ret = self .out_pattern (fgraph , node , {k .token : v for k , v in s .items ()})
1540+ if ret is None or ret is False :
1541+ # The output function is still allowed to reject the rewrite
1542+ return False
1543+ if not isinstance (ret , Variable ):
1544+ raise ValueError (
1545+ f"The output of the PatternNodeRewriter callable must be a variable got { ret } of type { type (ret )} ."
1546+ )
1547+ else :
1548+ ret = reify (self .out_pattern , s )
1549+ if isinstance (ret , ExpressionTuple ):
1550+ ret = ret .evaled_obj
1551+
1552+ if self .values_eq_approx :
1553+ ret .tag .values_eq_approx = self .values_eq_approx
1554+
14801555 [old_out ] = node .outputs
14811556 if not old_out .type .is_super (ret .type ):
14821557 from pytensor .tensor .type import TensorType
0 commit comments