1212import math
1313import sys
1414import warnings
15- from collections .abc import Collection
15+ from collections .abc import Collection , Hashable
16+ from functools import lru_cache
1617from typing import (
1718 TYPE_CHECKING ,
1819 Any ,
6162_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
6263
6364
65+ @lru_cache (100 )
66+ def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
67+ try :
68+ mod = sys .modules [modname ]
69+ except KeyError :
70+ return False
71+ parent_cls = getattr (mod , clsname )
72+ return issubclass (cls , parent_cls )
73+
74+
6475def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
6576 """Return True if `x` is a zero-gradient array.
6677
6778 These arrays are a design quirk of Jax that may one day be removed.
6879 See https://github.com/google/jax/issues/20620.
6980 """
70- if "numpy" not in sys .modules or "jax" not in sys .modules :
81+ # Fast exit
82+ try :
83+ dtype = x .dtype # type: ignore[attr-defined]
84+ except AttributeError :
85+ return False
86+ cls = cast (Hashable , type (dtype ))
87+ if not _issubclass_fast (cls , "numpy.dtypes" , "VoidDType" ):
7188 return False
7289
73- import jax
74- import numpy as np
90+ if " jax" not in sys . modules :
91+ return False
7592
76- jax_float0 = cast ("np.dtype[np.void]" , jax .float0 )
77- return (
78- isinstance (x , np .ndarray )
79- and cast ("npt.NDArray[np.void]" , x ).dtype == jax_float0
80- )
93+ import jax
94+ # jax.float0 is a np.dtype([('float0', 'V')])
95+ return dtype == jax .float0
8196
8297
8398def is_numpy_array (x : object ) -> TypeGuard [npt .NDArray [Any ]]:
@@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
101116 is_jax_array
102117 is_pydata_sparse_array
103118 """
104- # Avoid importing NumPy if it isn't already
105- if "numpy" not in sys .modules :
106- return False
107-
108- import numpy as np
109-
110119 # TODO: Should we reject ndarray subclasses?
111- return (isinstance (x , (np .ndarray , np .generic ))
112- and not _is_jax_zero_gradient_array (x )) # pyright: ignore[reportUnknownArgumentType] # fmt: skip
120+ cls = cast (Hashable , type (x ))
121+ return (
122+ _issubclass_fast (cls , "numpy" , "ndarray" )
123+ or _issubclass_fast (cls , "numpy" , "generic" )
124+ ) and not _is_jax_zero_gradient_array (x )
113125
114126
115127def is_cupy_array (x : object ) -> bool :
@@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool:
133145 is_jax_array
134146 is_pydata_sparse_array
135147 """
136- # Avoid importing CuPy if it isn't already
137- if "cupy" not in sys .modules :
138- return False
139-
140- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
141-
142- # TODO: Should we reject ndarray subclasses?
143- return isinstance (x , cp .ndarray ) # pyright: ignore[reportUnknownMemberType]
148+ cls = cast (Hashable , type (x ))
149+ return _issubclass_fast (cls , "cupy" , "ndarray" )
144150
145151
146152def is_torch_array (x : object ) -> TypeIs [torch .Tensor ]:
@@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
161167 is_jax_array
162168 is_pydata_sparse_array
163169 """
164- # Avoid importing torch if it isn't already
165- if "torch" not in sys .modules :
166- return False
167-
168- import torch
169-
170- # TODO: Should we reject ndarray subclasses?
171- return isinstance (x , torch .Tensor )
170+ cls = cast (Hashable , type (x ))
171+ return _issubclass_fast (cls , "torch" , "Tensor" )
172172
173173
174174def is_ndonnx_array (x : object ) -> TypeIs [ndx .Array ]:
@@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
190190 is_jax_array
191191 is_pydata_sparse_array
192192 """
193- # Avoid importing torch if it isn't already
194- if "ndonnx" not in sys .modules :
195- return False
196-
197- import ndonnx as ndx
198-
199- return isinstance (x , ndx .Array )
193+ cls = cast (Hashable , type (x ))
194+ return _issubclass_fast (cls , "ndonnx" , "Array" )
200195
201196
202197def is_dask_array (x : object ) -> TypeIs [da .Array ]:
@@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
218213 is_jax_array
219214 is_pydata_sparse_array
220215 """
221- # Avoid importing dask if it isn't already
222- if "dask.array" not in sys .modules :
223- return False
224-
225- import dask .array
226-
227- return isinstance (x , dask .array .Array )
216+ cls = cast (Hashable , type (x ))
217+ return _issubclass_fast (cls , "dask.array" , "Array" )
228218
229219
230220def is_jax_array (x : object ) -> TypeIs [jax .Array ]:
@@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
247237 is_dask_array
248238 is_pydata_sparse_array
249239 """
250- # Avoid importing jax if it isn't already
251- if "jax" not in sys .modules :
252- return False
253-
254- import jax
255-
256- return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
240+ cls = cast (Hashable , type (x ))
241+ return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
257242
258243
259244def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
@@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
276261 is_dask_array
277262 is_jax_array
278263 """
279- # Avoid importing jax if it isn't already
280- if "sparse" not in sys .modules :
281- return False
282-
283- import sparse # pyright: ignore[reportMissingTypeStubs]
284-
285264 # TODO: Account for other backends.
286- return isinstance (x , sparse .SparseArray )
265+ cls = cast (Hashable , type (x ))
266+ return _issubclass_fast (cls , "sparse" , "SparseArray" )
287267
288268
289269def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
@@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
302282 is_jax_array
303283 """
304284 return (
305- is_numpy_array (x )
306- or is_cupy_array (x )
307- or is_torch_array (x )
308- or is_dask_array (x )
309- or is_jax_array (x )
310- or is_pydata_sparse_array (x )
311- or hasattr (x , "__array_namespace__" )
285+ hasattr (x , '__array_namespace__' )
286+ or _is_array_api_cls (cast (Hashable , type (x )))
287+ )
288+
289+
290+ @lru_cache (100 )
291+ def _is_array_api_cls (cls : type ) -> bool :
292+ return (
293+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
294+ _issubclass_fast (cls , "numpy" , "ndarray" )
295+ or _issubclass_fast (cls , "numpy" , "generic" )
296+ or _issubclass_fast (cls , "cupy" , "ndarray" )
297+ or _issubclass_fast (cls , "torch" , "Tensor" )
298+ or _issubclass_fast (cls , "dask.array" , "Array" )
299+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
300+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
301+ or _issubclass_fast (cls , "jax" , "Array" )
312302 )
313303
314304
@@ -317,6 +307,7 @@ def _compat_module_name() -> str:
317307 return __name__ .removesuffix (".common._helpers" )
318308
319309
310+ @lru_cache (100 )
320311def is_numpy_namespace (xp : Namespace ) -> bool :
321312 """
322313 Returns True if `xp` is a NumPy namespace.
@@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
338329 return xp .__name__ in {"numpy" , _compat_module_name () + ".numpy" }
339330
340331
332+ @lru_cache (100 )
341333def is_cupy_namespace (xp : Namespace ) -> bool :
342334 """
343335 Returns True if `xp` is a CuPy namespace.
@@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
359351 return xp .__name__ in {"cupy" , _compat_module_name () + ".cupy" }
360352
361353
354+ @lru_cache (100 )
362355def is_torch_namespace (xp : Namespace ) -> bool :
363356 """
364357 Returns True if `xp` is a PyTorch namespace.
@@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
399392 return xp .__name__ == "ndonnx"
400393
401394
395+ @lru_cache (100 )
402396def is_dask_namespace (xp : Namespace ) -> bool :
403397 """
404398 Returns True if `xp` is a Dask namespace.
@@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
939933 return None if math .isnan (out ) else out
940934
941935
936+ @lru_cache (100 )
937+ def _is_writeable_cls (cls : type ) -> bool | None :
938+ if (
939+ _issubclass_fast (cls , "numpy" , "generic" )
940+ or _issubclass_fast (cls , "jax" , "Array" )
941+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
942+ ):
943+ return False
944+ if _is_array_api_cls (cls ):
945+ return True
946+ return None
947+
948+
942949def is_writeable_array (x : object ) -> bool :
943950 """
944951 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool:
949956 As there is no standard way to check if an array is writeable without actually
950957 writing to it, this function blindly returns True for all unknown array types.
951958 """
952- if is_numpy_array (x ):
953- return x .flags .writeable
954- if is_jax_array (x ) or is_pydata_sparse_array (x ):
959+ cls = cast (Hashable , type (x ))
960+ if _issubclass_fast (cls , "numpy" , "ndarray" ):
961+ return cast ("npt.NDArray" , x ).flags .writeable
962+ res = _is_writeable_cls (cls )
963+ if res is not None :
964+ return res
965+ return hasattr (x , '__array_namespace__' )
966+
967+
968+ @lru_cache (100 )
969+ def _is_lazy_cls (cls : type ) -> bool | None :
970+ if (
971+ _issubclass_fast (cls , "numpy" , "ndarray" )
972+ or _issubclass_fast (cls , "numpy" , "generic" )
973+ or _issubclass_fast (cls , "cupy" , "ndarray" )
974+ or _issubclass_fast (cls , "torch" , "Tensor" )
975+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
976+ ):
955977 return False
956- return is_array_api_obj (x )
978+ if (
979+ _issubclass_fast (cls , "jax" , "Array" )
980+ or _issubclass_fast (cls , "dask.array" , "Array" )
981+ or _issubclass_fast (cls , "ndonnx" , "Array" )
982+ ):
983+ return True
984+ return None
957985
958986
959987def is_lazy_array (x : object ) -> bool :
@@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool:
969997 This function errs on the side of caution for array types that may or may not be
970998 lazy, e.g. JAX arrays, by always returning True for them.
971999 """
972- if (
973- is_numpy_array (x )
974- or is_cupy_array (x )
975- or is_torch_array (x )
976- or is_pydata_sparse_array (x )
977- ):
978- return False
979-
9801000 # **JAX note:** while it is possible to determine if you're inside or outside
9811001 # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
9821002 # as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool:
9861006 # compatibility, is highly detrimental to performance as the whole graph will end
9871007 # up being computed multiple times.
9881008
989- if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
990- return True
1009+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
1010+ # exclusively get them once they leave a jax.grad JIT context.
1011+ cls = cast (Hashable , type (x ))
1012+ res = _is_lazy_cls (cls )
1013+ if res is not None :
1014+ return res
9911015
992- if not is_array_api_obj ( x ):
1016+ if not hasattr ( x , "__array_namespace__" ):
9931017 return False
9941018
9951019 # Unknown Array API compatible object. Note that this test may have dire consequences
@@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool:
10421066 "to_device" ,
10431067]
10441068
1045- _all_ignore = [" sys" , " math" , " inspect" , " warnings" ]
1069+ _all_ignore = ['lru_cache' , ' sys' , ' math' , ' inspect' , ' warnings' ]
10461070
10471071def __dir__ () -> list [str ]:
10481072 return __all__
0 commit comments