Skip to content

Commit f3851a7

Browse files
committed
Control caching of numba functions
1 parent 2363646 commit f3851a7

File tree

16 files changed

+607
-243
lines changed

16 files changed

+607
-243
lines changed

pytensor/link/numba/cache.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import warnings
2+
import weakref
3+
from collections.abc import Callable
4+
from functools import singledispatch, wraps
5+
from hashlib import sha256
6+
from pathlib import Path
7+
from pickle import dumps
8+
from tempfile import NamedTemporaryFile
9+
from typing import Any
10+
11+
from numba.core.caching import CacheImpl, _CacheLocator
12+
13+
from pytensor import config
14+
from pytensor.link.numba.compile import numba_funcify, numba_njit
15+
16+
17+
NUMBA_PYTENSOR_CACHE_ENABLED = True
18+
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
19+
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
20+
CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary()
21+
22+
23+
class NumbaPyTensorCacheLocator(_CacheLocator):
24+
def __init__(self, py_func, py_file, hash):
25+
self._py_func = py_func
26+
self._py_file = py_file
27+
self._hash = hash
28+
29+
def ensure_cache_path(self):
30+
pass
31+
32+
def get_cache_path(self):
33+
"""
34+
Return the directory the function is cached in.
35+
"""
36+
return NUMBA_CACHE_PATH
37+
38+
def get_source_stamp(self):
39+
"""
40+
Get a timestamp representing the source code's freshness.
41+
Can return any picklable Python object.
42+
"""
43+
return 0
44+
45+
def get_disambiguator(self):
46+
"""
47+
Get a string disambiguator for this locator's function.
48+
It should allow disambiguating different but similarly-named functions.
49+
"""
50+
return self._hash
51+
52+
@classmethod
53+
def from_function(cls, py_func, py_file):
54+
"""
55+
Create a locator instance for the given function located in the given file.
56+
"""
57+
# py_file = Path(py_file).parent
58+
# if py_file == (config.base_compiledir / "numba"):
59+
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS:
60+
# print(f"Applies to {py_file}")
61+
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
62+
63+
64+
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
65+
66+
67+
@singledispatch
68+
def numba_funcify_default_op_cache_key(
69+
op, node=None, **kwargs
70+
) -> Callable | tuple[Callable, Any]:
71+
"""Funcify an Op and implement a default cache key.
72+
73+
The default cache key is based on the op class and its properties.
74+
It does not take into account the node inputs or other context.
75+
Note that numba will use the array dtypes, rank and layout as part of the cache key,
76+
but not the static shape or constant values.
77+
If the funcify implementation exploits this information, then this method should not be used.
78+
Instead dispatch directly on `numba_funcify_and_cache_key` (or just numba_funcify)
79+
which won't use any cache key.
80+
"""
81+
# Default cache key of None which means "don't try to do directly cache this function"
82+
raise NotImplementedError()
83+
84+
85+
def register_funcify_default_op_cache_key(op_type):
86+
"""Register a funcify implementation for both cache and non-cache versions."""
87+
88+
def decorator(dispatch_func):
89+
# Register with the cache key dispatcher
90+
numba_funcify_default_op_cache_key.register(op_type)(dispatch_func)
91+
92+
# Create a wrapper for the non-cache dispatcher
93+
@wraps(dispatch_func)
94+
def dispatch_func_wrapper(*args, **kwargs):
95+
func, _key = dispatch_func(*args, **kwargs)
96+
# Discard the key for the non-cache version
97+
return func
98+
99+
# Register the wrapper with the non-cache dispatcher
100+
numba_funcify.register(op_type)(dispatch_func_wrapper)
101+
102+
return dispatch_func
103+
104+
return decorator
105+
106+
107+
@singledispatch
108+
def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]:
109+
# Default cache key of None which means "don't try to do directly cache this function"
110+
if hasattr(op, "_props"):
111+
try:
112+
func_and_salt = numba_funcify_default_op_cache_key(op, node=node, **kwargs)
113+
except NotImplementedError:
114+
pass
115+
else:
116+
if isinstance(func_and_salt, tuple):
117+
func, salt = func_and_salt
118+
else:
119+
func, salt = func_and_salt, "0"
120+
props_dict = op._props_dict()
121+
if not props_dict:
122+
# Simple op, just use the type string as key
123+
key_bytes = str((type(op), salt)).encode()
124+
else:
125+
# Simple props, can use string representation of props as key
126+
simple_types = (str, bool, int, type(None), float)
127+
container_types = (tuple, frozenset)
128+
if all(
129+
isinstance(v, simple_types)
130+
or (
131+
isinstance(v, container_types)
132+
and all(isinstance(i, simple_types) for i in v)
133+
)
134+
for v in props_dict.values()
135+
):
136+
key_bytes = str(
137+
(type(op), tuple(props_dict.items()), salt)
138+
).encode()
139+
else:
140+
# Complex props, use pickle to serialize them
141+
key_bytes = dumps((str(type(op)), tuple(props_dict.items()), salt))
142+
return func, sha256(key_bytes).hexdigest()
143+
144+
# Fallback
145+
return numba_funcify(op, node=node, **kwargs), None
146+
147+
148+
def register_funcify_and_cache_key(op_type):
149+
"""Register a funcify implementation for both cache and non-cache versions."""
150+
151+
def decorator(dispatch_func):
152+
# Register with the cache key dispatcher
153+
numba_funcify_and_cache_key.register(op_type)(dispatch_func)
154+
155+
# Create a wrapper for the non-cache dispatcher
156+
@wraps(dispatch_func)
157+
def dispatch_func_wrapper(*args, **kwargs):
158+
func, _key = dispatch_func(*args, **kwargs)
159+
# Discard the key for the non-cache version
160+
return func
161+
162+
# Register the wrapper with the non-cache dispatcher
163+
numba_funcify.register(op_type)(dispatch_func_wrapper)
164+
165+
return dispatch_func_wrapper
166+
167+
return decorator
168+
169+
170+
def numba_njit_and_cache(op, *args, **kwargs):
171+
jitable_func, key = numba_funcify_and_cache_key(op, *args, **kwargs)
172+
173+
if key is not None:
174+
# To force numba to use our cache, we must compile the function so that any closure
175+
# becomes a global variable...
176+
op_name = op.__class__.__name__
177+
cached_func = compile_numba_function_src(
178+
src=f"def {op_name}(*args): return jitable_func(*args)",
179+
function_name=op_name,
180+
global_env=globals() | {"jitable_func": jitable_func},
181+
cache_key=key,
182+
)
183+
return numba_njit(cached_func, final_function=True, cache=True), key
184+
else:
185+
if config.numba__cache and config.compiler_verbose:
186+
warnings.warn(
187+
f"Custom numba cache disabled for {op} of type {type(op)}. "
188+
f"Even if the function is cached by numba, larger graphs using this function cannot be cached.\n"
189+
"To enable custom caching, register a numba_funcify_and_cache_key implementation for this Op, with a proper cache key."
190+
)
191+
192+
return numba_njit(
193+
lambda *args: jitable_func(*args), final_function=True, cache=False
194+
), None
195+
196+
197+
def compile_numba_function_src(
198+
src: str,
199+
function_name: str,
200+
global_env: dict[Any, Any] | None = None,
201+
local_env: dict[Any, Any] | None = None,
202+
store_to_disk: bool = False,
203+
cache_key: str | None = None,
204+
) -> Callable:
205+
if store_to_disk:
206+
with NamedTemporaryFile(delete=False) as f:
207+
filename = f.name
208+
f.write(src.encode())
209+
else:
210+
filename = "<string>"
211+
212+
if global_env is None:
213+
global_env = {}
214+
215+
if local_env is None:
216+
local_env = {}
217+
218+
mod_code = compile(src, filename, mode="exec")
219+
exec(mod_code, global_env, local_env)
220+
221+
res = local_env[function_name]
222+
res.__source__ = src # type: ignore
223+
224+
if cache_key is not None:
225+
CACHED_SRC_FUNCTIONS[res] = cache_key
226+
return res

