1+ from hashlib import sha256
2+
13import numpy as np
24
35from pytensor .compile .builders import OpFromGraph
810from pytensor .ifelse import IfElse
911from pytensor .link .numba .dispatch import basic as numba_basic
1012from pytensor .link .numba .dispatch .basic import (
11- numba_funcify ,
13+ numba_funcify_and_cache_key ,
1214 numba_njit ,
15+ register_funcify_and_cache_key ,
16+ register_funcify_default_op_cache_key ,
1317)
1418from pytensor .raise_op import CheckAndRaise
1519from pytensor .tensor .type import TensorType
1620
1721
18- @numba_funcify . register (OpFromGraph )
22+ @register_funcify_and_cache_key (OpFromGraph )
1923def numba_funcify_OpFromGraph (op , node = None , ** kwargs ):
2024 _ = kwargs .pop ("storage_map" , None )
2125
@@ -30,7 +34,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
3034 accept_inplace = True ,
3135 )
3236 NUMBA .optimizer (fgraph )
33- fgraph_fn = numba_njit ( numba_funcify ( op .fgraph , ** kwargs ) )
37+ fgraph_fn , fgraph_cache_key = numba_funcify_and_cache_key ( op .fgraph , ** kwargs )
3438
3539 if len (op .fgraph .outputs ) == 1 :
3640
@@ -44,10 +48,23 @@ def opfromgraph(*inputs):
4448 def opfromgraph (* inputs ):
4549 return fgraph_fn (* inputs )
4650
47- return opfromgraph
51+ if fgraph_cache_key is None :
52+ # Can't cache the inner graph
53+ ofg_cache_key = None
54+ else :
55+ ofg_cache_key = sha256 (
56+ str (
57+ (
58+ type (op ),
59+ fgraph_cache_key ,
60+ )
61+ ).encode ()
62+ ).hexdigest ()
63+
64+ return opfromgraph , ofg_cache_key
4865
4966
50- @numba_funcify . register (TypeCastingOp )
67+ @register_funcify_default_op_cache_key (TypeCastingOp )
5168def numba_funcify_type_casting (op , ** kwargs ):
5269 @numba_njit
5370 def identity (x ):
@@ -56,7 +73,7 @@ def identity(x):
5673 return identity
5774
5875
59- @numba_funcify . register (DeepCopyOp )
76+ @register_funcify_default_op_cache_key (DeepCopyOp )
6077def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
6178 if isinstance (node .inputs [0 ].type , TensorType ):
6279
@@ -73,7 +90,7 @@ def deepcopy(x):
7390 return deepcopy
7491
7592
76- @numba_funcify . register (IfElse )
93+ @register_funcify_default_op_cache_key (IfElse )
7794def numba_funcify_IfElse (op , ** kwargs ):
7895 n_outs = op .n_outs
7996
@@ -102,7 +119,7 @@ def ifelse(cond, *args):
102119 return ifelse
103120
104121
105- @numba_funcify . register (CheckAndRaise )
122+ @register_funcify_and_cache_key (CheckAndRaise )
106123def numba_funcify_CheckAndRaise (op , node , ** kwargs ):
107124 error = op .exc_type
108125 msg = op .msg
@@ -114,4 +131,5 @@ def check_and_raise(x, *conditions):
114131 raise error (msg )
115132 return x
116133
117- return check_and_raise
134+ cache_key = sha256 (str ((type (op ), error , msg )).encode ()).hexdigest ()
135+ return check_and_raise , cache_key
0 commit comments