2929from pytensor .graph .features import AlreadyThere , Feature
3030from pytensor .graph .fg import FunctionGraph , Output
3131from pytensor .graph .op import Op
32- from pytensor .graph .rewriting .unify import Var , convert_strs_to_vars
32+ from pytensor .graph .rewriting .unify import OpPattern , Var , convert_strs_to_vars
3333from pytensor .graph .utils import AssocList , InconsistencyError
3434from pytensor .misc .ordered_set import OrderedSet
3535from pytensor .utils import flatten
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
13121312 The input and output patterns have the following syntax:
13131313
13141314 input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1315+ input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
13151316 input_pattern ::= dict(pattern = <input_pattern>,
13161317 constraint = <constraint>)
13171318 sub_pattern ::= input_pattern
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
13251326 output_pattern ::= string
13261327 output_pattern ::= int
13271328 output_pattern ::= float
1329+ output_pattern ::= callable
13281330
13291331 Each string in the input pattern is a variable that will be set to
13301332 whatever expression is found in its place. If the same string is
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
13501352 Examples
13511353 --------
13521354
1353- PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1354- PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1355- PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1356- PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1357- PatternNodeRewriter((boggle, {'pattern': 'x',
1358- 'constraint': lambda expr: expr.type == scrabble}),
1359- (scrabble, 'x'))
1355+ .. code-block:: python
13601356
1357+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1358+ from pytensor.tensor import add, mul, sub, pow, square
1359+
1360+ PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1361+ PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1362+ PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1363+ PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1364+ PatternNodeRewriter(
1365+ (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1366+ (mul, "y", "x"),
1367+ )
1368+
1369+ You can use OpPattern to match a subtype of an Op, with some parameter constraints
1370+ You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1371+
1372+
1373+ .. code-block:: python
1374+
1375+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1376+ from pytensor.graph.rewriting.unify import OpPattern
1377+ from pytensor.tensor.basic import Join
1378+ from pytensor.tensor.elemwise import CAReduce, Elemwise
1379+
1380+
1381+ def output_fn(fgraph, node, s):
1382+ reduce_op = node.op
1383+ reduced_a = reduce_op(s["a"])
1384+ reduced_b = reduce_op(s["b"])
1385+ return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1386+
1387+
1388+ PatternNodeRewriter(
1389+ (
1390+ OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
1391+ (Join(), "join_axis", "a", "b"),
1392+ ),
1393+ output_fn,
1394+ )
1395+
1396+
1397+ If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1398+
1399+ .. code-block:: python
1400+
1401+
1402+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1403+ from pytensor.graph.rewriting.unify import OpPattern, LiteralString
1404+ from pytensor.tensor.blockwise import Blockwise
1405+ from pytensor.tensor.slinalg import Solve
1406+
1407+ PatternNodeRewriter(
1408+ (
1409+ OpPattern(
1410+ Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
1411+ ),
1412+ "A",
1413+ "b",
1414+ )
1415+ )
13611416 """
13621417
13631418 def __init__ (
13641419 self ,
1365- in_pattern ,
1366- out_pattern ,
1420+ in_pattern : tuple ,
1421+ out_pattern : tuple | Callable | str ,
13671422 allow_multiple_clients : bool = False ,
13681423 name : str | None = None ,
13691424 tracks = (),
@@ -1378,7 +1433,8 @@ def __init__(
13781433 in_pattern
13791434 The input pattern that we want to replace.
13801435 out_pattern
1381- The replacement pattern.
1436+ The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs,
1437+ and returns the replacement variable (or None/False to reject the rewrite).
13821438 allow_multiple_clients
13831439 If ``False``, the pattern matching will fail if one of the subpatterns has
13841440 more than one client.
@@ -1407,26 +1463,40 @@ def __init__(
14071463 self .out_pattern = convert_strs_to_vars (out_pattern , var_map = var_map )
14081464 self .values_eq_approx = values_eq_approx
14091465 self .allow_cast = allow_cast
1410- if isinstance (in_pattern , list | tuple ):
1411- self .op = self .in_pattern [0 ]
1412- elif isinstance (in_pattern , dict ):
1413- self .op = self .in_pattern ["pattern" ][0 ]
1414- else :
1415- raise TypeError (
1416- "The pattern to search for must start with a specific Op instance."
1417- )
14181466 self .allow_multiple_clients = allow_multiple_clients
14191467 if name :
14201468 self .__name__ = name
1421- self ._tracks = tracks
14221469 self .get_nodes = get_nodes
14231470 if tracks != ():
1424- assert get_nodes
1471+ if not get_nodes :
1472+ raise ValueError ("Custom `tracks` requires `get_nodes` to be provided." )
1473+ self ._tracks = tracks
1474+ else :
1475+ if isinstance (in_pattern , list | tuple ):
1476+ op = self .in_pattern [0 ]
1477+ elif isinstance (in_pattern , dict ):
1478+ op = self .in_pattern ["pattern" ][0 ]
1479+ else :
1480+ raise TypeError (
1481+ f"The in_pattern must be a sequence or a dict, but got { in_pattern } of type { type (in_pattern )} "
1482+ )
1483+ if isinstance (op , Op ):
1484+ self ._tracks = [op ]
1485+ elif isinstance (op , type ) and issubclass (op , Op ):
1486+ raise ValueError (
1487+ f"The in_pattern starts with an Op class { op } , not an instance.\n "
1488+ "You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class."
1489+ )
1490+ elif isinstance (op , OpPattern ):
1491+ self ._tracks = [op .op_type ]
1492+ else :
1493+ raise ValueError (
1494+ f"The in_pattern must start with a specific Op or an OpPattern instance. "
1495+ f"Got { op } , with type { type (op )} ."
1496+ )
14251497
14261498 def tracks (self ):
1427- if self ._tracks != ():
1428- return self ._tracks
1429- return [self .op ]
1499+ return self ._tracks
14301500
14311501 def transform (self , fgraph , node , get_nodes = True ):
14321502 """Check if the graph from node corresponds to ``in_pattern``.
@@ -1447,28 +1517,39 @@ def transform(self, fgraph, node, get_nodes=True):
14471517 # PatternNodeRewriter doesn't support replacing multi-output nodes
14481518 return False
14491519
1450- s = unify (self .in_pattern , node .out )
1520+ s = unify (self .in_pattern , node .out , {} )
14511521
14521522 if s is False :
14531523 return False
14541524
1455- ret = reify (self .out_pattern , s )
1456-
1457- if isinstance (ret , ExpressionTuple ):
1458- ret = ret .evaled_obj
1459-
1460- if self .values_eq_approx :
1461- ret .tag .values_eq_approx = self .values_eq_approx
1462-
14631525 if not self .allow_multiple_clients :
1464- input_vars = list (s .values ())
1526+ input_vars = set (s .values ())
1527+ clients = fgraph .clients
14651528 if any (
1466- len (fgraph . clients [v ]) > 1
1529+ len (clients [v ]) > 1
14671530 for v in vars_between (input_vars , node .inputs )
14681531 if v not in input_vars
14691532 ):
14701533 return False
14711534
1535+ if callable (self .out_pattern ):
1536+ # token is the variable name used in the original pattern
1537+ ret = self .out_pattern (fgraph , node , {k .token : v for k , v in s .items ()})
1538+ if ret is None or ret is False :
1539+ # The output function is still allowed to reject the rewrite
1540+ return False
1541+ if not isinstance (ret , Variable ):
1542+ raise ValueError (
1543+ f"The output of the PatternNodeRewriter callable must be a variable got { ret } of type { type (ret )} ."
1544+ )
1545+ else :
1546+ ret = reify (self .out_pattern , s )
1547+ if isinstance (ret , ExpressionTuple ):
1548+ ret = ret .evaled_obj
1549+
1550+ if self .values_eq_approx :
1551+ ret .tag .values_eq_approx = self .values_eq_approx
1552+
14721553 [old_out ] = node .outputs
14731554 if not old_out .type .is_super (ret .type ):
14741555 from pytensor .tensor .type import TensorType
0 commit comments