1+ import warnings
12import weakref
3+ from collections .abc import Callable
4+ from functools import singledispatch , wraps
25from hashlib import sha256
36from pathlib import Path
7+ from pickle import dumps
8+ from tempfile import NamedTemporaryFile
9+ from typing import Any
410
511from numba .core .caching import CacheImpl , _CacheLocator
612
713from pytensor import config
814from pytensor .graph .basic import Apply
15+ from pytensor .link .numba .compile import numba_funcify , numba_njit
916
1017
1118NUMBA_PYTENSOR_CACHE_ENABLED = True
@@ -19,8 +26,6 @@ def __init__(self, py_func, py_file, hash):
1926 self ._py_func = py_func
2027 self ._py_file = py_file
2128 self ._hash = hash
22- # src_hash = hash(pytensor_loader._module_sources[self._py_file])
23- # self._hash = hash((src_hash, py_file, pytensor.__version__))
2429
2530 def ensure_cache_path (self ):
2631 pass
@@ -74,3 +79,165 @@ def cache_node_key(node: Apply, extra_key="") -> str:
7479 ),
7580 ).encode ()
7681 ).hexdigest ()
82+
83+
84+ @singledispatch
85+ def numba_funcify_default_op_cache_key (
86+ op , node = None , ** kwargs
87+ ) -> Callable | tuple [Callable , Any ]:
88+ """Funcify an Op and implement a default cache key.
89+
90+ The default cache key is based on the op class and its properties.
91+ It does not take into account the node inputs or other context.
92+ Note that numba will use the array dtypes, rank and layout as part of the cache key,
93+ but not the static shape or constant values.
94+ If the funcify implementation exploits this information, then this method should not be used.
95+ Instead dispatch directly on `numba_funcify_and_cache_key` (or just numba_funcify)
96+ which won't use any cache key.
97+ """
98+ # Default cache key of None which means "don't try to do directly cache this function"
99+ raise NotImplementedError ()
100+
101+
102+ def register_funcify_default_op_cache_key (op_type ):
103+ """Register a funcify implementation for both cache and non-cache versions."""
104+
105+ def decorator (dispatch_func ):
106+ # Register with the cache key dispatcher
107+ numba_funcify_default_op_cache_key .register (op_type )(dispatch_func )
108+
109+ # Create a wrapper for the non-cache dispatcher
110+ @wraps (dispatch_func )
111+ def dispatch_func_wrapper (* args , ** kwargs ):
112+ func , key = dispatch_func (* args , ** kwargs )
113+ # Discard the key for the non-cache version
114+ return func
115+
116+ # Register the wrapper with the non-cache dispatcher
117+ numba_funcify .register (op_type )(dispatch_func_wrapper )
118+
119+ return dispatch_func
120+
121+ return decorator
122+
123+
124+ @singledispatch
125+ def numba_funcify_and_cache_key (op , node = None , ** kwargs ) -> tuple [Callable , str | None ]:
126+ # Default cache key of None which means "don't try to do directly cache this function"
127+ if hasattr (op , "_props" ):
128+ try :
129+ func_and_salt = numba_funcify_default_op_cache_key (op , node = node , ** kwargs )
130+ except NotImplementedError :
131+ pass
132+ else :
133+ if isinstance (func_and_salt , tuple ):
134+ func , salt = func_and_salt
135+ else :
136+ func , salt = func_and_salt , "0"
137+ props_dict = op ._props_dict ()
138+ if not props_dict :
139+ # Simple op, just use the type string as key
140+ key_bytes = str ((type (op ), salt )).encode ()
141+ else :
142+ # Simple props, can use string representation of props as key
143+ simple_types = (str , bool , int , type (None ), float )
144+ container_types = (tuple , frozenset )
145+ if all (
146+ isinstance (v , simple_types )
147+ or (
148+ isinstance (v , container_types )
149+ and all (isinstance (i , simple_types ) for i in v )
150+ )
151+ for v in props_dict .values ()
152+ ):
153+ key_bytes = str (
154+ (type (op ), tuple (props_dict .items ()), salt )
155+ ).encode ()
156+ else :
157+ # Complex props, use pickle to serialize them
158+ key_bytes = dumps ((str (type (op )), tuple (props_dict .items ()), salt ))
159+ return func , sha256 (key_bytes ).hexdigest ()
160+
161+ # Fallback
162+ return numba_funcify (op , node = node , ** kwargs ), None
163+
164+
165+ def register_funcify_and_cache_key (op_type ):
166+ """Register a funcify implementation for both cache and non-cache versions."""
167+
168+ def decorator (dispatch_func ):
169+ # Register with the cache key dispatcher
170+ numba_funcify_and_cache_key .register (op_type )(dispatch_func )
171+
172+ # Create a wrapper for the non-cache dispatcher
173+ @wraps (dispatch_func )
174+ def dispatch_func_wrapper (* args , ** kwargs ):
175+ func , key = dispatch_func (* args , ** kwargs )
176+ # Discard the key for the non-cache version
177+ return func
178+
179+ # Register the wrapper with the non-cache dispatcher
180+ numba_funcify .register (op_type )(dispatch_func_wrapper )
181+
182+ return dispatch_func_wrapper
183+
184+ return decorator
185+
186+
187+ def numba_njit_and_cache (op , * args , ** kwargs ):
188+ jitable_func , key = numba_funcify_and_cache_key (op , * args , ** kwargs )
189+
190+ if key is not None :
191+ # To force numba to use our cache, we must compile the function so that any closure
192+ # becomes a global variable...
193+ op_name = op .__class__ .__name__
194+ cached_func = compile_numba_function_src (
195+ src = f"def { op_name } (*args): return jitable_func(*args)" ,
196+ function_name = op_name ,
197+ global_env = globals () | {"jitable_func" : jitable_func },
198+ cache_key = key ,
199+ )
200+ return numba_njit (cached_func , final_function = True , cache = True ), key
201+ else :
202+ if config .numba__cache and config .compiler_verbose :
203+ warnings .warn (
204+ f"Custom numba cache disabled for { op } of type { type (op )} . "
205+ f"Even if the function is cached by numba, larger graphs using this function cannot be cached.\n "
206+ "To enable custom caching, register a numba_funcify_and_cache_key implementation for this Op, with a proper cache key."
207+ )
208+
209+ return numba_njit (
210+ lambda * args : jitable_func (* args ), final_function = True , cache = False
211+ ), None
212+
213+
214+ def compile_numba_function_src (
215+ src : str ,
216+ function_name : str ,
217+ global_env : dict [Any , Any ] | None = None ,
218+ local_env : dict [Any , Any ] | None = None ,
219+ store_to_disk : bool = False ,
220+ cache_key : str | None = None ,
221+ ) -> Callable :
222+ if store_to_disk :
223+ with NamedTemporaryFile (delete = False ) as f :
224+ filename = f .name
225+ f .write (src .encode ())
226+ else :
227+ filename = "<string>"
228+
229+ if global_env is None :
230+ global_env = {}
231+
232+ if local_env is None :
233+ local_env = {}
234+
235+ mod_code = compile (src , filename , mode = "exec" )
236+ exec (mod_code , global_env , local_env )
237+
238+ res = local_env [function_name ]
239+ res .__source__ = src # type: ignore
240+
241+ if cache_key is not None :
242+ CACHED_SRC_FUNCTIONS [res ] = cache_key
243+ return res
0 commit comments