File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed
pytensor/link/numba/dispatch Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change 1010from pytensor .compile import NUMBA
1111from pytensor .compile .builders import OpFromGraph
1212from pytensor .compile .function .types import add_supervisor_to_fgraph
13- from pytensor .compile .ops import DeepCopyOp
13+ from pytensor .compile .ops import DeepCopyOp , TypeCastingOp
1414from pytensor .graph .basic import Apply
1515from pytensor .graph .fg import FunctionGraph
1616from pytensor .graph .type import Type
@@ -328,6 +328,15 @@ def opfromgraph(*inputs):
328328 return opfromgraph
329329
330330
331+ @numba_funcify .register (TypeCastingOp )
332+ def numba_funcify_type_casting (op , ** kwargs ):
333+ @numba_njit
334+ def identity (x ):
335+ return x
336+
337+ return identity
338+
339+
331340@numba_funcify .register (DeepCopyOp )
332341def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
333342 if isinstance (node .inputs [0 ].type , TensorType ):
Original file line number Diff line number Diff line change 22
33import numpy as np
44
5- from pytensor .compile .ops import TypeCastingOp
65from pytensor .graph .basic import Variable
76from pytensor .link .numba .dispatch import basic as numba_basic
87from pytensor .link .numba .dispatch .basic import (
@@ -197,7 +196,6 @@ def cast(x):
197196
198197
199198@numba_funcify .register (Identity )
200- @numba_funcify .register (TypeCastingOp )
201199def numba_funcify_type_casting (op , ** kwargs ):
202200 @numba_basic .numba_njit
203201 def identity (x ):
You can’t perform that action at this time.
0 commit comments