Skip to content

Commit 40e6dad

Browse files
committed
Implement Numba VM with caching
And make that the default backend
1 parent fb60edb commit 40e6dad

File tree

9 files changed

+240
-51
lines changed

9 files changed

+240
-51
lines changed

pytensor/compile/mode.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"jax": JAXLinker(),
5151
"pytorch": PytorchLinker(),
5252
"numba": NumbaLinker(),
53-
"numba_vm": NumbaLinker(),
53+
"numba_vm": NumbaLinker(vm=True),
5454
}
5555

5656

@@ -453,15 +453,15 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
453453
# string as the key
454454
# Use VM_linker to allow lazy evaluation by default.
455455
FAST_COMPILE = Mode(
456-
NumbaLinker(),
456+
NumbaLinker(vm=True),
457457
# TODO: Fast_compile should just use python code, CHANGE ME!
458458
RewriteDatabaseQuery(
459459
include=["fast_compile", "numba"],
460460
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
461461
),
462462
)
463463
FAST_RUN = Mode(
464-
NumbaLinker(),
464+
NumbaLinker(vm=True),
465465
RewriteDatabaseQuery(
466466
include=["fast_run", "numba"],
467467
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
@@ -481,6 +481,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
481481
),
482482
)
483483

484+
NUMBA_VM = Mode(
485+
NumbaLinker(vm=True),
486+
NUMBA._optimizer,
487+
)
488+
484489
JAX = Mode(
485490
JAXLinker(),
486491
RewriteDatabaseQuery(
@@ -519,6 +524,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
519524
"FAST_RUN": FAST_RUN,
520525
"JAX": JAX,
521526
"NUMBA": NUMBA,
527+
"NUMBA_VM": NUMBA_VM,
522528
"PYTORCH": PYTORCH,
523529
}
524530

pytensor/configdefaults.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,12 @@ def add_compile_configvars():
380380
"vm_nogc",
381381
"cvm_nogc",
382382
"jax",
383+
"numba",
383384
]
384385
else:
385386
# g++ is not present or the user disabled it,
386387
# linker should default to python only.
387-
linker_options = ["py", "vm", "vm_nogc", "jax"]
388+
linker_options = ["py", "vm", "vm_nogc", "jax", "numba"]
388389
if type(config).cxx.is_default:
389390
# If the user provided an empty value for cxx, do not warn.
390391
_logger.warning(
@@ -398,7 +399,7 @@ def add_compile_configvars():
398399
"linker",
399400
"Default linker used if the pytensor flags mode is Mode",
400401
# Not mutable because the default mode is cached after the first use.
401-
EnumStr("numba", linker_options, mutable=False),
402+
EnumStr("numba_vm", linker_options, mutable=False),
402403
in_c_key=False,
403404
)
404405

