Skip to content

Commit 1329890

Browse files
committed
.More hacking around
1 parent 9c1ee06 commit 1329890

File tree

23 files changed

+558
-573
lines changed

23 files changed

+558
-573
lines changed

pytensor/link/numba/cache.py

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,21 @@
1-
from collections.abc import Callable
1+
import weakref
2+
from hashlib import sha256
23
from pathlib import Path
3-
from tempfile import NamedTemporaryFile
4-
from typing import Any
54

65
from numba.core.caching import CacheImpl, _CacheLocator
76

87
from pytensor import config
8+
from pytensor.graph.basic import Apply
99

1010

1111
NUMBA_PYTENSOR_CACHE_ENABLED = True
1212
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
1313
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
14-
CACHED_SRC_FUNCTIONS = {}
15-
16-
17-
def compile_and_cache_numba_function_src(
18-
src: str,
19-
function_name: str,
20-
global_env: dict[Any, Any] | None = None,
21-
local_env: dict[Any, Any] | None = None,
22-
key: str | None = None,
23-
) -> Callable:
24-
if key is not None:
25-
filename = NUMBA_CACHE_PATH / key
26-
with filename.open("wb") as f:
27-
f.write(src.encode())
28-
else:
29-
with NamedTemporaryFile(delete=False) as f:
30-
filename = f.name
31-
f.write(src.encode())
32-
33-
if global_env is None:
34-
global_env = {}
35-
36-
if local_env is None:
37-
local_env = {}
38-
39-
mod_code = compile(src, filename, mode="exec")
40-
exec(mod_code, global_env, local_env)
41-
42-
res = local_env[function_name]
43-
res.__source__ = src # type: ignore
44-
45-
if key is not None:
46-
CACHED_SRC_FUNCTIONS[res] = key
47-
return res
48-
49-
50-
def cache_numba_function(
51-
fn,
52-
key: str | None = None,
53-
) -> Callable:
54-
if key is not None:
55-
CACHED_SRC_FUNCTIONS[fn] = key
56-
return fn
14+
CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary()
5715

5816

5917
class NumbaPyTensorCacheLocator(_CacheLocator):
6018
def __init__(self, py_func, py_file, hash):
61-
# print(f"New locator {py_func=}, {py_file=}, {hash=}")
6219
self._py_func = py_func
6320
self._py_file = py_file
6421
self._hash = hash
@@ -101,3 +58,19 @@ def from_function(cls, py_func, py_file):
10158

10259

10360
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
61+
62+
63+
def cache_node_key(node: Apply, extra_key="") -> str:
64+
op = node.op
65+
return sha256(
66+
str(
67+
(
68+
# Op signature
69+
(type(op), op._props_dict() if hasattr(op, "_props_dict") else ""),
70+
# Node signature
71+
tuple((type(inp_type := inp.type), inp_type) for inp in node.inputs),
72+
# Extra key given by the caller
73+
extra_key,
74+
),
75+
).encode()
76+
).hexdigest()

