|
1 | 1 | import operator |
2 | 2 | import sys |
3 | 3 | import warnings |
| 4 | +from collections.abc import Callable |
4 | 5 | from functools import singledispatch |
5 | 6 |
|
6 | 7 | import numba |
|
18 | 19 | from pytensor.compile.ops import DeepCopyOp |
19 | 20 | from pytensor.graph.fg import FunctionGraph |
20 | 21 | from pytensor.ifelse import IfElse |
| 22 | +from pytensor.link.numba.cache import ( |
| 23 | + cache_node_key, |
| 24 | +) |
21 | 25 | from pytensor.link.numba.compile import ( |
| 26 | + compile_and_cache_numba_function_src, |
22 | 27 | get_numba_type, |
23 | 28 | numba_njit, |
24 | 29 | ) |
@@ -208,20 +213,80 @@ def perform(*inputs): |
208 | 213 | ret = py_perform_return(inputs) |
209 | 214 | return ret |
210 | 215 |
|
211 | | - return perform |
| 216 | + # Assume we can't cache python functions |
| 217 | + return perform, None |
212 | 218 |
|
213 | 219 |
|
214 | 220 | @singledispatch |
215 | | -def numba_funcify(op, node=None, storage_map=None, **kwargs): |
| 221 | +def numba_funcify( |
| 222 | + op, node=None, storage_map=None, **kwargs |
| 223 | +) -> Callable | tuple[Callable, str | int | None]: |
216 | 224 | """Generate a numba function for a given op and apply node. |
217 | 225 |
|
218 | 226 | The resulting function will usually use the `no_cpython_wrapper` |
219 | 227 | argument in numba, so it can not be called directly from python, |
220 | 228 | but only from other jit functions. |
| 229 | +
|
| 230 | + Optionally, the function can return a key that can be used to provide |
| 231 | + extra caching context or to disable caching (by returning `None`). |
| 232 | + When nothing is returned, PyTensor will assume the function can be cached |
| 233 | + based on the op and node signature alone. |
221 | 234 | """ |
222 | 235 | return generate_fallback_impl(op, node, storage_map, **kwargs) |
223 | 236 |
|
224 | 237 |
|
| 238 | +def numba_funcify_njit(op, node, **kwargs): |
| 239 | + jitable_func_and_key = numba_funcify(op, node=node, **kwargs) |
| 240 | + |
| 241 | + match jitable_func_and_key: |
| 242 | + case Callable(): |
| 243 | + jitable_func = jitable_func_and_key |
| 244 | + key = cache_node_key(node) |
| 245 | + case (Callable(), str() | int()): |
| 246 | + jitable_func, funcify_key = jitable_func_and_key |
| 247 | + key = cache_node_key(node, funcify_key) |
| 248 | + case (Callable(), None): |
| 249 | + # We were explicitly told by the dispatch not to try and cache this function |
| 250 | + jitable_func, key = jitable_func_and_key |
| 251 | + case _: |
| 252 | + raise TypeError( |
| 253 | + f"numpy_funcify should return a callable or a (callable, key) pair, got {jitable_func_and_key}" |
| 254 | + ) |
| 255 | + |
| 256 | + if key is not None: |
| 257 | + # To force numba to use our cache, we must compile the function so that any closure |
| 258 | + # becomes a global variable... |
| 259 | + op_name = op.__class__.__name__ |
| 260 | + cached_func = compile_and_cache_numba_function_src( |
| 261 | + src=f"def {op_name}(*args): return jitable_func(*args)", |
| 262 | + function_name=op_name, |
| 263 | + global_env=globals() | dict(jitable_func=jitable_func), |
| 264 | + cache_key=key, |
| 265 | + ) |
| 266 | + return numba_njit(cached_func, final_function=True, cache=True) |
| 267 | + else: |
| 268 | + return numba_njit( |
| 269 | + lambda *args: jitable_func(*args), final_function=True, cache=False |
| 270 | + ) |
| 271 | + |
| 272 | + |
| 273 | +@numba_funcify.register(FunctionGraph) |
| 274 | +def numba_funcify_FunctionGraph( |
| 275 | + fgraph, |
| 276 | + node=None, |
| 277 | + fgraph_name="numba_funcified_fgraph", |
| 278 | + **kwargs, |
| 279 | +): |
| 280 | + # TODO: Create hash key for whole graph |
| 281 | + return fgraph_to_python( |
| 282 | + fgraph, |
| 283 | + op_conversion_fn=numba_funcify_njit, |
| 284 | + type_conversion_fn=numba_typify, |
| 285 | + fgraph_name=fgraph_name, |
| 286 | + **kwargs, |
| 287 | + ) |
| 288 | + |
| 289 | + |
225 | 290 | @numba_funcify.register(OpFromGraph) |
226 | 291 | def numba_funcify_OpFromGraph(op, node=None, **kwargs): |
227 | 292 | _ = kwargs.pop("storage_map", None) |
@@ -251,23 +316,8 @@ def opfromgraph(*inputs): |
251 | 316 | def opfromgraph(*inputs): |
252 | 317 | return fgraph_fn(*inputs) |
253 | 318 |
|
254 | | - return opfromgraph |
255 | | - |
256 | | - |
257 | | -@numba_funcify.register(FunctionGraph) |
258 | | -def numba_funcify_FunctionGraph( |
259 | | - fgraph, |
260 | | - node=None, |
261 | | - fgraph_name="numba_funcified_fgraph", |
262 | | - **kwargs, |
263 | | -): |
264 | | - return fgraph_to_python( |
265 | | - fgraph, |
266 | | - numba_funcify, |
267 | | - type_conversion_fn=numba_typify, |
268 | | - fgraph_name=fgraph_name, |
269 | | - **kwargs, |
270 | | - ) |
| 319 | + # We can't cache this correctly until we can define a key for it |
| 320 | + return opfromgraph, None |
271 | 321 |
|
272 | 322 |
|
273 | 323 | @numba_funcify.register(DeepCopyOp) |
|
0 commit comments