Skip to content

Commit 6d71d57

Browse files
committed
Move TypeCastingOp dispatcher to basic.py
This isn't strictly needed but it's a more intuitive placement
1 parent d12e7ac commit 6d71d57

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.compile import NUMBA
1111
from pytensor.compile.builders import OpFromGraph
1212
from pytensor.compile.function.types import add_supervisor_to_fgraph
13-
from pytensor.compile.ops import DeepCopyOp
13+
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
1414
from pytensor.graph.basic import Apply
1515
from pytensor.graph.fg import FunctionGraph
1616
from pytensor.graph.type import Type
@@ -328,6 +328,15 @@ def opfromgraph(*inputs):
328328
return opfromgraph
329329

330330

331+
@numba_funcify.register(TypeCastingOp)
332+
def numba_funcify_type_casting(op, **kwargs):
333+
@numba_njit
334+
def identity(x):
335+
return x
336+
337+
return identity
338+
339+
331340
@numba_funcify.register(DeepCopyOp)
332341
def numba_funcify_DeepCopyOp(op, node, **kwargs):
333342
if isinstance(node.inputs[0].type, TensorType):

pytensor/link/numba/dispatch/scalar.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44

5-
from pytensor.compile.ops import TypeCastingOp
65
from pytensor.graph.basic import Variable
76
from pytensor.link.numba.dispatch import basic as numba_basic
87
from pytensor.link.numba.dispatch.basic import (
@@ -197,7 +196,6 @@ def cast(x):
197196

198197

199198
@numba_funcify.register(Identity)
200-
@numba_funcify.register(TypeCastingOp)
201199
def numba_funcify_type_casting(op, **kwargs):
202200
@numba_basic.numba_njit
203201
def identity(x):

0 commit comments

Comments
 (0)