2626
2727import numpy as np
2828
29- import pytensor .scalar .basic as ps
3029from pytensor import compile , config
3130from pytensor .compile .ops import ViewOp
32- from pytensor .graph import FunctionGraph
31+ from pytensor .graph import FunctionGraph , Op
3332from pytensor .graph .basic import Constant
3433from pytensor .graph .rewriting .basic import (
3534 NodeProcessingGraphRewriter ,
4039 node_rewriter ,
4140)
4241from pytensor .graph .rewriting .db import RewriteDatabase
42+ from pytensor .graph .rewriting .unify import OpPattern , OpPatternOpTypeType
4343from pytensor .npy_2_compat import normalize_axis_index
4444from pytensor .raise_op import Assert , CheckAndRaise , assert_op
45- from pytensor .scalar .basic import Second
45+ from pytensor .scalar import (
46+ AND ,
47+ EQ ,
48+ LE ,
49+ NEQ ,
50+ OR ,
51+ XOR ,
52+ Add ,
53+ BinaryScalarOp ,
54+ Cast ,
55+ Identity ,
56+ Mul ,
57+ Second ,
58+ Switch ,
59+ )
4660from pytensor .tensor .basic import (
4761 Alloc ,
4862 AllocEmpty ,
@@ -225,6 +239,12 @@ def register(inner_rewriter: RewriteDatabase | Rewriter):
225239 return node_rewriter
226240
227241
242+ def elemwise_of (scalar_op : OpPatternOpTypeType | OpPattern ) -> OpPattern :
243+ if not isinstance (scalar_op , Op | OpPattern ):
244+ scalar_op = OpPattern (scalar_op )
245+ return OpPattern (Elemwise , scalar_op = scalar_op )
246+
247+
228248@register_canonicalize
229249@register_specialize
230250@node_rewriter ([TensorFromScalar ])
@@ -551,15 +571,15 @@ def local_useless_elemwise(fgraph, node):
551571 dtype = node .outputs [0 ].type .dtype
552572 scalar_op = node .op .scalar_op
553573
554- if isinstance (scalar_op , ps . EQ ) and len (node .inputs ) == 2 :
574+ if isinstance (scalar_op , EQ ) and len (node .inputs ) == 2 :
555575 if node .inputs [0 ] is node .inputs [1 ]:
556576 # it is the same var in the graph. That will always be true
557577 ret = ones_like (node .inputs [0 ], dtype = dtype , opt = True )
558578
559579 # Copy stack trace from input to constant output
560580 copy_stack_trace (node .outputs [0 ], ret )
561581 return [ret ]
562- elif isinstance (scalar_op , ps . NEQ | ps . XOR ) and len (node .inputs ) == 2 :
582+ elif isinstance (scalar_op , NEQ | XOR ) and len (node .inputs ) == 2 :
563583 if node .inputs [0 ] is node .inputs [1 ]:
564584 # it is the same var in the graph. That will always be false
565585 ret = zeros_like (node .inputs [0 ], dtype = dtype , opt = True )
@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node):
568588 copy_stack_trace (node .outputs [0 ], ret )
569589 return [ret ]
570590
571- elif (
572- isinstance (node .op .scalar_op , ps .Mul | ps .Add | ps .Identity )
573- and len (node .inputs ) == 1
574- ):
591+ elif isinstance (node .op .scalar_op , Mul | Add | Identity ) and len (node .inputs ) == 1 :
575592 # No need to copy over any stack trace
576593 return [node .inputs [0 ]]
577594
578- elif isinstance (node .op .scalar_op , ps . AND ) and len (node .inputs ) == 2 :
595+ elif isinstance (node .op .scalar_op , AND ) and len (node .inputs ) == 2 :
579596 if (
580597 isinstance (node .inputs [0 ], TensorConstant )
581598 and node .inputs [1 ].type .broadcastable == out_bcast
@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node):
602619 # and this rewrite would be wrong
603620 return [node .inputs [0 ].astype (node .outputs [0 ].dtype )]
604621
605- elif isinstance (node .op .scalar_op , ps . OR ) and len (node .inputs ) == 2 :
622+ elif isinstance (node .op .scalar_op , OR ) and len (node .inputs ) == 2 :
606623 if (
607624 isinstance (node .inputs [0 ], TensorConstant )
608625 and node .inputs [1 ].type .broadcastable == out_bcast
@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node):
653670
654671@register_canonicalize
655672@register_specialize
656- @node_rewriter ([Elemwise ])
673+ @node_rewriter ([elemwise_of ( Cast ) ])
657674def local_cast_cast (fgraph , node ):
658675 """cast(cast(x, dtype1), dtype2)
659676
@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node):
663680 and the first cast cause an upcast.
664681
665682 """
666- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , ps .Cast )):
667- return
668683 x = node .inputs [0 ]
669684 if not (
670685 x .owner
671686 and isinstance (x .owner .op , Elemwise )
672- and isinstance (x .owner .op .scalar_op , ps . Cast )
687+ and isinstance (x .owner .op .scalar_op , Cast )
673688 ):
674689 return
675690
@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node):
10091024 node .outputs [0 ].type .ndim == 0
10101025 and cond_var .owner
10111026 and isinstance (cond_var .owner .op , Elemwise )
1012- and isinstance (cond_var .owner .op .scalar_op , ps . LE )
1027+ and isinstance (cond_var .owner .op .scalar_op , LE )
10131028 and cond_var .owner .inputs [0 ].owner
10141029 and isinstance (cond_var .owner .inputs [0 ].owner .op , Shape_i )
10151030 and get_scalar_constant_value (
@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node):
10311046
10321047
10331048@register_canonicalize
1034- @node_rewriter ([Elemwise ])
1049+ @node_rewriter ([elemwise_of ( BinaryScalarOp | Add | Mul ) ])
10351050def local_merge_switch_same_cond (fgraph , node ):
10361051 """
10371052 Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
10381053 condition, to enable further simplification of their branches
10391054 Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
10401055 """
1041- # node must be binary elemwise or add or mul
1042- if not (
1043- isinstance (node .op , Elemwise )
1044- and isinstance (node .op .scalar_op , ps .BinaryScalarOp | ps .Add | ps .Mul )
1045- ):
1046- return
10471056 # all inputs must be switch
10481057 if not all (
10491058 s .owner
10501059 and isinstance (s .owner .op , Elemwise )
1051- and isinstance (s .owner .op .scalar_op , ps . Switch )
1060+ and isinstance (s .owner .op .scalar_op , Switch )
10521061 for s in node .inputs
10531062 ):
10541063 return
@@ -1174,10 +1183,9 @@ def constant_folding(fgraph, node):
11741183@register_infer_shape
11751184@register_canonicalize ("fast_compile" )
11761185@register_useless ("fast_compile" )
1177- @node_rewriter (None )
1186+ @node_rewriter ([ ViewOp ] )
11781187def local_view_op (fgraph , node ):
1179- if isinstance (node .op , ViewOp ):
1180- return node .inputs
1188+ return node .inputs
11811189
11821190
11831191@register_infer_shape
0 commit comments