@@ -120,6 +120,33 @@ def is_torch_array(x):
120120 # TODO: Should we reject ndarray subclasses?
121121 return isinstance (x , torch .Tensor )
122122
123+ def is_paddle_array (x ):
124+ """
125+ Return True if `x` is a Paddle tensor.
126+
127+ This function does not import Paddle if it has not already been imported
128+ and is therefore cheap to use.
129+
130+ See Also
131+ --------
132+
133+ array_namespace
134+ is_array_api_obj
135+ is_numpy_array
136+ is_cupy_array
137+ is_dask_array
138+ is_jax_array
139+ is_pydata_sparse_array
140+ """
141+ # Avoid importing paddle if it isn't already
142+ if 'paddle' not in sys .modules :
143+ return False
144+
145+ import paddle
146+
147+ # TODO: Should we reject ndarray subclasses?
148+ return paddle .is_tensor (x )
149+
123150def is_ndonnx_array (x ):
124151 """
125152 Return True if `x` is a ndonnx Array.
@@ -252,6 +279,7 @@ def is_array_api_obj(x):
252279 or is_dask_array (x ) \
253280 or is_jax_array (x ) \
254281 or is_pydata_sparse_array (x ) \
282+ or is_paddle_array (x ) \
255283 or hasattr (x , '__array_namespace__' )
256284
257285def _compat_module_name ():
@@ -319,6 +347,27 @@ def is_torch_namespace(xp) -> bool:
319347 return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320348
321349
350+ def is_paddle_namespace (xp ) -> bool :
351+ """
352+ Returns True if `xp` is a Paddle namespace.
353+
354+ This includes both Paddle itself and the version wrapped by array-api-compat.
355+
356+ See Also
357+ --------
358+
359+ array_namespace
360+ is_numpy_namespace
361+ is_cupy_namespace
362+ is_ndonnx_namespace
363+ is_dask_namespace
364+ is_jax_namespace
365+ is_pydata_sparse_namespace
366+ is_array_api_strict_namespace
367+ """
368+ return xp .__name__ in {'paddle' , _compat_module_name () + '.paddle' }
369+
370+
322371def is_ndonnx_namespace (xp ):
323372 """
324373 Returns True if `xp` is an NDONNX namespace.
@@ -543,6 +592,14 @@ def your_function(x, y):
543592 else :
544593 import jax .experimental .array_api as jnp
545594 namespaces .add (jnp )
595+ elif is_paddle_array (x ):
596+ if _use_compat :
597+ _check_api_version (api_version )
598+ from .. import paddle as paddle_namespace
599+ namespaces .add (paddle_namespace )
600+ else :
601+ import paddle
602+ namespaces .add (paddle )
546603 elif is_pydata_sparse_array (x ):
547604 if use_compat is True :
548605 _check_api_version (api_version )
@@ -660,6 +717,16 @@ def device(x: Array, /) -> Device:
660717 return "cpu"
661718 # Return the device of the constituent array
662719 return device (inner )
720+ elif is_paddle_array (x ):
721+ raw_place_str = str (x .place )
722+ if "gpu_pinned" in raw_place_str :
723+ return "cpu"
724+ elif "cpu" in raw_place_str :
725+ return "cpu"
726+ elif "gpu" in raw_place_str :
727+ return "gpu"
728+ raise NotImplementedError (f"Unsupported device { raw_place_str } " )
729+
663730 return x .device
664731
665732# Prevent shadowing, used below
@@ -709,6 +776,14 @@ def _torch_to_device(x, device, /, stream=None):
709776 raise NotImplementedError
710777 return x .to (device )
711778
779+ def _paddle_to_device (x , device , / , stream = None ):
780+ if stream is not None :
781+ raise NotImplementedError (
782+ "paddle.Tensor.to() do not support stream argument yet"
783+ )
784+ return x .to (device )
785+
786+
712787def to_device (x : Array , device : Device , / , * , stream : Optional [Union [int , Any ]] = None ) -> Array :
713788 """
714789 Copy the array from the device on which it currently resides to the specified ``device``.
@@ -781,6 +856,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
781856 # In JAX v0.4.31 and older, this import adds to_device method to x.
782857 import jax .experimental .array_api # noqa: F401
783858 return x .to_device (device , stream = stream )
859+ elif is_paddle_array (x ):
860+ return _paddle_to_device (x , device , stream = stream )
784861 elif is_pydata_sparse_array (x ) and device == _device (x ):
785862 # Perform trivial check to return the same array if
786863 # device is same instead of err-ing.
@@ -819,6 +896,8 @@ def size(x):
819896 "is_torch_namespace" ,
820897 "is_ndonnx_array" ,
821898 "is_ndonnx_namespace" ,
899+ "is_paddle_array" ,
900+ "is_paddle_namespace" ,
822901 "is_pydata_sparse_array" ,
823902 "is_pydata_sparse_namespace" ,
824903 "size" ,
0 commit comments