@@ -141,7 +141,12 @@ def tracks(self) -> Sequence[Op] | None:
141141
142142 @abc .abstractmethod
143143 def transform (
144- self , fgraph : FunctionGraph , node : Apply , * args , ** kwargs
144+ self ,
145+ fgraph : FunctionGraph ,
146+ node : Apply ,
147+ enforce_tracks : bool = True ,
148+ * args ,
149+ ** kwargs ,
145150 ) -> TransformOutputType :
146151 r"""Rewrite the sub-graph given by `node`.
147152
@@ -159,7 +164,9 @@ def transform(
159164 A `FunctionGraph` containing `node`.
160165 node
161166 An `Apply` node to be rewritten.
162-
167+ enforce_tracks: bool
168+ Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
169+ See `node_rewriter` tracks argument for more details.
163170 """
164171
165172 raise NotImplementedError ()
@@ -935,15 +942,48 @@ class FromFunctionNodeRewriter(NodeRewriter):
935942 def __init__ (self , fn , tracks = None , requirements = ()):
936943 self .fn = fn
937944 self ._tracks = tracks
938- self ._tracked_types = (
939- tuple (t for t in tracks if isinstance (t , type )) if tracks else ()
940- )
945+ self ._tracked_ops = set ()
946+ self ._tracked_types = type (None )
947+ self ._tracked_op_pattern_types = type (None )
948+ self ._tracked_op_patterns : list [OpPattern ] = []
949+ if tracks is not None :
950+ if not tracks :
951+ raise ValueError (
952+ "To specify a general rewrite leave tracks as None instead of an empty container"
953+ )
954+ for t in tracks :
955+ if isinstance (t , Op ):
956+ self ._tracked_ops .add (t )
957+ elif isinstance (t , type ):
958+ self ._tracked_types |= t
959+ elif isinstance (t , OpPattern ):
960+ if t .parameters :
961+ self ._tracked_op_patterns .append (t )
962+ self ._tracked_op_pattern_types |= t .op_type
963+ else :
964+ # An OpPattern without parameters behaves like a regular tracked_type
965+ self ._tracked_types |= t
966+ else :
967+ raise TypeError (
968+ "`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
969+ f"Got { t } of type { type (t )} "
970+ )
941971 self .requirements = requirements
942972
943- def transform (self , fgraph , node ):
944- if self ._tracks :
973+ def transform (self , fgraph , node , enforce_tracks : bool = True ):
974+ if enforce_tracks and self ._tracks :
975+ node_op = node .op
945976 if not (
946- node .op in self ._tracks or isinstance (node .op , self ._tracked_types )
977+ node_op in self ._tracked_ops
978+ or isinstance (node_op , self ._tracked_types )
979+ or (
980+ isinstance (node .op , self ._tracked_op_pattern_types )
981+ and any (
982+ t .match_parameters (node_op )
983+ for t in self ._tracked_op_patterns
984+ if isinstance (node_op , t .op_type )
985+ )
986+ )
947987 ):
948988 return False
949989
@@ -967,7 +1007,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
9671007
9681008
9691009def node_rewriter (
970- tracks : Sequence [Op | type ] | None ,
1010+ tracks : Sequence [Op | type , OpPattern ] | None ,
9711011 inplace : bool = False ,
9721012 requirements : tuple [type , ...] | None = (),
9731013):
@@ -976,7 +1016,7 @@ def node_rewriter(
9761016 Parameters
9771017 ----------
9781018 tracks
979- The `Op` types or instances to which this rewrite applies.
1019+ The `Op` type, instances or `OpPattern` to which this rewrite applies.
9801020 Use ``None`` instead of an empty list to have the rewrite apply to
9811021 all `Op`\s.
9821022 inplace
@@ -995,14 +1035,16 @@ def decorator(f):
9951035 if tracks is not None :
9961036 if len (tracks ) == 0 :
9971037 raise ValueError (
998- "Use `None` instead of an empty list to make an rewrite apply to all nodes."
1038+ "Use `None` instead of an empty list to make a rewrite apply to all nodes."
9991039 )
10001040 for t in tracks :
10011041 if not (
1002- isinstance (t , Op ) or (isinstance (t , type ) and issubclass (t , Op ))
1042+ isinstance (t , Op | OpPattern )
1043+ or (isinstance (t , type ) and issubclass (t , Op ))
10031044 ):
10041045 raise TypeError (
1005- "`tracks` must consist of `Op` classes or instances."
1046+ "`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
1047+ f"Got { t } of type { type (t )} "
10061048 )
10071049 req = requirements
10081050 if inplace :
@@ -1024,47 +1066,93 @@ class OpToRewriterTracker:
10241066 def __init__ (self ):
10251067 self .tracked_instances : dict [Op , list [NodeRewriter ]] = defaultdict (list )
10261068 self .tracked_types : dict [type , list [NodeRewriter ]] = defaultdict (list )
1069+ self .tracked_pattern_types : dict [type , dict [OpPattern , list [NodeRewriter ]]] = (
1070+ defaultdict (lambda : defaultdict (list ))
1071+ )
10271072 self .untracked_rewrites : list [NodeRewriter ] = []
1073+ self ._cached_composed_mro = None
10281074
10291075 def add_tracker (self , rw : NodeRewriter ):
10301076 """Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
1077+ if self ._cached_composed_mro is not None :
1078+ # We shouldn't actually add_trackers after the first call to get_trackers
1079+ # But just to be safe we kill the cache here
1080+ self ._cached_composed_mro = None
1081+
10311082 tracks = rw .tracks ()
10321083
10331084 if tracks is None :
10341085 self .untracked_rewrites .append (rw )
10351086 else :
10361087 for c in tracks :
1088+ if isinstance (c , OpPattern ):
1089+ if not isinstance (c .op_type , type ):
1090+ # OpPattern allows anything that you can check with isinstance(op, op_type),
1091+ # including tuples or union types. But for OpToRewriterTracker we need a single type.
1092+ raise NotImplementedError (
1093+ "OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
1094+ f"Got { c .op_type } of type { type (c .op_type )} "
1095+ )
1096+
1097+ if c .parameters :
1098+ self .tracked_pattern_types [c .op_type ][c ].append (rw )
1099+ else :
1100+ # An OpPattern without parameters behaves like a regular tracked_type
1101+ self .tracked_types [c .op_type ].append (rw )
10371102 if isinstance (c , type ):
10381103 self .tracked_types [c ].append (rw )
10391104 else :
10401105 self .tracked_instances [c ].append (rw )
10411106
1042- def _find_impl (self , cls ) -> list [NodeRewriter ]:
1043- r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
1107+ @functools .cache
1108+ def get_trackers (self , op : Op ) -> list [NodeRewriter ]:
1109+ """Get all the rewrites applicable to an `Op`."""
1110+
1111+ if self ._cached_composed_mro is None :
1112+ # Cache the mro call on the Op type. We have a small subset of op_types we actually care about
1113+ # like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
1114+ tracked_types = (
1115+ self .tracked_types .keys () | self .tracked_pattern_types .keys ()
1116+ )
1117+
1118+ @functools .cache
1119+ def cached_composed_mro (op_type , tracked_types = tracked_types ):
1120+ return _compose_mro (op_type , tracked_types )
1121+
1122+ self ._cached_composed_mro = cached_composed_mro
10441123
1045- This based on `functools._find_impl`.
1046- """
1047- mro = _compose_mro (cls , self .tracked_types .keys ())
10481124 matches = []
1049- for t in mro :
1050- match = self .tracked_types .get (t , None )
1051- if match :
1052- matches .extend (match )
1125+ if self .tracked_types or self .tracked_pattern_types :
1126+ # Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
1127+ mro = self ._cached_composed_mro (type (op ))
1128+ for t in mro :
1129+ if (match := self .tracked_types .get (t , None )) is not None :
1130+ matches .extend (match )
1131+ if (
1132+ potential_matches := self .tracked_pattern_types .get (t , None )
1133+ ) is not None :
1134+ # We still need to check if the Op parameters match the constraints
1135+ matches .extend (
1136+ [
1137+ item
1138+ for op_pattern , r_list in potential_matches .items ()
1139+ if op_pattern .match_parameters (op )
1140+ for item in r_list
1141+ ]
1142+ )
1143+ matches .extend (self .tracked_instances .get (op , []))
1144+ matches .extend (self .untracked_rewrites )
10531145 return matches
10541146
1055- @functools .lru_cache
1056- def get_trackers (self , op : Op ) -> list [NodeRewriter ]:
1057- """Get all the rewrites applicable to `op`."""
1058- return (
1059- self ._find_impl (type (op ))
1060- + self .tracked_instances .get (op , [])
1061- + self .untracked_rewrites
1062- )
1063-
1064- def get_rewriters (self ):
1147+ def get_rewriters (self ) -> Iterable [NodeRewriter ]:
1148+ """Get all the registered rewriters."""
10651149 return chain (
1150+ chain .from_iterable (self .tracked_types .values ()),
1151+ chain .from_iterable (self .tracked_instances .values ()),
10661152 chain .from_iterable (
1067- chain (self .tracked_types .values (), self .tracked_instances .values ())
1153+ item
1154+ for sub_dict in self .tracked_pattern_types .values ()
1155+ for item in sub_dict .values ()
10681156 ),
10691157 self .untracked_rewrites ,
10701158 )
@@ -1138,7 +1226,7 @@ def tracks(self):
11381226 t .extend (at )
11391227 return t
11401228
1141- def transform (self , fgraph , node ):
1229+ def transform (self , fgraph , node , enforce_tracks = False ):
11421230 if len (self .rewrites ) == 0 :
11431231 return
11441232
@@ -1150,7 +1238,8 @@ def transform(self, fgraph, node):
11501238 new_repl = None
11511239 for rewrite in rewrites :
11521240 rewrite_start = time .perf_counter ()
1153- new_repl = rewrite .transform (fgraph , node )
1241+ # Tracks are already enforced by `self.tracker.get_trackers`
1242+ new_repl = rewrite .transform (fgraph , node , enforce_tracks = False )
11541243 rewrite_finish = time .perf_counter ()
11551244 if self .profile :
11561245 self .time_rewrites [rewrite ] += rewrite_start - rewrite_finish
@@ -1292,8 +1381,8 @@ def __init__(self, op1, op2, transfer_tags=True):
12921381 def tracks (self ):
12931382 return [self .op1 ]
12941383
1295- def transform (self , fgraph , node ):
1296- if node .op != self .op1 :
1384+ def transform (self , fgraph , node , enforce_tracks = True ):
1385+ if enforce_tracks and ( node .op != self .op1 ) :
12971386 return False
12981387 repl = self .op2 .make_node (* node .inputs )
12991388 if self .transfer_tags :
@@ -1498,7 +1587,7 @@ def __init__(
14981587 def tracks (self ):
14991588 return self ._tracks
15001589
1501- def transform (self , fgraph , node , get_nodes = True ):
1590+ def transform (self , fgraph , node , enforce_tracks : bool = False , get_nodes = True ):
15021591 """Check if the graph from node corresponds to ``in_pattern``.
15031592
15041593 If it does, it constructs ``out_pattern`` and performs the replacement.
@@ -1788,6 +1877,7 @@ def process_node(
17881877 fgraph : FunctionGraph ,
17891878 node : Apply ,
17901879 node_rewriter : NodeRewriter | None = None ,
1880+ enforce_tracks : bool = True ,
17911881 ):
17921882 r"""Apply `node_rewriter` to `node`.
17931883
@@ -1805,6 +1895,9 @@ def process_node(
18051895 node_rewriter
18061896 A `NodeRewriter` instance that may have a better idea for
18071897 how to compute node's outputs.
1898+ enforce_tracks: bool
1899+ Whether the transform method should enforce tracks,
1900+ or it can be assumed the caller already enforced them in a pre-filter stage.
18081901
18091902 Returns
18101903 -------
@@ -1820,7 +1913,9 @@ def process_node(
18201913 # TODO FIXME: This class's interface is broken
18211914 assert node_rewriter is not None
18221915 try :
1823- replacements = node_rewriter .transform (fgraph , node )
1916+ replacements = node_rewriter .transform (
1917+ fgraph , node , enforce_tracks = enforce_tracks
1918+ )
18241919 except Exception as e :
18251920 if self .failure_callback is not None :
18261921 self .failure_callback (
@@ -1938,7 +2033,8 @@ def importer(node):
19382033 if node not in fgraph .apply_nodes :
19392034 continue
19402035 current_node = node
1941- nb += self .process_node (fgraph , node )
2036+ # This rewriter does not enforce tracks itself
2037+ nb += self .process_node (fgraph , node , enforce_tracks = True )
19422038 loop_t = time .perf_counter () - t0
19432039 finally :
19442040 self .detach_updater (fgraph , u )
@@ -2279,8 +2375,9 @@ def chin_(node, i, r, new_r, reason):
22792375 for node_rewriter in self .node_tracker .get_trackers (node .op ):
22802376 nb = change_tracker .nb_imported
22812377 t_rewrite = time .perf_counter ()
2378+ # Tracks are already enfoced by `self.node_tracker.get_trackers`
22822379 node_rewriter_change = self .process_node (
2283- fgraph , node , node_rewriter
2380+ fgraph , node , node_rewriter , enforce_tracks = False
22842381 )
22852382 time_rewriters [node_rewriter ] += time .perf_counter () - t_rewrite
22862383 if not node_rewriter_change :
0 commit comments