Skip to content

Commit d4f676d

Browse files
committed
Manual control of numba caching
1 parent 4f52dea commit d4f676d

File tree

18 files changed

+489
-78
lines changed

18 files changed

+489
-78
lines changed

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ Here's an example for :class:`DimShuffle`:
228228
# E No match.
229229
# ...(on this line)...
230230
# E shuffle_shape = res.shape[: len(shuffle)]
231-
@numba_basic.numba_njit(inline="always")
231+
@numba_basic.numba_njit
232232
def dimshuffle(x):
233233
return dimshuffle_inner(np.asarray(x), shuffle)
234234

pytensor/bin/pytensor_cache.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import shutil
34
import sys
45
from pathlib import Path
56

@@ -74,7 +75,10 @@ def main():
7475
'You can also call "pytensor-cache purge" to '
7576
"remove everything from that directory."
7677
)
77-
_logger.debug(f"Remaining elements ({len(items)}): {', '.join(items)}")
78+
_logger.debug(f"Remaining elements ({len(items)}): {items}")
79+
numba_cache_dir: Path = config.base_compiledir / "numba"
80+
shutil.rmtree(numba_cache_dir, ignore_errors=True)
81+
7882
elif sys.argv[1] == "list":
7983
pytensor.compile.compiledir.print_compiledir_content()
8084
elif sys.argv[1] == "cleanup":
@@ -86,6 +90,8 @@ def main():
8690
print("Lock successfully removed!")
8791
elif sys.argv[1] == "purge":
8892
pytensor.compile.compiledir.compiledir_purge()
93+
numba_cache_dir: Path = config.base_compiledir / "numba"
94+
shutil.rmtree(numba_cache_dir, ignore_errors=True)
8995
elif sys.argv[1] == "basecompiledir":
9096
# Simply print the base_compiledir
9197
print(pytensor.config.base_compiledir)

pytensor/link/numba/cache.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import weakref
2+
from collections.abc import Callable
3+
from pathlib import Path
4+
from tempfile import NamedTemporaryFile
5+
from typing import Any
6+
7+
from numba.core.caching import CacheImpl, _CacheLocator
8+
9+
from pytensor.configdefaults import config
10+
11+
12+
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
13+
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
14+
CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary()
15+
16+
17+
class NumbaPyTensorCacheLocator(_CacheLocator):
18+
"""Locator for Numba functions defined from PyTensor-generated source code.
19+
20+
It uses an internally-defined hash to disambiguate functions.
21+
22+
Functions returned by the PyTensor dispatchers are cached in the CACHED_SRC_FUNCTIONS
23+
weakref dictionary when `compile_numba_function_src` is called with a `cache_key`.
24+
When numba later attempts to find a cache for such a function, this locator gets triggered
25+
and directs numba to the PyTensor Numba cache directory, using the provided hash as disambiguator.
26+
27+
It is not necessary that the python functions be cached by the dispatchers.
28+
As long as the key is the same, numba will be directed to the same cache entry, even if the function is fresh.
29+
Conversely, if the function changed but the key is the same, numba will still use the old cache.
30+
"""
31+
32+
def __init__(self, py_func, py_file, hash):
33+
self._py_func = py_func
34+
self._py_file = py_file
35+
self._hash = hash
36+
37+
def ensure_cache_path(self):
38+
"""We ensured this when the module was loaded.
39+
40+
It's too slow to run every time a cache is needed.
41+
"""
42+
pass
43+
44+
def get_cache_path(self):
45+
"""Return the directory the function is cached in."""
46+
return NUMBA_CACHE_PATH
47+
48+
def get_source_stamp(self):
49+
"""Get a timestamp representing the source code's freshness.
50+
Can return any picklable Python object.
51+
52+
This can be used to invalidate all caches from previous PyTensor releases.
53+
"""
54+
return 0
55+
56+
def get_disambiguator(self):
57+
"""Get a string disambiguator for this locator's function.
58+
It should allow disambiguating different but similarly-named functions.
59+
"""
60+
return self._hash
61+
62+
@classmethod
63+
def from_function(cls, py_func, py_file):
64+
"""Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS."""
65+
if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS:
66+
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
67+
68+
69+
# Register our locator at the front of Numba's locator list
70+
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
71+
72+
73+
def compile_numba_function_src(
74+
src: str,
75+
function_name: str,
76+
global_env: dict[Any, Any] | None = None,
77+
local_env: dict[Any, Any] | None = None,
78+
store_to_disk: bool = False,
79+
cache_key: str | None = None,
80+
) -> Callable:
81+
"""Compile (and optionally cache) a function from source code for use with Numba.
82+
83+
This function compiles the provided source code string into a Python function
84+
with the specified name. If `store_to_disk` is True, the source code is written
85+
to a temporary file before compilation. The compiled function is then executed
86+
in the provided global and local environments.
87+
88+
If a `cache_key` is provided the function is registered in a `CACHED_SRC_FUNCTIONS`
89+
weak reference dictionary, to be used by the `NumbaPyTensorCacheLocator` for caching.
90+
91+
"""
92+
if store_to_disk:
93+
with NamedTemporaryFile(delete=False) as f:
94+
filename = f.name
95+
f.write(src.encode())
96+
else:
97+
filename = "<string>"
98+
99+
if global_env is None:
100+
global_env = {}
101+
102+
if local_env is None:
103+
local_env = {}
104+
105+
mod_code = compile(src, filename, mode="exec")
106+
exec(mod_code, global_env, local_env)
107+
108+
res = local_env[function_name]
109+
res.__source__ = src # type: ignore
110+
111+
if cache_key is not None:
112+
CACHED_SRC_FUNCTIONS[res] = cache_key
113+
114+
return res

0 commit comments

Comments
 (0)