Skip to content

Commit 8b9c772

Browse files
committed
Implement Numba Op dispatch cache key
1 parent d604423 commit 8b9c772

File tree

17 files changed

+448
-276
lines changed

17 files changed

+448
-276
lines changed

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1+
from hashlib import sha256
12
from typing import cast
23

34
from numba.core.extending import overload
45
from numba.np.unsafe.ndarray import to_fixed_tuple
56

6-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
7+
from pytensor.link.numba.cache import compile_numba_function_src
8+
from pytensor.link.numba.dispatch.basic import (
9+
numba_funcify_and_cache_key,
10+
numba_njit,
11+
register_funcify_and_cache_key,
12+
)
713
from pytensor.link.numba.dispatch.vectorize_codegen import (
814
_jit_options,
915
_vectorized,
1016
encode_literals,
1117
store_core_outputs,
1218
)
13-
from pytensor.link.utils import compile_function_src
1419
from pytensor.tensor import TensorVariable, get_vector_length
1520
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
1621

1722

18-
@numba_funcify.register(BlockwiseWithCoreShape)
23+
@register_funcify_and_cache_key(BlockwiseWithCoreShape)
1924
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
2025
[blockwise_node] = op.fgraph.apply_nodes
2126
blockwise_op: Blockwise = blockwise_node.op
@@ -28,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
2833
cast(tuple[TensorVariable], node.inputs[:nin]),
2934
propagate_unbatched_core_inputs=True,
3035
)
31-
core_op_fn = numba_funcify(
36+
core_op_fn, core_op_key = numba_funcify_and_cache_key(
3237
core_op,
3338
node=core_node,
3439
parent_node=node,
@@ -56,7 +61,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
5661
src += ")"
5762

5863
to_tuple = numba_njit(
59-
compile_function_src(
64+
compile_numba_function_src(
6065
src,
6166
"to_tuple",
6267
global_env={"to_fixed_tuple": to_fixed_tuple},
@@ -85,9 +90,27 @@ def blockwise(*inputs_and_core_shapes):
8590
def ov_blockwise(*inputs_and_core_shapes):
8691
return blockwise_wrapper
8792

88-
# The outer caller won't create a wrapper for our overloaded function, so we do it here
89-
@numba_njit
90-
def blockwise_wrapped(*args):
91-
return blockwise(*args)
93+
if core_op_key is None:
94+
# We were told the core op cannot be cached
95+
# The outer caller won't create a wrapper for our overloaded function, so we do it here
96+
@numba_njit
97+
def blockwise_wrapped(*args):
98+
return blockwise(*args)
9299

93-
return blockwise_wrapped
100+
return blockwise_wrapped, None
101+
else:
102+
blockwise_key = "_".join(
103+
map(
104+
str,
105+
(
106+
type(op),
107+
type(blockwise_op),
108+
tuple(blockwise_op.destroy_map.items()),
109+
blockwise_op.signature,
110+
input_bc_patterns,
111+
core_op_key,
112+
),
113+
)
114+
)
115+
blockwise_key = sha256(blockwise_key.encode()).hexdigest()
116+
return blockwise, blockwise_key

pytensor/link/numba/dispatch/compile_ops.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from hashlib import sha256
2+
13
import numpy as np
24

35
from pytensor.compile.builders import OpFromGraph
@@ -8,14 +10,16 @@
810
from pytensor.ifelse import IfElse
911
from pytensor.link.numba.dispatch import basic as numba_basic
1012
from pytensor.link.numba.dispatch.basic import (
11-
numba_funcify,
13+
numba_funcify_and_cache_key,
1214
numba_njit,
15+
register_funcify_and_cache_key,
16+
register_funcify_default_op_cache_key,
1317
)
1418
from pytensor.raise_op import CheckAndRaise
1519
from pytensor.tensor.type import TensorType
1620

1721

18-
@numba_funcify.register(OpFromGraph)
22+
@register_funcify_and_cache_key(OpFromGraph)
1923
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
2024
_ = kwargs.pop("storage_map", None)
2125

@@ -30,7 +34,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
3034
accept_inplace=True,
3135
)
3236
NUMBA.optimizer(fgraph)
33-
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
37+
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(op.fgraph, **kwargs)
3438

3539
if len(op.fgraph.outputs) == 1:
3640

@@ -44,10 +48,23 @@ def opfromgraph(*inputs):
4448
def opfromgraph(*inputs):
4549
return fgraph_fn(*inputs)
4650

47-
return opfromgraph
51+
if fgraph_cache_key is None:
52+
# Can't cache the inner graph
53+
ofg_cache_key = None
54+
else:
55+
ofg_cache_key = sha256(
56+
str(
57+
(
58+
type(op),
59+
fgraph_cache_key,
60+
)
61+
).encode()
62+
).hexdigest()
63+
64+
return opfromgraph, ofg_cache_key
4865

4966

50-
@numba_funcify.register(TypeCastingOp)
67+
@register_funcify_default_op_cache_key(TypeCastingOp)
5168
def numba_funcify_type_casting(op, **kwargs):
5269
@numba_njit
5370
def identity(x):
@@ -56,7 +73,7 @@ def identity(x):
5673
return identity
5774

5875

59-
@numba_funcify.register(DeepCopyOp)
76+
@register_funcify_default_op_cache_key(DeepCopyOp)
6077
def numba_funcify_DeepCopyOp(op, node, **kwargs):
6178
if isinstance(node.inputs[0].type, TensorType):
6279

@@ -73,7 +90,7 @@ def deepcopy(x):
7390
return deepcopy
7491

7592

76-
@numba_funcify.register(IfElse)
93+
@register_funcify_default_op_cache_key(IfElse)
7794
def numba_funcify_IfElse(op, **kwargs):
7895
n_outs = op.n_outs
7996

@@ -102,7 +119,7 @@ def ifelse(cond, *args):
102119
return ifelse
103120

104121

105-
@numba_funcify.register(CheckAndRaise)
122+
@register_funcify_and_cache_key(CheckAndRaise)
106123
def numba_funcify_CheckAndRaise(op, node, **kwargs):
107124
error = op.exc_type
108125
msg = op.msg
@@ -114,4 +131,5 @@ def check_and_raise(x, *conditions):
114131
raise error(msg)
115132
return x
116133

117-
return check_and_raise
134+
cache_key = sha256(str((type(op), error, msg)).encode()).hexdigest()
135+
return check_and_raise, cache_key

0 commit comments

Comments
 (0)