pytensor/link/numba/compile.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import warnings
2+
from collections.abc import Callable
3+
from typing import Any
4+
5+
import numba
6+
import numpy as np
7+
from numba import NumbaWarning
8+
from numba import njit as _njit
9+
from numba.core.extending import register_jitable
10+
11+
from pytensor import config
12+
from pytensor.graph import Apply, FunctionGraph, Type
13+
from pytensor.link.numba.cache import CACHED_SRC_FUNCTIONS
14+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
15+
from pytensor.scalar import ScalarType
16+
from pytensor.sparse import SparseTensorType
17+
from pytensor.tensor import TensorType
18+
19+
20+
def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs):
21+
if fastmath is None:
22+
if config.numba__fastmath:
23+
# Opinionated default on fastmath flags
24+
# https://llvm.org/docs/LangRef.html#fast-math-flags
25+
fastmath = {
26+
"arcp", # Allow Reciprocal
27+
"contract", # Allow floating-point contraction
28+
"afn", # Approximate functions
29+
"reassoc",
30+
"nsz", # no-signed zeros
31+
}
32+
else:
33+
fastmath = False
34+
35+
if final_function:
36+
kwargs.setdefault("cache", True)
37+
else:
38+
kwargs.setdefault("no_cpython_wrapper", True)
39+
kwargs.setdefault("no_cfunc_wrapper", True)
40+
41+
# Suppress cache warning for internal functions
42+
# We have to add an ansi escape code for optional bold text by numba
43+
warnings.filterwarnings(
44+
"ignore",
45+
message=(
46+
"(\x1b\\[1m)*" # ansi escape code for bold text
47+
"Cannot cache compiled function "
48+
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
49+
"as it uses dynamic globals"
50+
),
51+
category=NumbaWarning,
52+
)
53+
54+
func = _njit if final_function else register_jitable
55+
if len(args) > 0 and callable(args[0]):
56+
return func(*args[1:], fastmath=fastmath, **kwargs)(args[0])
57+
else:
58+
return func(*args, fastmath=fastmath, **kwargs)
59+
60+
61+
def compile_and_cache_numba_function_src(
62+
src: str,
63+
function_name: str,
64+
global_env: dict[Any, Any] | None = None,
65+
local_env: dict[Any, Any] | None = None,
66+
key: str | None = None,
67+
) -> Callable:
68+
# if key is not None:
69+
# filename = NUMBA_CACHE_PATH / key
70+
# with filename.open("wb") as f:
71+
# f.write(src.encode())
72+
# else:
73+
# with NamedTemporaryFile(delete=False) as f:
74+
# filename = f.name
75+
# f.write(src.encode())
76+
77+
if global_env is None:
78+
global_env = {}
79+
80+
if local_env is None:
81+
local_env = {}
82+
83+
mod_code = compile(src, "<string>", mode="exec")
84+
exec(mod_code, global_env, local_env)
85+
86+
res = local_env[function_name]
87+
res.__source__ = src # type: ignore
88+
89+
if key is not None:
90+
CACHED_SRC_FUNCTIONS[res] = key
91+
return res
92+
93+
94+
def get_numba_type(
95+
pytensor_type: Type,
96+
layout: str = "A",
97+
force_scalar: bool = False,
98+
reduce_to_scalar: bool = False,
99+
) -> numba.types.Type:
100+
r"""Create a Numba type object for a :class:`Type`.
101+
102+
Parameters
103+
----------
104+
pytensor_type
105+
The :class:`Type` to convert.
106+
layout
107+
The :class:`numpy.ndarray` layout to use.
108+
force_scalar
109+
Ignore dimension information and return the corresponding Numba scalar types.
110+
reduce_to_scalar
111+
Return Numba scalars for zero dimensional :class:`TensorType`\s.
112+
"""
113+
114+
if isinstance(pytensor_type, TensorType):
115+
dtype = pytensor_type.numpy_dtype
116+
numba_dtype = numba.from_dtype(dtype)
117+
if force_scalar or (
118+
reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0
119+
):
120+
return numba_dtype
121+
return numba.types.Array(numba_dtype, pytensor_type.ndim, layout)
122+
elif isinstance(pytensor_type, ScalarType):
123+
dtype = np.dtype(pytensor_type.dtype)
124+
numba_dtype = numba.from_dtype(dtype)
125+
return numba_dtype
126+
elif isinstance(pytensor_type, SparseTensorType):
127+
dtype = pytensor_type.numpy_dtype
128+
numba_dtype = numba.from_dtype(dtype)
129+
if pytensor_type.format == "csr":
130+
return CSRMatrixType(numba_dtype)
131+
if pytensor_type.format == "csc":
132+
return CSCMatrixType(numba_dtype)
133+
134+
raise NotImplementedError()
135+
else:
136+
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
137+
138+
139+
def create_numba_signature(
140+
node_or_fgraph: FunctionGraph | Apply,
141+
force_scalar: bool = False,
142+
reduce_to_scalar: bool = False,
143+
) -> numba.types.Type:
144+
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
145+
input_types = [
146+
get_numba_type(
147+
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
148+
)
149+
for inp in node_or_fgraph.inputs
150+
]
151+
152+
output_types = [
153+
get_numba_type(
154+
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
155+
)
156+
for out in node_or_fgraph.outputs
157+
]
158+
159+
if len(output_types) > 1:
160+
return numba.types.Tuple(output_types)(*input_types)
161+
elif len(output_types) == 1:
162+
return output_types[0](*input_types)
163+
else:
164+
return numba.types.void(*input_types)
165+
166+
167+
def create_tuple_creator(f, n):
168+
"""Construct a compile-time ``tuple``-comprehension-like loop.
169+
170+
See https://github.com/numba/numba/issues/2771#issuecomment-414358902
171+
"""
172+
assert n > 0
173+
174+
f = numba_njit(f)
175+
176+
@numba_njit
177+
def creator(args):
178+
return (f(0, *args),)
179+
180+
for i in range(1, n):
181+
182+
@numba_njit
183+
def creator(args, creator=creator, i=i):
184+
return (*creator(args), f(i, *args))
185+
186+
return numba_njit(lambda *args: creator(args))
187+
188+
189+
def create_tuple_string(x):
190+
args = ", ".join(x + ([""] if len(x) == 1 else []))
191+
return f"({args})"
192+
193+
194+
def create_arg_string(x):
195+
args = ", ".join(x)
196+
return args

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.link.numba.dispatch.random
1010
import pytensor.link.numba.dispatch.scan
1111
import pytensor.link.numba.dispatch.scalar
12+
import pytensor.link.numba.dispatch.shape
1213
import pytensor.link.numba.dispatch.signal
1314
import pytensor.link.numba.dispatch.slinalg
1415
import pytensor.link.numba.dispatch.sparse

0 commit comments

Comments
 (0)