pytensor/link/numba/compile.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import warnings
2+
from collections.abc import Callable
3+
from functools import singledispatch
24

35
import numba
46
import numpy as np
@@ -8,7 +10,6 @@
810

911
from pytensor import config
1012
from pytensor.graph import Apply, FunctionGraph, Type
11-
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1213
from pytensor.scalar import ScalarType
1314
from pytensor.sparse import SparseTensorType
1415
from pytensor.tensor import TensorType
@@ -55,6 +56,19 @@ def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs):
5556
return func(*args, fastmath=fastmath, **kwargs)
5657

5758

59+
@singledispatch
60+
def numba_funcify(
61+
typ, node=None, storage_map=None, **kwargs
62+
) -> Callable | tuple[Callable, str | int | None]:
63+
"""Generate a numba function for a given op and apply node (or Fgraph).
64+
65+
The resulting function will usually use the `no_cpython_wrapper`
66+
argument in numba, so it can not be called directly from python,
67+
but only from other jit functions.
68+
"""
69+
raise NotImplementedError(f"Numba funcify not implemented for type {typ}")
70+
71+
5872
def get_numba_type(
5973
pytensor_type: Type,
6074
layout: str = "A",
@@ -88,6 +102,8 @@ def get_numba_type(
88102
numba_dtype = numba.from_dtype(dtype)
89103
return numba_dtype
90104
elif isinstance(pytensor_type, SparseTensorType):
105+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
106+
91107
dtype = pytensor_type.numpy_dtype
92108
numba_dtype = numba.from_dtype(dtype)
93109
if pytensor_type.format == "csr":

0 commit comments

Comments
 (0)