diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 635af25e47..ee3dc27124 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -1838,7 +1838,6 @@ def orig_function( profile.compile_time += t2 - t1 # TODO: append profile.nb_nodes = len(fn.maker.fgraph.apply_nodes) - return fn diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 8bd0e2f901..bb3707f768 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -50,6 +50,7 @@ "jax": JAXLinker(), "pytorch": PytorchLinker(), "numba": NumbaLinker(), + "numba_vm": NumbaLinker(vm=True), } @@ -63,9 +64,8 @@ def register_linker(name, linker): # If a string is passed as the optimizer argument in the constructor # for Mode, it will be used as the key to retrieve the real optimizer # in this dictionary -exclude = [] -if not config.cxx: - exclude = ["cxx_only"] + +exclude = ["cxx_only", "BlasOpt"] OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude) # Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude) @@ -351,6 +351,11 @@ def __setstate__(self, state): optimizer = predefined_optimizers[optimizer] if isinstance(optimizer, RewriteDatabaseQuery): self.provided_optimizer = optimizer + + # Force numba-required rewrites if using NumbaLinker + if isinstance(linker, NumbaLinker): + optimizer = optimizer.including("numba") + self._optimizer = optimizer self.call_time = 0 self.fn_time = 0 @@ -448,19 +453,26 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # string as the key # Use VM_linker to allow lazy evaluation by default. FAST_COMPILE = Mode( - VMLinker(use_cloop=False, c_thunks=False), - RewriteDatabaseQuery(include=["fast_compile", "py_only"]), + "numba_vm", + # TODO: Fast_compile should just use python code, CHANGE ME! + RewriteDatabaseQuery( + include=["fast_compile", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), +) +FAST_RUN = Mode( + "numba", + RewriteDatabaseQuery( + include=["fast_run", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), ) -if config.cxx: - FAST_RUN = Mode("cvm", "fast_run") -else: - FAST_RUN = Mode( - "vm", - RewriteDatabaseQuery(include=["fast_run", "py_only"]), - ) + +C = Mode("c", "fast_run") +C_VM = Mode("cvm", "fast_run") NUMBA = Mode( - NumbaLinker(), + "numba", RewriteDatabaseQuery( include=["fast_run", "numba"], exclude=[ @@ -472,8 +484,13 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) +NUMBA_VM = Mode( + "numba_vm", + NUMBA._optimizer, +) + JAX = Mode( - JAXLinker(), + "jax", RewriteDatabaseQuery( include=["fast_run", "jax"], exclude=[ @@ -489,7 +506,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) PYTORCH = Mode( - PytorchLinker(), + "pytorch", RewriteDatabaseQuery( include=["fast_run"], exclude=[ @@ -508,8 +525,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): predefined_modes = { "FAST_COMPILE": FAST_COMPILE, "FAST_RUN": FAST_RUN, + "C": C, + "C_VM": C_VM, "JAX": JAX, "NUMBA": NUMBA, + "NUMBA_VM": NUMBA_VM, "PYTORCH": PYTORCH, } @@ -574,6 +594,7 @@ def register_mode(name, mode): Add a `Mode` which can be referred to by `name` in `function`. """ + # TODO: Remove me if name in predefined_modes: raise ValueError(f"Mode name already taken: {name}") predefined_modes[name] = mode diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 7698c5d441..8ad1755508 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -370,11 +370,22 @@ def add_compile_configvars(): if rc == 0 and config.cxx != "": # Keep the default linker the same as the one for the mode FAST_RUN - linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"] + linker_options = [ + "cvm", + "c|py", + "py", + "c", + "c|py_nogc", + "vm", + "vm_nogc", + "cvm_nogc", + "jax", + "numba_vm", + ] else: # g++ is not present or the user disabled it, # linker should default to python only. - linker_options = ["py", "vm_nogc"] + linker_options = ["py", "vm", "vm_nogc", "cvm" "jax", "numba", "numba_vm"] if type(config).cxx.is_default: # If the user provided an empty value for cxx, do not warn. _logger.warning( @@ -388,7 +399,7 @@ def add_compile_configvars(): "linker", "Default linker used if the pytensor flags mode is Mode", # Not mutable because the default mode is cached after the first use. - EnumStr("cvm", linker_options, mutable=False), + EnumStr("numba_vm", linker_options, mutable=False), in_c_key=False, ) diff --git a/pytensor/link/numba/cache.py b/pytensor/link/numba/cache.py new file mode 100644 index 0000000000..ad3a8dfe1b --- /dev/null +++ b/pytensor/link/numba/cache.py @@ -0,0 +1,76 @@ +import weakref +from hashlib import sha256 +from pathlib import Path + +from numba.core.caching import CacheImpl, _CacheLocator + +from pytensor import config +from pytensor.graph.basic import Apply + + +NUMBA_PYTENSOR_CACHE_ENABLED = True +NUMBA_CACHE_PATH = config.base_compiledir / "numba" +NUMBA_CACHE_PATH.mkdir(exist_ok=True) +CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary() + + +class NumbaPyTensorCacheLocator(_CacheLocator): + def __init__(self, py_func, py_file, hash): + self._py_func = py_func + self._py_file = py_file + self._hash = hash + # src_hash = hash(pytensor_loader._module_sources[self._py_file]) + # self._hash = hash((src_hash, py_file, pytensor.__version__)) + + def ensure_cache_path(self): + pass + + def get_cache_path(self): + """ + Return the directory the function is cached in. + """ + return NUMBA_CACHE_PATH + + def get_source_stamp(self): + """ + Get a timestamp representing the source code's freshness. + Can return any picklable Python object. + """ + return 0 + + def get_disambiguator(self): + """ + Get a string disambiguator for this locator's function. + It should allow disambiguating different but similarly-named functions. + """ + return self._hash + + @classmethod + def from_function(cls, py_func, py_file): + """ + Create a locator instance for the given function located in the given file. + """ + # py_file = Path(py_file).parent + # if py_file == (config.base_compiledir / "numba"): + if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS: + # print(f"Applies to {py_file}") + return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func]) + + +CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator) + + +def cache_node_key(node: Apply, extra_key="") -> str: + op = node.op + return sha256( + str( + ( + # Op signature + (type(op), op._props_dict() if hasattr(op, "_props_dict") else ""), + # Node signature + tuple((type(inp_type := inp.type), inp_type) for inp in node.inputs), + # Extra key given by the caller + extra_key, + ), + ).encode() + ).hexdigest() diff --git a/pytensor/link/numba/compile.py b/pytensor/link/numba/compile.py new file mode 100644 index 0000000000..01eaca26e6 --- /dev/null +++ b/pytensor/link/numba/compile.py @@ -0,0 +1,196 @@ +import warnings +from collections.abc import Callable +from typing import Any + +import numba +import numpy as np +from numba import NumbaWarning +from numba import njit as _njit +from numba.core.extending import register_jitable + +from pytensor import config +from pytensor.graph import Apply, FunctionGraph, Type +from pytensor.link.numba.cache import CACHED_SRC_FUNCTIONS +from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType +from pytensor.scalar import ScalarType +from pytensor.sparse import SparseTensorType +from pytensor.tensor import TensorType + + +def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs): + if fastmath is None: + if config.numba__fastmath: + # Opinionated default on fastmath flags + # https://llvm.org/docs/LangRef.html#fast-math-flags + fastmath = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # no-signed zeros + } + else: + fastmath = False + + if final_function: + kwargs.setdefault("cache", True) + else: + kwargs.setdefault("no_cpython_wrapper", True) + kwargs.setdefault("no_cfunc_wrapper", True) + + # Suppress cache warning for internal functions + # We have to add an ansi escape code for optional bold text by numba + warnings.filterwarnings( + "ignore", + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + "Cannot cache compiled function " + '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" ' + "as it uses dynamic globals" + ), + category=NumbaWarning, + ) + + func = _njit if final_function else register_jitable + if len(args) > 0 and callable(args[0]): + return func(*args[1:], fastmath=fastmath, **kwargs)(args[0]) + else: + return func(*args, fastmath=fastmath, **kwargs) + + +def compile_and_cache_numba_function_src( + src: str, + function_name: str, + global_env: dict[Any, Any] | None = None, + local_env: dict[Any, Any] | None = None, + key: str | None = None, +) -> Callable: + # if key is not None: + # filename = NUMBA_CACHE_PATH / key + # with filename.open("wb") as f: + # f.write(src.encode()) + # else: + # with NamedTemporaryFile(delete=False) as f: + # filename = f.name + # f.write(src.encode()) + + if global_env is None: + global_env = {} + + if local_env is None: + local_env = {} + + mod_code = compile(src, "", mode="exec") + exec(mod_code, global_env, local_env) + + res = local_env[function_name] + res.__source__ = src # type: ignore + + if key is not None: + CACHED_SRC_FUNCTIONS[res] = key + return res + + +def get_numba_type( + pytensor_type: Type, + layout: str = "A", + force_scalar: bool = False, + reduce_to_scalar: bool = False, +) -> numba.types.Type: + r"""Create a Numba type object for a :class:`Type`. + + Parameters + ---------- + pytensor_type + The :class:`Type` to convert. + layout + The :class:`numpy.ndarray` layout to use. + force_scalar + Ignore dimension information and return the corresponding Numba scalar types. + reduce_to_scalar + Return Numba scalars for zero dimensional :class:`TensorType`\s. + """ + + if isinstance(pytensor_type, TensorType): + dtype = pytensor_type.numpy_dtype + numba_dtype = numba.from_dtype(dtype) + if force_scalar or ( + reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0 + ): + return numba_dtype + return numba.types.Array(numba_dtype, pytensor_type.ndim, layout) + elif isinstance(pytensor_type, ScalarType): + dtype = np.dtype(pytensor_type.dtype) + numba_dtype = numba.from_dtype(dtype) + return numba_dtype + elif isinstance(pytensor_type, SparseTensorType): + dtype = pytensor_type.numpy_dtype + numba_dtype = numba.from_dtype(dtype) + if pytensor_type.format == "csr": + return CSRMatrixType(numba_dtype) + if pytensor_type.format == "csc": + return CSCMatrixType(numba_dtype) + + raise NotImplementedError() + else: + raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") + + +def create_numba_signature( + node_or_fgraph: FunctionGraph | Apply, + force_scalar: bool = False, + reduce_to_scalar: bool = False, +) -> numba.types.Type: + """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`.""" + input_types = [ + get_numba_type( + inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar + ) + for inp in node_or_fgraph.inputs + ] + + output_types = [ + get_numba_type( + out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar + ) + for out in node_or_fgraph.outputs + ] + + if len(output_types) > 1: + return numba.types.Tuple(output_types)(*input_types) + elif len(output_types) == 1: + return output_types[0](*input_types) + else: + return numba.types.void(*input_types) + + +def create_tuple_creator(f, n): + """Construct a compile-time ``tuple``-comprehension-like loop. + + See https://github.com/numba/numba/issues/2771#issuecomment-414358902 + """ + assert n > 0 + + f = numba_njit(f) + + @numba_njit + def creator(args): + return (f(0, *args),) + + for i in range(1, n): + + @numba_njit + def creator(args, creator=creator, i=i): + return (*creator(args), f(i, *args)) + + return numba_njit(lambda *args: creator(args)) + + +def create_tuple_string(x): + args = ", ".join(x + ([""] if len(x) == 1 else [])) + return f"({args})" + + +def create_arg_string(x): + args = ", ".join(x) + return args diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1fefb1d06d..1541331d31 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -9,6 +9,7 @@ import pytensor.link.numba.dispatch.random import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scalar +import pytensor.link.numba.dispatch.shape import pytensor.link.numba.dispatch.signal import pytensor.link.numba.dispatch.slinalg import pytensor.link.numba.dispatch.sparse diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 87b8e380d3..6314fc8628 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,170 +1,39 @@ import operator import sys import warnings -from copy import copy +from collections.abc import Callable from functools import singledispatch -from textwrap import dedent import numba -import numba.np.unsafe.ndarray as numba_ndarray import numpy as np -import scipy -import scipy.special from llvmlite import ir from numba import types -from numba.core.errors import NumbaWarning, TypingError +from numba.core.errors import TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import box, overload +from numba.extending import box from pytensor import In, config from pytensor.compile import NUMBA from pytensor.compile.builders import OpFromGraph from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.ops import DeepCopyOp -from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph -from pytensor.graph.type import Type from pytensor.ifelse import IfElse -from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType -from pytensor.link.utils import ( - compile_function_src, - fgraph_to_python, +from pytensor.link.numba.cache import ( + cache_node_key, ) -from pytensor.scalar.basic import ScalarType -from pytensor.sparse import SparseTensorType +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + get_numba_type, + numba_njit, +) +from pytensor.link.utils import fgraph_to_python +from pytensor.tensor import TensorType from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -from pytensor.tensor.slinalg import Solve from pytensor.tensor.sort import ArgSortOp, SortOp -from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import MakeSlice, NoneConst - - -def global_numba_func(func): - """Use to return global numba functions in numba_funcify_*. - - This allows tests to remove the compilation using mock. - """ - return func - - -def numba_njit(*args, fastmath=None, **kwargs): - kwargs.setdefault("cache", config.numba__cache) - kwargs.setdefault("no_cpython_wrapper", True) - kwargs.setdefault("no_cfunc_wrapper", True) - if fastmath is None: - if config.numba__fastmath: - # Opinionated default on fastmath flags - # https://llvm.org/docs/LangRef.html#fast-math-flags - fastmath = { - "arcp", # Allow Reciprocal - "contract", # Allow floating-point contraction - "afn", # Approximate functions - "reassoc", - "nsz", # no-signed zeros - } - else: - fastmath = False - - # Suppress cache warning for internal functions - # We have to add an ansi escape code for optional bold text by numba - warnings.filterwarnings( - "ignore", - message=( - "(\x1b\\[1m)*" # ansi escape code for bold text - "Cannot cache compiled function " - '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" ' - "as it uses dynamic globals" - ), - category=NumbaWarning, - ) - - if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) - - return numba.njit(*args, fastmath=fastmath, **kwargs) - - -def numba_vectorize(*args, **kwargs): - if len(args) > 0 and callable(args[0]): - return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0]) - - return numba.vectorize(*args, cache=config.numba__cache, **kwargs) - - -def get_numba_type( - pytensor_type: Type, - layout: str = "A", - force_scalar: bool = False, - reduce_to_scalar: bool = False, -) -> numba.types.Type: - r"""Create a Numba type object for a :class:`Type`. - - Parameters - ---------- - pytensor_type - The :class:`Type` to convert. - layout - The :class:`numpy.ndarray` layout to use. - force_scalar - Ignore dimension information and return the corresponding Numba scalar types. - reduce_to_scalar - Return Numba scalars for zero dimensional :class:`TensorType`\s. - """ - - if isinstance(pytensor_type, TensorType): - dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) - if force_scalar or ( - reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0 - ): - return numba_dtype - return numba.types.Array(numba_dtype, pytensor_type.ndim, layout) - elif isinstance(pytensor_type, ScalarType): - dtype = np.dtype(pytensor_type.dtype) - numba_dtype = numba.from_dtype(dtype) - return numba_dtype - elif isinstance(pytensor_type, SparseTensorType): - dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) - if pytensor_type.format == "csr": - return CSRMatrixType(numba_dtype) - if pytensor_type.format == "csc": - return CSCMatrixType(numba_dtype) - - raise NotImplementedError() - else: - raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") - - -def create_numba_signature( - node_or_fgraph: FunctionGraph | Apply, - force_scalar: bool = False, - reduce_to_scalar: bool = False, -) -> numba.types.Type: - """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`.""" - input_types = [ - get_numba_type( - inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) - for inp in node_or_fgraph.inputs - ] - - output_types = [ - get_numba_type( - out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) - for out in node_or_fgraph.outputs - ] - - if len(output_types) > 1: - return numba.types.Tuple(output_types)(*input_types) - elif len(output_types) == 1: - return output_types[0](*input_types) - else: - return numba.types.void(*input_types) +from pytensor.tensor.type_other import MakeSlice def slice_new(self, start, stop, step): @@ -244,36 +113,53 @@ def impl_to_scalar(x): raise TypingError(f"{x} must be a scalar compatible type.") -def create_tuple_creator(f, n): - """Construct a compile-time ``tuple``-comprehension-like loop. +@numba.extending.intrinsic +def direct_cast(typingctx, val, typ): + if isinstance(typ, numba.types.TypeRef): + casted = typ.instance_type + elif isinstance(typ, numba.types.DTypeSpec): + casted = typ.dtype + else: + casted = typ - See https://github.com/numba/numba/issues/2771#issuecomment-414358902 - """ - assert n > 0 + sig = casted(casted, typ) - f = numba_njit(f) + def codegen(context, builder, signature, args): + val, _ = args + context.nrt.incref(builder, signature.return_type, val) + return val - @numba_njit - def creator(args): - return (f(0, *args),) + return sig, codegen - for i in range(1, n): - @numba_njit - def creator(args, creator=creator, i=i): - return (*creator(args), f(i, *args)) +def int_to_float_fn(inputs, out_dtype): + """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" - return numba_njit(lambda *args: creator(args)) + if ( + all(inp.type.dtype == out_dtype for inp in inputs) + and np.dtype(out_dtype).kind == "f" + ): + @numba_njit(inline="always") + def inputs_cast(x): + return x -def create_tuple_string(x): - args = ", ".join(x + ([""] if len(x) == 1 else [])) - return f"({args})" + elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): + args_dtype = np.dtype(f"f{out_dtype.itemsize}") + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) -def create_arg_string(x): - args = ", ".join(x) - return args + else: + args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) + args_dtype = np.dtype(f"f{args_dtype_sz}") + + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) + + return inputs_cast @singledispatch @@ -327,7 +213,8 @@ def perform(*inputs): ret = py_perform_return(inputs) return ret - return perform + # Assume we can't cache python functions + return perform, None @singledispatch @@ -341,6 +228,65 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): return generate_fallback_impl(op, node, storage_map, **kwargs) +def numba_funcify_njit(op, node, **kwargs): + jitable_func_and_key = numba_funcify(op, node=node, **kwargs) + + match jitable_func_and_key: + case Callable(): + jitable_func = jitable_func_and_key + key = cache_node_key(node) + case (Callable(), str() | int()): + jitable_func, funcify_key = jitable_func_and_key + key = cache_node_key(node, funcify_key) + case (Callable(), None): + # We were explicitly told by the dispatch not to try and cache this function + jitable_func, key = jitable_func_and_key + case _: + raise TypeError( + f"numpy_funcify should return a callable or a callable, key pair, got {jitable_func_and_key}" + ) + + if key is not None: + # To force numba to use our cache, we must compile the function so that any closure + # becomes a global variable... + op_name = op.__class__.__name__ + cached_func = compile_and_cache_numba_function_src( + src=f"def {op_name}(*args): return jitable_func(*args)", + function_name=op_name, + global_env=globals() | dict(jitable_func=jitable_func), + key=key, + ) + return numba_njit(cached_func, final_function=True, cache=True) + else: + return numba_njit( + lambda *args: jitable_func(*args), final_function=True, cache=False + ) + + +@numba_funcify.register(FunctionGraph) +def numba_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="numba_funcified_fgraph", + jit_nodes: bool = False, + **kwargs, +): + def numba_funcify_wrapper(*args, **kwargs): + result = numba_funcify(*args, **kwargs) + if isinstance(result, tuple): + return result[0] + return result + + # TODO: Create hash key for whole graph + return fgraph_to_python( + fgraph, + op_conversion_fn=numba_funcify_njit, # if jit_nodes else numba_funcify_wrapper, + type_conversion_fn=numba_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + @numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) @@ -370,40 +316,25 @@ def opfromgraph(*inputs): def opfromgraph(*inputs): return fgraph_fn(*inputs) - return opfromgraph + # We can't cache this correctly until we can define a key for it + return opfromgraph, None -@numba_funcify.register(FunctionGraph) -def numba_funcify_FunctionGraph( - fgraph, - node=None, - fgraph_name="numba_funcified_fgraph", - **kwargs, -): - return fgraph_to_python( - fgraph, - numba_funcify, - type_conversion_fn=numba_typify, - fgraph_name=fgraph_name, - **kwargs, - ) - - -def deepcopyop(x): - return copy(x) - +@numba_funcify.register(DeepCopyOp) +def numba_funcify_DeepCopyOp(op, node, **kwargs): + if isinstance(node.inputs[0].type, TensorType): -@overload(deepcopyop) -def dispatch_deepcopyop(x): - if isinstance(x, types.Array): - return lambda x: np.copy(x) + @numba_njit + def deepcopy_fn(x): + return np.copy(x) - return lambda x: x + else: + @numba_njit + def deepcopy_fn(x): + return x -@numba_funcify.register(DeepCopyOp) -def numba_funcify_DeepCopyOp(op, node, **kwargs): - return deepcopyop + return deepcopy_fn @numba_funcify.register(MakeSlice) @@ -415,26 +346,6 @@ def makeslice(*x): return makeslice -@numba_funcify.register(Shape) -def numba_funcify_Shape(op, **kwargs): - @numba_njit - def shape(x): - return np.asarray(np.shape(x)) - - return shape - - -@numba_funcify.register(Shape_i) -def numba_funcify_Shape_i(op, **kwargs): - i = op.i - - @numba_njit - def shape_i(x): - return np.asarray(np.shape(x)[i]) - - return shape_i - - @numba_funcify.register(SortOp) def numba_funcify_SortOp(op, node, **kwargs): @numba_njit @@ -497,103 +408,6 @@ def argort_vec(X, axis): return argsort_f_kind(kind) -@numba.extending.intrinsic -def direct_cast(typingctx, val, typ): - if isinstance(typ, numba.types.TypeRef): - casted = typ.instance_type - elif isinstance(typ, numba.types.DTypeSpec): - casted = typ.dtype - else: - casted = typ - - sig = casted(casted, typ) - - def codegen(context, builder, signature, args): - val, _ = args - context.nrt.incref(builder, signature.return_type, val) - return val - - return sig, codegen - - -@numba_funcify.register(Reshape) -def numba_funcify_Reshape(op, **kwargs): - ndim = op.ndim - - if ndim == 0: - - @numba_njit - def reshape(x, shape): - return np.asarray(x.item()) - - else: - - @numba_njit - def reshape(x, shape): - # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. - return np.reshape( - np.ascontiguousarray(np.asarray(x)), - numba_ndarray.to_fixed_tuple(shape, ndim), - ) - - return reshape - - -@numba_funcify.register(SpecifyShape) -def numba_funcify_SpecifyShape(op, node, **kwargs): - shape_inputs = node.inputs[1:] - shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] - - func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( - zip(shape_inputs, shape_input_names, strict=True) - ) - if shape_input is not NoneConst - ] - - func = dedent( - f""" - def specify_shape(x, {create_arg_string(shape_input_names)}): - {"; ".join(func_conditions)} - return x - """ - ) - - specify_shape = compile_function_src(func, "specify_shape", globals()) - return numba_njit(specify_shape) - - -def int_to_float_fn(inputs, out_dtype): - """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" - - if ( - all(inp.type.dtype == out_dtype for inp in inputs) - and np.dtype(out_dtype).kind == "f" - ): - - @numba_njit(inline="always") - def inputs_cast(x): - return x - - elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): - args_dtype = np.dtype(f"f{out_dtype.itemsize}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - else: - args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) - args_dtype = np.dtype(f"f{args_dtype_sz}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - return inputs_cast - - @numba_funcify.register(Dot) def numba_funcify_Dot(op, node, **kwargs): # Numba's `np.dot` does not support integer dtypes, so we need to cast to float. @@ -641,51 +455,6 @@ def dot_with_cast(x, y): return dot_with_cast -@numba_funcify.register(Solve) -def numba_funcify_Solve(op, node, **kwargs): - assume_a = op.assume_a - # check_finite = op.check_finite - - if assume_a != "gen": - lower = op.lower - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`compute_uv` argument to `numpy.linalg.svd`." - ), - UserWarning, - ) - - ret_sig = get_numba_type(node.outputs[0].type) - - @numba_njit - def solve(a, b): - with numba.objmode(ret=ret_sig): - ret = scipy.linalg.solve_triangular( - a, - b, - lower=lower, - # check_finite=check_finite - ) - return ret - - else: - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_njit - def solve(a, b): - return np.linalg.solve( - inputs_cast(a), - inputs_cast(b), - # assume_a=assume_a, - # check_finite=check_finite, - ).astype(out_dtype) - - return solve - - @numba_funcify.register(BatchedDot) def numba_funcify_BatchedDot(op, node, **kwargs): dtype = node.outputs[0].type.numpy_dtype diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index 45df8341ea..421ede1334 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -1,10 +1,12 @@ import sys +from hashlib import sha256 from typing import cast from numba.core.extending import overload from numba.np.unsafe.ndarray import to_fixed_tuple -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.compile import numba_njit +from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.vectorize_codegen import ( _jit_options, _vectorized, @@ -29,12 +31,17 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): cast(tuple[TensorVariable], node.inputs[:nin]), propagate_unbatched_core_inputs=True, ) - core_op_fn = numba_funcify( + core_op_fn_and_key = numba_funcify( core_op, node=core_node, parent_node=node, **kwargs, ) + if isinstance(core_op_fn_and_key, tuple): + core_op_fn, core_op_key = core_op_fn_and_key + else: + # Assume we can cache core_op_fn + core_op_fn, core_op_key = core_op_fn_and_key, 0 core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) batch_ndim = blockwise_op.batch_ndim(node) @@ -89,4 +96,22 @@ def blockwise(*inputs_and_core_shapes): def ov_blockwise(*inputs_and_core_shapes): return blockwise_wrapper - return blockwise + if core_op_key is None: + # We were told the scalar op cannot be cached + blockwise_key = None + else: + blockwise_key = "_".join( + map( + str, + ( + type(op), + type(op.scalar_op), + tuple(op.inplace_pattern.items()), + tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()), + core_op_key, + ), + ) + ) + blockwise_key = sha256(blockwise_key.encode()).hexdigest() + + return blockwise, blockwise_key diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 7244762b93..61264f2bdc 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,4 +1,5 @@ from functools import singledispatch +from hashlib import sha256 from textwrap import dedent, indent import numba @@ -6,19 +7,18 @@ from numba.core.extending import overload from numpy.lib.stride_tricks import as_strided +import pytensor.link.numba.compile from pytensor.graph.op import Op +from pytensor.link.numba.compile import compile_and_cache_numba_function_src, numba_njit from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( numba_funcify, - numba_njit, ) from pytensor.link.numba.dispatch.vectorize_codegen import ( - _jit_options, _vectorized, encode_literals, store_core_outputs, ) -from pytensor.link.utils import compile_function_src from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar.basic import ( AND, @@ -237,7 +237,7 @@ def {careduce_fn_name}(x): careduce_def_src += "\n\n" careduce_def_src += indent(f"return {return_obj}", " " * 4) - careduce_fn = compile_function_src( + careduce_fn = compile_and_cache_numba_function_src( careduce_def_src, careduce_fn_name, {**globals(), **global_env} ) @@ -249,7 +249,7 @@ def create_axis_apply_fn(fn, axis, ndim, dtype): reaxis_first = (*(i for i in range(ndim) if i != axis), axis) - @numba_basic.numba_njit(boundscheck=False) + @pytensor.link.numba.compile.numba_njit(boundscheck=False) def axis_apply_fn(x): x_reaxis = x.transpose(reaxis_first) @@ -264,18 +264,23 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): + nin = len(node.inputs) + nout = len(node.outputs) + scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) - - scalar_op_fn = numba_funcify( + scalar_op_fn_and_key = numba_funcify( op.scalar_op, node=scalar_node, parent_node=node, **kwargs, ) + if isinstance(scalar_op_fn_and_key, tuple): + scalar_op_fn, scalar_op_key = scalar_op_fn_and_key + else: + # Assume op can be cached + scalar_op_fn, scalar_op_key = scalar_op_fn_and_key, 0 - nin = len(node.inputs) - nout = len(node.outputs) core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs) @@ -305,39 +310,46 @@ def elemwise_wrapper(*inputs): # Pure python implementation, that will be used in tests def elemwise(*inputs): - inputs = [np.asarray(input) for input in inputs] + Elemwise._check_runtime_broadcast(node, inputs) inputs_bc = np.broadcast_arrays(*inputs) - shape = inputs[0].shape - for input, bc in zip(inputs, input_bc_patterns, strict=True): - for length, allow_bc, iter_length in zip( - input.shape, bc, shape, strict=True - ): - if length == 1 and shape and iter_length != 1 and not allow_bc: - raise ValueError("Broadcast not allowed.") - - outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes] - - for idx in np.ndindex(shape): - vals = [input[idx] for input in inputs_bc] - outs = scalar_op_fn(*vals) - if not isinstance(outs, tuple): - outs = (outs,) - for out, out_val in zip(outputs, outs, strict=True): - out[idx] = out_val - - outputs_summed = [] - for output, bc in zip(outputs, output_bc_patterns, strict=True): - axes = tuple(np.nonzero(bc)[0]) - outputs_summed.append(output.sum(axes, keepdims=True)) - if len(outputs_summed) != 1: - return tuple(outputs_summed) - return outputs_summed[0] - - @overload(elemwise, jit_options=_jit_options) + shape = inputs_bc[0].shape + + if len(output_dtypes) == 1: + output = np.empty(shape, dtype=output_dtypes[0]) + for idx in np.ndindex(shape): + output[idx] = scalar_op_fn(*(inp[idx] for inp in inputs_bc)) + return output + + else: + outputs = [np.empty(shape, dtype=dtype) for dtype in output_dtypes] + for idx in np.ndindex(shape): + outs_vals = scalar_op_fn(*(inp[idx] for inp in inputs_bc)) + for out, out_val in zip(outputs, outs_vals): + out[idx] = out_val + return outputs + + @overload(elemwise) def ov_elemwise(*inputs): return elemwise_wrapper - return elemwise + if scalar_op_key is None: + # We were told the scalar op cannot be cached + elemwise_key = None + else: + elemwise_key = "_".join( + map( + str, + ( + type(op), + type(op.scalar_op), + tuple(op.inplace_pattern.items()), + tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()), + scalar_op_key, + ), + ) + ) + elemwise_key = sha256(elemwise_key.encode()).hexdigest() + return elemwise, elemwise_key @numba_funcify.register(Sum) @@ -421,7 +433,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs): if new_order == (): # Special case needed because of https://github.com/numba/numba/issues/9933 - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def squeeze_to_0d(x): return as_strided(x, shape=(), strides=()) @@ -429,7 +441,7 @@ def squeeze_to_0d(x): else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def dimshuffle(x): old_shape = x.shape old_strides = x.strides @@ -464,7 +476,7 @@ def numba_funcify_Softmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit(boundscheck=False) + jit_fn = pytensor.link.numba.compile.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -496,7 +508,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit(boundscheck=False) + jit_fn = pytensor.link.numba.compile.numba_njit(boundscheck=False) reduce_sum = jit_fn(reduce_sum_py) else: reduce_sum = np.sum @@ -533,7 +545,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit(boundscheck=False) + jit_fn = pytensor.link.numba.compile.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -559,7 +571,7 @@ def numba_funcify_Argmax(op, node, **kwargs): if x_ndim == 0: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def argmax(x): return np.array(0, dtype="int64") @@ -579,7 +591,7 @@ def argmax(x): sl1 = slice(None, len(keep_axes)) sl2 = slice(len(keep_axes), None) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def argmax(x): # Not-reduced axes in front transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order)) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index f7700acf47..74f7b27751 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -4,9 +4,11 @@ import numba import numpy as np +import pytensor.link.numba.compile from pytensor.graph import Apply +from pytensor.link.numba.compile import get_numba_type from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify +from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.raise_op import CheckAndRaise from pytensor.tensor import TensorVariable from pytensor.tensor.extra_ops import ( @@ -24,7 +26,7 @@ @numba_funcify.register(Bartlett) def numba_funcify_Bartlett(op, **kwargs): - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def bartlett(x): return np.bartlett(numba_basic.to_scalar(x)) @@ -49,13 +51,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False) + @pytensor.link.numba.compile.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -73,13 +75,13 @@ def cumop(x): else: if axis is None or ndim == 1: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False) + @pytensor.link.numba.compile.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -99,7 +101,7 @@ def cumop(x): @numba_funcify.register(FillDiagonal) def numba_funcify_FillDiagonal(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def filldiagonal(a, val): np.fill_diagonal(a, val) return a @@ -109,7 +111,7 @@ def filldiagonal(a, val): @numba_funcify.register(FillDiagonalOffset) def numba_funcify_FillDiagonalOffset(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def filldiagonaloffset(a, val, offset): height, width = a.shape @@ -144,25 +146,25 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs): if mode == "raise": - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def mode_fn(*args): raise ValueError("invalid entry in coordinates array") elif mode == "wrap": - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def mode_fn(new_arr, i, j, v, d): new_arr[i, j] = v % d elif mode == "clip": - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def mode_fn(new_arr, i, j, v, d): new_arr[i, j] = min(max(v, 0), d - 1) if node.inputs[0].ndim == 0: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def ravelmultiindex(*inp): shape = inp[-1] arr = np.stack(inp[:-1]) @@ -178,7 +180,7 @@ def ravelmultiindex(*inp): else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def ravelmultiindex(*inp): shape = inp[-1] arr = np.stack(inp[:-1]) @@ -217,7 +219,7 @@ def numba_funcify_Repeat(op, node, **kwargs): ret_sig = get_numba_type(node.outputs[0].type) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def repeatop(x, repeats): with numba.objmode(ret=ret_sig): ret = np.repeat(x, repeats, axis) @@ -228,13 +230,13 @@ def repeatop(x, repeats): if repeats_ndim == 0: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def repeatop(x, repeats): return np.repeat(x, repeats.item()) else: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def repeatop(x, repeats): return np.repeat(x, repeats) @@ -259,7 +261,7 @@ def numba_funcify_Unique(op, node, **kwargs): if not use_python: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def unique(x): return np.unique(x) @@ -277,7 +279,7 @@ def unique(x): else: ret_sig = get_numba_type(node.outputs[0].type) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def unique(x): with numba.objmode(ret=ret_sig): ret = np.unique(x, return_index, return_inverse, return_counts, axis) @@ -297,17 +299,17 @@ def numba_funcify_UnravelIndex(op, node, **kwargs): if len(node.outputs) == 1: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def maybe_expand_dim(arr): return arr else: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def maybe_expand_dim(arr): return np.expand_dims(arr, 1) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def unravelindex(arr, shape): a = np.ones(len(shape), dtype=np.int64) a[1:] = shape[:0:-1] @@ -340,7 +342,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ret_sig = get_numba_type(node.outputs[0].type) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def searchsorted(a, v, sorter): with numba.objmode(ret=ret_sig): ret = np.searchsorted(a, v, side, sorter) @@ -348,7 +350,7 @@ def searchsorted(a, v, sorter): else: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def searchsorted(a, v): return np.searchsorted(a, v, side) @@ -360,7 +362,7 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs): error = op.exc_type msg = op.msg - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def check_and_raise(x, *conditions): for cond in conditions: if not cond: diff --git a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py index 9575dd7d56..82bf2f009d 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py +++ b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py @@ -6,8 +6,8 @@ from numpy import ndarray from scipy import linalg +from pytensor.link.numba.compile import numba_njit from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.link.numba.dispatch.linalg._LAPACK import ( _LAPACK, _get_underlying_float, diff --git a/pytensor/link/numba/dispatch/linalg/solve/utils.py b/pytensor/link/numba/dispatch/linalg/solve/utils.py index ec6c4ef213..5eedd3ecba 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/utils.py +++ b/pytensor/link/numba/dispatch/linalg/solve/utils.py @@ -1,9 +1,9 @@ from scipy import linalg -from pytensor.link.numba.dispatch import basic as numba_basic +import pytensor.link.numba.compile -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _solve_check_input_shapes(A, B): if A.shape[0] != B.shape[0]: raise linalg.LinAlgError("Dimensions of A and B do not conform") diff --git a/pytensor/link/numba/dispatch/linalg/utils.py b/pytensor/link/numba/dispatch/linalg/utils.py index b15888abd6..568fa34235 100644 --- a/pytensor/link/numba/dispatch/linalg/utils.py +++ b/pytensor/link/numba/dispatch/linalg/utils.py @@ -6,7 +6,7 @@ from numba.np.linalg import _copy_to_fortran_order, ensure_lapack from numpy.linalg import LinAlgError -from pytensor.link.numba.dispatch import basic as numba_basic +import pytensor.link.numba.compile from pytensor.link.numba.dispatch.linalg._LAPACK import ( _LAPACK, _get_underlying_float, @@ -14,13 +14,13 @@ ) -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _copy_to_fortran_order_even_if_1d(x): # Numba's _copy_to_fortran_order doesn't do anything for vectors return x.copy() if x.ndim == 1 else _copy_to_fortran_order(x) -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _trans_char_to_int(trans): if trans not in [0, 1, 2]: raise ValueError('Parameter "trans" should be one of 0, 1, 2') @@ -53,7 +53,7 @@ def _check_scipy_linalg_matrix(a, func_name): raise numba.TypingError(msg, highlighting=False) -@numba_basic.numba_njit(inline="always") +@pytensor.link.numba.compile.numba_njit(inline="always") def _solve_check(n, info, lamch=False, rcond=None): """ Check arguments during the different steps of the solution phase diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 98d59a4595..c85fdbd3e8 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -3,9 +3,9 @@ import numba import numpy as np -from pytensor.link.numba.dispatch import basic as numba_basic +import pytensor.link.numba.compile +from pytensor.link.numba.compile import get_numba_type from pytensor.link.numba.dispatch.basic import ( - get_numba_type, int_to_float_fn, numba_funcify, ) @@ -30,14 +30,14 @@ def numba_funcify_SVD(op, node, **kwargs): if not compute_uv: - @numba_basic.numba_njit() + @pytensor.link.numba.compile.numba_njit() def svd(x): _, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices) return ret else: - @numba_basic.numba_njit() + @pytensor.link.numba.compile.numba_njit() def svd(x): return np.linalg.svd(inputs_cast(x), full_matrices) @@ -49,7 +49,7 @@ def numba_funcify_Det(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def det(x): return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype) @@ -63,7 +63,7 @@ def numba_funcify_SLogDet(op, node, **kwargs): inputs_cast = int_to_float_fn(node.inputs, out_dtype_1) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def slogdet(x): sign, det = np.linalg.slogdet(inputs_cast(x)) return ( @@ -81,7 +81,7 @@ def numba_funcify_Eig(op, node, **kwargs): inputs_cast = int_to_float_fn(node.inputs, out_dtype_1) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def eig(x): out = np.linalg.eig(inputs_cast(x)) return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2)) @@ -107,7 +107,7 @@ def numba_funcify_Eigh(op, node, **kwargs): [get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)] ) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def eigh(x): with numba.objmode(ret=ret_sig): out = np.linalg.eigh(x, UPLO=uplo) @@ -116,7 +116,7 @@ def eigh(x): else: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def eigh(x): return np.linalg.eigh(x) @@ -128,7 +128,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def matrix_inverse(x): return np.linalg.inv(inputs_cast(x)).astype(out_dtype) @@ -140,7 +140,7 @@ def numba_funcify_MatrixPinv(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def matrixpinv(x): return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 36618ceb26..f3099d54b7 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,6 +1,7 @@ from collections.abc import Callable from copy import copy, deepcopy from functools import singledispatch +from hashlib import sha256 from textwrap import dedent import numba @@ -9,10 +10,11 @@ from numba import types from numba.core.extending import overload +import pytensor.link.numba.compile import pytensor.tensor.random.basic as ptr from pytensor.graph import Apply from pytensor.graph.op import Op -from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.compile import numba_njit from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify from pytensor.link.numba.dispatch.vectorize_codegen import ( _jit_options, @@ -84,14 +86,14 @@ def {name}(rng, {input_signature}): """) func = compile_function_src(func_src, name, {**globals()}) - return numba_basic.numba_njit(func) + return numba_njit(func) @numba_core_rv_funcify.register(ptr.BernoulliRV) def numba_core_BernoulliRV(op, node): out_dtype = node.outputs[1].type.numpy_dtype - @numba_basic.numba_njit() + @pytensor.link.numba.compile.numba_njit() def random(rng, p): return ( direct_cast(0, out_dtype) @@ -104,7 +106,7 @@ def random(rng, p): @numba_core_rv_funcify.register(ptr.StudentTRV) def numba_core_StudentTRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, df, loc, scale): return loc + scale * rng.standard_t(df) @@ -113,7 +115,7 @@ def random_fn(rng, df, loc, scale): @numba_core_rv_funcify.register(ptr.HalfNormalRV) def numba_core_HalfNormalRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, loc, scale): return loc + scale * np.abs(rng.standard_normal()) @@ -122,7 +124,7 @@ def random_fn(rng, loc, scale): @numba_core_rv_funcify.register(ptr.CauchyRV) def numba_core_CauchyRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random(rng, loc, scale): return (loc + rng.standard_cauchy()) / scale @@ -131,7 +133,7 @@ def random(rng, loc, scale): @numba_core_rv_funcify.register(ptr.ParetoRV) def numba_core_ParetoRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random(rng, b, scale): # Follows scipy implementation U = rng.random() @@ -142,7 +144,7 @@ def random(rng, b, scale): @numba_core_rv_funcify.register(ptr.InvGammaRV) def numba_core_InvGammaRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random(rng, shape, scale): return 1 / rng.gamma(shape, 1 / scale) @@ -151,7 +153,7 @@ def random(rng, shape, scale): @numba_core_rv_funcify.register(ptr.CategoricalRV) def core_CategoricalRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, p): unif_sample = rng.uniform(0, 1) return np.searchsorted(np.cumsum(p), unif_sample) @@ -163,7 +165,7 @@ def random_fn(rng, p): def core_MultinomialRV(op, node): dtype = op.dtype - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, n, p): n_cat = p.shape[0] draws = np.zeros(n_cat, dtype=dtype) @@ -186,7 +188,7 @@ def random_fn(rng, n, p): def core_MvNormalRV(op, node): method = op.method - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, mean, cov): if method == "cholesky": A = np.linalg.cholesky(cov) @@ -209,7 +211,7 @@ def random_fn(rng, mean, cov): @numba_core_rv_funcify.register(ptr.DirichletRV) def core_DirichletRV(op, node): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, alpha): y = np.empty_like(alpha) for i in range(len(alpha)): @@ -226,7 +228,7 @@ def core_GumbelRV(op, node): https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L502-L511 """ - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, loc, scale): U = 1.0 - rng.random() if U < 1.0: @@ -244,7 +246,7 @@ def core_VonMisesRV(op, node): https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L855-L925 """ - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, mu, kappa): if np.isnan(kappa): return np.nan @@ -310,7 +312,7 @@ def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node): if op.has_p_param: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, a, p, core_shape): # Adapted from Numpy: https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L922-L941 size = np.prod(core_shape) @@ -363,7 +365,7 @@ def random_fn(rng, a, p, core_shape): else: - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def random_fn(rng, a, core_shape): # Until Numba supports generator.choice we use a poor implementation # that permutates the whole arange array and takes the first `size` elements @@ -407,7 +409,13 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs core_shape_len = get_vector_length(core_shape) inplace = rv_op.inplace - core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + core_rv_fn_and_key = numba_core_rv_funcify(rv_op, rv_node) + if isinstance(core_rv_fn_and_key, tuple): + core_rv_fn, core_rv_key = core_rv_fn_and_key + else: + # Assume we can cache core_op_fn + core_rv_fn, core_rv_key = core_rv_fn_and_key, 0 + nin = 1 + len(dist_params) # rng + params core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) @@ -447,4 +455,20 @@ def random(core_shape, rng, size, *dist_params): def ov_random(core_shape, rng, size, *dist_params): return random_wrapper - return random + if core_rv_key is None: + random_rv_key = None + else: + random_rv_key = "_".join( + map( + str, + ( + type(op), + type(rv_op), + tuple(getattr(rv_op, "props_dict", lambda: {})().items()), + core_rv_key, + ), + ) + ) + random_rv_key = sha256(random_rv_key.encode()).hexdigest() + + return random, random_rv_key diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index ada4e8cc36..cbac51db4d 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,17 +2,20 @@ import numpy as np +import pytensor.link.numba.compile from pytensor.compile.ops import TypeCastingOp from pytensor.graph.basic import Variable +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + create_numba_signature, +) from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( - create_numba_signature, generate_fallback_impl, numba_funcify, ) from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( - compile_function_src, get_name_for_object, unique_name_generator, ) @@ -59,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): output_inner_dtype = None # Cython functions might have an additional argument + cython_func = None has_pyx_skip_dispatch = False if scalar_func_path.startswith("scipy.special"): @@ -128,22 +132,20 @@ def {scalar_op_fn_name}({', '.join(input_names)}): return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) """ - scalar_op_fn = compile_function_src( - scalar_op_src, scalar_op_fn_name, {**globals(), **global_env} + scalar_op_fn = compile_and_cache_numba_function_src( + scalar_op_src, + scalar_op_fn_name, + {**globals(), **global_env}, ) - signature = create_numba_signature(node, force_scalar=True) - - return numba_basic.numba_njit( - signature, - # Functions that call a function pointer can't be cached - cache=False, - )(scalar_op_fn) + # Functions that call a function pointer can't be cached + cache_key = None if cython_func else 0 + return pytensor.link.numba.compile.numba_njit(scalar_op_fn), cache_key @numba_funcify.register(Switch) def numba_funcify_Switch(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def switch(condition, x, y): if condition: return x @@ -164,7 +166,7 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: def {binary_op_name}({input_signature}): return {output_expr} """ - nary_fn = compile_function_src(nary_src, binary_op_name, globals()) + nary_fn = compile_and_cache_numba_function_src(nary_src, binary_op_name, globals()) return nary_fn @@ -174,7 +176,7 @@ def numba_funcify_Add(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature)(nary_add_fn) + return pytensor.link.numba.compile.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Mul) @@ -182,14 +184,14 @@ def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature)(nary_add_fn) + return pytensor.link.numba.compile.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Cast) def numba_funcify_Cast(op, node, **kwargs): dtype = np.dtype(op.o_type.dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def cast(x): return numba_basic.direct_cast(x, dtype) @@ -199,7 +201,7 @@ def cast(x): @numba_funcify.register(Identity) @numba_funcify.register(TypeCastingOp) def numba_funcify_type_casting(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def identity(x): return x @@ -208,7 +210,7 @@ def identity(x): @numba_funcify.register(Clip) def numba_funcify_Clip(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def clip(x, min_val, max_val): x = numba_basic.to_scalar(x) min_scalar = numba_basic.to_scalar(min_val) @@ -230,7 +232,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature)( + composite_fn = pytensor.link.numba.compile.numba_njit(signature)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn @@ -238,7 +240,7 @@ def numba_funcify_Composite(op, node, **kwargs): @numba_funcify.register(Second) def numba_funcify_Second(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def second(x, y): return y @@ -247,7 +249,7 @@ def second(x, y): @numba_funcify.register(Reciprocal) def numba_funcify_Reciprocal(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def reciprocal(x): # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when # `x` is an `int` @@ -258,7 +260,7 @@ def reciprocal(x): @numba_funcify.register(Sigmoid) def numba_funcify_Sigmoid(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -267,7 +269,7 @@ def sigmoid(x): @numba_funcify.register(GammaLn) def numba_funcify_GammaLn(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def gammaln(x): return math.lgamma(x) @@ -276,7 +278,7 @@ def gammaln(x): @numba_funcify.register(Log1mexp) def numba_funcify_Log1mexp(op, node, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def logp1mexp(x): if x < np.log(0.5): return np.log1p(-np.exp(x)) @@ -288,7 +290,7 @@ def logp1mexp(x): @numba_funcify.register(Erf) def numba_funcify_Erf(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def erf(x): return math.erf(x) @@ -297,7 +299,7 @@ def erf(x): @numba_funcify.register(Erfc) def numba_funcify_Erfc(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def erfc(x): return math.erfc(x) @@ -308,7 +310,7 @@ def erfc(x): def numba_funcify_Softplus(op, node, **kwargs): out_dtype = np.dtype(node.outputs[0].type.dtype) - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def softplus(x): if x < -37.0: value = np.exp(x) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c75a4cf890..f357cb73fc 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -4,16 +4,20 @@ from numba import types from numba.extending import overload +import pytensor.link.numba.compile from pytensor import In from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.mode import NUMBA, get_mode -from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import ( +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, create_arg_string, create_tuple_string, + numba_njit, +) +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import ( numba_funcify, ) -from pytensor.link.utils import compile_function_src from pytensor.scan.op import Scan from pytensor.tensor.type import TensorType @@ -97,7 +101,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ) rewriter(fgraph) - scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) + scan_inner_func = pytensor.link.numba.compile.numba_njit(numba_funcify(op.fgraph)) outer_in_names_to_vars = { (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) @@ -440,6 +444,12 @@ def scan({", ".join(outer_in_names)}): } global_env["np"] = np - scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) + scan_op_fn = compile_and_cache_numba_function_src( + scan_op_src, + "scan", + {**globals(), **global_env}, + # We can't cache until we can hash FunctionGraph + key=None, + ) - return numba_basic.numba_njit(scan_op_fn, boundscheck=False) + return numba_njit(scan_op_fn, boundscheck=False), None diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py new file mode 100644 index 0000000000..44ad42da75 --- /dev/null +++ b/pytensor/link/numba/dispatch/shape.py @@ -0,0 +1,85 @@ +from textwrap import dedent + +import numpy as np +from numba.np.unsafe.ndarray import to_fixed_tuple + +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + create_arg_string, + numba_njit, +) +from pytensor.link.numba.dispatch.basic import numba_funcify +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.tensor.type_other import NoneTypeT + + +@numba_funcify.register(Shape) +def numba_funcify_Shape(op, **kwargs): + @numba_njit + def shape(x): + return np.asarray(np.shape(x)) + + return shape + + +@numba_funcify.register(Shape_i) +def numba_funcify_Shape_i(op, **kwargs): + i = op.i + + @numba_njit + def shape_i(x): + return np.asarray(np.shape(x)[i]) + + return shape_i + + +@numba_funcify.register(Reshape) +def numba_funcify_Reshape(op, **kwargs): + ndim = op.ndim + + if ndim == 0: + + @numba_njit + def reshape(x, shape): + return np.asarray(x.item()) + + else: + + @numba_njit + def reshape(x, shape): + # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. + return np.reshape( + np.ascontiguousarray(np.asarray(x)), + to_fixed_tuple(shape, ndim), + ) + + return reshape + + +@numba_funcify.register(SpecifyShape) +def numba_funcify_SpecifyShape(op, node, **kwargs): + shape_inputs = node.inputs[1:] + shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] + + func_conditions = [ + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( + zip(shape_inputs, shape_input_names, strict=True) + ) + if not isinstance(node_dim_input.type, NoneTypeT) + ] + + func = dedent( + f""" + def specify_shape(x, {create_arg_string(shape_input_names)}): + {"; ".join(func_conditions)} + return x + """ + ) + + specify_shape = compile_and_cache_numba_function_src( + func, + "specify_shape", + globals(), + ) + return numba_njit(specify_shape) diff --git a/pytensor/link/numba/dispatch/signal/conv.py b/pytensor/link/numba/dispatch/signal/conv.py index 15d1bb29b1..7bb809093b 100644 --- a/pytensor/link/numba/dispatch/signal/conv.py +++ b/pytensor/link/numba/dispatch/signal/conv.py @@ -1,8 +1,8 @@ import numpy as np from numba.np.arraymath import _get_inner_prod +from pytensor.link.numba.compile import numba_njit from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.tensor.signal.conv import Convolve1d diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 7d1e915298..d04ad8161f 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -3,7 +3,8 @@ import numpy as np from pytensor import config -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.compile import numba_njit +from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( _lu_1, @@ -87,7 +88,8 @@ def cholesky(a): return res - return cholesky + # We cannot cache LAPACK functions + return cholesky, None @numba_funcify.register(PivotToPermutations) @@ -118,7 +120,7 @@ def numba_funcify_LU(op, node, **kwargs): if dtype in complex_dtypes: NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) - @numba_njit(inline="always") + @numba_njit def lu(a): if check_finite: if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): @@ -153,7 +155,8 @@ def lu(a): return res - return lu + # We cannot cache LAPACK functions + return lu, None @numba_funcify.register(LUFactor) @@ -177,7 +180,8 @@ def lu_factor(a): return LU, piv - return lu_factor + # We cannot cache LAPACK functions + return lu_factor, None @numba_funcify.register(BlockDiagonal) @@ -250,7 +254,8 @@ def solve(a, b): res = solve_fn(a, b, lower, overwrite_a, overwrite_b, check_finite, transposed) return res - return solve + # We cannot cache LAPACK functions + return solve, None @numba_funcify.register(SolveTriangular) @@ -291,7 +296,8 @@ def solve_triangular(a, b): return res - return solve_triangular + # We cannot cache LAPACK functions + return solve_triangular, None @numba_funcify.register(CholeskySolve) @@ -320,7 +326,8 @@ def cho_solve(c, b): c, b, lower=lower, overwrite_b=overwrite_b, check_finite=check_finite ) - return cho_solve + # We cannot cache LAPACK functions + return cho_solve, None @numba_funcify.register(QR) @@ -413,4 +420,5 @@ def qr(a): f"QR mode={mode}, pivoting={pivoting} not supported in numba mode." ) - return qr + # We cannot cache LAPACK functions + return qr, None diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index fe0eda153e..8f30597887 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,9 +1,10 @@ import numpy as np from pytensor.graph import Type +from pytensor.link.numba.compile import compile_and_cache_numba_function_src, numba_njit from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit -from pytensor.link.utils import compile_function_src, unique_name_generator +from pytensor.link.numba.dispatch.basic import generate_fallback_impl +from pytensor.link.utils import unique_name_generator from pytensor.tensor import TensorType from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( @@ -95,7 +96,7 @@ def {function_name}({", ".join(input_names)}): return np.asarray(z) """ - func = compile_function_src( + func = compile_and_cache_numba_function_src( subtensor_def_src, function_name=function_name, global_env=globals() | {"np": np}, diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 3a9d8767b9..531b695c23 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -2,9 +2,14 @@ import numpy as np +import pytensor.link.numba.compile +from pytensor.link.numba.compile import ( + compile_and_cache_numba_function_src, + create_tuple_string, +) from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify -from pytensor.link.utils import compile_function_src, unique_name_generator +from pytensor.link.numba.dispatch.basic import numba_funcify +from pytensor.link.utils import unique_name_generator from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -17,6 +22,7 @@ Split, TensorFromScalar, ) +from pytensor.utils import hash_from_code @numba_funcify.register(AllocEmpty) @@ -49,11 +55,16 @@ def allocempty({", ".join(shape_var_names)}): return np.empty(scalar_shape, dtype) """ - alloc_fn = compile_function_src( - alloc_def_src, "allocempty", {**globals(), **global_env} + alloc_fn = compile_and_cache_numba_function_src( + alloc_def_src, + "allocempty", + {**globals(), **global_env}, ) - return numba_basic.numba_njit(alloc_fn) + return ( + pytensor.link.numba.compile.numba_njit(alloc_fn), + hash_from_code(alloc_def_src), + ) @numba_funcify.register(Alloc) @@ -93,21 +104,25 @@ def alloc(val, {", ".join(shape_var_names)}): res[...] = val return res """ - alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env}) + alloc_fn = compile_and_cache_numba_function_src( + alloc_def_src, + "alloc", + {**globals(), **global_env}, + ) - return numba_basic.numba_njit(alloc_fn) + return pytensor.link.numba.compile.numba_njit(alloc_fn) @numba_funcify.register(ARange) def numba_funcify_ARange(op, **kwargs): dtype = np.dtype(op.dtype) - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def arange(start, stop, step): return np.arange( - numba_basic.to_scalar(start), - numba_basic.to_scalar(stop), - numba_basic.to_scalar(step), + start.item(), + stop.item(), + step.item(), dtype=dtype, ) @@ -116,7 +131,7 @@ def arange(start, stop, step): @numba_funcify.register(Join) def numba_funcify_Join(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def join(axis, *tensors): return np.concatenate(tensors, axis.item()) @@ -125,7 +140,7 @@ def join(axis, *tensors): @numba_funcify.register(Split) def numba_funcify_Split(op, **kwargs): - @numba_basic.numba_njit + @pytensor.link.numba.compile.numba_njit def split(tensor, axis, indices): return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item()) @@ -139,7 +154,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs): if node.inputs[0].type.ndim == 2: - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def extract_diag(x): out = np.diag(x, k=offset) @@ -154,7 +169,7 @@ def extract_diag(x): leading_dims = (slice(None),) * axis1 middle_dims = (slice(None),) * (axis2 - axis1 - 1) - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def extract_diag(x): if offset >= 0: diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset)) @@ -179,7 +194,7 @@ def extract_diag(x): def numba_funcify_Eye(op, **kwargs): dtype = np.dtype(op.dtype) - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit(inline="always") def eye(N, M, k): return np.eye( numba_basic.to_scalar(N), @@ -212,16 +227,19 @@ def makevector({", ".join(input_names)}): return np.array({create_list_string(input_names)}, dtype=dtype) """ - makevector_fn = compile_function_src( - makevector_def_src, "makevector", {**globals(), **global_env} + makevector_fn = compile_and_cache_numba_function_src( + makevector_def_src, + "makevector", + {**globals(), **global_env}, + ) + return pytensor.link.numba.compile.numba_njit(makevector_fn), hash_from_code( + makevector_def_src ) - - return numba_basic.numba_njit(makevector_fn) @numba_funcify.register(TensorFromScalar) def numba_funcify_TensorFromScalar(op, **kwargs): - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def tensor_from_scalar(x): return np.array(x) @@ -230,8 +248,8 @@ def tensor_from_scalar(x): @numba_funcify.register(ScalarFromTensor) def numba_funcify_ScalarFromTensor(op, **kwargs): - @numba_basic.numba_njit(inline="always") + @pytensor.link.numba.compile.numba_njit def scalar_from_tensor(x): - return numba_basic.to_scalar(x) + return x.item() - return scalar_from_tensor + return scalar_from_tensor, 0 diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 060418cb6c..f2fe02eff8 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -4,7 +4,7 @@ import pickle from collections.abc import Callable, Sequence from textwrap import indent -from typing import Any, cast +from typing import Any import numba import numpy as np @@ -15,8 +15,8 @@ from numba.core.types.misc import NoneType from numba.np import arrayobj -from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.utils import compile_function_src +import pytensor.link.numba.compile +from pytensor.link.numba.compile import compile_and_cache_numba_function_src def encode_literals(literals: Sequence) -> str: @@ -52,10 +52,13 @@ def store_core_outputs({inp_signature}, {out_signature}): {indent(store_outputs, " " * 4)} """ global_env = {"core_op_fn": core_op_fn} - func = compile_function_src( - func_src, "store_core_outputs", {**globals(), **global_env} + + func = compile_and_cache_numba_function_src( + func_src, + "store_core_outputs", + {**globals(), **global_env}, ) - return cast(Callable, numba_basic.numba_njit(func)) + return pytensor.link.numba.compile.numba_njit(func) _jit_options = { @@ -74,7 +77,7 @@ def store_core_outputs({inp_signature}, {out_signature}): @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) def _vectorized( typingctx, - scalar_func, + core_func, input_bc_patterns, output_bc_patterns, output_dtypes, @@ -85,7 +88,7 @@ def _vectorized( size_type, ): arg_types = [ - scalar_func, + core_func, input_bc_patterns, output_bc_patterns, output_dtypes, @@ -173,16 +176,6 @@ def _vectorized( ) out_types[output_idx] = output_type - core_signature = typingctx.resolve_function_type( - scalar_func, - [ - *constant_inputs_types, - *core_input_types, - *core_out_types, - ], - {}, - ) - ret_type = types.Tuple(out_types) if len(output_dtypes) == 1: @@ -239,11 +232,21 @@ def codegen( output_core_shapes, ) + core_signature = typingctx.resolve_function_type( + core_func, + [ + *constant_inputs_types, + *core_input_types, + *core_out_types, + ], + {}, + ) + make_loop_call( typingctx, ctx, builder, - scalar_func, + core_func, core_signature, iter_shape, constant_inputs, diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 59dc81e1b0..d8e8eb332a 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -4,16 +4,23 @@ class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" + def __init__(self, *args, vm: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.vm = vm + def fgraph_convert(self, fgraph, **kwargs): from pytensor.link.numba.dispatch import numba_funcify - return numba_funcify(fgraph, **kwargs) + return numba_funcify(fgraph, jit_nodes=self.vm, **kwargs) def jit_compile(self, fn): - from pytensor.link.numba.dispatch.basic import numba_njit + if self.vm: + return fn + else: + from pytensor.link.numba.compile import numba_njit - jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) - return jitted_fn + jitted_fn = numba_njit(fn, final_function=True) + return jitted_fn def create_thunk_inputs(self, storage_map): return [storage_map[n] for n in self.fgraph.inputs] diff --git a/pytensor/mod1.py b/pytensor/mod1.py new file mode 100644 index 0000000000..8ba7e5abe2 --- /dev/null +++ b/pytensor/mod1.py @@ -0,0 +1,6 @@ +import numba + + +@numba.extending.register_jitable(cache=True) +def foo(x): + return x + 1 diff --git a/pytensor/utils.py b/pytensor/utils.py index c81fb74f56..137d03f380 100644 --- a/pytensor/utils.py +++ b/pytensor/utils.py @@ -191,9 +191,8 @@ def hash_from_code(msg: str | bytes) -> str: # but Python 3 (unicode) strings don't. if isinstance(msg, str): msg = msg.encode() - # Python 3 does not like module names that start with - # a digit. - return "m" + hashlib.sha256(msg).hexdigest() + # Python 3 does not like module names that start with a digit. + return f"m{hashlib.sha256(msg).hexdigest()}" def uniq(seq: Sequence) -> list: diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 90589db337..4aaeb97663 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -12,7 +12,9 @@ from pytensor.compile.io import In, Out from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config -from pytensor.graph.basic import Constant +from pytensor.graph.basic import Constant, explicit_graph_inputs +from pytensor.graph.replace import graph_replace +from pytensor.graph.rewriting import rewrite_graph from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter from pytensor.graph.utils import MissingInputError from pytensor.link.vm import VMLinker @@ -1357,3 +1359,139 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark): rng_val = np.random.default_rng() benchmark(f, rng_val) + + +@pytest.fixture(scope="module") +def radon_model(): + def halfnormal(name, *, sigma=1.0, model_logp): + log_value = pt.scalar(f"{name}_log") + value = pt.exp(log_value) + + logp = ( + -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) + ) + logp = pt.switch(value >= 0, logp, -np.inf) + model_logp.append(logp + value) + return value + + def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): + value = pt.scalar(name) if observed is None else pt.as_tensor(observed) + + logp = ( + -0.5 * (((value - mu) / sigma) ** 2) + - pt.log(pt.sqrt(2.0 * np.pi)) + - pt.log(sigma) + ) + model_logp.append(logp) + return value + + def zerosumnormal(name, *, sigma=1.0, size, model_logp): + raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) + n = raw_value.shape[0] + 1 + sum_vals = raw_value.sum(0, keepdims=True) + norm = sum_vals / (pt.sqrt(n) + n) + fill_value = norm - sum_vals / pt.sqrt(n) + value = pt.concatenate([raw_value, fill_value]) - norm + + shape = value.shape + _full_size = pt.prod(shape) + _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) + logp = pt.sum( + -0.5 * ((value / sigma) ** 2) + - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) + * (_degrees_of_freedom / _full_size) + ) + model_logp.append(logp) + return value + + rng = np.random.default_rng(1) + n_counties = 85 + county_idx = rng.integers(n_counties, size=919) + county_idx.sort() + floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) + log_radon = rng.normal(size=919) + + model_logp = [] + intercept = normal("intercept", sigma=10, model_logp=model_logp) + + # County effects + county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) + county_sd = halfnormal("county_sd", model_logp=model_logp) + county_effect = county_raw * county_sd + + # Global floor effect + floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) + + county_floor_raw = zerosumnormal( + "county_floor_raw", size=n_counties, model_logp=model_logp + ) + county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) + county_floor_effect = county_floor_raw * county_floor_sd + + mu = ( + intercept + + county_effect[county_idx] + + floor_effect * floor + + county_floor_effect[county_idx] * floor + ) + + sigma = halfnormal("sigma", model_logp=model_logp) + _ = normal( + "log_radon", + mu=mu, + sigma=sigma, + observed=log_radon, + model_logp=model_logp, + ) + + model_logp = pt.sum([logp.sum() for logp in model_logp]) + model_logp = rewrite_graph( + model_logp, include=("canonicalize", "stabilize"), clone=False + ) + params = list(explicit_graph_inputs(model_logp)) + model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) + + size = sum(int(np.prod(p.type.shape)) for p in params) + joined_inputs = pt.vector("joined_inputs", shape=(size,)) + idx = 0 + replacement = {} + for param in params: + param_shape = param.type.shape + param_size = int(np.prod(param_shape)) + replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape) + idx += param_size + assert idx == joined_inputs.type.shape[0] + + model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement) + return joined_inputs, [model_logp, model_dlogp] + + +@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA", "NUMBA_VM"]) +def test_radon_model_compile_benchmark(mode, radon_model, benchmark): + joined_inputs, [model_logp, model_dlogp] = radon_model + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + + def compile_and_call_once(): + fn = function( + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True + ) + fn(x) + + benchmark(compile_and_call_once) + + +@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA", "NUMBA_VM"]) +def test_radon_model_call_benchmark(mode, radon_model, benchmark): + joined_inputs, [model_logp, model_dlogp] = radon_model + + real_mode = "C_VM" if mode == "C_VM_NOGC" else mode + fn = function( + [joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True + ) + if mode == "C_VM_NOGC": + fn.vm.allow_gc = False + + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + benchmark(fn, x) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index fd9a48111f..da058c60e9 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -8,6 +8,8 @@ import pytest import scipy +import pytensor.link.numba.cache +import pytensor.link.numba.compile from pytensor.compile import SymbolicInput @@ -26,7 +28,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.type import Type from pytensor.ifelse import ifelse -from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op from pytensor.scalar.basic import ScalarOp, as_scalar @@ -324,7 +325,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented): else pytest.raises(NotImplementedError) ) with cm: - res = numba_basic.get_numba_type(v, force_scalar=force_scalar) + res = pytensor.link.numba.compile.get_numba_type(v, force_scalar=force_scalar) assert res == expected @@ -366,7 +367,9 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented): ], ) def test_create_numba_signature(v, expected, force_scalar): - res = numba_basic.create_numba_signature(v, force_scalar=force_scalar) + res = pytensor.link.numba.compile.create_numba_signature( + v, force_scalar=force_scalar + ) assert res == expected @@ -950,3 +953,22 @@ def test_mat_vec_dot_performance(dtype, benchmark): x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype) np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4) benchmark(fn, A_test, x_test) + + +@pytest.mark.parametrize("use_cache", [False, True], ids=["no-cache", "cache"]) +@pytest.mark.parametrize("func", [pt.cos, pt.sin], ids=["cos", "sin"]) +def test_compile_time_benchmark(func, use_cache, benchmark): + x = pt.matrix("x") + y = func(x) + rng = np.random.default_rng(42) + x_test = rng.normal(size=(5, 3)) + + def compile(): + fn = function([x], y, mode="NUMBA_VM", trust_input=True) + return fn(x_test) + + pytensor.link.numba.cache.NUMBA_PYTENSOR_CACHE_ENABLED = use_cache + + res = compile() + np.testing.assert_allclose(res, y.eval({x: x_test})) + benchmark(compile) diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 625246e340..741021dff5 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -280,7 +280,7 @@ def test_ExtractDiag(val, offset): ) @pytest.mark.parametrize("reverse_axis", (False, True)) def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis): - from pytensor.link.numba.dispatch.basic import numba_njit + from pytensor.link.numba.compile import numba_njit if reverse_axis: axis1, axis2 = axis2, axis1 diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 4a78a1e9fe..aabe71b7dd 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1,4 +1,5 @@ import copy +import re import numpy as np import pytest @@ -126,7 +127,7 @@ _specialize_rewrites.position_cutoff = 2.01 _specialize_rewrites = optdb.query(_specialize_rewrites) -_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"], exclude=["inplace"]) _fast_run_rewrites = optdb.query(_fast_run_rewrites) @@ -306,7 +307,9 @@ def test_inconsistent_shared(self, shape_unsafe): # Error raised by Alloc Op with pytest.raises( ValueError, - match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", + match=re.escape( + "cannot assign slice of shape (3, 7) from input of shape (6, 7)" + ), ): f() @@ -473,6 +476,8 @@ def test_basic(self): x = scalar() y = scalar() f = function([x, y], assert_op(x, eq(x, y)), mode=mode) + f.dprint() + return assert f(1, 1) == 1 with pytest.raises(AssertionError): f(1, 0)