pytensor/link/numba/cache.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from collections.abc import Callable
2+
from pathlib import Path
3+
from tempfile import NamedTemporaryFile, TemporaryFile
4+
from typing import Any
5+
6+
from numba.core.caching import CacheImpl, _CacheLocator
7+
8+
from pytensor import config
9+
10+
11+
NUMBA_PYTENSOR_CACHE_ENABLED = True
12+
COMPILED_SRC_FUNCTIONS = {}
13+
14+
15+
def compile_and_cache_numba_function_src(
16+
src: str,
17+
function_name: str,
18+
global_env: dict[Any, Any] | None = None,
19+
local_env: dict[Any, Any] | None = None,
20+
key: str | None = None,
21+
) -> Callable:
22+
if key is not None:
23+
numba_path = config.base_compiledir / "numba"
24+
numba_path.mkdir(exist_ok=True)
25+
filename = numba_path / key
26+
with filename.open("wb") as f:
27+
f.write(src.encode())
28+
else:
29+
with NamedTemporaryFile(delete=False) as f:
30+
filename = f.name
31+
f.write(src.encode())
32+
33+
if global_env is None:
34+
global_env = {}
35+
36+
if local_env is None:
37+
local_env = {}
38+
39+
mod_code = compile(src, filename, mode="exec")
40+
exec(mod_code, global_env, local_env)
41+
42+
res = local_env[function_name]
43+
res.__source__ = src # type: ignore
44+
45+
if key is not None:
46+
COMPILED_SRC_FUNCTIONS[res] = key
47+
return res
48+
49+
50+
class NumbaPyTensorCacheLocator(_CacheLocator):
51+
def __init__(self, py_func, py_file, hash):
52+
# print(f"New locator {py_func=}, {py_file=}, {hash=}")
53+
self._py_func = py_func
54+
self._py_file = py_file
55+
self._hash = hash
56+
# src_hash = hash(pytensor_loader._module_sources[self._py_file])
57+
# self._hash = hash((src_hash, py_file, pytensor.__version__))
58+
59+
def ensure_cache_path(self):
60+
# print("ensure_cache_path called")
61+
path = self.get_cache_path()
62+
path.mkdir(exist_ok=True)
63+
# Ensure the directory is writable by trying to write a temporary file
64+
TemporaryFile(dir=path).close()
65+
66+
def get_cache_path(self):
67+
"""
68+
Return the directory the function is cached in.
69+
"""
70+
# print("get_cache_path called")
71+
return self._py_file
72+
73+
def get_source_stamp(self):
74+
"""
75+
Get a timestamp representing the source code's freshness.
76+
Can return any picklable Python object.
77+
"""
78+
return 0
79+
# print("get_source_stamp called")
80+
return self._hash
81+
82+
def get_disambiguator(self):
83+
"""
84+
Get a string disambiguator for this locator's function.
85+
It should allow disambiguating different but similarly-named functions.
86+
"""
87+
# print("get_disambiguator called")
88+
return self._hash
89+
90+
@classmethod
91+
def from_function(cls, py_func, py_file):
92+
"""
93+
Create a locator instance for the given function located in the given file.
94+
"""
95+
# py_file = Path(py_file).parent
96+
# if py_file == (config.base_compiledir / "numba"):
97+
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in COMPILED_SRC_FUNCTIONS:
98+
# print(f"Applies to {py_file}")
99+
return cls(py_func, Path(py_file).parent, COMPILED_SRC_FUNCTIONS[py_func])
100+
101+
102+
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)

pytensor/link/numba/dispatch/basic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numba import types
1515
from numba.core.errors import NumbaWarning, TypingError
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
17-
from numba.extending import box, overload
17+
from numba.extending import box, overload, register_jitable as _register_jitable
1818

1919
from pytensor import In, config
2020
from pytensor.compile import NUMBA
@@ -50,10 +50,11 @@ def global_numba_func(func):
5050
return func
5151

5252

53-
def numba_njit(*args, fastmath=None, **kwargs):
54-
kwargs.setdefault("cache", config.numba__cache)
55-
kwargs.setdefault("no_cpython_wrapper", True)
56-
kwargs.setdefault("no_cfunc_wrapper", True)
53+
def numba_njit(*args, fastmath=None, register_jitable: bool = False, **kwargs):
54+
kwargs.setdefault("cache", True)
55+
kwargs.setdefault("no_cpython_wrapper", False)
56+
kwargs.setdefault("no_cfunc_wrapper", False)
57+
# print(kwargs)
5758
if fastmath is None:
5859
if config.numba__fastmath:
5960
# Opinionated default on fastmath flags
@@ -81,10 +82,11 @@ def numba_njit(*args, fastmath=None, **kwargs):
8182
category=NumbaWarning,
8283
)
8384

85+
func = _register_jitable if register_jitable else numba.njit
8486
if len(args) > 0 and callable(args[0]):
85-
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
86-
87-
return numba.njit(*args, fastmath=fastmath, **kwargs)
87+
return func(*args[1:], fastmath=fastmath, **kwargs)(args[0])
88+
else:
89+
return func(*args, fastmath=fastmath, **kwargs)
8890

8991

9092
def numba_vectorize(*args, **kwargs):

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import singledispatch
2+
from hashlib import sha256
23
from textwrap import dedent, indent
34

45
import numba
@@ -7,18 +8,17 @@
78
from numpy.lib.stride_tricks import as_strided
89

