5050 "jax" : JAXLinker (),
5151 "pytorch" : PytorchLinker (),
5252 "numba" : NumbaLinker (),
53+ "numba_vm" : NumbaLinker (),
5354}
5455
5556
@@ -63,9 +64,8 @@ def register_linker(name, linker):
6364# If a string is passed as the optimizer argument in the constructor
6465# for Mode, it will be used as the key to retrieve the real optimizer
6566# in this dictionary
66- exclude = []
67- if not config .cxx :
68- exclude = ["cxx_only" ]
67+
68+ exclude = ["cxx_only" , "BlasOpt" ]
6969OPT_NONE = RewriteDatabaseQuery (include = [], exclude = exclude )
7070# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
7171OPT_MINIMUM = RewriteDatabaseQuery (include = ["minimum_compile" ], exclude = exclude )
@@ -351,6 +351,11 @@ def __setstate__(self, state):
351351 optimizer = predefined_optimizers [optimizer ]
352352 if isinstance (optimizer , RewriteDatabaseQuery ):
353353 self .provided_optimizer = optimizer
354+
355+ # Force numba-required rewrites if using NumbaLinker
356+ if isinstance (linker , NumbaLinker ):
357+ optimizer = optimizer .including ("numba" )
358+
354359 self ._optimizer = optimizer
355360 self .call_time = 0
356361 self .fn_time = 0
@@ -448,16 +453,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
448453# string as the key
449454# Use VM_linker to allow lazy evaluation by default.
450455FAST_COMPILE = Mode (
451- VMLinker (use_cloop = False , c_thunks = False ),
452- RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
456+ NumbaLinker (),
457+ # TODO: Fast_compile should just use python code, CHANGE ME!
458+ RewriteDatabaseQuery (
459+ include = ["fast_compile" , "numba" ],
460+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
461+ ),
462+ )
463+ FAST_RUN = Mode (
464+ NumbaLinker (),
465+ RewriteDatabaseQuery (
466+ include = ["fast_run" , "numba" ],
467+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
468+ ),
453469)
454- if config .cxx :
455- FAST_RUN = Mode ("cvm" , "fast_run" )
456- else :
457- FAST_RUN = Mode (
458- "vm" ,
459- RewriteDatabaseQuery (include = ["fast_run" , "py_only" ]),
460- )
461470
462471NUMBA = Mode (
463472 NumbaLinker (),
@@ -574,6 +583,7 @@ def register_mode(name, mode):
574583 Add a `Mode` which can be referred to by `name` in `function`.
575584
576585 """
586+ # TODO: Remove me
577587 if name in predefined_modes :
578588 raise ValueError (f"Mode name already taken: { name } " )
579589 predefined_modes [name ] = mode
0 commit comments