@@ -48,6 +48,7 @@ def is_numpy_array(x):
4848 is_array_api_obj
4949 is_cupy_array
5050 is_torch_array
51+ is_ndonnx_array
5152 is_dask_array
5253 is_jax_array
5354 is_pydata_sparse_array
@@ -78,11 +79,12 @@ def is_cupy_array(x):
7879 is_array_api_obj
7980 is_numpy_array
8081 is_torch_array
82+ is_ndonnx_array
8183 is_dask_array
8284 is_jax_array
8385 is_pydata_sparse_array
8486 """
85- # Avoid importing NumPy if it isn't already
87+ # Avoid importing CuPy if it isn't already
8688 if 'cupy' not in sys .modules :
8789 return False
8890
@@ -118,6 +120,33 @@ def is_torch_array(x):
118120 # TODO: Should we reject ndarray subclasses?
119121 return isinstance (x , torch .Tensor )
120122
123+ def is_ndonnx_array (x ):
124+ """
125+ Return True if `x` is a ndonnx Array.
126+
127+ This function does not import ndonnx if it has not already been imported
128+ and is therefore cheap to use.
129+
130+ See Also
131+ --------
132+
133+ array_namespace
134+ is_array_api_obj
135+ is_numpy_array
136+ is_cupy_array
137+ is_ndonnx_array
138+ is_dask_array
139+ is_jax_array
140+ is_pydata_sparse_array
141+ """
142+ # Avoid importing torch if it isn't already
143+ if 'ndonnx' not in sys .modules :
144+ return False
145+
146+ import ndonnx as ndx
147+
148+ return isinstance (x , ndx .Array )
149+
121150def is_dask_array (x ):
122151 """
123152 Return True if `x` is a dask.array Array.
@@ -133,6 +162,7 @@ def is_dask_array(x):
133162 is_numpy_array
134163 is_cupy_array
135164 is_torch_array
165+ is_ndonnx_array
136166 is_jax_array
137167 is_pydata_sparse_array
138168 """
@@ -160,6 +190,7 @@ def is_jax_array(x):
160190 is_numpy_array
161191 is_cupy_array
162192 is_torch_array
193+ is_ndonnx_array
163194 is_dask_array
164195 is_pydata_sparse_array
165196 """
@@ -188,6 +219,7 @@ def is_pydata_sparse_array(x) -> bool:
188219 is_numpy_array
189220 is_cupy_array
190221 is_torch_array
222+ is_ndonnx_array
191223 is_dask_array
192224 is_jax_array
193225 """
@@ -211,6 +243,7 @@ def is_array_api_obj(x):
211243 is_numpy_array
212244 is_cupy_array
213245 is_torch_array
246+ is_ndonnx_array
214247 is_dask_array
215248 is_jax_array
216249 """
@@ -613,6 +646,7 @@ def size(x):
613646 "is_jax_array" ,
614647 "is_numpy_array" ,
615648 "is_torch_array" ,
649+ "is_ndonnx_array" ,
616650 "is_pydata_sparse_array" ,
617651 "size" ,
618652 "to_device" ,
0 commit comments