Skip to content

Commit fb60edb

Browse files
committed
Make Numba the default backend
1 parent aefe989 commit fb60edb

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

pytensor/compile/mode.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"jax": JAXLinker(),
5151
"pytorch": PytorchLinker(),
5252
"numba": NumbaLinker(),
53+
"numba_vm": NumbaLinker(),
5354
}
5455

5556

@@ -63,9 +64,8 @@ def register_linker(name, linker):
6364
# If a string is passed as the optimizer argument in the constructor
6465
# for Mode, it will be used as the key to retrieve the real optimizer
6566
# in this dictionary
66-
exclude = []
67-
if not config.cxx:
68-
exclude = ["cxx_only"]
67+
68+
exclude = ["cxx_only", "BlasOpt"]
6969
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
7070
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
7171
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
@@ -351,6 +351,11 @@ def __setstate__(self, state):
351351
optimizer = predefined_optimizers[optimizer]
352352
if isinstance(optimizer, RewriteDatabaseQuery):
353353
self.provided_optimizer = optimizer
354+
355+
# Force numba-required rewrites if using NumbaLinker
356+
if isinstance(linker, NumbaLinker):
357+
optimizer = optimizer.including("numba")
358+
354359
self._optimizer = optimizer
355360
self.call_time = 0
356361
self.fn_time = 0
@@ -448,16 +453,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
448453
# string as the key
449454
# Use VM_linker to allow lazy evaluation by default.
450455
FAST_COMPILE = Mode(
451-
VMLinker(use_cloop=False, c_thunks=False),
452-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
456+
NumbaLinker(),
457+
# TODO: Fast_compile should just use python code, CHANGE ME!
458+
RewriteDatabaseQuery(
459+
include=["fast_compile", "numba"],
460+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
461+
),
462+
)
463+
FAST_RUN = Mode(
464+
NumbaLinker(),
465+
RewriteDatabaseQuery(
466+
include=["fast_run", "numba"],
467+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
468+
),
453469
)
454-
if config.cxx:
455-
FAST_RUN = Mode("cvm", "fast_run")
456-
else:
457-
FAST_RUN = Mode(
458-
"vm",
459-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
460-
)
461470

462471
NUMBA = Mode(
463472
NumbaLinker(),
@@ -574,6 +583,7 @@ def register_mode(name, mode):
574583
Add a `Mode` which can be referred to by `name` in `function`.
575584
576585
"""
586+
# TODO: Remove me
577587
if name in predefined_modes:
578588
raise ValueError(f"Mode name already taken: {name}")
579589
predefined_modes[name] = mode

pytensor/configdefaults.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,21 @@ def add_compile_configvars():
370370

371371
if rc == 0 and config.cxx != "":
372372
# Keep the default linker the same as the one for the mode FAST_RUN
373-
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
373+
linker_options = [
374+
"cvm",
375+
"c|py",
376+
"py",
377+
"c",
378+
"c|py_nogc",
379+
"vm",
380+
"vm_nogc",
381+
"cvm_nogc",
382+
"jax",
383+
]
374384
else:
375385
# g++ is not present or the user disabled it,
376386
# linker should default to python only.
377-
linker_options = ["py", "vm_nogc"]
387+
linker_options = ["py", "vm", "vm_nogc", "jax"]
378388
if type(config).cxx.is_default:
379389
# If the user provided an empty value for cxx, do not warn.
380390
_logger.warning(
@@ -388,7 +398,7 @@ def add_compile_configvars():
388398
"linker",
389399
"Default linker used if the pytensor flags mode is Mode",
390400
# Not mutable because the default mode is cached after the first use.
391-
EnumStr("cvm", linker_options, mutable=False),
401+
EnumStr("numba", linker_options, mutable=False),
392402
in_c_key=False,
393403
)
394404

tests/tensor/rewriting/test_basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import re
23

34
import numpy as np
45
import pytest
@@ -306,7 +307,9 @@ def test_inconsistent_shared(self, shape_unsafe):
306307
# Error raised by Alloc Op
307308
with pytest.raises(
308309
ValueError,
309-
match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)",
310+
match=re.escape(
311+
"cannot assign slice of shape (3, 7) from input of shape (6, 7)"
312+
),
310313
):
311314
f()
312315

@@ -1203,6 +1206,7 @@ def test_sum_bool_upcast(self):
12031206
f(5)
12041207

12051208

1209+
@pytest.mark.xfail(reason="Numba does not support float16")
12061210
class TestLocalOptAllocF16(TestLocalOptAlloc):
12071211
dtype = "float16"
12081212

0 commit comments

Comments
 (0)