910
from pytensor.graph.op import Op
11+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
1012
from pytensor.link.numba.dispatch import basic as numba_basic
1113
from pytensor.link.numba.dispatch.basic import (
1214
numba_funcify,
1315
numba_njit,
1416
)
1517
from pytensor.link.numba.dispatch.vectorize_codegen import (
16-
_jit_options,
1718
_vectorized,
1819
encode_literals,
1920
store_core_outputs,
2021
)
21-
from pytensor.link.utils import compile_function_src
2222
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
2323
from pytensor.scalar.basic import (
2424
AND,
@@ -237,7 +237,7 @@ def {careduce_fn_name}(x):
237237
careduce_def_src += "\n\n"
238238
careduce_def_src += indent(f"return {return_obj}", " " * 4)
239239

240-
careduce_fn = compile_function_src(
240+
careduce_fn = compile_and_cache_numba_function_src(
241241
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
242242
)
243243

@@ -264,19 +264,34 @@ def axis_apply_fn(x):
264264

265265
@numba_funcify.register(Elemwise)
266266
def numba_funcify_Elemwise(op, node, **kwargs):
267+
nin = len(node.inputs)
268+
nout = len(node.outputs)
269+
267270
scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
268271
scalar_node = op.scalar_op.make_node(*scalar_inputs)
269-
270272
scalar_op_fn = numba_funcify(
271273
op.scalar_op,
272274
node=scalar_node,
273275
parent_node=node,
274276
**kwargs,
275277
)
276278

277-
nin = len(node.inputs)
278-
nout = len(node.outputs)
279-
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
279+
# TODO: Proper key
280+
core_op_key = "_".join(
281+
map(
282+
str,
283+
(
284+
op,
285+
op.scalar_op,
286+
tuple(op.inplace_pattern.items()),
287+
tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()),
288+
),
289+
)
290+
)
291+
core_op_key = sha256(core_op_key.encode()).hexdigest()
292+
core_op_fn = store_core_outputs(
293+
scalar_op_fn, nin=nin, nout=nout, core_op_key=core_op_key
294+
)
280295

281296
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
282297
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
@@ -333,11 +348,31 @@ def elemwise(*inputs):
333348
return tuple(outputs_summed)
334349
return outputs_summed[0]
335350

336-
@overload(elemwise, jit_options=_jit_options)
351+
@overload(elemwise)
337352
def ov_elemwise(*inputs):
338353
return elemwise_wrapper
339354

340-
return elemwise
355+
# TODO: Also input dtypes in key
356+
elemwise_key = "_".join(
357+
map(
358+
str,
359+
(
360+
"Elemwise",
361+
core_op_key,
362+
input_bc_patterns,
363+
inplace_pattern,
364+
),
365+
)
366+
)
367+
elemwise_key = sha256(elemwise_key.encode()).hexdigest()
368+
f = compile_and_cache_numba_function_src(
369+
"def f(*inputs): return elemwise(*inputs)",
370+
"f",
371+
{**globals(), **{"elemwise": elemwise}},
372+
key=elemwise_key,
373+
)
374+
375+
return numba_njit(f)
341376

342377

343378
@numba_funcify.register(Sum)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pytensor.compile.ops import TypeCastingOp
66
from pytensor.graph.basic import Variable
7+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
78
from pytensor.link.numba.dispatch import basic as numba_basic
89
from pytensor.link.numba.dispatch.basic import (
910
create_numba_signature,
@@ -12,7 +13,6 @@
1213
)
1314
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
1415
from pytensor.link.utils import (
15-
compile_function_src,
1616
get_name_for_object,
1717
unique_name_generator,
1818
)
@@ -128,16 +128,20 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
128128
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
129129
"""
130130

131-
scalar_op_fn = compile_function_src(
132-
scalar_op_src, scalar_op_fn_name, {**globals(), **global_env}
131+
scalar_op_fn = compile_and_cache_numba_function_src(
132+
scalar_op_src,
133+
scalar_op_fn_name,
134+
{**globals(), **global_env},
133135
)
134136

135-
signature = create_numba_signature(node, force_scalar=True)
137+
# signature = create_numba_signature(node, force_scalar=True)
136138

137139
return numba_basic.numba_njit(
138-
signature,
140+
# signature,
139141
# Functions that call a function pointer can't be cached
140-
cache=False,
142+
no_cfunc_wrapper=True,
143+
no_cpython_wrapper=True,
144+
register_jitable=False,
141145
)(scalar_op_fn)
142146

143147

@@ -164,7 +168,7 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
164168
def {binary_op_name}({input_signature}):
165169
return {output_expr}
166170
"""
167-
nary_fn = compile_function_src(nary_src, binary_op_name, globals())
171+
nary_fn = compile_and_cache_numba_function_src(nary_src, binary_op_name, globals())
168172

169173
return nary_fn
170174

0 commit comments

Comments
 (0)