1111import warnings
1212from collections import Counter , UserList , defaultdict , deque
1313from collections .abc import Callable , Iterable , Sequence
14- from functools import _compose_mro , partial # type: ignore
14+ from functools import _compose_mro , lru_cache , partial # type: ignore
1515from itertools import chain
1616from typing import Literal
1717
@@ -141,7 +141,12 @@ def tracks(self) -> Sequence[Op] | None:
141141
142142 @abc .abstractmethod
143143 def transform (
144- self , fgraph : FunctionGraph , node : Apply , * args , ** kwargs
144+ self ,
145+ fgraph : FunctionGraph ,
146+ node : Apply ,
147+ enfoce_tracks : bool = True ,
148+ * args ,
149+ ** kwargs ,
145150 ) -> TransformOutputType :
146151 r"""Rewrite the sub-graph given by `node`.
147152
@@ -159,7 +164,8 @@ def transform(
159164 A `FunctionGraph` containing `node`.
160165 node
161166 An `Apply` node to be rewritten.
162-
167+ enforce_tracks: bool
168+ Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
163169 """
164170
165171 raise NotImplementedError ()
@@ -935,15 +941,43 @@ class FromFunctionNodeRewriter(NodeRewriter):
935941 def __init__ (self , fn , tracks = None , requirements = ()):
936942 self .fn = fn
937943 self ._tracks = tracks
938- self ._tracked_types = (
939- tuple (t for t in tracks if isinstance (t , type )) if tracks else ()
940- )
944+ self ._tracked_ops = set ()
945+ self ._tracked_types = type (None )
946+ self ._tracked_parametrized_types = type (None )
947+ self ._tracked_op_instance_patterns : list [OpPattern ] = []
948+ if tracks is not None :
949+ if not tracks :
950+ raise ValueError (
951+ "To specify a general rewrite leave tracks as None instead of an empty container"
952+ )
953+ for t in tracks :
954+ if isinstance (t , Op ):
955+ self ._tracked_ops .add (t )
956+ if isinstance (t , type ):
957+ self ._tracked_types |= t
958+ elif isinstance (t , OpPattern ):
959+ if t .parameters :
960+ self ._tracked_op_instance_patterns .append (t )
961+ self ._tracked_parametrized_types |= t .op_type
962+ else :
963+ # It's a regular tracked_type
964+ self ._tracked_types |= t
941965 self .requirements = requirements
942966
943- def transform (self , fgraph , node ):
944- if self ._tracks :
967+ def transform (self , fgraph , node , enforce_tracks : bool = True ):
968+ if enforce_tracks and self ._tracks :
969+ node_op = node .op
945970 if not (
946- node .op in self ._tracks or isinstance (node .op , self ._tracked_types )
971+ node_op in self ._tracked_ops
972+ or isinstance (node_op , self ._tracked_types )
973+ or (
974+ isinstance (node .op , self ._tracked_parametrized_types )
975+ and any (
976+ t .match_parameters (node_op )
977+ for t in self ._tracked_op_instance_patterns
978+ if isinstance (node_op , t .op_type )
979+ )
980+ )
947981 ):
948982 return False
949983
@@ -967,7 +1001,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
9671001
9681002
9691003def node_rewriter (
970- tracks : Sequence [Op | type ] | None ,
1004+ tracks : Sequence [Op | type , OpPattern ] | None ,
9711005 inplace : bool = False ,
9721006 requirements : tuple [type , ...] | None = (),
9731007):
@@ -995,14 +1029,15 @@ def decorator(f):
9951029 if tracks is not None :
9961030 if len (tracks ) == 0 :
9971031 raise ValueError (
998- "Use `None` instead of an empty list to make an rewrite apply to all nodes."
1032+ "Use `None` instead of an empty list to make a rewrite apply to all nodes."
9991033 )
10001034 for t in tracks :
10011035 if not (
1002- isinstance (t , Op ) or (isinstance (t , type ) and issubclass (t , Op ))
1036+ isinstance (t , Op | OpPattern )
1037+ or (isinstance (t , type ) and issubclass (t , Op ))
10031038 ):
10041039 raise TypeError (
1005- "`tracks` must consist of `Op` classes or instances."
1040+ "`tracks` must consist of `Op` classes, instances or `OpPattern` instances."
10061041 )
10071042 req = requirements
10081043 if inplace :
@@ -1024,47 +1059,91 @@ class OpToRewriterTracker:
10241059 def __init__ (self ):
10251060 self .tracked_instances : dict [Op , list [NodeRewriter ]] = defaultdict (list )
10261061 self .tracked_types : dict [type , list [NodeRewriter ]] = defaultdict (list )
1062+ self .tracked_parametrized_types : dict [
1063+ type , dict [OpPattern , list [NodeRewriter ]]
1064+ ] = defaultdict (lambda : defaultdict (list ))
10271065 self .untracked_rewrites : list [NodeRewriter ] = []
1066+ self ._cached_composed_mro = None
10281067
10291068 def add_tracker (self , rw : NodeRewriter ):
10301069 """Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
1070+ if self ._cached_composed_mro is not None :
1071+ # We shouldn't actually add_trackers after the first call to get_trackers
1072+ # But just to be safe we kill the cache here
1073+ self ._cached_composed_mro = None
1074+
10311075 tracks = rw .tracks ()
10321076
10331077 if tracks is None :
10341078 self .untracked_rewrites .append (rw )
10351079 else :
10361080 for c in tracks :
1081+ if isinstance (c , OpPattern ):
1082+ if not isinstance (c .op_type , type ):
1083+ raise NotImplementedError (
1084+ "OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
1085+ f"Got { c .op_type } of type { type (c .op_type )} "
1086+ )
1087+
1088+ if c .parameters :
1089+ self .tracked_parametrized_types [c .op_type ][c ].append (rw )
1090+ else :
1091+ # It's a simple type track
1092+ self .tracked_types [c .op_type ].append (rw )
10371093 if isinstance (c , type ):
10381094 self .tracked_types [c ].append (rw )
10391095 else :
10401096 self .tracked_instances [c ].append (rw )
10411097
1042- def _find_impl (self , cls ) -> list [NodeRewriter ]:
1043- r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
1098+ @functools .lru_cache
1099+ def get_trackers (self , op : Op ) -> list [NodeRewriter ]:
1100+ """Get all the rewrites applicable to an `Op`."""
1101+
1102+ if self ._cached_composed_mro is None :
1103+ # Cache the mro call on the Op type. We have a small subset of op_types we actuall care about
1104+ # like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
1105+ tracked_types = (
1106+ self .tracked_types .keys () | self .tracked_parametrized_types .keys ()
1107+ )
1108+
1109+ @lru_cache
1110+ def cached_composed_mro (op_type , tracked_types = tracked_types ):
1111+ return _compose_mro (op_type , tracked_types )
1112+
1113+ self ._cached_composed_mro = cached_composed_mro
10441114
1045- This based on `functools._find_impl`.
1046- """
1047- mro = _compose_mro (cls , self .tracked_types .keys ())
10481115 matches = []
1049- for t in mro :
1050- match = self .tracked_types .get (t , None )
1051- if match :
1052- matches .extend (match )
1116+ if self .tracked_types or self .tracked_parametrized_types :
1117+ # Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
1118+ mro = self ._cached_composed_mro (type (op ))
1119+ for t in mro :
1120+ if (match := self .tracked_types .get (t , None )) is not None :
1121+ matches .extend (match )
1122+ if (
1123+ potential_matches := self .tracked_parametrized_types .get (t , None )
1124+ ) is not None :
1125+ # We still need to check if the Op parameters match the constraints
1126+ matches .extend (
1127+ [
1128+ item
1129+ for op_pattern , r_list in potential_matches .items ()
1130+ if op_pattern .match_parameters (op )
1131+ for item in r_list
1132+ ]
1133+ )
1134+ matches .extend (self .tracked_instances .get (op , []))
1135+ matches .extend (self .untracked_rewrites )
10531136 return matches
10541137
1055- @functools .lru_cache
1056- def get_trackers (self , op : Op ) -> list [NodeRewriter ]:
1057- """Get all the rewrites applicable to `op`."""
1058- return (
1059- self ._find_impl (type (op ))
1060- + self .tracked_instances .get (op , [])
1061- + self .untracked_rewrites
1062- )
1063-
1064- def get_rewriters (self ):
1138+ def get_rewriters (self ) -> Iterable [NodeRewriter ]:
1139+ """Get all the registered rewriters."""
10651140 return chain (
1141+ chain .from_iterable (self .tracked_types .values ()),
1142+ chain .from_iterable (self .tracked_instances .values ()),
10661143 chain .from_iterable (
1067- chain (self .tracked_types .values (), self .tracked_instances .values ())
1144+ item
1145+ for sub_dict in self .tracked_parametrized_types .values ()
1146+ for item in sub_dict .values ()
10681147 ),
10691148 self .untracked_rewrites ,
10701149 )
@@ -1138,7 +1217,7 @@ def tracks(self):
11381217 t .extend (at )
11391218 return t
11401219
1141- def transform (self , fgraph , node ):
1220+ def transform (self , fgraph , node , enforce_tracks = False ):
11421221 if len (self .rewrites ) == 0 :
11431222 return
11441223
@@ -1150,7 +1229,8 @@ def transform(self, fgraph, node):
11501229 new_repl = None
11511230 for rewrite in rewrites :
11521231 rewrite_start = time .perf_counter ()
1153- new_repl = rewrite .transform (fgraph , node )
1232+ # Tracks are already enforced by `self.tracker.get_trackers`
1233+ new_repl = rewrite .transform (fgraph , node , enforce_tracks = False )
11541234 rewrite_finish = time .perf_counter ()
11551235 if self .profile :
11561236 self .time_rewrites [rewrite ] += rewrite_start - rewrite_finish
@@ -1292,8 +1372,8 @@ def __init__(self, op1, op2, transfer_tags=True):
12921372 def tracks (self ):
12931373 return [self .op1 ]
12941374
1295- def transform (self , fgraph , node ):
1296- if node .op != self .op1 :
1375+ def transform (self , fgraph , node , enforce_tracks = True ):
1376+ if enforce_tracks and ( node .op != self .op1 ) :
12971377 return False
12981378 repl = self .op2 .make_node (* node .inputs )
12991379 if self .transfer_tags :
@@ -1492,7 +1572,7 @@ def __init__(
14921572 def tracks (self ):
14931573 return self ._tracks
14941574
1495- def transform (self , fgraph , node , get_nodes = True ):
1575+ def transform (self , fgraph , node , enforce_tracks : bool = False , get_nodes = True ):
14961576 """Check if the graph from node corresponds to ``in_pattern``.
14971577
14981578 If it does, it constructs ``out_pattern`` and performs the replacement.
@@ -1782,6 +1862,7 @@ def process_node(
17821862 fgraph : FunctionGraph ,
17831863 node : Apply ,
17841864 node_rewriter : NodeRewriter | None = None ,
1865+ enforce_tracks : bool = True ,
17851866 ):
17861867 r"""Apply `node_rewriter` to `node`.
17871868
@@ -1799,6 +1880,9 @@ def process_node(
17991880 node_rewriter
18001881 A `NodeRewriter` instance that may have a better idea for
18011882 how to compute node's outputs.
1883+ enforce_tracks: bool
1884+ Whether the transform method should enforce tracks,
1885+ or it can be assumed the caller already enforced them in a pre-filter stage.
18021886
18031887 Returns
18041888 -------
@@ -1814,7 +1898,9 @@ def process_node(
18141898 # TODO FIXME: This class's interface is broken
18151899 assert node_rewriter is not None
18161900 try :
1817- replacements = node_rewriter .transform (fgraph , node )
1901+ replacements = node_rewriter .transform (
1902+ fgraph , node , enforce_tracks = enforce_tracks
1903+ )
18181904 except Exception as e :
18191905 if self .failure_callback is not None :
18201906 self .failure_callback (
@@ -1932,7 +2018,8 @@ def importer(node):
19322018 if node not in fgraph .apply_nodes :
19332019 continue
19342020 current_node = node
1935- nb += self .process_node (fgraph , node )
2021+ # This rewriter does not enforce tracks itself
2022+ nb += self .process_node (fgraph , node , enforce_tracks = True )
19362023 loop_t = time .perf_counter () - t0
19372024 finally :
19382025 self .detach_updater (fgraph , u )
@@ -2273,8 +2360,9 @@ def chin_(node, i, r, new_r, reason):
22732360 for node_rewriter in self .node_tracker .get_trackers (node .op ):
22742361 nb = change_tracker .nb_imported
22752362 t_rewrite = time .perf_counter ()
2363+ # Tracks are already enfoced by `self.node_tracker.get_trackers`
22762364 node_rewriter_change = self .process_node (
2277- fgraph , node , node_rewriter
2365+ fgraph , node , node_rewriter , enforce_tracks = False
22782366 )
22792367 time_rewriters [node_rewriter ] += time .perf_counter () - t_rewrite
22802368 if not node_rewriter_change :
0 commit comments