Skip to content

Commit d604423

Browse files
committed
Manual control of numba caching
1 parent 7446599 commit d604423

File tree

17 files changed

+450
-63
lines changed

17 files changed

+450
-63
lines changed

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ Here's an example for :class:`DimShuffle`:
228228
# E No match.
229229
# ...(on this line)...
230230
# E shuffle_shape = res.shape[: len(shuffle)]
231-
@numba_basic.numba_njit(inline="always")
231+
@numba_basic.numba_njit
232232
def dimshuffle(x):
233233
return dimshuffle_inner(np.asarray(x), shuffle)
234234

pytensor/link/numba/cache.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import weakref
2+
from collections.abc import Callable
3+
from pathlib import Path
4+
from tempfile import NamedTemporaryFile
5+
from typing import Any
6+
7+
from numba.core.caching import CacheImpl, _CacheLocator
8+
9+
from pytensor import config
10+
11+
12+
NUMBA_PYTENSOR_CACHE_ENABLED = True
13+
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
14+
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
15+
CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary()
16+
17+
18+
class NumbaPyTensorCacheLocator(_CacheLocator):
19+
def __init__(self, py_func, py_file, hash):
20+
self._py_func = py_func
21+
self._py_file = py_file
22+
self._hash = hash
23+
24+
def ensure_cache_path(self):
25+
pass
26+
27+
def get_cache_path(self):
28+
"""
29+
Return the directory the function is cached in.
30+
"""
31+
return NUMBA_CACHE_PATH
32+
33+
def get_source_stamp(self):
34+
"""
35+
Get a timestamp representing the source code's freshness.
36+
Can return any picklable Python object.
37+
"""
38+
return 0
39+
40+
def get_disambiguator(self):
41+
"""
42+
Get a string disambiguator for this locator's function.
43+
It should allow disambiguating different but similarly-named functions.
44+
"""
45+
return self._hash
46+
47+
@classmethod
48+
def from_function(cls, py_func, py_file):
49+
"""
50+
Create a locator instance for the given function located in the given file.
51+
"""
52+
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS:
53+
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
54+
55+
56+
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
57+
58+
59+
def compile_numba_function_src(
60+
src: str,
61+
function_name: str,
62+
global_env: dict[Any, Any] | None = None,
63+
local_env: dict[Any, Any] | None = None,
64+
store_to_disk: bool = False,
65+
cache_key: str | None = None,
66+
) -> Callable:
67+
# TODO: Docstrings
68+
if store_to_disk:
69+
with NamedTemporaryFile(delete=False) as f:
70+
filename = f.name
71+
f.write(src.encode())
72+
else:
73+
filename = "<string>"
74+
75+
if global_env is None:
76+
global_env = {}
77+
78+
if local_env is None:
79+
local_env = {}
80+
81+
mod_code = compile(src, filename, mode="exec")
82+
exec(mod_code, global_env, local_env)
83+
84+
res = local_env[function_name]
85+
res.__source__ = src # type: ignore
86+
87+
if cache_key is not None:
88+
CACHED_SRC_FUNCTIONS[res] = cache_key
89+
return res

0 commit comments

Comments
 (0)