From 079a51aebd62a2170f17427d43c302bf58294414 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Jun 2024 19:05:19 +0200 Subject: [PATCH 01/12] Add error message in Numba implementation of SpecifyShape --- pytensor/link/numba/dispatch/basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 87b8e380d3..8bb28c03a5 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -545,11 +545,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( zip(shape_inputs, shape_input_names, strict=True) ) - if shape_input is not NoneConst + if node_dim_input is not NoneConst ] func = dedent( From 675e7e51ac43682315cd851b4a52477e9122041f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 14 Sep 2025 12:19:15 +0200 Subject: [PATCH 02/12] Make Numba the default backend --- pytensor/compile/mode.py | 34 ++++++++++++++++++---------- pytensor/configdefaults.py | 16 ++++++++++--- tests/tensor/rewriting/test_basic.py | 6 ++++- 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 8bd0e2f901..e92ba26eb3 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -50,6 +50,7 @@ "jax": JAXLinker(), "pytorch": PytorchLinker(), "numba": NumbaLinker(), + "numba_vm": NumbaLinker(), } @@ -63,9 +64,8 @@ def register_linker(name, linker): # If a string is passed as the optimizer argument in the constructor # for Mode, it will be used as the key to retrieve the real optimizer # in this dictionary -exclude = [] -if not config.cxx: - exclude = ["cxx_only"] + +exclude = ["cxx_only", "BlasOpt"] OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude) # Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude) @@ -351,6 +351,11 @@ def __setstate__(self, state): optimizer = predefined_optimizers[optimizer] if isinstance(optimizer, RewriteDatabaseQuery): self.provided_optimizer = optimizer + + # Force numba-required rewrites if using NumbaLinker + if isinstance(linker, NumbaLinker): + optimizer = optimizer.including("numba") + self._optimizer = optimizer self.call_time = 0 self.fn_time = 0 @@ -448,16 +453,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # string as the key # Use VM_linker to allow lazy evaluation by default. FAST_COMPILE = Mode( - VMLinker(use_cloop=False, c_thunks=False), - RewriteDatabaseQuery(include=["fast_compile", "py_only"]), + NumbaLinker(), + # TODO: Fast_compile should just use python code, CHANGE ME! + RewriteDatabaseQuery( + include=["fast_compile", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), +) +FAST_RUN = Mode( + NumbaLinker(), + RewriteDatabaseQuery( + include=["fast_run", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), ) -if config.cxx: - FAST_RUN = Mode("cvm", "fast_run") -else: - FAST_RUN = Mode( - "vm", - RewriteDatabaseQuery(include=["fast_run", "py_only"]), - ) NUMBA = Mode( NumbaLinker(), @@ -574,6 +583,7 @@ def register_mode(name, mode): Add a `Mode` which can be referred to by `name` in `function`. """ + # TODO: Remove me if name in predefined_modes: raise ValueError(f"Mode name already taken: {name}") predefined_modes[name] = mode diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 7698c5d441..ccb8a7434d 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -370,11 +370,21 @@ def add_compile_configvars(): if rc == 0 and config.cxx != "": # Keep the default linker the same as the one for the mode FAST_RUN - linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"] + linker_options = [ + "cvm", + "c|py", + "py", + "c", + "c|py_nogc", + "vm", + "vm_nogc", + "cvm_nogc", + "jax", + ] else: # g++ is not present or the user disabled it, # linker should default to python only. - linker_options = ["py", "vm_nogc"] + linker_options = ["py", "vm", "vm_nogc", "jax"] if type(config).cxx.is_default: # If the user provided an empty value for cxx, do not warn. _logger.warning( @@ -388,7 +398,7 @@ def add_compile_configvars(): "linker", "Default linker used if the pytensor flags mode is Mode", # Not mutable because the default mode is cached after the first use. - EnumStr("cvm", linker_options, mutable=False), + EnumStr("numba", linker_options, mutable=False), in_c_key=False, ) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 4a78a1e9fe..2fd8eca2ef 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1,4 +1,5 @@ import copy +import re import numpy as np import pytest @@ -306,7 +307,9 @@ def test_inconsistent_shared(self, shape_unsafe): # Error raised by Alloc Op with pytest.raises( ValueError, - match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", + match=re.escape( + "cannot assign slice of shape (3, 7) from input of shape (6, 7)" + ), ): f() @@ -1203,6 +1206,7 @@ def test_sum_bool_upcast(self): f(5) +@pytest.mark.xfail(reason="Numba does not support float16") class TestLocalOptAllocF16(TestLocalOptAlloc): dtype = "float16" From fb5768466062d76fdc5104e9b0aa00c0145798f9 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 14 Sep 2025 11:46:01 +0200 Subject: [PATCH 03/12] Implement Numba VM with caching And make that the default backend --- pytensor/compile/mode.py | 12 ++- pytensor/configdefaults.py | 5 +- pytensor/link/numba/cache.py | 102 ++++++++++++++++++ pytensor/link/numba/dispatch/basic.py | 18 ++-- pytensor/link/numba/dispatch/elemwise.py | 53 +++++++-- pytensor/link/numba/dispatch/scalar.py | 18 ++-- .../link/numba/dispatch/vectorize_codegen.py | 50 +++++---- pytensor/link/numba/linker.py | 13 ++- tests/link/numba/test_basic.py | 20 ++++ 9 files changed, 240 insertions(+), 51 deletions(-) create mode 100644 pytensor/link/numba/cache.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index e92ba26eb3..1c7f9e70a4 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -50,7 +50,7 @@ "jax": JAXLinker(), "pytorch": PytorchLinker(), "numba": NumbaLinker(), - "numba_vm": NumbaLinker(), + "numba_vm": NumbaLinker(vm=True), } @@ -453,7 +453,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # string as the key # Use VM_linker to allow lazy evaluation by default. FAST_COMPILE = Mode( - NumbaLinker(), + NumbaLinker(vm=True), # TODO: Fast_compile should just use python code, CHANGE ME! RewriteDatabaseQuery( include=["fast_compile", "numba"], @@ -461,7 +461,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) FAST_RUN = Mode( - NumbaLinker(), + NumbaLinker(vm=True), RewriteDatabaseQuery( include=["fast_run", "numba"], exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], @@ -481,6 +481,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) +NUMBA_VM = Mode( + NumbaLinker(vm=True), + NUMBA._optimizer, +) + JAX = Mode( JAXLinker(), RewriteDatabaseQuery( @@ -519,6 +524,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "FAST_RUN": FAST_RUN, "JAX": JAX, "NUMBA": NUMBA, + "NUMBA_VM": NUMBA_VM, "PYTORCH": PYTORCH, } diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index ccb8a7434d..e8ea54e7c2 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -380,11 +380,12 @@ def add_compile_configvars(): "vm_nogc", "cvm_nogc", "jax", + "numba", ] else: # g++ is not present or the user disabled it, # linker should default to python only. - linker_options = ["py", "vm", "vm_nogc", "jax"] + linker_options = ["py", "vm", "vm_nogc", "jax", "numba"] if type(config).cxx.is_default: # If the user provided an empty value for cxx, do not warn. _logger.warning( @@ -398,7 +399,7 @@ def add_compile_configvars(): "linker", "Default linker used if the pytensor flags mode is Mode", # Not mutable because the default mode is cached after the first use. - EnumStr("numba", linker_options, mutable=False), + EnumStr("numba_vm", linker_options, mutable=False), in_c_key=False, ) diff --git a/pytensor/link/numba/cache.py b/pytensor/link/numba/cache.py new file mode 100644 index 0000000000..0db8a7db42 --- /dev/null +++ b/pytensor/link/numba/cache.py @@ -0,0 +1,102 @@ +from collections.abc import Callable +from pathlib import Path +from tempfile import NamedTemporaryFile, TemporaryFile +from typing import Any + +from numba.core.caching import CacheImpl, _CacheLocator + +from pytensor import config + + +NUMBA_PYTENSOR_CACHE_ENABLED = True +COMPILED_SRC_FUNCTIONS = {} + + +def compile_and_cache_numba_function_src( + src: str, + function_name: str, + global_env: dict[Any, Any] | None = None, + local_env: dict[Any, Any] | None = None, + key: str | None = None, +) -> Callable: + if key is not None: + numba_path = config.base_compiledir / "numba" + numba_path.mkdir(exist_ok=True) + filename = numba_path / key + with filename.open("wb") as f: + f.write(src.encode()) + else: + with NamedTemporaryFile(delete=False) as f: + filename = f.name + f.write(src.encode()) + + if global_env is None: + global_env = {} + + if local_env is None: + local_env = {} + + mod_code = compile(src, filename, mode="exec") + exec(mod_code, global_env, local_env) + + res = local_env[function_name] + res.__source__ = src # type: ignore + + if key is not None: + COMPILED_SRC_FUNCTIONS[res] = key + return res + + +class NumbaPyTensorCacheLocator(_CacheLocator): + def __init__(self, py_func, py_file, hash): + # print(f"New locator {py_func=}, {py_file=}, {hash=}") + self._py_func = py_func + self._py_file = py_file + self._hash = hash + # src_hash = hash(pytensor_loader._module_sources[self._py_file]) + # self._hash = hash((src_hash, py_file, pytensor.__version__)) + + def ensure_cache_path(self): + # print("ensure_cache_path called") + path = self.get_cache_path() + path.mkdir(exist_ok=True) + # Ensure the directory is writable by trying to write a temporary file + TemporaryFile(dir=path).close() + + def get_cache_path(self): + """ + Return the directory the function is cached in. + """ + # print("get_cache_path called") + return self._py_file + + def get_source_stamp(self): + """ + Get a timestamp representing the source code's freshness. + Can return any picklable Python object. + """ + return 0 + # print("get_source_stamp called") + return self._hash + + def get_disambiguator(self): + """ + Get a string disambiguator for this locator's function. + It should allow disambiguating different but similarly-named functions. + """ + # print("get_disambiguator called") + return self._hash + + @classmethod + def from_function(cls, py_func, py_file): + """ + Create a locator instance for the given function located in the given file. + """ + # py_file = Path(py_file).parent + # if py_file == (config.base_compiledir / "numba"): + if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in COMPILED_SRC_FUNCTIONS: + # print(f"Applies to {py_file}") + return cls(py_func, Path(py_file).parent, COMPILED_SRC_FUNCTIONS[py_func]) + + +CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8bb28c03a5..7c307683f2 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -14,7 +14,7 @@ from numba import types from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import box, overload +from numba.extending import box, overload, register_jitable as _register_jitable from pytensor import In, config from pytensor.compile import NUMBA @@ -50,10 +50,11 @@ def global_numba_func(func): return func -def numba_njit(*args, fastmath=None, **kwargs): - kwargs.setdefault("cache", config.numba__cache) - kwargs.setdefault("no_cpython_wrapper", True) - kwargs.setdefault("no_cfunc_wrapper", True) +def numba_njit(*args, fastmath=None, register_jitable: bool = False, **kwargs): + kwargs.setdefault("cache", True) + kwargs.setdefault("no_cpython_wrapper", False) + kwargs.setdefault("no_cfunc_wrapper", False) + # print(kwargs) if fastmath is None: if config.numba__fastmath: # Opinionated default on fastmath flags @@ -81,10 +82,11 @@ def numba_njit(*args, fastmath=None, **kwargs): category=NumbaWarning, ) + func = _register_jitable if register_jitable else numba.njit if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) - - return numba.njit(*args, fastmath=fastmath, **kwargs) + return func(*args[1:], fastmath=fastmath, **kwargs)(args[0]) + else: + return func(*args, fastmath=fastmath, **kwargs) def numba_vectorize(*args, **kwargs): diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 7244762b93..550b2a5f5d 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,4 +1,5 @@ from functools import singledispatch +from hashlib import sha256 from textwrap import dedent, indent import numba @@ -7,18 +8,17 @@ from numpy.lib.stride_tricks import as_strided from pytensor.graph.op import Op +from pytensor.link.numba.cache import compile_and_cache_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( numba_funcify, numba_njit, ) from pytensor.link.numba.dispatch.vectorize_codegen import ( - _jit_options, _vectorized, encode_literals, store_core_outputs, ) -from pytensor.link.utils import compile_function_src from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar.basic import ( AND, @@ -237,7 +237,7 @@ def {careduce_fn_name}(x): careduce_def_src += "\n\n" careduce_def_src += indent(f"return {return_obj}", " " * 4) - careduce_fn = compile_function_src( + careduce_fn = compile_and_cache_numba_function_src( careduce_def_src, careduce_fn_name, {**globals(), **global_env} ) @@ -264,9 +264,11 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): + nin = len(node.inputs) + nout = len(node.outputs) + scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) - scalar_op_fn = numba_funcify( op.scalar_op, node=scalar_node, @@ -274,9 +276,22 @@ def numba_funcify_Elemwise(op, node, **kwargs): **kwargs, ) - nin = len(node.inputs) - nout = len(node.outputs) - core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) + # TODO: Proper key + core_op_key = "_".join( + map( + str, + ( + op, + op.scalar_op, + tuple(op.inplace_pattern.items()), + tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()), + ), + ) + ) + core_op_key = sha256(core_op_key.encode()).hexdigest() + core_op_fn = store_core_outputs( + scalar_op_fn, nin=nin, nout=nout, core_op_key=core_op_key + ) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) @@ -333,11 +348,31 @@ def elemwise(*inputs): return tuple(outputs_summed) return outputs_summed[0] - @overload(elemwise, jit_options=_jit_options) + @overload(elemwise) def ov_elemwise(*inputs): return elemwise_wrapper - return elemwise + # TODO: Also input dtypes in key + elemwise_key = "_".join( + map( + str, + ( + "Elemwise", + core_op_key, + input_bc_patterns, + inplace_pattern, + ), + ) + ) + elemwise_key = sha256(elemwise_key.encode()).hexdigest() + f = compile_and_cache_numba_function_src( + "def f(*inputs): return elemwise(*inputs)", + "f", + {**globals(), **{"elemwise": elemwise}}, + key=elemwise_key, + ) + + return numba_njit(f) @numba_funcify.register(Sum) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index ada4e8cc36..29189b2f80 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -4,6 +4,7 @@ from pytensor.compile.ops import TypeCastingOp from pytensor.graph.basic import Variable +from pytensor.link.numba.cache import compile_and_cache_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( create_numba_signature, @@ -12,7 +13,6 @@ ) from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( - compile_function_src, get_name_for_object, unique_name_generator, ) @@ -128,16 +128,20 @@ def {scalar_op_fn_name}({', '.join(input_names)}): return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) """ - scalar_op_fn = compile_function_src( - scalar_op_src, scalar_op_fn_name, {**globals(), **global_env} + scalar_op_fn = compile_and_cache_numba_function_src( + scalar_op_src, + scalar_op_fn_name, + {**globals(), **global_env}, ) - signature = create_numba_signature(node, force_scalar=True) + # signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( - signature, + # signature, # Functions that call a function pointer can't be cached - cache=False, + no_cfunc_wrapper=True, + no_cpython_wrapper=True, + register_jitable=False, )(scalar_op_fn) @@ -164,7 +168,7 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: def {binary_op_name}({input_signature}): return {output_expr} """ - nary_fn = compile_function_src(nary_src, binary_op_name, globals()) + nary_fn = compile_and_cache_numba_function_src(nary_src, binary_op_name, globals()) return nary_fn diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 060418cb6c..a25ba507ab 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -4,7 +4,7 @@ import pickle from collections.abc import Callable, Sequence from textwrap import indent -from typing import Any, cast +from typing import Any import numba import numpy as np @@ -15,15 +15,17 @@ from numba.core.types.misc import NoneType from numba.np import arrayobj +from pytensor.link.numba.cache import compile_and_cache_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.utils import compile_function_src def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() -def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: +def store_core_outputs( + core_op_fn: Callable, nin: int, nout: int, core_op_key=None +) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @njit @@ -52,10 +54,20 @@ def store_core_outputs({inp_signature}, {out_signature}): {indent(store_outputs, " " * 4)} """ global_env = {"core_op_fn": core_op_fn} - func = compile_function_src( - func_src, "store_core_outputs", {**globals(), **global_env} + + key = "_".join(("store_core_outputs", core_op_key)) if core_op_key else None + func = compile_and_cache_numba_function_src( + func_src, + "store_core_outputs", + {**globals(), **global_env}, + key=key, + ) + return numba_basic.numba_njit( + func, + register_jitable=True, + no_cpython_wrapper=True, + no_cfunc_wrapper=True, ) - return cast(Callable, numba_basic.numba_njit(func)) _jit_options = { @@ -74,7 +86,7 @@ def store_core_outputs({inp_signature}, {out_signature}): @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) def _vectorized( typingctx, - scalar_func, + core_func, input_bc_patterns, output_bc_patterns, output_dtypes, @@ -85,7 +97,7 @@ def _vectorized( size_type, ): arg_types = [ - scalar_func, + core_func, input_bc_patterns, output_bc_patterns, output_dtypes, @@ -173,16 +185,6 @@ def _vectorized( ) out_types[output_idx] = output_type - core_signature = typingctx.resolve_function_type( - scalar_func, - [ - *constant_inputs_types, - *core_input_types, - *core_out_types, - ], - {}, - ) - ret_type = types.Tuple(out_types) if len(output_dtypes) == 1: @@ -239,11 +241,21 @@ def codegen( output_core_shapes, ) + core_signature = typingctx.resolve_function_type( + core_func, + [ + *constant_inputs_types, + *core_input_types, + *core_out_types, + ], + {}, + ) + make_loop_call( typingctx, ctx, builder, - scalar_func, + core_func, core_signature, iter_shape, constant_inputs, diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 59dc81e1b0..199e6f48b1 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -4,16 +4,23 @@ class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" + def __init__(self, *args, vm: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.vm = vm + def fgraph_convert(self, fgraph, **kwargs): from pytensor.link.numba.dispatch import numba_funcify return numba_funcify(fgraph, **kwargs) def jit_compile(self, fn): - from pytensor.link.numba.dispatch.basic import numba_njit + if self.vm: + return fn + else: + from pytensor.link.numba.dispatch.basic import numba_njit - jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) - return jitted_fn + jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) + return jitted_fn def create_thunk_inputs(self, storage_map): return [storage_map[n] for n in self.fgraph.inputs] diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index fd9a48111f..4b5a44b053 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -8,6 +8,7 @@ import pytest import scipy +import pytensor.link.numba.cache from pytensor.compile import SymbolicInput @@ -950,3 +951,22 @@ def test_mat_vec_dot_performance(dtype, benchmark): x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype) np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4) benchmark(fn, A_test, x_test) + + +@pytest.mark.parametrize("use_cache", [False, True], ids=["no-cache", "cache"]) +@pytest.mark.parametrize("func", [pt.cos, pt.sin], ids=["cos", "sin"]) +def test_compile_time_benchmark(func, use_cache, benchmark): + x = pt.matrix("x") + y = func(x) + rng = np.random.default_rng(42) + x_test = rng.normal(size=(5, 3)) + + def compile(): + fn = function([x], y, mode="NUMBA_VM", trust_input=True) + return fn(x_test) + + pytensor.link.numba.cache.NUMBA_PYTENSOR_CACHE_ENABLED = use_cache + + res = compile() + np.testing.assert_allclose(res, y.eval({x: x_test})) + benchmark(compile) From b98cf4fb33b47f9559b78ded99631a985a7ecb36 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 14 Sep 2025 12:05:35 +0200 Subject: [PATCH 04/12] Cache more Ops --- pytensor/link/numba/dispatch/basic.py | 17 ++++++++----- pytensor/link/numba/dispatch/subtensor.py | 7 ++++-- pytensor/link/numba/dispatch/tensor_basic.py | 25 +++++++++++++++----- pytensor/utils.py | 5 ++-- 4 files changed, 37 insertions(+), 17 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 7c307683f2..4f2104746e 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -14,7 +14,8 @@ from numba import types from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import box, overload, register_jitable as _register_jitable +from numba.extending import box, overload +from numba.extending import register_jitable as _register_jitable from pytensor import In, config from pytensor.compile import NUMBA @@ -25,11 +26,9 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.ifelse import IfElse +from pytensor.link.numba.cache import compile_and_cache_numba_function_src from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType -from pytensor.link.utils import ( - compile_function_src, - fgraph_to_python, -) +from pytensor.link.utils import fgraph_to_python from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType from pytensor.tensor.basic import Nonzero @@ -40,6 +39,7 @@ from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst +from pytensor.utils import hash_from_code def global_numba_func(func): @@ -562,7 +562,12 @@ def specify_shape(x, {create_arg_string(shape_input_names)}): """ ) - specify_shape = compile_function_src(func, "specify_shape", globals()) + specify_shape = compile_and_cache_numba_function_src( + func, + "specify_shape", + globals(), + key=hash_from_code(func), + ) return numba_njit(specify_shape) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index fe0eda153e..f9ff3ebeee 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,9 +1,10 @@ import numpy as np from pytensor.graph import Type +from pytensor.link.numba.cache import compile_and_cache_numba_function_src from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit -from pytensor.link.utils import compile_function_src, unique_name_generator +from pytensor.link.utils import unique_name_generator from pytensor.tensor import TensorType from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( @@ -15,6 +16,7 @@ Subtensor, ) from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.utils import hash_from_code @numba_funcify.register(Subtensor) @@ -95,10 +97,11 @@ def {function_name}({", ".join(input_names)}): return np.asarray(z) """ - func = compile_function_src( + func = compile_and_cache_numba_function_src( subtensor_def_src, function_name=function_name, global_env=globals() | {"np": np}, + key=hash_from_code(subtensor_def_src), ) return numba_njit(func, boundscheck=True) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 3a9d8767b9..a9325fb0b5 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -2,9 +2,10 @@ import numpy as np +from pytensor.link.numba.cache import compile_and_cache_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify -from pytensor.link.utils import compile_function_src, unique_name_generator +from pytensor.link.utils import unique_name_generator from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -17,6 +18,7 @@ Split, TensorFromScalar, ) +from pytensor.utils import hash_from_code @numba_funcify.register(AllocEmpty) @@ -49,8 +51,11 @@ def allocempty({", ".join(shape_var_names)}): return np.empty(scalar_shape, dtype) """ - alloc_fn = compile_function_src( - alloc_def_src, "allocempty", {**globals(), **global_env} + alloc_fn = compile_and_cache_numba_function_src( + alloc_def_src, + "allocempty", + {**globals(), **global_env}, + key=hash_from_code(alloc_def_src), ) return numba_basic.numba_njit(alloc_fn) @@ -93,7 +98,12 @@ def alloc(val, {", ".join(shape_var_names)}): res[...] = val return res """ - alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env}) + alloc_fn = compile_and_cache_numba_function_src( + alloc_def_src, + "alloc", + {**globals(), **global_env}, + key=hash_from_code(alloc_def_src), + ) return numba_basic.numba_njit(alloc_fn) @@ -212,8 +222,11 @@ def makevector({", ".join(input_names)}): return np.array({create_list_string(input_names)}, dtype=dtype) """ - makevector_fn = compile_function_src( - makevector_def_src, "makevector", {**globals(), **global_env} + makevector_fn = compile_and_cache_numba_function_src( + makevector_def_src, + "makevector", + {**globals(), **global_env}, + key=f"MakeVector({op.dtype})", ) return numba_basic.numba_njit(makevector_fn) diff --git a/pytensor/utils.py b/pytensor/utils.py index c81fb74f56..137d03f380 100644 --- a/pytensor/utils.py +++ b/pytensor/utils.py @@ -191,9 +191,8 @@ def hash_from_code(msg: str | bytes) -> str: # but Python 3 (unicode) strings don't. if isinstance(msg, str): msg = msg.encode() - # Python 3 does not like module names that start with - # a digit. - return "m" + hashlib.sha256(msg).hexdigest() + # Python 3 does not like module names that start with a digit. + return f"m{hashlib.sha256(msg).hexdigest()}" def uniq(seq: Sequence) -> list: From 8dd4b48fc979afe503969fb4da50985f2cc56525 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 30 Sep 2025 19:47:27 +0200 Subject: [PATCH 05/12] .wip --- pytensor/link/numba/dispatch/basic.py | 11 +++-- pytensor/link/numba/dispatch/elemwise.py | 44 ++++++++------------ pytensor/link/numba/dispatch/extra_ops.py | 8 ++-- pytensor/link/numba/dispatch/nlinalg.py | 6 +-- pytensor/link/numba/dispatch/slinalg.py | 2 +- pytensor/link/numba/dispatch/tensor_basic.py | 16 +++---- pytensor/link/numba/linker.py | 2 +- 7 files changed, 42 insertions(+), 47 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 4f2104746e..14374c69c3 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -50,11 +50,11 @@ def global_numba_func(func): return func -def numba_njit(*args, fastmath=None, register_jitable: bool = False, **kwargs): +def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs): kwargs.setdefault("cache", True) kwargs.setdefault("no_cpython_wrapper", False) kwargs.setdefault("no_cfunc_wrapper", False) - # print(kwargs) + if fastmath is None: if config.numba__fastmath: # Opinionated default on fastmath flags @@ -380,11 +380,16 @@ def numba_funcify_FunctionGraph( fgraph, node=None, fgraph_name="numba_funcified_fgraph", + jit_nodes: bool = False, **kwargs, ): + def numba_funcify_njit(op, node, **kwargs): + jitable_func = numba_funcify(op, node=node, **kwargs) + return numba_njit(lambda *args: jitable_func(*args), register_jitable=False) + return fgraph_to_python( fgraph, - numba_funcify, + op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify, type_conversion_fn=numba_typify, fgraph_name=fgraph_name, **kwargs, diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 550b2a5f5d..ef57225609 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -320,33 +320,23 @@ def elemwise_wrapper(*inputs): # Pure python implementation, that will be used in tests def elemwise(*inputs): - inputs = [np.asarray(input) for input in inputs] + Elemwise._check_runtime_broadcast(node, inputs) inputs_bc = np.broadcast_arrays(*inputs) - shape = inputs[0].shape - for input, bc in zip(inputs, input_bc_patterns, strict=True): - for length, allow_bc, iter_length in zip( - input.shape, bc, shape, strict=True - ): - if length == 1 and shape and iter_length != 1 and not allow_bc: - raise ValueError("Broadcast not allowed.") - - outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes] - - for idx in np.ndindex(shape): - vals = [input[idx] for input in inputs_bc] - outs = scalar_op_fn(*vals) - if not isinstance(outs, tuple): - outs = (outs,) - for out, out_val in zip(outputs, outs, strict=True): - out[idx] = out_val - - outputs_summed = [] - for output, bc in zip(outputs, output_bc_patterns, strict=True): - axes = tuple(np.nonzero(bc)[0]) - outputs_summed.append(output.sum(axes, keepdims=True)) - if len(outputs_summed) != 1: - return tuple(outputs_summed) - return outputs_summed[0] + shape = inputs_bc[0].shape + + if len(output_dtypes) == 1: + output = np.empty(shape, dtype=output_dtypes[0]) + for idx in np.ndindex(shape): + output[idx] = scalar_op_fn(*(inp[idx] for inp in inputs_bc)) + return output + + else: + outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes] + for idx in np.ndindex(shape): + outs_vals = scalar_op_fn(*(inp[idx] for inp in inputs_bc)) + for out, out_val in zip(outputs, outs_vals): + out[idx] = out_val + return outputs @overload(elemwise) def ov_elemwise(*inputs): @@ -594,7 +584,7 @@ def numba_funcify_Argmax(op, node, **kwargs): if x_ndim == 0: - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def argmax(x): return np.array(0, dtype="int64") diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index f7700acf47..21df44d3a8 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -24,7 +24,7 @@ @numba_funcify.register(Bartlett) def numba_funcify_Bartlett(op, **kwargs): - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def bartlett(x): return np.bartlett(numba_basic.to_scalar(x)) @@ -228,13 +228,13 @@ def repeatop(x, repeats): if repeats_ndim == 0: - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def repeatop(x, repeats): return np.repeat(x, repeats.item()) else: - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def repeatop(x, repeats): return np.repeat(x, repeats) @@ -348,7 +348,7 @@ def searchsorted(a, v, sorter): else: - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def searchsorted(a, v): return np.searchsorted(a, v, side) diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 98d59a4595..58fe0e3719 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -49,7 +49,7 @@ def numba_funcify_Det(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def det(x): return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype) @@ -128,7 +128,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def matrix_inverse(x): return np.linalg.inv(inputs_cast(x)).astype(out_dtype) @@ -140,7 +140,7 @@ def numba_funcify_MatrixPinv(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def matrixpinv(x): return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 7d1e915298..5578a8379c 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -118,7 +118,7 @@ def numba_funcify_LU(op, node, **kwargs): if dtype in complex_dtypes: NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) - @numba_njit(inline="always") + @numba_njit def lu(a): if check_finite: if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index a9325fb0b5..e37a133f1c 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -112,12 +112,12 @@ def alloc(val, {", ".join(shape_var_names)}): def numba_funcify_ARange(op, **kwargs): dtype = np.dtype(op.dtype) - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def arange(start, stop, step): return np.arange( - numba_basic.to_scalar(start), - numba_basic.to_scalar(stop), - numba_basic.to_scalar(step), + start.item(), + stop.item(), + step.item(), dtype=dtype, ) @@ -164,7 +164,7 @@ def extract_diag(x): leading_dims = (slice(None),) * axis1 middle_dims = (slice(None),) * (axis2 - axis1 - 1) - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def extract_diag(x): if offset >= 0: diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset)) @@ -234,7 +234,7 @@ def makevector({", ".join(input_names)}): @numba_funcify.register(TensorFromScalar) def numba_funcify_TensorFromScalar(op, **kwargs): - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def tensor_from_scalar(x): return np.array(x) @@ -243,8 +243,8 @@ def tensor_from_scalar(x): @numba_funcify.register(ScalarFromTensor) def numba_funcify_ScalarFromTensor(op, **kwargs): - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit def scalar_from_tensor(x): - return numba_basic.to_scalar(x) + return x.item() return scalar_from_tensor diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 199e6f48b1..5ca598e472 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -11,7 +11,7 @@ def __init__(self, *args, vm: bool = False, **kwargs): def fgraph_convert(self, fgraph, **kwargs): from pytensor.link.numba.dispatch import numba_funcify - return numba_funcify(fgraph, **kwargs) + return numba_funcify(fgraph, jit_nodes=self.vm, **kwargs) def jit_compile(self, fn): if self.vm: From 3692b3af7a528ae866a2dd4b25ea4a109d67f233 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 1 Oct 2025 13:58:26 +0200 Subject: [PATCH 06/12] .More hacking around --- pytensor/link/numba/cache.py | 37 +++++----- pytensor/link/numba/dispatch/basic.py | 71 ++++++++++++++----- pytensor/link/numba/dispatch/elemwise.py | 30 ++------ pytensor/link/numba/dispatch/extra_ops.py | 2 +- pytensor/link/numba/dispatch/scalar.py | 8 +-- pytensor/link/numba/dispatch/subtensor.py | 7 +- pytensor/link/numba/dispatch/tensor_basic.py | 24 ++++--- .../link/numba/dispatch/vectorize_codegen.py | 13 +--- pytensor/mod1.py | 6 ++ tests/tensor/rewriting/test_basic.py | 4 +- 10 files changed, 107 insertions(+), 95 deletions(-) create mode 100644 pytensor/mod1.py diff --git a/pytensor/link/numba/cache.py b/pytensor/link/numba/cache.py index 0db8a7db42..7a9535e64f 100644 --- a/pytensor/link/numba/cache.py +++ b/pytensor/link/numba/cache.py @@ -1,6 +1,6 @@ from collections.abc import Callable from pathlib import Path -from tempfile import NamedTemporaryFile, TemporaryFile +from tempfile import NamedTemporaryFile from typing import Any from numba.core.caching import CacheImpl, _CacheLocator @@ -9,7 +9,9 @@ NUMBA_PYTENSOR_CACHE_ENABLED = True -COMPILED_SRC_FUNCTIONS = {} +NUMBA_CACHE_PATH = config.base_compiledir / "numba" +NUMBA_CACHE_PATH.mkdir(exist_ok=True) +CACHED_SRC_FUNCTIONS = {} def compile_and_cache_numba_function_src( @@ -20,9 +22,7 @@ def compile_and_cache_numba_function_src( key: str | None = None, ) -> Callable: if key is not None: - numba_path = config.base_compiledir / "numba" - numba_path.mkdir(exist_ok=True) - filename = numba_path / key + filename = NUMBA_CACHE_PATH / key with filename.open("wb") as f: f.write(src.encode()) else: @@ -43,10 +43,19 @@ def compile_and_cache_numba_function_src( res.__source__ = src # type: ignore if key is not None: - COMPILED_SRC_FUNCTIONS[res] = key + CACHED_SRC_FUNCTIONS[res] = key return res +def cache_numba_function( + fn, + key: str | None = None, +) -> Callable: + if key is not None: + CACHED_SRC_FUNCTIONS[fn] = key + return fn + + class NumbaPyTensorCacheLocator(_CacheLocator): def __init__(self, py_func, py_file, hash): # print(f"New locator {py_func=}, {py_file=}, {hash=}") @@ -57,18 +66,13 @@ def __init__(self, py_func, py_file, hash): # self._hash = hash((src_hash, py_file, pytensor.__version__)) def ensure_cache_path(self): - # print("ensure_cache_path called") - path = self.get_cache_path() - path.mkdir(exist_ok=True) - # Ensure the directory is writable by trying to write a temporary file - TemporaryFile(dir=path).close() + pass def get_cache_path(self): """ Return the directory the function is cached in. """ - # print("get_cache_path called") - return self._py_file + return NUMBA_CACHE_PATH def get_source_stamp(self): """ @@ -76,15 +80,12 @@ def get_source_stamp(self): Can return any picklable Python object. """ return 0 - # print("get_source_stamp called") - return self._hash def get_disambiguator(self): """ Get a string disambiguator for this locator's function. It should allow disambiguating different but similarly-named functions. """ - # print("get_disambiguator called") return self._hash @classmethod @@ -94,9 +95,9 @@ def from_function(cls, py_func, py_file): """ # py_file = Path(py_file).parent # if py_file == (config.base_compiledir / "numba"): - if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in COMPILED_SRC_FUNCTIONS: + if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS: # print(f"Applies to {py_file}") - return cls(py_func, Path(py_file).parent, COMPILED_SRC_FUNCTIONS[py_func]) + return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func]) CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 14374c69c3..0d822177f2 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -3,6 +3,7 @@ import warnings from copy import copy from functools import singledispatch +from hashlib import sha256 from textwrap import dedent import numba @@ -11,11 +12,11 @@ import scipy import scipy.special from llvmlite import ir +from numba import njit as _njit from numba import types from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import box, overload -from numba.extending import register_jitable as _register_jitable +from numba.extending import box, overload, register_jitable from pytensor import In, config from pytensor.compile import NUMBA @@ -26,7 +27,9 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.ifelse import IfElse -from pytensor.link.numba.cache import compile_and_cache_numba_function_src +from pytensor.link.numba.cache import ( + compile_and_cache_numba_function_src, +) from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import fgraph_to_python from pytensor.scalar.basic import ScalarType @@ -50,11 +53,7 @@ def global_numba_func(func): return func -def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs): - kwargs.setdefault("cache", True) - kwargs.setdefault("no_cpython_wrapper", False) - kwargs.setdefault("no_cfunc_wrapper", False) - +def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs): if fastmath is None: if config.numba__fastmath: # Opinionated default on fastmath flags @@ -69,6 +68,12 @@ def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs): else: fastmath = False + if final_function: + kwargs.setdefault("cache", True) + # else: + # kwargs.setdefault("no_cpython_wrapper", True) + # kwargs.setdefault("no_cfunc_wrapper", True) + # Suppress cache warning for internal functions # We have to add an ansi escape code for optional bold text by numba warnings.filterwarnings( @@ -82,7 +87,7 @@ def numba_njit(*args, fastmath=None, register_jitable: bool = True, **kwargs): category=NumbaWarning, ) - func = _register_jitable if register_jitable else numba.njit + func = register_jitable if final_function else _njit if len(args) > 0 and callable(args[0]): return func(*args[1:], fastmath=fastmath, **kwargs)(args[0]) else: @@ -384,8 +389,43 @@ def numba_funcify_FunctionGraph( **kwargs, ): def numba_funcify_njit(op, node, **kwargs): - jitable_func = numba_funcify(op, node=node, **kwargs) - return numba_njit(lambda *args: jitable_func(*args), register_jitable=False) + jitable_func_and_key = numba_funcify(op, node=node, **kwargs) + from collections.abc import Callable + + match jitable_func_and_key: + case (Callable(), str()): + jitable_func, key = jitable_func_and_key + case (Callable(), int()): + # Default key for Ops that return an integer + jitable_func, int_key = jitable_func_and_key + key = sha256( + str((type(op), op._props_dict(), int_key)).encode() + ).hexdigest() + case Callable(): + jitable_func, key = jitable_func_and_key, None + warnings.warn( + f"No cache key returned by numba_funcify of op {op}. This function won't be cached by Numba" + ) + case _: + raise TypeError( + f"numpy_funcify should return a callable or a callable, key pair, got {jitable_func_and_key}" + ) + + if 0 and key is not None: + # To force numba to use our cache, we must compile the function so that any closure + # becomes a global variable... + op_name = op.__class__.__name__ + cached_func = compile_and_cache_numba_function_src( + src=f"def {op_name}(*args): return jitable_func(*args)", + function_name=op_name, + global_env=globals() | dict(jitable_func=jitable_func), + key=key, + ) + return numba_njit(cached_func, final_function=True, cache=True) + else: + return numba_njit( + lambda *args: jitable_func(*args), final_function=True, cache=False + ) return fgraph_to_python( fgraph, @@ -410,7 +450,7 @@ def dispatch_deepcopyop(x): @numba_funcify.register(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): - return deepcopyop + return deepcopyop, 0 @numba_funcify.register(MakeSlice) @@ -439,7 +479,7 @@ def numba_funcify_Shape_i(op, **kwargs): def shape_i(x): return np.asarray(np.shape(x)[i]) - return shape_i + return shape_i, 0 @numba_funcify.register(SortOp) @@ -543,7 +583,7 @@ def reshape(x, shape): numba_ndarray.to_fixed_tuple(shape, ndim), ) - return reshape + return reshape, 0 @numba_funcify.register(SpecifyShape) @@ -571,9 +611,8 @@ def specify_shape(x, {create_arg_string(shape_input_names)}): func, "specify_shape", globals(), - key=hash_from_code(func), ) - return numba_njit(specify_shape) + return numba_njit(specify_shape), hash_from_code(func) def int_to_float_fn(inputs, out_dtype): diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index ef57225609..4a9140cc41 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -288,10 +288,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): ), ) ) - core_op_key = sha256(core_op_key.encode()).hexdigest() - core_op_fn = store_core_outputs( - scalar_op_fn, nin=nin, nout=nout, core_op_key=core_op_key - ) + core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) @@ -342,27 +339,8 @@ def elemwise(*inputs): def ov_elemwise(*inputs): return elemwise_wrapper - # TODO: Also input dtypes in key - elemwise_key = "_".join( - map( - str, - ( - "Elemwise", - core_op_key, - input_bc_patterns, - inplace_pattern, - ), - ) - ) - elemwise_key = sha256(elemwise_key.encode()).hexdigest() - f = compile_and_cache_numba_function_src( - "def f(*inputs): return elemwise(*inputs)", - "f", - {**globals(), **{"elemwise": elemwise}}, - key=elemwise_key, - ) - - return numba_njit(f) + elemwise_key = sha256(f"Elemwise2{core_op_key}".encode()).hexdigest() + return elemwise, elemwise_key @numba_funcify.register(Sum) @@ -470,7 +448,7 @@ def dimshuffle(x): return as_strided(x, shape=new_shape, strides=new_strides) - return dimshuffle + return dimshuffle, 0 @numba_funcify.register(Softmax) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 21df44d3a8..0d61158061 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -367,4 +367,4 @@ def check_and_raise(x, *conditions): raise error(msg) return x - return check_and_raise + return check_and_raise, 0 diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 29189b2f80..ee2960a1c4 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -136,13 +136,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}): # signature = create_numba_signature(node, force_scalar=True) - return numba_basic.numba_njit( - # signature, - # Functions that call a function pointer can't be cached - no_cfunc_wrapper=True, - no_cpython_wrapper=True, - register_jitable=False, - )(scalar_op_fn) + return numba_basic.numba_njit(scalar_op_fn) @numba_funcify.register(Switch) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index f9ff3ebeee..ce7f8fc3a1 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -101,9 +101,8 @@ def {function_name}({", ".join(input_names)}): subtensor_def_src, function_name=function_name, global_env=globals() | {"np": np}, - key=hash_from_code(subtensor_def_src), ) - return numba_njit(func, boundscheck=True) + return numba_njit(func, boundscheck=True), hash_from_code(subtensor_def_src) @numba_funcify.register(AdvancedSubtensor) @@ -350,7 +349,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs): return x if inplace: - return advancedincsubtensor1_inplace + return advancedincsubtensor1_inplace, 0 else: @@ -359,4 +358,4 @@ def advancedincsubtensor1(x, vals, idxs): x = x.copy() return advancedincsubtensor1_inplace(x, vals, idxs) - return advancedincsubtensor1 + return advancedincsubtensor1, 0 diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index e37a133f1c..91ca7ab15e 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -55,10 +55,12 @@ def allocempty({", ".join(shape_var_names)}): alloc_def_src, "allocempty", {**globals(), **global_env}, - key=hash_from_code(alloc_def_src), ) - return numba_basic.numba_njit(alloc_fn) + return ( + numba_basic.numba_njit(alloc_fn), + hash_from_code(alloc_def_src), + ) @numba_funcify.register(Alloc) @@ -102,10 +104,12 @@ def alloc(val, {", ".join(shape_var_names)}): alloc_def_src, "alloc", {**globals(), **global_env}, - key=hash_from_code(alloc_def_src), ) - return numba_basic.numba_njit(alloc_fn) + return ( + numba_basic.numba_njit(alloc_fn), + hash_from_code(alloc_def_src), + ) @numba_funcify.register(ARange) @@ -130,7 +134,7 @@ def numba_funcify_Join(op, **kwargs): def join(axis, *tensors): return np.concatenate(tensors, axis.item()) - return join + return join, 0 @numba_funcify.register(Split) @@ -139,7 +143,7 @@ def numba_funcify_Split(op, **kwargs): def split(tensor, axis, indices): return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item()) - return split + return split, 0 @numba_funcify.register(ExtractDiag) @@ -226,10 +230,8 @@ def makevector({", ".join(input_names)}): makevector_def_src, "makevector", {**globals(), **global_env}, - key=f"MakeVector({op.dtype})", ) - - return numba_basic.numba_njit(makevector_fn) + return numba_basic.numba_njit(makevector_fn), hash_from_code(makevector_def_src) @numba_funcify.register(TensorFromScalar) @@ -238,7 +240,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs): def tensor_from_scalar(x): return np.array(x) - return tensor_from_scalar + return tensor_from_scalar, 0 @numba_funcify.register(ScalarFromTensor) @@ -247,4 +249,4 @@ def numba_funcify_ScalarFromTensor(op, **kwargs): def scalar_from_tensor(x): return x.item() - return scalar_from_tensor + return scalar_from_tensor, 0 diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index a25ba507ab..332a165539 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -23,9 +23,7 @@ def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() -def store_core_outputs( - core_op_fn: Callable, nin: int, nout: int, core_op_key=None -) -> Callable: +def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @njit @@ -55,19 +53,12 @@ def store_core_outputs({inp_signature}, {out_signature}): """ global_env = {"core_op_fn": core_op_fn} - key = "_".join(("store_core_outputs", core_op_key)) if core_op_key else None func = compile_and_cache_numba_function_src( func_src, "store_core_outputs", {**globals(), **global_env}, - key=key, - ) - return numba_basic.numba_njit( - func, - register_jitable=True, - no_cpython_wrapper=True, - no_cfunc_wrapper=True, ) + return numba_basic.numba_njit(func) _jit_options = { diff --git a/pytensor/mod1.py b/pytensor/mod1.py new file mode 100644 index 0000000000..8ba7e5abe2 --- /dev/null +++ b/pytensor/mod1.py @@ -0,0 +1,6 @@ +import numba + + +@numba.extending.register_jitable(cache=True) +def foo(x): + return x + 1 diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 2fd8eca2ef..3ffa1aa267 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -127,7 +127,7 @@ _specialize_rewrites.position_cutoff = 2.01 _specialize_rewrites = optdb.query(_specialize_rewrites) -_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"], exclude=["inplace"]) _fast_run_rewrites = optdb.query(_fast_run_rewrites) @@ -476,6 +476,8 @@ def test_basic(self): x = scalar() y = scalar() f = function([x, y], assert_op(x, eq(x, y)), mode=mode) + f.dprint() + return assert f(1, 1) == 1 with pytest.raises(AssertionError): f(1, 0) From 022e1879e448c1f921917d239501f1213164061d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 1 Oct 2025 16:57:46 +0200 Subject: [PATCH 07/12] .More hacking around --- pytensor/link/numba/cache.py | 67 +-- pytensor/link/numba/compile.py | 196 +++++++ pytensor/link/numba/dispatch/__init__.py | 1 + pytensor/link/numba/dispatch/basic.py | 514 ++++-------------- pytensor/link/numba/dispatch/blockwise.py | 3 +- pytensor/link/numba/dispatch/elemwise.py | 22 +- pytensor/link/numba/dispatch/extra_ops.py | 52 +- .../dispatch/linalg/solve/tridiagonal.py | 2 +- .../link/numba/dispatch/linalg/solve/utils.py | 4 +- pytensor/link/numba/dispatch/linalg/utils.py | 8 +- pytensor/link/numba/dispatch/nlinalg.py | 22 +- pytensor/link/numba/dispatch/random.py | 39 +- pytensor/link/numba/dispatch/scalar.py | 39 +- pytensor/link/numba/dispatch/scan.py | 8 +- pytensor/link/numba/dispatch/shape.py | 85 +++ pytensor/link/numba/dispatch/signal/conv.py | 2 +- pytensor/link/numba/dispatch/slinalg.py | 3 +- pytensor/link/numba/dispatch/subtensor.py | 8 +- pytensor/link/numba/dispatch/tensor_basic.py | 38 +- .../link/numba/dispatch/vectorize_codegen.py | 6 +- pytensor/link/numba/linker.py | 2 +- tests/link/numba/test_basic.py | 8 +- tests/link/numba/test_tensor_basic.py | 2 +- 23 files changed, 558 insertions(+), 573 deletions(-) create mode 100644 pytensor/link/numba/compile.py create mode 100644 pytensor/link/numba/dispatch/shape.py diff --git a/pytensor/link/numba/cache.py b/pytensor/link/numba/cache.py index 7a9535e64f..ad3a8dfe1b 100644 --- a/pytensor/link/numba/cache.py +++ b/pytensor/link/numba/cache.py @@ -1,64 +1,21 @@ -from collections.abc import Callable +import weakref +from hashlib import sha256 from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Any from numba.core.caching import CacheImpl, _CacheLocator from pytensor import config +from pytensor.graph.basic import Apply NUMBA_PYTENSOR_CACHE_ENABLED = True NUMBA_CACHE_PATH = config.base_compiledir / "numba" NUMBA_CACHE_PATH.mkdir(exist_ok=True) -CACHED_SRC_FUNCTIONS = {} - - -def compile_and_cache_numba_function_src( - src: str, - function_name: str, - global_env: dict[Any, Any] | None = None, - local_env: dict[Any, Any] | None = None, - key: str | None = None, -) -> Callable: - if key is not None: - filename = NUMBA_CACHE_PATH / key - with filename.open("wb") as f: - f.write(src.encode()) - else: - with NamedTemporaryFile(delete=False) as f: - filename = f.name - f.write(src.encode()) - - if global_env is None: - global_env = {} - - if local_env is None: - local_env = {} - - mod_code = compile(src, filename, mode="exec") - exec(mod_code, global_env, local_env) - - res = local_env[function_name] - res.__source__ = src # type: ignore - - if key is not None: - CACHED_SRC_FUNCTIONS[res] = key - return res - - -def cache_numba_function( - fn, - key: str | None = None, -) -> Callable: - if key is not None: - CACHED_SRC_FUNCTIONS[fn] = key - return fn +CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary() class NumbaPyTensorCacheLocator(_CacheLocator): def __init__(self, py_func, py_file, hash): - # print(f"New locator {py_func=}, {py_file=}, {hash=}") self._py_func = py_func self._py_file = py_file self._hash = hash @@ -101,3 +58,19 @@ def from_function(cls, py_func, py_file): CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator) + + +def cache_node_key(node: Apply, extra_key="") -> str: + op = node.op + return sha256( + str( + ( + # Op signature + (type(op), op._props_dict() if hasattr(op, "_props_dict") else ""), + # Node signature + tuple((type(inp_type := inp.type), inp_type) for inp in node.inputs), + # Extra key given by the caller + extra_key, + ), + ).encode() + ).hexdigest() diff --git a/pytensor/link/numba/compile.py b/pytensor/link/numba/compile.py new file mode 100644 index 0000000000..01eaca26e6 --- /dev/null +++ b/pytensor/link/numba/compile.py @@ -0,0 +1,196 @@ +import warnings +from collections.abc import Callable +from typing import Any + +import numba +import numpy as np +from numba import NumbaWarning +from numba import njit as _njit +from numba.core.extending import register_jitable + +from pytensor import config +from pytensor.graph import Apply, FunctionGraph, Type +from pytensor.link.numba.cache import CACHED_SRC_FUNCTIONS +from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType +from pytensor.scalar import ScalarType +from pytensor.sparse import SparseTensorType +from pytensor.tensor import TensorType + + +def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs): + if fastmath is None: + if config.numba__fastmath: + # Opinionated default on fastmath flags + # https://llvm.org/docs/LangRef.html#fast-math-flags + fastmath = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # no-signed zeros + } + else: + fastmath = False + + if final_function: + kwargs.setdefault("cache", True) + else: + kwargs.setdefault("no_cpython_wrapper", True) + kwargs.setdefault("no_cfunc_wrapper", True) + + # Suppress cache warning for internal functions + # We have to add an ansi escape code for optional bold text by numba + warnings.filterwarnings( + "ignore", + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + "Cannot cache compiled function " + '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" ' + "as it uses dynamic globals" + ), + category=NumbaWarning, + ) + + func = _njit if final_function else register_jitable + if len(args) > 0 and callable(args[0]): + return func(*args[1:], fastmath=fastmath, **kwargs)(args[0]) + else: + return func(*args, fastmath=fastmath, **kwargs) + + +def compile_and_cache_numba_function_src( + src: str, + function_name: str, + global_env: dict[Any, Any] | None = None, + local_env: dict[Any, Any] | None = None, + key: str | None = None, +) -> Callable: + # if key is not None: + # filename = NUMBA_CACHE_PATH / key + # with filename.open("wb") as f: + # f.write(src.encode()) + # else: + # with NamedTemporaryFile(delete=False) as f: + # filename = f.name + # f.write(src.encode()) + + if global_env is None: + global_env = {} + + if local_env is None: + local_env = {} + + mod_code = compile(src, "", mode="exec") + exec(mod_code, global_env, local_env) + + res = local_env[function_name] + res.__source__ = src # type: ignore + + if key is not None: + CACHED_SRC_FUNCTIONS[res] = key + return res + + +def get_numba_type( + pytensor_type: Type, + layout: str = "A", + force_scalar: bool = False, + reduce_to_scalar: bool = False, +) -> numba.types.Type: + r"""Create a Numba type object for a :class:`Type`. + + Parameters + ---------- + pytensor_type + The :class:`Type` to convert. + layout + The :class:`numpy.ndarray` layout to use. + force_scalar + Ignore dimension information and return the corresponding Numba scalar types. + reduce_to_scalar + Return Numba scalars for zero dimensional :class:`TensorType`\s. + """ + + if isinstance(pytensor_type, TensorType): + dtype = pytensor_type.numpy_dtype + numba_dtype = numba.from_dtype(dtype) + if force_scalar or ( + reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0 + ): + return numba_dtype + return numba.types.Array(numba_dtype, pytensor_type.ndim, layout) + elif isinstance(pytensor_type, ScalarType): + dtype = np.dtype(pytensor_type.dtype) + numba_dtype = numba.from_dtype(dtype) + return numba_dtype + elif isinstance(pytensor_type, SparseTensorType): + dtype = pytensor_type.numpy_dtype + numba_dtype = numba.from_dtype(dtype) + if pytensor_type.format == "csr": + return CSRMatrixType(numba_dtype) + if pytensor_type.format == "csc": + return CSCMatrixType(numba_dtype) + + raise NotImplementedError() + else: + raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") + + +def create_numba_signature( + node_or_fgraph: FunctionGraph | Apply, + force_scalar: bool = False, + reduce_to_scalar: bool = False, +) -> numba.types.Type: + """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`.""" + input_types = [ + get_numba_type( + inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar + ) + for inp in node_or_fgraph.inputs + ] + + output_types = [ + get_numba_type( + out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar + ) + for out in node_or_fgraph.outputs + ] + + if len(output_types) > 1: + return numba.types.Tuple(output_types)(*input_types) + elif len(output_types) == 1: + return output_types[0](*input_types) + else: + return numba.types.void(*input_types) + + +def create_tuple_creator(f, n): + """Construct a compile-time ``tuple``-comprehension-like loop. + + See https://github.com/numba/numba/issues/2771#issuecomment-414358902 + """ + assert n > 0 + + f = numba_njit(f) + + @numba_njit + def creator(args): + return (f(0, *args),) + + for i in range(1, n): + + @numba_njit + def creator(args, creator=creator, i=i): + return (*creator(args), f(i, *args)) + + return numba_njit(lambda *args: creator(args)) + + +def create_tuple_string(x): + args = ", ".join(x + ([""] if len(x) == 1 else [])) + return f"({args})" + + +def create_arg_string(x): + args = ", ".join(x) + return args diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1fefb1d06d..1541331d31 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -9,6 +9,7 @@ import pytensor.link.numba.dispatch.random import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scalar +import pytensor.link.numba.dispatch.shape import pytensor.link.numba.dispatch.signal import pytensor.link.numba.dispatch.slinalg import pytensor.link.numba.dispatch.sparse diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0d822177f2..267f40b608 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,177 +1,39 @@ import operator import sys import warnings -from copy import copy +from collections.abc import Callable from functools import singledispatch -from hashlib import sha256 -from textwrap import dedent import numba -import numba.np.unsafe.ndarray as numba_ndarray import numpy as np -import scipy -import scipy.special from llvmlite import ir -from numba import njit as _njit from numba import types -from numba.core.errors import NumbaWarning, TypingError +from numba.core.errors import TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import box, overload, register_jitable +from numba.extending import box from pytensor import In, config from pytensor.compile import NUMBA from pytensor.compile.builders import OpFromGraph from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.ops import DeepCopyOp -from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph -from pytensor.graph.type import Type from pytensor.ifelse import IfElse from pytensor.link.numba.cache import ( + cache_node_key, +) +from pytensor.link.numba.compile import ( compile_and_cache_numba_function_src, + get_numba_type, + numba_njit, ) -from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import fgraph_to_python -from pytensor.scalar.basic import ScalarType -from pytensor.sparse import SparseTensorType +from pytensor.tensor import TensorType from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -from pytensor.tensor.slinalg import Solve from pytensor.tensor.sort import ArgSortOp, SortOp -from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import MakeSlice, NoneConst -from pytensor.utils import hash_from_code - - -def global_numba_func(func): - """Use to return global numba functions in numba_funcify_*. - - This allows tests to remove the compilation using mock. - """ - return func - - -def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs): - if fastmath is None: - if config.numba__fastmath: - # Opinionated default on fastmath flags - # https://llvm.org/docs/LangRef.html#fast-math-flags - fastmath = { - "arcp", # Allow Reciprocal - "contract", # Allow floating-point contraction - "afn", # Approximate functions - "reassoc", - "nsz", # no-signed zeros - } - else: - fastmath = False - - if final_function: - kwargs.setdefault("cache", True) - # else: - # kwargs.setdefault("no_cpython_wrapper", True) - # kwargs.setdefault("no_cfunc_wrapper", True) - - # Suppress cache warning for internal functions - # We have to add an ansi escape code for optional bold text by numba - warnings.filterwarnings( - "ignore", - message=( - "(\x1b\\[1m)*" # ansi escape code for bold text - "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" ' - "as it uses dynamic globals" - ), - category=NumbaWarning, - ) - - func = register_jitable if final_function else _njit - if len(args) > 0 and callable(args[0]): - return func(*args[1:], fastmath=fastmath, **kwargs)(args[0]) - else: - return func(*args, fastmath=fastmath, **kwargs) - - -def numba_vectorize(*args, **kwargs): - if len(args) > 0 and callable(args[0]): - return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0]) - - return numba.vectorize(*args, cache=config.numba__cache, **kwargs) - - -def get_numba_type( - pytensor_type: Type, - layout: str = "A", - force_scalar: bool = False, - reduce_to_scalar: bool = False, -) -> numba.types.Type: - r"""Create a Numba type object for a :class:`Type`. - - Parameters - ---------- - pytensor_type - The :class:`Type` to convert. - layout - The :class:`numpy.ndarray` layout to use. - force_scalar - Ignore dimension information and return the corresponding Numba scalar types. - reduce_to_scalar - Return Numba scalars for zero dimensional :class:`TensorType`\s. - """ - - if isinstance(pytensor_type, TensorType): - dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) - if force_scalar or ( - reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0 - ): - return numba_dtype - return numba.types.Array(numba_dtype, pytensor_type.ndim, layout) - elif isinstance(pytensor_type, ScalarType): - dtype = np.dtype(pytensor_type.dtype) - numba_dtype = numba.from_dtype(dtype) - return numba_dtype - elif isinstance(pytensor_type, SparseTensorType): - dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) - if pytensor_type.format == "csr": - return CSRMatrixType(numba_dtype) - if pytensor_type.format == "csc": - return CSCMatrixType(numba_dtype) - - raise NotImplementedError() - else: - raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") - - -def create_numba_signature( - node_or_fgraph: FunctionGraph | Apply, - force_scalar: bool = False, - reduce_to_scalar: bool = False, -) -> numba.types.Type: - """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`.""" - input_types = [ - get_numba_type( - inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) - for inp in node_or_fgraph.inputs - ] - - output_types = [ - get_numba_type( - out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) - for out in node_or_fgraph.outputs - ] - - if len(output_types) > 1: - return numba.types.Tuple(output_types)(*input_types) - elif len(output_types) == 1: - return output_types[0](*input_types) - else: - return numba.types.void(*input_types) +from pytensor.tensor.type_other import MakeSlice def slice_new(self, start, stop, step): @@ -251,36 +113,53 @@ def impl_to_scalar(x): raise TypingError(f"{x} must be a scalar compatible type.") -def create_tuple_creator(f, n): - """Construct a compile-time ``tuple``-comprehension-like loop. +@numba.extending.intrinsic +def direct_cast(typingctx, val, typ): + if isinstance(typ, numba.types.TypeRef): + casted = typ.instance_type + elif isinstance(typ, numba.types.DTypeSpec): + casted = typ.dtype + else: + casted = typ - See https://github.com/numba/numba/issues/2771#issuecomment-414358902 - """ - assert n > 0 + sig = casted(casted, typ) - f = numba_njit(f) + def codegen(context, builder, signature, args): + val, _ = args + context.nrt.incref(builder, signature.return_type, val) + return val - @numba_njit - def creator(args): - return (f(0, *args),) + return sig, codegen - for i in range(1, n): - @numba_njit - def creator(args, creator=creator, i=i): - return (*creator(args), f(i, *args)) +def int_to_float_fn(inputs, out_dtype): + """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" + + if ( + all(inp.type.dtype == out_dtype for inp in inputs) + and np.dtype(out_dtype).kind == "f" + ): - return numba_njit(lambda *args: creator(args)) + @numba_njit(inline="always") + def inputs_cast(x): + return x + elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): + args_dtype = np.dtype(f"f{out_dtype.itemsize}") -def create_tuple_string(x): - args = ", ".join(x + ([""] if len(x) == 1 else [])) - return f"({args})" + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) + + else: + args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) + args_dtype = np.dtype(f"f{args_dtype_sz}") + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) -def create_arg_string(x): - args = ", ".join(x) - return args + return inputs_cast @singledispatch @@ -348,6 +227,58 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): return generate_fallback_impl(op, node, storage_map, **kwargs) +def numba_funcify_njit(op, node, **kwargs): + jitable_func_and_key = numba_funcify(op, node=node, **kwargs) + + match jitable_func_and_key: + case Callable(): + jitable_func = jitable_func_and_key + key = cache_node_key(node) + case (Callable(), str() | int()): + jitable_func, funcify_key = jitable_func_and_key + key = cache_node_key(node, funcify_key) + case (Callable(), None): + # We were explicitly told by the dispatch not to try and cache this function + jitable_func, key = jitable_func_and_key + case _: + raise TypeError( + f"numpy_funcify should return a callable or a callable, key pair, got {jitable_func_and_key}" + ) + + if key is not None: + # To force numba to use our cache, we must compile the function so that any closure + # becomes a global variable... + op_name = op.__class__.__name__ + cached_func = compile_and_cache_numba_function_src( + src=f"def {op_name}(*args): return jitable_func(*args)", + function_name=op_name, + global_env=globals() | dict(jitable_func=jitable_func), + key=key, + ) + return numba_njit(cached_func, final_function=True, cache=True) + else: + return numba_njit( + lambda *args: jitable_func(*args), final_function=True, cache=False + ) + + +@numba_funcify.register(FunctionGraph) +def numba_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="numba_funcified_fgraph", + jit_nodes: bool = False, + **kwargs, +): + return fgraph_to_python( + fgraph, + op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify, + type_conversion_fn=numba_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + @numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) @@ -377,80 +308,25 @@ def opfromgraph(*inputs): def opfromgraph(*inputs): return fgraph_fn(*inputs) - return opfromgraph - - -@numba_funcify.register(FunctionGraph) -def numba_funcify_FunctionGraph( - fgraph, - node=None, - fgraph_name="numba_funcified_fgraph", - jit_nodes: bool = False, - **kwargs, -): - def numba_funcify_njit(op, node, **kwargs): - jitable_func_and_key = numba_funcify(op, node=node, **kwargs) - from collections.abc import Callable - - match jitable_func_and_key: - case (Callable(), str()): - jitable_func, key = jitable_func_and_key - case (Callable(), int()): - # Default key for Ops that return an integer - jitable_func, int_key = jitable_func_and_key - key = sha256( - str((type(op), op._props_dict(), int_key)).encode() - ).hexdigest() - case Callable(): - jitable_func, key = jitable_func_and_key, None - warnings.warn( - f"No cache key returned by numba_funcify of op {op}. This function won't be cached by Numba" - ) - case _: - raise TypeError( - f"numpy_funcify should return a callable or a callable, key pair, got {jitable_func_and_key}" - ) - - if 0 and key is not None: - # To force numba to use our cache, we must compile the function so that any closure - # becomes a global variable... - op_name = op.__class__.__name__ - cached_func = compile_and_cache_numba_function_src( - src=f"def {op_name}(*args): return jitable_func(*args)", - function_name=op_name, - global_env=globals() | dict(jitable_func=jitable_func), - key=key, - ) - return numba_njit(cached_func, final_function=True, cache=True) - else: - return numba_njit( - lambda *args: jitable_func(*args), final_function=True, cache=False - ) - - return fgraph_to_python( - fgraph, - op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify, - type_conversion_fn=numba_typify, - fgraph_name=fgraph_name, - **kwargs, - ) - + # We can't cache this correctly until we can define a key for it + return opfromgraph, None -def deepcopyop(x): - return copy(x) +@numba_funcify.register(DeepCopyOp) +def numba_funcify_DeepCopyOp(op, node, **kwargs): + if isinstance(node.inputs[0].type, TensorType): -@overload(deepcopyop) -def dispatch_deepcopyop(x): - if isinstance(x, types.Array): - return lambda x: np.copy(x) + @numba_njit + def deepcopy_fn(x): + return np.copy(x) - return lambda x: x + else: + @numba_njit + def deepcopy_fn(x): + return x -@numba_funcify.register(DeepCopyOp) -def numba_funcify_DeepCopyOp(op, node, **kwargs): - return deepcopyop, 0 + return deepcopy_fn @numba_funcify.register(MakeSlice) @@ -462,26 +338,6 @@ def makeslice(*x): return makeslice -@numba_funcify.register(Shape) -def numba_funcify_Shape(op, **kwargs): - @numba_njit - def shape(x): - return np.asarray(np.shape(x)) - - return shape - - -@numba_funcify.register(Shape_i) -def numba_funcify_Shape_i(op, **kwargs): - i = op.i - - @numba_njit - def shape_i(x): - return np.asarray(np.shape(x)[i]) - - return shape_i, 0 - - @numba_funcify.register(SortOp) def numba_funcify_SortOp(op, node, **kwargs): @numba_njit @@ -544,107 +400,6 @@ def argort_vec(X, axis): return argsort_f_kind(kind) -@numba.extending.intrinsic -def direct_cast(typingctx, val, typ): - if isinstance(typ, numba.types.TypeRef): - casted = typ.instance_type - elif isinstance(typ, numba.types.DTypeSpec): - casted = typ.dtype - else: - casted = typ - - sig = casted(casted, typ) - - def codegen(context, builder, signature, args): - val, _ = args - context.nrt.incref(builder, signature.return_type, val) - return val - - return sig, codegen - - -@numba_funcify.register(Reshape) -def numba_funcify_Reshape(op, **kwargs): - ndim = op.ndim - - if ndim == 0: - - @numba_njit - def reshape(x, shape): - return np.asarray(x.item()) - - else: - - @numba_njit - def reshape(x, shape): - # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. - return np.reshape( - np.ascontiguousarray(np.asarray(x)), - numba_ndarray.to_fixed_tuple(shape, ndim), - ) - - return reshape, 0 - - -@numba_funcify.register(SpecifyShape) -def numba_funcify_SpecifyShape(op, node, **kwargs): - shape_inputs = node.inputs[1:] - shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] - - func_conditions = [ - f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" - for i, (node_dim_input, eval_dim_name) in enumerate( - zip(shape_inputs, shape_input_names, strict=True) - ) - if node_dim_input is not NoneConst - ] - - func = dedent( - f""" - def specify_shape(x, {create_arg_string(shape_input_names)}): - {"; ".join(func_conditions)} - return x - """ - ) - - specify_shape = compile_and_cache_numba_function_src( - func, - "specify_shape", - globals(), - ) - return numba_njit(specify_shape), hash_from_code(func) - - -def int_to_float_fn(inputs, out_dtype): - """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" - - if ( - all(inp.type.dtype == out_dtype for inp in inputs) - and np.dtype(out_dtype).kind == "f" - ): - - @numba_njit(inline="always") - def inputs_cast(x): - return x - - elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): - args_dtype = np.dtype(f"f{out_dtype.itemsize}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - else: - args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) - args_dtype = np.dtype(f"f{args_dtype_sz}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - return inputs_cast - - @numba_funcify.register(Dot) def numba_funcify_Dot(op, node, **kwargs): # Numba's `np.dot` does not support integer dtypes, so we need to cast to float. @@ -692,51 +447,6 @@ def dot_with_cast(x, y): return dot_with_cast -@numba_funcify.register(Solve) -def numba_funcify_Solve(op, node, **kwargs): - assume_a = op.assume_a - # check_finite = op.check_finite - - if assume_a != "gen": - lower = op.lower - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`compute_uv` argument to `numpy.linalg.svd`." - ), - UserWarning, - ) - - ret_sig = get_numba_type(node.outputs[0].type) - - @numba_njit - def solve(a, b): - with numba.objmode(ret=ret_sig): - ret = scipy.linalg.solve_triangular( - a, - b, - lower=lower, - # check_finite=check_finite - ) - return ret - - else: - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_njit - def solve(a, b): - return np.linalg.solve( - inputs_cast(a), - inputs_cast(b), - # assume_a=assume_a, - # check_finite=check_finite, - ).astype(out_dtype) - - return solve - - @numba_funcify.register(BatchedDot) def numba_funcify_BatchedDot(op, node, **kwargs): dtype = node.outputs[0].type.numpy_dtype diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index 45df8341ea..cfbf5e89e7 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -4,7 +4,8 @@ from numba.core.extending import overload from numba.np.unsafe.ndarray import to_fixed_tuple -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.compile import numba_njit +from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.vectorize_codegen import ( _jit_options, _vectorized, diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 4a9140cc41..4ed8636979 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -7,12 +7,12 @@ from numba.core.extending import overload from numpy.lib.stride_tricks import as_strided +import pytensor.link.numba.compile from pytensor.graph.op import Op -from pytensor.link.numba.cache import compile_and_cache_numba_function_src +from pytensor.link.numba.compile import compile_and_cache_numba_function_src, numba_njit from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( numba_funcify, - numba_njit, ) from pytensor.link.numba.dispatch.vectorize_codegen import ( _vectorized, @@ -249,7 +249,7 @@ def create_axis_apply_fn(fn, axis, ndim, dtype): reaxis_first = (*(i for i in range(ndim) if i != axis), axis) - @numba_basic.numba_njit(boundscheck=False) + @pytensor.link.numba.compile.numba_njit(boundscheck=False) def axis_apply_fn(x): x_reaxis = x.transpose(reaxis_first) @@ -424,7 +424,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs): if new_order == (): # Special case needed because of https://github.com/numba/numba/issues/9933 - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def squeeze_to_0d(x): return as_strided(x, shape=(), strides=()) @@ -432,7 +432,7 @@ def squeeze_to_0d(x): else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def dimshuffle(x): old_shape = x.shape old_strides = x.strides @@ -448,7 +448,7 @@ def dimshuffle(x): return as_strided(x, shape=new_shape, strides=new_strides) - return dimshuffle, 0 + return dimshuffle @numba_funcify.register(Softmax) @@ -467,7 +467,7 @@ def numba_funcify_Softmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit(boundscheck=False) + jit_fn = pytensor.link.numba.compile.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -499,7 +499,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit(boundscheck=False) + jit_fn = pytensor.link.numba.compile.numba_njit(boundscheck=False) reduce_sum = jit_fn(reduce_sum_py) else: reduce_sum = np.sum @@ -536,7 +536,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit(boundscheck=False) + jit_fn = pytensor.link.numba.compile.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -562,7 +562,7 @@ def numba_funcify_Argmax(op, node, **kwargs): if x_ndim == 0: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def argmax(x): return np.array(0, dtype="int64") @@ -582,7 +582,7 @@ def argmax(x): sl1 = slice(None, len(keep_axes)) sl2 = slice(len(keep_axes), None) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def argmax(x): # Not-reduced axes in front transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order)) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 0d61158061..74f7b27751 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -4,9 +4,11 @@ import numba import numpy as np +import pytensor.link.numba.compile from pytensor.graph import Apply +from pytensor.link.numba.compile import get_numba_type from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify +from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.raise_op import CheckAndRaise from pytensor.tensor import TensorVariable from pytensor.tensor.extra_ops import ( @@ -24,7 +26,7 @@ @numba_funcify.register(Bartlett) def numba_funcify_Bartlett(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def bartlett(x): return np.bartlett(numba_basic.to_scalar(x)) @@ -49,13 +51,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False) + @pytensor.link.numba.compile.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -73,13 +75,13 @@ def cumop(x): else: if axis is None or ndim == 1: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False) + @pytensor.link.numba.compile.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -99,7 +101,7 @@ def cumop(x): @numba_funcify.register(FillDiagonal) def numba_funcify_FillDiagonal(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def filldiagonal(a, val): np.fill_diagonal(a, val) return a @@ -109,7 +111,7 @@ def filldiagonal(a, val): @numba_funcify.register(FillDiagonalOffset) def numba_funcify_FillDiagonalOffset(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def filldiagonaloffset(a, val, offset): height, width = a.shape @@ -144,25 +146,25 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs): if mode == "raise": - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def mode_fn(*args): raise ValueError("invalid entry in coordinates array") elif mode == "wrap": - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def mode_fn(new_arr, i, j, v, d): new_arr[i, j] = v % d elif mode == "clip": - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def mode_fn(new_arr, i, j, v, d): new_arr[i, j] = min(max(v, 0), d - 1) if node.inputs[0].ndim == 0: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def ravelmultiindex(*inp): shape = inp[-1] arr = np.stack(inp[:-1]) @@ -178,7 +180,7 @@ def ravelmultiindex(*inp): else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def ravelmultiindex(*inp): shape = inp[-1] arr = np.stack(inp[:-1]) @@ -217,7 +219,7 @@ def numba_funcify_Repeat(op, node, **kwargs): ret_sig = get_numba_type(node.outputs[0].type) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def repeatop(x, repeats): with numba.objmode(ret=ret_sig): ret = np.repeat(x, repeats, axis) @@ -228,13 +230,13 @@ def repeatop(x, repeats): if repeats_ndim == 0: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def repeatop(x, repeats): return np.repeat(x, repeats.item()) else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def repeatop(x, repeats): return np.repeat(x, repeats) @@ -259,7 +261,7 @@ def numba_funcify_Unique(op, node, **kwargs): if not use_python: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def unique(x): return np.unique(x) @@ -277,7 +279,7 @@ def unique(x): else: ret_sig = get_numba_type(node.outputs[0].type) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def unique(x): with numba.objmode(ret=ret_sig): ret = np.unique(x, return_index, return_inverse, return_counts, axis) @@ -297,17 +299,17 @@ def numba_funcify_UnravelIndex(op, node, **kwargs): if len(node.outputs) == 1: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def maybe_expand_dim(arr): return arr else: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def maybe_expand_dim(arr): return np.expand_dims(arr, 1) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def unravelindex(arr, shape): a = np.ones(len(shape), dtype=np.int64) a[1:] = shape[:0:-1] @@ -340,7 +342,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ret_sig = get_numba_type(node.outputs[0].type) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def searchsorted(a, v, sorter): with numba.objmode(ret=ret_sig): ret = np.searchsorted(a, v, side, sorter) @@ -348,7 +350,7 @@ def searchsorted(a, v, sorter): else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def searchsorted(a, v): return np.searchsorted(a, v, side) @@ -360,11 +362,11 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs): error = op.exc_type msg = op.msg - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def check_and_raise(x, *conditions): for cond in conditions: if not cond: raise error(msg) return x - return check_and_raise, 0 + return check_and_raise diff --git a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py index 9575dd7d56..82bf2f009d 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py +++ b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py @@ -6,8 +6,8 @@ from numpy import ndarray from scipy import linalg +from pytensor.link.numba.compile import numba_njit from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.link.numba.dispatch.linalg._LAPACK import ( _LAPACK, _get_underlying_float, diff --git a/pytensor/link/numba/dispatch/linalg/solve/utils.py b/pytensor/link/numba/dispatch/linalg/solve/utils.py index ec6c4ef213..5eedd3ecba 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/utils.py +++ b/pytensor/link/numba/dispatch/linalg/solve/utils.py @@ -1,9 +1,9 @@ from scipy import linalg -from pytensor.link.numba.dispatch import basic as numba_basic +import pytensor.link.numba.compile -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _solve_check_input_shapes(A, B): if A.shape[0] != B.shape[0]: raise linalg.LinAlgError("Dimensions of A and B do not conform") diff --git a/pytensor/link/numba/dispatch/linalg/utils.py b/pytensor/link/numba/dispatch/linalg/utils.py index b15888abd6..568fa34235 100644 --- a/pytensor/link/numba/dispatch/linalg/utils.py +++ b/pytensor/link/numba/dispatch/linalg/utils.py @@ -6,7 +6,7 @@ from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numpy.linalg import LinAlgError -from pytensor.link.numba.dispatch import basic as numba_basic +import pytensor.link.numba.compile from pytensor.link.numba.dispatch.linalg._LAPACK import ( _LAPACK, _get_underlying_float, @@ -14,13 +14,13 @@ ) -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _copy_to_fortran_order_even_if_1d(x): # Numba's _copy_to_fortran_order doesn't do anything for vectors return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x) -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _trans_char_to_int(trans): if trans not in [0, 1, 2]: raise ValueError('Parameter "trans" should be one of 0, 1, 2') @@ -53,7 +53,7 @@ def _check_scipy_linalg_matrix(a, func_name): raise numba.TypingError(msg, highlighting=False) -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _solve_check(n, info, lamch=False, rcond=None): """ Check arguments during the different steps of the solution phase diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 58fe0e3719..c85fdbd3e8 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -3,9 +3,9 @@ import numba import numpy as np -from pytensor.link.numba.dispatch import basic as numba_basic +import pytensor.link.numba.compile +from pytensor.link.numba.compile import get_numba_type from pytensor.link.numba.dispatch.basic import ( - get_numba_type, int_to_float_fn, numba_funcify, ) @@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs): if not compute_uv: - @numba_basic.numba_njit() + @pytensor.link.numba.compile.numba_njit() def svd(x): _, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices) return ret else: - @numba_basic.numba_njit() + @pytensor.link.numba.compile.numba_njit() def svd(x): return np.linalg.svd(inputs_cast(x), full_matrices) @@ -49,7 +49,7 @@ def numba_funcify_Det(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def det(x): return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype) @@ -63,7 +63,7 @@ def numba_funcify_SLogDet(op, node, **kwargs): inputs_cast = int_to_float_fn(node.inputs, out_dtype_1) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def slogdet(x): sign, det = np.linalg.slogdet(inputs_cast(x)) return ( @@ -81,7 +81,7 @@ def numba_funcify_Eig(op, node, **kwargs): inputs_cast = int_to_float_fn(node.inputs, out_dtype_1) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def eig(x): out = np.linalg.eig(inputs_cast(x)) return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2)) @@ -107,7 +107,7 @@ def numba_funcify_Eigh(op, node, **kwargs): [get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)] ) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def eigh(x): with numba.objmode(ret=ret_sig): out = np.linalg.eigh(x, UPLO=uplo) @@ -116,7 +116,7 @@ def eigh(x): else: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def eigh(x): return np.linalg.eigh(x) @@ -128,7 +128,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def matrix_inverse(x): return np.linalg.inv(inputs_cast(x)).astype(out_dtype) @@ -140,7 +140,7 @@ def numba_funcify_MatrixPinv(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def matrixpinv(x): return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 36618ceb26..a700de5689 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -9,10 +9,11 @@ from numba import types from numba.core.extending import overload +import pytensor.link.numba.compile import pytensor.tensor.random.basic as ptr from pytensor.graph import Apply from pytensor.graph.op import Op -from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.compile import numba_njit from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify from pytensor.link.numba.dispatch.vectorize_codegen import ( _jit_options, @@ -84,14 +85,14 @@ def {name}(rng, {input_signature}): """) func = compile_function_src(func_src, name, {**globals()}) - return numba_basic.numba_njit(func) + return numba_njit(func) @numba_core_rv_funcify.register(ptr.BernoulliRV) def numba_core_BernoulliRV(op, node): out_dtype = node.outputs[1].type.numpy_dtype - @numba_basic.numba_njit() + @pytensor.link.numba.compile.numba_njit() def random(rng, p): return ( direct_cast(0, out_dtype) @@ -104,7 +105,7 @@ def random(rng, p): @numba_core_rv_funcify.register(ptr.StudentTRV) def numba_core_StudentTRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, df, loc, scale): return loc + scale * rng.standard_t(df) @@ -113,7 +114,7 @@ def random_fn(rng, df, loc, scale): @numba_core_rv_funcify.register(ptr.HalfNormalRV) def numba_core_HalfNormalRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, loc, scale): return loc + scale * np.abs(rng.standard_normal()) @@ -122,7 +123,7 @@ def random_fn(rng, loc, scale): @numba_core_rv_funcify.register(ptr.CauchyRV) def numba_core_CauchyRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random(rng, loc, scale): return (loc + rng.standard_cauchy()) / scale @@ -131,7 +132,7 @@ def random(rng, loc, scale): @numba_core_rv_funcify.register(ptr.ParetoRV) def numba_core_ParetoRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random(rng, b, scale): # Follows scipy implementation U = rng.random() @@ -142,7 +143,7 @@ def random(rng, b, scale): @numba_core_rv_funcify.register(ptr.InvGammaRV) def numba_core_InvGammaRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random(rng, shape, scale): return 1 / rng.gamma(shape, 1 / scale) @@ -151,7 +152,7 @@ def random(rng, shape, scale): @numba_core_rv_funcify.register(ptr.CategoricalRV) def core_CategoricalRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, p): unif_sample = rng.uniform(0, 1) return np.searchsorted(np.cumsum(p), unif_sample) @@ -163,7 +164,7 @@ def random_fn(rng, p): def core_MultinomialRV(op, node): dtype = op.dtype - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, n, p): n_cat = p.shape[0] draws = np.zeros(n_cat, dtype=dtype) @@ -186,7 +187,7 @@ def random_fn(rng, n, p): def core_MvNormalRV(op, node): method = op.method - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, mean, cov): if method == "cholesky": A = np.linalg.cholesky(cov) @@ -209,7 +210,7 @@ def random_fn(rng, mean, cov): @numba_core_rv_funcify.register(ptr.DirichletRV) def core_DirichletRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, alpha): y = np.empty_like(alpha) for i in range(len(alpha)): @@ -226,7 +227,7 @@ def core_GumbelRV(op, node): https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L502-L511 """ - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, loc, scale): U = 1.0 - rng.random() if U < 1.0: @@ -244,7 +245,7 @@ def core_VonMisesRV(op, node): https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L855-L925 """ - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, mu, kappa): if np.isnan(kappa): return np.nan @@ -310,7 +311,7 @@ def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node): if op.has_p_param: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, a, p, core_shape): # Adapted from Numpy: https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L922-L941 size = np.prod(core_shape) @@ -363,7 +364,7 @@ def random_fn(rng, a, p, core_shape): else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, a, core_shape): # Until Numba supports generator.choice we use a poor implementation # that permutates the whole arange array and takes the first `size` elements @@ -447,4 +448,8 @@ def random(core_shape, rng, size, *dist_params): def ov_random(core_shape, rng, size, *dist_params): return random_wrapper - return random + return random, str( + type( + rv_op, + ) + ) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index ee2960a1c4..e021f1c7b9 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,12 +2,15 @@ import numpy as np +import pytensor.link.numba.compile from pytensor.compile.ops import TypeCastingOp from pytensor.graph.basic import Variable -from pytensor.link.numba.cache import compile_and_cache_numba_function_src +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + create_numba_signature, +) from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( - create_numba_signature, generate_fallback_impl, numba_funcify, ) @@ -136,12 +139,12 @@ def {scalar_op_fn_name}({', '.join(input_names)}): # signature = create_numba_signature(node, force_scalar=True) - return numba_basic.numba_njit(scalar_op_fn) + return pytensor.link.numba.compile.numba_njit(scalar_op_fn) @numba_funcify.register(Switch) def numba_funcify_Switch(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def switch(condition, x, y): if condition: return x @@ -172,7 +175,7 @@ def numba_funcify_Add(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature)(nary_add_fn) + return pytensor.link.numba.compile.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Mul) @@ -180,14 +183,14 @@ def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature)(nary_add_fn) + return pytensor.link.numba.compile.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Cast) def numba_funcify_Cast(op, node, **kwargs): dtype = np.dtype(op.o_type.dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def cast(x): return numba_basic.direct_cast(x, dtype) @@ -197,7 +200,7 @@ def cast(x): @numba_funcify.register(Identity) @numba_funcify.register(TypeCastingOp) def numba_funcify_type_casting(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def identity(x): return x @@ -206,7 +209,7 @@ def identity(x): @numba_funcify.register(Clip) def numba_funcify_Clip(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def clip(x, min_val, max_val): x = numba_basic.to_scalar(x) min_scalar = numba_basic.to_scalar(min_val) @@ -228,7 +231,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature)( + composite_fn = pytensor.link.numba.compile.numba_njit(signature)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn @@ -236,7 +239,7 @@ def numba_funcify_Composite(op, node, **kwargs): @numba_funcify.register(Second) def numba_funcify_Second(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def second(x, y): return y @@ -245,7 +248,7 @@ def second(x, y): @numba_funcify.register(Reciprocal) def numba_funcify_Reciprocal(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def reciprocal(x): # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when # `x` is an `int` @@ -256,7 +259,7 @@ def reciprocal(x): @numba_funcify.register(Sigmoid) def numba_funcify_Sigmoid(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -265,7 +268,7 @@ def sigmoid(x): @numba_funcify.register(GammaLn) def numba_funcify_GammaLn(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def gammaln(x): return math.lgamma(x) @@ -274,7 +277,7 @@ def gammaln(x): @numba_funcify.register(Log1mexp) def numba_funcify_Log1mexp(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def logp1mexp(x): if x < np.log(0.5): return np.log1p(-np.exp(x)) @@ -286,7 +289,7 @@ def logp1mexp(x): @numba_funcify.register(Erf) def numba_funcify_Erf(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def erf(x): return math.erf(x) @@ -295,7 +298,7 @@ def erf(x): @numba_funcify.register(Erfc) def numba_funcify_Erfc(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def erfc(x): return math.erfc(x) @@ -306,7 +309,7 @@ def erfc(x): def numba_funcify_Softplus(op, node, **kwargs): out_dtype = np.dtype(node.outputs[0].type.dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def softplus(x): if x < -37.0: value = np.exp(x) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c75a4cf890..c73bb8cce5 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -4,13 +4,13 @@ from numba import types from numba.extending import overload +import pytensor.link.numba.compile from pytensor import In from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.mode import NUMBA, get_mode +from pytensor.link.numba.compile import create_arg_string, create_tuple_string from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( - create_arg_string, - create_tuple_string, numba_funcify, ) from pytensor.link.utils import compile_function_src @@ -97,7 +97,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ) rewriter(fgraph) - scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) + scan_inner_func = pytensor.link.numba.compile.numba_njit(numba_funcify(op.fgraph)) outer_in_names_to_vars = { (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) @@ -442,4 +442,4 @@ def scan({", ".join(outer_in_names)}): scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) - return numba_basic.numba_njit(scan_op_fn, boundscheck=False) + return pytensor.link.numba.compile.numba_njit(scan_op_fn, boundscheck=False) diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py new file mode 100644 index 0000000000..44ad42da75 --- /dev/null +++ b/pytensor/link/numba/dispatch/shape.py @@ -0,0 +1,85 @@ +from textwrap import dedent + +import numpy as np +from numba.np.unsafe.ndarray import to_fixed_tuple + +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + create_arg_string, + numba_njit, +) +from pytensor.link.numba.dispatch.basic import numba_funcify +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.tensor.type_other import NoneTypeT + + +@numba_funcify.register(Shape) +def numba_funcify_Shape(op, **kwargs): + @numba_njit + def shape(x): + return np.asarray(np.shape(x)) + + return shape + + +@numba_funcify.register(Shape_i) +def numba_funcify_Shape_i(op, **kwargs): + i = op.i + + @numba_njit + def shape_i(x): + return np.asarray(np.shape(x)[i]) + + return shape_i + + +@numba_funcify.register(Reshape) +def numba_funcify_Reshape(op, **kwargs): + ndim = op.ndim + + if ndim == 0: + + @numba_njit + def reshape(x, shape): + return np.asarray(x.item()) + + else: + + @numba_njit + def reshape(x, shape): + # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. + return np.reshape( + np.ascontiguousarray(np.asarray(x)), + to_fixed_tuple(shape, ndim), + ) + + return reshape + + +@numba_funcify.register(SpecifyShape) +def numba_funcify_SpecifyShape(op, node, **kwargs): + shape_inputs = node.inputs[1:] + shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] + + func_conditions = [ + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( + zip(shape_inputs, shape_input_names, strict=True) + ) + if not isinstance(node_dim_input.type, NoneTypeT) + ] + + func = dedent( + f""" + def specify_shape(x, {create_arg_string(shape_input_names)}): + {"; ".join(func_conditions)} + return x + """ + ) + + specify_shape = compile_and_cache_numba_function_src( + func, + "specify_shape", + globals(), + ) + return numba_njit(specify_shape) diff --git a/pytensor/link/numba/dispatch/signal/conv.py b/pytensor/link/numba/dispatch/signal/conv.py index 15d1bb29b1..7bb809093b 100644 --- a/pytensor/link/numba/dispatch/signal/conv.py +++ b/pytensor/link/numba/dispatch/signal/conv.py @@ -1,8 +1,8 @@ import numpy as np from numba.np.arraymath import _get_inner_prod +from pytensor.link.numba.compile import numba_njit from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.tensor.signal.conv import Convolve1d diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 5578a8379c..73b7a7221c 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -3,7 +3,8 @@ import numpy as np from pytensor import config -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.compile import numba_njit +from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( _lu_1, diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index ce7f8fc3a1..82845ec713 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,9 +1,9 @@ import numpy as np from pytensor.graph import Type -from pytensor.link.numba.cache import compile_and_cache_numba_function_src +from pytensor.link.numba.compile import compile_and_cache_numba_function_src, numba_njit from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit +from pytensor.link.numba.dispatch.basic import generate_fallback_impl from pytensor.link.utils import unique_name_generator from pytensor.tensor import TensorType from pytensor.tensor.rewriting.subtensor import is_full_slice @@ -349,7 +349,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs): return x if inplace: - return advancedincsubtensor1_inplace, 0 + return advancedincsubtensor1_inplace else: @@ -358,4 +358,4 @@ def advancedincsubtensor1(x, vals, idxs): x = x.copy() return advancedincsubtensor1_inplace(x, vals, idxs) - return advancedincsubtensor1, 0 + return advancedincsubtensor1 diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 91ca7ab15e..b95670024e 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -2,9 +2,13 @@ import numpy as np -from pytensor.link.numba.cache import compile_and_cache_numba_function_src +import pytensor.link.numba.compile +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + create_tuple_string, +) from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify +from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.utils import unique_name_generator from pytensor.tensor.basic import ( Alloc, @@ -58,7 +62,7 @@ def allocempty({", ".join(shape_var_names)}): ) return ( - numba_basic.numba_njit(alloc_fn), + pytensor.link.numba.compile.numba_njit(alloc_fn), hash_from_code(alloc_def_src), ) @@ -107,7 +111,7 @@ def alloc(val, {", ".join(shape_var_names)}): ) return ( - numba_basic.numba_njit(alloc_fn), + pytensor.link.numba.compile.numba_njit(alloc_fn), hash_from_code(alloc_def_src), ) @@ -116,7 +120,7 @@ def alloc(val, {", ".join(shape_var_names)}): def numba_funcify_ARange(op, **kwargs): dtype = np.dtype(op.dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def arange(start, stop, step): return np.arange( start.item(), @@ -130,20 +134,20 @@ def arange(start, stop, step): @numba_funcify.register(Join) def numba_funcify_Join(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def join(axis, *tensors): return np.concatenate(tensors, axis.item()) - return join, 0 + return join @numba_funcify.register(Split) def numba_funcify_Split(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def split(tensor, axis, indices): return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item()) - return split, 0 + return split @numba_funcify.register(ExtractDiag) @@ -153,7 +157,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): if node.inputs[0].type.ndim == 2: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def extract_diag(x): out = np.diag(x, k=offset) @@ -168,7 +172,7 @@ def extract_diag(x): leading_dims = (slice(None),) * axis1 middle_dims = (slice(None),) * (axis2 - axis1 - 1) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def extract_diag(x): if offset >= 0: diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset)) @@ -193,7 +197,7 @@ def extract_diag(x): def numba_funcify_Eye(op, **kwargs): dtype = np.dtype(op.dtype) - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def eye(N, M, k): return np.eye( numba_basic.to_scalar(N), @@ -231,21 +235,23 @@ def makevector({", ".join(input_names)}): "makevector", {**globals(), **global_env}, ) - return numba_basic.numba_njit(makevector_fn), hash_from_code(makevector_def_src) + return pytensor.link.numba.compile.numba_njit(makevector_fn), hash_from_code( + makevector_def_src + ) @numba_funcify.register(TensorFromScalar) def numba_funcify_TensorFromScalar(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def tensor_from_scalar(x): return np.array(x) - return tensor_from_scalar, 0 + return tensor_from_scalar @numba_funcify.register(ScalarFromTensor) def numba_funcify_ScalarFromTensor(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def scalar_from_tensor(x): return x.item() diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 332a165539..f2fe02eff8 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -15,8 +15,8 @@ from numba.core.types.misc import NoneType from numba.np import arrayobj -from pytensor.link.numba.cache import compile_and_cache_numba_function_src -from pytensor.link.numba.dispatch import basic as numba_basic +import pytensor.link.numba.compile +from pytensor.link.numba.compile import compile_and_cache_numba_function_src def encode_literals(literals: Sequence) -> str: @@ -58,7 +58,7 @@ def store_core_outputs({inp_signature}, {out_signature}): "store_core_outputs", {**globals(), **global_env}, ) - return numba_basic.numba_njit(func) + return pytensor.link.numba.compile.numba_njit(func) _jit_options = { diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 5ca598e472..2738fd60da 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -17,7 +17,7 @@ def jit_compile(self, fn): if self.vm: return fn else: - from pytensor.link.numba.dispatch.basic import numba_njit + from pytensor.link.numba.compile import numba_njit jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) return jitted_fn diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 4b5a44b053..da058c60e9 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -9,6 +9,7 @@ import scipy import pytensor.link.numba.cache +import pytensor.link.numba.compile from pytensor.compile import SymbolicInput @@ -27,7 +28,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.type import Type from pytensor.ifelse import ifelse -from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op from pytensor.scalar.basic import ScalarOp, as_scalar @@ -325,7 +325,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented): else pytest.raises(NotImplementedError) ) with cm: - res = numba_basic.get_numba_type(v, force_scalar=force_scalar) + res = pytensor.link.numba.compile.get_numba_type(v, force_scalar=force_scalar) assert res == expected @@ -367,7 +367,9 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented): ], ) def test_create_numba_signature(v, expected, force_scalar): - res = numba_basic.create_numba_signature(v, force_scalar=force_scalar) + res = pytensor.link.numba.compile.create_numba_signature( + v, force_scalar=force_scalar + ) assert res == expected diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 625246e340..741021dff5 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -280,7 +280,7 @@ def test_ExtractDiag(val, offset): ) @pytest.mark.parametrize("reverse_axis", (False, True)) def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis): - from pytensor.link.numba.dispatch.basic import numba_njit + from pytensor.link.numba.compile import numba_njit if reverse_axis: axis1, axis2 = axis2, axis1 From 5798bf6e580cbfe04dd256710565e06936a72e27 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 1 Oct 2025 19:16:57 +0200 Subject: [PATCH 08/12] New bench function --- tests/link/numba/test_compile.py | 102 +++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/link/numba/test_compile.py diff --git a/tests/link/numba/test_compile.py b/tests/link/numba/test_compile.py new file mode 100644 index 0000000000..aa1d282fc7 --- /dev/null +++ b/tests/link/numba/test_compile.py @@ -0,0 +1,102 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor import function +from pytensor.graph import rewrite_graph +from pytensor.graph.traversal import explicit_graph_inputs + + +def test_radon_model_logp_dlogp(): + def halfnormal(name, *, sigma=1.0, model_logp): + log_value = pt.scalar(f"{name}_log") + value = pt.exp(log_value) + + logp = ( + -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) + ) + logp = pt.switch(value >= 0, logp, -np.inf) + model_logp.append(logp + value) + return value + + def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): + value = pt.scalar(name) if observed is None else pt.as_tensor(observed) + + logp = ( + -0.5 * (((value - mu) / sigma) ** 2) + - pt.log(pt.sqrt(2.0 * np.pi)) + - pt.log(sigma) + ) + model_logp.append(logp) + return value + + def zerosumnormal(name, *, sigma=1.0, size, model_logp): + raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) + n = raw_value.shape[0] + 1 + sum_vals = raw_value.sum(0, keepdims=True) + norm = sum_vals / (pt.sqrt(n) + n) + fill_value = norm - sum_vals / pt.sqrt(n) + value = pt.concatenate([raw_value, fill_value]) - norm + + shape = value.shape + _full_size = pt.prod(shape) + _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) + logp = pt.sum( + -0.5 * ((value / sigma) ** 2) + - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) + * (_degrees_of_freedom / _full_size) + ) + model_logp.append(logp) + return value + + rng = np.random.default_rng(1) + n_counties = 85 + county_idx = rng.integers(n_counties, size=919) + county_idx.sort() + floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) + log_radon = rng.normal(size=919) + + # joined_inputs = pt.vector("joined_inputs") + + model_logp = [] + intercept = normal("intercept", sigma=10, model_logp=model_logp) + + # County effects + county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) + county_sd = halfnormal("county_sd", model_logp=model_logp) + county_effect = county_raw * county_sd + + # Global floor effect + floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) + + county_floor_raw = zerosumnormal( + "county_floor_raw", size=n_counties, model_logp=model_logp + ) + county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) + county_floor_effect = county_floor_raw * county_floor_sd + + mu = ( + intercept + + county_effect[county_idx] + + floor_effect * floor + + county_floor_effect[county_idx] * floor + ) + + sigma = halfnormal("sigma", model_logp=model_logp) + _ = normal( + "log_radon", + mu=mu, + sigma=sigma, + observed=log_radon, + model_logp=model_logp, + ) + + model_logp = pt.sum([logp.sum() for logp in model_logp]) + model_logp = rewrite_graph( + model_logp, include=("canonicalize", "stabilize"), clone=False + ) + params = list(explicit_graph_inputs(model_logp)) + model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) + + # TODO: Replace inputs by raveled vector + + function(params, [model_logp, model_dlogp]).dprint() From b57828429833501c5d01067f8a50047a90d39aa6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Oct 2025 13:32:26 +0200 Subject: [PATCH 09/12] Benchmark radon function --- pytensor/compile/mode.py | 17 ++-- tests/compile/function/test_types.py | 140 ++++++++++++++++++++++++++- 2 files changed, 150 insertions(+), 7 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 1c7f9e70a4..2f4eed31fb 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -453,7 +453,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # string as the key # Use VM_linker to allow lazy evaluation by default. FAST_COMPILE = Mode( - NumbaLinker(vm=True), + "numba_vm", # TODO: Fast_compile should just use python code, CHANGE ME! RewriteDatabaseQuery( include=["fast_compile", "numba"], @@ -461,15 +461,18 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) FAST_RUN = Mode( - NumbaLinker(vm=True), + "numba_vm", RewriteDatabaseQuery( include=["fast_run", "numba"], exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], ), ) +C = Mode("c", "fast_run") +C_VM = Mode("cvm", "fast_run") + NUMBA = Mode( - NumbaLinker(), + "numba", RewriteDatabaseQuery( include=["fast_run", "numba"], exclude=[ @@ -482,12 +485,12 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ) NUMBA_VM = Mode( - NumbaLinker(vm=True), + "numba_vm", NUMBA._optimizer, ) JAX = Mode( - JAXLinker(), + "jax", RewriteDatabaseQuery( include=["fast_run", "jax"], exclude=[ @@ -503,7 +506,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) PYTORCH = Mode( - PytorchLinker(), + "pytorch", RewriteDatabaseQuery( include=["fast_run"], exclude=[ @@ -522,6 +525,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): predefined_modes = { "FAST_COMPILE": FAST_COMPILE, "FAST_RUN": FAST_RUN, + "C": C, + "C_VM": C_VM, "JAX": JAX, "NUMBA": NUMBA, "NUMBA_VM": NUMBA_VM, diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 90589db337..4aaeb97663 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -12,7 +12,9 @@ from pytensor.compile.io import In, Out from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config -from pytensor.graph.basic import Constant +from pytensor.graph.basic import Constant, explicit_graph_inputs +from pytensor.graph.replace import graph_replace +from pytensor.graph.rewriting import rewrite_graph from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter from pytensor.graph.utils import MissingInputError from pytensor.link.vm import VMLinker @@ -1357,3 +1359,139 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark): rng_val = np.random.default_rng() benchmark(f, rng_val) + + +@pytest.fixture(scope="module") +def radon_model(): + def halfnormal(name, *, sigma=1.0, model_logp): + log_value = pt.scalar(f"{name}_log") + value = pt.exp(log_value) + + logp = ( + -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) + ) + logp = pt.switch(value >= 0, logp, -np.inf) + model_logp.append(logp + value) + return value + + def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): + value = pt.scalar(name) if observed is None else pt.as_tensor(observed) + + logp = ( + -0.5 * (((value - mu) / sigma) ** 2) + - pt.log(pt.sqrt(2.0 * np.pi)) + - pt.log(sigma) + ) + model_logp.append(logp) + return value + + def zerosumnormal(name, *, sigma=1.0, size, model_logp): + raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) + n = raw_value.shape[0] + 1 + sum_vals = raw_value.sum(0, keepdims=True) + norm = sum_vals / (pt.sqrt(n) + n) + fill_value = norm - sum_vals / pt.sqrt(n) + value = pt.concatenate([raw_value, fill_value]) - norm + + shape = value.shape + _full_size = pt.prod(shape) + _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) + logp = pt.sum( + -0.5 * ((value / sigma) ** 2) + - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) + * (_degrees_of_freedom / _full_size) + ) + model_logp.append(logp) + return value + + rng = np.random.default_rng(1) + n_counties = 85 + county_idx = rng.integers(n_counties, size=919) + county_idx.sort() + floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) + log_radon = rng.normal(size=919) + + model_logp = [] + intercept = normal("intercept", sigma=10, model_logp=model_logp) + + # County effects + county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) + county_sd = halfnormal("county_sd", model_logp=model_logp) + county_effect = county_raw * county_sd + + # Global floor effect + floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) + + county_floor_raw = zerosumnormal( + "county_floor_raw", size=n_counties, model_logp=model_logp + ) + county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) + county_floor_effect = county_floor_raw * county_floor_sd + + mu = ( + intercept + + county_effect[county_idx] + + floor_effect * floor + + county_floor_effect[county_idx] * floor + ) + + sigma = halfnormal("sigma", model_logp=model_logp) + _ = normal( + "log_radon", + mu=mu, + sigma=sigma, + observed=log_radon, + model_logp=model_logp, + ) + + model_logp = pt.sum([logp.sum() for logp in model_logp]) + model_logp = rewrite_graph( + model_logp, include=("canonicalize", "stabilize"), clone=False + ) + params = list(explicit_graph_inputs(model_logp)) + model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) + + size = sum(int(np.prod(p.type.shape)) for p in params) + joined_inputs = pt.vector("joined_inputs", shape=(size,)) + idx = 0 + replacement = {} + for param in params: + param_shape = param.type.shape + param_size = int(np.prod(param_shape)) + replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape) + idx += param_size + assert idx == joined_inputs.type.shape[0] + + model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement) + return joined_inputs, [model_logp, model_dlogp] + + +@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA", "NUMBA_VM"]) +def test_radon_model_compile_benchmark(mode, radon_model, benchmark): + joined_inputs, [model_logp, model_dlogp] = radon_model + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + + def compile_and_call_once(): + fn = function( + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True + ) + fn(x) + + benchmark(compile_and_call_once) + + +@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA", "NUMBA_VM"]) +def test_radon_model_call_benchmark(mode, radon_model, benchmark): + joined_inputs, [model_logp, model_dlogp] = radon_model + + real_mode = "C_VM" if mode == "C_VM_NOGC" else mode + fn = function( + [joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True + ) + if mode == "C_VM_NOGC": + fn.vm.allow_gc = False + + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + benchmark(fn, x) From 5bb8c9b8580846d41c890436d458e304a9542547 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Oct 2025 13:32:38 +0200 Subject: [PATCH 10/12] Fix non-vm NUMBA --- pytensor/link/numba/dispatch/basic.py | 8 +- pytensor/link/numba/dispatch/subtensor.py | 3 +- pytensor/link/numba/dispatch/tensor_basic.py | 5 +- pytensor/link/numba/linker.py | 2 +- tests/link/numba/test_compile.py | 102 ------------------- 5 files changed, 10 insertions(+), 110 deletions(-) delete mode 100644 tests/link/numba/test_compile.py diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 267f40b608..9fe77f7a07 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -270,9 +270,15 @@ def numba_funcify_FunctionGraph( jit_nodes: bool = False, **kwargs, ): + def numba_funcify_wrapper(*args, **kwargs): + result = numba_funcify(*args, **kwargs) + if isinstance(result, tuple): + return result[0] + return result + return fgraph_to_python( fgraph, - op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify, + op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify_wrapper, type_conversion_fn=numba_typify, fgraph_name=fgraph_name, **kwargs, diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 82845ec713..8f30597887 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -16,7 +16,6 @@ Subtensor, ) from pytensor.tensor.type_other import NoneTypeT, SliceType -from pytensor.utils import hash_from_code @numba_funcify.register(Subtensor) @@ -102,7 +101,7 @@ def {function_name}({", ".join(input_names)}): function_name=function_name, global_env=globals() | {"np": np}, ) - return numba_njit(func, boundscheck=True), hash_from_code(subtensor_def_src) + return numba_njit(func, boundscheck=True) @numba_funcify.register(AdvancedSubtensor) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index b95670024e..531b695c23 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -110,10 +110,7 @@ def alloc(val, {", ".join(shape_var_names)}): {**globals(), **global_env}, ) - return ( - pytensor.link.numba.compile.numba_njit(alloc_fn), - hash_from_code(alloc_def_src), - ) + return pytensor.link.numba.compile.numba_njit(alloc_fn) @numba_funcify.register(ARange) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 2738fd60da..d8e8eb332a 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -19,7 +19,7 @@ def jit_compile(self, fn): else: from pytensor.link.numba.compile import numba_njit - jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) + jitted_fn = numba_njit(fn, final_function=True) return jitted_fn def create_thunk_inputs(self, storage_map): diff --git a/tests/link/numba/test_compile.py b/tests/link/numba/test_compile.py deleted file mode 100644 index aa1d282fc7..0000000000 --- a/tests/link/numba/test_compile.py +++ /dev/null @@ -1,102 +0,0 @@ -import numpy as np - -import pytensor.tensor as pt -from pytensor import function -from pytensor.graph import rewrite_graph -from pytensor.graph.traversal import explicit_graph_inputs - - -def test_radon_model_logp_dlogp(): - def halfnormal(name, *, sigma=1.0, model_logp): - log_value = pt.scalar(f"{name}_log") - value = pt.exp(log_value) - - logp = ( - -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) - ) - logp = pt.switch(value >= 0, logp, -np.inf) - model_logp.append(logp + value) - return value - - def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): - value = pt.scalar(name) if observed is None else pt.as_tensor(observed) - - logp = ( - -0.5 * (((value - mu) / sigma) ** 2) - - pt.log(pt.sqrt(2.0 * np.pi)) - - pt.log(sigma) - ) - model_logp.append(logp) - return value - - def zerosumnormal(name, *, sigma=1.0, size, model_logp): - raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) - n = raw_value.shape[0] + 1 - sum_vals = raw_value.sum(0, keepdims=True) - norm = sum_vals / (pt.sqrt(n) + n) - fill_value = norm - sum_vals / pt.sqrt(n) - value = pt.concatenate([raw_value, fill_value]) - norm - - shape = value.shape - _full_size = pt.prod(shape) - _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) - logp = pt.sum( - -0.5 * ((value / sigma) ** 2) - - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) - * (_degrees_of_freedom / _full_size) - ) - model_logp.append(logp) - return value - - rng = np.random.default_rng(1) - n_counties = 85 - county_idx = rng.integers(n_counties, size=919) - county_idx.sort() - floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) - log_radon = rng.normal(size=919) - - # joined_inputs = pt.vector("joined_inputs") - - model_logp = [] - intercept = normal("intercept", sigma=10, model_logp=model_logp) - - # County effects - county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) - county_sd = halfnormal("county_sd", model_logp=model_logp) - county_effect = county_raw * county_sd - - # Global floor effect - floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) - - county_floor_raw = zerosumnormal( - "county_floor_raw", size=n_counties, model_logp=model_logp - ) - county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) - county_floor_effect = county_floor_raw * county_floor_sd - - mu = ( - intercept - + county_effect[county_idx] - + floor_effect * floor - + county_floor_effect[county_idx] * floor - ) - - sigma = halfnormal("sigma", model_logp=model_logp) - _ = normal( - "log_radon", - mu=mu, - sigma=sigma, - observed=log_radon, - model_logp=model_logp, - ) - - model_logp = pt.sum([logp.sum() for logp in model_logp]) - model_logp = rewrite_graph( - model_logp, include=("canonicalize", "stabilize"), clone=False - ) - params = list(explicit_graph_inputs(model_logp)) - model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) - - # TODO: Replace inputs by raveled vector - - function(params, [model_logp, model_dlogp]).dprint() From 4f1585b82e47077fe035bfe96d3c20295cf3b2e4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 6 Oct 2025 18:45:26 +0200 Subject: [PATCH 11/12] .more caching --- pytensor/compile/mode.py | 2 +- pytensor/link/numba/dispatch/basic.py | 4 ++- pytensor/link/numba/dispatch/blockwise.py | 28 +++++++++++++++-- pytensor/link/numba/dispatch/elemwise.py | 37 ++++++++++++++--------- pytensor/link/numba/dispatch/random.py | 29 +++++++++++++++--- pytensor/link/numba/dispatch/scalar.py | 7 +++-- pytensor/link/numba/dispatch/scan.py | 18 ++++++++--- pytensor/link/numba/dispatch/slinalg.py | 21 ++++++++----- tests/tensor/rewriting/test_basic.py | 1 - 9 files changed, 109 insertions(+), 38 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 2f4eed31fb..bb3707f768 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -461,7 +461,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) FAST_RUN = Mode( - "numba_vm", + "numba", RewriteDatabaseQuery( include=["fast_run", "numba"], exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 9fe77f7a07..1de5a461da 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -213,7 +213,8 @@ def perform(*inputs): ret = py_perform_return(inputs) return ret - return perform + # Assume we can't cache python functions + return perform, None @singledispatch @@ -276,6 +277,7 @@ def numba_funcify_wrapper(*args, **kwargs): return result[0] return result + # TODO: Create hash key for whole graph return fgraph_to_python( fgraph, op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify_wrapper, diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index cfbf5e89e7..421ede1334 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -1,4 +1,5 @@ import sys +from hashlib import sha256 from typing import cast from numba.core.extending import overload @@ -30,12 +31,17 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): cast(tuple[TensorVariable], node.inputs[:nin]), propagate_unbatched_core_inputs=True, ) - core_op_fn = numba_funcify( + core_op_fn_and_key = numba_funcify( core_op, node=core_node, parent_node=node, **kwargs, ) + if isinstance(core_op_fn_and_key, tuple): + core_op_fn, core_op_key = core_op_fn_and_key + else: + # Assume we can cache core_op_fn + core_op_fn, core_op_key = core_op_fn_and_key, 0 core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) batch_ndim = blockwise_op.batch_ndim(node) @@ -90,4 +96,22 @@ def blockwise(*inputs_and_core_shapes): def ov_blockwise(*inputs_and_core_shapes): return blockwise_wrapper - return blockwise + if core_op_key is None: + # We were told the scalar op cannot be cached + blockwise_key = None + else: + blockwise_key = "_".join( + map( + str, + ( + type(op), + type(op.scalar_op), + tuple(op.inplace_pattern.items()), + tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()), + core_op_key, + ), + ) + ) + blockwise_key = sha256(blockwise_key.encode()).hexdigest() + + return blockwise, blockwise_key diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 4ed8636979..61264f2bdc 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -269,25 +269,18 @@ def numba_funcify_Elemwise(op, node, **kwargs): scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) - scalar_op_fn = numba_funcify( + scalar_op_fn_and_key = numba_funcify( op.scalar_op, node=scalar_node, parent_node=node, **kwargs, ) + if isinstance(scalar_op_fn_and_key, tuple): + scalar_op_fn, scalar_op_key = scalar_op_fn_and_key + else: + # Assume op can be cached + scalar_op_fn, scalar_op_key = scalar_op_fn_and_key, 0 - # TODO: Proper key - core_op_key = "_".join( - map( - str, - ( - op, - op.scalar_op, - tuple(op.inplace_pattern.items()), - tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()), - ), - ) - ) core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) @@ -339,7 +332,23 @@ def elemwise(*inputs): def ov_elemwise(*inputs): return elemwise_wrapper - elemwise_key = sha256(f"Elemwise2{core_op_key}".encode()).hexdigest() + if scalar_op_key is None: + # We were told the scalar op cannot be cached + elemwise_key = None + else: + elemwise_key = "_".join( + map( + str, + ( + type(op), + type(op.scalar_op), + tuple(op.inplace_pattern.items()), + tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()), + scalar_op_key, + ), + ) + ) + elemwise_key = sha256(elemwise_key.encode()).hexdigest() return elemwise, elemwise_key diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index a700de5689..f3099d54b7 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,6 +1,7 @@ from collections.abc import Callable from copy import copy, deepcopy from functools import singledispatch +from hashlib import sha256 from textwrap import dedent import numba @@ -408,7 +409,13 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs core_shape_len = get_vector_length(core_shape) inplace = rv_op.inplace - core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + core_rv_fn_and_key = numba_core_rv_funcify(rv_op, rv_node) + if isinstance(core_rv_fn_and_key, tuple): + core_rv_fn, core_rv_key = core_rv_fn_and_key + else: + # Assume we can cache core_op_fn + core_rv_fn, core_rv_key = core_rv_fn_and_key, 0 + nin = 1 + len(dist_params) # rng + params core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) @@ -448,8 +455,20 @@ def random(core_shape, rng, size, *dist_params): def ov_random(core_shape, rng, size, *dist_params): return random_wrapper - return random, str( - type( - rv_op, + if core_rv_key is None: + random_rv_key = None + else: + random_rv_key = "_".join( + map( + str, + ( + type(op), + type(rv_op), + tuple(getattr(rv_op, "props_dict", lambda: {})().items()), + core_rv_key, + ), + ) ) - ) + random_rv_key = sha256(random_rv_key.encode()).hexdigest() + + return random, random_rv_key diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e021f1c7b9..cbac51db4d 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -62,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): output_inner_dtype = None # Cython functions might have an additional argument + cython_func = None has_pyx_skip_dispatch = False if scalar_func_path.startswith("scipy.special"): @@ -137,9 +138,9 @@ def {scalar_op_fn_name}({', '.join(input_names)}): {**globals(), **global_env}, ) - # signature = create_numba_signature(node, force_scalar=True) - - return pytensor.link.numba.compile.numba_njit(scalar_op_fn) + # Functions that call a function pointer can't be cached + cache_key = None if cython_func else 0 + return pytensor.link.numba.compile.numba_njit(scalar_op_fn), cache_key @numba_funcify.register(Switch) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c73bb8cce5..f357cb73fc 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -8,12 +8,16 @@ from pytensor import In from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.mode import NUMBA, get_mode -from pytensor.link.numba.compile import create_arg_string, create_tuple_string +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + create_arg_string, + create_tuple_string, + numba_njit, +) from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( numba_funcify, ) -from pytensor.link.utils import compile_function_src from pytensor.scan.op import Scan from pytensor.tensor.type import TensorType @@ -440,6 +444,12 @@ def scan({", ".join(outer_in_names)}): } global_env["np"] = np - scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) + scan_op_fn = compile_and_cache_numba_function_src( + scan_op_src, + "scan", + {**globals(), **global_env}, + # We can't cache until we can hash FunctionGraph + key=None, + ) - return pytensor.link.numba.compile.numba_njit(scan_op_fn, boundscheck=False) + return numba_njit(scan_op_fn, boundscheck=False), None diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 73b7a7221c..d04ad8161f 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -88,7 +88,8 @@ def cholesky(a): return res - return cholesky + # We cannot cache LAPACK functions + return cholesky, None @numba_funcify.register(PivotToPermutations) @@ -154,7 +155,8 @@ def lu(a): return res - return lu + # We cannot cache LAPACK functions + return lu, None @numba_funcify.register(LUFactor) @@ -178,7 +180,8 @@ def lu_factor(a): return LU, piv - return lu_factor + # We cannot cache LAPACK functions + return lu_factor, None @numba_funcify.register(BlockDiagonal) @@ -251,7 +254,8 @@ def solve(a, b): res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed) return res - return solve + # We cannot cache LAPACK functions + return solve, None @numba_funcify.register(SolveTriangular) @@ -292,7 +296,8 @@ def solve_triangular(a, b): return res - return solve_triangular + # We cannot cache LAPACK functions + return solve_triangular, None @numba_funcify.register(CholeskySolve) @@ -321,7 +326,8 @@ def cho_solve(c, b): c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite ) - return cho_solve + # We cannot cache LAPACK functions + return cho_solve, None @numba_funcify.register(QR) @@ -414,4 +420,5 @@ def qr(a): f"QR mode={mode}, pivoting={pivoting} not supported in numba mode." ) - return qr + # We cannot cache LAPACK functions + return qr, None diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 3ffa1aa267..aabe71b7dd 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1208,7 +1208,6 @@ def test_sum_bool_upcast(self): f(5) -@pytest.mark.xfail(reason="Numba does not support float16") class TestLocalOptAllocF16(TestLocalOptAlloc): dtype = "float16" From 44a4ab751e348f806c102613d9a256753f7d587c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 6 Oct 2025 20:30:29 +0200 Subject: [PATCH 12/12] .fix shit --- pytensor/compile/function/types.py | 1 - pytensor/configdefaults.py | 4 ++-- pytensor/link/numba/dispatch/basic.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 635af25e47..ee3dc27124 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -1838,7 +1838,6 @@ def orig_function( profile.compile_time += t2 - t1 # TODO: append profile.nb_nodes = len(fn.maker.fgraph.apply_nodes) - return fn diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index e8ea54e7c2..8ad1755508 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -380,12 +380,12 @@ def add_compile_configvars(): "vm_nogc", "cvm_nogc", "jax", - "numba", + "numba_vm", ] else: # g++ is not present or the user disabled it, # linker should default to python only. - linker_options = ["py", "vm", "vm_nogc", "jax", "numba"] + linker_options = ["py", "vm", "vm_nogc", "cvm" "jax", "numba", "numba_vm"] if type(config).cxx.is_default: # If the user provided an empty value for cxx, do not warn. _logger.warning( diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 1de5a461da..6314fc8628 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -280,7 +280,7 @@ def numba_funcify_wrapper(*args, **kwargs): # TODO: Create hash key for whole graph return fgraph_to_python( fgraph, - op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify_wrapper, + op_conversion_fn=numba_funcify_njit, # if jit_nodes else numba_funcify_wrapper, type_conversion_fn=numba_typify, fgraph_name=fgraph_name, **kwargs,