2929from pytensor .graph .features import AlreadyThere , Feature
3030from pytensor .graph .fg import FunctionGraph , Output
3131from pytensor .graph .op import Op
32- from pytensor .graph .rewriting .unify import Var , convert_strs_to_vars
32+ from pytensor .graph .rewriting .unify import OpPattern , Var , convert_strs_to_vars
3333from pytensor .graph .utils import AssocList , InconsistencyError
3434from pytensor .misc .ordered_set import OrderedSet
3535from pytensor .utils import flatten
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
13121312 The input and output patterns have the following syntax:
13131313
13141314 input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1315+ input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
13151316 input_pattern ::= dict(pattern = <input_pattern>,
13161317 constraint = <constraint>)
13171318 sub_pattern ::= input_pattern
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
13251326 output_pattern ::= string
13261327 output_pattern ::= int
13271328 output_pattern ::= float
1329+ output_pattern ::= callable
13281330
13291331 Each string in the input pattern is a variable that will be set to
13301332 whatever expression is found in its place. If the same string is
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
13501352 Examples
13511353 --------
13521354
1353- PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1354- PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1355- PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1356- PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1357- PatternNodeRewriter((boggle, {'pattern': 'x',
1358- 'constraint': lambda expr: expr.type == scrabble}),
1359- (scrabble, 'x'))
1355+ .. code-block:: python
13601356
1357+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1358+ from pytensor.tensor import add, mul, sub, pow, square
1359+
1360+ PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1361+ PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1362+ PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1363+ PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1364+ PatternNodeRewriter(
1365+ (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1366+ (mul, "y", "x"),
1367+ )
1368+
1369+ You can use OpPattern to match a subtype of an Op, with some parameter constraints
1370+ You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1371+
1372+
1373+ .. code-block:: python
1374+
1375+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1376+ from pytensor.graph.rewriting.unify import OpPattern
1377+ from pytensor.tensor.basic import Join
1378+ from pytensor.tensor.elemwise import CAReduce, Elemwise
1379+
1380+
1381+ def output_fn(fgraph, node, s):
1382+ reduce_op = node.op
1383+ reduced_a = reduce_op(s["a"])
1384+ reduced_b = reduce_op(s["b"])
1385+ return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1386+
1387+
1388+ PatternNodeRewriter(
1389+ (
1390+ OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
1391+ (Join(), "join_axis", "a", "b"),
1392+ ),
1393+ output_fn,
1394+ )
1395+
1396+
1397+ If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1398+
1399+ .. code-block:: python
1400+
1401+
1402+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1403+ from pytensor.graph.rewriting.unify import OpPattern, LiteralString
1404+ from pytensor.tensor.blockwise import Blockwise
1405+ from pytensor.tensor.slinalg import Solve
1406+
1407+ PatternNodeRewriter(
1408+ (
1409+ OpPattern(
1410+ Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
1411+ ),
1412+ "A",
1413+ "b",
1414+ )
1415+ )
13611416 """
13621417
13631418 def __init__ (
13641419 self ,
1365- in_pattern ,
1366- out_pattern ,
1420+ in_pattern : tuple ,
1421+ out_pattern : tuple | Callable ,
13671422 allow_multiple_clients : bool = False ,
13681423 name : str | None = None ,
13691424 tracks = (),
@@ -1378,7 +1433,7 @@ def __init__(
13781433 in_pattern
13791434 The input pattern that we want to replace.
13801435 out_pattern
1381- The replacement pattern.
1436+ The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs
13821437 allow_multiple_clients
13831438 If ``False``, the pattern matching will fail if one of the subpatterns has
13841439 more than one client.
@@ -1407,26 +1462,40 @@ def __init__(
14071462 self .out_pattern = convert_strs_to_vars (out_pattern , var_map = var_map )
14081463 self .values_eq_approx = values_eq_approx
14091464 self .allow_cast = allow_cast
1410- if isinstance (in_pattern , list | tuple ):
1411- self .op = self .in_pattern [0 ]
1412- elif isinstance (in_pattern , dict ):
1413- self .op = self .in_pattern ["pattern" ][0 ]
1414- else :
1415- raise TypeError (
1416- "The pattern to search for must start with a specific Op instance."
1417- )
14181465 self .allow_multiple_clients = allow_multiple_clients
14191466 if name :
14201467 self .__name__ = name
1421- self ._tracks = tracks
14221468 self .get_nodes = get_nodes
14231469 if tracks != ():
1424- assert get_nodes
1470+ if not get_nodes :
1471+ raise ValueError ("Custom `tracks` requires `get_nodes` to be provided." )
1472+ self ._tracks = tracks
1473+ else :
1474+ if isinstance (in_pattern , list | tuple ):
1475+ op = self .in_pattern [0 ]
1476+ elif isinstance (in_pattern , dict ):
1477+ op = self .in_pattern ["pattern" ][0 ]
1478+ else :
1479+ raise TypeError (
1480+ f"The in_pattern must be a sequence or a dict, but got { type (in_pattern )} "
1481+ )
1482+ if isinstance (op , Op ):
1483+ self ._tracks = [op ]
1484+ elif isinstance (op , type ) and issubclass (op , Op ):
1485+ raise ValueError (
1486+ f"The in_pattern starts with an Op class { op } , not an instance.\n "
1487+ "You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class."
1488+ )
1489+ elif isinstance (op , OpPattern ):
1490+ self ._tracks = [op .op_type ]
1491+ else :
1492+ raise ValueError (
1493+ f"The in_pattern must start with a specific Op or an OpPattern instance. "
1494+ f"Got { op } , with type { type (op )} ."
1495+ )
14251496
14261497 def tracks (self ):
1427- if self ._tracks != ():
1428- return self ._tracks
1429- return [self .op ]
1498+ return self ._tracks
14301499
14311500 def transform (self , fgraph , node , get_nodes = True ):
14321501 """Check if the graph from node corresponds to ``in_pattern``.
@@ -1447,28 +1516,39 @@ def transform(self, fgraph, node, get_nodes=True):
14471516 # PatternNodeRewriter doesn't support replacing multi-output nodes
14481517 return False
14491518
1450- s = unify (self .in_pattern , node .out )
1519+ s = unify (self .in_pattern , node .out , {} )
14511520
14521521 if s is False :
14531522 return False
14541523
1455- ret = reify (self .out_pattern , s )
1456-
1457- if isinstance (ret , ExpressionTuple ):
1458- ret = ret .evaled_obj
1459-
1460- if self .values_eq_approx :
1461- ret .tag .values_eq_approx = self .values_eq_approx
1462-
14631524 if not self .allow_multiple_clients :
1464- input_vars = list (s .values ())
1525+ input_vars = set (s .values ())
1526+ clients = fgraph .clients
14651527 if any (
1466- len (fgraph . clients [v ]) > 1
1528+ len (clients [v ]) > 1
14671529 for v in vars_between (input_vars , node .inputs )
14681530 if v not in input_vars
14691531 ):
14701532 return False
14711533
1534+ if callable (self .out_pattern ):
1535+ # token is the variable name used in the original pattern
1536+ ret = self .out_pattern (fgraph , node , {k .token : v for k , v in s .items ()})
1537+ if ret is None or ret is False :
1538+ # The output function is still allowed to reject the rewrite
1539+ return False
1540+ if not isinstance (ret , Variable ):
1541+ raise ValueError (
1542+ f"The output of the PatternNodeRewriter callable must be a variable got { ret } of type { type (ret )} ."
1543+ )
1544+ else :
1545+ ret = reify (self .out_pattern , s )
1546+ if isinstance (ret , ExpressionTuple ):
1547+ ret = ret .evaled_obj
1548+
1549+ if self .values_eq_approx :
1550+ ret .tag .values_eq_approx = self .values_eq_approx
1551+
14721552 [old_out ] = node .outputs
14731553 if not old_out .type .is_super (ret .type ):
14741554 from pytensor .tensor .type import TensorType
0 commit comments