From 8e5cc941152d3f4019f70071c2bbfe12c46669f2 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 12:44:01 +0800
Subject: [PATCH 01/28] add paddle support in array-api-compat
---
.github/workflows/array-api-tests-paddle.yml | 11 +
array_api_compat/common/_helpers.py | 79 ++
array_api_compat/paddle/__init__.py | 28 +
array_api_compat/paddle/_aliases.py | 1153 ++++++++++++++++++
array_api_compat/paddle/_info.py | 373 ++++++
array_api_compat/paddle/fft.py | 92 ++
array_api_compat/paddle/linalg.py | 136 +++
array_api_compat/torch/fft.py | 26 +-
array_api_compat/torch/linalg.py | 76 +-
docs/index.md | 4 +
docs/supported-array-libraries.md | 23 +
requirements-dev.txt | 1 +
tests/_helpers.py | 13 +-
tests/test_array_namespace.py | 76 +-
tests/test_common.py | 28 +-
tests/test_isdtype.py | 2 +-
tests/test_no_dependencies.py | 8 +-
tests/test_vendoring.py | 26 +-
vendor_test/uses_paddle.py | 30 +
19 files changed, 2088 insertions(+), 97 deletions(-)
create mode 100644 .github/workflows/array-api-tests-paddle.yml
create mode 100644 array_api_compat/paddle/__init__.py
create mode 100644 array_api_compat/paddle/_aliases.py
create mode 100644 array_api_compat/paddle/_info.py
create mode 100644 array_api_compat/paddle/fft.py
create mode 100644 array_api_compat/paddle/linalg.py
create mode 100644 vendor_test/uses_paddle.py
diff --git a/.github/workflows/array-api-tests-paddle.yml b/.github/workflows/array-api-tests-paddle.yml
new file mode 100644
index 00000000..d4f88b00
--- /dev/null
+++ b/.github/workflows/array-api-tests-paddle.yml
@@ -0,0 +1,11 @@
+name: Array API Tests (Paddle Latest)
+
+on: [push, pull_request]
+
+jobs:
+ array-api-tests-paddle:
+ uses: ./.github/workflows/array-api-tests.yml
+ with:
+ package-name: paddle
+ extra-env-vars: |
+ ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index b011f08d..ff2c213f 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -120,6 +120,33 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)
+def is_paddle_array(x):
+ """
+ Return True if `x` is a Paddle tensor.
+
+ This function does not import Paddle if it has not already been imported
+ and is therefore cheap to use.
+
+ See Also
+ --------
+
+ array_namespace
+ is_array_api_obj
+ is_numpy_array
+ is_cupy_array
+ is_dask_array
+ is_jax_array
+ is_pydata_sparse_array
+ """
+ # Avoid importing paddle if it isn't already
+ if 'paddle' not in sys.modules:
+ return False
+
+ import paddle
+
+ # TODO: Should we reject ndarray subclasses?
+ return paddle.is_tensor(x)
+
def is_ndonnx_array(x):
"""
Return True if `x` is a ndonnx Array.
@@ -252,6 +279,7 @@ def is_array_api_obj(x):
or is_dask_array(x) \
or is_jax_array(x) \
or is_pydata_sparse_array(x) \
+ or is_paddle_array(x) \
or hasattr(x, '__array_namespace__')
def _compat_module_name():
@@ -319,6 +347,27 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
+def is_paddle_namespace(xp) -> bool:
+ """
+ Returns True if `xp` is a Paddle namespace.
+
+ This includes both Paddle itself and the version wrapped by array-api-compat.
+
+ See Also
+ --------
+
+ array_namespace
+ is_numpy_namespace
+ is_cupy_namespace
+ is_ndonnx_namespace
+ is_dask_namespace
+ is_jax_namespace
+ is_pydata_sparse_namespace
+ is_array_api_strict_namespace
+ """
+ return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}
+
+
def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.
@@ -543,6 +592,14 @@ def your_function(x, y):
else:
import jax.experimental.array_api as jnp
namespaces.add(jnp)
+ elif is_paddle_array(x):
+ if _use_compat:
+ _check_api_version(api_version)
+ from .. import paddle as paddle_namespace
+ namespaces.add(paddle_namespace)
+ else:
+ import paddle
+ namespaces.add(paddle)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
@@ -660,6 +717,16 @@ def device(x: Array, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner)
+ elif is_paddle_array(x):
+ raw_place_str = str(x.place)
+ if "gpu_pinned" in raw_place_str:
+ return "cpu"
+ elif "cpu" in raw_place_str:
+ return "cpu"
+ elif "gpu" in raw_place_str:
+ return "gpu"
+ raise NotImplementedError(f"Unsupported device {raw_place_str}")
+
return x.device
# Prevent shadowing, used below
@@ -709,6 +776,14 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)
+def _paddle_to_device(x, device, /, stream=None):
+ if stream is not None:
+ raise NotImplementedError(
+ "paddle.Tensor.to() do not support stream argument yet"
+ )
+ return x.to(device)
+
+
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -781,6 +856,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# In JAX v0.4.31 and older, this import adds to_device method to x.
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
+ elif is_paddle_array(x):
+ return _paddle_to_device(x, device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
@@ -819,6 +896,8 @@ def size(x):
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
+ "is_paddle_array",
+ "is_paddle_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
diff --git a/array_api_compat/paddle/__init__.py b/array_api_compat/paddle/__init__.py
new file mode 100644
index 00000000..9f96fa9f
--- /dev/null
+++ b/array_api_compat/paddle/__init__.py
@@ -0,0 +1,28 @@
+from paddle import * # noqa: F403
+
+# Several names are not included in the above import *
+import paddle
+
+for n in dir(paddle):
+ if (
+ n.startswith("_")
+ or n.endswith("_")
+ or "gpu" in n
+ or "cpu" in n
+ or "backward" in n
+ ):
+ continue
+ exec(n + " = paddle." + n)
+ exec("asarray = paddle.to_tensor")
+
+# These imports may overwrite names from the import * above.
+from ._aliases import * # noqa: F403
+
+# See the comment in the numpy __init__.py
+__import__(__package__ + ".linalg")
+
+__import__(__package__ + ".fft")
+
+from ..common._helpers import * # noqa: F403
+
+__array_api_version__ = "2023.12"
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
new file mode 100644
index 00000000..dabe2928
--- /dev/null
+++ b/array_api_compat/paddle/_aliases.py
@@ -0,0 +1,1153 @@
+from __future__ import annotations
+
+from functools import wraps as _wraps
+from builtins import all as _builtin_all, any as _builtin_any
+
+from ..common._aliases import (
+ matrix_transpose as _aliases_matrix_transpose,
+ vecdot as _aliases_vecdot,
+ clip as _aliases_clip,
+ unstack as _aliases_unstack,
+ cumulative_sum as _aliases_cumulative_sum,
+)
+from .._internal import get_xp
+
+from ._info import __array_namespace_info__
+
+import paddle
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from typing import List, Optional, Sequence, Tuple, Union
+ from ..common._typing import Device
+ from paddle import dtype as Dtype
+
+ array = paddle.Tensor
+
+_int_dtypes = {
+ paddle.uint8,
+ paddle.int8,
+ paddle.int16,
+ paddle.int32,
+ paddle.int64,
+}
+
+_array_api_dtypes = {
+ paddle.bool,
+ *_int_dtypes,
+ paddle.float32,
+ paddle.float64,
+ paddle.complex64,
+ paddle.complex128,
+}
+
+_promotion_table = {
+ # bool
+ (paddle.bool, paddle.bool): paddle.bool,
+ # ints
+ (paddle.int8, paddle.int8): paddle.int8,
+ (paddle.int8, paddle.int16): paddle.int16,
+ (paddle.int8, paddle.int32): paddle.int32,
+ (paddle.int8, paddle.int64): paddle.int64,
+ (paddle.int16, paddle.int8): paddle.int16,
+ (paddle.int16, paddle.int16): paddle.int16,
+ (paddle.int16, paddle.int32): paddle.int32,
+ (paddle.int16, paddle.int64): paddle.int64,
+ (paddle.int32, paddle.int8): paddle.int32,
+ (paddle.int32, paddle.int16): paddle.int32,
+ (paddle.int32, paddle.int32): paddle.int32,
+ (paddle.int32, paddle.int64): paddle.int64,
+ (paddle.int64, paddle.int8): paddle.int64,
+ (paddle.int64, paddle.int16): paddle.int64,
+ (paddle.int64, paddle.int32): paddle.int64,
+ (paddle.int64, paddle.int64): paddle.int64,
+ # uints
+ (paddle.uint8, paddle.uint8): paddle.uint8,
+ # ints and uints (mixed sign)
+ (paddle.int8, paddle.uint8): paddle.int16,
+ (paddle.int16, paddle.uint8): paddle.int16,
+ (paddle.int32, paddle.uint8): paddle.int32,
+ (paddle.int64, paddle.uint8): paddle.int64,
+ (paddle.uint8, paddle.int8): paddle.int16,
+ (paddle.uint8, paddle.int16): paddle.int16,
+ (paddle.uint8, paddle.int32): paddle.int32,
+ (paddle.uint8, paddle.int64): paddle.int64,
+ # floats
+ (paddle.float32, paddle.float32): paddle.float32,
+ (paddle.float32, paddle.float64): paddle.float64,
+ (paddle.float64, paddle.float32): paddle.float64,
+ (paddle.float64, paddle.float64): paddle.float64,
+ # complexes
+ (paddle.complex64, paddle.complex64): paddle.complex64,
+ (paddle.complex64, paddle.complex128): paddle.complex128,
+ (paddle.complex128, paddle.complex64): paddle.complex128,
+ (paddle.complex128, paddle.complex128): paddle.complex128,
+ # Mixed float and complex
+ (paddle.float32, paddle.complex64): paddle.complex64,
+ (paddle.float32, paddle.complex128): paddle.complex128,
+ (paddle.float64, paddle.complex64): paddle.complex128,
+ (paddle.float64, paddle.complex128): paddle.complex128,
+}
+
+
+def _two_arg(f):
+ @_wraps(f)
+ def _f(x1, x2, /, **kwargs):
+ x1, x2 = _fix_promotion(x1, x2)
+ return f(x1, x2, **kwargs)
+
+ if _f.__doc__ is None:
+ _f.__doc__ = f"""\
+Array API compatibility wrapper for paddle.{f.__name__}.
+
+See the corresponding Paddle documentation and/or the array API specification
+for more details.
+
+"""
+ return _f
+
+
+def _fix_promotion(x1, x2, only_scalar=True):
+ if not isinstance(x1, paddle.Tensor) or not isinstance(x2, paddle.Tensor):
+ return x1, x2
+ if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
+ return x1, x2
+ # If an argument is 0-D pytorch downcasts the other argument
+ if not only_scalar or x1.shape == ():
+ dtype = result_type(x1, x2)
+ x2 = x2.to(dtype)
+ if not only_scalar or x2.shape == ():
+ dtype = result_type(x1, x2)
+ x1 = x1.to(dtype)
+ return x1, x2
+
+
+def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
+ if len(arrays_and_dtypes) == 0:
+ raise TypeError("At least one array or dtype must be provided")
+ if len(arrays_and_dtypes) == 1:
+ x = arrays_and_dtypes[0]
+ if isinstance(x, paddle.dtype):
+ return x
+ return x.dtype
+ if len(arrays_and_dtypes) > 2:
+ return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
+
+ x, y = arrays_and_dtypes
+ xdt = x.dtype if not isinstance(x, paddle.dtype) else x
+ ydt = y.dtype if not isinstance(y, paddle.dtype) else y
+
+ if (xdt, ydt) in _promotion_table:
+ return _promotion_table[xdt, ydt]
+
+ # This doesn't result_type(dtype, dtype) for non-array API dtypes
+ # because paddle.result_type only accepts tensors. This does however, allow
+ # cross-kind promotion.
+ x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x
+ y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y
+ return paddle.result_type(x, y)
+
+
+def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
+ can_cast_dict = {
+ paddle.bfloat16: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.float16: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.float32: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.float64: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.complex64: {
+ paddle.bfloat16: False,
+ paddle.float16: False,
+ paddle.float32: False,
+ paddle.float64: False,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.complex128: {
+ paddle.bfloat16: False,
+ paddle.float16: False,
+ paddle.float32: False,
+ paddle.float64: False,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: False,
+ paddle.int8: False,
+ paddle.int16: False,
+ paddle.int32: False,
+ paddle.int64: False,
+ paddle.bool: False,
+ },
+ paddle.uint8: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int8: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int16: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int32: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.int64: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: False,
+ },
+ paddle.bool: {
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
+ paddle.complex64: True,
+ paddle.complex128: True,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
+ },
+ }
+ return can_cast_dict[from_][to]
+
+
+# Basic renames
+bitwise_invert = paddle.bitwise_not
+newaxis = None
+# paddle.conj sets the conjugation bit, which breaks conversion to other
+# libraries. See https://github.com/data-apis/array-api-compat/issues/173
+conj = paddle.conj
+
+# Two-arg elementwise functions
+# These require a wrapper to do the correct type promotion on 0-D tensors
+add = _two_arg(paddle.add)
+atan2 = _two_arg(paddle.atan2)
+bitwise_and = _two_arg(paddle.bitwise_and)
+bitwise_left_shift = _two_arg(paddle.bitwise_left_shift)
+bitwise_or = _two_arg(paddle.bitwise_or)
+bitwise_right_shift = _two_arg(paddle.bitwise_right_shift)
+bitwise_xor = _two_arg(paddle.bitwise_xor)
+copysign = _two_arg(paddle.copysign)
+divide = _two_arg(paddle.divide)
+# Also a rename. paddle.equal does not broadcast
+equal = _two_arg(paddle.equal)
+floor_divide = _two_arg(paddle.floor_divide)
+greater = _two_arg(paddle.greater_than)
+greater_equal = _two_arg(paddle.greater_equal)
+hypot = _two_arg(paddle.hypot)
+less = _two_arg(paddle.less)
+less_equal = _two_arg(paddle.less_equal)
+logaddexp = _two_arg(paddle.logaddexp)
+# logical functions are not included here because they only accept bool in the
+# spec, so type promotion is irrelevant.
+maximum = _two_arg(paddle.maximum)
+minimum = _two_arg(paddle.minimum)
+multiply = _two_arg(paddle.multiply)
+not_equal = _two_arg(paddle.not_equal)
+pow = _two_arg(paddle.pow)
+remainder = _two_arg(paddle.remainder)
+subtract = _two_arg(paddle.subtract)
+
+# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
+# of 'axis'.
+
+
+def max(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.clone(x)
+ return paddle.amax(x, axis, keepdim=keepdims)
+
+
+def min(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.clone(x)
+ return paddle.min(x, axis, keepdim=keepdims)
+
+
+clip = get_xp(paddle)(_aliases_clip)
+unstack = get_xp(paddle)(_aliases_unstack)
+cumulative_sum = get_xp(paddle)(_aliases_cumulative_sum)
+
+
+# paddle.sort also returns a tuple
+# https://github.com/pytorch/pytorch/issues/70921
+def sort(
+ x: array,
+ /,
+ *,
+ axis: int = -1,
+ descending: bool = False,
+ stable: bool = True,
+ **kwargs,
+) -> array:
+ return paddle.sort(
+ x, axis=axis, descending=descending, stable=stable, **kwargs
+ ).values
+
+
+def _normalize_axes(axis, ndim):
+ axes = []
+ if ndim == 0 and axis:
+ # Better error message in this case
+ raise IndexError(f"Dimension out of range: {axis[0]}")
+ lower, upper = -ndim, ndim - 1
+ for a in axis:
+ if a < lower or a > upper:
+ # Match paddle error message (e.g., from sum())
+ raise IndexError(
+ f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}"
+ )
+ if a < 0:
+ a = a + ndim
+ if a in axes:
+ # Use IndexError instead of RuntimeError, and "axis" instead of "dim"
+ raise IndexError(f"Axis {a} appears multiple times in the list of axes")
+ axes.append(a)
+ return sorted(axes)
+
+
+def _axis_none_keepdims(x, ndim, keepdims):
+ # Apply keepdims when axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ # Note that this is only valid for the axis=None case.
+ if keepdims:
+ for i in range(ndim):
+ x = paddle.unsqueeze(x, 0)
+ return x
+
+
+def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
+ # Some reductions don't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ axes = _normalize_axes(axis, x.ndim)
+ for a in reversed(axes):
+ x = paddle.movedim(x, a, -1)
+ x = paddle.flatten(x, -len(axes))
+
+ out = f(x, -1, **kwargs)
+
+ if keepdims:
+ for a in axes:
+ out = paddle.unsqueeze(out, a)
+ return out
+
+
+def prod(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+
+ # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
+ # below because it still needs to upcast.
+ if axis == ():
+ if dtype is None:
+ # We can't upcast uint8 according to the spec because there is no
+ # paddle.uint64, so at least upcast to int64 which is what sum does
+ # when axis=None.
+ if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
+ return x.to(paddle.int64)
+ return x.clone()
+ return x.to(dtype)
+
+ # paddle.prod doesn't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ if isinstance(axis, tuple):
+ return _reduce_multiple_axes(
+ paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
+ )
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.prod(x, dtype=dtype, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res
+
+ return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
+
+
+def sum(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+
+ # https://github.com/pytorch/pytorch/issues/29137.
+ # Make sure it upcasts.
+ if axis == ():
+ if dtype is None:
+ # We can't upcast uint8 according to the spec because there is no
+ # paddle.uint64, so at least upcast to int64 which is what sum does
+ # when axis=None.
+ if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
+ return x.to(paddle.int64)
+ return x.clone()
+ return x.to(dtype)
+
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.sum(x, dtype=dtype, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res
+
+ return paddle.sum(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
+
+
+def any(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+ if axis == ():
+ return x.to(paddle.bool)
+ # paddle.any doesn't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ if isinstance(axis, tuple):
+ res = _reduce_multiple_axes(paddle.any, x, axis, keepdim=keepdims, **kwargs)
+ return res.to(paddle.bool)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.any(x, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res.to(paddle.bool)
+
+ # paddle.any doesn't return bool for uint8
+ return paddle.any(x, axis, keepdim=keepdims).to(paddle.bool)
+
+
+def all(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ x = paddle.asarray(x)
+ ndim = x.ndim
+ if axis == ():
+ return x.to(paddle.bool)
+ # paddle.all doesn't support multiple axes
+ # (https://github.com/pytorch/pytorch/issues/56586).
+ if isinstance(axis, tuple):
+ res = _reduce_multiple_axes(paddle.all, x, axis, keepdim=keepdims, **kwargs)
+ return res.to(paddle.bool)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.all(x, **kwargs)
+ res = _axis_none_keepdims(res, ndim, keepdims)
+ return res.to(paddle.bool)
+
+ # paddle.all doesn't return bool for uint8
+ return paddle.all(x, axis, keepdim=keepdims).to(paddle.bool)
+
+
+def mean(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.clone(x)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.mean(x, **kwargs)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
+ return res
+ return paddle.mean(x, axis, keepdim=keepdims, **kwargs)
+
+
+def std(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ # Note, float correction is not supported
+ # https://github.com/pytorch/pytorch/issues/61492. We don't try to
+ # implement it here for now.
+
+ if isinstance(correction, float):
+ _correction = int(correction)
+ if correction != _correction:
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
+ elif isinstance(correction, int):
+ if correction not in [0, 1]:
+ raise NotImplementedError("correction only can be 0 or 1")
+ elif not isinstance(correction, bool):
+ raise NotImplementedError("Only support bool correction and 0, 1")
+
+ _correction = bool(_correction)
+
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.zeros_like(x)
+ if isinstance(axis, int):
+ axis = (axis,)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.std(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
+ return res
+ return paddle.std(x, axis, unbiased=_correction, keepdim=keepdims, **kwargs)
+
+
+def var(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+ **kwargs,
+) -> array:
+ # Note, float correction is not supported
+ # https://github.com/pytorch/pytorch/issues/61492. We don't try to
+ # implement it here for now.
+
+ # if isinstance(correction, float):
+ # correction = int(correction)
+ if isinstance(correction, float):
+ _correction = int(correction)
+ if correction != _correction:
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
+ elif isinstance(correction, int):
+ if correction not in [0, 1]:
+ raise NotImplementedError("correction only can be 0 or 1")
+ elif not isinstance(correction, bool):
+ raise NotImplementedError("Only support bool correction and 0, 1")
+
+ _correction = bool(_correction)
+
+ # https://github.com/pytorch/pytorch/issues/29137
+ if axis == ():
+ return paddle.zeros_like(x)
+ if isinstance(axis, int):
+ axis = (axis,)
+ if axis is None:
+ # paddle doesn't support keepdims with axis=None
+ # (https://github.com/pytorch/pytorch/issues/71209)
+ res = paddle.var(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
+ return res
+ return paddle.var(x, axis, unbiased=_correction, keepdim=keepdims, **kwargs)
+
+
+# paddle.concat doesn't support dim=None
+# https://github.com/pytorch/pytorch/issues/70925
+def concat(
+ arrays: Union[Tuple[array, ...], List[array]],
+ /,
+ *,
+ axis: Optional[int] = 0,
+ **kwargs,
+) -> array:
+ if axis is None:
+ arrays = tuple(ar.flatten() for ar in arrays)
+ axis = 0
+ return paddle.concat(arrays, axis, **kwargs)
+
+
+# paddle.squeeze only accepts int dim and doesn't require it
+# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
+# added at https://github.com/pytorch/pytorch/pull/89017.
+def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
+ if isinstance(axis, int):
+ axis = (axis,)
+ for a in axis:
+ if x.shape[a] != 1:
+ raise ValueError("squeezed dimensions must be equal to 1")
+ axes = _normalize_axes(axis, x.ndim)
+ # Remove this once pytorch 1.14 is released with the above PR #89017.
+ sequence = [a - i for i, a in enumerate(axes)]
+ for a in sequence:
+ x = paddle.squeeze(x, a)
+ return x
+
+
+# paddle.broadcast_to uses size instead of shape
+def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
+ return paddle.broadcast_to(x, shape, **kwargs)
+
+
+# paddle.permute uses dims instead of axes
+def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
+ if len(axes) == 2:
+ perm = list(range(x.ndim))
+ perm[axes[0]], perm[axes[1]] = perm[axes[1]], perm[axes[0]]
+ axes = perm
+ return paddle.transpose(x, axes)
+
+
+# The axis parameter doesn't work for flip() and roll()
+# https://github.com/pytorch/pytorch/issues/71210. Also paddle.flip() doesn't
+# accept axis=None
+def flip(
+ x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs
+) -> array:
+ if axis is None:
+ axis = tuple(range(x.ndim))
+ # paddle.flip doesn't accept dim as an int but the method does
+ # https://github.com/pytorch/pytorch/issues/18095
+ return x.flip(axis, **kwargs)
+
+
+def roll(
+ x: array,
+ /,
+ shift: Union[int, Tuple[int, ...]],
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ **kwargs,
+) -> array:
+ return paddle.roll(x, shift, axis, **kwargs)
+
+
+def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
+ if x.ndim == 0:
+ raise ValueError("nonzero() does not support zero-dimensional arrays")
+ return paddle.nonzero(x, as_tuple=True, **kwargs)
+
+
+def where(condition: array, x1: array, x2: array, /) -> array:
+ x1, x2 = _fix_promotion(x1, x2)
+ return paddle.where(condition, x1, x2)
+
+
+# paddle.reshape doesn't have the copy keyword
+def reshape(
+ x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs
+) -> array:
+ if copy is not None:
+ raise NotImplementedError("paddle.reshape doesn't yet support the copy keyword")
+ return paddle.reshape(x, shape, **kwargs)
+
+
+# paddle.arange doesn't support returning empty arrays
+# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
+# keyword argument combinations
+# (https://github.com/pytorch/pytorch/issues/70914)
+def arange(
+ start: Union[int, float],
+ /,
+ stop: Optional[Union[int, float]] = None,
+ step: Union[int, float] = 1,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ if stop is None:
+ start, stop = 0, start
+ if step > 0 and stop <= start or step < 0 and stop >= start:
+ if dtype is None:
+ if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
+ dtype = paddle.int64
+ else:
+ dtype = paddle.float32
+ return paddle.empty([0], dtype=dtype, **kwargs).to(device)
+ return paddle.arange(start, stop, step, dtype=dtype, **kwargs).to(device)
+
+
+# paddle.eye does not accept None as a default for the second argument and
+# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
+def eye(
+ n_rows: int,
+ n_cols: Optional[int] = None,
+ /,
+ *,
+ k: int = 0,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ if n_cols is None:
+ n_cols = n_rows
+ z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device)
+ if abs(k) <= n_rows + n_cols:
+ z.diagonal(k).fill_(1)
+ return z
+
+
+# paddle.linspace doesn't have the endpoint parameter
+def linspace(
+ start: Union[int, float],
+ stop: Union[int, float],
+ /,
+ num: int,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ endpoint: bool = True,
+ **kwargs,
+) -> array:
+ if not endpoint:
+ return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[
+ :-1
+ ]
+ return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device)
+
+
+# paddle.full does not accept an int size
+# https://github.com/pytorch/pytorch/issues/70906
+def full(
+ shape: Union[int, Tuple[int, ...]],
+ fill_value: Union[bool, int, float, complex],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ if isinstance(shape, int):
+ shape = (shape,)
+
+ return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device)
+
+
+# ones, zeros, and empty do not accept shape as a keyword argument
+def ones(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ return paddle.ones(shape, dtype=dtype, **kwargs).to(device)
+
+
+def zeros(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ return paddle.zeros(shape, dtype=dtype, **kwargs).to(device)
+
+
+def empty(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ **kwargs,
+) -> array:
+ return paddle.empty(shape, dtype=dtype, **kwargs).to(device)
+
+
+# tril and triu do not call the keyword argument k
+
+
+def tril(x: array, /, *, k: int = 0) -> array:
+ return paddle.tril(x, k)
+
+
+def triu(x: array, /, *, k: int = 0) -> array:
+ return paddle.triu(x, k)
+
+
+# Functions that aren't in paddle https://github.com/pytorch/pytorch/issues/58742
+def expand_dims(x: array, /, *, axis: int = 0) -> array:
+ return paddle.unsqueeze(x, axis)
+
+
+def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
+ return x.to(dtype, copy=copy)
+
+
+def broadcast_arrays(*arrays: array) -> List[array]:
+ shape = paddle.broadcast_shapes(*[a.shape for a in arrays])
+ return [paddle.broadcast_to(a, shape) for a in arrays]
+
+
+# Note that these named tuples aren't actually part of the standard namespace,
+# but I don't see any issue with exporting the names here regardless.
+from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult
+
+
+# https://github.com/pytorch/pytorch/issues/70920
+def unique_all(x: array) -> UniqueAllResult:
+ # paddle.unique doesn't support returning indices.
+ # https://github.com/pytorch/pytorch/issues/36748. The workaround
+ # suggested in that issue doesn't actually function correctly (it relies
+ # on non-deterministic behavior of scatter()).
+ raise NotImplementedError(
+ "unique_all() not yet implemented for paddle (see https://github.com/pytorch/pytorch/issues/36748)"
+ )
+
+ # values, inverse_indices, counts = paddle.unique(x, return_counts=True, return_inverse=True)
+ # # paddle.unique incorrectly gives a 0 count for nan values.
+ # # https://github.com/pytorch/pytorch/issues/94106
+ # counts[paddle.isnan(values)] = 1
+ # return UniqueAllResult(values, indices, inverse_indices, counts)
+
+
+def unique_counts(x: array) -> UniqueCountsResult:
+ values, counts = paddle.unique(x, return_counts=True)
+
+ # paddle.unique incorrectly gives a 0 count for nan values.
+ # https://github.com/pytorch/pytorch/issues/94106
+ counts[paddle.isnan(values)] = 1
+ return UniqueCountsResult(values, counts)
+
+
+def unique_inverse(x: array) -> UniqueInverseResult:
+ values, inverse = paddle.unique(x, return_inverse=True)
+ return UniqueInverseResult(values, inverse)
+
+
+def unique_values(x: array) -> array:
+ return paddle.unique(x)
+
+
+def matmul(x1: array, x2: array, /, **kwargs) -> array:
+ # paddle.matmul doesn't type promote (but differently from _fix_promotion)
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ return paddle.matmul(x1, x2, **kwargs)
+
+
+matrix_transpose = get_xp(paddle)(_aliases_matrix_transpose)
+_vecdot = get_xp(paddle)(_aliases_vecdot)
+
+
+def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ return _vecdot(x1, x2, axis=axis)
+
+
+# paddle.tensordot uses dims instead of axes
+def tensordot(
+ x1: array,
+ x2: array,
+ /,
+ *,
+ axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
+ **kwargs,
+) -> array:
+ # Note: paddle.tensordot fails with integer dtypes when there is only 1
+ # element in the axis (https://github.com/pytorch/pytorch/issues/84530).
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ return paddle.tensordot(x1, x2, axes=axes, **kwargs)
+
+
+def isdtype(
+ dtype: Dtype,
+ kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
+ *,
+ _tuple=True, # Disallow nested tuples
+) -> bool:
+ """
+ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
+
+ Note that outside of this function, this compat library does not yet fully
+ support complex numbers.
+
+ See
+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
+ for more details
+ """
+
+ def is_signed(dtype):
+ return dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]
+
+ def is_floating_point(dtype):
+ return dtype in [
+ paddle.float32,
+ paddle.float64,
+ paddle.float16,
+ paddle.bfloat16,
+ paddle.float8_e4m3fn,
+ paddle.float8_e5m2,
+ ]
+
+ def is_complex(dtype):
+ return dtype in [paddle.complex64, paddle.complex128]
+
+ if isinstance(kind, tuple) and _tuple:
+ return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
+
+ elif isinstance(kind, str):
+ if kind == "bool":
+ return dtype == paddle.bool
+ elif kind == "signed integer":
+ return dtype in _int_dtypes and is_signed(dtype)
+ elif kind == "unsigned integer":
+ return dtype in _int_dtypes and not is_signed(dtype)
+ elif kind == "integral":
+ return dtype in _int_dtypes
+ elif kind == "real floating":
+ return is_floating_point(dtype)
+ elif kind == "complex floating":
+ return is_complex(dtype)
+ elif kind == "numeric":
+ return isdtype(dtype, ("integral", "real floating", "complex floating"))
+ else:
+ raise ValueError(f"Unrecognized data type kind: {kind!r}")
+ else:
+ return dtype == kind
+
+
+def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
+ if axis is None:
+ if x.ndim != 1:
+ raise ValueError("axis must be specified when ndim > 1")
+ axis = 0
+ return paddle.index_select(x, axis, indices, **kwargs)
+
+
+def sign(x: array, /) -> array:
+ # paddle sign() does not support complex numbers and does not propagate
+ # nans. See https://github.com/data-apis/array-api-compat/issues/136
+ if x.dtype.is_complex:
+ out = x / paddle.abs(x)
+ # sign(0) = 0 but the above formula would give nan
+ out[x == 0 + 0j] = 0 + 0j
+ return out
+ else:
+ out = paddle.sign(x)
+ if x.dtype.is_floating_point:
+ out[paddle.isnan(x)] = paddle.nan
+ return out
+
+
+__all__ = [
+ "__array_namespace_info__",
+ "result_type",
+ "can_cast",
+ "permute_dims",
+ "bitwise_invert",
+ "newaxis",
+ "conj",
+ "add",
+ "atan2",
+ "bitwise_and",
+ "bitwise_left_shift",
+ "bitwise_or",
+ "bitwise_right_shift",
+ "bitwise_xor",
+ "copysign",
+ "divide",
+ "equal",
+ "floor_divide",
+ "greater",
+ "greater_equal",
+ "hypot",
+ "less",
+ "less_equal",
+ "logaddexp",
+ "maximum",
+ "minimum",
+ "multiply",
+ "not_equal",
+ "pow",
+ "remainder",
+ "subtract",
+ "max",
+ "min",
+ "clip",
+ "unstack",
+ "cumulative_sum",
+ "sort",
+ "prod",
+ "sum",
+ "any",
+ "all",
+ "mean",
+ "std",
+ "var",
+ "concat",
+ "squeeze",
+ "broadcast_to",
+ "flip",
+ "roll",
+ "nonzero",
+ "where",
+ "reshape",
+ "arange",
+ "eye",
+ "linspace",
+ "full",
+ "ones",
+ "zeros",
+ "empty",
+ "tril",
+ "triu",
+ "expand_dims",
+ "astype",
+ "broadcast_arrays",
+ "UniqueAllResult",
+ "UniqueCountsResult",
+ "UniqueInverseResult",
+ "unique_all",
+ "unique_counts",
+ "unique_inverse",
+ "unique_values",
+ "matmul",
+ "matrix_transpose",
+ "vecdot",
+ "tensordot",
+ "isdtype",
+ "take",
+ "sign",
+]
+
+_all_ignore = ["paddle", "get_xp"]
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
new file mode 100644
index 00000000..1fe48356
--- /dev/null
+++ b/array_api_compat/paddle/_info.py
@@ -0,0 +1,373 @@
+"""
+Array API Inspection namespace
+
+This is the namespace for inspection functions as defined by the array API
+standard. See
+https://data-apis.org/array-api/latest/API_specification/inspection.html for
+more details.
+
+"""
+
+import paddle
+
+from functools import cache
+
+
+class __array_namespace_info__:
+ """
+ Get the array API inspection namespace for PyTorch.
+
+ The array API inspection namespace defines the following functions:
+
+ - capabilities()
+ - default_device()
+ - default_dtypes()
+ - dtypes()
+ - devices()
+
+ See
+ https://data-apis.org/array-api/latest/API_specification/inspection.html
+ for more details.
+
+ Returns
+ -------
+ info : ModuleType
+ The array API inspection namespace for PyTorch.
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.default_dtypes()
+ {'real floating': numpy.float64,
+ 'complex floating': numpy.complex128,
+ 'integral': numpy.int64,
+ 'indexing': numpy.int64}
+
+ """
+
+ __module__ = "paddle"
+
+ def capabilities(self):
+ """
+ Return a dictionary of array API library capabilities.
+
+ The resulting dictionary has the following keys:
+
+ - **"boolean indexing"**: boolean indicating whether an array library
+ supports boolean indexing. Always ``True`` for PyTorch.
+
+ - **"data-dependent shapes"**: boolean indicating whether an array
+ library supports data-dependent output shapes. Always ``True`` for
+ PyTorch.
+
+ See
+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
+ for more details.
+
+ See Also
+ --------
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.dtypes,
+ __array_namespace_info__.devices
+
+ Returns
+ -------
+ capabilities : dict
+ A dictionary of array API library capabilities.
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.capabilities()
+ {'boolean indexing': True,
+ 'data-dependent shapes': True}
+
+ """
+ return {
+ "boolean indexing": True,
+ "data-dependent shapes": True,
+ # 'max rank' will be part of the 2024.12 standard
+ # "max rank": 64,
+ }
+
+ def default_device(self):
+ """
+ The default device used for new PyTorch arrays.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.dtypes,
+ __array_namespace_info__.devices
+
+ Returns
+ -------
+ device : str
+ The default device used for new PyTorch arrays.
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.default_device()
+ 'cpu'
+
+ """
+ return paddle.device.get_device()
+
+ def default_dtypes(self, *, device=None):
+ """
+ The default data types used for new PyTorch arrays.
+
+ Parameters
+ ----------
+ device : str, optional
+ The device to get the default data types for. For PyTorch, only
+ ``'cpu'`` is allowed.
+
+ Returns
+ -------
+ dtypes : dict
+ A dictionary describing the default data types used for new PyTorch
+ arrays.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.dtypes,
+ __array_namespace_info__.devices
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.default_dtypes()
+ {'real floating': paddle.float32,
+ 'complex floating': paddle.complex64,
+ 'integral': paddle.int64,
+ 'indexing': paddle.int64}
+
+ """
+ # Note: if the default is set to float64, the devices like MPS that
+ # don't support float64 will error. We still return the default_dtype
+ # value here because this error doesn't represent a different default
+ # per-device.
+ default_floating = paddle.get_default_dtype()
+ default_complex = "complex64" if default_floating == "float32" else "complex128"
+ default_integral = "int64"
+ return {
+ "real floating": default_floating,
+ "complex floating": default_complex,
+ "integral": default_integral,
+ "indexing": default_integral,
+ }
+
+ def _dtypes(self, kind):
+ bool = paddle.bool
+ int8 = paddle.int8
+ int16 = paddle.int16
+ int32 = paddle.int32
+ int64 = paddle.int64
+ uint8 = paddle.uint8
+ # uint16, uint32, and uint64 are present in newer versions of pytorch,
+ # but they aren't generally supported by the array API functions, so
+ # we omit them from this function.
+ float32 = paddle.float32
+ float64 = paddle.float64
+ complex64 = paddle.complex64
+ complex128 = paddle.complex128
+
+ if kind is None:
+ return {
+ "bool": bool,
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ "uint8": uint8,
+ "float32": float32,
+ "float64": float64,
+ "complex64": complex64,
+ "complex128": complex128,
+ }
+ if kind == "bool":
+ return {"bool": bool}
+ if kind == "signed integer":
+ return {
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ }
+ if kind == "unsigned integer":
+ return {
+ "uint8": uint8,
+ }
+ if kind == "integral":
+ return {
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ "uint8": uint8,
+ }
+ if kind == "real floating":
+ return {
+ "float32": float32,
+ "float64": float64,
+ }
+ if kind == "complex floating":
+ return {
+ "complex64": complex64,
+ "complex128": complex128,
+ }
+ if kind == "numeric":
+ return {
+ "int8": int8,
+ "int16": int16,
+ "int32": int32,
+ "int64": int64,
+ "uint8": uint8,
+ "float32": float32,
+ "float64": float64,
+ "complex64": complex64,
+ "complex128": complex128,
+ }
+ if isinstance(kind, tuple):
+ res = {}
+ for k in kind:
+ res.update(self.dtypes(kind=k))
+ return res
+ raise ValueError(f"unsupported kind: {kind!r}")
+
+ @cache
+ def dtypes(self, *, device=None, kind=None):
+ """
+ The array API data types supported by PyTorch.
+
+ Note that this function only returns data types that are defined by
+ the array API.
+
+ Parameters
+ ----------
+ device : str, optional
+ The device to get the data types for.
+ kind : str or tuple of str, optional
+ The kind of data types to return. If ``None``, all data types are
+ returned. If a string, only data types of that kind are returned.
+ If a tuple, a dictionary containing the union of the given kinds
+ is returned. The following kinds are supported:
+
+ - ``'bool'``: boolean data types (i.e., ``bool``).
+ - ``'signed integer'``: signed integer data types (i.e., ``int8``,
+ ``int16``, ``int32``, ``int64``).
+ - ``'unsigned integer'``: unsigned integer data types (i.e.,
+ ``uint8``, ``uint16``, ``uint32``, ``uint64``).
+ - ``'integral'``: integer data types. Shorthand for ``('signed
+ integer', 'unsigned integer')``.
+ - ``'real floating'``: real-valued floating-point data types
+ (i.e., ``float32``, ``float64``).
+ - ``'complex floating'``: complex floating-point data types (i.e.,
+ ``complex64``, ``complex128``).
+ - ``'numeric'``: numeric data types. Shorthand for ``('integral',
+ 'real floating', 'complex floating')``.
+
+ Returns
+ -------
+ dtypes : dict
+ A dictionary mapping the names of data types to the corresponding
+ PyTorch data types.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.devices
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.dtypes(kind='signed integer')
+ {'int8': numpy.int8,
+ 'int16': numpy.int16,
+ 'int32': numpy.int32,
+ 'int64': numpy.int64}
+
+ """
+ res = self._dtypes(kind)
+ for k, v in res.copy().items():
+ try:
+ paddle.empty((0,), dtype=v, device=device)
+ except:
+ del res[k]
+ return res
+
+ @cache
+ def devices(self):
+ """
+ The devices supported by PyTorch.
+
+ Returns
+ -------
+ devices : list of str
+ The devices supported by PyTorch.
+
+ See Also
+ --------
+ __array_namespace_info__.capabilities,
+ __array_namespace_info__.default_device,
+ __array_namespace_info__.default_dtypes,
+ __array_namespace_info__.dtypes
+
+ Examples
+ --------
+ >>> info = np.__array_namespace_info__()
+ >>> info.devices()
+ [device(type='cpu'), device(type='mps', index=0), device(type='meta')]
+
+ """
+ # Paddle doesn't have a straightforward way to get the list of all
+ # currently supported devices. To do this, we first parse the error
+ # message of paddle.device to get the list of all possible types of
+ # device:
+ try:
+ paddle.device("notadevice")
+ except RuntimeError as e:
+ # The error message is something like:
+ # ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
+ devices_names = (
+ e.args[0]
+ .split("ValueError: The device must be a string which is like ")[1]
+ .split(", ")
+ )
+ devices_names = [
+ name.strip("'") for name in devices_names if ":" not in name
+ ]
+
+ # Next we need to check for different indices for different devices.
+ # device(device_name, index=index) doesn't actually check if the
+ # device name or index is valid. We have to try to create a tensor
+ # with it (which is why this function is cached).
+ devices = []
+ for device_name in devices_names:
+ i = 0
+ while True:
+ try:
+ if device_name == "cpu":
+ a = paddle.empty((0,), place=paddle.CPUPlace())
+ elif device_name == "gpu":
+ a = paddle.empty((0,), place=paddle.CUDAPlace(i))
+ elif device_name == "xpu":
+ a = paddle.empty((0,), place=paddle.XPUPlace())
+ else:
+ raise
+ if a.place in devices:
+ break
+ devices.append(a.device)
+ except:
+ break
+ i += 1
+
+ return devices
diff --git a/array_api_compat/paddle/fft.py b/array_api_compat/paddle/fft.py
new file mode 100644
index 00000000..15519b5a
--- /dev/null
+++ b/array_api_compat/paddle/fft.py
@@ -0,0 +1,92 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import paddle
+
+ array = paddle.Tensor
+ from typing import Union, Sequence, Literal
+
+from paddle.fft import * # noqa: F403
+import paddle.fft
+
+
+def fftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def ifftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def rfftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def irfftn(
+ x: array,
+ /,
+ *,
+ s: Sequence[int] = None,
+ axes: Sequence[int] = None,
+ norm: Literal["backward", "ortho", "forward"] = "backward",
+ **kwargs,
+) -> array:
+ return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+
+
+def fftshift(
+ x: array,
+ /,
+ *,
+ axes: Union[int, Sequence[int]] = None,
+ **kwargs,
+) -> array:
+ return paddle.fft.fftshift(x, axes=axes, **kwargs)
+
+
+def ifftshift(
+ x: array,
+ /,
+ *,
+ axes: Union[int, Sequence[int]] = None,
+ **kwargs,
+) -> array:
+ return paddle.fft.ifftshift(x, axes=axes, **kwargs)
+
+
+__all__ = paddle.fft.__all__ + [
+ "fftn",
+ "ifftn",
+ "rfftn",
+ "irfftn",
+ "fftshift",
+ "ifftshift",
+]
+
+_all_ignore = ["paddle"]
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
new file mode 100644
index 00000000..6ee57fcf
--- /dev/null
+++ b/array_api_compat/paddle/linalg.py
@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import paddle
+
+ array = paddle.Tensor
+ from paddle import dtype as Dtype
+ from typing import Optional, Union, Tuple, Literal
+
+ inf = float("inf")
+
+from ._aliases import _fix_promotion, sum
+
+from paddle.linalg import * # noqa: F403
+
+# paddle.linalg doesn't define __all__
+# from paddle.linalg import __all__ as linalg_all
+from paddle import linalg as paddle_linalg
+
+linalg_all = [i for i in dir(paddle_linalg) if not i.startswith("_")]
+
+# outer is implemented in paddle but aren't in the linalg namespace
+from paddle import outer
+
+# These functions are in both the main and linalg namespaces
+from ._aliases import matmul, matrix_transpose, tensordot
+
+# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
+# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
+
+
+# paddle.cross also does not support broadcasting when it would add new
+# dimensions https://github.com/pytorch/pytorch/issues/39656
+def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
+ raise ValueError(
+ f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}"
+ )
+
+ if not (x1.shape[axis] == x2.shape[axis] == 3):
+ raise ValueError(
+ f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}"
+ )
+
+ x1, x2 = paddle.broadcast_tensors(x1, x2)
+ return paddle_linalg.cross(x1, x2, axis=axis)
+
+
+def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
+ from ._aliases import isdtype
+
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+
+ # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
+ if x1.shape[axis] != x2.shape[axis]:
+ raise ValueError("x1 and x2 must have the same size along the given axis")
+
+ # paddle.linalg.vecdot doesn't support integer dtypes
+ if isdtype(x1.dtype, "integral") or isdtype(x2.dtype, "integral"):
+ if kwargs:
+ raise RuntimeError("vecdot kwargs not supported for integral dtypes")
+
+ x1_ = paddle.moveaxis(x1, axis, -1)
+ x2_ = paddle.moveaxis(x2, axis, -1)
+ x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
+
+ res = x1_[..., None, :] @ x2_[..., None]
+ return res[..., 0, 0]
+ return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
+
+
+def solve(x1: array, x2: array, /, **kwargs) -> array:
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+
+ if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
+ x2 = x2[None]
+ return paddle.linalg.solve(x1, x2, **kwargs)
+
+
+# paddle.trace doesn't support the offset argument and doesn't support stacking
+def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
+ # Use our wrapped sum to make sure it does upcasting correctly
+ return sum(
+ paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype
+ )
+
+
+def vector_norm(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+ ord: Union[int, float, Literal[inf, -inf]] = 2,
+ **kwargs,
+) -> array:
+ # paddle.vector_norm incorrectly treats axis=() the same as axis=None
+ if axis == ():
+ out = kwargs.get("out")
+ if out is None:
+ dtype = None
+ if x.dtype == paddle.complex64:
+ dtype = paddle.float32
+ elif x.dtype == paddle.complex128:
+ dtype = paddle.float64
+
+ out = paddle.zeros_like(x, dtype=dtype)
+
+ # The norm of a single scalar works out to abs(x) in every case except
+ # for ord=0, which is x != 0.
+ if ord == 0:
+ out[:] = x != 0
+ else:
+ out[:] = paddle.abs(x)
+ return out
+ return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
+
+
+__all__ = linalg_all + [
+ "outer",
+ "matmul",
+ "matrix_transpose",
+ "tensordot",
+ "cross",
+ "vecdot",
+ "solve",
+ "trace",
+ "vector_norm",
+]
+
+_all_ignore = ["paddle_linalg", "sum"]
+
+del linalg_all
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index 3c9117ee..59c306af 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -2,14 +2,14 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import torch
- array = torch.Tensor
+ import paddle
+ array = paddle.Tensor
from typing import Union, Sequence, Literal
-from torch.fft import * # noqa: F403
-import torch.fft
+from paddle.fft import * # noqa: F403
+import paddle.fft
-# Several torch fft functions do not map axes to dim
+# Several paddle fft functions do not map axes to dim
def fftn(
x: array,
@@ -20,7 +20,7 @@ def fftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
def ifftn(
x: array,
@@ -31,7 +31,7 @@ def ifftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
def rfftn(
x: array,
@@ -42,7 +42,7 @@ def rfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
def irfftn(
x: array,
@@ -53,7 +53,7 @@ def irfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
+ return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
def fftshift(
x: array,
@@ -62,7 +62,7 @@ def fftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return torch.fft.fftshift(x, dim=axes, **kwargs)
+ return paddle.fft.fftshift(x, axes=axes, **kwargs)
def ifftshift(
x: array,
@@ -71,10 +71,10 @@ def ifftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return torch.fft.ifftshift(x, dim=axes, **kwargs)
+ return paddle.fft.ifftshift(x, axes=axes, **kwargs)
-__all__ = torch.fft.__all__ + [
+__all__ = paddle.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
@@ -83,4 +83,4 @@ def ifftshift(
"ifftshift",
]
-_all_ignore = ['torch']
+_all_ignore = ['paddle']
diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py
index e26198b9..5e4ee47b 100644
--- a/array_api_compat/torch/linalg.py
+++ b/array_api_compat/torch/linalg.py
@@ -2,86 +2,84 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import torch
- array = torch.Tensor
- from torch import dtype as Dtype
+ import paddle
+ array = paddle.Tensor
+ from paddle import dtype as Dtype
from typing import Optional, Union, Tuple, Literal
inf = float('inf')
from ._aliases import _fix_promotion, sum
-from torch.linalg import * # noqa: F403
+from paddle.linalg import * # noqa: F403
-# torch.linalg doesn't define __all__
-# from torch.linalg import __all__ as linalg_all
-from torch import linalg as torch_linalg
-linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
+# paddle.linalg doesn't define __all__
+# from paddle.linalg import __all__ as linalg_all
+from paddle import linalg as paddle_linalg
+linalg_all = [i for i in dir(paddle_linalg) if not i.startswith('_')]
-# outer is implemented in torch but aren't in the linalg namespace
-from torch import outer
+# outer is implemented in paddle but aren't in the linalg namespace
+from paddle import outer
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
-# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
-# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
+# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
-# torch.cross also does not support broadcasting when it would add new
-# dimensions https://github.com/pytorch/pytorch/issues/39656
+# paddle.cross also does not support broadcasting when it would add new
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
- x1, x2 = torch.broadcast_tensors(x1, x2)
- return torch_linalg.cross(x1, x2, dim=axis)
+ x1, x2 = paddle.broadcast_tensors(x1, x2)
+ return paddle_linalg.cross(x1, x2, axis=axis)
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
+ # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
- # torch.linalg.vecdot doesn't support integer dtypes
+ # paddle.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
- x1_ = torch.moveaxis(x1, axis, -1)
- x2_ = torch.moveaxis(x2, axis, -1)
- x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
+ x1_ = paddle.moveaxis(x1, axis, -1)
+ x2_ = paddle.moveaxis(x2, axis, -1)
+ x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
- return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
+ return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
+ # paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
- # TORCH_META_FUNC(_linalg_solve_ex) in
- # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
+ # paddle_META_FUNC(_linalg_solve_ex) in
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
- # See https://github.com/pytorch/pytorch/issues/52915
+ # See https://github.com/pypaddle/pypaddle/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
- return torch.linalg.solve(x1, x2, **kwargs)
+ return paddle.linalg.solve(x1, x2, **kwargs)
-# torch.trace doesn't support the offset argument and doesn't support stacking
+# paddle.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
- return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
+ return sum(paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
x: array,
@@ -92,30 +90,30 @@ def vector_norm(
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
- # torch.vector_norm incorrectly treats axis=() the same as axis=None
+ # paddle.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
if out is None:
dtype = None
- if x.dtype == torch.complex64:
- dtype = torch.float32
- elif x.dtype == torch.complex128:
- dtype = torch.float64
+ if x.dtype == paddle.complex64:
+ dtype = paddle.float32
+ elif x.dtype == paddle.complex128:
+ dtype = paddle.float64
- out = torch.zeros_like(x, dtype=dtype)
+ out = paddle.zeros_like(x, dtype=dtype)
# The norm of a single scalar works out to abs(x) in every case except
- # for ord=0, which is x != 0.
+ # for p=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
- out[:] = torch.abs(x)
+ out[:] = paddle.abs(x)
return out
- return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
+ return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
-_all_ignore = ['torch_linalg', 'sum']
+_all_ignore = ['paddle_linalg', 'sum']
del linalg_all
diff --git a/docs/index.md b/docs/index.md
index ef18265e..874c3866 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -60,6 +60,10 @@ import array_api_compat.torch as torch
import array_api_compat.dask as da
```
+```py
+import array_api_compat.paddle as paddle
+```
+
```{note}
There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These
support for these libraries is contained in the libraries themselves (JAX
diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md
index a016a636..fa30ccd2 100644
--- a/docs/supported-array-libraries.md
+++ b/docs/supported-array-libraries.md
@@ -137,3 +137,26 @@ The minimum supported Dask version is 2023.12.0.
## [Sparse](https://sparse.pydata.org/en/stable/)
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
+
+## [Paddle](https://www.paddlepaddle.org.cn/)
+
+- Like NumPy/CuPy, we do not wrap the `paddle.Tensor` object. It is missing the
+ `__array_namespace__` and `to_device` methods, so the corresponding helper
+ functions {func}`~.array_namespace()` and {func}`~.to_device()` in this
+ library should be used instead.
+
+- Paddle does not have unsigned integer types other than `uint8`, and no
+ attempt is made to implement them here.
+
+- [`std()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html#array_api.std)
+ and
+ [`var()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html#array_api.var)
+ do not support floating-point `correction` except for `0.0` and `1.0`.
+
+- The `stream` argument of the {func}`~.to_device()` helper is not supported.
+
+- As with NumPy, type annotations and positional-only arguments may not
+ exactly match the spec for functions that are not wrapped at all.
+
+The minimum supported PyTorch version is 1.13.
+
diff --git a/requirements-dev.txt b/requirements-dev.txt
index c9d10f71..ae41a25e 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,5 +4,6 @@ jax[cpu]
numpy
pytest
torch
+paddlepaddle
sparse >=0.15.1
ndonnx
diff --git a/tests/_helpers.py b/tests/_helpers.py
index e2a7e1d1..0321bcb4 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -3,12 +3,12 @@
import pytest
-wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
-all_libraries = wrapped_libraries + ["jax.numpy"]
+wrapped_libraries = ["numpy", "paddle"]
+all_libraries = wrapped_libraries + []
# `sparse` added array API support as of Python 3.10.
-if sys.version_info >= (3, 10):
- all_libraries.append('sparse')
+# if sys.version_info >= (3, 10):
+# all_libraries.append('sparse')
def import_(library, wrapper=False):
if library == 'cupy':
@@ -25,4 +25,9 @@ def import_(library, wrapper=False):
else:
library = 'array_api_compat.' + library
+ if library == 'paddle':
+ xp = import_module(library)
+ xp.asarray = xp.to_tensor
+ return xp
+
return import_module(library)
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index 9c26371c..4b494ec3 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -2,10 +2,11 @@
import sys
import warnings
-import jax
+# import jax
import numpy as np
import pytest
-import torch
+# import torch
+import paddle
import array_api_compat
from array_api_compat import array_namespace
@@ -72,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
"""
subprocess.run([sys.executable, "-c", code], check=True)
-def test_jax_zero_gradient():
- jx = jax.numpy.arange(4)
- jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
- assert (array_api_compat.get_namespace(jax_zero) is
- array_api_compat.get_namespace(jx))
+# def test_jax_zero_gradient():
+# jx = jax.numpy.arange(4)
+# jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
+# assert (array_api_compat.get_namespace(jax_zero) is
+# array_api_compat.get_namespace(jx))
def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -86,26 +87,53 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
-def test_array_namespace_errors_torch():
- y = torch.asarray([1, 2])
+# def test_array_namespace_errors_torch():
+# y = torch.asarray([1, 2])
+# x = np.asarray([1, 2])
+# pytest.raises(TypeError, lambda: array_namespace(x, y))
+
+
+def test_array_namespace_errors_paddle():
+ y = paddle.to_tensor([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))
+
+# def test_api_version():
+# x = torch.asarray([1, 2])
+# torch_ = import_("torch", wrapper=True)
+# assert array_namespace(x, api_version="2023.12") == torch_
+# assert array_namespace(x, api_version=None) == torch_
+# assert array_namespace(x) == torch_
+# # Should issue a warning
+# with warnings.catch_warnings(record=True) as w:
+# assert array_namespace(x, api_version="2021.12") == torch_
+# assert len(w) == 1
+# assert "2021.12" in str(w[0].message)
+
+# # Should issue a warning
+# with warnings.catch_warnings(record=True) as w:
+# assert array_namespace(x, api_version="2022.12") == torch_
+# assert len(w) == 1
+# assert "2022.12" in str(w[0].message)
+
+# pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
+
def test_api_version():
- x = torch.asarray([1, 2])
- torch_ = import_("torch", wrapper=True)
- assert array_namespace(x, api_version="2023.12") == torch_
- assert array_namespace(x, api_version=None) == torch_
- assert array_namespace(x) == torch_
+ x = paddle.asarray([1, 2])
+ paddle_ = import_("paddle", wrapper=True)
+ assert array_namespace(x, api_version="2023.12") == paddle_
+ assert array_namespace(x, api_version=None) == paddle_
+ assert array_namespace(x) == paddle_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2021.12") == torch_
+ assert array_namespace(x, api_version="2021.12") == paddle_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2022.12") == torch_
+ assert array_namespace(x, api_version="2022.12") == paddle_
assert len(w) == 1
assert "2022.12" in str(w[0].message)
@@ -130,3 +158,19 @@ def test_python_scalars():
assert array_namespace(a, 1j) == xp
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp
+
+def test_python_scalars():
+ a = paddle.to_tensor([1, 2])
+ xp = import_("paddle", wrapper=True)
+
+ pytest.raises(TypeError, lambda: array_namespace(1))
+ pytest.raises(TypeError, lambda: array_namespace(1.0))
+ pytest.raises(TypeError, lambda: array_namespace(1j))
+ pytest.raises(TypeError, lambda: array_namespace(True))
+ pytest.raises(TypeError, lambda: array_namespace(None))
+
+ assert array_namespace(a, 1) == xp
+ assert array_namespace(a, 1.0) == xp
+ assert array_namespace(a, 1j) == xp
+ assert array_namespace(a, True) == xp
+ assert array_namespace(a, None) == xp
diff --git a/tests/test_common.py b/tests/test_common.py
index e1cfa9eb..5c0b5826 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,8 +1,8 @@
from array_api_compat import ( # noqa: F401
- is_numpy_array, is_cupy_array, is_torch_array,
+ is_numpy_array, is_cupy_array, is_torch_array, is_paddle_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
- is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
+ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, is_paddle_namespace,
)
from array_api_compat import is_array_api_obj, device, to_device
@@ -16,20 +16,22 @@
is_array_functions = {
'numpy': 'is_numpy_array',
- 'cupy': 'is_cupy_array',
- 'torch': 'is_torch_array',
- 'dask.array': 'is_dask_array',
- 'jax.numpy': 'is_jax_array',
- 'sparse': 'is_pydata_sparse_array',
+ # 'cupy': 'is_cupy_array',
+ # 'torch': 'is_torch_array',
+ # 'dask.array': 'is_dask_array',
+ # 'jax.numpy': 'is_jax_array',
+ # 'sparse': 'is_pydata_sparse_array',
+ 'paddle': 'is_paddle_array',
}
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
- 'cupy': 'is_cupy_namespace',
- 'torch': 'is_torch_namespace',
- 'dask.array': 'is_dask_namespace',
- 'jax.numpy': 'is_jax_namespace',
- 'sparse': 'is_pydata_sparse_namespace',
+ # 'cupy': 'is_cupy_namespace',
+ # 'torch': 'is_torch_namespace',
+ # 'dask.array': 'is_dask_namespace',
+ # 'jax.numpy': 'is_jax_namespace',
+ # 'sparse': 'is_pydata_sparse_namespace',
+ 'paddle': 'is_paddle_namespace',
}
@@ -114,6 +116,8 @@ def test_asarray_cross_library(source_library, target_library, request):
@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
+ if library == 'paddle':
+ pytest.skip("Paddle does not support explicit copies")
# Note, we have this test here because the test suite currently doesn't
# test the copy flag to asarray() very rigorously. Once
# https://github.com/data-apis/array-api-tests/issues/241 is fixed we
diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py
index 6ad45d4c..e7b7d9c1 100644
--- a/tests/test_isdtype.py
+++ b/tests/test_isdtype.py
@@ -10,7 +10,7 @@
# Check the known dtypes by their string names
def _spec_dtypes(library):
- if library == 'torch':
+ if library in ['torch', 'paddle']:
# torch does not have unsigned integer dtypes
return {
'bool',
diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py
index a1fdf731..201f98ea 100644
--- a/tests/test_no_dependencies.py
+++ b/tests/test_no_dependencies.py
@@ -49,8 +49,12 @@ def _test_dependency(mod):
# TODO: Test that wrapper for library X doesn't depend on wrappers for library
# Y (except most array libraries actually do themselves depend on numpy).
-@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
- "jax.numpy", "sparse", "array_api_strict"])
+@pytest.mark.parametrize("library",
+ [
+ "numpy",
+ "paddle", "array_api_strict",
+ ]
+)
def test_numpy_dependency(library):
# This import is here because it imports numpy
from ._helpers import import_
diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py
index 70083b49..91ca1709 100644
--- a/tests/test_vendoring.py
+++ b/tests/test_vendoring.py
@@ -7,20 +7,26 @@ def test_vendoring_numpy():
uses_numpy._test_numpy()
-def test_vendoring_cupy():
- pytest.importorskip("cupy")
+# def test_vendoring_cupy():
+# pytest.importorskip("cupy")
- from vendor_test import uses_cupy
+# from vendor_test import uses_cupy
- uses_cupy._test_cupy()
+# uses_cupy._test_cupy()
-def test_vendoring_torch():
- from vendor_test import uses_torch
+# def test_vendoring_torch():
+# from vendor_test import uses_torch
- uses_torch._test_torch()
+# uses_torch._test_torch()
-def test_vendoring_dask():
- from vendor_test import uses_dask
- uses_dask._test_dask()
+# def test_vendoring_dask():
+# from vendor_test import uses_dask
+# uses_dask._test_dask()
+
+
+def test_vendoring_paddle():
+ from vendor_test import uses_paddle
+
+ uses_paddle._test_paddle()
diff --git a/vendor_test/uses_paddle.py b/vendor_test/uses_paddle.py
new file mode 100644
index 00000000..e92257a4
--- /dev/null
+++ b/vendor_test/uses_paddle.py
@@ -0,0 +1,30 @@
+# Basic test that vendoring works
+
+from .vendored._compat import (
+ is_paddle_array,
+ is_paddle_namespace,
+ paddle as paddle_compat,
+)
+
+import paddle
+
+def _test_paddle():
+ a = paddle_compat.to_tensor([1., 2., 3.])
+ b = paddle_compat.arange(3, dtype=paddle_compat.float64)
+ assert a.dtype == paddle_compat.float32 == paddle.float32
+ assert b.dtype == paddle_compat.float64 == paddle.float64
+
+ # paddle.expand_dims does not exist. Update this to use something else if it is added
+ res = paddle_compat.expand_dims(a, axis=0)
+ assert res.dtype == paddle_compat.float32 == paddle.float32
+ assert res.shape == [1, 3]
+ assert isinstance(res.shape, list)
+ assert isinstance(a, paddle.Tensor)
+ assert isinstance(b, paddle.Tensor)
+ assert isinstance(res, paddle.Tensor)
+
+ assert paddle.allclose(res, paddle.to_tensor([[1., 2., 3.]]))
+
+ assert is_paddle_array(res)
+ assert is_paddle_namespace(paddle) and is_paddle_namespace(paddle_compat)
+
From 7118894fae4c1d101a8262a61248c4208fe83d56 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 12:45:05 +0800
Subject: [PATCH 02/28] update README
---
README.md | 4 +-
array_api_compat/torch/fft.py | 26 +++++------
array_api_compat/torch/linalg.py | 76 ++++++++++++++++----------------
3 files changed, 54 insertions(+), 52 deletions(-)
diff --git a/README.md b/README.md
index 4b0b0c9c..5c30919d 100644
--- a/README.md
+++ b/README.md
@@ -2,8 +2,8 @@
This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
-NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
+NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want
support for other array libraries, or if you encounter any issues, please [open
an issue](https://github.com/data-apis/array-api-compat/issues).
-See the documentation for more details https://data-apis.org/array-api-compat/
+See the documentation for more details
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index 59c306af..3c9117ee 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -2,14 +2,14 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import paddle
- array = paddle.Tensor
+ import torch
+ array = torch.Tensor
from typing import Union, Sequence, Literal
-from paddle.fft import * # noqa: F403
-import paddle.fft
+from torch.fft import * # noqa: F403
+import torch.fft
-# Several paddle fft functions do not map axes to dim
+# Several torch fft functions do not map axes to dim
def fftn(
x: array,
@@ -20,7 +20,7 @@ def fftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
def ifftn(
x: array,
@@ -31,7 +31,7 @@ def ifftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
def rfftn(
x: array,
@@ -42,7 +42,7 @@ def rfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def irfftn(
x: array,
@@ -53,7 +53,7 @@ def irfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
- return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
+ return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def fftshift(
x: array,
@@ -62,7 +62,7 @@ def fftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return paddle.fft.fftshift(x, axes=axes, **kwargs)
+ return torch.fft.fftshift(x, dim=axes, **kwargs)
def ifftshift(
x: array,
@@ -71,10 +71,10 @@ def ifftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
- return paddle.fft.ifftshift(x, axes=axes, **kwargs)
+ return torch.fft.ifftshift(x, dim=axes, **kwargs)
-__all__ = paddle.fft.__all__ + [
+__all__ = torch.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
@@ -83,4 +83,4 @@ def ifftshift(
"ifftshift",
]
-_all_ignore = ['paddle']
+_all_ignore = ['torch']
diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py
index 5e4ee47b..e26198b9 100644
--- a/array_api_compat/torch/linalg.py
+++ b/array_api_compat/torch/linalg.py
@@ -2,84 +2,86 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- import paddle
- array = paddle.Tensor
- from paddle import dtype as Dtype
+ import torch
+ array = torch.Tensor
+ from torch import dtype as Dtype
from typing import Optional, Union, Tuple, Literal
inf = float('inf')
from ._aliases import _fix_promotion, sum
-from paddle.linalg import * # noqa: F403
+from torch.linalg import * # noqa: F403
-# paddle.linalg doesn't define __all__
-# from paddle.linalg import __all__ as linalg_all
-from paddle import linalg as paddle_linalg
-linalg_all = [i for i in dir(paddle_linalg) if not i.startswith('_')]
+# torch.linalg doesn't define __all__
+# from torch.linalg import __all__ as linalg_all
+from torch import linalg as torch_linalg
+linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
-# outer is implemented in paddle but aren't in the linalg namespace
-from paddle import outer
+# outer is implemented in torch but aren't in the linalg namespace
+from torch import outer
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
-# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
+# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
+# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
-# paddle.cross also does not support broadcasting when it would add new
+# torch.cross also does not support broadcasting when it would add new
+# dimensions https://github.com/pytorch/pytorch/issues/39656
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
- x1, x2 = paddle.broadcast_tensors(x1, x2)
- return paddle_linalg.cross(x1, x2, axis=axis)
+ x1, x2 = torch.broadcast_tensors(x1, x2)
+ return torch_linalg.cross(x1, x2, dim=axis)
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
+ # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
- # paddle.linalg.vecdot doesn't support integer dtypes
+ # torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
- x1_ = paddle.moveaxis(x1, axis, -1)
- x2_ = paddle.moveaxis(x2, axis, -1)
- x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
+ x1_ = torch.moveaxis(x1, axis, -1)
+ x2_ = torch.moveaxis(x2, axis, -1)
+ x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
- return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
+ return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- # paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
- # paddle_META_FUNC(_linalg_solve_ex) in
- # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
+ # TORCH_META_FUNC(_linalg_solve_ex) in
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
- # See https://github.com/pypaddle/pypaddle/issues/52915
+ # See https://github.com/pytorch/pytorch/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
- return paddle.linalg.solve(x1, x2, **kwargs)
+ return torch.linalg.solve(x1, x2, **kwargs)
-# paddle.trace doesn't support the offset argument and doesn't support stacking
+# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
- return sum(paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
+ return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
x: array,
@@ -90,30 +92,30 @@ def vector_norm(
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
- # paddle.vector_norm incorrectly treats axis=() the same as axis=None
+ # torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
if out is None:
dtype = None
- if x.dtype == paddle.complex64:
- dtype = paddle.float32
- elif x.dtype == paddle.complex128:
- dtype = paddle.float64
+ if x.dtype == torch.complex64:
+ dtype = torch.float32
+ elif x.dtype == torch.complex128:
+ dtype = torch.float64
- out = paddle.zeros_like(x, dtype=dtype)
+ out = torch.zeros_like(x, dtype=dtype)
# The norm of a single scalar works out to abs(x) in every case except
- # for p=0, which is x != 0.
+ # for ord=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
- out[:] = paddle.abs(x)
+ out[:] = torch.abs(x)
return out
- return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
+ return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
-_all_ignore = ['paddle_linalg', 'sum']
+_all_ignore = ['torch_linalg', 'sum']
del linalg_all
From 85dc3bafb4f3a9e9e351b8f1037e32360009dd1e Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:05:11 +0800
Subject: [PATCH 03/28] update promotion table and can_cast table
---
array_api_compat/common/_helpers.py | 3 +-
array_api_compat/paddle/_aliases.py | 121 ++++++++++++----------------
tests/_helpers.py | 2 +-
tests/test_all.py | 4 +-
tests/test_common.py | 11 ++-
5 files changed, 64 insertions(+), 77 deletions(-)
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index ff2c213f..ec6b3e0d 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -144,7 +144,6 @@ def is_paddle_array(x):
import paddle
- # TODO: Should we reject ndarray subclasses?
return paddle.is_tensor(x)
def is_ndonnx_array(x):
@@ -725,7 +724,7 @@ def device(x: Array, /) -> Device:
return "cpu"
elif "gpu" in raw_place_str:
return "gpu"
- raise NotImplementedError(f"Unsupported device {raw_place_str}")
+ raise ValueError(f"Unsupported Paddle device: {x.place}")
return x.device
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index dabe2928..14d3de7f 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -42,37 +42,18 @@
paddle.complex128,
}
+# NOTE: Implicit promotion rules of Paddle is a bit strict than other frameworks,
+# see details: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/advanced/auto_type_promotion_cn.html
_promotion_table = {
# bool
(paddle.bool, paddle.bool): paddle.bool,
# ints
(paddle.int8, paddle.int8): paddle.int8,
- (paddle.int8, paddle.int16): paddle.int16,
- (paddle.int8, paddle.int32): paddle.int32,
- (paddle.int8, paddle.int64): paddle.int64,
- (paddle.int16, paddle.int8): paddle.int16,
(paddle.int16, paddle.int16): paddle.int16,
- (paddle.int16, paddle.int32): paddle.int32,
- (paddle.int16, paddle.int64): paddle.int64,
- (paddle.int32, paddle.int8): paddle.int32,
- (paddle.int32, paddle.int16): paddle.int32,
(paddle.int32, paddle.int32): paddle.int32,
- (paddle.int32, paddle.int64): paddle.int64,
- (paddle.int64, paddle.int8): paddle.int64,
- (paddle.int64, paddle.int16): paddle.int64,
- (paddle.int64, paddle.int32): paddle.int64,
(paddle.int64, paddle.int64): paddle.int64,
# uints
(paddle.uint8, paddle.uint8): paddle.uint8,
- # ints and uints (mixed sign)
- (paddle.int8, paddle.uint8): paddle.int16,
- (paddle.int16, paddle.uint8): paddle.int16,
- (paddle.int32, paddle.uint8): paddle.int32,
- (paddle.int64, paddle.uint8): paddle.int64,
- (paddle.uint8, paddle.int8): paddle.int16,
- (paddle.uint8, paddle.int16): paddle.int16,
- (paddle.uint8, paddle.int32): paddle.int32,
- (paddle.uint8, paddle.int64): paddle.int64,
# floats
(paddle.float32, paddle.float32): paddle.float32,
(paddle.float32, paddle.float64): paddle.float64,
@@ -158,12 +139,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.float16: {
paddle.bfloat16: True,
@@ -172,12 +153,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.float32: {
paddle.bfloat16: True,
@@ -186,12 +167,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.float64: {
paddle.bfloat16: True,
@@ -200,40 +181,40 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.complex64: {
- paddle.bfloat16: False,
- paddle.float16: False,
- paddle.float32: False,
- paddle.float64: False,
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.complex128: {
- paddle.bfloat16: False,
- paddle.float16: False,
- paddle.float32: False,
- paddle.float64: False,
+ paddle.bfloat16: True,
+ paddle.float16: True,
+ paddle.float32: True,
+ paddle.float64: True,
paddle.complex64: True,
paddle.complex128: True,
- paddle.uint8: False,
- paddle.int8: False,
- paddle.int16: False,
- paddle.int32: False,
- paddle.int64: False,
- paddle.bool: False,
+ paddle.uint8: True,
+ paddle.int8: True,
+ paddle.int16: True,
+ paddle.int32: True,
+ paddle.int64: True,
+ paddle.bool: True,
},
paddle.uint8: {
paddle.bfloat16: True,
@@ -247,7 +228,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int8: {
paddle.bfloat16: True,
@@ -261,7 +242,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int16: {
paddle.bfloat16: True,
@@ -275,7 +256,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int32: {
paddle.bfloat16: True,
@@ -289,7 +270,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.int64: {
paddle.bfloat16: True,
@@ -303,7 +284,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
paddle.int16: True,
paddle.int32: True,
paddle.int64: True,
- paddle.bool: False,
+ paddle.bool: True,
},
paddle.bool: {
paddle.bfloat16: True,
diff --git a/tests/_helpers.py b/tests/_helpers.py
index 0321bcb4..07f1859a 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -3,7 +3,7 @@
import pytest
-wrapped_libraries = ["numpy", "paddle"]
+wrapped_libraries = ["numpy", "paddle", "torch"]
all_libraries = wrapped_libraries + []
# `sparse` added array API support as of Python 3.10.
diff --git a/tests/test_all.py b/tests/test_all.py
index 969d5cfb..7528b22e 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -40,5 +40,5 @@ def test_all(library):
all_names = module.__all__
if set(dir_names) != set(all_names):
- assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
- assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
+ assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
+ assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"
diff --git a/tests/test_common.py b/tests/test_common.py
index 5c0b5826..a46a2be2 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -17,7 +17,7 @@
is_array_functions = {
'numpy': 'is_numpy_array',
# 'cupy': 'is_cupy_array',
- # 'torch': 'is_torch_array',
+ 'torch': 'is_torch_array',
# 'dask.array': 'is_dask_array',
# 'jax.numpy': 'is_jax_array',
# 'sparse': 'is_pydata_sparse_array',
@@ -27,7 +27,7 @@
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
# 'cupy': 'is_cupy_namespace',
- # 'torch': 'is_torch_namespace',
+ 'torch': 'is_torch_namespace',
# 'dask.array': 'is_dask_namespace',
# 'jax.numpy': 'is_jax_namespace',
# 'sparse': 'is_pydata_sparse_namespace',
@@ -103,6 +103,13 @@ def test_asarray_cross_library(source_library, target_library, request):
if source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
+ if source_library == "paddle" or target_library == "paddle":
+ pytest.skip(
+ reason=(
+ "paddle does not support implicit conversion from/to other framework "
+ "via 'asarray', dlpack is recommend now."
+ )
+ )
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
From c5b82db6f6429a3ed6f4d08a7bac39a9ff6bfd1c Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:14:17 +0800
Subject: [PATCH 04/28] update doc
---
docs/supported-array-libraries.md | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md
index fa30ccd2..26a1c1c5 100644
--- a/docs/supported-array-libraries.md
+++ b/docs/supported-array-libraries.md
@@ -158,5 +158,4 @@ Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
- As with NumPy, type annotations and positional-only arguments may not
exactly match the spec for functions that are not wrapped at all.
-The minimum supported PyTorch version is 1.13.
-
+The minimum supported PyTorch version is 3.0.0.
From 7b99449f2634f05ecae2d4f046355e445ec481de Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:18:17 +0800
Subject: [PATCH 05/28] restore code
---
tests/_helpers.py | 8 ++---
tests/test_all.py | 4 +--
tests/test_array_namespace.py | 55 +++++++++++------------------------
tests/test_common.py | 16 +++++-----
tests/test_no_dependencies.py | 4 +--
tests/test_vendoring.py | 20 ++++++-------
6 files changed, 43 insertions(+), 64 deletions(-)
diff --git a/tests/_helpers.py b/tests/_helpers.py
index 07f1859a..801cd32d 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -3,12 +3,12 @@
import pytest
-wrapped_libraries = ["numpy", "paddle", "torch"]
-all_libraries = wrapped_libraries + []
+wrapped_libraries = ["numpy", "cupy", "torch", "dask.array", "paddle"]
+all_libraries = wrapped_libraries + ["jax.numpy"]
# `sparse` added array API support as of Python 3.10.
-# if sys.version_info >= (3, 10):
-# all_libraries.append('sparse')
+if sys.version_info >= (3, 10):
+ all_libraries.append('sparse')
def import_(library, wrapper=False):
if library == 'cupy':
diff --git a/tests/test_all.py b/tests/test_all.py
index 7528b22e..969d5cfb 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -40,5 +40,5 @@ def test_all(library):
all_names = module.__all__
if set(dir_names) != set(all_names):
- assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
- assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"
+ assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
+ assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index 4b494ec3..cd25a931 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -5,7 +5,7 @@
# import jax
import numpy as np
import pytest
-# import torch
+import torch
import paddle
import array_api_compat
@@ -73,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
"""
subprocess.run([sys.executable, "-c", code], check=True)
-# def test_jax_zero_gradient():
-# jx = jax.numpy.arange(4)
-# jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
-# assert (array_api_compat.get_namespace(jax_zero) is
-# array_api_compat.get_namespace(jx))
+def test_jax_zero_gradient():
+ jx = jax.numpy.arange(4)
+ jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
+ assert (array_api_compat.get_namespace(jax_zero) is
+ array_api_compat.get_namespace(jx))
def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -87,10 +87,10 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
-# def test_array_namespace_errors_torch():
-# y = torch.asarray([1, 2])
-# x = np.asarray([1, 2])
-# pytest.raises(TypeError, lambda: array_namespace(x, y))
+def test_array_namespace_errors_torch():
+ y = torch.asarray([1, 2])
+ x = np.asarray([1, 2])
+ pytest.raises(TypeError, lambda: array_namespace(x, y))
def test_array_namespace_errors_paddle():
@@ -98,42 +98,21 @@ def test_array_namespace_errors_paddle():
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))
-
-# def test_api_version():
-# x = torch.asarray([1, 2])
-# torch_ = import_("torch", wrapper=True)
-# assert array_namespace(x, api_version="2023.12") == torch_
-# assert array_namespace(x, api_version=None) == torch_
-# assert array_namespace(x) == torch_
-# # Should issue a warning
-# with warnings.catch_warnings(record=True) as w:
-# assert array_namespace(x, api_version="2021.12") == torch_
-# assert len(w) == 1
-# assert "2021.12" in str(w[0].message)
-
-# # Should issue a warning
-# with warnings.catch_warnings(record=True) as w:
-# assert array_namespace(x, api_version="2022.12") == torch_
-# assert len(w) == 1
-# assert "2022.12" in str(w[0].message)
-
-# pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
-
def test_api_version():
- x = paddle.asarray([1, 2])
- paddle_ = import_("paddle", wrapper=True)
- assert array_namespace(x, api_version="2023.12") == paddle_
- assert array_namespace(x, api_version=None) == paddle_
- assert array_namespace(x) == paddle_
+ x = torch.asarray([1, 2])
+ torch_ = import_("torch", wrapper=True)
+ assert array_namespace(x, api_version="2023.12") == torch_
+ assert array_namespace(x, api_version=None) == torch_
+ assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2021.12") == paddle_
+ assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2022.12") == paddle_
+ assert array_namespace(x, api_version="2022.12") == torch_
assert len(w) == 1
assert "2022.12" in str(w[0].message)
diff --git a/tests/test_common.py b/tests/test_common.py
index a46a2be2..23ac53d1 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -16,21 +16,21 @@
is_array_functions = {
'numpy': 'is_numpy_array',
- # 'cupy': 'is_cupy_array',
+ 'cupy': 'is_cupy_array',
'torch': 'is_torch_array',
- # 'dask.array': 'is_dask_array',
- # 'jax.numpy': 'is_jax_array',
- # 'sparse': 'is_pydata_sparse_array',
+ 'dask.array': 'is_dask_array',
+ 'jax.numpy': 'is_jax_array',
+ 'sparse': 'is_pydata_sparse_array',
'paddle': 'is_paddle_array',
}
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
- # 'cupy': 'is_cupy_namespace',
+ 'cupy': 'is_cupy_namespace',
'torch': 'is_torch_namespace',
- # 'dask.array': 'is_dask_namespace',
- # 'jax.numpy': 'is_jax_namespace',
- # 'sparse': 'is_pydata_sparse_namespace',
+ 'dask.array': 'is_dask_namespace',
+ 'jax.numpy': 'is_jax_namespace',
+ 'sparse': 'is_pydata_sparse_namespace',
'paddle': 'is_paddle_namespace',
}
diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py
index 201f98ea..11a516ac 100644
--- a/tests/test_no_dependencies.py
+++ b/tests/test_no_dependencies.py
@@ -51,8 +51,8 @@ def _test_dependency(mod):
@pytest.mark.parametrize("library",
[
- "numpy",
- "paddle", "array_api_strict",
+ "numpy", "cupy", "numpy", "torch", "dask.array",
+ "jax.numpy", "sparse", "paddle", "array_api_strict"
]
)
def test_numpy_dependency(library):
diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py
index 91ca1709..3c9b5d92 100644
--- a/tests/test_vendoring.py
+++ b/tests/test_vendoring.py
@@ -7,23 +7,23 @@ def test_vendoring_numpy():
uses_numpy._test_numpy()
-# def test_vendoring_cupy():
-# pytest.importorskip("cupy")
+def test_vendoring_cupy():
+ pytest.importorskip("cupy")
-# from vendor_test import uses_cupy
+ from vendor_test import uses_cupy
-# uses_cupy._test_cupy()
+ uses_cupy._test_cupy()
-# def test_vendoring_torch():
-# from vendor_test import uses_torch
+def test_vendoring_torch():
+ from vendor_test import uses_torch
-# uses_torch._test_torch()
+ uses_torch._test_torch()
-# def test_vendoring_dask():
-# from vendor_test import uses_dask
-# uses_dask._test_dask()
+def test_vendoring_dask():
+ from vendor_test import uses_dask
+ uses_dask._test_dask()
def test_vendoring_paddle():
From bb40851d5b060886b08d32c2122df1186539754b Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:20:23 +0800
Subject: [PATCH 06/28] update docstring
---
array_api_compat/paddle/_info.py | 26 +++++++++++++-------------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index 1fe48356..97e78960 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -15,7 +15,7 @@
class __array_namespace_info__:
"""
- Get the array API inspection namespace for PyTorch.
+ Get the array API inspection namespace for Paddle.
The array API inspection namespace defines the following functions:
@@ -32,7 +32,7 @@ class __array_namespace_info__:
Returns
-------
info : ModuleType
- The array API inspection namespace for PyTorch.
+ The array API inspection namespace for Paddle.
Examples
--------
@@ -54,11 +54,11 @@ def capabilities(self):
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
- supports boolean indexing. Always ``True`` for PyTorch.
+ supports boolean indexing. Always ``True`` for Paddle.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
- PyTorch.
+ Paddle.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
@@ -93,7 +93,7 @@ def capabilities(self):
def default_device(self):
"""
- The default device used for new PyTorch arrays.
+ The default device used for new Paddle arrays.
See Also
--------
@@ -105,7 +105,7 @@ def default_device(self):
Returns
-------
device : str
- The default device used for new PyTorch arrays.
+ The default device used for new Paddle arrays.
Examples
--------
@@ -118,18 +118,18 @@ def default_device(self):
def default_dtypes(self, *, device=None):
"""
- The default data types used for new PyTorch arrays.
+ The default data types used for new Paddle arrays.
Parameters
----------
device : str, optional
- The device to get the default data types for. For PyTorch, only
+ The device to get the default data types for. For Paddle, only
``'cpu'`` is allowed.
Returns
-------
dtypes : dict
- A dictionary describing the default data types used for new PyTorch
+ A dictionary describing the default data types used for new Paddle
arrays.
See Also
@@ -244,7 +244,7 @@ def _dtypes(self, kind):
@cache
def dtypes(self, *, device=None, kind=None):
"""
- The array API data types supported by PyTorch.
+ The array API data types supported by Paddle.
Note that this function only returns data types that are defined by
the array API.
@@ -277,7 +277,7 @@ def dtypes(self, *, device=None, kind=None):
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
- PyTorch data types.
+ Paddle data types.
See Also
--------
@@ -307,12 +307,12 @@ def dtypes(self, *, device=None, kind=None):
@cache
def devices(self):
"""
- The devices supported by PyTorch.
+ The devices supported by Paddle.
Returns
-------
devices : list of str
- The devices supported by PyTorch.
+ The devices supported by Paddle.
See Also
--------
From a7163f903796684d6f25cb420d72aaadb416433c Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 14:35:33 +0800
Subject: [PATCH 07/28] refine more code
---
array_api_compat/paddle/_info.py | 3 +--
array_api_compat/paddle/linalg.py | 5 ++---
tests/test_array_namespace.py | 2 +-
3 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index 97e78960..d8dab7ee 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -170,8 +170,7 @@ def _dtypes(self, kind):
int32 = paddle.int32
int64 = paddle.int64
uint8 = paddle.uint8
- # uint16, uint32, and uint64 are present in newer versions of pytorch,
- # but they aren't generally supported by the array API functions, so
+ # uint16, uint32, and uint64 are not fully supported in paddle,
# we omit them from this function.
float32 = paddle.float32
float64 = paddle.float64
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 6ee57fcf..7ef04a90 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -28,11 +28,10 @@
from ._aliases import matmul, matrix_transpose, tensordot
# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
-# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
-
+# first axis with size 3)
# paddle.cross also does not support broadcasting when it would add new
-# dimensions https://github.com/pytorch/pytorch/issues/39656
+# dimensions
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index cd25a931..e9e7458f 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -2,7 +2,7 @@
import sys
import warnings
-# import jax
+import jax
import numpy as np
import pytest
import torch
From ec461786832538cadb1af01bf41fa023621252fb Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 26 Nov 2024 22:26:24 +0800
Subject: [PATCH 08/28] add suffix for test_python_scalars and add paddle
index-url in rqeuirements
---
requirements-dev.txt | 2 +-
tests/test_array_namespace.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index ae41a25e..7ad022d7 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,6 +4,6 @@ jax[cpu]
numpy
pytest
torch
-paddlepaddle
+paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
sparse >=0.15.1
ndonnx
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index e9e7458f..4076c74c 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -122,7 +122,7 @@ def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace
-def test_python_scalars():
+def test_python_scalars_torch():
a = torch.asarray([1, 2])
xp = import_("torch", wrapper=True)
@@ -138,7 +138,7 @@ def test_python_scalars():
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp
-def test_python_scalars():
+def test_python_scalars_paddle():
a = paddle.to_tensor([1, 2])
xp = import_("paddle", wrapper=True)
From dfd448518b35867ea4ee99a474a51054ac02f2f0 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 3 Dec 2024 14:53:05 +0800
Subject: [PATCH 09/28] update paddle code
---
array_api_compat/paddle/__init__.py | 12 +-
array_api_compat/paddle/_aliases.py | 388 ++++++++++++++++++++--------
array_api_compat/paddle/_info.py | 14 +-
array_api_compat/paddle/fft.py | 31 ++-
array_api_compat/paddle/linalg.py | 48 +++-
5 files changed, 349 insertions(+), 144 deletions(-)
diff --git a/array_api_compat/paddle/__init__.py b/array_api_compat/paddle/__init__.py
index 9f96fa9f..1016312d 100644
--- a/array_api_compat/paddle/__init__.py
+++ b/array_api_compat/paddle/__init__.py
@@ -4,16 +4,10 @@
import paddle
for n in dir(paddle):
- if (
- n.startswith("_")
- or n.endswith("_")
- or "gpu" in n
- or "cpu" in n
- or "backward" in n
- ):
+ if n.startswith("_") or n.endswith("_") or "gpu" in n or "cpu" in n or "backward" in n:
continue
- exec(n + " = paddle." + n)
- exec("asarray = paddle.to_tensor")
+ exec(f"{n} = paddle.{n}")
+
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 14d3de7f..601afa5f 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1,14 +1,17 @@
from __future__ import annotations
+from typing import Literal
+import numpy as np
+
from functools import wraps as _wraps
from builtins import all as _builtin_all, any as _builtin_any
from ..common._aliases import (
- matrix_transpose as _aliases_matrix_transpose,
- vecdot as _aliases_vecdot,
- clip as _aliases_clip,
unstack as _aliases_unstack,
- cumulative_sum as _aliases_cumulative_sum,
+)
+from ..common._typing import (
+ SupportsBufferProtocol,
+ NestedSequence,
)
from .._internal import get_xp
@@ -94,7 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True):
return x1, x2
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
return x1, x2
- # If an argument is 0-D pytorch downcasts the other argument
+ # If an argument is 0-D paddle downcasts the other argument
if not only_scalar or x1.shape == ():
dtype = result_type(x1, x2)
x2 = x2.to(dtype)
@@ -131,6 +134,12 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
+ if paddle.is_tensor(from_):
+ from_ = from_.dtype
+
+ assert isinstance(from_, paddle.dtype), from_.dtype
+ assert isinstance(to, paddle.dtype), to.dtype
+
can_cast_dict = {
paddle.bfloat16: {
paddle.bfloat16: True,
@@ -341,9 +350,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
remainder = _two_arg(paddle.remainder)
subtract = _two_arg(paddle.subtract)
-# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
-# of 'axis'.
-
def max(
x: array,
@@ -352,12 +358,21 @@ def max(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> array:
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.clone(x)
return paddle.amax(x, axis, keepdim=keepdims)
+def argmax(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ return paddle.argmax(x, axis, keepdim=keepdims)
+
+
def min(
x: array,
/,
@@ -365,19 +380,25 @@ def min(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> array:
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.clone(x)
return paddle.min(x, axis, keepdim=keepdims)
-clip = get_xp(paddle)(_aliases_clip)
+def argmin(
+ x: array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> array:
+ return paddle.argmin(x, axis, keepdim=keepdims)
+
+
unstack = get_xp(paddle)(_aliases_unstack)
-cumulative_sum = get_xp(paddle)(_aliases_cumulative_sum)
# paddle.sort also returns a tuple
-# https://github.com/pytorch/pytorch/issues/70921
def sort(
x: array,
/,
@@ -387,9 +408,7 @@ def sort(
stable: bool = True,
**kwargs,
) -> array:
- return paddle.sort(
- x, axis=axis, descending=descending, stable=stable, **kwargs
- ).values
+ return paddle.sort(x, axis=axis, descending=descending, stable=stable, **kwargs)
def _normalize_axes(axis, ndim):
@@ -401,9 +420,7 @@ def _normalize_axes(axis, ndim):
for a in axis:
if a < lower or a > upper:
# Match paddle error message (e.g., from sum())
- raise IndexError(
- f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}"
- )
+ raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
if a < 0:
a = a + ndim
if a in axes:
@@ -415,7 +432,6 @@ def _normalize_axes(axis, ndim):
def _axis_none_keepdims(x, ndim, keepdims):
# Apply keepdims when axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
# Note that this is only valid for the axis=None case.
if keepdims:
for i in range(ndim):
@@ -425,7 +441,6 @@ def _axis_none_keepdims(x, ndim, keepdims):
def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
# Some reductions don't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
axes = _normalize_axes(axis, x.ndim)
for a in reversed(axes):
x = paddle.movedim(x, a, -1)
@@ -448,10 +463,10 @@ def prod(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
# below because it still needs to upcast.
if axis == ():
if dtype is None:
@@ -464,14 +479,10 @@ def prod(
return x.to(dtype)
# paddle.prod doesn't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
- return _reduce_multiple_axes(
- paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
- )
+ return _reduce_multiple_axes(paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.prod(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res
@@ -488,10 +499,10 @@ def sum(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
- # https://github.com/pytorch/pytorch/issues/29137.
# Make sure it upcasts.
if axis == ():
if dtype is None:
@@ -505,7 +516,6 @@ def sum(
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.sum(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res
@@ -521,18 +531,17 @@ def any(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
if axis == ():
return x.to(paddle.bool)
# paddle.any doesn't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(paddle.any, x, axis, keepdim=keepdims, **kwargs)
return res.to(paddle.bool)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.any(x, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res.to(paddle.bool)
@@ -549,18 +558,17 @@ def all(
keepdims: bool = False,
**kwargs,
) -> array:
- x = paddle.asarray(x)
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
ndim = x.ndim
if axis == ():
return x.to(paddle.bool)
# paddle.all doesn't support multiple axes
- # (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(paddle.all, x, axis, keepdim=keepdims, **kwargs)
return res.to(paddle.bool)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.all(x, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res.to(paddle.bool)
@@ -577,12 +585,10 @@ def mean(
keepdims: bool = False,
**kwargs,
) -> array:
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.clone(x)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.mean(x, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
@@ -599,15 +605,12 @@ def std(
**kwargs,
) -> array:
# Note, float correction is not supported
- # https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError(
- "float correction in paddle std() is not yet supported"
- )
+ raise NotImplementedError("float correction in paddle std() is not yet supported")
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -616,14 +619,12 @@ def std(
_correction = bool(_correction)
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.std(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
@@ -640,7 +641,6 @@ def var(
**kwargs,
) -> array:
# Note, float correction is not supported
- # https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
# if isinstance(correction, float):
@@ -648,9 +648,7 @@ def var(
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError(
- "float correction in paddle std() is not yet supported"
- )
+ raise NotImplementedError("float correction in paddle std() is not yet supported")
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -659,14 +657,12 @@ def var(
_correction = bool(_correction)
- # https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return paddle.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# paddle doesn't support keepdims with axis=None
- # (https://github.com/pytorch/pytorch/issues/71209)
res = paddle.var(x, tuple(range(x.ndim)), unbiased=_correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
@@ -674,7 +670,6 @@ def var(
# paddle.concat doesn't support dim=None
-# https://github.com/pytorch/pytorch/issues/70925
def concat(
arrays: Union[Tuple[array, ...], List[array]],
/,
@@ -688,9 +683,6 @@ def concat(
return paddle.concat(arrays, axis, **kwargs)
-# paddle.squeeze only accepts int dim and doesn't require it
-# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
-# added at https://github.com/pytorch/pytorch/pull/89017.
def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
if isinstance(axis, int):
axis = (axis,)
@@ -698,7 +690,7 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
if x.shape[a] != 1:
raise ValueError("squeezed dimensions must be equal to 1")
axes = _normalize_axes(axis, x.ndim)
- # Remove this once pytorch 1.14 is released with the above PR #89017.
+
sequence = [a - i for i, a in enumerate(axes)]
for a in sequence:
x = paddle.squeeze(x, a)
@@ -712,23 +704,15 @@ def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
# paddle.permute uses dims instead of axes
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
- if len(axes) == 2:
- perm = list(range(x.ndim))
- perm[axes[0]], perm[axes[1]] = perm[axes[1]], perm[axes[0]]
- axes = perm
return paddle.transpose(x, axes)
# The axis parameter doesn't work for flip() and roll()
-# https://github.com/pytorch/pytorch/issues/71210. Also paddle.flip() doesn't
# accept axis=None
-def flip(
- x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs
-) -> array:
+def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
if axis is None:
axis = tuple(range(x.ndim))
# paddle.flip doesn't accept dim as an int but the method does
- # https://github.com/pytorch/pytorch/issues/18095
return x.flip(axis, **kwargs)
@@ -754,19 +738,48 @@ def where(condition: array, x1: array, x2: array, /) -> array:
return paddle.where(condition, x1, x2)
-# paddle.reshape doesn't have the copy keyword
-def reshape(
- x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs
+def empty_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+ out = paddle.empty_like(x, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def zeros_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+ out = paddle.zeros_like(x, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def ones_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+ out = paddle.ones_like(x, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def full_like(
+ x: array,
+ /,
+ fill_value: bool | int | float | complex,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
) -> array:
- if copy is not None:
- raise NotImplementedError("paddle.reshape doesn't yet support the copy keyword")
+ out = paddle.full_like(x, fill_value, dtype=dtype)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+# paddle.reshape doesn't have the copy keyword
+def reshape(x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs) -> array:
return paddle.reshape(x, shape, **kwargs)
# paddle.arange doesn't support returning empty arrays
-# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
# keyword argument combinations
-# (https://github.com/pytorch/pytorch/issues/70914)
def arange(
start: Union[int, float],
/,
@@ -790,7 +803,6 @@ def arange(
# paddle.eye does not accept None as a default for the second argument and
-# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
def eye(
n_rows: int,
n_cols: Optional[int] = None,
@@ -822,14 +834,11 @@ def linspace(
**kwargs,
) -> array:
if not endpoint:
- return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[
- :-1
- ]
+ return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[:-1]
return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device)
# paddle.full does not accept an int size
-# https://github.com/pytorch/pytorch/issues/70906
def full(
shape: Union[int, Tuple[int, ...]],
fill_value: Union[bool, int, float, complex],
@@ -886,17 +895,21 @@ def triu(x: array, /, *, k: int = 0) -> array:
return paddle.triu(x, k)
-# Functions that aren't in paddle https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: array, /, *, axis: int = 0) -> array:
return paddle.unsqueeze(x, axis)
-def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
- return x.to(dtype, copy=copy)
+def astype(x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None) -> array:
+ # if copy is not None:
+ # raise NotImplementedError("paddle.astype doesn't yet support the copy keyword")
+ t = x.to(dtype, device=device)
+ if copy:
+ t = t.detach().clone()
+ return t
def broadcast_arrays(*arrays: array) -> List[array]:
- shape = paddle.broadcast_shapes(*[a.shape for a in arrays])
+ shape = broadcast_shapes(*[a.shape for a in arrays])
return [paddle.broadcast_to(a, shape) for a in arrays]
@@ -905,28 +918,19 @@ def broadcast_arrays(*arrays: array) -> List[array]:
from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult
-# https://github.com/pytorch/pytorch/issues/70920
def unique_all(x: array) -> UniqueAllResult:
- # paddle.unique doesn't support returning indices.
- # https://github.com/pytorch/pytorch/issues/36748. The workaround
- # suggested in that issue doesn't actually function correctly (it relies
- # on non-deterministic behavior of scatter()).
- raise NotImplementedError(
- "unique_all() not yet implemented for paddle (see https://github.com/pytorch/pytorch/issues/36748)"
+ return paddle.unique(
+ x,
+ return_index=True,
+ return_inverse=True,
+ return_counts=True,
)
- # values, inverse_indices, counts = paddle.unique(x, return_counts=True, return_inverse=True)
- # # paddle.unique incorrectly gives a 0 count for nan values.
- # # https://github.com/pytorch/pytorch/issues/94106
- # counts[paddle.isnan(values)] = 1
- # return UniqueAllResult(values, indices, inverse_indices, counts)
-
def unique_counts(x: array) -> UniqueCountsResult:
values, counts = paddle.unique(x, return_counts=True)
# paddle.unique incorrectly gives a 0 count for nan values.
- # https://github.com/pytorch/pytorch/issues/94106
counts[paddle.isnan(values)] = 1
return UniqueCountsResult(values, counts)
@@ -946,13 +950,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array:
return paddle.matmul(x1, x2, **kwargs)
-matrix_transpose = get_xp(paddle)(_aliases_matrix_transpose)
-_vecdot = get_xp(paddle)(_aliases_vecdot)
+def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:
+ if indexing == "ij":
+ return paddle.meshgrid(*arrays)
+ else:
+ return [i.T for i in paddle.meshgrid(*arrays)]
+
+
+matrix_transpose = paddle.linalg.matrix_transpose
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
- return _vecdot(x1, x2, axis=axis)
+ return paddle.linalg.vecdot(x1, x2, axis=axis)
# paddle.tensordot uses dims instead of axes
@@ -965,7 +975,6 @@ def tensordot(
**kwargs,
) -> array:
# Note: paddle.tensordot fails with integer dtypes when there is only 1
- # element in the axis (https://github.com/pytorch/pytorch/issues/84530).
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return paddle.tensordot(x1, x2, axes=axes, **kwargs)
@@ -990,16 +999,6 @@ def isdtype(
def is_signed(dtype):
return dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]
- def is_floating_point(dtype):
- return dtype in [
- paddle.float32,
- paddle.float64,
- paddle.float16,
- paddle.bfloat16,
- paddle.float8_e4m3fn,
- paddle.float8_e5m2,
- ]
-
def is_complex(dtype):
return dtype in [paddle.complex64, paddle.complex128]
@@ -1016,7 +1015,7 @@ def is_complex(dtype):
elif kind == "integral":
return dtype in _int_dtypes
elif kind == "real floating":
- return is_floating_point(dtype)
+ return paddle.is_floating_point(dtype)
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1038,18 +1037,172 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
def sign(x: array, /) -> array:
# paddle sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
- if x.dtype.is_complex:
+ if paddle.is_complex(x):
out = x / paddle.abs(x)
# sign(0) = 0 but the above formula would give nan
out[x == 0 + 0j] = 0 + 0j
return out
else:
out = paddle.sign(x)
- if x.dtype.is_floating_point:
- out[paddle.isnan(x)] = paddle.nan
+ if paddle.is_floating_point(x):
+ out = paddle.where(paddle.isnan(x), paddle.nan, out)
return out
+def broadcast_shapes(*shapes: List[int]) -> List[int]:
+ out_shape = shapes[0]
+ for i, shape in enumerate(shapes):
+ if i == 0:
+ continue
+ out_shape = paddle.broadcast_shape(out_shape, shape)
+
+ return out_shape
+
+
+# asarray also adds the copy keyword, which is not present in numpy 1.0.
+def asarray(
+ obj: Union[
+ array,
+ bool,
+ int,
+ float,
+ NestedSequence[bool | int | float],
+ SupportsBufferProtocol,
+ ],
+ /,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ copy: Optional[bool] = None,
+ **kwargs,
+) -> array:
+ """
+ Array API compatibility wrapper for asarray().
+
+ See the corresponding documentation in the array library and/or the array API
+ specification for more details.
+ """
+ if copy is False:
+ if hasattr(obj, "__dlpack__"):
+ obj = paddle.from_dlpack(obj.__dlpack__())
+ if device is not None:
+ obj = obj.to(device)
+ if dtype is not None:
+ obj = obj.to(dtype)
+ return obj
+ else:
+ raise NotImplementedError(
+ "asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method"
+ )
+ elif copy is True:
+ obj = np.array(obj, copy=True)
+ return paddle.to_tensor(obj, dtype=dtype, place=device)
+ else:
+ if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
+ obj = np.array(obj, copy=False)
+ obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
+ if device is not None:
+ obj = obj.to(device)
+ return obj
+
+ return obj
+
+
+def clip(
+ x: array,
+ /,
+ min: Optional[Union[int, float, array]] = None,
+ max: Optional[Union[int, float, array]] = None,
+) -> array:
+ if min is None and max is None:
+ return x
+
+ def _isscalar(a):
+ return isinstance(a, (int, float, type(None)))
+
+ min_shape = [] if _isscalar(min) else min.shape
+ max_shape = [] if _isscalar(max) else max.shape
+
+ result_shape = broadcast_shapes(x.shape, min_shape, max_shape)
+
+ # np.clip does type promotion but the array API clip requires that the
+ # output have the same dtype as x. We do this instead of just downcasting
+ # the result of xp.clip() to handle some corner cases better (e.g.,
+ # avoiding uint64 -> float64 promotion).
+
+ # Note: cases where min or max overflow (integer) or round (float) in the
+ # wrong direction when downcasting to x.dtype are unspecified. This code
+ # just does whatever NumPy does when it downcasts in the assignment, but
+ # other behavior could be preferred, especially for integers. For example,
+ # this code produces:
+
+ # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
+ # -128
+
+ # but an answer of 0 might be preferred. See
+ # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
+
+ # At least handle the case of Python integers correctly (see
+ # https://github.com/numpy/numpy/pull/26892).
+ if type(min) is int and min <= paddle.iinfo(x.dtype).min:
+ min = None
+ if type(max) is int and max >= paddle.iinfo(x.dtype).max:
+ max = None
+
+ if out is None:
+ out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place)
+ if min is not None:
+ if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(min):
+ # Avoid loss of precision due to paddle defaulting to float32
+ min = paddle.to_tensor(min, dtype=paddle.float64)
+ a = broadcast_to(paddle.to_tensor(min, place=x.place), result_shape)
+ ia = (out < a) | paddle.isnan(a)
+ # paddle requires an explicit cast here
+ out[ia] = astype(a[ia], out.dtype)
+ if max is not None:
+ if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(max):
+ max = paddle.to_tensor(max, dtype=paddle.float64)
+ b = broadcast_to(paddle.to_tensor(max, place=x.place), result_shape)
+ ib = (out > b) | paddle.isnan(b)
+ out[ib] = astype(b[ib], out.dtype)
+ # Return a scalar for 0-D
+ return out[()]
+
+
+def cumulative_sum(
+ x: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False
+) -> array:
+ if axis is None:
+ if x.ndim > 1:
+ raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
+ axis = 0
+
+ res = paddle.cumsum(x, axis=axis, dtype=dtype)
+
+ # np.cumsum does not support include_initial
+ if include_initial:
+ initial_shape = list(x.shape)
+ initial_shape[axis] = 1
+ res = paddle.concat(
+ [paddle.zeros(shape=initial_shape, dtype=res.dtype).to(res.place), res],
+ axis=axis,
+ )
+ return res
+
+
+def searchsorted(
+ x1: array, x2: array, /, *, side: Literal["left", "right"] = "left", sorter: array | None = None
+) -> array:
+ if sorter is None:
+ return paddle.searchsorted(x1, x2, right=(side == "right"))
+
+ return paddle.searchsorted(
+ x1.take_along_axis(axis=-1, indices=sorter),
+ x2,
+ right=(side == "right"),
+ )
+
+
__all__ = [
"__array_namespace_info__",
"result_type",
@@ -1129,6 +1282,15 @@ def sign(x: array, /) -> array:
"isdtype",
"take",
"sign",
+ "broadcast_shapes",
+ "argmax",
+ "argmin",
+ "searchsorted",
+ "empty_like",
+ "zeros_like",
+ "ones_like",
+ "full_like",
+ "asarray",
]
_all_ignore = ["paddle", "get_xp"]
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index d8dab7ee..5d29e270 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -332,18 +332,12 @@ def devices(self):
# message of paddle.device to get the list of all possible types of
# device:
try:
- paddle.device("notadevice")
- except RuntimeError as e:
+ paddle.set_device("notadevice")
+ except ValueError as e:
# The error message is something like:
# ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
- devices_names = (
- e.args[0]
- .split("ValueError: The device must be a string which is like ")[1]
- .split(", ")
- )
- devices_names = [
- name.strip("'") for name in devices_names if ":" not in name
- ]
+ devices_names = e.args[0].split("The device must be a string which is like ")[1].split(", ")
+ devices_names = [name.strip("'") for name in devices_names if ":" not in name]
# Next we need to check for different indices for different devices.
# device(device_name, index=index) doesn't actually check if the
diff --git a/array_api_compat/paddle/fft.py b/array_api_compat/paddle/fft.py
index 15519b5a..1442aed8 100644
--- a/array_api_compat/paddle/fft.py
+++ b/array_api_compat/paddle/fft.py
@@ -4,9 +4,10 @@
if TYPE_CHECKING:
import paddle
+ from ..common._typing import Device
array = paddle.Tensor
- from typing import Union, Sequence, Literal
+ from typing import Optional, Union, Sequence, Literal
from paddle.fft import * # noqa: F403
import paddle.fft
@@ -80,6 +81,32 @@ def ifftshift(
return paddle.fft.ifftshift(x, axes=axes, **kwargs)
+def fftfreq(
+ n: int,
+ /,
+ *,
+ d: float = 1.0,
+ device: Optional[Device] = None,
+) -> array:
+ out = paddle.fft.fftfreq(n, d)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
+def rfftfreq(
+ n: int,
+ /,
+ *,
+ d: float = 1.0,
+ device: Optional[Device] = None,
+) -> array:
+ out = paddle.fft.rfftfreq(n, d)
+ if device is not None:
+ out = out.to(device)
+ return out
+
+
__all__ = paddle.fft.__all__ + [
"fftn",
"ifftn",
@@ -87,6 +114,8 @@ def ifftshift(
"irfftn",
"fftshift",
"ifftshift",
+ "fftfreq",
+ "rfftfreq",
]
_all_ignore = ["paddle"]
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 7ef04a90..7dd1a266 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -12,7 +12,9 @@
inf = float("inf")
from ._aliases import _fix_promotion, sum
+from collections import namedtuple
+import paddle
from paddle.linalg import * # noqa: F403
# paddle.linalg doesn't define __all__
@@ -23,6 +25,7 @@
# outer is implemented in paddle but aren't in the linalg namespace
from paddle import outer
+import paddle
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
@@ -30,21 +33,18 @@
# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3)
+
# paddle.cross also does not support broadcasting when it would add new
# dimensions
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
- raise ValueError(
- f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}"
- )
+ raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
- raise ValueError(
- f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}"
- )
+ raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
- x1, x2 = paddle.broadcast_tensors(x1, x2)
+ x1, x2 = paddle.broadcast_tensors([x1, x2])
return paddle_linalg.cross(x1, x2, axis=axis)
@@ -64,7 +64,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
x1_ = paddle.moveaxis(x1, axis, -1)
x2_ = paddle.moveaxis(x2, axis, -1)
- x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
+ x1_, x2_ = paddle.broadcast_tensors([x1_, x2_])
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
@@ -82,9 +82,7 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
# paddle.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
- return sum(
- paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype
- )
+ return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype)
def vector_norm(
@@ -118,16 +116,44 @@ def vector_norm(
return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
+def matrix_norm(
+ x: array,
+ /,
+ *,
+ keepdims: bool = False,
+ ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro",
+) -> array:
+ return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
+
+
+def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
+ if rtol is None:
+ return paddle.linalg.pinv(x)
+
+ return paddle.linalg.pinv(x, rcond=rtol)
+
+
+def slogdet(x: array):
+ det = paddle.linalg.det(x)
+ sign = paddle.sign(det)
+ log_det = paddle.log(det)
+
+ slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
+ return slotdet(sign, log_det)
+
+
__all__ = linalg_all + [
"outer",
"matmul",
"matrix_transpose",
+ "matrix_norm",
"tensordot",
"cross",
"vecdot",
"solve",
"trace",
"vector_norm",
+ "slogdet",
]
_all_ignore = ["paddle_linalg", "sum"]
From 5ae8ec8c59106e5e9aa742dc794d9334f1c620f0 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 3 Dec 2024 15:33:54 +0800
Subject: [PATCH 10/28] fix
---
array_api_compat/paddle/_aliases.py | 14 ++++----------
1 file changed, 4 insertions(+), 10 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 601afa5f..00130e23 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -790,15 +790,6 @@ def arange(
device: Optional[Device] = None,
**kwargs,
) -> array:
- if stop is None:
- start, stop = 0, start
- if step > 0 and stop <= start or step < 0 and stop >= start:
- if dtype is None:
- if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
- dtype = paddle.int64
- else:
- dtype = paddle.float32
- return paddle.empty([0], dtype=dtype, **kwargs).to(device)
return paddle.arange(start, stop, step, dtype=dtype, **kwargs).to(device)
@@ -1100,7 +1091,10 @@ def asarray(
else:
if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
obj = np.array(obj, copy=False)
- obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
+ if dtype != paddle.bool and dtype != "bool":
+ obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
+ else:
+ obj = paddle.to_tensor(obj, dtype=dtype)
if device is not None:
obj = obj.to(device)
return obj
From b10273b41058945d2969c00426d0bc2edbb015f5 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Tue, 10 Dec 2024 16:37:13 +0800
Subject: [PATCH 11/28] update code
---
array_api_compat/paddle/_aliases.py | 67 ++++++++++++++++++++++-------
array_api_compat/paddle/_info.py | 22 ++++++++--
2 files changed, 69 insertions(+), 20 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 00130e23..989b4d85 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -420,7 +420,9 @@ def _normalize_axes(axis, ndim):
for a in axis:
if a < lower or a > upper:
# Match paddle error message (e.g., from sum())
- raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
+ raise IndexError(
+ f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}"
+ )
if a < 0:
a = a + ndim
if a in axes:
@@ -480,7 +482,9 @@ def prod(
# paddle.prod doesn't support multiple axes
if isinstance(axis, tuple):
- return _reduce_multiple_axes(paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs)
+ return _reduce_multiple_axes(
+ paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
+ )
if axis is None:
# paddle doesn't support keepdims with axis=None
res = paddle.prod(x, dtype=dtype, **kwargs)
@@ -610,7 +614,9 @@ def std(
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError("float correction in paddle std() is not yet supported")
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -648,7 +654,9 @@ def var(
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
- raise NotImplementedError("float correction in paddle std() is not yet supported")
+ raise NotImplementedError(
+ "float correction in paddle std() is not yet supported"
+ )
elif isinstance(correction, int):
if correction not in [0, 1]:
raise NotImplementedError("correction only can be 0 or 1")
@@ -709,7 +717,9 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
# The axis parameter doesn't work for flip() and roll()
# accept axis=None
-def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
+def flip(
+ x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs
+) -> array:
if axis is None:
axis = tuple(range(x.ndim))
# paddle.flip doesn't accept dim as an int but the method does
@@ -738,21 +748,27 @@ def where(condition: array, x1: array, x2: array, /) -> array:
return paddle.where(condition, x1, x2)
-def empty_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+def empty_like(
+ x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> array:
out = paddle.empty_like(x, dtype=dtype)
if device is not None:
out = out.to(device)
return out
-def zeros_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+def zeros_like(
+ x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> array:
out = paddle.zeros_like(x, dtype=dtype)
if device is not None:
out = out.to(device)
return out
-def ones_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
+def ones_like(
+ x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> array:
out = paddle.ones_like(x, dtype=dtype)
if device is not None:
out = out.to(device)
@@ -774,7 +790,9 @@ def full_like(
# paddle.reshape doesn't have the copy keyword
-def reshape(x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs) -> array:
+def reshape(
+ x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs
+) -> array:
return paddle.reshape(x, shape, **kwargs)
@@ -825,7 +843,9 @@ def linspace(
**kwargs,
) -> array:
if not endpoint:
- return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[:-1]
+ return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[
+ :-1
+ ]
return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device)
@@ -890,7 +910,9 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array:
return paddle.unsqueeze(x, axis)
-def astype(x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None) -> array:
+def astype(
+ x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None
+) -> array:
# if copy is not None:
# raise NotImplementedError("paddle.astype doesn't yet support the copy keyword")
t = x.to(dtype, device=device)
@@ -1036,7 +1058,7 @@ def sign(x: array, /) -> array:
else:
out = paddle.sign(x)
if paddle.is_floating_point(x):
- out = paddle.where(paddle.isnan(x), paddle.nan, out)
+ out = paddle.where(paddle.isnan(x), paddle.full(x.shape, paddle.nan), out)
return out
@@ -1083,7 +1105,8 @@ def asarray(
return obj
else:
raise NotImplementedError(
- "asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method"
+ "asarray(obj, ..., copy=False) is not supported "
+ "for obj do not has '__dlpack__()' method"
)
elif copy is True:
obj = np.array(obj, copy=True)
@@ -1164,11 +1187,18 @@ def _isscalar(a):
def cumulative_sum(
- x: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False
+ x: array,
+ /,
+ *,
+ axis: Optional[int] = None,
+ dtype: Optional[Dtype] = None,
+ include_initial: bool = False,
) -> array:
if axis is None:
if x.ndim > 1:
- raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
+ raise ValueError(
+ "axis must be specified in cumulative_sum for more than one dimension"
+ )
axis = 0
res = paddle.cumsum(x, axis=axis, dtype=dtype)
@@ -1185,7 +1215,12 @@ def cumulative_sum(
def searchsorted(
- x1: array, x2: array, /, *, side: Literal["left", "right"] = "left", sorter: array | None = None
+ x1: array,
+ x2: array,
+ /,
+ *,
+ side: Literal["left", "right"] = "left",
+ sorter: array | None = None,
) -> array:
if sorter is None:
return paddle.searchsorted(x1, x2, right=(side == "right"))
diff --git a/array_api_compat/paddle/_info.py b/array_api_compat/paddle/_info.py
index 5d29e270..6f079020 100644
--- a/array_api_compat/paddle/_info.py
+++ b/array_api_compat/paddle/_info.py
@@ -154,8 +154,16 @@ def default_dtypes(self, *, device=None):
# value here because this error doesn't represent a different default
# per-device.
default_floating = paddle.get_default_dtype()
- default_complex = "complex64" if default_floating == "float32" else "complex128"
- default_integral = "int64"
+ if default_floating in ["float16", "float32", "float64", "bfloat16"]:
+ default_floating = getattr(paddle, default_floating)
+ else:
+ raise ValueError(f"Unsupported default floating: {default_floating}")
+ default_complex = (
+ paddle.complex64
+ if default_floating == paddle.float32
+ else paddle.complex128
+ )
+ default_integral = paddle.int64
return {
"real floating": default_floating,
"complex floating": default_complex,
@@ -336,8 +344,14 @@ def devices(self):
except ValueError as e:
# The error message is something like:
# ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
- devices_names = e.args[0].split("The device must be a string which is like ")[1].split(", ")
- devices_names = [name.strip("'") for name in devices_names if ":" not in name]
+ devices_names = (
+ e.args[0]
+ .split("The device must be a string which is like ")[1]
+ .split(", ")
+ )
+ devices_names = [
+ name.strip("'") for name in devices_names if ":" not in name
+ ]
# Next we need to check for different indices for different devices.
# device(device_name, index=index) doesn't actually check if the
From 8d2425ee538eca51698c35296269f2de114848aa Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Sat, 14 Dec 2024 17:26:40 +0800
Subject: [PATCH 12/28] fix moveaxis
---
array_api_compat/paddle/_aliases.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 989b4d85..31a1193b 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -445,7 +445,7 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
# Some reductions don't support multiple axes
axes = _normalize_axes(axis, x.ndim)
for a in reversed(axes):
- x = paddle.movedim(x, a, -1)
+ x = paddle.moveaxis(x, a, -1)
x = paddle.flatten(x, -len(axes))
out = f(x, -1, **kwargs)
@@ -922,8 +922,7 @@ def astype(
def broadcast_arrays(*arrays: array) -> List[array]:
- shape = broadcast_shapes(*[a.shape for a in arrays])
- return [paddle.broadcast_to(a, shape) for a in arrays]
+ return paddle.broadcast_tensors(arrays)
# Note that these named tuples aren't actually part of the standard namespace,
From 7b8555e8ea57cd644be573e9c613c9d209f2467f Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Wed, 8 Jan 2025 22:13:31 +0800
Subject: [PATCH 13/28] fix default floating dtype of paddle.assaray
---
array_api_compat/paddle/_aliases.py | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 31a1193b..0cccdbc8 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1027,7 +1027,16 @@ def is_complex(dtype):
elif kind == "integral":
return dtype in _int_dtypes
elif kind == "real floating":
- return paddle.is_floating_point(dtype)
+ return dtype in [
+ paddle.framework.core.VarDesc.VarType.FP32,
+ paddle.framework.core.VarDesc.VarType.FP64,
+ paddle.framework.core.VarDesc.VarType.FP16,
+ paddle.framework.core.VarDesc.VarType.BF16,
+ paddle.framework.core.DataType.FLOAT32,
+ paddle.framework.core.DataType.FLOAT64,
+ paddle.framework.core.DataType.FLOAT16,
+ paddle.framework.core.DataType.BFLOAT16,
+ ]
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1109,10 +1118,14 @@ def asarray(
)
elif copy is True:
obj = np.array(obj, copy=True)
+ if np.issubdtype(obj.dtype, np.floating):
+ obj = obj.astype(paddle.get_default_dtype())
return paddle.to_tensor(obj, dtype=dtype, place=device)
else:
if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
obj = np.array(obj, copy=False)
+ if np.issubdtype(obj.dtype, np.floating):
+ obj = obj.astype(paddle.get_default_dtype())
if dtype != paddle.bool and dtype != "bool":
obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
else:
From 603c8524b20917b6dd4b61d4106c60504458fe0d Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 9 Jan 2025 11:33:37 +0800
Subject: [PATCH 14/28] use default_dtype only when dtype is None
---
array_api_compat/paddle/_aliases.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 0cccdbc8..c3e94cf1 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1118,13 +1118,13 @@ def asarray(
)
elif copy is True:
obj = np.array(obj, copy=True)
- if np.issubdtype(obj.dtype, np.floating):
+ if np.issubdtype(obj.dtype, np.floating) and dtype is None:
obj = obj.astype(paddle.get_default_dtype())
return paddle.to_tensor(obj, dtype=dtype, place=device)
else:
if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
obj = np.array(obj, copy=False)
- if np.issubdtype(obj.dtype, np.floating):
+ if np.issubdtype(obj.dtype, np.floating) and dtype is None:
obj = obj.astype(paddle.get_default_dtype())
if dtype != paddle.bool and dtype != "bool":
obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
From 742792f6635689ce9d67270f5cb649db6c357fe4 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 9 Jan 2025 16:10:20 +0800
Subject: [PATCH 15/28] add floor and ceil with same return dtype
---
array_api_compat/paddle/_aliases.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index c3e94cf1..6f23ee20 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1036,7 +1036,7 @@ def is_complex(dtype):
paddle.framework.core.DataType.FLOAT64,
paddle.framework.core.DataType.FLOAT16,
paddle.framework.core.DataType.BFLOAT16,
- ]
+ ]
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1137,6 +1137,14 @@ def asarray(
return obj
+def floor(x: array, /) -> array:
+ return paddle.floor(x).to(x.dtype)
+
+
+def ceil(x: array, /) -> array:
+ return paddle.ceil(x).to(x.dtype)
+
+
def clip(
x: array,
/,
From fd6eea032fb9b42ae3c84550a220c404d2ef14a0 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Thu, 9 Jan 2025 16:22:40 +0800
Subject: [PATCH 16/28] update code
---
array_api_compat/paddle/_aliases.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 6f23ee20..622504d7 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -4,7 +4,7 @@
import numpy as np
from functools import wraps as _wraps
-from builtins import all as _builtin_all, any as _builtin_any
+from builtins import any as _builtin_any
from ..common._aliases import (
unstack as _aliases_unstack,
@@ -1036,7 +1036,7 @@ def is_complex(dtype):
paddle.framework.core.DataType.FLOAT64,
paddle.framework.core.DataType.FLOAT16,
paddle.framework.core.DataType.BFLOAT16,
- ]
+ ]
elif kind == "complex floating":
return is_complex(dtype)
elif kind == "numeric":
@@ -1186,8 +1186,7 @@ def _isscalar(a):
if type(max) is int and max >= paddle.iinfo(x.dtype).max:
max = None
- if out is None:
- out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place)
+ out = paddle.to_tensor(broadcast_to(x, result_shape), place=x.place)
if min is not None:
if paddle.is_tensor(x) and x.dtype == paddle.float64 and _isscalar(min):
# Avoid loss of precision due to paddle defaulting to float32
@@ -1203,7 +1202,7 @@ def _isscalar(a):
ib = (out > b) | paddle.isnan(b)
out[ib] = astype(b[ib], out.dtype)
# Return a scalar for 0-D
- return out[()]
+ return out
def cumulative_sum(
@@ -1340,6 +1339,8 @@ def searchsorted(
"ones_like",
"full_like",
"asarray",
+ "ceil",
+ "floor",
]
_all_ignore = ["paddle", "get_xp"]
From 6f32d63ccdfa672f02bbb0aba8b51fa47a83523c Mon Sep 17 00:00:00 2001
From: hongyuHe
Date: Mon, 31 Mar 2025 12:44:53 +0000
Subject: [PATCH 17/28] update
---
array_api_compat/paddle/_aliases.py | 36 +++++++++++++++++++++++++----
array_api_compat/paddle/linalg.py | 22 ++++++++++++++++++
vendor_test/vendored/_compat | 2 +-
3 files changed, 54 insertions(+), 6 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 622504d7..0e90b020 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -78,7 +78,7 @@
def _two_arg(f):
@_wraps(f)
def _f(x1, x2, /, **kwargs):
- x1, x2 = _fix_promotion(x1, x2)
+ # x1, x2 = _fix_promotion(x1, x2)
return f(x1, x2, **kwargs)
if _f.__doc__ is None:
@@ -312,6 +312,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
}
return can_cast_dict[from_][to]
+def test_bitwise_or(x: array, y: array):
+ if not paddle.is_tensor(x):
+ x = paddle.to_tensor(x)
+ if not paddle.is_tensor(y):
+ y = paddle.to_tensor(y)
+ return paddle.bitwise_or(x, y)
# Basic renames
bitwise_invert = paddle.bitwise_not
@@ -326,7 +332,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
atan2 = _two_arg(paddle.atan2)
bitwise_and = _two_arg(paddle.bitwise_and)
bitwise_left_shift = _two_arg(paddle.bitwise_left_shift)
-bitwise_or = _two_arg(paddle.bitwise_or)
+bitwise_or = _two_arg(test_bitwise_or)
bitwise_right_shift = _two_arg(paddle.bitwise_right_shift)
bitwise_xor = _two_arg(paddle.bitwise_xor)
copysign = _two_arg(paddle.copysign)
@@ -455,6 +461,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
out = paddle.unsqueeze(out, a)
return out
+_NP_2_PADDLE_DTYPE = {
+ "BOOL": 'bool',
+ "UINT8": 'uint8',
+ "INT8": 'int8',
+ "INT16": 'int16',
+ "INT32": 'int32',
+ "INT64": 'int64',
+ "FLOAT16": 'float16',
+ "BFLOAT16": 'bfloat16',
+ "FLOAT32": 'float32',
+ "FLOAT64": 'float64',
+ "COMPLEX128": 'complex128',
+ "COMPLEX64": 'complex64',
+}
def prod(
x: array,
@@ -469,6 +489,10 @@ def prod(
x = paddle.to_tensor(x)
ndim = x.ndim
+ if dtype is not None:
+ # import pdb
+ # pdb.set_trace()
+ dtype = _NP_2_PADDLE_DTYPE[dtype.name]
# below because it still needs to upcast.
if axis == ():
if dtype is None:
@@ -825,7 +849,7 @@ def eye(
if n_cols is None:
n_cols = n_rows
z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device)
- if abs(k) <= n_rows + n_cols:
+ if n_rows > 0 and n_cols > 0 and abs(k) <= n_rows + n_cols:
z.diagonal(k).fill_(1)
return z
@@ -1052,6 +1076,10 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
if x.ndim != 1:
raise ValueError("axis must be specified when ndim > 1")
axis = 0
+ if not paddle.is_tensor(indices):
+ indices = paddle.to_tensor(indices)
+ if not paddle.is_tensor(axis):
+ axis = paddle.to_tensor(axis)
return paddle.index_select(x, axis, indices, **kwargs)
@@ -1144,7 +1172,6 @@ def floor(x: array, /) -> array:
def ceil(x: array, /) -> array:
return paddle.ceil(x).to(x.dtype)
-
def clip(
x: array,
/,
@@ -1250,7 +1277,6 @@ def searchsorted(
right=(side == "right"),
)
-
__all__ = [
"__array_namespace_info__",
"result_type",
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 7dd1a266..4be84908 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -84,6 +84,8 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
# Use our wrapped sum to make sure it does upcasting correctly
return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype)
+def diagonal(x: ndarray, / , *, offset: int = 0, **kwargs) -> ndarray:
+ return paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
def vector_norm(
x: array,
@@ -141,6 +143,24 @@ def slogdet(x: array):
slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
return slotdet(sign, log_det)
+def tuple_to_namedtuple(data, fields):
+ nt_class = namedtuple('DynamicNameTuple', fields)
+ return nt_class(*data)
+
+def eigh(x: array):
+ return tuple_to_namedtuple(paddle.linalg.eigh(x), ['eigenvalues', 'eigenvectors'])
+
+def qr(x: array, mode: Optional[str] = None) -> array:
+ if mode is None:
+ return tuple_to_namedtuple(paddle.linalg.qr(x), ['Q', 'R'])
+
+ return tuple_to_namedtuple(paddle.linalg.qr(x, mode), ['Q', 'R'])
+
+
+def svd(x: array, full_matrices: Optional[bool]= None) -> array:
+ if full_matrices is None :
+ return tuple_to_namedtuple(paddle.linalg.svd(x), ['U', 'S', 'Vh'])
+ return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])
__all__ = linalg_all + [
"outer",
@@ -154,6 +174,8 @@ def slogdet(x: array):
"trace",
"vector_norm",
"slogdet",
+ "eigh",
+ "diagonal",
]
_all_ignore = ["paddle_linalg", "sum"]
diff --git a/vendor_test/vendored/_compat b/vendor_test/vendored/_compat
index ba484f91..07b6ab4f 120000
--- a/vendor_test/vendored/_compat
+++ b/vendor_test/vendored/_compat
@@ -1 +1 @@
-../../array_api_compat/
\ No newline at end of file
+../../array_api_compat
\ No newline at end of file
From 37785d442e890a579edf923fb3c174c5bbf64926 Mon Sep 17 00:00:00 2001
From: cangtianhuang
Date: Tue, 1 Apr 2025 15:05:10 +0800
Subject: [PATCH 18/28] Add broadcast_tensors alias, modify result_type
---
array_api_compat/paddle/_aliases.py | 44 ++++++++++++++++++++---------
1 file changed, 30 insertions(+), 14 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 622504d7..d19353e0 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import builtins
from typing import Literal
import numpy as np
@@ -112,25 +113,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
raise TypeError("At least one array or dtype must be provided")
if len(arrays_and_dtypes) == 1:
x = arrays_and_dtypes[0]
- if isinstance(x, paddle.dtype):
- return x
- return x.dtype
+ return x if isinstance(x, paddle.dtype) else x.dtype
if len(arrays_and_dtypes) > 2:
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
x, y = arrays_and_dtypes
- xdt = x.dtype if not isinstance(x, paddle.dtype) else x
- ydt = y.dtype if not isinstance(y, paddle.dtype) else y
+ xdt = x if isinstance(x, paddle.dtype) else x.dtype
+ ydt = y if isinstance(y, paddle.dtype) else y.dtype
if (xdt, ydt) in _promotion_table:
- return _promotion_table[xdt, ydt]
-
- # This doesn't result_type(dtype, dtype) for non-array API dtypes
- # because paddle.result_type only accepts tensors. This does however, allow
- # cross-kind promotion.
- x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x
- y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y
- return paddle.result_type(x, y)
+ return _promotion_table[(xdt, ydt)]
+
+ type_order = {
+ paddle.bool: 0,
+ paddle.int8: 1,
+ paddle.uint8: 2,
+ paddle.int16: 3,
+ paddle.int32: 4,
+ paddle.int64: 5,
+ paddle.float16: 6,
+ paddle.float32: 7,
+ paddle.float64: 8,
+ paddle.complex64: 9,
+ paddle.complex128: 10
+ }
+
+ return xdt if type_order.get(xdt, 0) > type_order.get(ydt, 0) else ydt
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
@@ -922,7 +930,15 @@ def astype(
def broadcast_arrays(*arrays: array) -> List[array]:
- return paddle.broadcast_tensors(arrays)
+ original_dtypes = [arr.dtype for arr in arrays]
+ if len(set(original_dtypes)) == 1:
+ return paddle.broadcast_tensors(arrays)
+ target_dtype = result_type(*arrays)
+ casted_arrays = [arr.astype(target_dtype) if arr.dtype != target_dtype else arr
+ for arr in arrays]
+ broadcasted = paddle.broadcast_tensors(casted_arrays)
+ result = [arr.astype(original_dtype) for arr, original_dtype in zip(broadcasted, original_dtypes)]
+ return result
# Note that these named tuples aren't actually part of the standard namespace,
From 0651731fc5bda2a3362d1665da31ea12635f5963 Mon Sep 17 00:00:00 2001
From: cangtianhuang
Date: Tue, 1 Apr 2025 15:06:56 +0800
Subject: [PATCH 19/28] refine
---
array_api_compat/paddle/_aliases.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index d19353e0..88f71e7d 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import builtins
from typing import Literal
import numpy as np
From 372283ac98ddc4aa34b4dc9cf95a15a4f7353f15 Mon Sep 17 00:00:00 2001
From: Hongyuhe
Date: Tue, 1 Apr 2025 09:11:03 +0000
Subject: [PATCH 20/28] update
---
array_api_compat/paddle/_aliases.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 0e90b020..d0ff66b8 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -78,7 +78,7 @@
def _two_arg(f):
@_wraps(f)
def _f(x1, x2, /, **kwargs):
- # x1, x2 = _fix_promotion(x1, x2)
+ x1, x2 = _fix_promotion(x1, x2)
return f(x1, x2, **kwargs)
if _f.__doc__ is None:
From b946e8263e6d7527f9f4805a9d07f4e1a0ef11d1 Mon Sep 17 00:00:00 2001
From: Hongyuhe
Date: Thu, 3 Apr 2025 09:51:52 +0000
Subject: [PATCH 21/28] update
---
array_api_compat/paddle/_aliases.py | 35 +++++++++++++++++++++--------
array_api_compat/paddle/linalg.py | 7 +-----
array_api_compat/torch/_aliases.py | 3 +++
3 files changed, 30 insertions(+), 15 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 1618e498..b32a7c94 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -938,6 +938,8 @@ def triu(x: array, /, *, k: int = 0) -> array:
def expand_dims(x: array, /, *, axis: int = 0) -> array:
+ if axis < -x.ndim - 1 or axis > x.ndim:
+ raise IndexError(f"Axis {axis} is out of bounds for array of dimension { x.ndim}")
return paddle.unsqueeze(x, axis)
@@ -1087,15 +1089,31 @@ def is_complex(dtype):
def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
- if axis is None:
+ _axis = axis
+ if _axis is None:
if x.ndim != 1:
- raise ValueError("axis must be specified when ndim > 1")
- axis = 0
- if not paddle.is_tensor(indices):
- indices = paddle.to_tensor(indices)
- if not paddle.is_tensor(axis):
- axis = paddle.to_tensor(axis)
- return paddle.index_select(x, axis, indices, **kwargs)
+ raise ValueError("axis must be specified when x.ndim > 1")
+ _axis = 0
+ elif not isinstance(_axis, int):
+ raise TypeError(f"axis must be an integer, but received {type(_axis)}")
+
+ if not (-x.ndim <= _axis < x.ndim):
+ raise IndexError(f"axis {_axis} is out of bounds for tensor of dimension {x.ndim}")
+
+ if isinstance(indices, paddle.Tensor):
+ indices_tensor = indices
+ elif isinstance(indices, int):
+ indices_tensor = paddle.to_tensor([indices], dtype='int64')
+ else:
+ # Otherwise (e.g., list, tuple), convert directly
+ indices_tensor = paddle.to_tensor(indices, dtype='int64')
+ # Ensure indices is a 1D tensor
+ if indices_tensor.ndim == 0:
+ indices_tensor = indices_tensor.reshape([1])
+ elif indices_tensor.ndim > 1:
+ raise ValueError(f"indices must be a 1D tensor, but received a {indices_tensor.ndim}D tensor")
+
+ return paddle.index_select(x, index=indices_tensor, axis=_axis)
def sign(x: array, /) -> array:
@@ -1261,7 +1279,6 @@ def cumulative_sum(
"axis must be specified in cumulative_sum for more than one dimension"
)
axis = 0
-
res = paddle.cumsum(x, axis=axis, dtype=dtype)
# np.cumsum does not support include_initial
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 4be84908..0cdea3ba 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -136,12 +136,7 @@ def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
def slogdet(x: array):
- det = paddle.linalg.det(x)
- sign = paddle.sign(det)
- log_det = paddle.log(det)
-
- slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
- return slotdet(sign, log_det)
+ return tuple_to_namedtuple(paddle.linalg.slogdet(x), ["sign", "logabsdet"])
def tuple_to_namedtuple(data, fields):
nt_class = namedtuple('DynamicNameTuple', fields)
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 5ac66bcb..792abc0a 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -611,6 +611,9 @@ def triu(x: array, /, *, k: int = 0) -> array:
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: array, /, *, axis: int = 0) -> array:
+ if axis == 2:
+ import pdb
+ pdb.set_trace()
return torch.unsqueeze(x, axis)
def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
From 34af083bf4d8a35bb4505c496f5177721a585a64 Mon Sep 17 00:00:00 2001
From: Hongyuhe
Date: Mon, 7 Apr 2025 03:47:18 +0000
Subject: [PATCH 22/28] update
---
array_api_compat/paddle/_aliases.py | 31 ++++++++++++++++++++++++++++-
1 file changed, 30 insertions(+), 1 deletion(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index b32a7c94..93c0f259 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -1014,7 +1014,36 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
- x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
+ shape1 = x1.shape
+ shape2 = x2.shape
+ rank1 = len(shape1)
+ rank2 = len(shape2)
+ if rank1 == 0 or rank2 == 0:
+ raise ValueError(
+ f"Vector dot product requires non-scalar inputs (rank > 0). "
+ f"Got ranks {rank1} and {rank2} for shapes {shape1} and {shape2}."
+ )
+ try:
+ norm_axis1 = axis if axis >= 0 else rank1 + axis
+ if not (0 <= norm_axis1 < rank1):
+ raise IndexError # Axis out of bounds for x1
+ norm_axis2 = axis if axis >= 0 else rank2 + axis
+ if not (0 <= norm_axis2 < rank2):
+ raise IndexError # Axis out of bounds for x2
+ size1 = shape1[norm_axis1]
+ size2 = shape2[norm_axis2]
+ except IndexError:
+ raise ValueError(
+ f"Axis {axis} is out of bounds for input shapes {shape1} (rank {rank1}) "
+ f"and/or {shape2} (rank {rank2})."
+ )
+
+ if size1 != size2:
+ raise ValueError(
+ f"Inputs must have the same dimension size along the reduction axis ({axis}). "
+ f"Got shapes {shape1} and {shape2}, with sizes {size1} and {size2} "
+ f"along the normalized axis {norm_axis1} and {norm_axis2} respectively."
+ )
return paddle.linalg.vecdot(x1, x2, axis=axis)
From 0dbf7dd0320ea42564c07cffdbfa381664d3c4ab Mon Sep 17 00:00:00 2001
From: Hongyuhe
Date: Thu, 1 May 2025 10:09:02 +0000
Subject: [PATCH 23/28] update
---
array_api_compat/paddle/_aliases.py | 14 ++++++++------
array_api_compat/paddle/linalg.py | 2 +-
2 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 93c0f259..969d8bc5 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -891,7 +891,11 @@ def full(
) -> array:
if isinstance(shape, int):
shape = (shape,)
-
+ if dtype is None :
+ if isinstance(fill_value, (bool)):
+ dtype = "bool"
+ elif isinstance(fill_value, int):
+ dtype = 'int64'
return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device)
@@ -1148,11 +1152,9 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
def sign(x: array, /) -> array:
# paddle sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
- if paddle.is_complex(x):
- out = x / paddle.abs(x)
- # sign(0) = 0 but the above formula would give nan
- out[x == 0 + 0j] = 0 + 0j
- return out
+ if paddle.is_complex(x) and x.ndim == 0 and x.item() == 0j:
+ # Handle 0-D complex zero explicitly
+ return paddle.zeros_like(x, dtype=x.dtype)
else:
out = paddle.sign(x)
if paddle.is_floating_point(x):
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 0cdea3ba..4b5ac83d 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -154,7 +154,7 @@ def qr(x: array, mode: Optional[str] = None) -> array:
def svd(x: array, full_matrices: Optional[bool]= None) -> array:
if full_matrices is None :
- return tuple_to_namedtuple(paddle.linalg.svd(x), ['U', 'S', 'Vh'])
+ return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh'])
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])
__all__ = linalg_all + [
From 912fe3e56739a4353c95f6111e05270dde6fcb86 Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Mon, 12 May 2025 20:01:10 +0800
Subject: [PATCH 24/28] add paddle skip and xfail files
---
paddle-skips.txt | 6 +++
paddle-xfails.txt | 108 ++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 114 insertions(+)
create mode 100644 paddle-skips.txt
create mode 100644 paddle-xfails.txt
diff --git a/paddle-skips.txt b/paddle-skips.txt
new file mode 100644
index 00000000..094c553f
--- /dev/null
+++ b/paddle-skips.txt
@@ -0,0 +1,6 @@
+array_api_tests/test_array_object.py::test_getitem_masking
+array_api_tests/test_data_type_functions.py::test_result_type
+array_api_tests/test_data_type_functions.py::test_broadcast_arrays
+array_api_tests/test_manipulation_functions.py::test_roll
+array_api_tests/test_data_type_functions.py::test_broadcast_to
+array_api_tests/test_linalg.py::test_cholesky
diff --git a/paddle-xfails.txt b/paddle-xfails.txt
new file mode 100644
index 00000000..6998f374
--- /dev/null
+++ b/paddle-xfails.txt
@@ -0,0 +1,108 @@
+# Skip 'copy=...'
+array_api_tests/test_array_object.py::test_setitem
+array_api_tests/test_array_object.py::test_setitem_masking
+# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
+# array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
+
+# Skip promotion test for 'Scalar op Tensor'
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
+
+# torch do not pass
+array_api_tests/test_creation_functions.py::test_asarray_scalars
+array_api_tests/test_creation_functions.py::test_asarray_arrays
+array_api_tests/test_creation_functions.py::test_empty_like
+array_api_tests/test_creation_functions.py::test_eye
+array_api_tests/test_creation_functions.py::test_full
+array_api_tests/test_creation_functions.py::test_full_like
+array_api_tests/test_creation_functions.py::test_linspace
+array_api_tests/test_creation_functions.py::test_ones
+array_api_tests/test_creation_functions.py::test_ones_like
+array_api_tests/test_creation_functions.py::test_zeros
+array_api_tests/test_creation_functions.py::test_zeros_like
+array_api_tests/test_fft.py::test_fft
+array_api_tests/test_fft.py::test_ifft
+array_api_tests/test_fft.py::test_fftn
+array_api_tests/test_fft.py::test_ifftn
+array_api_tests/test_fft.py::test_rfft
+array_api_tests/test_fft.py::test_irfft
+array_api_tests/test_fft.py::test_rfftn
+array_api_tests/test_fft.py::test_hfft
+array_api_tests/test_fft.py::test_ihfft
+array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
+array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
+array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
+array_api_tests/test_indexing_functions.py::test_take
+array_api_tests/test_linalg.py::test_linalg_matmul
+array_api_tests/test_linalg.py::test_qr
+array_api_tests/test_linalg.py::test_solve
+array_api_tests/test_manipulation_functions.py::test_concat
+array_api_tests/test_manipulation_functions.py::test_repeat
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_round
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
+array_api_tests/test_set_functions.py::test_unique_all
+array_api_tests/test_set_functions.py::test_unique_counts
+array_api_tests/test_set_functions.py::test_unique_inverse
+array_api_tests/test_set_functions.py::test_unique_values
+array_api_tests/test_signatures.py::test_func_signature[astype]
+array_api_tests/test_signatures.py::test_func_signature[repeat]
+array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
+array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
+array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
+array_api_tests/test_signatures.py::test_array_method_signature[to_device]
+array_api_tests/test_sorting_functions.py::test_argsort
+array_api_tests/test_sorting_functions.py::test_sort
+array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]
+
+# dtype promotion related
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor
+array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
+array_api_tests/test_searching_functions.py::test_where
From 13e27826caca1054c628448b432411950b1083a2 Mon Sep 17 00:00:00 2001
From: Hongyuhe
Date: Sat, 7 Jun 2025 09:24:19 +0000
Subject: [PATCH 25/28] update
---
array_api_compat/paddle/_aliases.py | 73 ++++++++++++++++++++++++-----
array_api_compat/paddle/linalg.py | 20 +++++++-
paddle-xfails.txt | 58 +++++++++++++++++++++++
3 files changed, 137 insertions(+), 14 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 969d8bc5..2c337f24 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -496,33 +496,56 @@ def prod(
x = paddle.to_tensor(x)
ndim = x.ndim
+ # fix reducing on the zero dimension
+ if x.numel() == 0:
+ if dtype is not None:
+ output_dtype = _NP_2_PADDLE_DTYPE[dtype.name]
+ else:
+ if x.dtype == paddle.bool:
+ output_dtype = paddle.int64
+ else:
+ output_dtype = x.dtype
+
+ if axis is None:
+ return paddle.to_tensor(1, dtype=output_dtype)
+
+ if keepdims:
+ output_shape = list(x.shape)
+ if isinstance(axis, int):
+ axis = (axis,)
+ for ax in axis:
+ output_shape[ax] = 1
+ else:
+ output_shape = [dim for i, dim in enumerate(x.shape) if i not in (axis if isinstance(axis, tuple) else [axis])]
+ if not output_shape:
+ return paddle.to_tensor(1, dtype=output_dtype)
+
+ return paddle.ones(output_shape, dtype=output_dtype)
+
+
if dtype is not None:
- # import pdb
- # pdb.set_trace()
- dtype = _NP_2_PADDLE_DTYPE[dtype.name]
- # below because it still needs to upcast.
+ dtype = _NP_2_PADDLE_DTYPE[dtype.name]
+
if axis == ():
if dtype is None:
- # We can't upcast uint8 according to the spec because there is no
- # paddle.uint64, so at least upcast to int64 which is what sum does
- # when axis=None.
if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
return x.to(paddle.int64)
return x.clone()
return x.to(dtype)
- # paddle.prod doesn't support multiple axes
if isinstance(axis, tuple):
return _reduce_multiple_axes(
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
)
+
if axis is None:
- # paddle doesn't support keepdims with axis=None
+ if dtype is None and x.dtype == paddle.int32:
+ dtype = 'int64'
res = paddle.prod(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res
-
- return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
+
+ return paddle.prod(x, axis=axis, keepdims=keepdims, dtype=dtype, **kwargs)
def sum(
@@ -771,7 +794,17 @@ def roll(
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
- return paddle.nonzero(x, as_tuple=True, **kwargs)
+
+ if paddle.is_floating_point(x) or paddle.is_complex(x) :
+ # Use paddle.isclose() to determine which elements are
+ # "close enough" to zero.
+ zero_tensor = paddle.zeros(shape=x.shape ,dtype=x.dtype)
+ is_zero_mask = paddle.isclose(x, zero_tensor)
+ is_nonzero_mask = paddle.logical_not(is_zero_mask)
+ return paddle.nonzero(is_nonzero_mask, as_tuple=True, **kwargs)
+
+ else:
+ return paddle.nonzero(x, as_tuple=True, **kwargs)
def where(condition: array, x1: array, x2: array, /) -> array:
@@ -1003,6 +1036,22 @@ def unique_values(x: array) -> array:
def matmul(x1: array, x2: array, /, **kwargs) -> array:
# paddle.matmul doesn't type promote (but differently from _fix_promotion)
+ d1 = x1.ndim
+ d2 = x2.ndim
+
+ if d1 == 0 or d2 == 0:
+ raise ValueError("matmul does not support 0-D (scalar) inputs.")
+
+ k1 = x1.shape[-1]
+
+ if d2 == 1:
+ k2 = x2.shape[0]
+ else:
+ k2 = x2.shape[-2]
+
+ if k1 != k2:
+ raise ValueError(f"Shapes {x1.shape} and {x2.shape} are not aligned for matmul: "
+ f"{k1} (dim -1 of x1) != {k2} (dim -2 of x2)")
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return paddle.matmul(x1, x2, **kwargs)
diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py
index 4b5ac83d..aa091c81 100644
--- a/array_api_compat/paddle/linalg.py
+++ b/array_api_compat/paddle/linalg.py
@@ -125,12 +125,24 @@ def matrix_norm(
keepdims: bool = False,
ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro",
) -> array:
- return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
-
+ res = paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
+ if res.dtype == paddle.complex64 :
+ res = paddle.cast(res, "float32")
+ if res.dtype == paddle.complex128:
+ res = paddle.cast(res, "float64")
+ return res
def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
if rtol is None:
return paddle.linalg.pinv(x)
+
+ # change rtol shape
+ if isinstance(rtol, (int, float)):
+ rtol = paddle.to_tensor(rtol, dtype=x.dtype)
+
+ # broadcast rtol to [..., 1]
+ if rtol.ndim > 0:
+ rtol = rtol.unsqueeze(-1)
return paddle.linalg.pinv(x, rcond=rtol)
@@ -157,6 +169,9 @@ def svd(x: array, full_matrices: Optional[bool]= None) -> array:
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh'])
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])
+def svdvals(x: array) -> array:
+ return paddle.linalg.svd(x)[1]
+
__all__ = linalg_all + [
"outer",
"matmul",
@@ -171,6 +186,7 @@ def svd(x: array, full_matrices: Optional[bool]= None) -> array:
"slogdet",
"eigh",
"diagonal",
+ "svdvals"
]
_all_ignore = ["paddle_linalg", "sum"]
diff --git a/paddle-xfails.txt b/paddle-xfails.txt
index 6998f374..6d615857 100644
--- a/paddle-xfails.txt
+++ b/paddle-xfails.txt
@@ -106,3 +106,61 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
array_api_tests/test_searching_functions.py::test_where
+
+array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_hypot
+array_api_tests/test_operators_and_elementwise_functions.py::test_copysign
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
+array_api_tests/test_linalg.py::test_outer
+array_api_tests/test_linalg.py::test_vecdot
+array_api_tests/test_operators_and_elementwise_functions.py::test_clip
+array_api_tests/test_manipulation_functions.py::test_stack
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide
+
+# do not pass
+array_api_tests/test_has_names[array_attribute-device]
+array_api_tests/test_signatures.py::test_func_signature[meshgrid]
+
+array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
+array_api_tests/test_indexing_functions.py::test_take
+array_api_tests/test_linalg.py::test_linalg_vecdot
+array_api_tests/test_creation_functions.py::test_asarray_arrays
+
+array_api_tests/test_linalg.py::test_qr
+
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or
+
+# test exceeds the deadline of 800ms
+array_api_tests/test_linalg.py::test_pinv
+array_api_tests/test_linalg.py::test_det
+
+# only supports access to dimension 0 to 9, but received dimension is 10.
+array_api_tests/test_linalg.py::test_tensordot
+array_api_tests/test_linalg.py::test_linalg_tensordot
\ No newline at end of file
From e6cf0119930345a78880f16e3696333ee05227f2 Mon Sep 17 00:00:00 2001
From: Hongyuhe
Date: Sat, 7 Jun 2025 09:55:18 +0000
Subject: [PATCH 26/28] update
---
array_api_compat/paddle/_aliases.py | 16 +++++++---------
array_api_compat/torch/_aliases.py | 3 ---
paddle-xfails.txt | 1 +
vendor_test/vendored/_compat | 2 +-
4 files changed, 9 insertions(+), 13 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index 2c337f24..e04b14c0 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -319,13 +319,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
}
return can_cast_dict[from_][to]
-def test_bitwise_or(x: array, y: array):
- if not paddle.is_tensor(x):
- x = paddle.to_tensor(x)
- if not paddle.is_tensor(y):
- y = paddle.to_tensor(y)
- return paddle.bitwise_or(x, y)
-
# Basic renames
bitwise_invert = paddle.bitwise_not
newaxis = None
@@ -339,7 +332,7 @@ def test_bitwise_or(x: array, y: array):
atan2 = _two_arg(paddle.atan2)
bitwise_and = _two_arg(paddle.bitwise_and)
bitwise_left_shift = _two_arg(paddle.bitwise_left_shift)
-bitwise_or = _two_arg(test_bitwise_or)
+bitwise_or = _two_arg(paddle.bitwise_or)
bitwise_right_shift = _two_arg(paddle.bitwise_right_shift)
bitwise_xor = _two_arg(paddle.bitwise_xor)
copysign = _two_arg(paddle.copysign)
@@ -527,6 +520,9 @@ def prod(
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
if axis == ():
+ # We can't upcast uint8 according to the spec because there is no
+ # paddle.uint64, so at least upcast to int64 which is what sum does
+ # when axis=None.
if dtype is None:
if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
return x.to(paddle.int64)
@@ -537,8 +533,10 @@ def prod(
return _reduce_multiple_axes(
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
)
-
+
+
if axis is None:
+ # paddle.prod doesn't support multiple axes
if dtype is None and x.dtype == paddle.int32:
dtype = 'int64'
res = paddle.prod(x, dtype=dtype, **kwargs)
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 792abc0a..5ac66bcb 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -611,9 +611,6 @@ def triu(x: array, /, *, k: int = 0) -> array:
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: array, /, *, axis: int = 0) -> array:
- if axis == 2:
- import pdb
- pdb.set_trace()
return torch.unsqueeze(x, axis)
def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
diff --git a/paddle-xfails.txt b/paddle-xfails.txt
index 6d615857..b92267c2 100644
--- a/paddle-xfails.txt
+++ b/paddle-xfails.txt
@@ -156,6 +156,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_s
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor
# test exceeds the deadline of 800ms
array_api_tests/test_linalg.py::test_pinv
diff --git a/vendor_test/vendored/_compat b/vendor_test/vendored/_compat
index 07b6ab4f..ba484f91 120000
--- a/vendor_test/vendored/_compat
+++ b/vendor_test/vendored/_compat
@@ -1 +1 @@
-../../array_api_compat
\ No newline at end of file
+../../array_api_compat/
\ No newline at end of file
From 67aa9ef5d22876b8de2ab7784ba50674f85882b1 Mon Sep 17 00:00:00 2001
From: Hongyuhe
Date: Sat, 7 Jun 2025 09:59:02 +0000
Subject: [PATCH 27/28] updat
---
array_api_compat/paddle/_aliases.py | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py
index e04b14c0..df45aef5 100644
--- a/array_api_compat/paddle/_aliases.py
+++ b/array_api_compat/paddle/_aliases.py
@@ -319,6 +319,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
}
return can_cast_dict[from_][to]
+
# Basic renames
bitwise_invert = paddle.bitwise_not
newaxis = None
@@ -520,15 +521,16 @@ def prod(
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
if axis == ():
- # We can't upcast uint8 according to the spec because there is no
- # paddle.uint64, so at least upcast to int64 which is what sum does
- # when axis=None.
if dtype is None:
+ # We can't upcast uint8 according to the spec because there is no
+ # paddle.uint64, so at least upcast to int64 which is what sum does
+ # when axis=None.
if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
return x.to(paddle.int64)
return x.clone()
return x.to(dtype)
+ # paddle.prod doesn't support multiple axes
if isinstance(axis, tuple):
return _reduce_multiple_axes(
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
@@ -536,7 +538,7 @@ def prod(
if axis is None:
- # paddle.prod doesn't support multiple axes
+ # paddle doesn't support keepdims with axis=None
if dtype is None and x.dtype == paddle.int32:
dtype = 'int64'
res = paddle.prod(x, dtype=dtype, **kwargs)
@@ -1283,6 +1285,7 @@ def floor(x: array, /) -> array:
def ceil(x: array, /) -> array:
return paddle.ceil(x).to(x.dtype)
+
def clip(
x: array,
/,
@@ -1357,6 +1360,7 @@ def cumulative_sum(
"axis must be specified in cumulative_sum for more than one dimension"
)
axis = 0
+
res = paddle.cumsum(x, axis=axis, dtype=dtype)
# np.cumsum does not support include_initial
@@ -1387,6 +1391,7 @@ def searchsorted(
right=(side == "right"),
)
+
__all__ = [
"__array_namespace_info__",
"result_type",
From 9f45d0456fd52a99eafd9d9c6e2add7fdf74145b Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Wed, 10 Sep 2025 11:32:53 +0800
Subject: [PATCH 28/28] update
---
.github/workflows/array-api-tests-dask.yml | 11 +-
.../workflows/array-api-tests-numpy-1-21.yml | 11 -
.../workflows/array-api-tests-numpy-1-22.yml | 15 +
.../workflows/array-api-tests-numpy-1-26.yml | 4 +
.../workflows/array-api-tests-numpy-dev.yml | 4 +
.../array-api-tests-numpy-latest.yml | 6 +-
.github/workflows/array-api-tests-torch.yml | 6 +-
.github/workflows/array-api-tests.yml | 33 +-
.github/workflows/docs-build.yml | 6 +-
.github/workflows/docs-deploy.yml | 4 +-
.github/workflows/publish-package.yml | 39 +-
.github/workflows/ruff.yml | 4 +-
.github/workflows/tests.yml | 64 +-
array_api_compat/__init__.py | 8 +-
array_api_compat/_internal.py | 43 +-
array_api_compat/common/__init__.py | 2 +-
array_api_compat/common/_aliases.py | 559 +++++++-----
array_api_compat/common/_fft.py | 158 ++--
array_api_compat/common/_helpers.py | 835 +++++++++++-------
array_api_compat/common/_linalg.py | 158 +++-
array_api_compat/common/_typing.py | 186 +++-
array_api_compat/cupy/__init__.py | 14 +-
array_api_compat/cupy/_aliases.py | 128 ++-
array_api_compat/cupy/_info.py | 22 +-
array_api_compat/cupy/_typing.py | 62 +-
array_api_compat/cupy/fft.py | 16 +-
array_api_compat/cupy/linalg.py | 8 +-
array_api_compat/dask/array/__init__.py | 23 +-
array_api_compat/dask/array/_aliases.py | 374 +++++---
array_api_compat/dask/array/_info.py | 142 ++-
array_api_compat/dask/array/fft.py | 20 +-
array_api_compat/dask/array/linalg.py | 55 +-
array_api_compat/numpy/__init__.py | 36 +-
array_api_compat/numpy/_aliases.py | 166 ++--
array_api_compat/numpy/_info.py | 53 +-
array_api_compat/numpy/_typing.py | 69 +-
array_api_compat/numpy/fft.py | 21 +-
array_api_compat/numpy/linalg.py | 86 +-
array_api_compat/py.typed | 0
array_api_compat/torch/__init__.py | 29 +-
array_api_compat/torch/_aliases.py | 478 ++++++----
array_api_compat/torch/_info.py | 43 +-
array_api_compat/torch/_typing.py | 3 +
array_api_compat/torch/fft.py | 66 +-
array_api_compat/torch/linalg.py | 51 +-
cupy-xfails.txt | 33 +-
dask-skips.txt | 11 +-
dask-xfails.txt | 137 ++-
docs/changelog.md | 156 ++++
docs/dev/tests.md | 2 +-
docs/helper-functions.rst | 2 +
docs/index.md | 4 +-
docs/requirements.txt | 6 -
docs/supported-array-libraries.md | 30 +-
numpy-1-21-xfails.txt | 260 ------
numpy-1-22-xfails.txt | 175 ++++
numpy-1-26-xfails.txt | 48 +-
numpy-dev-xfails.txt | 47 +-
numpy-skips.txt | 11 -
numpy-xfails.txt | 46 +-
pyproject.toml | 120 +++
ruff.toml | 17 -
setup.py | 37 -
test_cupy.sh | 2 +-
tests/_helpers.py | 27 +-
tests/test_all.py | 329 ++++++-
tests/test_array_namespace.py | 161 ++--
tests/test_common.py | 295 +++++--
tests/test_copies_or_views.py | 64 ++
tests/test_cupy.py | 45 +
tests/test_dask.py | 183 ++++
tests/test_jax.py | 38 +
tests/test_torch.py | 119 +++
tests/test_vendoring.py | 2 +
torch-skips.txt | 11 -
torch-xfails.txt | 126 +--
vendor_test/uses_torch.py | 2 +-
77 files changed, 4413 insertions(+), 2254 deletions(-)
delete mode 100644 .github/workflows/array-api-tests-numpy-1-21.yml
create mode 100644 .github/workflows/array-api-tests-numpy-1-22.yml
create mode 100644 array_api_compat/py.typed
create mode 100644 array_api_compat/torch/_typing.py
delete mode 100644 docs/requirements.txt
delete mode 100644 numpy-1-21-xfails.txt
create mode 100644 numpy-1-22-xfails.txt
create mode 100644 pyproject.toml
delete mode 100644 ruff.toml
delete mode 100644 setup.py
create mode 100644 tests/test_copies_or_views.py
create mode 100644 tests/test_cupy.py
create mode 100644 tests/test_dask.py
create mode 100644 tests/test_jax.py
create mode 100644 tests/test_torch.py
diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml
index 78010233..ef430d9c 100644
--- a/.github/workflows/array-api-tests-dask.yml
+++ b/.github/workflows/array-api-tests-dask.yml
@@ -7,7 +7,14 @@ jobs:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
- package-version: '>= 2024.9.0'
module-name: dask.array
extra-requires: numpy
- pytest-extra-args: --disable-deadline --max-examples=5
+ # Dask is substantially slower then other libraries on unit tests.
+ # Reduce the number of examples to speed up CI, even though this means that this
+ # workflow is barely more than a smoke test, and one should expect extreme
+ # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run
+ # the full test suite with at least 200 examples.
+ pytest-extra-args: --max-examples=200 -n 4
+ python-versions: '[''3.10'', ''3.13'']'
+ extra-env-vars: |
+ ARRAY_API_TESTS_XFAIL_MARK=skip
diff --git a/.github/workflows/array-api-tests-numpy-1-21.yml b/.github/workflows/array-api-tests-numpy-1-21.yml
deleted file mode 100644
index 2d81c3cd..00000000
--- a/.github/workflows/array-api-tests-numpy-1-21.yml
+++ /dev/null
@@ -1,11 +0,0 @@
-name: Array API Tests (NumPy 1.21)
-
-on: [push, pull_request]
-
-jobs:
- array-api-tests-numpy-1-21:
- uses: ./.github/workflows/array-api-tests.yml
- with:
- package-name: numpy
- package-version: '== 1.21.*'
- xfails-file-extra: '-1-21'
diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml
new file mode 100644
index 00000000..83d4cf1d
--- /dev/null
+++ b/.github/workflows/array-api-tests-numpy-1-22.yml
@@ -0,0 +1,15 @@
+name: Array API Tests (NumPy 1.22)
+
+on: [push, pull_request]
+
+jobs:
+ array-api-tests-numpy-1-22:
+ uses: ./.github/workflows/array-api-tests.yml
+ with:
+ package-name: numpy
+ package-version: '== 1.22.*'
+ xfails-file-extra: '-1-22'
+ python-versions: '[''3.10'']'
+ pytest-extra-args: -n 4
+ extra-env-vars: |
+ ARRAY_API_TESTS_XFAIL_MARK=skip
diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml
index 660935f0..13124644 100644
--- a/.github/workflows/array-api-tests-numpy-1-26.yml
+++ b/.github/workflows/array-api-tests-numpy-1-26.yml
@@ -9,3 +9,7 @@ jobs:
package-name: numpy
package-version: '== 1.26.*'
xfails-file-extra: '-1-26'
+ python-versions: '[''3.10'', ''3.12'']'
+ pytest-extra-args: -n 4
+ extra-env-vars: |
+ ARRAY_API_TESTS_XFAIL_MARK=skip
diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml
index eef4269d..dec4c7ae 100644
--- a/.github/workflows/array-api-tests-numpy-dev.yml
+++ b/.github/workflows/array-api-tests-numpy-dev.yml
@@ -9,3 +9,7 @@ jobs:
package-name: numpy
extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple'
xfails-file-extra: '-dev'
+ python-versions: '[''3.11'', ''3.13'']'
+ pytest-extra-args: -n 4
+ extra-env-vars: |
+ ARRAY_API_TESTS_XFAIL_MARK=skip
diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml
index 36984345..65bbc9a2 100644
--- a/.github/workflows/array-api-tests-numpy-latest.yml
+++ b/.github/workflows/array-api-tests-numpy-latest.yml
@@ -1,4 +1,4 @@
-name: Array API Tests (NumPy Latest)
+name: Array API Tests (NumPy latest)
on: [push, pull_request]
@@ -7,3 +7,7 @@ jobs:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: numpy
+ python-versions: '[''3.10'', ''3.13'']'
+ pytest-extra-args: -n 4
+ extra-env-vars: |
+ ARRAY_API_TESTS_XFAIL_MARK=skip
diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml
index 56ab81a3..4b4b945e 100644
--- a/.github/workflows/array-api-tests-torch.yml
+++ b/.github/workflows/array-api-tests-torch.yml
@@ -1,4 +1,4 @@
-name: Array API Tests (PyTorch Latest)
+name: Array API Tests (PyTorch CPU)
on: [push, pull_request]
@@ -7,5 +7,9 @@ jobs:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: torch
+ extra-requires: '--index-url https://download.pytorch.org/whl/cpu'
extra-env-vars: |
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
+ ARRAY_API_TESTS_XFAIL_MARK=skip
+ python-versions: '[''3.10'', ''3.13'']'
+ pytest-extra-args: -n 4
diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml
index e0d5d84e..e3c0c9e0 100644
--- a/.github/workflows/array-api-tests.yml
+++ b/.github/workflows/array-api-tests.yml
@@ -16,6 +16,10 @@ on:
required: false
type: string
default: '>= 0'
+ python-versions:
+ required: true
+ type: string
+ description: JSON array of Python versions to test against.
pytest-extra-args:
required: false
type: string
@@ -30,52 +34,57 @@ on:
extra-env-vars:
required: false
type: string
- description: "Multiline string of environment variables to set for the test run."
+ description: Multiline string of environment variables to set for the test run.
env:
- PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
+ PYTEST_ARGS: "--max-examples 1000 -v -rxXfE ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20"
jobs:
tests:
runs-on: ubuntu-latest
strategy:
+ fail-fast: false
matrix:
- # min version of dask we needs drops support for python 3.9
- python-version: ${{ inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']') || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }}
+ python-version: ${{ fromJson(inputs.python-versions) }}
steps:
- name: Checkout array-api-compat
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
path: array-api-compat
+
- name: Checkout array-api-tests
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
repository: data-apis/array-api-tests
submodules: 'true'
path: array-api-tests
+
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
+
- name: Set Extra Environment Variables
# Set additional environment variables if provided
if: inputs.extra-env-vars
run: |
echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV
+
- name: Install dependencies
- # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
- # to put this in the numpy 1.21 config file.
- if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
run: |
python -m pip install --upgrade pip
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }}
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
+ python -m pip install pytest-xdist
+
+ - name: Dump pip environment
+ run: pip freeze
+
- name: Run the array API testsuite (${{ inputs.package-name }})
- if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
env:
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
- ARRAY_API_TESTS_VERSION: 2023.12
+ ARRAY_API_TESTS_VERSION: 2024.12
# This enables the NEP 50 type promotion behavior (without it a lot of
# tests fail on bad scalar type promotion behavior)
NPY_PROMOTION_STATE: weak
diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml
index 04c3aa66..305a9003 100644
--- a/.github/workflows/docs-build.yml
+++ b/.github/workflows/docs-build.yml
@@ -6,11 +6,11 @@ jobs:
docs-build:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-python@v5
+ - uses: actions/checkout@v5
+ - uses: actions/setup-python@v6
- name: Install Dependencies
run: |
- python -m pip install -r docs/requirements.txt
+ python -m pip install .[docs]
- name: Build Docs
run: |
cd docs
diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml
index 9aa379de..42a3598f 100644
--- a/.github/workflows/docs-deploy.yml
+++ b/.github/workflows/docs-deploy.yml
@@ -11,9 +11,9 @@ jobs:
environment:
name: docs-deploy
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: Download Artifact
- uses: dawidd6/action-download-artifact@v6
+ uses: dawidd6/action-download-artifact@v11
with:
workflow: docs-build.yml
name: docs-build
diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml
index ad021adc..bbfb2e80 100644
--- a/.github/workflows/publish-package.yml
+++ b/.github/workflows/publish-package.yml
@@ -30,24 +30,25 @@ jobs:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.x'
- name: Install python-build and twine
run: |
- python -m pip install --upgrade pip setuptools
+ python -m pip install --upgrade pip "setuptools<=67"
python -m pip install build twine
python -m pip list
- name: Build a wheel and a sdist
run: |
- PYTHONWARNINGS=error,default::DeprecationWarning python -m build .
+ #PYTHONWARNINGS=error,default::DeprecationWarning python -m build .
+ python -m build .
- name: Verify the distribution
run: twine check --strict dist/*
@@ -80,7 +81,7 @@ jobs:
steps:
- name: Download distribution artifact
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v5
with:
name: dist-artifact
path: dist
@@ -88,15 +89,21 @@ jobs:
- name: List all files
run: ls -lh dist
- - name: Publish distribution 📦 to Test PyPI
- # Publish to TestPyPI on tag events of if manually triggered
- # Compare to 'true' string as booleans get turned into strings in the console
- if: >-
- (github.event_name == 'push' && startsWith(github.ref, 'refs/tags'))
- || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true')
- uses: pypa/gh-action-pypi-publish@v1.12.2
+ # - name: Publish distribution 📦 to Test PyPI
+ # # Publish to TestPyPI on tag events of if manually triggered
+ # # Compare to 'true' string as booleans get turned into strings in the console
+ # if: >-
+ # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags'))
+ # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true')
+ # uses: pypa/gh-action-pypi-publish@v1.13.0
+ # with:
+ # repository-url: https://test.pypi.org/legacy/
+ # print-hash: true
+
+ - name: Publish distribution 📦 to PyPI
+ if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
+ uses: pypa/gh-action-pypi-publish@v1.13.0
with:
- repository-url: https://test.pypi.org/legacy/
print-hash: true
- name: Create GitHub Release from a Tag
@@ -104,9 +111,3 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
with:
files: dist/*
-
- - name: Publish distribution 📦 to PyPI
- if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
- uses: pypa/gh-action-pypi-publish@v1.12.2
- with:
- print-hash: true
diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
index a9f0fd4b..4a2ffcff 100644
--- a/.github/workflows/ruff.yml
+++ b/.github/workflows/ruff.yml
@@ -5,9 +5,9 @@ jobs:
runs-on: ubuntu-latest
continue-on-error: true
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: Install Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install dependencies
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index fcd43367..cfbb875f 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -4,43 +4,55 @@ jobs:
tests:
runs-on: ubuntu-latest
strategy:
+ fail-fast: false
matrix:
- python-version: ['3.9', '3.10', '3.11', '3.12']
- numpy-version: ['1.21', '1.26', '2.0', 'dev']
- exclude:
- - python-version: '3.11'
- numpy-version: '1.21'
- - python-version: '3.12'
- numpy-version: '1.21'
- fail-fast: true
+ include:
+ - numpy-version: '1.22'
+ python-version: '3.10'
+ - numpy-version: '1.26'
+ python-version: '3.10'
+ - numpy-version: '1.26'
+ python-version: '3.12'
+ - numpy-version: 'latest'
+ python-version: '3.10'
+ - numpy-version: 'latest'
+ python-version: '3.13'
+ - numpy-version: 'dev'
+ python-version: '3.11'
+ - numpy-version: 'dev'
+ python-version: '3.13'
+
steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-python@v5
+ - uses: actions/checkout@v5
+ - uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
+ python -m pip install pytest
+
+ # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack
+ python -m pip install array-api-strict
+ python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
+
if [ "${{ matrix.numpy-version }}" == "dev" ]; then
- PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple'
- elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then
- PIP_EXTRA='numpy==1.21.*'
+ python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
+ python -m pip install dask[array] jax[cpu] sparse ndonnx
+ elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then
+ python -m pip install 'numpy==1.22.*'
+ elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then
+ python -m pip install 'numpy==1.26.*'
else
- PIP_EXTRA='numpy==1.26.*'
+ python -m pip install numpy
+ python -m pip install dask[array] jax[cpu] sparse ndonnx
fi
- if [ "${{ matrix.python-version }}" == "3.9" ]; then
- sed -i '/^ndonnx/d' requirements-dev.txt
- fi
+ - name: Dump pip environment
+ run: pip freeze
- python -m pip install -r requirements-dev.txt $PIP_EXTRA
+ - name: Test it installs
+ run: python -m pip install .
- name: Run Tests
- run: |
- if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then
- PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse")
- fi
- pytest -v "${PYTEST_EXTRA[@]}"
-
- # Make sure it installs
- python -m pip install .
+ run: pytest -v
diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py
index 30b1d852..a00e8cbc 100644
--- a/array_api_compat/__init__.py
+++ b/array_api_compat/__init__.py
@@ -1,9 +1,9 @@
"""
NumPy Array API compatibility library
-This is a small wrapper around NumPy and CuPy that is compatible with the
-Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
-https://numpy.org/neps/nep-0047-array-api-standard.html.
+This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are
+compatible with the Array API standard https://data-apis.org/array-api/latest/.
+See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
@@ -17,6 +17,6 @@
this implementation for the default when working with NumPy arrays.
"""
-__version__ = '1.9.1'
+__version__ = '1.13.0.dev0'
from .common import * # noqa: F401, F403
diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py
index 170a1ff9..baa39ded 100644
--- a/array_api_compat/_internal.py
+++ b/array_api_compat/_internal.py
@@ -2,10 +2,17 @@
Internal helpers
"""
+import importlib
+from collections.abc import Callable
from functools import wraps
from inspect import signature
+from types import ModuleType
+from typing import TypeVar
-def get_xp(xp):
+_T = TypeVar("_T")
+
+
+def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""
Decorator to automatically replace xp with the corresponding array module.
@@ -22,14 +29,14 @@ def func(x, /, xp, kwarg=None):
"""
- def inner(f):
+ def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
@wraps(f)
- def wrapped_f(*args, **kwargs):
+ def wrapped_f(*args: object, **kwargs: object) -> object:
return f(*args, xp=xp, **kwargs)
sig = signature(f)
new_sig = sig.replace(
- parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
+ parameters=[par for i, par in sig.parameters.items() if i != "xp"]
)
if wrapped_f.__doc__ is None:
@@ -40,7 +47,31 @@ def wrapped_f(*args, **kwargs):
specification for more details.
"""
- wrapped_f.__signature__ = new_sig
- return wrapped_f
+ wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
+ return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]
return inner
+
+
+def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]:
+ """Import everything from module, updating globals().
+ Returns __all__.
+ """
+ mod = importlib.import_module(mod_name)
+ # Neither of these two methods is sufficient by itself,
+ # depending on various idiosyncrasies of the libraries we're wrapping.
+ objs = {}
+ exec(f"from {mod.__name__} import *", objs)
+
+ for n in dir(mod):
+ if not n.startswith("_") and hasattr(mod, n):
+ objs[n] = getattr(mod, n)
+
+ globals_.update(objs)
+ return list(objs)
+
+
+__all__ = ["get_xp", "clone_module"]
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py
index 91ab1c40..82360807 100644
--- a/array_api_compat/common/__init__.py
+++ b/array_api_compat/common/__init__.py
@@ -1 +1 @@
-from ._helpers import * # noqa: F403
+from ._helpers import * # noqa: F403
diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index d5405745..3587ef16 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -4,142 +4,172 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- from typing import Optional, Sequence, Tuple, Union
- from ._typing import ndarray, Device, Dtype
-
-from typing import NamedTuple
import inspect
+from collections.abc import Sequence
+from typing import TYPE_CHECKING, Any, NamedTuple, cast
+
+from ._helpers import _check_device, array_namespace
+from ._helpers import device as _get_device
+from ._helpers import is_cupy_namespace
+from ._typing import Array, Device, DType, Namespace
-from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace
+if TYPE_CHECKING:
+ # TODO: import from typing (requires Python >=3.13)
+ from typing_extensions import TypeIs
# These functions are modified from the NumPy versions.
-# Creation functions add the device keyword (which does nothing for NumPy)
+# Creation functions add the device keyword (which does nothing for NumPy and Dask)
+
def arange(
- start: Union[int, float],
+ start: float,
/,
- stop: Optional[Union[int, float]] = None,
- step: Union[int, float] = 1,
+ stop: float | None = None,
+ step: float = 1,
*,
- xp,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs
-) -> ndarray:
+ xp: Namespace,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
+
def empty(
- shape: Union[int, Tuple[int, ...]],
- xp,
+ shape: int | tuple[int, ...],
+ xp: Namespace,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs
-) -> ndarray:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.empty(shape, dtype=dtype, **kwargs)
+
def empty_like(
- x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
- **kwargs
-) -> ndarray:
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.empty_like(x, dtype=dtype, **kwargs)
+
def eye(
n_rows: int,
- n_cols: Optional[int] = None,
+ n_cols: int | None = None,
/,
*,
- xp,
+ xp: Namespace,
k: int = 0,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs,
-) -> ndarray:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
+
def full(
- shape: Union[int, Tuple[int, ...]],
- fill_value: Union[int, float],
- xp,
+ shape: int | tuple[int, ...],
+ fill_value: complex,
+ xp: Namespace,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs,
-) -> ndarray:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
+
def full_like(
- x: ndarray,
+ x: Array,
/,
- fill_value: Union[int, float],
+ fill_value: complex,
*,
- xp,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs,
-) -> ndarray:
+ xp: Namespace,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
+
def linspace(
- start: Union[int, float],
- stop: Union[int, float],
+ start: float,
+ stop: float,
/,
num: int,
*,
- xp,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
+ xp: Namespace,
+ dtype: DType | None = None,
+ device: Device | None = None,
endpoint: bool = True,
- **kwargs,
-) -> ndarray:
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
+
def ones(
- shape: Union[int, Tuple[int, ...]],
- xp,
+ shape: int | tuple[int, ...],
+ xp: Namespace,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs,
-) -> ndarray:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.ones(shape, dtype=dtype, **kwargs)
+
def ones_like(
- x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
- **kwargs,
-) -> ndarray:
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.ones_like(x, dtype=dtype, **kwargs)
+
def zeros(
- shape: Union[int, Tuple[int, ...]],
- xp,
+ shape: int | tuple[int, ...],
+ xp: Namespace,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs,
-) -> ndarray:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.zeros(shape, dtype=dtype, **kwargs)
+
def zeros_like(
- x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
- **kwargs,
-) -> ndarray:
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
+) -> Array:
_check_device(xp, device)
return xp.zeros_like(x, dtype=dtype, **kwargs)
+
# np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
# to remove polymorphic return types).
@@ -147,35 +177,37 @@ def zeros_like(
# The functions here return namedtuples (np.unique() returns a normal
# tuple).
+
# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
class UniqueAllResult(NamedTuple):
- values: ndarray
- indices: ndarray
- inverse_indices: ndarray
- counts: ndarray
+ values: Array
+ indices: Array
+ inverse_indices: Array
+ counts: Array
class UniqueCountsResult(NamedTuple):
- values: ndarray
- counts: ndarray
+ values: Array
+ counts: Array
class UniqueInverseResult(NamedTuple):
- values: ndarray
- inverse_indices: ndarray
+ values: Array
+ inverse_indices: Array
-def _unique_kwargs(xp):
+def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
# trying to parse version numbers, just check if equal_nan is in the
# signature.
s = inspect.signature(xp.unique)
- if 'equal_nan' in s.parameters:
- return {'equal_nan': False}
+ if "equal_nan" in s.parameters:
+ return {"equal_nan": False}
return {}
-def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
+
+def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
kwargs = _unique_kwargs(xp)
values, indices, inverse_indices, counts = xp.unique(
x,
@@ -195,20 +227,16 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
)
-def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
+def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
kwargs = _unique_kwargs(xp)
res = xp.unique(
- x,
- return_counts=True,
- return_index=False,
- return_inverse=False,
- **kwargs
+ x, return_counts=True, return_index=False, return_inverse=False, **kwargs
)
return UniqueCountsResult(*res)
-def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
+def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult:
kwargs = _unique_kwargs(xp)
values, inverse_indices = xp.unique(
x,
@@ -223,7 +251,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
return UniqueInverseResult(values, inverse_indices)
-def unique_values(x: ndarray, /, xp) -> ndarray:
+def unique_values(x: Array, /, xp: Namespace) -> Array:
kwargs = _unique_kwargs(xp)
return xp.unique(
x,
@@ -233,56 +261,58 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
**kwargs,
)
-def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
- if not copy and dtype == x.dtype:
- return x
- return x.astype(dtype=dtype, copy=copy)
# These functions have different keyword argument names
+
def std(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
- correction: Union[int, float] = 0.0, # correction instead of ddof
+ axis: int | tuple[int, ...] | None = None,
+ correction: float = 0.0, # correction instead of ddof
keepdims: bool = False,
- **kwargs,
-) -> ndarray:
+ **kwargs: object,
+) -> Array:
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
+
def var(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
- correction: Union[int, float] = 0.0, # correction instead of ddof
+ axis: int | tuple[int, ...] | None = None,
+ correction: float = 0.0, # correction instead of ddof
keepdims: bool = False,
- **kwargs,
-) -> ndarray:
+ **kwargs: object,
+) -> Array:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
+
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
# argument
+
def cumulative_sum(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- axis: Optional[int] = None,
- dtype: Optional[Dtype] = None,
+ axis: int | None = None,
+ dtype: DType | None = None,
include_initial: bool = False,
- **kwargs
-) -> ndarray:
+ **kwargs: object,
+) -> Array:
wrapped_xp = array_namespace(x)
# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
- raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
+ raise ValueError(
+ "axis must be specified in cumulative_sum for more than one dimension"
+ )
axis = 0
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
@@ -292,25 +322,69 @@ def cumulative_sum(
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
- [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
+ [
+ wrapped_xp.zeros(
+ shape=initial_shape, dtype=res.dtype, device=_get_device(res)
+ ),
+ res,
+ ],
axis=axis,
)
return res
+
+def cumulative_prod(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ axis: int | None = None,
+ dtype: DType | None = None,
+ include_initial: bool = False,
+ **kwargs: object,
+) -> Array:
+ wrapped_xp = array_namespace(x)
+
+ if axis is None:
+ if x.ndim > 1:
+ raise ValueError(
+ "axis must be specified in cumulative_prod for more than one dimension"
+ )
+ axis = 0
+
+ res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
+
+ # np.cumprod does not support include_initial
+ if include_initial:
+ initial_shape = list(x.shape)
+ initial_shape[axis] = 1
+ res = xp.concatenate(
+ [
+ wrapped_xp.ones(
+ shape=initial_shape, dtype=res.dtype, device=_get_device(res)
+ ),
+ res,
+ ],
+ axis=axis,
+ )
+ return res
+
+
# The min and max argument names in clip are different and not optional in numpy, and type
# promotion behavior is different.
def clip(
- x: ndarray,
+ x: Array,
/,
- min: Optional[Union[int, float, ndarray]] = None,
- max: Optional[Union[int, float, ndarray]] = None,
+ min: float | Array | None = None,
+ max: float | Array | None = None,
*,
- xp,
+ xp: Namespace,
# TODO: np.clip has other ufunc kwargs
- out: Optional[ndarray] = None,
-) -> ndarray:
- def _isscalar(a):
- return isinstance(a, (int, float, type(None)))
+ out: Array | None = None,
+) -> Array:
+ def _isscalar(a: object) -> TypeIs[float | None]:
+ return isinstance(a, int | float) or a is None
+
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
@@ -335,44 +409,51 @@ def _isscalar(a):
# but an answer of 0 might be preferred. See
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
-
# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
- if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
- min = None
- if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
- max = None
+ if wrapped_xp.isdtype(x.dtype, "integral"):
+ if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
+ min = None
+ if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
+ max = None
+ dev = _get_device(x)
if out is None:
- out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
- copy=True, device=device(x))
+ out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
+ assert out is not None # workaround for a type-narrowing issue in pyright
+ out[()] = x
+
if min is not None:
- if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
- # Avoid loss of precision due to torch defaulting to float32
- min = wrapped_xp.asarray(min, dtype=xp.float64)
- a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
+ a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev)
+ a = xp.broadcast_to(a, result_shape)
ia = (out < a) | xp.isnan(a)
- # torch requires an explicit cast here
- out[ia] = wrapped_xp.astype(a[ia], out.dtype)
+ out[ia] = a[ia]
+
if max is not None:
- if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
- max = wrapped_xp.asarray(max, dtype=xp.float64)
- b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
+ b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev)
+ b = xp.broadcast_to(b, result_shape)
ib = (out > b) | xp.isnan(b)
- out[ib] = wrapped_xp.astype(b[ib], out.dtype)
+ out[ib] = b[ib]
+
# Return a scalar for 0-D
return out[()]
+
# Unlike transpose(), the axes argument to permute_dims() is required.
-def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
+def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array:
return xp.transpose(x, axes)
+
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
-def reshape(x: ndarray,
- /,
- shape: Tuple[int, ...],
- xp, copy: Optional[bool] = None,
- **kwargs) -> ndarray:
+def reshape(
+ x: Array,
+ /,
+ shape: tuple[int, ...],
+ xp: Namespace,
+ *,
+ copy: bool | None = None,
+ **kwargs: object,
+) -> Array:
if copy is True:
x = x.copy()
elif copy is False:
@@ -381,17 +462,24 @@ def reshape(x: ndarray,
return y
return xp.reshape(x, shape, **kwargs)
+
# The descending keyword is new in sort and argsort, and 'kind' replaced with
# 'stable'
def argsort(
- x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
- **kwargs,
-) -> ndarray:
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ axis: int = -1,
+ descending: bool = False,
+ stable: bool = True,
+ **kwargs: object,
+) -> Array:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
- kwargs['kind'] = "stable"
+ kwargs["kind"] = "stable"
if not descending:
res = xp.argsort(x, axis=axis, **kwargs)
else:
@@ -408,69 +496,66 @@ def argsort(
res = max_i - res
return res
+
def sort(
- x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
- **kwargs,
-) -> ndarray:
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ axis: int = -1,
+ descending: bool = False,
+ stable: bool = True,
+ **kwargs: object,
+) -> Array:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
- kwargs['kind'] = "stable"
+ kwargs["kind"] = "stable"
res = xp.sort(x, axis=axis, **kwargs)
if descending:
res = xp.flip(res, axis=axis)
return res
+
# nonzero should error for zero-dimensional arrays
-def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
+def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)
-# ceil, floor, and trunc return integers for integer inputs
-
-def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
- if xp.issubdtype(x.dtype, xp.integer):
- return x
- return xp.ceil(x, **kwargs)
-
-def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
- if xp.issubdtype(x.dtype, xp.integer):
- return x
- return xp.floor(x, **kwargs)
-
-def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
- if xp.issubdtype(x.dtype, xp.integer):
- return x
- return xp.trunc(x, **kwargs)
# linear algebra functions
-def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
+
+def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
return xp.matmul(x1, x2, **kwargs)
+
# Unlike transpose, matrix_transpose only transposes the last two axes.
-def matrix_transpose(x: ndarray, /, xp) -> ndarray:
+def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
if x.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return xp.swapaxes(x, -1, -2)
-def tensordot(x1: ndarray,
- x2: ndarray,
- /,
- xp,
- *,
- axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
- **kwargs,
-) -> ndarray:
+
+def tensordot(
+ x1: Array,
+ x2: Array,
+ /,
+ xp: Namespace,
+ *,
+ axes: int | tuple[Sequence[int], Sequence[int]] = 2,
+ **kwargs: object,
+) -> Array:
return xp.tensordot(x1, x2, axes=axes, **kwargs)
-def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
+
+def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
- if hasattr(xp, 'broadcast_tensors'):
+ if hasattr(xp, "broadcast_tensors"):
_broadcast = xp.broadcast_tensors
else:
_broadcast = xp.broadcast_arrays
@@ -482,11 +567,16 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
res = xp.conj(x1_[..., None, :]) @ x2_[..., None]
return res[..., 0, 0]
+
# isdtype is a new function in the 2022.12 array API specification.
+
def isdtype(
- dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
- *, _tuple=True, # Disallow nested tuples
+ dtype: DType,
+ kind: DType | str | tuple[DType | str, ...],
+ xp: Namespace,
+ *,
+ _tuple: bool = True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
@@ -499,21 +589,24 @@ def isdtype(
for more details
"""
if isinstance(kind, tuple) and _tuple:
- return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
+ return any(
+ isdtype(dtype, k, xp, _tuple=False)
+ for k in cast("tuple[DType | str, ...]", kind)
+ )
elif isinstance(kind, str):
- if kind == 'bool':
+ if kind == "bool":
return dtype == xp.bool_
- elif kind == 'signed integer':
+ elif kind == "signed integer":
return xp.issubdtype(dtype, xp.signedinteger)
- elif kind == 'unsigned integer':
+ elif kind == "unsigned integer":
return xp.issubdtype(dtype, xp.unsignedinteger)
- elif kind == 'integral':
+ elif kind == "integral":
return xp.issubdtype(dtype, xp.integer)
- elif kind == 'real floating':
+ elif kind == "real floating":
return xp.issubdtype(dtype, xp.floating)
- elif kind == 'complex floating':
+ elif kind == "complex floating":
return xp.issubdtype(dtype, xp.complexfloating)
- elif kind == 'numeric':
+ elif kind == "numeric":
return xp.issubdtype(dtype, xp.number)
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
@@ -524,32 +617,86 @@ def isdtype(
# array_api_strict implementation will be very strict.
return dtype == kind
+
# unstack is a new function in the 2023.12 array API standard
-def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
+def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("Input array must be at least 1-d.")
return tuple(xp.moveaxis(x, axis, 0))
+
# numpy 1.26 does not use the standard definition for sign on complex numbers
-def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
- if isdtype(x.dtype, 'complex floating', xp=xp):
- out = (x/xp.abs(x, **kwargs))[...]
+
+def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
+ if isdtype(x.dtype, "complex floating", xp=xp):
+ out = (x / xp.abs(x, **kwargs))[...]
# sign(0) = 0 but the above formula would give nan
- out[x == 0+0j] = 0+0j
+ out[x == 0j] = 0j
else:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
- if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
+ if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]
-__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
- 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
- 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
- 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
- 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
- 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
- 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
- 'unstack', 'sign']
+
+def finfo(type_: DType | Array, /, xp: Namespace) -> Any:
+ # It is surprisingly difficult to recognize a dtype apart from an array.
+ # np.int64 is not the same as np.asarray(1).dtype!
+ try:
+ return xp.finfo(type_)
+ except (ValueError, TypeError):
+ return xp.finfo(type_.dtype)
+
+
+def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
+ try:
+ return xp.iinfo(type_)
+ except (ValueError, TypeError):
+ return xp.iinfo(type_.dtype)
+
+
+__all__ = [
+ "arange",
+ "empty",
+ "empty_like",
+ "eye",
+ "full",
+ "full_like",
+ "linspace",
+ "ones",
+ "ones_like",
+ "zeros",
+ "zeros_like",
+ "UniqueAllResult",
+ "UniqueCountsResult",
+ "UniqueInverseResult",
+ "unique_all",
+ "unique_counts",
+ "unique_inverse",
+ "unique_values",
+ "std",
+ "var",
+ "cumulative_sum",
+ "cumulative_prod",
+ "clip",
+ "permute_dims",
+ "reshape",
+ "argsort",
+ "sort",
+ "nonzero",
+ "matmul",
+ "matrix_transpose",
+ "tensordot",
+ "vecdot",
+ "isdtype",
+ "unstack",
+ "sign",
+ "finfo",
+ "iinfo",
+]
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index 666b0b1f..18839d37 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -1,168 +1,195 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Union, Optional, Literal
+from collections.abc import Sequence
+from typing import Literal, TypeAlias
-if TYPE_CHECKING:
- from ._typing import Device, ndarray
- from collections.abc import Sequence
+from ._typing import Array, Device, DType, Namespace
+
+_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.
def fft(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- n: Optional[int] = None,
+ n: int | None = None,
axis: int = -1,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifft(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- n: Optional[int] = None,
+ n: int | None = None,
axis: int = -1,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def fftn(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- s: Sequence[int] = None,
- axes: Sequence[int] = None,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ s: Sequence[int] | None = None,
+ axes: Sequence[int] | None = None,
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifftn(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- s: Sequence[int] = None,
- axes: Sequence[int] = None,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ s: Sequence[int] | None = None,
+ axes: Sequence[int] | None = None,
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def rfft(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- n: Optional[int] = None,
+ n: int | None = None,
axis: int = -1,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfft(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- n: Optional[int] = None,
+ n: int | None = None,
axis: int = -1,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def rfftn(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- s: Sequence[int] = None,
- axes: Sequence[int] = None,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ s: Sequence[int] | None = None,
+ axes: Sequence[int] | None = None,
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfftn(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- s: Sequence[int] = None,
- axes: Sequence[int] = None,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ s: Sequence[int] | None = None,
+ axes: Sequence[int] | None = None,
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def hfft(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- n: Optional[int] = None,
+ n: int | None = None,
axis: int = -1,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.float32)
return res
def ihfft(
- x: ndarray,
+ x: Array,
/,
- xp,
+ xp: Namespace,
*,
- n: Optional[int] = None,
+ n: int | None = None,
axis: int = -1,
- norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+ norm: _Norm = "backward",
+) -> Array:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
-def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
+def fftfreq(
+ n: int,
+ /,
+ xp: Namespace,
+ *,
+ d: float = 1.0,
+ dtype: DType | None = None,
+ device: Device | None = None,
+) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
- return xp.fft.fftfreq(n, d=d)
+ res = xp.fft.fftfreq(n, d=d)
+ if dtype is not None:
+ return res.astype(dtype)
+ return res
-def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
+def rfftfreq(
+ n: int,
+ /,
+ xp: Namespace,
+ *,
+ d: float = 1.0,
+ dtype: DType | None = None,
+ device: Device | None = None,
+) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
- return xp.fft.rfftfreq(n, d=d)
+ res = xp.fft.rfftfreq(n, d=d)
+ if dtype is not None:
+ return res.astype(dtype)
+ return res
-def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
+def fftshift(
+ x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
+) -> Array:
return xp.fft.fftshift(x, axes=axes)
-def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
+def ifftshift(
+ x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
+) -> Array:
return xp.fft.ifftshift(x, axes=axes)
__all__ = [
@@ -181,3 +208,6 @@ def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> n
"fftshift",
"ifftshift",
]
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index ec6b3e0d..d75de5c4 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -5,34 +5,95 @@
that are in __all__ are intended as additional helper functions for use by end
users of the compat library.
"""
+
from __future__ import annotations
-from typing import TYPE_CHECKING
+import enum
+import inspect
+import math
+import sys
+import warnings
+from collections.abc import Collection, Hashable
+from functools import lru_cache
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Final,
+ Literal,
+ SupportsIndex,
+ TypeAlias,
+ TypeGuard,
+ cast,
+ overload,
+)
+
+from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
if TYPE_CHECKING:
- from typing import Optional, Union, Any
- from ._typing import Array, Device
+ import cupy as cp
+ import dask.array as da
+ import jax
+ import ndonnx as ndx
+ import numpy as np
+ import numpy.typing as npt
+ import sparse
+ import torch
+ import paddle
+
+ # TODO: import from typing (requires Python >=3.13)
+ from typing_extensions import TypeIs
+
+ _ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
+
+ _ArrayApiObj: TypeAlias = (
+ npt.NDArray[Any]
+ | cp.ndarray
+ | da.Array
+ | jax.Array
+ | ndx.Array
+ | sparse.SparseArray
+ | torch.Tensor
+ | SupportsArrayNamespace[Any]
+ )
+
+_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
+_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
-import sys
-import math
-import inspect
-import warnings
-def _is_jax_zero_gradient_array(x):
+@lru_cache(100)
+def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
+ try:
+ mod = sys.modules[modname]
+ except KeyError:
+ return False
+ parent_cls = getattr(mod, clsname)
+ return issubclass(cls, parent_cls)
+
+
+def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
"""Return True if `x` is a zero-gradient array.
These arrays are a design quirk of Jax that may one day be removed.
See https://github.com/google/jax/issues/20620.
"""
- if 'numpy' not in sys.modules or 'jax' not in sys.modules:
+ # Fast exit
+ try:
+ dtype = x.dtype # type: ignore[attr-defined]
+ except AttributeError:
+ return False
+ cls = cast(Hashable, type(dtype))
+ if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"):
+ return False
+
+ if "jax" not in sys.modules:
return False
- import numpy as np
import jax
+ # jax.float0 is a np.dtype([('float0', 'V')])
+ return dtype == jax.float0
- return isinstance(x, np.ndarray) and x.dtype == jax.float0
-def is_numpy_array(x):
+def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
"""
Return True if `x` is a NumPy array.
@@ -53,17 +114,15 @@ def is_numpy_array(x):
is_jax_array
is_pydata_sparse_array
"""
- # Avoid importing NumPy if it isn't already
- if 'numpy' not in sys.modules:
- return False
-
- import numpy as np
-
# TODO: Should we reject ndarray subclasses?
- return (isinstance(x, (np.ndarray, np.generic))
- and not _is_jax_zero_gradient_array(x))
+ cls = cast(Hashable, type(x))
+ return (
+ _issubclass_fast(cls, "numpy", "ndarray")
+ or _issubclass_fast(cls, "numpy", "generic")
+ ) and not _is_jax_zero_gradient_array(x)
+
-def is_cupy_array(x):
+def is_cupy_array(x: object) -> bool:
"""
Return True if `x` is a CuPy array.
@@ -84,16 +143,11 @@ def is_cupy_array(x):
is_jax_array
is_pydata_sparse_array
"""
- # Avoid importing CuPy if it isn't already
- if 'cupy' not in sys.modules:
- return False
+ cls = cast(Hashable, type(x))
+ return _issubclass_fast(cls, "cupy", "ndarray")
- import cupy as cp
-
- # TODO: Should we reject ndarray subclasses?
- return isinstance(x, (cp.ndarray, cp.generic))
-def is_torch_array(x):
+def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
"""
Return True if `x` is a PyTorch tensor.
@@ -111,20 +165,15 @@ def is_torch_array(x):
is_jax_array
is_pydata_sparse_array
"""
- # Avoid importing torch if it isn't already
- if 'torch' not in sys.modules:
- return False
+ cls = cast(Hashable, type(x))
+ return _issubclass_fast(cls, "torch", "Tensor")
- import torch
- # TODO: Should we reject ndarray subclasses?
- return isinstance(x, torch.Tensor)
-
-def is_paddle_array(x):
+def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
"""
- Return True if `x` is a Paddle tensor.
+ Return True if `x` is a ndonnx Array.
- This function does not import Paddle if it has not already been imported
+ This function does not import ndonnx if it has not already been imported
and is therefore cheap to use.
See Also
@@ -134,23 +183,20 @@ def is_paddle_array(x):
is_array_api_obj
is_numpy_array
is_cupy_array
+ is_ndonnx_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
- # Avoid importing paddle if it isn't already
- if 'paddle' not in sys.modules:
- return False
+ cls = cast(Hashable, type(x))
+ return _issubclass_fast(cls, "ndonnx", "Array")
- import paddle
- return paddle.is_tensor(x)
-
-def is_ndonnx_array(x):
+def is_dask_array(x: object) -> TypeIs[da.Array]:
"""
- Return True if `x` is a ndonnx Array.
+ Return True if `x` is a dask.array Array.
- This function does not import ndonnx if it has not already been imported
+ This function does not import dask if it has not already been imported
and is therefore cheap to use.
See Also
@@ -160,26 +206,23 @@ def is_ndonnx_array(x):
is_array_api_obj
is_numpy_array
is_cupy_array
+ is_torch_array
is_ndonnx_array
- is_dask_array
is_jax_array
is_pydata_sparse_array
"""
- # Avoid importing torch if it isn't already
- if 'ndonnx' not in sys.modules:
- return False
-
- import ndonnx as ndx
+ cls = cast(Hashable, type(x))
+ return _issubclass_fast(cls, "dask.array", "Array")
- return isinstance(x, ndx.Array)
-def is_dask_array(x):
+def is_jax_array(x: object) -> TypeIs[jax.Array]:
"""
- Return True if `x` is a dask.array Array.
+ Return True if `x` is a JAX array.
- This function does not import dask if it has not already been imported
+ This function does not import JAX if it has not already been imported
and is therefore cheap to use.
+
See Also
--------
@@ -189,22 +232,18 @@ def is_dask_array(x):
is_cupy_array
is_torch_array
is_ndonnx_array
- is_jax_array
+ is_dask_array
is_pydata_sparse_array
"""
- # Avoid importing dask if it isn't already
- if 'dask.array' not in sys.modules:
- return False
-
- import dask.array
+ cls = cast(Hashable, type(x))
+ return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
- return isinstance(x, dask.array.Array)
-def is_jax_array(x):
+def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
"""
- Return True if `x` is a JAX array.
+ Return True if `x` is an array from the `sparse` package.
- This function does not import JAX if it has not already been imported
+ This function does not import `sparse` if it has not already been imported
and is therefore cheap to use.
@@ -218,24 +257,20 @@ def is_jax_array(x):
is_torch_array
is_ndonnx_array
is_dask_array
- is_pydata_sparse_array
+ is_jax_array
"""
- # Avoid importing jax if it isn't already
- if 'jax' not in sys.modules:
- return False
-
- import jax
+ # TODO: Account for other backends.
+ cls = cast(Hashable, type(x))
+ return _issubclass_fast(cls, "sparse", "SparseArray")
- return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
-def is_pydata_sparse_array(x) -> bool:
+def is_paddle_array(x):
"""
- Return True if `x` is an array from the `sparse` package.
+ Return True if `x` is a Paddle tensor.
- This function does not import `sparse` if it has not already been imported
+ This function does not import Paddle if it has not already been imported
and is therefore cheap to use.
-
See Also
--------
@@ -243,21 +278,20 @@ def is_pydata_sparse_array(x) -> bool:
is_array_api_obj
is_numpy_array
is_cupy_array
- is_torch_array
- is_ndonnx_array
is_dask_array
is_jax_array
+ is_pydata_sparse_array
"""
- # Avoid importing jax if it isn't already
- if 'sparse' not in sys.modules:
+ # Avoid importing paddle if it isn't already
+ if 'paddle' not in sys.modules:
return False
- import sparse
+ import paddle
- # TODO: Account for other backends.
- return isinstance(x, sparse.SparseArray)
+ return paddle.is_tensor(x)
-def is_array_api_obj(x):
+
+def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
"""
Return True if `x` is an array API compatible array object.
@@ -272,20 +306,34 @@ def is_array_api_obj(x):
is_dask_array
is_jax_array
"""
- return is_numpy_array(x) \
- or is_cupy_array(x) \
- or is_torch_array(x) \
- or is_dask_array(x) \
- or is_jax_array(x) \
- or is_pydata_sparse_array(x) \
- or is_paddle_array(x) \
- or hasattr(x, '__array_namespace__')
-
-def _compat_module_name():
- assert __name__.endswith('.common._helpers')
- return __name__.removesuffix('.common._helpers')
-
-def is_numpy_namespace(xp) -> bool:
+ return (
+ hasattr(x, '__array_namespace__')
+ or _is_array_api_cls(cast(Hashable, type(x)))
+ )
+
+
+@lru_cache(100)
+def _is_array_api_cls(cls: type) -> bool:
+ return (
+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
+ _issubclass_fast(cls, "numpy", "ndarray")
+ or _issubclass_fast(cls, "numpy", "generic")
+ or _issubclass_fast(cls, "cupy", "ndarray")
+ or _issubclass_fast(cls, "torch", "Tensor")
+ or _issubclass_fast(cls, "dask.array", "Array")
+ or _issubclass_fast(cls, "sparse", "SparseArray")
+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
+ or _issubclass_fast(cls, "jax", "Array")
+ )
+
+
+def _compat_module_name() -> str:
+ assert __name__.endswith(".common._helpers")
+ return __name__.removesuffix(".common._helpers")
+
+
+@lru_cache(100)
+def is_numpy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a NumPy namespace.
@@ -303,9 +351,11 @@ def is_numpy_namespace(xp) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
+ return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}
-def is_cupy_namespace(xp) -> bool:
+
+@lru_cache(100)
+def is_cupy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
@@ -323,9 +373,11 @@ def is_cupy_namespace(xp) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
+ return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}
+
-def is_torch_namespace(xp) -> bool:
+@lru_cache(100)
+def is_torch_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
@@ -343,14 +395,12 @@ def is_torch_namespace(xp) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
+ return xp.__name__ in {"torch", _compat_module_name() + ".torch"}
-def is_paddle_namespace(xp) -> bool:
+def is_ndonnx_namespace(xp: Namespace) -> bool:
"""
- Returns True if `xp` is a Paddle namespace.
-
- This includes both Paddle itself and the version wrapped by array-api-compat.
+ Returns True if `xp` is an NDONNX namespace.
See Also
--------
@@ -358,18 +408,21 @@ def is_paddle_namespace(xp) -> bool:
array_namespace
is_numpy_namespace
is_cupy_namespace
- is_ndonnx_namespace
+ is_torch_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}
+ return xp.__name__ == "ndonnx"
-def is_ndonnx_namespace(xp):
+@lru_cache(100)
+def is_dask_namespace(xp: Namespace) -> bool:
"""
- Returns True if `xp` is an NDONNX namespace.
+ Returns True if `xp` is a Dask namespace.
+
+ This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
See Also
--------
@@ -378,18 +431,20 @@ def is_ndonnx_namespace(xp):
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
- is_dask_namespace
+ is_ndonnx_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ == 'ndonnx'
+ return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"}
-def is_dask_namespace(xp):
+
+def is_jax_namespace(xp: Namespace) -> bool:
"""
- Returns True if `xp` is a Dask namespace.
+ Returns True if `xp` is a JAX namespace.
- This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
+ This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
+ older versions of JAX.
See Also
--------
@@ -399,18 +454,16 @@ def is_dask_namespace(xp):
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
- is_jax_namespace
+ is_dask_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
+ return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"}
-def is_jax_namespace(xp):
- """
- Returns True if `xp` is a JAX namespace.
- This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
- older versions of JAX.
+def is_pydata_sparse_namespace(xp: Namespace) -> bool:
+ """
+ Returns True if `xp` is a pydata/sparse namespace.
See Also
--------
@@ -421,14 +474,17 @@ def is_jax_namespace(xp):
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
- is_pydata_sparse_namespace
+ is_jax_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
+ return xp.__name__ == "sparse"
+
-def is_pydata_sparse_namespace(xp):
+def is_paddle_namespace(xp) -> bool:
"""
- Returns True if `xp` is a pydata/sparse namespace.
+ Returns True if `xp` is a Paddle namespace.
+
+ This includes both Paddle itself and the version wrapped by array-api-compat.
See Also
--------
@@ -436,15 +492,16 @@ def is_pydata_sparse_namespace(xp):
array_namespace
is_numpy_namespace
is_cupy_namespace
- is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
+ is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
- return xp.__name__ == 'sparse'
+ return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}
+
-def is_array_api_strict_namespace(xp):
+def is_array_api_strict_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is an array-api-strict namespace.
@@ -460,16 +517,105 @@ def is_array_api_strict_namespace(xp):
is_jax_namespace
is_pydata_sparse_namespace
"""
- return xp.__name__ == 'array_api_strict'
+ return xp.__name__ == "array_api_strict"
-def _check_api_version(api_version):
- if api_version in ['2021.12', '2022.12']:
- warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
- elif api_version is not None and api_version not in ['2021.12', '2022.12',
- '2023.12']:
- raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
-def array_namespace(*xs, api_version=None, use_compat=None):
+def _check_api_version(api_version: str | None) -> None:
+ if api_version in _API_VERSIONS_OLD:
+ warnings.warn(
+ f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12"
+ )
+ elif api_version is not None and api_version not in _API_VERSIONS:
+ raise ValueError(
+ "Only the 2024.12 version of the array API specification is currently supported"
+ )
+
+
+class _ClsToXPInfo(enum.Enum):
+ SCALAR = 0
+ MAYBE_JAX_ZERO_GRADIENT = 1
+
+
+@lru_cache(100)
+def _cls_to_namespace(
+ cls: type,
+ api_version: str | None,
+ use_compat: bool | None,
+) -> tuple[Namespace | None, _ClsToXPInfo | None]:
+ if use_compat not in (None, True, False):
+ raise ValueError("use_compat must be None, True, or False")
+ _use_compat = use_compat in (None, True)
+ cls_ = cast(Hashable, cls) # Make mypy happy
+
+ if (
+ _issubclass_fast(cls_, "numpy", "ndarray")
+ or _issubclass_fast(cls_, "numpy", "generic")
+ ):
+ if use_compat is True:
+ _check_api_version(api_version)
+ from .. import numpy as xp
+ elif use_compat is False:
+ import numpy as xp # type: ignore[no-redef]
+ else:
+ # NumPy 2.0+ have __array_namespace__; however they are not
+ # yet fully array API compatible.
+ from .. import numpy as xp # type: ignore[no-redef]
+ return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT
+
+ # Note: this must happen _after_ the test for np.generic,
+ # because np.float64 and np.complex128 are subclasses of float and complex.
+ if issubclass(cls, int | float | complex | type(None)):
+ return None, _ClsToXPInfo.SCALAR
+
+ if _issubclass_fast(cls_, "cupy", "ndarray"):
+ if _use_compat:
+ _check_api_version(api_version)
+ from .. import cupy as xp # type: ignore[no-redef]
+ else:
+ import cupy as xp # type: ignore[no-redef]
+ return xp, None
+
+ if _issubclass_fast(cls_, "torch", "Tensor"):
+ if _use_compat:
+ _check_api_version(api_version)
+ from .. import torch as xp # type: ignore[no-redef]
+ else:
+ import torch as xp # type: ignore[no-redef]
+ return xp, None
+
+ if _issubclass_fast(cls_, "dask.array", "Array"):
+ if _use_compat:
+ _check_api_version(api_version)
+ from ..dask import array as xp # type: ignore[no-redef]
+ else:
+ import dask.array as xp # type: ignore[no-redef]
+ return xp, None
+
+ # Backwards compatibility for jax<0.4.32
+ if _issubclass_fast(cls_, "jax", "Array"):
+ return _jax_namespace(api_version, use_compat), None
+
+ return None, None
+
+
+def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace:
+ if use_compat:
+ raise ValueError("JAX does not have an array-api-compat wrapper")
+ import jax.numpy as jnp
+ if not hasattr(jnp, "__array_namespace_info__"):
+ # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
+ # For older JAX versions, it is available via jax.experimental.array_api.
+ # jnp.Array objects gain the __array_namespace__ method.
+ import jax.experimental.array_api # noqa: F401
+ # Test api_version
+ return jnp.empty(0).__array_namespace__(api_version=api_version)
+
+
+def array_namespace(
+ *xs: Array | complex | None,
+ api_version: str | None = None,
+ use_compat: bool | None = None,
+) -> Namespace:
"""
Get the array API compatible namespace for the arrays `xs`.
@@ -481,7 +627,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
api_version: str
The newest version of the spec that you need support for (currently
- the compat library wrapped APIs support v2023.12).
+ the compat library wrapped APIs support v2024.12).
use_compat: bool or None
If None (the default), the native namespace will be returned if it is
@@ -533,125 +679,85 @@ def your_function(x, y):
is_pydata_sparse_array
"""
- if use_compat not in [None, True, False]:
- raise ValueError("use_compat must be None, True, or False")
-
- _use_compat = use_compat in [None, True]
-
- namespaces = set()
+ namespaces: set[Namespace] = set()
for x in xs:
- if is_numpy_array(x):
- from .. import numpy as numpy_namespace
- import numpy as np
- if use_compat is True:
- _check_api_version(api_version)
- namespaces.add(numpy_namespace)
- elif use_compat is False:
- namespaces.add(np)
- else:
- # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
- # compatible.
- namespaces.add(numpy_namespace)
- elif is_cupy_array(x):
- if _use_compat:
- _check_api_version(api_version)
- from .. import cupy as cupy_namespace
- namespaces.add(cupy_namespace)
- else:
- import cupy as cp
- namespaces.add(cp)
- elif is_torch_array(x):
- if _use_compat:
- _check_api_version(api_version)
- from .. import torch as torch_namespace
- namespaces.add(torch_namespace)
- else:
- import torch
- namespaces.add(torch)
- elif is_dask_array(x):
- if _use_compat:
- _check_api_version(api_version)
- from ..dask import array as dask_namespace
- namespaces.add(dask_namespace)
- else:
- import dask.array as da
- namespaces.add(da)
- elif is_jax_array(x):
- if use_compat is True:
- _check_api_version(api_version)
- raise ValueError("JAX does not have an array-api-compat wrapper")
- elif use_compat is False:
- import jax.numpy as jnp
- else:
- # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
- # For older JAX versions, it is available via jax.experimental.array_api.
- import jax.numpy
- if hasattr(jax.numpy, "__array_api_version__"):
- jnp = jax.numpy
- else:
- import jax.experimental.array_api as jnp
- namespaces.add(jnp)
- elif is_paddle_array(x):
- if _use_compat:
- _check_api_version(api_version)
- from .. import paddle as paddle_namespace
- namespaces.add(paddle_namespace)
- else:
- import paddle
- namespaces.add(paddle)
- elif is_pydata_sparse_array(x):
- if use_compat is True:
- _check_api_version(api_version)
- raise ValueError("`sparse` does not have an array-api-compat wrapper")
- else:
- import sparse
- # `sparse` is already an array namespace. We do not have a wrapper
- # submodule for it.
- namespaces.add(sparse)
- elif hasattr(x, '__array_namespace__'):
- if use_compat is True:
- raise ValueError("The given array does not have an array-api-compat wrapper")
- namespaces.add(x.__array_namespace__(api_version=api_version))
- elif isinstance(x, (bool, int, float, complex, type(None))):
+ xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat)
+ if info is _ClsToXPInfo.SCALAR:
continue
- else:
- # TODO: Support Python scalars?
- raise TypeError(f"{type(x).__name__} is not a supported array type")
-
- if not namespaces:
- raise TypeError("Unrecognized array input")
- if len(namespaces) != 1:
+ if (
+ info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT
+ and _is_jax_zero_gradient_array(x)
+ ):
+ xp = _jax_namespace(api_version, use_compat)
+
+ if xp is None:
+ get_ns = getattr(x, "__array_namespace__", None)
+ if get_ns is None:
+ raise TypeError(f"{type(x).__name__} is not a supported array type")
+ if use_compat:
+ raise ValueError(
+ "The given array does not have an array-api-compat wrapper"
+ )
+ xp = get_ns(api_version=api_version)
+
+ namespaces.add(xp)
+
+ try:
+ (xp,) = namespaces
+ return xp
+ except ValueError:
+ if not namespaces:
+ raise TypeError(
+ "array_namespace requires at least one non-scalar array input"
+ )
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
- xp, = namespaces
-
- return xp
# backwards compatibility alias
get_namespace = array_namespace
-def _check_device(xp, device):
- if xp == sys.modules.get('numpy'):
- if device not in ["cpu", None]:
+
+def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
+ """
+ Validate dummy device on device-less array backends.
+
+ Notes
+ -----
+ This function is also invoked by CuPy, which does have multiple devices
+ if there are multiple GPUs available.
+ However, CuPy multi-device support is currently impossible
+ without using the global device or a context manager:
+
+ https://github.com/data-apis/array-api-compat/pull/293
+ """
+ if bare_xp is sys.modules.get("numpy"):
+ if device not in ("cpu", None):
raise ValueError(f"Unsupported device for NumPy: {device!r}")
+ elif bare_xp is sys.modules.get("dask.array"):
+ if device not in ("cpu", _DASK_DEVICE, None):
+ raise ValueError(f"Unsupported device for Dask: {device!r}")
+
+
# Placeholder object to represent the dask device
# when the array backend is not the CPU.
# (since it is not easy to tell which device a dask array is on)
class _dask_device:
- def __repr__(self):
+ def __repr__(self) -> Literal["DASK_DEVICE"]:
return "DASK_DEVICE"
+
_DASK_DEVICE = _dask_device()
+
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
-def device(x: Array, /) -> Device:
+def device(x: _ArrayApiObj, /) -> Device:
"""
Hardware device the array data resides on.
@@ -686,36 +792,36 @@ def device(x: Array, /) -> Device:
if is_numpy_array(x):
return "cpu"
elif is_dask_array(x):
- # Peek at the metadata of the jax array to determine type
- try:
- import numpy as np
- if isinstance(x._meta, np.ndarray):
- # Must be on CPU since backed by numpy
- return "cpu"
- except ImportError:
- pass
+ # Peek at the metadata of the Dask array to determine type
+ if is_numpy_array(x._meta):
+ # Must be on CPU since backed by numpy
+ return "cpu"
return _DASK_DEVICE
elif is_jax_array(x):
- # JAX has .device() as a method, but it is being deprecated so that it
- # can become a property, in accordance with the standard. In order for
- # this function to not break when JAX makes the flip, we check for
- # both here.
- if inspect.ismethod(x.device):
- return x.device()
+ # FIXME Jitted JAX arrays do not have a device attribute
+ # https://github.com/jax-ml/jax/issues/26000
+ # Return None in this case. Note that this workaround breaks
+ # the standard and will result in new arrays being created on the
+ # default device instead of the same device as the input array(s).
+ x_device = getattr(x, "device", None)
+ # Older JAX releases had .device() as a method, which has been replaced
+ # with a property in accordance with the standard.
+ if inspect.ismethod(x_device):
+ return x_device()
else:
- return x.device
+ return x_device
elif is_pydata_sparse_array(x):
# `sparse` will gain `.device`, so check for this first.
- x_device = getattr(x, 'device', None)
+ x_device = getattr(x, "device", None)
if x_device is not None:
return x_device
# Everything but DOK has this attr.
try:
- inner = x.data
+ inner = x.data # pyright: ignore
except AttributeError:
return "cpu"
# Return the device of the constituent array
- return device(inner)
+ return device(inner) # pyright: ignore
elif is_paddle_array(x):
raw_place_str = str(x.place)
if "gpu_pinned" in raw_place_str:
@@ -725,65 +831,64 @@ def device(x: Array, /) -> Device:
elif "gpu" in raw_place_str:
return "gpu"
raise ValueError(f"Unsupported Paddle device: {x.place}")
+ return x.device # type: ignore # pyright: ignore
- return x.device
# Prevent shadowing, used below
_device = device
+
# Based on cupy.array_api.Array.to_device
-def _cupy_to_device(x, device, /, stream=None):
+def _cupy_to_device(
+ x: cp.ndarray,
+ device: Device,
+ /,
+ stream: int | Any | None = None,
+) -> cp.ndarray:
import cupy as cp
- from cupy.cuda import Device as _Device
- from cupy.cuda import stream as stream_module
- from cupy_backends.cuda.api import runtime
- if device == x.device:
- return x
- elif device == "cpu":
+ if device == "cpu":
# allowing us to use `to_device(x, "cpu")`
# is useful for portable test swapping between
# host and device backends
return x.get()
- elif not isinstance(device, _Device):
- raise ValueError(f"Unsupported device {device!r}")
- else:
- # see cupy/cupy#5985 for the reason how we handle device/stream here
- prev_device = runtime.getDevice()
- prev_stream: stream_module.Stream = None
- if stream is not None:
- prev_stream = stream_module.get_current_stream()
- # stream can be an int as specified in __dlpack__, or a CuPy stream
- if isinstance(stream, int):
- stream = cp.cuda.ExternalStream(stream)
- elif isinstance(stream, cp.cuda.Stream):
- pass
- else:
- raise ValueError('the input stream is not recognized')
- stream.use()
- try:
- runtime.setDevice(device.id)
- arr = x.copy()
- finally:
- runtime.setDevice(prev_device)
- if stream is not None:
- prev_stream.use()
- return arr
-
-def _torch_to_device(x, device, /, stream=None):
+ if not isinstance(device, cp.cuda.Device):
+ raise TypeError(f"Unsupported device type {device!r}")
+
+ if stream is None:
+ with device:
+ return cp.asarray(x)
+
+ # stream can be an int as specified in __dlpack__, or a CuPy stream
+ if isinstance(stream, int):
+ stream = cp.cuda.ExternalStream(stream)
+ elif not isinstance(stream, cp.cuda.Stream):
+ raise TypeError(f"Unsupported stream type {stream!r}")
+
+ with device, stream:
+ return cp.asarray(x)
+
+
+def _torch_to_device(
+ x: torch.Tensor,
+ device: torch.device | str | int,
+ /,
+ stream: int | Any | None = None,
+) -> torch.Tensor:
if stream is not None:
raise NotImplementedError
return x.to(device)
-def _paddle_to_device(x, device, /, stream=None):
+
+def _paddle_to_device(x: paddle.Tensor, device, /, stream=None):
if stream is not None:
raise NotImplementedError(
- "paddle.Tensor.to() do not support stream argument yet"
+ "paddle.Tensor.to() do not support 'stream' argument yet"
)
- return x.to(device)
+ return x.to(device=device)
-def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
+def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -803,7 +908,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
a ``device`` object (see the `Device Support `__
section of the array API specification).
- stream: Optional[Union[int, Any]]
+ stream: int | Any | None
stream object to use during copy. In addition to the types supported
in ``array.__dlpack__``, implementations may choose to support any
library-specific stream object with the caveat that any code using
@@ -835,7 +940,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
if is_numpy_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
- if device == 'cpu':
+ if device == "cpu":
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_cupy_array(x):
@@ -847,13 +952,17 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
- if device == 'cpu':
+ if device == "cpu":
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_jax_array(x):
if not hasattr(x, "__array_namespace__"):
- # In JAX v0.4.31 and older, this import adds to_device method to x.
- import jax.experimental.array_api # noqa: F401
+ # In JAX v0.4.31 and older, this import adds to_device method to x...
+ import jax.experimental.array_api # noqa: F401 # pyright: ignore
+
+ # ... but only on eager JAX. It won't work inside jax.jit.
+ if not hasattr(x, "to_device"):
+ return x
return x.to_device(device, stream=stream)
elif is_paddle_array(x):
return _paddle_to_device(x, device, stream=stream)
@@ -861,21 +970,140 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# Perform trivial check to return the same array if
# device is same instead of err-ing.
return x
- return x.to_device(device, stream=stream)
+ return x.to_device(device, stream=stream) # pyright: ignore
+
-def size(x):
+@overload
+def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
+@overload
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
"""
Return the total number of elements of x.
This is equivalent to `x.size` according to the `standard
`__.
+
This helper is included because PyTorch defines `size` in an
:external+torch:meth:`incompatible way `.
-
+ It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
+ the standard requires None.
"""
+ # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
if None in x.shape:
return None
- return math.prod(x.shape)
+ out = math.prod(cast("Collection[SupportsIndex]", x.shape))
+ # dask.array.Array.shape can contain NaN
+ return None if math.isnan(out) else out
+
+
+@lru_cache(100)
+def _is_writeable_cls(cls: type) -> bool | None:
+ if (
+ _issubclass_fast(cls, "numpy", "generic")
+ or _issubclass_fast(cls, "jax", "Array")
+ or _issubclass_fast(cls, "sparse", "SparseArray")
+ ):
+ return False
+ if _is_array_api_cls(cls):
+ return True
+ return None
+
+
+def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
+ """
+ Return False if ``x.__setitem__`` is expected to raise; True otherwise.
+ Return False if `x` is not an array API compatible object.
+
+ Warning
+ -------
+ As there is no standard way to check if an array is writeable without actually
+ writing to it, this function blindly returns True for all unknown array types.
+ """
+ cls = cast(Hashable, type(x))
+ if _issubclass_fast(cls, "numpy", "ndarray"):
+ return cast("npt.NDArray", x).flags.writeable
+ res = _is_writeable_cls(cls)
+ if res is not None:
+ return res
+ return hasattr(x, '__array_namespace__')
+
+
+@lru_cache(100)
+def _is_lazy_cls(cls: type) -> bool | None:
+ if (
+ _issubclass_fast(cls, "numpy", "ndarray")
+ or _issubclass_fast(cls, "numpy", "generic")
+ or _issubclass_fast(cls, "cupy", "ndarray")
+ or _issubclass_fast(cls, "torch", "Tensor")
+ or _issubclass_fast(cls, "paddle", "Tensor")
+ or _issubclass_fast(cls, "sparse", "SparseArray")
+ ):
+ return False
+ if (
+ _issubclass_fast(cls, "jax", "Array")
+ or _issubclass_fast(cls, "dask.array", "Array")
+ or _issubclass_fast(cls, "ndonnx", "Array")
+ ):
+ return True
+ return None
+
+
+def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
+ """Return True if x is potentially a future or it may be otherwise impossible or
+ expensive to eagerly read its contents, regardless of their size, e.g. by
+ calling ``bool(x)`` or ``float(x)``.
+
+ Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
+ cheap as long as the array has the right dtype and size.
+
+ Note
+ ----
+ This function errs on the side of caution for array types that may or may not be
+ lazy, e.g. JAX arrays, by always returning True for them.
+ """
+ # **JAX note:** while it is possible to determine if you're inside or outside
+ # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
+ # as we do below for unknown arrays, this is not recommended by JAX best practices.
+
+ # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
+ # This behaviour, while impossible to change without breaking backwards
+ # compatibility, is highly detrimental to performance as the whole graph will end
+ # up being computed multiple times.
+
+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
+ # exclusively get them once they leave a jax.grad JIT context.
+ cls = cast(Hashable, type(x))
+ res = _is_lazy_cls(cls)
+ if res is not None:
+ return res
+
+ if not hasattr(x, "__array_namespace__"):
+ return False
+
+ # Unknown Array API compatible object. Note that this test may have dire consequences
+ # in terms of performance, e.g. for a lazy object that eagerly computes the graph
+ # on __bool__ (dask is one such example, which however is special-cased above).
+
+ # Select a single point of the array
+ s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
+ if s is None:
+ return True
+ xp = array_namespace(x)
+ if s > 1:
+ x = xp.reshape(x, (-1,))[0]
+ # Cast to dtype=bool and deal with size 0 arrays
+ x = xp.any(x)
+
+ try:
+ bool(x)
+ return False
+ # The Array API standard dictactes that __bool__ should raise TypeError if the
+ # output cannot be defined.
+ # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
+ except Exception:
+ return True
+
__all__ = [
"array_namespace",
@@ -899,8 +1127,11 @@ def size(x):
"is_paddle_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
+ "is_writeable_array",
+ "is_lazy_array",
"size",
"to_device",
]
-_all_ignore = ['sys', 'math', 'inspect', 'warnings']
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index bfa1f1b9..69672af7 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -1,85 +1,114 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, NamedTuple
-if TYPE_CHECKING:
- from typing import Literal, Optional, Tuple, Union
- from ._typing import ndarray
-
import math
+from typing import Literal, NamedTuple, cast
import numpy as np
+
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
- from numpy.core.numeric import normalize_axis_tuple
+ from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
-from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
from .._internal import get_xp
+from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
+from ._typing import Array, DType, JustFloat, JustInt, Namespace
+
# These are in the main NumPy namespace but not in numpy.linalg
-def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
+def cross(
+ x1: Array,
+ x2: Array,
+ /,
+ xp: Namespace,
+ *,
+ axis: int = -1,
+ **kwargs: object,
+) -> Array:
return xp.cross(x1, x2, axis=axis, **kwargs)
-def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
+def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
return xp.outer(x1, x2, **kwargs)
class EighResult(NamedTuple):
- eigenvalues: ndarray
- eigenvectors: ndarray
+ eigenvalues: Array
+ eigenvectors: Array
class QRResult(NamedTuple):
- Q: ndarray
- R: ndarray
+ Q: Array
+ R: Array
class SlogdetResult(NamedTuple):
- sign: ndarray
- logabsdet: ndarray
+ sign: Array
+ logabsdet: Array
class SVDResult(NamedTuple):
- U: ndarray
- S: ndarray
- Vh: ndarray
+ U: Array
+ S: Array
+ Vh: Array
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
+def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
return EighResult(*xp.linalg.eigh(x, **kwargs))
-def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
- **kwargs) -> QRResult:
+def qr(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ mode: Literal["reduced", "complete"] = "reduced",
+ **kwargs: object,
+) -> QRResult:
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
-def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult:
+def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
-def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult:
+def svd(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ full_matrices: bool = True,
+ **kwargs: object,
+) -> SVDResult:
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
# These functions have additional keyword arguments
# The upper keyword argument is new from NumPy
-def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
+def cholesky(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ upper: bool = False,
+ **kwargs: object,
+) -> Array:
L = xp.linalg.cholesky(x, **kwargs)
if upper:
U = get_xp(xp)(matrix_transpose)(L)
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
- U = xp.conj(U)
+ U = xp.conj(U) # pyright: ignore[reportConstantRedefinition]
return U
return L
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
# Note that it has a different semantic meaning from tol and rcond.
-def matrix_rank(x: ndarray,
- /,
- xp,
- *,
- rtol: Optional[Union[float, ndarray]] = None,
- **kwargs) -> ndarray:
+def matrix_rank(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ rtol: float | Array | None = None,
+ **kwargs: object,
+) -> Array:
# this is different from xp.linalg.matrix_rank, which supports 1
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
- S = get_xp(xp)(svdvals)(x, **kwargs)
+ S: Array = get_xp(xp)(svdvals)(x, **kwargs)
if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
else:
@@ -88,7 +117,14 @@ def matrix_rank(x: ndarray,
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
return xp.count_nonzero(S > tol, axis=-1)
-def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray:
+def pinv(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ rtol: float | Array | None = None,
+ **kwargs: object,
+) -> Array:
# this is different from xp.linalg.pinv, which does not multiply the
# default tolerance by max(M, N).
if rtol is None:
@@ -97,15 +133,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k
# These functions are new in the array API spec
-def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
+def matrix_norm(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ keepdims: bool = False,
+ ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro",
+) -> Array:
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# xp.linalg.svd(compute_uv=False).
-def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]:
+def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
return xp.linalg.svd(x, compute_uv=False)
-def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
+def vector_norm(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ axis: int | tuple[int, ...] | None = None,
+ keepdims: bool = False,
+ ord: JustInt | JustFloat = 2,
+) -> Array:
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
# it so the input is 1-D (for axis=None), or reshape so that norm is done
@@ -117,7 +168,10 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas
# xp.linalg.norm() only supports a single axis for vector norm.
- normalized_axis = normalize_axis_tuple(axis, x.ndim)
+ normalized_axis = cast(
+ "tuple[int, ...]",
+ normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue]
+ )
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
newshape = axis + rest
_x = xp.transpose(x, newshape).reshape(
@@ -133,8 +187,14 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
- _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
- for i in _axis:
+ axes = cast(
+ "tuple[int, ...]",
+ normalize_axis_tuple( # pyright: ignore[reportCallIssue]
+ range(x.ndim) if axis is None else axis,
+ x.ndim,
+ ),
+ )
+ for i in axes:
shape[i] = 1
res = xp.reshape(res, tuple(shape))
@@ -143,14 +203,28 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
# xp.diagonal and xp.trace operate on the first two axes whereas these
# operates on the last two
-def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
+def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
-def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
- return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
+def trace(
+ x: Array,
+ /,
+ xp: Namespace,
+ *,
+ offset: int = 0,
+ dtype: DType | None = None,
+ **kwargs: object,
+) -> Array:
+ return xp.asarray(
+ xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)
+ )
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']
+
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index 07f3850d..11b00bd1 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,23 +1,189 @@
from __future__ import annotations
-__all__ = [
- "NestedSequence",
- "SupportsBufferProtocol",
-]
-
+from collections.abc import Mapping
+from types import ModuleType as Namespace
from typing import (
- Any,
- TypeVar,
+ TYPE_CHECKING,
+ Literal,
Protocol,
+ TypeAlias,
+ TypedDict,
+ TypeVar,
+ final,
)
+if TYPE_CHECKING:
+ from _typeshed import Incomplete
+
+ SupportsBufferProtocol: TypeAlias = Incomplete
+ Array: TypeAlias = Incomplete
+ Device: TypeAlias = Incomplete
+ DType: TypeAlias = Incomplete
+else:
+ SupportsBufferProtocol = object
+ Array = object
+ Device = object
+ DType = object
+
+
_T_co = TypeVar("_T_co", covariant=True)
+
+# These "Just" types are equivalent to the `Just` type from the `optype` library,
+# apart from them not being `@runtime_checkable`.
+# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
+# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
+@final
+class JustInt(Protocol): # type: ignore[misc]
+ @property # type: ignore[override]
+ def __class__(self, /) -> type[int]: ...
+ @__class__.setter
+ def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
+
+
+@final
+class JustFloat(Protocol): # type: ignore[misc]
+ @property # type: ignore[override]
+ def __class__(self, /) -> type[float]: ...
+ @__class__.setter
+ def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
+
+
+@final
+class JustComplex(Protocol): # type: ignore[misc]
+ @property # type: ignore[override]
+ def __class__(self, /) -> type[complex]: ...
+ @__class__.setter
+ def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
+
+
class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
-SupportsBufferProtocol = Any
-Array = Any
-Device = Any
+class SupportsArrayNamespace(Protocol[_T_co]):
+ def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
+
+
+class HasShape(Protocol[_T_co]):
+ @property
+ def shape(self, /) -> _T_co: ...
+
+
+# Return type of `__array_namespace_info__.default_dtypes`
+Capabilities = TypedDict(
+ "Capabilities",
+ {
+ "boolean indexing": bool,
+ "data-dependent shapes": bool,
+ "max dimensions": int,
+ },
+)
+
+# Return type of `__array_namespace_info__.default_dtypes`
+DefaultDTypes = TypedDict(
+ "DefaultDTypes",
+ {
+ "real floating": DType,
+ "complex floating": DType,
+ "integral": DType,
+ "indexing": DType,
+ },
+)
+
+
+_DTypeKind: TypeAlias = Literal[
+ "bool",
+ "signed integer",
+ "unsigned integer",
+ "integral",
+ "real floating",
+ "complex floating",
+ "numeric",
+]
+# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
+DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
+
+
+# `__array_namespace_info__.dtypes(kind="bool")`
+class DTypesBool(TypedDict):
+ bool: DType
+
+
+# `__array_namespace_info__.dtypes(kind="signed integer")`
+class DTypesSigned(TypedDict):
+ int8: DType
+ int16: DType
+ int32: DType
+ int64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="unsigned integer")`
+class DTypesUnsigned(TypedDict):
+ uint8: DType
+ uint16: DType
+ uint32: DType
+ uint64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="integral")`
+class DTypesIntegral(DTypesSigned, DTypesUnsigned):
+ pass
+
+
+# `__array_namespace_info__.dtypes(kind="real floating")`
+class DTypesReal(TypedDict):
+ float32: DType
+ float64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="complex floating")`
+class DTypesComplex(TypedDict):
+ complex64: DType
+ complex128: DType
+
+
+# `__array_namespace_info__.dtypes(kind="numeric")`
+class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
+ pass
+
+
+# `__array_namespace_info__.dtypes(kind=None)` (default)
+class DTypesAll(DTypesBool, DTypesNumeric):
+ pass
+
+
+# `__array_namespace_info__.dtypes(kind=?)` (fallback)
+DTypesAny: TypeAlias = Mapping[str, DType]
+
+
+__all__ = [
+ "Array",
+ "Capabilities",
+ "DType",
+ "DTypeKind",
+ "DTypesAny",
+ "DTypesAll",
+ "DTypesBool",
+ "DTypesNumeric",
+ "DTypesIntegral",
+ "DTypesSigned",
+ "DTypesUnsigned",
+ "DTypesReal",
+ "DTypesComplex",
+ "DefaultDTypes",
+ "Device",
+ "HasShape",
+ "Namespace",
+ "JustInt",
+ "JustFloat",
+ "JustComplex",
+ "NestedSequence",
+ "SupportsArrayNamespace",
+ "SupportsBufferProtocol",
+]
+
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py
index d8685761..af003c5a 100644
--- a/array_api_compat/cupy/__init__.py
+++ b/array_api_compat/cupy/__init__.py
@@ -1,3 +1,4 @@
+from typing import Final
from cupy import * # noqa: F403
# from cupy import * doesn't overwrite these builtin names
@@ -5,12 +6,19 @@
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
+from ._info import __array_namespace_info__ # noqa: F401
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
-
__import__(__package__ + '.fft')
-from ..common._helpers import * # noqa: F401,F403
+__array_api_version__: Final = '2024.12'
+
+__all__ = sorted(
+ {name for name in globals() if not name.startswith("__")}
+ - {"Final", "_aliases", "_info", "_typing"}
+ | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
+)
-__array_api_version__ = '2023.12'
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index 3627fb6b..2e512fc8 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -1,16 +1,13 @@
from __future__ import annotations
+from builtins import bool as py_bool
+
import cupy as cp
-from ..common import _aliases
+from ..common import _aliases, _helpers
+from ..common._typing import NestedSequence, SupportsBufferProtocol
from .._internal import get_xp
-
-from ._info import __array_namespace_info__
-
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- from typing import Optional, Union
- from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
+from ._typing import Array, Device, DType
bool = cp.bool_
@@ -46,43 +43,34 @@
unique_counts = get_xp(cp)(_aliases.unique_counts)
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
unique_values = get_xp(cp)(_aliases.unique_values)
-astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
+cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
clip = get_xp(cp)(_aliases.clip)
permute_dims = get_xp(cp)(_aliases.permute_dims)
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
-ceil = get_xp(cp)(_aliases.ceil)
-floor = get_xp(cp)(_aliases.floor)
-trunc = get_xp(cp)(_aliases.trunc)
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
sign = get_xp(cp)(_aliases.sign)
+finfo = get_xp(cp)(_aliases.finfo)
+iinfo = get_xp(cp)(_aliases.iinfo)
-_copy_default = object()
# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
- obj: Union[
- ndarray,
- bool,
- int,
- float,
- NestedSequence[bool | int | float],
- SupportsBufferProtocol,
- ],
+ obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- copy: Optional[bool] = _copy_default,
- **kwargs,
-) -> ndarray:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ copy: py_bool | None = None,
+ **kwargs: object,
+) -> Array:
"""
Array API compatibility wrapper for asarray().
@@ -90,25 +78,66 @@ def asarray(
specification for more details.
"""
with cp.cuda.Device(device):
- # cupy is like NumPy 1.26 (except without _CopyMode). See the comments
- # in asarray in numpy/_aliases.py.
- if copy is not _copy_default:
- # A future version of CuPy will change the meaning of copy=False
- # to mean no-copy. We don't know for certain what version it will
- # be yet, so to avoid breaking that version, we use a different
- # default value for copy so asarray(obj) with no copy kwarg will
- # always do the copy-if-needed behavior.
-
- # This will still need to be updated to remove the
- # NotImplementedError for copy=False, but at least this won't
- # break the default or existing behavior.
- if copy is None:
- copy = False
- elif copy is False:
- raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
- kwargs['copy'] = copy
-
- return cp.array(obj, dtype=dtype, **kwargs)
+ if copy is None:
+ return cp.asarray(obj, dtype=dtype, **kwargs)
+ else:
+ res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
+ if not copy and res is not obj:
+ raise ValueError("Unable to avoid copy while creating an array as requested")
+ return res
+
+
+def astype(
+ x: Array,
+ dtype: DType,
+ /,
+ *,
+ copy: py_bool = True,
+ device: Device | None = None,
+) -> Array:
+ if device is None:
+ return x.astype(dtype=dtype, copy=copy)
+ out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
+ return out.copy() if copy and out is x else out
+
+
+# cupy.count_nonzero does not have keepdims
+def count_nonzero(
+ x: Array,
+ axis: int | tuple[int, ...] | None = None,
+ keepdims: py_bool = False,
+) -> Array:
+ result = cp.count_nonzero(x, axis)
+ if keepdims:
+ if axis is None:
+ return cp.reshape(result, [1]*x.ndim)
+ return cp.expand_dims(result, axis)
+ return result
+
+# ceil, floor, and trunc return integers for integer inputs
+
+def ceil(x: Array, /) -> Array:
+ if cp.issubdtype(x.dtype, cp.integer):
+ return x.copy()
+ return cp.ceil(x)
+
+
+def floor(x: Array, /) -> Array:
+ if cp.issubdtype(x.dtype, cp.integer):
+ return x.copy()
+ return cp.floor(x)
+
+
+def trunc(x: Array, /) -> Array:
+ if cp.issubdtype(x.dtype, cp.integer):
+ return x.copy()
+ return cp.trunc(x)
+
+
+# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
+ return cp.take_along_axis(x, indices, axis=axis)
+
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
@@ -127,10 +156,13 @@ def asarray(
else:
unstack = get_xp(cp)(_aliases.unstack)
-__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
+__all__ = _aliases.__all__ + ['asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
- 'concat', 'pow', 'sign']
+ 'bool', 'concat', 'count_nonzero', 'pow', 'sign',
+ 'ceil', 'floor', 'trunc', 'take_along_axis']
+
-_all_ignore = ['cp', 'get_xp']
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py
index 4440807d..78e48a33 100644
--- a/array_api_compat/cupy/_info.py
+++ b/array_api_compat/cupy/_info.py
@@ -26,6 +26,7 @@
complex128,
)
+
class __array_namespace_info__:
"""
Get the array API inspection namespace for CuPy.
@@ -49,7 +50,7 @@ class __array_namespace_info__:
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': cupy.float64,
'complex floating': cupy.complex128,
@@ -94,14 +95,14 @@ def capabilities(self):
>>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
- 'data-dependent shapes': True}
+ 'data-dependent shapes': True,
+ 'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
- # 'max rank' will be part of the 2024.12 standard
- # "max rank": 64,
+ "max dimensions": 64,
}
def default_device(self):
@@ -117,7 +118,7 @@ def default_device(self):
Returns
-------
- device : str
+ device : Device
The default device used for new CuPy arrays.
Examples
@@ -126,6 +127,15 @@ def default_device(self):
>>> info.default_device()
Device(0)
+ Notes
+ -----
+ This method returns the static default device when CuPy is initialized.
+ However, the *current* device used by creation functions (``empty`` etc.)
+ can be changed globally or with a context manager.
+
+ See Also
+ --------
+ https://github.com/data-apis/array-api/issues/835
"""
return cuda.Device(0)
@@ -312,7 +322,7 @@ def devices(self):
Returns
-------
- devices : list of str
+ devices : list[Device]
The devices supported by CuPy.
See Also
diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py
index f3d9aab6..e5c202dc 100644
--- a/array_api_compat/cupy/_typing.py
+++ b/array_api_compat/cupy/_typing.py
@@ -1,46 +1,30 @@
from __future__ import annotations
-__all__ = [
- "ndarray",
- "Device",
- "Dtype",
-]
+__all__ = ["Array", "DType", "Device"]
-import sys
-from typing import (
- Union,
- TYPE_CHECKING,
-)
-
-from cupy import (
- ndarray,
- dtype,
- int8,
- int16,
- int32,
- int64,
- uint8,
- uint16,
- uint32,
- uint64,
- float32,
- float64,
-)
+from typing import TYPE_CHECKING
+import cupy as cp
+from cupy import ndarray as Array
from cupy.cuda.device import Device
-if TYPE_CHECKING or sys.version_info >= (3, 9):
- Dtype = dtype[Union[
- int8,
- int16,
- int32,
- int64,
- uint8,
- uint16,
- uint32,
- uint64,
- float32,
- float64,
- ]]
+if TYPE_CHECKING:
+ # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
+ DType = cp.dtype[
+ cp.intp
+ | cp.int8
+ | cp.int16
+ | cp.int32
+ | cp.int64
+ | cp.uint8
+ | cp.uint16
+ | cp.uint32
+ | cp.uint64
+ | cp.float32
+ | cp.float64
+ | cp.complex64
+ | cp.complex128
+ | cp.bool_
+ ]
else:
- Dtype = dtype
+ DType = cp.dtype
diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py
index 307e0f72..53a9a454 100644
--- a/array_api_compat/cupy/fft.py
+++ b/array_api_compat/cupy/fft.py
@@ -1,10 +1,11 @@
-from cupy.fft import * # noqa: F403
+from cupy.fft import * # noqa: F403
+
# cupy.fft doesn't have __all__. If it is added, replace this with
#
# from cupy.fft import __all__ as linalg_all
-_n = {}
-exec('from cupy.fft import *', _n)
-del _n['__builtins__']
+_n: dict[str, object] = {}
+exec("from cupy.fft import *", _n)
+del _n["__builtins__"]
fft_all = list(_n)
del _n
@@ -30,7 +31,6 @@
__all__ = fft_all + _fft.__all__
-del get_xp
-del cp
-del fft_all
-del _fft
+def __dir__() -> list[str]:
+ return __all__
+
diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py
index 7fcdd498..da301574 100644
--- a/array_api_compat/cupy/linalg.py
+++ b/array_api_compat/cupy/linalg.py
@@ -2,7 +2,7 @@
# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
-_n = {}
+_n: dict[str, object] = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
@@ -43,7 +43,5 @@
__all__ = linalg_all + _linalg.__all__
-del get_xp
-del cp
-del linalg_all
-del _linalg
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py
index b49be6cf..f78aa8b3 100644
--- a/array_api_compat/dask/array/__init__.py
+++ b/array_api_compat/dask/array/__init__.py
@@ -1,9 +1,26 @@
-from dask.array import * # noqa: F403
+from typing import Final
+
+from ..._internal import clone_module
+
+__all__ = clone_module("dask.array", globals())
# These imports may overwrite names from the import * above.
-from ._aliases import * # noqa: F403
+from . import _aliases
+from ._aliases import * # type: ignore[assignment] # noqa: F403
+from ._info import __array_namespace_info__ # noqa: F401
-__array_api_version__ = '2023.12'
+__array_api_version__: Final = "2024.12"
+del Final
+# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
+
+__all__ = sorted(
+ set(__all__)
+ | set(_aliases.__all__)
+ | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
+)
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py
index ee2d88c0..54d323b2 100644
--- a/array_api_compat/dask/array/_aliases.py
+++ b/array_api_compat/dask/array/_aliases.py
@@ -1,66 +1,101 @@
-from __future__ import annotations
+# pyright: reportPrivateUsage=false
+# pyright: reportUnknownArgumentType=false
+# pyright: reportUnknownMemberType=false
+# pyright: reportUnknownVariableType=false
-from ...common import _aliases
-from ...common._helpers import _check_device
+from __future__ import annotations
-from ..._internal import get_xp
+from builtins import bool as py_bool
+from collections.abc import Callable
+from typing import TYPE_CHECKING, Any
-from ._info import __array_namespace_info__
+if TYPE_CHECKING:
+ from typing_extensions import TypeIs
+import dask.array as da
import numpy as np
+from numpy import bool_ as bool
from numpy import (
- # Dtypes
- iinfo,
- finfo,
- bool_ as bool,
+ can_cast,
+ complex64,
+ complex128,
float32,
float64,
int8,
int16,
int32,
int64,
+ result_type,
uint8,
uint16,
uint32,
uint64,
- complex64,
- complex128,
- can_cast,
- result_type,
)
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- from typing import Optional, Union
-
- from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
-
-import dask.array as da
+from ..._internal import get_xp
+from ...common import _aliases, _helpers, array_namespace
+from ...common._typing import (
+ Array,
+ Device,
+ DType,
+ NestedSequence,
+ SupportsBufferProtocol,
+)
isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)
-astype = _aliases.astype
+
+
+# da.astype doesn't respect copy=True
+def astype(
+ x: Array,
+ dtype: DType,
+ /,
+ *,
+ copy: py_bool = True,
+ device: Device | None = None,
+) -> Array:
+ """
+ Array API compatibility wrapper for astype().
+
+ See the corresponding documentation in the array library and/or the array API
+ specification for more details.
+ """
+ # TODO: respect device keyword?
+ _helpers._check_device(da, device)
+
+ if not copy and dtype == x.dtype:
+ return x
+ x = x.astype(dtype)
+ return x.copy() if copy else x
+
# Common aliases
+
# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask
-
-# TODO: delete the xp stuff, it shouldn't be necessary
-def _dask_arange(
- start: Union[int, float],
+def arange(
+ start: float,
/,
- stop: Optional[Union[int, float]] = None,
- step: Union[int, float] = 1,
+ stop: float | None = None,
+ step: float = 1,
*,
- xp,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object,
) -> Array:
- _check_device(xp, device)
- args = [start]
+ """
+ Array API compatibility wrapper for arange().
+
+ See the corresponding documentation in the array library and/or the array API
+ specification for more details.
+ """
+ # TODO: respect device keyword?
+ _helpers._check_device(da, device)
+
+ args: list[Any] = [start]
if stop is not None:
args.append(stop)
else:
@@ -68,13 +103,12 @@ def _dask_arange(
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
- return xp.arange(*args, dtype=dtype, **kwargs)
-arange = get_xp(da)(_dask_arange)
-eye = get_xp(da)(_aliases.eye)
+ return da.arange(*args, dtype=dtype, **kwargs)
+
-linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
+linspace = get_xp(da)(_aliases.linspace)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
@@ -86,6 +120,7 @@ def _dask_arange(
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
+cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
@@ -97,31 +132,23 @@ def _dask_arange(
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)
-
nonzero = get_xp(da)(_aliases.nonzero)
-ceil = get_xp(np)(_aliases.ceil)
-floor = get_xp(np)(_aliases.floor)
-trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)
+finfo = get_xp(np)(_aliases.finfo)
+iinfo = get_xp(np)(_aliases.iinfo)
+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
- obj: Union[
- Array,
- bool,
- int,
- float,
- NestedSequence[bool | int | float],
- SupportsBufferProtocol,
- ],
+ obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- copy: "Optional[Union[bool, np._CopyMode]]" = None,
- **kwargs,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ copy: py_bool | None = None,
+ **kwargs: object,
) -> Array:
"""
Array API compatibility wrapper for asarray().
@@ -129,89 +156,214 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
+ # TODO: respect device keyword?
+ _helpers._check_device(da, device)
+
+ if isinstance(obj, da.Array):
+ if dtype is not None and dtype != obj.dtype:
+ if copy is False:
+ raise ValueError("Unable to avoid copy when changing dtype")
+ obj = obj.astype(dtype)
+ return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
+
if copy is False:
- # copy=False is not yet implemented in dask
- raise NotImplementedError("copy=False is not yet implemented")
- elif copy is True:
- if isinstance(obj, da.Array) and dtype is None:
- return obj.copy()
- # Go through numpy, since dask copy is no-op by default
- obj = np.array(obj, dtype=dtype, copy=True)
- return da.array(obj, dtype=dtype)
- else:
- if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
- obj = np.asarray(obj, dtype=dtype)
- return da.from_array(obj)
- return obj
-
- return da.asarray(obj, dtype=dtype, **kwargs)
-
-from dask.array import (
- # Element wise aliases
- arccos as acos,
- arccosh as acosh,
- arcsin as asin,
- arcsinh as asinh,
- arctan as atan,
- arctan2 as atan2,
- arctanh as atanh,
- left_shift as bitwise_left_shift,
- right_shift as bitwise_right_shift,
- invert as bitwise_invert,
- power as pow,
- # Other
- concatenate as concat,
-)
+ raise ValueError(
+ "Unable to avoid copy when converting a non-dask object to dask"
+ )
+
+ # copy=None to be uniform across dask < 2024.12 and >= 2024.12
+ # see https://github.com/dask/dask/pull/11524/
+ obj = np.array(obj, dtype=dtype, copy=True)
+ return da.from_array(obj)
+
+
+# Element wise aliases
+from dask.array import arccos as acos
+from dask.array import arccosh as acosh
+from dask.array import arcsin as asin
+from dask.array import arcsinh as asinh
+from dask.array import arctan as atan
+from dask.array import arctan2 as atan2
+from dask.array import arctanh as atanh
+
+# Other
+from dask.array import concatenate as concat
+from dask.array import invert as bitwise_invert
+from dask.array import left_shift as bitwise_left_shift
+from dask.array import power as pow
+from dask.array import right_shift as bitwise_right_shift
+
# dask.array.clip does not work unless all three arguments are provided.
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
# now).
-@get_xp(da)
def clip(
x: Array,
/,
- min: Optional[Union[int, float, Array]] = None,
- max: Optional[Union[int, float, Array]] = None,
- *,
- xp,
+ min: float | Array | None = None,
+ max: float | Array | None = None,
) -> Array:
- def _isscalar(a):
- return isinstance(a, (int, float, type(None)))
+ """
+ Array API compatibility wrapper for clip().
+
+ See the corresponding documentation in the array library and/or the array API
+ specification for more details.
+ """
+
+ def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
+ return a is None or isinstance(a, (int, float))
+
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
# TODO: This won't handle dask unknown shapes
- import numpy as np
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
if min is not None:
- min = xp.broadcast_to(xp.asarray(min), result_shape)
+ min = da.broadcast_to(da.asarray(min), result_shape)
if max is not None:
- max = xp.broadcast_to(xp.asarray(max), result_shape)
+ max = da.broadcast_to(da.asarray(max), result_shape)
if min is None and max is None:
- return xp.positive(x)
+ return da.positive(x)
if min is None:
- return astype(xp.minimum(x, max), x.dtype)
+ return astype(da.minimum(x, max), x.dtype)
if max is None:
- return astype(xp.maximum(x, min), x.dtype)
+ return astype(da.maximum(x, min), x.dtype)
+
+ return astype(da.minimum(da.maximum(x, min), max), x.dtype)
+
+
+def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
+ """
+ Make sure that Array is not broken into multiple chunks along axis.
+
+ Returns
+ -------
+ x : Array
+ The input Array with a single chunk along axis.
+ restore : Callable[Array, Array]
+ function to apply to the output to rechunk it back into reasonable chunks
+ """
+ if axis < 0:
+ axis += x.ndim
+ if x.numblocks[axis] < 2:
+ return x, lambda x: x
+
+ # Break chunks on other axes in an attempt to keep chunk size low
+ x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
+
+ # Rather than reconstructing the original chunks, which can be a
+ # very expensive affair, just break down oversized chunks without
+ # incurring in any transfers over the network.
+ # This has the downside of a risk of overchunking if the array is
+ # then used in operations against other arrays that match the
+ # original chunking pattern.
+ return x, lambda x: x.rechunk()
+
+
+def sort(
+ x: Array,
+ /,
+ *,
+ axis: int = -1,
+ descending: py_bool = False,
+ stable: py_bool = True,
+) -> Array:
+ """
+ Array API compatibility layer around the lack of sort() in Dask.
+
+ Warnings
+ --------
+ This function temporarily rechunks the array along `axis` to a single chunk.
+ This can be extremely inefficient and can lead to out-of-memory errors.
+
+ See the corresponding documentation in the array library and/or the array API
+ specification for more details.
+ """
+ x, restore = _ensure_single_chunk(x, axis)
- return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
+ meta_xp = array_namespace(x._meta)
+ x = da.map_blocks(
+ meta_xp.sort,
+ x,
+ axis=axis,
+ meta=x._meta,
+ dtype=x.dtype,
+ descending=descending,
+ stable=stable,
+ )
-# exclude these from all since dask.array has no sorting functions
-_da_unsupported = ['sort', 'argsort']
+ return restore(x)
+
+
+def argsort(
+ x: Array,
+ /,
+ *,
+ axis: int = -1,
+ descending: py_bool = False,
+ stable: py_bool = True,
+) -> Array:
+ """
+ Array API compatibility layer around the lack of argsort() in Dask.
+
+ See the corresponding documentation in the array library and/or the array API
+ specification for more details.
+
+ Warnings
+ --------
+ This function temporarily rechunks the array along `axis` into a single chunk.
+ This can be extremely inefficient and can lead to out-of-memory errors.
+ """
+ x, restore = _ensure_single_chunk(x, axis)
+
+ meta_xp = array_namespace(x._meta)
+ dtype = meta_xp.argsort(x._meta).dtype
+ meta = meta_xp.astype(x._meta, dtype)
+ x = da.map_blocks(
+ meta_xp.argsort,
+ x,
+ axis=axis,
+ meta=meta,
+ dtype=dtype,
+ descending=descending,
+ stable=stable,
+ )
+
+ return restore(x)
+
+
+# dask.array.count_nonzero does not have keepdims
+def count_nonzero(
+ x: Array,
+ axis: int | None = None,
+ keepdims: py_bool = False,
+) -> Array:
+ result = da.count_nonzero(x, axis)
+ if keepdims:
+ if axis is None:
+ return da.reshape(result, [1] * x.ndim)
+ return da.expand_dims(result, axis)
+ return result
-_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
-__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
- 'acosh', 'asin', 'asinh', 'atan', 'atan2',
- 'atanh', 'bitwise_left_shift', 'bitwise_invert',
- 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
- 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
- 'uint8', 'uint16', 'uint32', 'uint64',
- 'complex64', 'complex128', 'iinfo', 'finfo',
- 'can_cast', 'result_type']
+__all__ = [
+ "count_nonzero",
+ "bool",
+ "int8", "int16", "int32", "int64",
+ "uint8", "uint16", "uint32", "uint64",
+ "float32", "float64",
+ "complex64", "complex128",
+ "asarray", "astype", "can_cast", "result_type",
+ "pow",
+ "concat",
+ "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
+ "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
+] # fmt: skip
+__all__ += _aliases.__all__
-_all_ignore = ["get_xp", "da", "np"]
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py
index d3b12dc9..2f39fc4b 100644
--- a/array_api_compat/dask/array/_info.py
+++ b/array_api_compat/dask/array/_info.py
@@ -7,25 +7,50 @@
more details.
"""
+
+# pyright: reportPrivateUsage=false
+
+from __future__ import annotations
+
+from typing import Literal, TypeAlias, overload
+
+import dask.array as da
+from numpy import bool_ as bool
from numpy import (
+ complex64,
+ complex128,
dtype,
- bool_ as bool,
- intp,
+ float32,
+ float64,
int8,
int16,
int32,
int64,
+ intp,
uint8,
uint16,
uint32,
uint64,
- float32,
- float64,
- complex64,
- complex128,
)
-from ...common._helpers import _DASK_DEVICE
+from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device
+from ...common._typing import (
+ Capabilities,
+ DefaultDTypes,
+ DType,
+ DTypeKind,
+ DTypesAll,
+ DTypesAny,
+ DTypesBool,
+ DTypesComplex,
+ DTypesIntegral,
+ DTypesNumeric,
+ DTypesReal,
+ DTypesSigned,
+ DTypesUnsigned,
+)
+Device: TypeAlias = Literal["cpu"] | _dask_device
+
class __array_namespace_info__:
"""
@@ -50,7 +75,7 @@ class __array_namespace_info__:
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': dask.float64,
'complex floating': dask.complex128,
@@ -59,20 +84,31 @@ class __array_namespace_info__:
"""
- __module__ = 'dask.array'
+ __module__ = "dask.array"
- def capabilities(self):
+ def capabilities(self) -> Capabilities:
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
- supports boolean indexing. Always ``False`` for Dask.
+ supports boolean indexing.
+
+ Dask support boolean indexing as long as both the index
+ and the indexed arrays have known shapes.
+ Note however that the output .shape and .size properties
+ will contain a non-compliant math.nan instead of None.
- **"data-dependent shapes"**: boolean indicating whether an array
- library supports data-dependent output shapes. Always ``False`` for
- Dask.
+ library supports data-dependent output shapes.
+
+ Dask implements unique_values et.al.
+ Note however that the output .shape and .size properties
+ will contain a non-compliant math.nan instead of None.
+
+ - **"max dimensions"**: integer indicating the maximum number of
+ dimensions supported by the array library.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
@@ -92,20 +128,20 @@ def capabilities(self):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
- 'data-dependent shapes': True}
+ 'data-dependent shapes': True,
+ 'max dimensions': 64}
"""
return {
- "boolean indexing": False,
- "data-dependent shapes": False,
- # 'max rank' will be part of the 2024.12 standard
- # "max rank": 64,
+ "boolean indexing": True,
+ "data-dependent shapes": True,
+ "max dimensions": 64,
}
- def default_device(self):
+ def default_device(self) -> Device:
"""
The default device used for new Dask arrays.
@@ -120,19 +156,19 @@ def default_device(self):
Returns
-------
- device : str
+ device : Device
The default device used for new Dask arrays.
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.default_device()
'cpu'
"""
return "cpu"
- def default_dtypes(self, *, device=None):
+ def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes:
"""
The default data types used for new Dask arrays.
@@ -163,7 +199,7 @@ def default_dtypes(self, *, device=None):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': dask.float64,
'complex floating': dask.complex128,
@@ -171,11 +207,7 @@ def default_dtypes(self, *, device=None):
'indexing': dask.int64}
"""
- if device not in ["cpu", _DASK_DEVICE, None]:
- raise ValueError(
- 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
- f' {device}'
- )
+ _check_device(da, device)
return {
"real floating": dtype(float64),
"complex floating": dtype(complex128),
@@ -183,7 +215,41 @@ def default_dtypes(self, *, device=None):
"indexing": dtype(intp),
}
- def dtypes(self, *, device=None, kind=None):
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: None = None
+ ) -> DTypesAll: ...
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: Literal["bool"]
+ ) -> DTypesBool: ...
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: Literal["signed integer"]
+ ) -> DTypesSigned: ...
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: Literal["unsigned integer"]
+ ) -> DTypesUnsigned: ...
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: Literal["integral"]
+ ) -> DTypesIntegral: ...
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: Literal["real floating"]
+ ) -> DTypesReal: ...
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: Literal["complex floating"]
+ ) -> DTypesComplex: ...
+ @overload
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: Literal["numeric"]
+ ) -> DTypesNumeric: ...
+ def dtypes(
+ self, /, *, device: Device | None = None, kind: DTypeKind | None = None
+ ) -> DTypesAny:
"""
The array API data types supported by Dask.
@@ -229,7 +295,7 @@ def dtypes(self, *, device=None, kind=None):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': dask.int8,
'int16': dask.int16,
@@ -237,11 +303,7 @@ def dtypes(self, *, device=None, kind=None):
'int64': dask.int64}
"""
- if device not in ["cpu", _DASK_DEVICE, None]:
- raise ValueError(
- 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
- f' {device}'
- )
+ _check_device(da, device)
if kind is None:
return {
"bool": dtype(bool),
@@ -311,13 +373,13 @@ def dtypes(self, *, device=None, kind=None):
"complex128": dtype(complex128),
}
if isinstance(kind, tuple):
- res = {}
+ res: dict[str, DType] = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
- def devices(self):
+ def devices(self) -> list[Device]:
"""
The devices supported by Dask.
@@ -325,7 +387,7 @@ def devices(self):
Returns
-------
- devices : list of str
+ devices : list[Device]
The devices supported by Dask.
See Also
@@ -337,7 +399,7 @@ def devices(self):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.devices()
['cpu', DASK_DEVICE]
diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py
index aebd86f7..44b68e73 100644
--- a/array_api_compat/dask/array/fft.py
+++ b/array_api_compat/dask/array/fft.py
@@ -1,12 +1,6 @@
-from dask.array.fft import * # noqa: F403
-# dask.array.fft doesn't have __all__. If it is added, replace this with
-#
-# from dask.array.fft import __all__ as linalg_all
-_n = {}
-exec('from dask.array.fft import *', _n)
-del _n['__builtins__']
-fft_all = list(_n)
-del _n
+from ..._internal import clone_module
+
+__all__ = clone_module("dask.array.fft", globals())
from ...common import _fft
from ..._internal import get_xp
@@ -16,9 +10,7 @@
fftfreq = get_xp(da)(_fft.fftfreq)
rfftfreq = get_xp(da)(_fft.rfftfreq)
-__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"]
+__all__ += ["fftfreq", "rfftfreq"]
-del get_xp
-del da
-del fft_all
-del _fft
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py
index 49c26d8b..6b3c1011 100644
--- a/array_api_compat/dask/array/linalg.py
+++ b/array_api_compat/dask/array/linalg.py
@@ -1,33 +1,20 @@
from __future__ import annotations
-from ...common import _linalg
-from ..._internal import get_xp
+from typing import Literal
-# Exports
-from dask.array.linalg import * # noqa: F403
-from dask.array import outer
+import dask.array as da
-# These functions are in both the main and linalg namespaces
-from dask.array import matmul, tensordot
-from ._aliases import matrix_transpose, vecdot
+# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
+from dask.array import matmul, outer, tensordot
-import dask.array as da
+# Exports
+from ..._internal import clone_module, get_xp
+from ...common import _linalg
+from ...common._typing import Array
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- from ...common._typing import Array
- from typing import Literal
+__all__ = clone_module("dask.array.linalg", globals())
-# dask.array.linalg doesn't have __all__. If it is added, replace this with
-#
-# from dask.array.linalg import __all__ as linalg_all
-_n = {}
-exec('from dask.array.linalg import *', _n)
-del _n['__builtins__']
-if 'annotations' in _n:
- del _n['annotations']
-linalg_all = list(_n)
-del _n
+from ._aliases import matrix_transpose, vecdot
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
@@ -37,8 +24,11 @@
# supports the mode keyword on QR
# https://github.com/dask/dask/issues/10388
#qr = get_xp(da)(_linalg.qr)
-def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
- **kwargs) -> QRResult:
+def qr( # type: ignore[no-redef]
+ x: Array,
+ mode: Literal["reduced", "complete"] = "reduced",
+ **kwargs: object,
+) -> QRResult:
if mode != "reduced":
raise ValueError("dask arrays only support using mode='reduced'")
return QRResult(*da.linalg.qr(x, **kwargs))
@@ -51,7 +41,7 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
# Wrap the svd functions to not pass full_matrices to dask
# when full_matrices=False (as that is the default behavior for dask),
# and dask doesn't have the full_matrices keyword
-def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
+def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef]
if full_matrices:
raise ValueError("full_matrics=True is not supported by dask.")
return da.linalg.svd(x, coerce_signs=False, **kwargs)
@@ -64,10 +54,11 @@ def svdvals(x: Array) -> Array:
vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)
-__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
- "matrix_transpose", "vecdot", "EighResult",
- "QRResult", "SlogdetResult", "SVDResult", "qr",
- "cholesky", "matrix_rank", "matrix_norm", "svdvals",
- "vector_norm", "diagonal"]
+__all__ += ["trace", "outer", "matmul", "tensordot",
+ "matrix_transpose", "vecdot", "EighResult",
+ "QRResult", "SlogdetResult", "SVDResult", "qr",
+ "cholesky", "matrix_rank", "matrix_norm", "svdvals",
+ "vector_norm", "diagonal"]
-_all_ignore = ['get_xp', 'da', 'linalg_all']
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py
index 9bdbf312..23379e44 100644
--- a/array_api_compat/numpy/__init__.py
+++ b/array_api_compat/numpy/__init__.py
@@ -1,10 +1,17 @@
-from numpy import * # noqa: F403
+# ruff: noqa: PLC0414
+from typing import Final
-# from numpy import * doesn't overwrite these builtin names
-from numpy import abs, max, min, round # noqa: F401
+from .._internal import clone_module
+
+# This needs to be loaded explicitly before cloning
+import numpy.typing # noqa: F401
+
+__all__ = clone_module("numpy", globals())
# These imports may overwrite names from the import * above.
-from ._aliases import * # noqa: F403
+from . import _aliases
+from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403
+from ._info import __array_namespace_info__ # noqa: F401
# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
@@ -13,18 +20,19 @@
#
# It doesn't overwrite np.linalg from above. The import is generated
# dynamically so that the library can be vendored.
-__import__(__package__ + '.linalg')
+__import__(__package__ + ".linalg")
-__import__(__package__ + '.fft')
+__import__(__package__ + ".fft")
-from .linalg import matrix_transpose, vecdot # noqa: F401
+from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401
-from ..common._helpers import * # noqa: F403
+__array_api_version__: Final = "2024.12"
-try:
- # Used in asarray(). Not present in older versions.
- from numpy import _CopyMode # noqa: F401
-except ImportError:
- pass
+__all__ = sorted(
+ set(__all__)
+ | set(_aliases.__all__)
+ | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
+)
-__array_api_version__ = '2023.12'
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 2bfc98ff..87b3c2f3 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -1,17 +1,16 @@
+# pyright: reportPrivateUsage=false
from __future__ import annotations
-from ..common import _aliases
+from builtins import bool as py_bool
+from typing import Any, cast
-from .._internal import get_xp
-
-from ._info import __array_namespace_info__
+import numpy as np
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- from typing import Optional, Union
- from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
+from .._internal import get_xp
+from ..common import _aliases, _helpers
+from ..common._typing import NestedSequence, SupportsBufferProtocol
+from ._typing import Array, Device, DType
-import numpy as np
bool = np.bool_
# Basic renames
@@ -46,96 +45,147 @@
unique_counts = get_xp(np)(_aliases.unique_counts)
unique_inverse = get_xp(np)(_aliases.unique_inverse)
unique_values = get_xp(np)(_aliases.unique_values)
-astype = _aliases.astype
std = get_xp(np)(_aliases.std)
var = get_xp(np)(_aliases.var)
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
+cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
clip = get_xp(np)(_aliases.clip)
permute_dims = get_xp(np)(_aliases.permute_dims)
reshape = get_xp(np)(_aliases.reshape)
argsort = get_xp(np)(_aliases.argsort)
sort = get_xp(np)(_aliases.sort)
nonzero = get_xp(np)(_aliases.nonzero)
-ceil = get_xp(np)(_aliases.ceil)
-floor = get_xp(np)(_aliases.floor)
-trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)
+finfo = get_xp(np)(_aliases.finfo)
+iinfo = get_xp(np)(_aliases.iinfo)
-def _supports_buffer_protocol(obj):
- try:
- memoryview(obj)
- except TypeError:
- return False
- return True
# asarray also adds the copy keyword, which is not present in numpy 1.0.
# asarray() is different enough between numpy, cupy, and dask, the logic
# complicated enough that it's easier to define it separately for each module
# rather than trying to combine everything into one function in common/
def asarray(
- obj: Union[
- ndarray,
- bool,
- int,
- float,
- NestedSequence[bool | int | float],
- SupportsBufferProtocol,
- ],
+ obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- copy: "Optional[Union[bool, np._CopyMode]]" = None,
- **kwargs,
-) -> ndarray:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ copy: py_bool | None = None,
+ **kwargs: Any,
+) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device for NumPy: {device!r}")
-
- if hasattr(np, '_CopyMode'):
- if copy is None:
- copy = np._CopyMode.IF_NEEDED
- elif copy is False:
- copy = np._CopyMode.NEVER
- elif copy is True:
- copy = np._CopyMode.ALWAYS
- else:
- # Not present in older NumPys. In this case, we cannot really support
- # copy=False.
- if copy is False:
- raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.")
+ _helpers._check_device(np, device)
+
+ # None is unsupported in NumPy 1.0, but we can use an internal enum
+ # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API
+ if copy is None:
+ copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined]
+ elif copy is False:
+ copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined]
return np.array(obj, copy=copy, dtype=dtype, **kwargs)
+
+def astype(
+ x: Array,
+ dtype: DType,
+ /,
+ *,
+ copy: py_bool = True,
+ device: Device | None = None,
+) -> Array:
+ _helpers._check_device(np, device)
+ return x.astype(dtype=dtype, copy=copy)
+
+
+# count_nonzero returns a python int for axis=None and keepdims=False
+# https://github.com/numpy/numpy/issues/17562
+def count_nonzero(
+ x: Array,
+ axis: int | tuple[int, ...] | None = None,
+ keepdims: py_bool = False,
+) -> Array:
+ # NOTE: this is currently incorrectly typed in numpy, but will be fixed in
+ # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
+ result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue]
+ if axis is None and not keepdims:
+ return np.asarray(result)
+ return result
+
+
+# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
+ return np.take_along_axis(x, indices, axis=axis)
+
+
+# ceil, floor, and trunc return integers for integer inputs in NumPy < 2
+
+def ceil(x: Array, /) -> Array:
+ if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
+ return x.copy()
+ return np.ceil(x)
+
+
+def floor(x: Array, /) -> Array:
+ if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
+ return x.copy()
+ return np.floor(x)
+
+
+def trunc(x: Array, /) -> Array:
+ if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
+ return x.copy()
+ return np.trunc(x)
+
+
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
-if hasattr(np, 'vecdot'):
+if hasattr(np, "vecdot"):
vecdot = np.vecdot
else:
- vecdot = get_xp(np)(_aliases.vecdot)
+ vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment]
-if hasattr(np, 'isdtype'):
+if hasattr(np, "isdtype"):
isdtype = np.isdtype
else:
isdtype = get_xp(np)(_aliases.isdtype)
-if hasattr(np, 'unstack'):
+if hasattr(np, "unstack"):
unstack = np.unstack
else:
unstack = get_xp(np)(_aliases.unstack)
-__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
- 'acos', 'acosh', 'asin', 'asinh', 'atan',
- 'atan2', 'atanh', 'bitwise_left_shift',
- 'bitwise_invert', 'bitwise_right_shift',
- 'concat', 'pow']
-
-_all_ignore = ['np', 'get_xp']
+__all__ = _aliases.__all__ + [
+ "asarray",
+ "astype",
+ "acos",
+ "acosh",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "ceil",
+ "floor",
+ "trunc",
+ "bitwise_left_shift",
+ "bitwise_invert",
+ "bitwise_right_shift",
+ "bool",
+ "concat",
+ "count_nonzero",
+ "pow",
+ "take_along_axis"
+]
+
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index 62f7ae62..c625c13e 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -7,24 +7,29 @@
more details.
"""
+from __future__ import annotations
+
+from numpy import bool_ as bool
from numpy import (
+ complex64,
+ complex128,
dtype,
- bool_ as bool,
- intp,
+ float32,
+ float64,
int8,
int16,
int32,
int64,
+ intp,
uint8,
uint16,
uint32,
uint64,
- float32,
- float64,
- complex64,
- complex128,
)
+from ..common._typing import DefaultDTypes
+from ._typing import Device, DType
+
class __array_namespace_info__:
"""
@@ -94,14 +99,14 @@ def capabilities(self):
>>> info = np.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
- 'data-dependent shapes': True}
+ 'data-dependent shapes': True,
+ 'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
- # 'max rank' will be part of the 2024.12 standard
- # "max rank": 64,
+ "max dimensions": 64,
}
def default_device(self):
@@ -119,7 +124,7 @@ def default_device(self):
Returns
-------
- device : str
+ device : Device
The default device used for new NumPy arrays.
Examples
@@ -131,7 +136,11 @@ def default_device(self):
"""
return "cpu"
- def default_dtypes(self, *, device=None):
+ def default_dtypes(
+ self,
+ *,
+ device: Device | None = None,
+ ) -> DefaultDTypes:
"""
The default data types used for new NumPy arrays.
@@ -183,7 +192,12 @@ def default_dtypes(self, *, device=None):
"indexing": dtype(intp),
}
- def dtypes(self, *, device=None, kind=None):
+ def dtypes(
+ self,
+ *,
+ device: Device | None = None,
+ kind: str | tuple[str, ...] | None = None,
+ ) -> dict[str, DType]:
"""
The array API data types supported by NumPy.
@@ -260,7 +274,7 @@ def dtypes(self, *, device=None, kind=None):
"complex128": dtype(complex128),
}
if kind == "bool":
- return {"bool": bool}
+ return {"bool": dtype(bool)}
if kind == "signed integer":
return {
"int8": dtype(int8),
@@ -312,13 +326,13 @@ def dtypes(self, *, device=None, kind=None):
"complex128": dtype(complex128),
}
if isinstance(kind, tuple):
- res = {}
+ res: dict[str, DType] = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
- def devices(self):
+ def devices(self) -> list[Device]:
"""
The devices supported by NumPy.
@@ -326,7 +340,7 @@ def devices(self):
Returns
-------
- devices : list of str
+ devices : list[Device]
The devices supported by NumPy.
See Also
@@ -344,3 +358,10 @@ def devices(self):
"""
return ["cpu"]
+
+
+__all__ = ["__array_namespace_info__"]
+
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index c5ebb5ab..b5fa188c 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -1,46 +1,29 @@
from __future__ import annotations
-__all__ = [
- "ndarray",
- "Device",
- "Dtype",
-]
-
-import sys
-from typing import (
- Literal,
- Union,
- TYPE_CHECKING,
-)
-
-from numpy import (
- ndarray,
- dtype,
- int8,
- int16,
- int32,
- int64,
- uint8,
- uint16,
- uint32,
- uint64,
- float32,
- float64,
-)
-
-Device = Literal["cpu"]
-if TYPE_CHECKING or sys.version_info >= (3, 9):
- Dtype = dtype[Union[
- int8,
- int16,
- int32,
- int64,
- uint8,
- uint16,
- uint32,
- uint64,
- float32,
- float64,
- ]]
+from typing import TYPE_CHECKING, Any, Literal, TypeAlias
+
+import numpy as np
+
+Device: TypeAlias = Literal["cpu"]
+
+if TYPE_CHECKING:
+
+ # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
+ DType: TypeAlias = np.dtype[
+ np.bool_
+ | np.integer[Any]
+ | np.float32
+ | np.float64
+ | np.complex64
+ | np.complex128
+ ]
+ Array: TypeAlias = np.ndarray[Any, DType]
else:
- Dtype = dtype
+ DType: TypeAlias = np.dtype
+ Array: TypeAlias = np.ndarray
+
+__all__ = ["Array", "DType", "Device"]
+
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py
index 28667594..a492feb8 100644
--- a/array_api_compat/numpy/fft.py
+++ b/array_api_compat/numpy/fft.py
@@ -1,10 +1,11 @@
-from numpy.fft import * # noqa: F403
-from numpy.fft import __all__ as fft_all
+import numpy as np
-from ..common import _fft
-from .._internal import get_xp
+from .._internal import clone_module
-import numpy as np
+__all__ = clone_module("numpy.fft", globals())
+
+from .._internal import get_xp
+from ..common import _fft
fft = get_xp(np)(_fft.fft)
ifft = get_xp(np)(_fft.ifft)
@@ -21,9 +22,9 @@
fftshift = get_xp(np)(_fft.fftshift)
ifftshift = get_xp(np)(_fft.ifftshift)
-__all__ = fft_all + _fft.__all__
-del get_xp
-del np
-del fft_all
-del _fft
+__all__ = sorted(set(__all__) | set(_fft.__all__))
+
+def __dir__() -> list[str]:
+ return __all__
+
diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py
index 8f01593b..7168441c 100644
--- a/array_api_compat/numpy/linalg.py
+++ b/array_api_compat/numpy/linalg.py
@@ -1,14 +1,20 @@
-from numpy.linalg import * # noqa: F403
-from numpy.linalg import __all__ as linalg_all
-import numpy as _np
+# pyright: reportAttributeAccessIssue=false
+# pyright: reportUnknownArgumentType=false
+# pyright: reportUnknownMemberType=false
+# pyright: reportUnknownVariableType=false
+from __future__ import annotations
+
+import numpy as np
+
+from .._internal import clone_module, get_xp
from ..common import _linalg
-from .._internal import get_xp
-# These functions are in both the main and linalg namespaces
-from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
+__all__ = clone_module("numpy.linalg", globals())
-import numpy as np
+# These functions are in both the main and linalg namespaces
+from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
+from ._typing import Array
cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
@@ -38,19 +44,28 @@
# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.
+
# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
-def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
+def solve(x1: Array, x2: Array, /) -> Array:
try:
- from numpy.linalg._linalg import (
- _makearray, _assert_stacked_2d, _assert_stacked_square,
- _commonType, isComplexType, _raise_linalgerror_singular
+ from numpy.linalg._linalg import ( # type: ignore[attr-defined]
+ _assert_stacked_2d,
+ _assert_stacked_square,
+ _commonType,
+ _makearray,
+ _raise_linalgerror_singular,
+ isComplexType,
)
except ImportError:
- from numpy.linalg.linalg import (
- _makearray, _assert_stacked_2d, _assert_stacked_square,
- _commonType, isComplexType, _raise_linalgerror_singular
+ from numpy.linalg.linalg import ( # type: ignore[attr-defined]
+ _assert_stacked_2d,
+ _assert_stacked_square,
+ _commonType,
+ _makearray,
+ _raise_linalgerror_singular,
+ isComplexType,
)
from numpy.linalg import _umath_linalg
@@ -61,6 +76,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
t, result_t = _commonType(x1, x2)
# This part is different from np.linalg.solve
+ gufunc: np.ufunc
if x2.ndim == 1:
gufunc = _umath_linalg.solve1
else:
@@ -68,23 +84,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
- signature = 'DD->D' if isComplexType(t) else 'dd->d'
- with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
- over='ignore', divide='ignore', under='ignore'):
- r = gufunc(x1, x2, signature=signature)
+ signature = "DD->D" if isComplexType(t) else "dd->d"
+ with np.errstate(
+ call=_raise_linalgerror_singular,
+ invalid="call",
+ over="ignore",
+ divide="ignore",
+ under="ignore",
+ ):
+ r: Array = gufunc(x1, x2, signature=signature)
return wrap(r.astype(result_t, copy=False))
+
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
-if hasattr(np.linalg, 'vector_norm'):
+if hasattr(np.linalg, "vector_norm"):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
-__all__ = linalg_all + _linalg.__all__ + ['solve']
-del get_xp
-del np
-del linalg_all
-del _linalg
+_all = [
+ "LinAlgError",
+ "cond",
+ "det",
+ "eig",
+ "eigvals",
+ "eigvalsh",
+ "inv",
+ "lstsq",
+ "matrix_power",
+ "multi_dot",
+ "norm",
+ "solve",
+ "tensorinv",
+ "tensorsolve",
+ "vector_norm",
+]
+__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all))
+
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed
new file mode 100644
index 00000000..e69de29b
diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py
index cfa3acf8..6cbb6ec2 100644
--- a/array_api_compat/torch/__init__.py
+++ b/array_api_compat/torch/__init__.py
@@ -1,24 +1,25 @@
-from torch import * # noqa: F403
+from typing import Final
-# Several names are not included in the above import *
-import torch
-for n in dir(torch):
- if (n.startswith('_')
- or n.endswith('_')
- or 'cuda' in n
- or 'cpu' in n
- or 'backward' in n):
- continue
- exec(n + ' = torch.' + n)
+from .._internal import clone_module
+
+__all__ = clone_module("torch", globals())
# These imports may overwrite names from the import * above.
+from . import _aliases
from ._aliases import * # noqa: F403
+from ._info import __array_namespace_info__ # noqa: F401
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
-
__import__(__package__ + '.fft')
-from ..common._helpers import * # noqa: F403
+__array_api_version__: Final = '2024.12'
+
+__all__ = sorted(
+ set(__all__)
+ | set(_aliases.__all__)
+ | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
+)
-__array_api_version__ = '2023.12'
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 5ac66bcb..91161d24 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -1,27 +1,16 @@
from __future__ import annotations
-from functools import wraps as _wraps
+from collections.abc import Sequence
+from functools import reduce as _reduce, wraps as _wraps
from builtins import all as _builtin_all, any as _builtin_any
-
-from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
- vecdot as _aliases_vecdot,
- clip as _aliases_clip,
- unstack as _aliases_unstack,
- cumulative_sum as _aliases_cumulative_sum,
- )
-from .._internal import get_xp
-
-from ._info import __array_namespace_info__
+from typing import Any, Literal
import torch
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- from typing import List, Optional, Sequence, Tuple, Union
- from ..common._typing import Device
- from torch import dtype as Dtype
-
- array = torch.Tensor
+from .._internal import get_xp
+from ..common import _aliases
+from ..common._typing import NestedSequence, SupportsBufferProtocol
+from ._typing import Array, Device, DType
_int_dtypes = {
torch.uint8,
@@ -30,6 +19,12 @@
torch.int32,
torch.int64,
}
+try:
+ # torch >=2.3
+ _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
+except AttributeError:
+ pass
+
_array_api_dtypes = {
torch.bool,
@@ -40,47 +35,23 @@
torch.complex128,
}
-_promotion_table = {
- # bool
- (torch.bool, torch.bool): torch.bool,
+_promotion_table = {
# ints
- (torch.int8, torch.int8): torch.int8,
(torch.int8, torch.int16): torch.int16,
(torch.int8, torch.int32): torch.int32,
(torch.int8, torch.int64): torch.int64,
- (torch.int16, torch.int8): torch.int16,
- (torch.int16, torch.int16): torch.int16,
(torch.int16, torch.int32): torch.int32,
(torch.int16, torch.int64): torch.int64,
- (torch.int32, torch.int8): torch.int32,
- (torch.int32, torch.int16): torch.int32,
- (torch.int32, torch.int32): torch.int32,
(torch.int32, torch.int64): torch.int64,
- (torch.int64, torch.int8): torch.int64,
- (torch.int64, torch.int16): torch.int64,
- (torch.int64, torch.int32): torch.int64,
- (torch.int64, torch.int64): torch.int64,
- # uints
- (torch.uint8, torch.uint8): torch.uint8,
# ints and uints (mixed sign)
- (torch.int8, torch.uint8): torch.int16,
- (torch.int16, torch.uint8): torch.int16,
- (torch.int32, torch.uint8): torch.int32,
- (torch.int64, torch.uint8): torch.int64,
(torch.uint8, torch.int8): torch.int16,
(torch.uint8, torch.int16): torch.int16,
(torch.uint8, torch.int32): torch.int32,
(torch.uint8, torch.int64): torch.int64,
# floats
- (torch.float32, torch.float32): torch.float32,
(torch.float32, torch.float64): torch.float64,
- (torch.float64, torch.float32): torch.float64,
- (torch.float64, torch.float64): torch.float64,
# complexes
- (torch.complex64, torch.complex64): torch.complex64,
(torch.complex64, torch.complex128): torch.complex128,
- (torch.complex128, torch.complex64): torch.complex128,
- (torch.complex128, torch.complex128): torch.complex128,
# Mixed float and complex
(torch.float32, torch.complex64): torch.complex64,
(torch.float32, torch.complex128): torch.complex128,
@@ -88,6 +59,9 @@
(torch.float64, torch.complex128): torch.complex128,
}
+_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
+_promotion_table.update({(a, a): a for a in _array_api_dtypes})
+
def _two_arg(f):
@_wraps(f)
@@ -118,23 +92,50 @@ def _fix_promotion(x1, x2, only_scalar=True):
x1 = x1.to(dtype)
return x1, x2
-def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
- if len(arrays_and_dtypes) == 0:
- raise TypeError("At least one array or dtype must be provided")
- if len(arrays_and_dtypes) == 1:
+
+_py_scalars = (bool, int, float, complex)
+
+
+def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType:
+ num = len(arrays_and_dtypes)
+
+ if num == 0:
+ raise ValueError("At least one array or dtype must be provided")
+
+ elif num == 1:
x = arrays_and_dtypes[0]
if isinstance(x, torch.dtype):
return x
return x.dtype
- if len(arrays_and_dtypes) > 2:
- return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
- x, y = arrays_and_dtypes
- xdt = x.dtype if not isinstance(x, torch.dtype) else x
- ydt = y.dtype if not isinstance(y, torch.dtype) else y
+ if num == 2:
+ x, y = arrays_and_dtypes
+ return _result_type(x, y)
+
+ else:
+ # sort scalars so that they are treated last
+ scalars, others = [], []
+ for x in arrays_and_dtypes:
+ if isinstance(x, _py_scalars):
+ scalars.append(x)
+ else:
+ others.append(x)
+ if not others:
+ raise ValueError("At least one array or dtype must be provided")
+
+ # combine left-to-right
+ return _reduce(_result_type, others + scalars)
+
+
+def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType:
+ if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
+ xdt = x if isinstance(x, torch.dtype) else x.dtype
+ ydt = y if isinstance(y, torch.dtype) else y.dtype
- if (xdt, ydt) in _promotion_table:
- return _promotion_table[xdt, ydt]
+ try:
+ return _promotion_table[xdt, ydt]
+ except KeyError:
+ pass
# This doesn't result_type(dtype, dtype) for non-array API dtypes
# because torch.result_type only accepts tensors. This does however, allow
@@ -143,7 +144,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
return torch.result_type(x, y)
-def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
+
+def can_cast(from_: DType | Array, to: DType, /) -> bool:
if not isinstance(from_, torch.dtype):
from_ = from_.dtype
return torch.can_cast(from_, to)
@@ -185,29 +187,58 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
remainder = _two_arg(torch.remainder)
subtract = _two_arg(torch.subtract)
+
+def asarray(
+ obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
+ /,
+ *,
+ dtype: DType | None = None,
+ device: Device | None = None,
+ copy: bool | None = None,
+ **kwargs: Any,
+) -> Array:
+ # torch.asarray does not respect input->output device propagation
+ # https://github.com/pytorch/pytorch/issues/150199
+ if device is None and isinstance(obj, torch.Tensor):
+ device = obj.device
+ return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
+
+
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
# of 'axis'.
# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745
-def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
+def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
return torch.amax(x, axis, keepdims=keepdims)
-def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
+def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
return torch.amin(x, axis, keepdims=keepdims)
-clip = get_xp(torch)(_aliases_clip)
-unstack = get_xp(torch)(_aliases_unstack)
-cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
+clip = get_xp(torch)(_aliases.clip)
+unstack = get_xp(torch)(_aliases.unstack)
+cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
+cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
+finfo = get_xp(torch)(_aliases.finfo)
+iinfo = get_xp(torch)(_aliases.iinfo)
+
# torch.sort also returns a tuple
# https://github.com/pytorch/pytorch/issues/70921
-def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
+def sort(
+ x: Array,
+ /,
+ *,
+ axis: int = -1,
+ descending: bool = False,
+ stable: bool = True,
+ **kwargs: object,
+) -> Array:
return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
def _normalize_axes(axis, ndim):
@@ -252,28 +283,35 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
out = torch.unsqueeze(out, a)
return out
-def prod(x: array,
+
+def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
+ """
+ Implements `sum(..., axis=())` and `prod(..., axis=())`.
+
+ Works around https://github.com/pytorch/pytorch/issues/29137
+ """
+ if dtype is not None:
+ return x.clone() if dtype == x.dtype else x.to(dtype)
+
+ # We can't upcast uint8 according to the spec because there is no
+ # torch.uint64, so at least upcast to int64 which is what prod does
+ # when axis=None.
+ if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
+ return x.to(torch.int64)
+
+ return x.clone()
+
+
+def prod(x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
- dtype: Optional[Dtype] = None,
+ axis: int | tuple[int, ...] | None = None,
+ dtype: DType | None = None,
keepdims: bool = False,
- **kwargs) -> array:
- x = torch.asarray(x)
- ndim = x.ndim
+ **kwargs: object) -> Array:
- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
- # below because it still needs to upcast.
if axis == ():
- if dtype is None:
- # We can't upcast uint8 according to the spec because there is no
- # torch.uint64, so at least upcast to int64 which is what sum does
- # when axis=None.
- if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
- return x.to(torch.int64)
- return x.clone()
- return x.to(dtype)
-
+ return _sum_prod_no_axis(x, dtype)
# torch.prod doesn't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
@@ -282,51 +320,38 @@ def prod(x: array,
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.prod(x, dtype=dtype, **kwargs)
- res = _axis_none_keepdims(res, ndim, keepdims)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
-def sum(x: array,
+def sum(x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
- dtype: Optional[Dtype] = None,
+ axis: int | tuple[int, ...] | None = None,
+ dtype: DType | None = None,
keepdims: bool = False,
- **kwargs) -> array:
- x = torch.asarray(x)
- ndim = x.ndim
+ **kwargs: object) -> Array:
- # https://github.com/pytorch/pytorch/issues/29137.
- # Make sure it upcasts.
if axis == ():
- if dtype is None:
- # We can't upcast uint8 according to the spec because there is no
- # torch.uint64, so at least upcast to int64 which is what sum does
- # when axis=None.
- if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
- return x.to(torch.int64)
- return x.clone()
- return x.to(dtype)
-
+ return _sum_prod_no_axis(x, dtype)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.sum(x, dtype=dtype, **kwargs)
- res = _axis_none_keepdims(res, ndim, keepdims)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
-def any(x: array,
+def any(x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
- **kwargs) -> array:
- x = torch.asarray(x)
- ndim = x.ndim
+ **kwargs: object) -> Array:
+
if axis == ():
return x.to(torch.bool)
# torch.any doesn't support multiple axes
@@ -338,20 +363,19 @@ def any(x: array,
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.any(x, **kwargs)
- res = _axis_none_keepdims(res, ndim, keepdims)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
return res.to(torch.bool)
# torch.any doesn't return bool for uint8
return torch.any(x, axis, keepdims=keepdims).to(torch.bool)
-def all(x: array,
+def all(x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
- **kwargs) -> array:
- x = torch.asarray(x)
- ndim = x.ndim
+ **kwargs: object) -> Array:
+
if axis == ():
return x.to(torch.bool)
# torch.all doesn't support multiple axes
@@ -363,18 +387,18 @@ def all(x: array,
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.all(x, **kwargs)
- res = _axis_none_keepdims(res, ndim, keepdims)
+ res = _axis_none_keepdims(res, x.ndim, keepdims)
return res.to(torch.bool)
# torch.all doesn't return bool for uint8
return torch.all(x, axis, keepdims=keepdims).to(torch.bool)
-def mean(x: array,
+def mean(x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
- **kwargs) -> array:
+ **kwargs: object) -> Array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
@@ -386,13 +410,13 @@ def mean(x: array,
return res
return torch.mean(x, axis, keepdims=keepdims, **kwargs)
-def std(x: array,
+def std(x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
- correction: Union[int, float] = 0.0,
+ axis: int | tuple[int, ...] | None = None,
+ correction: float = 0.0,
keepdims: bool = False,
- **kwargs) -> array:
+ **kwargs: object) -> Array:
# Note, float correction is not supported
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
@@ -417,13 +441,13 @@ def std(x: array,
return res
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
-def var(x: array,
+def var(x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
- correction: Union[int, float] = 0.0,
+ axis: int | tuple[int, ...] | None = None,
+ correction: float = 0.0,
keepdims: bool = False,
- **kwargs) -> array:
+ **kwargs: object) -> Array:
# Note, float correction is not supported
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
@@ -446,11 +470,11 @@ def var(x: array,
# torch.concat doesn't support dim=None
# https://github.com/pytorch/pytorch/issues/70925
-def concat(arrays: Union[Tuple[array, ...], List[array]],
+def concat(arrays: tuple[Array, ...] | list[Array],
/,
*,
- axis: Optional[int] = 0,
- **kwargs) -> array:
+ axis: int | None = 0,
+ **kwargs: object) -> Array:
if axis is None:
arrays = tuple(ar.flatten() for ar in arrays)
axis = 0
@@ -459,7 +483,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]],
# torch.squeeze only accepts int dim and doesn't require it
# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
# added at https://github.com/pytorch/pytorch/pull/89017.
-def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
+def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array:
if isinstance(axis, int):
axis = (axis,)
for a in axis:
@@ -473,41 +497,83 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
return x
# torch.broadcast_to uses size instead of shape
-def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
+def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array:
return torch.broadcast_to(x, shape, **kwargs)
# torch.permute uses dims instead of axes
-def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
+def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
return torch.permute(x, axes)
# The axis parameter doesn't work for flip() and roll()
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
# accept axis=None
-def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
+def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array:
if axis is None:
axis = tuple(range(x.ndim))
# torch.flip doesn't accept dim as an int but the method does
# https://github.com/pytorch/pytorch/issues/18095
return x.flip(axis, **kwargs)
-def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
+def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array:
return torch.roll(x, shift, axis, **kwargs)
-def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
+def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return torch.nonzero(x, as_tuple=True, **kwargs)
-def where(condition: array, x1: array, x2: array, /) -> array:
+
+# torch uses `dim` instead of `axis`
+def diff(
+ x: Array,
+ /,
+ *,
+ axis: int = -1,
+ n: int = 1,
+ prepend: Array | None = None,
+ append: Array | None = None,
+) -> Array:
+ return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
+
+
+# torch uses `dim` instead of `axis`, does not have keepdims
+def count_nonzero(
+ x: Array,
+ /,
+ *,
+ axis: int | tuple[int, ...] | None = None,
+ keepdims: bool = False,
+) -> Array:
+ result = torch.count_nonzero(x, dim=axis)
+ if keepdims:
+ if isinstance(axis, int):
+ return result.unsqueeze(axis)
+ elif isinstance(axis, tuple):
+ n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
+ sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
+ return torch.reshape(result, sh)
+ return _axis_none_keepdims(result, x.ndim, keepdims)
+ else:
+ return result
+
+
+# "repeat" is torch.repeat_interleave; also the dim argument
+def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
+ return torch.repeat_interleave(x, repeats, axis)
+
+
+def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array:
x1, x2 = _fix_promotion(x1, x2)
return torch.where(condition, x1, x2)
+
# torch.reshape doesn't have the copy keyword
-def reshape(x: array,
+def reshape(x: Array,
/,
- shape: Tuple[int, ...],
- copy: Optional[bool] = None,
- **kwargs) -> array:
+ shape: tuple[int, ...],
+ *,
+ copy: bool | None = None,
+ **kwargs: object) -> Array:
if copy is not None:
raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
return torch.reshape(x, shape, **kwargs)
@@ -516,14 +582,14 @@ def reshape(x: array,
# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
# keyword argument combinations
# (https://github.com/pytorch/pytorch/issues/70914)
-def arange(start: Union[int, float],
+def arange(start: float,
/,
- stop: Optional[Union[int, float]] = None,
- step: Union[int, float] = 1,
+ stop: float | None = None,
+ step: float = 1,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs) -> array:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object) -> Array:
if stop is None:
start, stop = 0, start
if step > 0 and stop <= start or step < 0 and stop >= start:
@@ -538,13 +604,13 @@ def arange(start: Union[int, float],
# torch.eye does not accept None as a default for the second argument and
# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
def eye(n_rows: int,
- n_cols: Optional[int] = None,
+ n_cols: int | None = None,
/,
*,
k: int = 0,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs) -> array:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object) -> Array:
if n_cols is None:
n_cols = n_rows
z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs)
@@ -553,70 +619,81 @@ def eye(n_rows: int,
return z
# torch.linspace doesn't have the endpoint parameter
-def linspace(start: Union[int, float],
- stop: Union[int, float],
+def linspace(start: float,
+ stop: float,
/,
num: int,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
+ dtype: DType | None = None,
+ device: Device | None = None,
endpoint: bool = True,
- **kwargs) -> array:
+ **kwargs: object) -> Array:
if not endpoint:
return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1]
return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs)
# torch.full does not accept an int size
# https://github.com/pytorch/pytorch/issues/70906
-def full(shape: Union[int, Tuple[int, ...]],
- fill_value: Union[bool, int, float, complex],
+def full(shape: int | tuple[int, ...],
+ fill_value: complex,
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs) -> array:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object) -> Array:
if isinstance(shape, int):
shape = (shape,)
return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
# ones, zeros, and empty do not accept shape as a keyword argument
-def ones(shape: Union[int, Tuple[int, ...]],
+def ones(shape: int | tuple[int, ...],
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs) -> array:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object) -> Array:
return torch.ones(shape, dtype=dtype, device=device, **kwargs)
-def zeros(shape: Union[int, Tuple[int, ...]],
+def zeros(shape: int | tuple[int, ...],
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs) -> array:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object) -> Array:
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
-def empty(shape: Union[int, Tuple[int, ...]],
+def empty(shape: int | tuple[int, ...],
*,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- **kwargs) -> array:
+ dtype: DType | None = None,
+ device: Device | None = None,
+ **kwargs: object) -> Array:
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
# tril and triu do not call the keyword argument k
-def tril(x: array, /, *, k: int = 0) -> array:
+def tril(x: Array, /, *, k: int = 0) -> Array:
return torch.tril(x, k)
-def triu(x: array, /, *, k: int = 0) -> array:
+def triu(x: Array, /, *, k: int = 0) -> Array:
return torch.triu(x, k)
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
-def expand_dims(x: array, /, *, axis: int = 0) -> array:
+def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
return torch.unsqueeze(x, axis)
-def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
- return x.to(dtype, copy=copy)
-def broadcast_arrays(*arrays: array) -> List[array]:
+def astype(
+ x: Array,
+ dtype: DType,
+ /,
+ *,
+ copy: bool = True,
+ device: Device | None = None,
+) -> Array:
+ if device is not None:
+ return x.to(device, dtype=dtype, copy=copy)
+ return x.to(dtype=dtype, copy=copy)
+
+
+def broadcast_arrays(*arrays: Array) -> list[Array]:
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
return [torch.broadcast_to(a, shape) for a in arrays]
@@ -626,7 +703,7 @@ def broadcast_arrays(*arrays: array) -> List[array]:
UniqueInverseResult)
# https://github.com/pytorch/pytorch/issues/70920
-def unique_all(x: array) -> UniqueAllResult:
+def unique_all(x: Array) -> UniqueAllResult:
# torch.unique doesn't support returning indices.
# https://github.com/pytorch/pytorch/issues/36748. The workaround
# suggested in that issue doesn't actually function correctly (it relies
@@ -639,7 +716,7 @@ def unique_all(x: array) -> UniqueAllResult:
# counts[torch.isnan(values)] = 1
# return UniqueAllResult(values, indices, inverse_indices, counts)
-def unique_counts(x: array) -> UniqueCountsResult:
+def unique_counts(x: Array) -> UniqueCountsResult:
values, counts = torch.unique(x, return_counts=True)
# torch.unique incorrectly gives a 0 count for nan values.
@@ -647,27 +724,34 @@ def unique_counts(x: array) -> UniqueCountsResult:
counts[torch.isnan(values)] = 1
return UniqueCountsResult(values, counts)
-def unique_inverse(x: array) -> UniqueInverseResult:
+def unique_inverse(x: Array) -> UniqueInverseResult:
values, inverse = torch.unique(x, return_inverse=True)
return UniqueInverseResult(values, inverse)
-def unique_values(x: array) -> array:
+def unique_values(x: Array) -> Array:
return torch.unique(x)
-def matmul(x1: array, x2: array, /, **kwargs) -> array:
+def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array:
# torch.matmul doesn't type promote (but differently from _fix_promotion)
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.matmul(x1, x2, **kwargs)
-matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
-_vecdot = get_xp(torch)(_aliases_vecdot)
+matrix_transpose = get_xp(torch)(_aliases.matrix_transpose)
+_vecdot = get_xp(torch)(_aliases.vecdot)
-def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return _vecdot(x1, x2, axis=axis)
# torch.tensordot uses dims instead of axes
-def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array:
+def tensordot(
+ x1: Array,
+ x2: Array,
+ /,
+ *,
+ axes: int | tuple[Sequence[int], Sequence[int]] = 2,
+ **kwargs: object,
+) -> Array:
# Note: torch.tensordot fails with integer dtypes when there is only 1
# element in the axis (https://github.com/pytorch/pytorch/issues/84530).
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
@@ -675,8 +759,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
def isdtype(
- dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
- *, _tuple=True, # Disallow nested tuples
+ dtype: DType,
+ kind: DType | str | tuple[DType | str, ...],
+ *,
+ _tuple: bool = True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
@@ -710,14 +796,19 @@ def isdtype(
else:
return dtype == kind
-def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
+def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array:
if axis is None:
if x.ndim != 1:
raise ValueError("axis must be specified when ndim > 1")
axis = 0
return torch.index_select(x, axis, indices, **kwargs)
-def sign(x: array, /) -> array:
+
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
+ return torch.take_along_dim(x, indices, dim=axis)
+
+
+def sign(x: Array, /) -> Array:
# torch sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
if x.dtype.is_complex:
@@ -732,14 +823,21 @@ def sign(x: array, /) -> array:
return out
-__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
+def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]:
+ # enforce the default of 'xy'
+ # TODO: is the return type a list or a tuple
+ return list(torch.meshgrid(*arrays, indexing='xy'))
+
+
+__all__ = ['asarray', 'result_type', 'can_cast',
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
- 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide',
+ 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
+ 'diff', 'divide',
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
- 'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum',
+ 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
@@ -747,6 +845,4 @@ def sign(x: array, /) -> array:
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
- 'take', 'sign']
-
-_all_ignore = ['torch', 'get_xp']
+ 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid']
diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py
index 264caa9e..818e5d37 100644
--- a/array_api_compat/torch/_info.py
+++ b/array_api_compat/torch/_info.py
@@ -34,7 +34,7 @@ class __array_namespace_info__:
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': numpy.float64,
'complex floating': numpy.complex128,
@@ -76,17 +76,17 @@ def capabilities(self):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
- 'data-dependent shapes': True}
+ 'data-dependent shapes': True,
+ 'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
- # 'max rank' will be part of the 2024.12 standard
- # "max rank": 64,
+ "max dimensions": 64,
}
def default_device(self):
@@ -102,15 +102,24 @@ def default_device(self):
Returns
-------
- device : str
+ device : Device
The default device used for new PyTorch arrays.
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.default_device()
- 'cpu'
+ device(type='cpu')
+ Notes
+ -----
+ This method returns the static default device when PyTorch is initialized.
+ However, the *current* device used by creation functions (``empty`` etc.)
+ can be changed at runtime.
+
+ See Also
+ --------
+ https://github.com/data-apis/array-api/issues/835
"""
return torch.device("cpu")
@@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None):
Parameters
----------
- device : str, optional
- The device to get the default data types for. For PyTorch, only
- ``'cpu'`` is allowed.
+ device : Device, optional
+ The device to get the default data types for.
+ Unused for PyTorch, as all devices use the same default dtypes.
Returns
-------
@@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': torch.float32,
'complex floating': torch.complex64,
@@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None):
Parameters
----------
- device : str, optional
+ device : Device, optional
The device to get the data types for.
+ Unused for PyTorch, as all devices use the same dtypes.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
@@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': numpy.int8,
'int16': numpy.int16,
@@ -310,7 +320,7 @@ def devices(self):
Returns
-------
- devices : list of str
+ devices : list[Device]
The devices supported by PyTorch.
See Also
@@ -322,7 +332,7 @@ def devices(self):
Examples
--------
- >>> info = np.__array_namespace_info__()
+ >>> info = xp.__array_namespace_info__()
>>> info.devices()
[device(type='cpu'), device(type='mps', index=0), device(type='meta')]
@@ -333,6 +343,7 @@ def devices(self):
# device:
try:
torch.device('notadevice')
+ raise AssertionError("unreachable") # pragma: nocover
except RuntimeError as e:
# The error message is something like:
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py
new file mode 100644
index 00000000..52670871
--- /dev/null
+++ b/array_api_compat/torch/_typing.py
@@ -0,0 +1,3 @@
+__all__ = ["Array", "Device", "DType"]
+
+from torch import device as Device, dtype as DType, Tensor as Array
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index 3c9117ee..76342980 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -1,86 +1,82 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- import torch
- array = torch.Tensor
- from typing import Union, Sequence, Literal
+from collections.abc import Sequence
+from typing import Literal
-from torch.fft import * # noqa: F403
+import torch
import torch.fft
+from ._typing import Array
+from .._internal import clone_module
+
+__all__ = clone_module("torch.fft", globals())
+
# Several torch fft functions do not map axes to dim
def fftn(
- x: array,
+ x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
- **kwargs,
-) -> array:
+ **kwargs: object,
+) -> Array:
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
def ifftn(
- x: array,
+ x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
- **kwargs,
-) -> array:
+ **kwargs: object,
+) -> Array:
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
def rfftn(
- x: array,
+ x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
- **kwargs,
-) -> array:
+ **kwargs: object,
+) -> Array:
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def irfftn(
- x: array,
+ x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
- **kwargs,
-) -> array:
+ **kwargs: object,
+) -> Array:
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def fftshift(
- x: array,
+ x: Array,
/,
*,
- axes: Union[int, Sequence[int]] = None,
- **kwargs,
-) -> array:
+ axes: int | Sequence[int] = None,
+ **kwargs: object,
+) -> Array:
return torch.fft.fftshift(x, dim=axes, **kwargs)
def ifftshift(
- x: array,
+ x: Array,
/,
*,
- axes: Union[int, Sequence[int]] = None,
- **kwargs,
-) -> array:
+ axes: int | Sequence[int] = None,
+ **kwargs: object,
+) -> Array:
return torch.fft.ifftshift(x, dim=axes, **kwargs)
-__all__ = torch.fft.__all__ + [
- "fftn",
- "ifftn",
- "rfftn",
- "irfftn",
- "fftshift",
- "ifftshift",
-]
+__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"]
-_all_ignore = ['torch']
+def __dir__() -> list[str]:
+ return __all__
diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py
index e26198b9..08271d22 100644
--- a/array_api_compat/torch/linalg.py
+++ b/array_api_compat/torch/linalg.py
@@ -1,42 +1,35 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
- import torch
- array = torch.Tensor
- from torch import dtype as Dtype
- from typing import Optional, Union, Tuple, Literal
- inf = float('inf')
+import torch
+import torch.linalg
-from ._aliases import _fix_promotion, sum
-
-from torch.linalg import * # noqa: F403
+from .._internal import clone_module
-# torch.linalg doesn't define __all__
-# from torch.linalg import __all__ as linalg_all
-from torch import linalg as torch_linalg
-linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
+__all__ = clone_module("torch.linalg", globals())
# outer is implemented in torch but aren't in the linalg namespace
from torch import outer
+from ._aliases import _fix_promotion, sum
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
+from ._typing import Array, DType
+from ..common._typing import JustInt, JustFloat
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
# torch.cross also does not support broadcasting when it would add new
# dimensions https://github.com/pytorch/pytorch/issues/39656
-def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
+def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
x1, x2 = torch.broadcast_tensors(x1, x2)
- return torch_linalg.cross(x1, x2, dim=axis)
+ return torch.linalg.cross(x1, x2, dim=axis)
-def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
@@ -58,7 +51,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
-def solve(x1: array, x2: array, /, **kwargs) -> array:
+def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
@@ -79,19 +72,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
return torch.linalg.solve(x1, x2, **kwargs)
# torch.trace doesn't support the offset argument and doesn't support stacking
-def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
+def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
- x: array,
+ x: Array,
/,
*,
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
- ord: Union[int, float, Literal[inf, -inf]] = 2,
- **kwargs,
-) -> array:
+ # JustFloat stands for inf | -inf, which are not valid for Literal
+ ord: JustInt | JustFloat = 2,
+ **kwargs: object,
+) -> Array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
@@ -113,9 +107,8 @@ def vector_norm(
return out
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
-__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
- 'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
-
-_all_ignore = ['torch_linalg', 'sum']
+__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot',
+ 'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
-del linalg_all
+def __dir__() -> list[str]:
+ return __all__
diff --git a/cupy-xfails.txt b/cupy-xfails.txt
index fb7b03da..0a91cafe 100644
--- a/cupy-xfails.txt
+++ b/cupy-xfails.txt
@@ -11,12 +11,10 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)]
# testsuite bug (https://github.com/data-apis/array-api-tests/issues/172)
array_api_tests/test_array_object.py::test_getitem
-# copy=False is not yet implemented
-array_api_tests/test_creation_functions.py::test_asarray_arrays
-
-# finfo test is testing that the result is a float instead of float32 (see
-# also https://github.com/data-apis/array-api/issues/405)
+# attributes are np.float32 instead of float
+# (see also https://github.com/data-apis/array-api/issues/405)
array_api_tests/test_data_type_functions.py::test_finfo[float32]
+array_api_tests/test_data_type_functions.py::test_finfo[complex64]
# Some array attributes are missing, and we do not wrap the array object
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
@@ -36,6 +34,16 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
# floating point inaccuracy
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
+# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1)
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
# cupy (arg)min/max wrong with infinities
# https://github.com/cupy/cupy/issues/7424
@@ -173,10 +181,23 @@ array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0]
array_api_tests/test_fft.py::test_fftn
array_api_tests/test_fft.py::test_ifftn
array_api_tests/test_fft.py::test_rfftn
+
+# observed in the 1.10 release process, is likely related to xfails above
+array_api_tests/test_fft.py::test_irfftn
# 2023.12 support
# cupy.ndaray cannot be specified as `repeats` argument.
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
-array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
+
+# 2024.12 support
+array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
+array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0]
+
+# cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8)
+array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
diff --git a/dask-skips.txt b/dask-skips.txt
index 63a09e4b..a16a8588 100644
--- a/dask-skips.txt
+++ b/dask-skips.txt
@@ -1,2 +1,9 @@
-# slow and not implemented in dask
-array_api_tests/test_linalg.py::test_matrix_power
+# NOTE: dask tests run on a very small number of examples in CI due to
+# slowness. This causes very high flakiness in the tests.
+# Before changing this file, please run with at least 200 examples.
+
+# Passes, but extremely slow
+array_api_tests/test_linalg.py::test_outer
+
+# Hangs
+array_api_tests/test_creation_functions.py::test_eye
diff --git a/dask-xfails.txt b/dask-xfails.txt
index 1e9c421c..3efb4f96 100644
--- a/dask-xfails.txt
+++ b/dask-xfails.txt
@@ -1,73 +1,46 @@
-# This fails in dask
-# import dask.array as da
-# a = da.array([1]).reshape((1,1))
-# key = (0, slice(None, None, -1))
-# a[key] = da.array([1])
-
-# Failing hypothesis test case
-#x=dask.array
-#| Draw 1 (key): (slice(None, None, None), slice(None, None, None))
-#| Draw 2 (value): dask.array
-
-# Various shape mismatches e.g.
-ValueError: shape mismatch: value array of shape (0, 2) could not be broadcast to indexing result of shape (0, 2)
-array_api_tests/test_array_object.py::test_setitem
+# NOTE: dask tests run on a very small number of examples in CI due to
+# slowness. This causes very high flakiness in the tests.
+# Before changing this file, please run with at least 200 examples.
-# Fails since bad upcast from uint8 -> int64
-# MRE:
-# a = da.array(0, dtype="uint8")
-# b = da.array(False)
-# a[b] = 0
-array_api_tests/test_array_object.py::test_setitem_masking
+# Broken edge case with shape 0
+# https://github.com/dask/dask/issues/11800
+array_api_tests/test_array_object.py::test_setitem
# Various indexing errors
array_api_tests/test_array_object.py::test_getitem_masking
-# asarray(copy=False) is not yet implemented
-# copied from numpy xfails, TODO: should this pass with dask?
-array_api_tests/test_creation_functions.py::test_asarray_arrays
-
# zero division error, and typeerror: tuple indices must be integers or slices not tuple
array_api_tests/test_creation_functions.py::test_eye
-# finfo(float32).eps returns float32 but should return float
+# attributes are np.float32 instead of float
+# (see also https://github.com/data-apis/array-api/issues/405)
array_api_tests/test_data_type_functions.py::test_finfo[float32]
+array_api_tests/test_data_type_functions.py::test_finfo[complex64]
-# out[-1]=dask.aray but should be some floating number
+# out[-1]=dask.array but should be some floating number
# (I think the test is not forcing the op to be computed?)
array_api_tests/test_creation_functions.py::test_linspace
-# out.shape=(2,) but should be (1,)
+# Shape mismatch
array_api_tests/test_indexing_functions.py::test_take
-# out=-0, but should be +0
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-
-# output is nan but should be infinity
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
+# missing `take_along_axis`, https://github.com/dask/dask/issues/3663
+array_api_tests/test_indexing_functions.py::test_take_along_axis
-# No sorting in dask
-array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
-array_api_tests/test_has_names.py::test_has_names[sorting-sort]
-array_api_tests/test_sorting_functions.py::test_argsort
-array_api_tests/test_sorting_functions.py::test_sort
-array_api_tests/test_signatures.py::test_func_signature[argsort]
-array_api_tests/test_signatures.py::test_func_signature[sort]
-
-# Array methods and attributes not already on np.ndarray cannot be wrapped
+# Array methods and attributes not already on da.Array cannot be wrapped
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
-# Fails because shape is NaN since we don't materialize it yet
+# Data-dependent output shape
+# These tests fail as array-api-tests doesn't cope with unknown shapes
+# Also, output shape is (math.nan, ) instead of (None, )
+# Also, da.unique() doesn't accept equals_nan which causes non-compliant
+# output when there are NaNs in the input.
array_api_tests/test_searching_functions.py::test_nonzero
array_api_tests/test_set_functions.py::test_unique_all
array_api_tests/test_set_functions.py::test_unique_counts
-
-# Different error but same cause as above, we're just trying to do ndindex on nan shape
array_api_tests/test_set_functions.py::test_unique_inverse
array_api_tests/test_set_functions.py::test_unique_values
@@ -75,24 +48,17 @@ array_api_tests/test_set_functions.py::test_unique_values
# fails for ndim > 2
array_api_tests/test_linalg.py::test_svdvals
-array_api_tests/test_linalg.py::test_cholesky
-# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
+
+# dtype mismatch got uint64, but should be uint8; NPY_PROMOTION_STATE=weak doesn't help
array_api_tests/test_linalg.py::test_tensordot
# AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)]
array_api_tests/test_linalg.py::test_linalg_tensordot
-# AssertionError: out.shape=(1,), but should be () [linalg.vector_norm(keepdims=True)]
-array_api_tests/test_linalg.py::test_vector_norm
-
# ZeroDivisionError in dask's normalize_chunks/auto_chunks internals
array_api_tests/test_linalg.py::test_inv
array_api_tests/test_linalg.py::test_matrix_power
-# did not raise error for invalid shapes
-array_api_tests/test_linalg.py::test_matmul
-array_api_tests/test_linalg.py::test_linalg_matmul
-
# Linalg - these don't exist in dask
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det]
@@ -105,6 +71,7 @@ array_api_tests/test_linalg.py::test_cross
array_api_tests/test_linalg.py::test_det
array_api_tests/test_linalg.py::test_eigh
array_api_tests/test_linalg.py::test_eigvalsh
+array_api_tests/test_linalg.py::test_matrix_rank
array_api_tests/test_linalg.py::test_pinv
array_api_tests/test_linalg.py::test_slogdet
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
@@ -115,17 +82,10 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]
-array_api_tests/test_linalg.py::test_matrix_norm
-array_api_tests/test_linalg.py::test_matrix_rank
-
-# missing mode kw
-# https://github.com/dask/dask/issues/10388
-array_api_tests/test_linalg.py::test_qr
-
# Constructing the input arrays fails to a weird shape error...
array_api_tests/test_linalg.py::test_solve
-# missing full_matrics kw
+# missing full_matrices kw
# https://github.com/dask/dask/issues/10389
# also only supports 2-d inputs
array_api_tests/test_linalg.py::test_svd
@@ -140,18 +100,51 @@ array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
-# Some cases unsupported by dask
-array_api_tests/test_manipulation_functions.py::test_roll
-
# No mT on dask array
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
-# The test suite is incorrectly checking sums that have loss of significance
-# (https://github.com/data-apis/array-api-tests/issues/168)
-array_api_tests/test_statistical_functions.py::test_sum
-array_api_tests/test_statistical_functions.py::test_prod
+# Edge case of args near 2**63
+# https://github.com/dask/dask/issues/11706
+array_api_tests/test_creation_functions.py::test_arange
+
+# da.searchsorted with a sorter argument is not supported
+array_api_tests/test_searching_functions.py::test_searchsorted
# 2023.12 support
array_api_tests/test_manipulation_functions.py::test_repeat
-array_api_tests/test_searching_functions.py::test_searchsorted
-array_api_tests/test_signatures.py::test_func_signature[astype]
+
+# 2024.12 support
+array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[1]
+array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[None]
+array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1]
+array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None]
+array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
+array_api_tests/test_signatures.py::test_func_signature[count_nonzero]
+array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
+
+array_api_tests/test_linalg.py::test_cholesky
+array_api_tests/test_linalg.py::test_linalg_matmul
+array_api_tests/test_linalg.py::test_matmul
+array_api_tests/test_linalg.py::test_matrix_norm
+array_api_tests/test_linalg.py::test_qr
+array_api_tests/test_manipulation_functions.py::test_roll
+
+# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.)
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
diff --git a/docs/changelog.md b/docs/changelog.md
index 29a18a5c..6f6c1251 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -1,5 +1,161 @@
# Changelog
+## 1.12.0 (2025-05-13)
+
+
+### Major changes
+
+- The build system has been updated to use `pyproject.toml` instead of `setup.py`
+- Support for Python 3.9 has been dropped. The minimum supported Python version is now
+ 3.10; the minimum supported NumPy version is 1.22.
+- The `linalg` extension works correctly with `pytorch>=2.7`.
+- Multiple improvements to handling of devices in CuPy and PyTorch backends.
+ Support for multiple devices in CuPy is still immature and you should use
+ context managers rather than relying on input-output device propagation or
+ on the `device` parameter. Please report any issues you encounter.
+
+### Minor changes
+
+- `finfo` and `iinfo` functions now accept array arguments, in accordance with the
+ Array API spec;
+- `torch.asarray` function propagates the device of the input array. This works around
+ the [pytorch issue #150199](https://github.com/pytorch/pytorch/issues/150199);
+- `torch.repeat` function is now available;
+- `torch.count_nonzero` function now correctly handles the case of a tuple `axis`
+ arguments and `keepdims=True`;
+- `torch.meshgrid` wrapper defaults to `indexing="xy"`, in accordance with the
+ array API specification;
+- `cupy.asarray` function now implements the `copy=False` argument, albeit
+ at the cost of risking to make a temporary copy.
+- In `numpy.take_along_axis` and `cupy.take_along_axis` the `axis` parameter now
+ defaults to -1, in accordance to the Array API spec.
+
+
+The following users contributed to this release:
+
+Evgeni Burovski,
+Lucas Colley,
+Neil Girdhar,
+Joren Hammudoglu,
+Guido Imperiale
+
+
+## 1.11.2 (2025-03-20)
+
+This is a bugfix release with no new features compared to version 1.11.
+
+- fix the `result_type` wrapper for pytorch. Previously, `result_type` had multiple
+ issues with scalar arguments.
+- fix several issues with `clip` wrappers. Previously, `clip` was failing to allow
+ behaviors which are unspecified by the 2024.12 standard but allowed by the array
+ libraries.
+
+The following users contributed to this release:
+
+Evgeni Burovski
+Guido Imperiale
+Magnus Dalen Kvalevåg
+
+
+## 1.11.1 (2025-03-04)
+
+This is a bugfix release with no new features compared to version 1.11.
+
+### Major Changes
+
+- fix `count_nonzero` wrappers: work around the lack of the `keepdims` argument in
+ several array libraries (torch, dask, cupy); work around numpy returning python
+ ints in for some input combinations.
+
+### Minor Changes
+
+- runnings self-tests does not require all array libraries. Missing libraries are
+ skipped.
+
+The following users contributed to this release:
+
+Evgeni Burovski
+Guido Imperiale
+
+
+## 1.11.0 (2025-02-27)
+
+### Major Changes
+
+This release targets the 2024.12 Array API revision. This includes
+
+ - `__array_api_version__` for the wrapped APIs is now set to `2024.12`;
+ - Wrappers for `count_nonzero`;
+ - Wrappers for `cumulative_prod`;
+ - Wrappers for `take_along_axis` (with the exception of Dask);
+ - Wrappers for `diff`;
+ - `__capabilities__` dict contains a `max_dimensions` key;
+ - Python scalars are accepted as arguments to `result_type`;
+ - `fft.fftfreq` and `fft.rfftfreq` functions now accept an optional `dtype`
+ argument to control the output data type.
+
+Note that these wrappers, as well as other 2024.12 features, are relatively undertested
+in this release, and may have rough edges. Please report any issues you encounter
+in [the issue tracker](https://github.com/data-apis/array-api-compat/issues).
+
+New functions to test properties of arrays:
+ - `is_writeable_array` (benefits NumPy, JAX, Sparse)
+ - `is_lazy_array` (benefits JAX, Dask, ndonnx)
+
+Improved support for JAX:
+ - Work arounds for `.device` attribute and `to_device` function
+ not working correctly within `jax.jit`
+
+### Minor Changes
+
+- Several improvements to `dask.array` wrappers:
+
+ - `size` returns None for arrays of unknown shapes.
+ - `astype(..., copy=True)` always copies, independently of the Dask version.
+ - implementations of `sort` and `argsort` are now available. Note that these
+ implementations are relatively crude, and might be memory intensive.
+ - `asarray` no longer accidentally materializes the Dask graph
+ - `torch` wrappers contain unsigned integer dtypes of widths >8 bits, `uint16`,
+ `uint32` and `uint64` if PyTorch version is at least 2.3. Note that the
+ unsigned integer support is incomplete in PyTorch itself, see
+ [gh-253](https://github.com/data-apis/array-api-compat/pull/253).
+
+### Authors
+
+The following users contributed to this release:
+
+Athan Reines
+Guido Imperiale
+Evgeni Burovski
+Guido Imperiale
+Lucas Colley
+Ralf Gommers
+Thomas Li
+
+
+## 1.10.0 (2024-12-25)
+
+### Major Changes
+
+- New function `is_writeable_array` adds transparent support for readonly
+ arrays, such as JAX arrays or numpy arrays with `.flags.writeable=False`.
+
+- `asarray(..., copy=None)` with `dask` backend always copies, so that
+ `copy=None` and `copy=True` are equivalent for the `dask` backend.
+ This change is made to be forward compatible with the `dask==2024.12`
+ release.
+
+
+### Minor Changes
+
+- `array_namespace` accepts (and ignores) `None` and python scalars (int, float,
+ complex, bool). This change is to simplify downstream adoption, for
+ functions where arguments can be either arrays or scalars.
+
+- `vecdot` conjugates its first argument, as stipulated by the Array API spec.
+ Previously, conjation if the first argument was missing.
+
+
## 1.9.1 (2024-10-29)
### Major Changes
diff --git a/docs/dev/tests.md b/docs/dev/tests.md
index 6d9d1d7b..18fb7cf5 100644
--- a/docs/dev/tests.md
+++ b/docs/dev/tests.md
@@ -7,7 +7,7 @@ the array API standard. There are also array-api-compat specific tests in
These tests should be limited to things that are not tested by the test suite,
e.g., tests for [helper functions](../helper-functions.rst) or for behavior
that is not strictly required by the standard. To run these tests, install the
-dependencies from `requirements-dev.txt` (array-api-compat has [no hard
+dependencies from the `dev` optional group (array-api-compat has [no hard
runtime dependencies](no-dependencies)).
array-api-tests is run against all supported libraries are tested on CI
diff --git a/docs/helper-functions.rst b/docs/helper-functions.rst
index f44dc070..155eda9a 100644
--- a/docs/helper-functions.rst
+++ b/docs/helper-functions.rst
@@ -51,6 +51,8 @@ yet.
.. autofunction:: is_jax_array
.. autofunction:: is_pydata_sparse_array
.. autofunction:: is_ndonnx_array
+.. autofunction:: is_writeable_array
+.. autofunction:: is_lazy_array
.. autofunction:: is_numpy_namespace
.. autofunction:: is_cupy_namespace
.. autofunction:: is_torch_namespace
diff --git a/docs/index.md b/docs/index.md
index 874c3866..b3d9a44f 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -12,8 +12,8 @@ each array library itself fully compatible with the array API, but this
requires making backwards incompatible changes in many cases, so this will
take some time.
-Currently all libraries here are implemented against the [2023.12
-version](https://data-apis.org/array-api/2023.12/) of the standard.
+Currently all libraries here are implemented against the [2024.12
+version](https://data-apis.org/array-api/2024.12/) of the standard.
## Installation
diff --git a/docs/requirements.txt b/docs/requirements.txt
deleted file mode 100644
index dbec7740..00000000
--- a/docs/requirements.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-furo
-linkify-it-py
-myst-parser
-sphinx
-sphinx-copybutton
-sphinx-autobuild
diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md
index 26a1c1c5..96452c13 100644
--- a/docs/supported-array-libraries.md
+++ b/docs/supported-array-libraries.md
@@ -36,23 +36,16 @@ deviations from the standard should be noted:
50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and
https://github.com/numpy/numpy/issues/22341)
-- `asarray()` does not support `copy=False`.
-
- Functions which are not wrapped may not have the same type annotations
as the spec.
- Functions which are not wrapped may not use positional-only arguments.
-The minimum supported NumPy version is 1.21. However, this older version of
+The minimum supported NumPy version is 1.22. However, this older version of
NumPy has a few issues:
- `unique_*` will not compare nans as unequal.
-- `finfo()` has no `smallest_normal`.
- No `from_dlpack` or `__dlpack__`.
-- `argmax()` and `argmin()` do not have `keepdims`.
-- `qr()` doesn't support matrix stacks.
-- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not
- supported even in the latest NumPy).
- Type promotion behavior will be value based for 0-D arrays (and there is no
`NPY_PROMOTION_STATE=weak` to disable this).
@@ -72,8 +65,8 @@ version.
attribute in the spec. Use the {func}`~.size()` helper function as a
portable workaround.
-- PyTorch does not have unsigned integer types other than `uint8`, and no
- attempt is made to implement them here.
+- PyTorch has incomplete support for unsigned integer types other
+ than `uint8`, and no attempt is made to implement them here.
- PyTorch has type promotion semantics that differ from the array API
specification for 0-D tensor objects. The array functions in this wrapper
@@ -100,8 +93,6 @@ version.
- As with NumPy, type annotations and positional-only arguments may not
exactly match the spec for functions that are not wrapped at all.
-The minimum supported PyTorch version is 1.13.
-
(jax-support)=
## [JAX](https://jax.readthedocs.io/en/latest/)
@@ -131,13 +122,17 @@ For `linalg`, several methods are missing, for example:
- `matrix_rank`
Other methods may only be partially implemented or return incorrect results at times.
-The minimum supported Dask version is 2023.12.0.
-
(sparse-support)=
## [Sparse](https://sparse.pydata.org/en/stable/)
Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
+(ndonnx-support)=
+## [ndonnx](https://github.com/quantco/ndonnx)
+
+Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`.
+
+(paddle-support)=
## [Paddle](https://www.paddlepaddle.org.cn/)
- Like NumPy/CuPy, we do not wrap the `paddle.Tensor` object. It is missing the
@@ -158,4 +153,9 @@ Similar to JAX, `sparse` Array API support is contained directly in `sparse`.
- As with NumPy, type annotations and positional-only arguments may not
exactly match the spec for functions that are not wrapped at all.
-The minimum supported PyTorch version is 3.0.0.
+The minimum supported PyTorch version is `3.2.0`.
+
+(array-api-strict-support)=
+## [array-api-strict](https://data-apis.org/array-api-strict/)
+
+array-api-strict exists only to test support for the Array API, so it does not need any wrappers.
diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt
deleted file mode 100644
index 459b33e3..00000000
--- a/numpy-1-21-xfails.txt
+++ /dev/null
@@ -1,260 +0,0 @@
-# asarray(copy=False) is not yet implemented
-array_api_tests/test_creation_functions.py::test_asarray_arrays
-
-# https://github.com/data-apis/array-api-tests/issues/195
-array_api_tests/test_creation_functions.py::test_linspace
-
-# finfo(float32).eps returns float32 but should return float
-array_api_tests/test_data_type_functions.py::test_finfo[float32]
-
-# Array methods and attributes not already on np.ndarray cannot be wrapped
-array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
-array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
-array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
-array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
-
-# linalg tests require https://github.com/data-apis/array-api-tests/pull/101
-# cleanups. Also some tests are using .mT
-array_api_tests/test_linalg.py::test_eigvalsh
-array_api_tests/test_linalg.py::test_solve
-array_api_tests/test_linalg.py::test_trace
-
-# Array methods and attributes not already on np.ndarray cannot be wrapped
-array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
-array_api_tests/test_signatures.py::test_array_method_signature[to_device]
-
-# NumPy deviates in some special cases for floordiv
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
-array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
-array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
-
-# https://github.com/numpy/numpy/issues/21213
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
-
-# NumPy 1.21 specific XFAILS
-############################
-
-# finfo has no smallest_normal
-array_api_tests/test_data_type_functions.py::test_finfo[float64]
-
-# dlpack stuff
-array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
-array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
-array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
-array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
-array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
-array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
-
-# qr() doesn't support matrix stacks
-array_api_tests/test_linalg.py::test_qr
-
-# cross has some promotion bug that is fixed in newer numpy versions
-array_api_tests/test_linalg.py::test_cross
-
-# vector_norm with ord=-1 which has since been fixed
-# https://github.com/numpy/numpy/issues/21083
-array_api_tests/test_linalg.py::test_vector_norm
-
-# argmax and argmin do not support keepdims
-array_api_tests/test_searching_functions.py::test_argmax
-array_api_tests/test_searching_functions.py::test_argmin
-array_api_tests/test_signatures.py::test_func_signature[argmax]
-array_api_tests/test_signatures.py::test_func_signature[argmin]
-
-# unique doesn't support comparing nans as unequal
-array_api_tests/test_set_functions.py::test_unique_all
-array_api_tests/test_set_functions.py::test_unique_counts
-array_api_tests/test_set_functions.py::test_unique_inverse
-array_api_tests/test_set_functions.py::test_unique_values
-
-# The test suite is incorrectly checking sums that have loss of significance
-# (https://github.com/data-apis/array-api-tests/issues/168)
-array_api_tests/test_statistical_functions.py::test_sum
-array_api_tests/test_statistical_functions.py::test_prod
-
-# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with
-# type promotion issues
-array_api_tests/test_manipulation_functions.py::test_concat
-array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_copysign
-array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_hypot
-array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp
-array_api_tests/test_operators_and_elementwise_functions.py::test_maximum
-array_api_tests/test_operators_and_elementwise_functions.py::test_minimum
-array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
-array_api_tests/test_searching_functions.py::test_where
-array_api_tests/test_special_cases.py::test_binary[__add__((x1_i is +0 or x1_i == -0) and isfinite(x2_i) and x2_i != 0) -> x2_i]
-array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and (x2_i is +0 or x2_i == -0)) -> x1_i]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i]
-array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i]
-array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i]
-array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i]
-array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is +infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is -infinity) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is +infinity) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is -infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i > 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is NaN and not x2_i == 0) -> NaN]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is +0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is -0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is +0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is -0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i > 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[add((x1_i is +0 or x1_i == -0) and isfinite(x2_i) and x2_i != 0) -> x2_i]
-array_api_tests/test_special_cases.py::test_binary[add(isfinite(x1_i) and x1_i != 0 and (x2_i is +0 or x2_i == -0)) -> x1_i]
-array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is +0) -> roughly -pi/2]
-array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2]
-array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is +0) -> roughly +pi/2]
-array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i < 0 and x2_i is +0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i < 0 and x2_i is -0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i > 0 and x2_i is +0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i > 0 and x2_i is -0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i is +0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i > 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) < 1 and x2_i is +infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) < 1 and x2_i is -infinity) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) > 1 and x2_i is +infinity) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(abs(x1_i) > 1 and x2_i is -infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +0 and x2_i < 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +infinity and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is +infinity and x2_i > 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -infinity and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[pow(x1_i is NaN and not x2_i == 0) -> NaN]
-array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i]
-array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i]
-array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i]
-array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i]
-array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0]
-
-# 2023.12 support
-array_api_tests/test_searching_functions.py::test_searchsorted
-array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
-array_api_tests/test_signatures.py::test_func_signature[astype]
-array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
-# uint64 repeats not supported
-array_api_tests/test_manipulation_functions.py::test_repeat
diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt
new file mode 100644
index 00000000..5df1b6d7
--- /dev/null
+++ b/numpy-1-22-xfails.txt
@@ -0,0 +1,175 @@
+# attributes are np.float32 instead of float
+# (see also https://github.com/data-apis/array-api/issues/405)
+array_api_tests/test_data_type_functions.py::test_finfo[float32]
+array_api_tests/test_data_type_functions.py::test_finfo[complex64]
+
+# Array methods and attributes not already on np.ndarray cannot be wrapped
+array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
+array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
+array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
+array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
+
+# Array methods and attributes not already on np.ndarray cannot be wrapped
+array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
+array_api_tests/test_signatures.py::test_array_method_signature[to_device]
+
+# NumPy deviates in some special cases for floordiv
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+
+# https://github.com/numpy/numpy/issues/21213
+array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
+
+# NumPy 1.22 specific XFAILS
+############################
+
+# cross has some promotion bug that is fixed in newer numpy versions
+array_api_tests/test_linalg.py::test_cross
+
+# linspace(-0.0, -1.0, num=1) returns +0.0 instead of -0.0.
+# Fixed in newer numpy versions.
+array_api_tests/test_creation_functions.py::test_linspace
+
+# vector_norm with ord=-1 which has since been fixed
+# https://github.com/numpy/numpy/issues/21083
+array_api_tests/test_linalg.py::test_vector_norm
+
+# NumPy 1.22 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with
+# type promotion issues
+# NOTE: some of these may not fail until one runs array-api-tests with
+# --max-examples 100000
+array_api_tests/test_manipulation_functions.py::test_concat
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_copysign
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_hypot
+array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp
+array_api_tests/test_operators_and_elementwise_functions.py::test_maximum
+array_api_tests/test_operators_and_elementwise_functions.py::test_minimum
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
+array_api_tests/test_searching_functions.py::test_where
+array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0]
+
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2]
+
+array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
+
+# 2023.12 support
+array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
+array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
+array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
+# uint64 repeats not supported
+array_api_tests/test_manipulation_functions.py::test_repeat
+
+# 2024.12 support
+array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
+array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
+
+# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt
index 57b80e7e..98cb9f6c 100644
--- a/numpy-1-26-xfails.txt
+++ b/numpy-1-26-xfails.txt
@@ -1,8 +1,7 @@
-# asarray(copy=False) is not yet implemented
-array_api_tests/test_creation_functions.py::test_asarray_arrays
-
-# finfo(float32).eps returns float32 but should return float
+# attributes are np.float32 instead of float
+# (see also https://github.com/data-apis/array-api/issues/405)
array_api_tests/test_data_type_functions.py::test_finfo[float32]
+array_api_tests/test_data_type_functions.py::test_finfo[complex64]
# Array methods and attributes not already on np.ndarray cannot be wrapped
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
@@ -35,21 +34,40 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
# https://github.com/numpy/numpy/issues/21213
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
-# The test suite is incorrectly checking sums that have loss of significance
-# (https://github.com/data-apis/array-api-tests/issues/168)
-array_api_tests/test_statistical_functions.py::test_sum
-array_api_tests/test_statistical_functions.py::test_prod
-
# 2023.12 support
-array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
-array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
+
+# 2024.12 support
+array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
+array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
+
+array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
+
+# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt
index 23a83e1e..972d2346 100644
--- a/numpy-dev-xfails.txt
+++ b/numpy-dev-xfails.txt
@@ -1,17 +1,7 @@
-# finfo(float32).eps returns float32 but should return float
+# attributes are np.float32 instead of float
+# (see also https://github.com/data-apis/array-api/issues/405)
array_api_tests/test_data_type_functions.py::test_finfo[float32]
-
-# https://github.com/numpy/numpy/issues/21213
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-
-# The test suite is incorrectly checking sums that have loss of significance
-# (https://github.com/data-apis/array-api-tests/issues/168)
-array_api_tests/test_statistical_functions.py::test_sum
-array_api_tests/test_statistical_functions.py::test_prod
-array_api_tests/test_statistical_functions.py::test_cumulative_sum
+array_api_tests/test_data_type_functions.py::test_finfo[complex64]
# The test suite cannot properly get the signature for vecdot
# https://github.com/numpy/numpy/pull/26237
@@ -19,9 +9,32 @@ array_api_tests/test_signatures.py::test_func_signature[vecdot]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
# 2023.12 support
-# Argument 'device' missing from signature
-array_api_tests/test_signatures.py::test_func_signature[astype]
-array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
-array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
# uint64 repeats not supported
array_api_tests/test_manipulation_functions.py::test_repeat
+
+# 2024.12 support
+array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
+
+# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
diff --git a/numpy-skips.txt b/numpy-skips.txt
index cbf7235b..e69de29b 100644
--- a/numpy-skips.txt
+++ b/numpy-skips.txt
@@ -1,11 +0,0 @@
-# These tests cause a core dump on CI, so we have to skip them entirely
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
diff --git a/numpy-xfails.txt b/numpy-xfails.txt
index 1c9d98f6..632b4ec3 100644
--- a/numpy-xfails.txt
+++ b/numpy-xfails.txt
@@ -1,7 +1,25 @@
-# finfo(float32).eps returns float32 but should return float
+# attributes are np.float32 instead of float
+# (see also https://github.com/data-apis/array-api/issues/405)
array_api_tests/test_data_type_functions.py::test_finfo[float32]
+array_api_tests/test_data_type_functions.py::test_finfo[complex64]
-# NumPy deviates in some special cases for floordiv
+# The test suite cannot properly get the signature for vecdot
+# https://github.com/numpy/numpy/pull/26237
+array_api_tests/test_signatures.py::test_func_signature[vecdot]
+array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
+
+# 2023.12 support
+# uint64 repeats not supported
+array_api_tests/test_manipulation_functions.py::test_repeat
+
+# 2024.12 support
+array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
+
+# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
@@ -21,27 +39,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
-# https://github.com/numpy/numpy/issues/21213
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
-
-# The test suite is incorrectly checking sums that have loss of significance
-# (https://github.com/data-apis/array-api-tests/issues/168)
-array_api_tests/test_statistical_functions.py::test_sum
-array_api_tests/test_statistical_functions.py::test_prod
-
-# The test suite cannot properly get the signature for vecdot
-# https://github.com/numpy/numpy/pull/26237
-array_api_tests/test_signatures.py::test_func_signature[vecdot]
-array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
-
-# 2023.12 support
-array_api_tests/test_searching_functions.py::test_searchsorted
-array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
-array_api_tests/test_signatures.py::test_func_signature[astype]
-array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
-# uint64 repeats not supported
-array_api_tests/test_manipulation_functions.py::test_repeat
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..ec054417
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,120 @@
+[build-system]
+requires = ["setuptools", "setuptools-scm"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "array-api-compat"
+dynamic = ["version"]
+description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard"
+readme = "README.md"
+requires-python = ">=3.10"
+license = "MIT"
+authors = [{name = "Consortium for Python Data API Standards"}]
+classifiers = [
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ "Typing :: Typed",
+]
+
+[project.optional-dependencies]
+cupy = ["cupy"]
+dask = ["dask>=2024.9.0"]
+jax = ["jax"]
+# Note: array-api-compat follows scikit-learn minimum dependencies, which support
+# much older versions of NumPy than what SPEC0 recommends.
+numpy = ["numpy>=1.22"]
+pytorch = ["torch"]
+sparse = ["sparse>=0.15.1"]
+ndonnx = ["ndonnx"]
+docs = [
+ "furo",
+ "linkify-it-py",
+ "myst-parser",
+ "sphinx",
+ "sphinx-copybutton",
+ "sphinx-autobuild",
+]
+dev = [
+ "array-api-strict",
+ "dask[array]>=2024.9.0",
+ "jax[cpu]",
+ "ndonnx",
+ "numpy>=1.22",
+ "pytest",
+ "torch",
+ "sparse>=0.15.1",
+]
+
+[project.urls]
+homepage = "https://data-apis.org/array-api-compat/"
+repository = "https://github.com/data-apis/array-api-compat/"
+
+[tool.setuptools.dynamic]
+version = {attr = "array_api_compat.__version__"}
+
+[tool.setuptools.packages.find]
+include = ["array_api_compat*"]
+namespaces = false
+
+[tool.ruff.lint]
+preview = true
+select = [
+# Defaults
+"E4", "E7", "E9", "F",
+# Undefined export
+"F822",
+# Useless import alias
+"PLC0414"
+]
+
+ignore = [
+ # Module import not at top of file
+ "E402",
+ # Do not use bare `except`
+ "E722"
+]
+
+
+[tool.mypy]
+files = ["array_api_compat"]
+disallow_incomplete_defs = true
+disallow_untyped_decorators = true
+disallow_untyped_defs = false # TODO
+ignore_missing_imports = false
+no_implicit_optional = true
+show_error_codes = true
+warn_redundant_casts = true
+warn_unused_ignores = true
+warn_unreachable = true
+
+[[tool.mypy.overrides]]
+module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"]
+ignore_missing_imports = true
+
+
+[tool.pyright]
+include = ["src", "tests"]
+pythonPlatform = "All"
+
+reportAny = false
+reportExplicitAny = false
+# missing type stubs
+reportAttributeAccessIssue = false
+reportUnknownMemberType = false
+reportUnknownVariableType = false
+# Redundant with mypy checks
+reportMissingImports = false
+reportMissingTypeStubs = false
+# false positives for input validation
+reportUnreachable = false
+# ruff handles this
+reportUnusedParameter = false
+
+executionEnvironments = [
+ { root = "array_api_compat" },
+]
diff --git a/ruff.toml b/ruff.toml
deleted file mode 100644
index 72e111b5..00000000
--- a/ruff.toml
+++ /dev/null
@@ -1,17 +0,0 @@
-[lint]
-preview = true
-select = [
-# Defaults
-"E4", "E7", "E9", "F",
-# Undefined export
-"F822",
-# Useless import alias
-"PLC0414"
-]
-
-ignore = [
- # Module import not at top of file
- "E402",
- # Do not use bare `except`
- "E722"
-]
diff --git a/setup.py b/setup.py
deleted file mode 100644
index d0a28404..00000000
--- a/setup.py
+++ /dev/null
@@ -1,37 +0,0 @@
-from setuptools import setup, find_packages
-
-with open("README.md", "r") as fh:
- long_description = fh.read()
-
-import array_api_compat
-
-setup(
- name='array_api_compat',
- version=array_api_compat.__version__,
- packages=find_packages(include=["array_api_compat*"]),
- author="Consortium for Python Data API Standards",
- description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard",
- long_description=long_description,
- long_description_content_type="text/markdown",
- url="https://data-apis.org/array-api-compat/",
- license="MIT",
- extras_require={
- "numpy": "numpy",
- "cupy": "cupy",
- "jax": "jax",
- "pytorch": "pytorch",
- "dask": "dask",
- "sparse": "sparse >=0.15.1",
- },
- python_requires=">=3.9",
- classifiers=[
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: 3.12",
- "Programming Language :: Python :: 3.13",
- "License :: OSI Approved :: MIT License",
- "Operating System :: OS Independent",
- ],
-)
diff --git a/test_cupy.sh b/test_cupy.sh
index 2e176aa1..a6974333 100755
--- a/test_cupy.sh
+++ b/test_cupy.sh
@@ -26,5 +26,5 @@ mkdir -p $SCRIPT_DIR/.hypothesis
ln -s $SCRIPT_DIR/.hypothesis .hypothesis
export ARRAY_API_TESTS_MODULE=array_api_compat.cupy
-export ARRAY_API_TESTS_VERSION=2023.12
+export ARRAY_API_TESTS_VERSION=2024.12
pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@"
diff --git a/tests/_helpers.py b/tests/_helpers.py
index 801cd32d..6826bd36 100644
--- a/tests/_helpers.py
+++ b/tests/_helpers.py
@@ -1,18 +1,14 @@
from importlib import import_module
-import sys
import pytest
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array", "paddle"]
-all_libraries = wrapped_libraries + ["jax.numpy"]
-
-# `sparse` added array API support as of Python 3.10.
-if sys.version_info >= (3, 10):
- all_libraries.append('sparse')
+all_libraries = wrapped_libraries + [
+ "array_api_strict", "jax.numpy", "ndonnx", "sparse"
+]
def import_(library, wrapper=False):
- if library == 'cupy':
- pytest.importorskip(library)
+ pytest.importorskip(library)
if wrapper:
if 'jax' in library:
# JAX v0.4.32 implements the array API directly in jax.numpy
@@ -20,9 +16,7 @@ def import_(library, wrapper=False):
jax_numpy = import_module("jax.numpy")
if not hasattr(jax_numpy, "__array_api_version__"):
library = 'jax.experimental.array_api'
- elif library.startswith('sparse'):
- library = 'sparse'
- else:
+ elif library in wrapped_libraries:
library = 'array_api_compat.' + library
if library == 'paddle':
@@ -31,3 +25,14 @@ def import_(library, wrapper=False):
return xp
return import_module(library)
+
+
+def xfail(request: pytest.FixtureRequest, reason: str) -> None:
+ """
+ XFAIL the currently running test.
+
+ Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
+ halting it, so that it may result in a XPASS.
+ xref https://github.com/pandas-dev/pandas/issues/38902
+ """
+ request.node.add_marker(pytest.mark.xfail(reason=reason))
diff --git a/tests/test_all.py b/tests/test_all.py
index 969d5cfb..c36aef67 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -1,44 +1,311 @@
-"""
-Test that files that define __all__ aren't missing any exports.
+"""Test exported names"""
-You can add names that shouldn't be exported to _all_ignore, like
+import builtins
-_all_ignore = ['sys']
+import numpy as np
+import pytest
-This is preferable to del-ing the names as this will break any name that is
-used inside of a function. Note that names starting with an underscore are automatically ignored.
-"""
+from array_api_compat._internal import clone_module
+from ._helpers import wrapped_libraries
-import sys
+NAMES = {
+ "": [
+ # Inspection
+ "__array_api_version__",
+ "__array_namespace_info__",
+ # Submodules
+ "fft",
+ "linalg",
+ # Constants
+ "e",
+ "inf",
+ "nan",
+ "newaxis",
+ "pi",
+ # Creation Functions
+ "arange",
+ "asarray",
+ "empty",
+ "empty_like",
+ "eye",
+ "from_dlpack",
+ "full",
+ "full_like",
+ "linspace",
+ "meshgrid",
+ "ones",
+ "ones_like",
+ "tril",
+ "triu",
+ "zeros",
+ "zeros_like",
+ # Data Type Functions
+ "astype",
+ "can_cast",
+ "finfo",
+ "iinfo",
+ "isdtype",
+ "result_type",
+ # Data Types
+ "bool",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "complex64",
+ "complex128",
+ # Elementwise Functions
+ "abs",
+ "acos",
+ "acosh",
+ "add",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bitwise_and",
+ "bitwise_left_shift",
+ "bitwise_invert",
+ "bitwise_or",
+ "bitwise_right_shift",
+ "bitwise_xor",
+ "ceil",
+ "clip",
+ "conj",
+ "copysign",
+ "cos",
+ "cosh",
+ "divide",
+ "equal",
+ "exp",
+ "expm1",
+ "floor",
+ "floor_divide",
+ "greater",
+ "greater_equal",
+ "hypot",
+ "imag",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "less",
+ "less_equal",
+ "log",
+ "log1p",
+ "log2",
+ "log10",
+ "logaddexp",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "maximum",
+ "minimum",
+ "multiply",
+ "negative",
+ "nextafter",
+ "not_equal",
+ "positive",
+ "pow",
+ "real",
+ "reciprocal",
+ "remainder",
+ "round",
+ "sign",
+ "signbit",
+ "sin",
+ "sinh",
+ "square",
+ "sqrt",
+ "subtract",
+ "tan",
+ "tanh",
+ "trunc",
+ # Indexing Functions
+ "take",
+ "take_along_axis",
+ # Linear Algebra Functions
+ "matmul",
+ "matrix_transpose",
+ "tensordot",
+ "vecdot",
+ # Manipulation Functions
+ "broadcast_arrays",
+ "broadcast_to",
+ "concat",
+ "expand_dims",
+ "flip",
+ "moveaxis",
+ "permute_dims",
+ "repeat",
+ "reshape",
+ "roll",
+ "squeeze",
+ "stack",
+ "tile",
+ "unstack",
+ # Searching Functions
+ "argmax",
+ "argmin",
+ "count_nonzero",
+ "nonzero",
+ "searchsorted",
+ "where",
+ # Set Functions
+ "unique_all",
+ "unique_counts",
+ "unique_inverse",
+ "unique_values",
+ # Sorting Functions
+ "argsort",
+ "sort",
+ # Statistical Functions
+ "cumulative_prod",
+ "cumulative_sum",
+ "max",
+ "mean",
+ "min",
+ "prod",
+ "std",
+ "sum",
+ "var",
+ # Utility Functions
+ "all",
+ "any",
+ "diff",
+ ],
+ "fft": [
+ "fft",
+ "ifft",
+ "fftn",
+ "ifftn",
+ "rfft",
+ "irfft",
+ "rfftn",
+ "irfftn",
+ "hfft",
+ "ihfft",
+ "fftfreq",
+ "rfftfreq",
+ "fftshift",
+ "ifftshift",
+ ],
+ "linalg": [
+ "cholesky",
+ "cross",
+ "det",
+ "diagonal",
+ "eigh",
+ "eigvalsh",
+ "inv",
+ "matmul",
+ "matrix_norm",
+ "matrix_power",
+ "matrix_rank",
+ "matrix_transpose",
+ "outer",
+ "pinv",
+ "qr",
+ "slogdet",
+ "solve",
+ "svd",
+ "svdvals",
+ "tensordot",
+ "trace",
+ "vecdot",
+ "vector_norm",
+ ],
+}
-from ._helpers import import_, wrapped_libraries
+XFAILS = {
+ ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [],
+ ("dask.array", ""): ["from_dlpack", "take_along_axis"],
+ ("dask.array", "linalg"): [
+ "cross",
+ "det",
+ "eigh",
+ "eigvalsh",
+ "matrix_power",
+ "pinv",
+ "slogdet",
+ ],
+}
-import pytest
-@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
-def test_all(library):
- import_(library, wrapper=True)
+def all_names(mod):
+ """Return all names available in a module."""
+ objs = {}
+ clone_module(mod.__name__, objs)
+ return set(objs)
+
+
+def get_mod(library, module, *, compat):
+ if compat:
+ library = f"array_api_compat.{library}"
+ xp = pytest.importorskip(library)
+ return getattr(xp, module) if module else xp
+
+
+@pytest.mark.parametrize("module", list(NAMES))
+@pytest.mark.parametrize("library", wrapped_libraries)
+def test_array_api_names(library, module):
+ """Test that __all__ isn't missing any exports
+ dictated by the Standard.
+ """
+ mod = get_mod(library, module, compat=True)
+ missing = set(NAMES[module]) - all_names(mod)
+ xfail = set(XFAILS.get((library, module), []))
+ xpass = xfail - missing
+ fails = missing - xfail
+ assert not xpass, f"Names in XFAILS are defined: {xpass}"
+ assert not fails, f"Missing exports: {fails}"
+
+
+@pytest.mark.parametrize("module", list(NAMES))
+@pytest.mark.parametrize("library", wrapped_libraries)
+def test_compat_doesnt_hide_names(library, module):
+ """The base namespace can have more names than the ones explicitly exported
+ by array-api-compat. Test that we're not suppressing them.
+ """
+ bare_mod = get_mod(library, module, compat=False)
+ compat_mod = get_mod(library, module, compat=True)
+
+ missing = all_names(bare_mod) - all_names(compat_mod)
+ missing = {name for name in missing if not name.startswith("_")}
+ assert not missing, f"Non-Array API names have been hidden: {missing}"
- for mod_name in sys.modules:
- if not mod_name.startswith('array_api_compat.' + library):
- continue
- module = sys.modules[mod_name]
+@pytest.mark.parametrize("module", list(NAMES))
+@pytest.mark.parametrize("library", wrapped_libraries)
+def test_compat_doesnt_add_names(library, module):
+ """Test that array-api-compat isn't adding names to the namespace
+ besides those defined by the Array API Standard.
+ """
+ bare_mod = get_mod(library, module, compat=False)
+ compat_mod = get_mod(library, module, compat=True)
- # TODO: We should define __all__ in the __init__.py files and test it
- # there too.
- if not hasattr(module, '__all__'):
- continue
+ aapi_names = set(NAMES[module])
+ spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names
+ # Quietly ignore *Result dataclasses
+ spurious = {name for name in spurious if not name.endswith("Result")}
+ assert not spurious, (
+ f"array-api-compat is adding non-Array API names: {spurious}"
+ )
- dir_names = [n for n in dir(module) if not n.startswith('_')]
- if '__array_namespace_info__' in dir(module):
- dir_names.append('__array_namespace_info__')
- ignore_all_names = getattr(module, '_all_ignore', [])
- ignore_all_names += ['annotations', 'TYPE_CHECKING']
- dir_names = set(dir_names) - set(ignore_all_names)
- all_names = module.__all__
- if set(dir_names) != set(all_names):
- assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
- assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
+@pytest.mark.parametrize(
+ "name", [name for name in NAMES[""] if hasattr(builtins, name)]
+)
+@pytest.mark.parametrize("library", wrapped_libraries)
+def test_builtins_collision(library, name):
+ """Test that xp.bool is not accidentally builtins.bool, etc."""
+ xp = pytest.importorskip(f"array_api_compat.{library}")
+ assert getattr(xp, name) is not getattr(builtins, name)
diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py
index 4076c74c..311efc37 100644
--- a/tests/test_array_namespace.py
+++ b/tests/test_array_namespace.py
@@ -2,68 +2,72 @@
import sys
import warnings
-import jax
import numpy as np
import pytest
-import torch
-import paddle
import array_api_compat
from array_api_compat import array_namespace
-from ._helpers import import_, all_libraries, wrapped_libraries
+from ._helpers import all_libraries, wrapped_libraries, xfail
+
@pytest.mark.parametrize("use_compat", [True, False, None])
-@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
-@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
-def test_array_namespace(library, api_version, use_compat):
- xp = import_(library)
+@pytest.mark.parametrize(
+ "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"]
+)
+@pytest.mark.parametrize("library", all_libraries)
+def test_array_namespace(request, library, api_version, use_compat):
+ xp = pytest.importorskip(library)
array = xp.asarray([1.0, 2.0, 3.0])
- if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
+ if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
- namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
+ if (library == "sparse" and api_version in ("2023.12", "2024.12")) or (
+ library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12")
+ ):
+ xfail(request, "Unsupported API version")
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', UserWarning)
+ namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)
if use_compat is False or use_compat is None and library not in wrapped_libraries:
- if library == "jax.numpy" and use_compat is None:
- import jax.numpy
- if hasattr(jax.numpy, "__array_api_version__"):
- # JAX v0.4.32 or later uses jax.numpy directly
- assert namespace == jax.numpy
- else:
- # JAX v0.4.31 or earlier uses jax.experimental.array_api
- import jax.experimental.array_api
- assert namespace == jax.experimental.array_api
+ if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"):
+ # Backwards compatibility for JAX <0.4.32
+ import jax.experimental.array_api
+ assert namespace == jax.experimental.array_api
else:
assert namespace == xp
+ elif library == "dask.array":
+ assert namespace == array_api_compat.dask.array
else:
- if library == "dask.array":
- assert namespace == array_api_compat.dask.array
- else:
- assert namespace == getattr(array_api_compat, library)
+ assert namespace == getattr(array_api_compat, library)
if library == "numpy":
# check that the same namespace is returned for NumPy scalars
- scalar_namespace = array_api_compat.array_namespace(
- xp.float64(0.0), api_version=api_version, use_compat=use_compat
- )
- assert scalar_namespace == namespace
-
- # Check that array_namespace works even if jax.experimental.array_api
- # hasn't been imported yet (it monkeypatches __array_namespace__
- # onto JAX arrays, but we should support them regardless). The only way to
- # do this is to use a subprocess, since we cannot un-import it and another
- # test probably already imported it.
- if library == "jax.numpy" and sys.version_info >= (3, 9):
- code = f"""\
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', UserWarning)
+
+ scalar_namespace = array_namespace(
+ xp.float64(0.0), api_version=api_version, use_compat=use_compat
+ )
+ assert scalar_namespace == namespace
+
+
+def test_jax_backwards_compat():
+ """On JAX <0.4.32, test that array_namespace works even if
+ jax.experimental.array_api has not been imported yet.
+ """
+ pytest.importorskip("jax")
+ code = """\
import sys
import jax.numpy
import array_api_compat
-array = jax.numpy.asarray([1.0, 2.0, 3.0])
+array = jax.numpy.asarray([1.0, 2.0, 3.0])
assert 'jax.experimental.array_api' not in sys.modules
-namespace = array_api_compat.array_namespace(array, api_version={api_version!r})
+namespace = array_api_compat.array_namespace(array)
if hasattr(jax.numpy, '__array_api_version__'):
assert namespace == jax.numpy
@@ -71,13 +75,15 @@ def test_array_namespace(library, api_version, use_compat):
import jax.experimental.array_api
assert namespace == jax.experimental.array_api
"""
- subprocess.run([sys.executable, "-c", code], check=True)
+ subprocess.check_call([sys.executable, "-c", code])
+
def test_jax_zero_gradient():
+ jax = pytest.importorskip("jax")
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
- assert (array_api_compat.get_namespace(jax_zero) is
- array_api_compat.get_namespace(jx))
+ assert array_namespace(jax_zero) is array_namespace(jx)
+
def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -87,60 +93,31 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
-def test_array_namespace_errors_torch():
- y = torch.asarray([1, 2])
- x = np.asarray([1, 2])
- pytest.raises(TypeError, lambda: array_namespace(x, y))
+@pytest.mark.parametrize("library", all_libraries)
+def test_array_namespace_many_args(library):
+ xp = pytest.importorskip(library)
+ a = xp.asarray(1)
+ b = xp.asarray(2)
+ assert array_namespace(a, b) is array_namespace(a)
-def test_array_namespace_errors_paddle():
- y = paddle.to_tensor([1, 2])
- x = np.asarray([1, 2])
- pytest.raises(TypeError, lambda: array_namespace(x, y))
-
-def test_api_version():
- x = torch.asarray([1, 2])
- torch_ = import_("torch", wrapper=True)
- assert array_namespace(x, api_version="2023.12") == torch_
- assert array_namespace(x, api_version=None) == torch_
- assert array_namespace(x) == torch_
- # Should issue a warning
- with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2021.12") == torch_
- assert len(w) == 1
- assert "2021.12" in str(w[0].message)
-
- # Should issue a warning
- with warnings.catch_warnings(record=True) as w:
- assert array_namespace(x, api_version="2022.12") == torch_
- assert len(w) == 1
- assert "2022.12" in str(w[0].message)
-
- pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
-def test_get_namespace():
- # Backwards compatible wrapper
- assert array_api_compat.get_namespace is array_api_compat.array_namespace
+def test_array_namespace_mismatch():
+ xp = pytest.importorskip("array_api_strict")
+ with pytest.raises(TypeError, match="Multiple namespaces"):
+ array_namespace(np.asarray(1), xp.asarray(1))
-def test_python_scalars_torch():
- a = torch.asarray([1, 2])
- xp = import_("torch", wrapper=True)
- pytest.raises(TypeError, lambda: array_namespace(1))
- pytest.raises(TypeError, lambda: array_namespace(1.0))
- pytest.raises(TypeError, lambda: array_namespace(1j))
- pytest.raises(TypeError, lambda: array_namespace(True))
- pytest.raises(TypeError, lambda: array_namespace(None))
+def test_get_namespace():
+ # Backwards compatible wrapper
+ assert array_api_compat.get_namespace is array_namespace
- assert array_namespace(a, 1) == xp
- assert array_namespace(a, 1.0) == xp
- assert array_namespace(a, 1j) == xp
- assert array_namespace(a, True) == xp
- assert array_namespace(a, None) == xp
-def test_python_scalars_paddle():
- a = paddle.to_tensor([1, 2])
- xp = import_("paddle", wrapper=True)
+@pytest.mark.parametrize("library", all_libraries)
+def test_python_scalars(library):
+ xp = pytest.importorskip(library)
+ a = xp.asarray([1, 2])
+ xp = array_namespace(a)
pytest.raises(TypeError, lambda: array_namespace(1))
pytest.raises(TypeError, lambda: array_namespace(1.0))
@@ -148,8 +125,8 @@ def test_python_scalars_paddle():
pytest.raises(TypeError, lambda: array_namespace(True))
pytest.raises(TypeError, lambda: array_namespace(None))
- assert array_namespace(a, 1) == xp
- assert array_namespace(a, 1.0) == xp
- assert array_namespace(a, 1j) == xp
- assert array_namespace(a, True) == xp
- assert array_namespace(a, None) == xp
+ assert array_namespace(a, 1) is xp
+ assert array_namespace(a, 1.0) is xp
+ assert array_namespace(a, 1j) is xp
+ assert array_namespace(a, True) is xp
+ assert array_namespace(a, None) is xp
diff --git a/tests/test_common.py b/tests/test_common.py
index 23ac53d1..575127a0 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,18 +1,25 @@
+import math
+
+import pytest
+import numpy as np
+import array
+from numpy.testing import assert_equal
+
from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array, is_paddle_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
+ is_ndonnx_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
- is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, is_paddle_namespace,
+ is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
+ is_array_api_strict_namespace, is_ndonnx_namespace, is_paddle_namespace
)
-from array_api_compat import is_array_api_obj, device, to_device
-
-from ._helpers import import_, wrapped_libraries, all_libraries
+from array_api_compat import (
+ device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
+)
+from array_api_compat.common._helpers import _DASK_DEVICE
+from ._helpers import all_libraries, import_, wrapped_libraries, xfail
-import pytest
-import numpy as np
-import array
-from numpy.testing import assert_allclose
is_array_functions = {
'numpy': 'is_numpy_array',
@@ -21,6 +28,7 @@
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
'sparse': 'is_pydata_sparse_array',
+ 'ndonnx': 'is_ndonnx_array',
'paddle': 'is_paddle_array',
}
@@ -31,6 +39,8 @@
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
+ 'array_api_strict': 'is_array_api_strict_namespace',
+ 'ndonnx': 'is_ndonnx_namespace',
'paddle': 'is_paddle_namespace',
}
@@ -57,18 +67,154 @@ def test_is_xp_namespace(library, func):
assert is_func(lib) == (func == is_namespace_functions[library])
+@pytest.mark.parametrize('library', all_libraries)
+def test_xp_is_array_generics(library):
+ """
+ Test that scalar selection on a xp.ndarray always returns
+ an object that matches with exactly one among the is_*_array
+ function of the same library and is_numpy_array.
+ """
+ lib = import_(library)
+ x = lib.asarray([1, 2, 3])
+ x0 = x[0]
+
+ matches = []
+ for library2, func in is_array_functions.items():
+ is_func = globals()[func]
+ if is_func(x0):
+ matches.append(library2)
+
+ if library == "array_api_strict":
+ # There is no is_array_api_strict_array() function
+ assert matches == []
+ else:
+ assert matches in ([library], ["numpy"])
+
+
@pytest.mark.parametrize("library", all_libraries)
-def test_device(library):
- xp = import_(library, wrapper=True)
+def test_is_writeable_array(library):
+ lib = import_(library)
+ x = lib.asarray([1, 2, 3])
+ if is_writeable_array(x):
+ x[1] = 4
+ else:
+ with pytest.raises((TypeError, ValueError)):
+ x[1] = 4
+
+
+def test_is_writeable_array_numpy():
+ x = np.asarray([1, 2, 3])
+ assert is_writeable_array(x)
+ x.flags.writeable = False
+ assert not is_writeable_array(x)
+
+
+@pytest.mark.parametrize("library", all_libraries)
+def test_size(library):
+ xp = import_(library)
+ x = xp.asarray([1, 2, 3])
+ assert size(x) == 3
+
+
+@pytest.mark.parametrize("library", all_libraries)
+def test_size_none(library):
+ if library == "sparse":
+ pytest.skip("No arange(); no indexing by sparse arrays")
+
+ xp = import_(library)
+ x = xp.arange(10)
+ x = x[x < 5]
+
+ # dask.array now has shape=(nan, ) and size=nan
+ # ndonnx now has shape=(None, ) and size=None
+ # Eager libraries have shape=(5, ) and size=5
+ assert size(x) in (None, 5)
- # We can't test much for device() and to_device() other than that
- # x.to_device(x.device) works.
+@pytest.mark.parametrize("library", all_libraries)
+def test_is_lazy_array(library):
+ lib = import_(library)
+ x = lib.asarray([1, 2, 3])
+ assert isinstance(is_lazy_array(x), bool)
+
+
+@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)])
+def test_is_lazy_array_nan_size(shape, monkeypatch):
+ """Test is_lazy_array() on an unknown Array API compliant object
+ with NaN (like Dask) or None (like ndonnx) in its shape
+ """
+ xp = import_("array_api_strict")
+ x = xp.asarray(1)
+ assert not is_lazy_array(x)
+ monkeypatch.setattr(type(x), "shape", shape)
+ assert is_lazy_array(x)
+
+
+@pytest.mark.parametrize("exc", [TypeError, AssertionError])
+def test_is_lazy_array_bool_raises(exc, monkeypatch):
+ """Test is_lazy_array() on an unknown Array API compliant object
+ where calling bool() raises:
+ - TypeError: e.g. like jitted JAX. This is the proper exception which
+ lazy arrays should raise as per the Array API specification
+ - something else: e.g. like Dask, where bool() triggers compute()
+ which can result in any kind of exception to be raised
+ """
+ xp = import_("array_api_strict")
+ x = xp.asarray(1)
+ assert not is_lazy_array(x)
+
+ def __bool__(self):
+ raise exc("Hello world")
+
+ monkeypatch.setattr(type(x), "__bool__", __bool__)
+ assert is_lazy_array(x)
+
+
+@pytest.mark.parametrize(
+ 'func',
+ list(is_array_functions.values())
+ + ["is_array_api_obj", "is_lazy_array", "is_writeable_array"]
+)
+def test_is_array_any_object(func):
+ """Test that is_*_array functions return False and don't raise on non-array objects
+ """
+ func = globals()[func]
+
+ # These objects are missing attributes such as __name__
+ assert not func(object())
+ assert not func(None)
+ assert not func(1)
+
+ class C:
+ pass
+
+ assert not func(C())
+
+
+@pytest.mark.parametrize("library", all_libraries)
+def test_device_to_device(library, request):
+ if library == "ndonnx":
+ xfail(request, reason="Stub raises ValueError")
+ if library == "sparse":
+ xfail(request, reason="No __array_namespace_info__()")
+ if library == "array_api_strict":
+ if np.__version__ < "2":
+ xfail(request, reason="no copy argument of np.asarray")
+
+ xp = import_(library, wrapper=True)
+ devices = xp.__array_namespace_info__().devices()
+
+ # Default device
x = xp.asarray([1, 2, 3])
dev = device(x)
- x2 = to_device(x, dev)
- assert device(x) == device(x2)
+ for dev in devices:
+ if dev is None: # JAX >=0.5.3
+ continue
+ if dev is _DASK_DEVICE: # TODO this needs a better design
+ continue
+ y = to_device(x, dev)
+ assert device(y) == dev
@pytest.mark.parametrize("library", wrapped_libraries)
@@ -87,20 +233,30 @@ def test_to_device_host(library):
# a `device(x)` query; however, what's really important
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
- assert_allclose(x, expected)
+ assert_equal(x, expected)
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
- # Allow rest of test to execute instead of immediately xfailing
- # xref https://github.com/pandas-dev/pandas/issues/38902
-
# TODO: remove xfail once
# https://github.com/dask/dask/issues/8260 is resolved
- request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
- if source_library == "cupy" and target_library != "cupy":
+ xfail(request, reason="Bug in dask raising error on conversion")
+
+ elif (
+ source_library == "ndonnx"
+ and target_library not in ("array_api_strict", "ndonnx", "numpy")
+ ):
+ xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown")
+ elif source_library == "ndonnx" and target_library == "numpy":
+ xfail(request, reason="produces numpy array of ndonnx scalar arrays")
+ elif target_library == "ndonnx" and source_library in ("torch", "dask.array", "jax.numpy"):
+ xfail(request, reason="unable to infer dtype")
+
+ elif source_library == "jax.numpy" and target_library == "torch":
+ xfail(request, reason="casts int to float")
+ elif source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
if source_library == "paddle" or target_library == "paddle":
@@ -112,14 +268,17 @@ def test_asarray_cross_library(source_library, target_library, request):
)
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")
+
src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[is_array_functions[target_library]]
- a = src_lib.asarray([1, 2, 3])
+ a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32)
b = tgt_lib.asarray(a)
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
+ assert b.dtype == tgt_lib.int32
+
@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
@@ -132,91 +291,99 @@ def test_asarray_copy(library):
xp = import_(library, wrapper=True)
asarray = xp.asarray
is_lib_func = globals()[is_array_functions[library]]
- all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
-
- if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
- supports_copy_false = False
- elif library in ['cupy', 'dask.array']:
- supports_copy_false = False
- else:
- supports_copy_false = True
a = asarray([1])
b = asarray(a, copy=True)
assert is_lib_func(b)
a[0] = 0
- assert all(b[0] == 1)
- assert all(a[0] == 0)
+ assert b[0] == 1
+ assert a[0] == 0
a = asarray([1])
- if supports_copy_false:
- b = asarray(a, copy=False)
- assert is_lib_func(b)
- a[0] = 0
- assert all(b[0] == 0)
- else:
- pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
- a = asarray([1])
- if supports_copy_false:
- pytest.raises(ValueError, lambda: asarray(a, copy=False,
- dtype=xp.float64))
- else:
- pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
+ # Test copy=False within the same namespace
+ b = asarray(a, copy=False)
+ assert is_lib_func(b)
+ a[0] = 0
+ assert b[0] == 0
+ with pytest.raises(ValueError):
+ asarray(a, copy=False, dtype=xp.float64)
+ # copy=None defaults to False when possible
a = asarray([1])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0
- assert all(b[0] == 0)
+ assert b[0] == 0
+ # copy=None defaults to True when impossible
a = asarray([1.0], dtype=xp.float32)
assert a.dtype == xp.float32
b = asarray(a, dtype=xp.float64, copy=None)
assert is_lib_func(b)
assert b.dtype == xp.float64
a[0] = 0.0
- assert all(b[0] == 1.0)
+ assert b[0] == 1.0
+ # copy=None defaults to False when possible
a = asarray([1.0], dtype=xp.float64)
assert a.dtype == xp.float64
b = asarray(a, dtype=xp.float64, copy=None)
assert is_lib_func(b)
assert b.dtype == xp.float64
a[0] = 0.0
- assert all(b[0] == 0.0)
+ assert b[0] == 0.0
# Python built-in types
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
- asarray(obj, copy=True) # No error
- asarray(obj, copy=None) # No error
- if supports_copy_false:
- pytest.raises(ValueError, lambda: asarray(obj, copy=False))
- else:
- pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
+ asarray(obj, copy=True) # No error
+ asarray(obj, copy=None) # No error
+
+ with pytest.raises(ValueError):
+ asarray(obj, copy=False)
# Use the standard library array to test the buffer protocol
- a = array.array('f', [1.0])
+ a = array.array("f", [1.0])
b = asarray(a, copy=True)
assert is_lib_func(b)
a[0] = 0.0
- assert all(b[0] == 1.0)
+ assert b[0] == 1.0
- a = array.array('f', [1.0])
- if supports_copy_false:
+ a = array.array("f", [1.0])
+ if library in ("cupy", "dask.array"):
+ with pytest.raises(ValueError):
+ asarray(a, copy=False)
+ else:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0.0
- assert all(b[0] == 0.0)
- else:
- pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
+ assert b[0] == 0.0
- a = array.array('f', [1.0])
+ a = array.array("f", [1.0])
b = asarray(a, copy=None)
assert is_lib_func(b)
a[0] = 0.0
- if library == 'cupy':
+ if library in ("cupy", "dask.array"):
# A copy is required for libraries where the default device is not CPU
- assert all(b[0] == 1.0)
+ # dask changed behaviour of copy=None in 2024.12 to copy;
+ # this wrapper ensures the same behaviour in older versions too.
+ # https://github.com/dask/dask/pull/11524/
+ assert b[0] == 1.0
else:
- assert all(b[0] == 0.0)
+ # copy=None defaults to False when possible
+ assert b[0] == 0.0
+
+
+@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"])
+def test_clip_out(library):
+ """Test non-standard out= parameter for clip()
+
+ (see "Avoid Restricting Behavior that is Outside the Scope of the Standard"
+ in https://data-apis.org/array-api-compat/dev/special-considerations.html)
+ """
+ xp = import_(library, wrapper=True)
+ x = xp.asarray([10, 20, 30])
+ out = xp.zeros_like(x)
+ xp.clip(x, 15, 25, out=out)
+ expect = xp.asarray([15, 20, 25])
+ assert xp.all(out == expect)
diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py
new file mode 100644
index 00000000..ec8995f7
--- /dev/null
+++ b/tests/test_copies_or_views.py
@@ -0,0 +1,64 @@
+"""
+A collection of tests to make sure that wrapped namespaces agree with the bare ones
+on whether to return a view or a copy of inputs.
+"""
+import pytest
+from ._helpers import import_, wrapped_libraries
+
+
+FUNC_INPUTS = [
+ # func_name, arr_input, dtype, scalar_value
+ ('abs', [1, 2], 'int8', 3),
+ ('abs', [1, 2], 'float32', 3.),
+ ('ceil', [1, 2], 'int8', 3),
+ ('clip', [1, 2], 'int8', 3),
+ ('conj', [1, 2], 'int8', 3),
+ ('floor', [1, 2], 'int8', 3),
+ ('imag', [1j, 2j], 'complex64', 3),
+ ('positive', [1, 2], 'int8', 3),
+ ('real', [1., 2.], 'float32', 3.),
+ ('round', [1, 2], 'int8', 3),
+ ('sign', [0, 0], 'float32', 3),
+ ('trunc', [1, 2], 'int8', 3),
+ ('trunc', [1, 2], 'float32', 3),
+]
+
+
+def ensure_unary(func, arr):
+ """Make a trivial unary function from func."""
+ if func.__name__ == 'clip':
+ return lambda x: func(x, arr[0], arr[1])
+ return func
+
+
+def is_view(func, a, value):
+ """Apply `func`, mutate the output; does the input change?"""
+ b = func(a)
+ b[0] = value
+ return a[0] == value
+
+
+@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict'])
+@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS])
+def test_view_or_copy(inputs, xp_name):
+ bare_xp = import_(xp_name, wrapper=False)
+ wrapped_xp = import_(xp_name, wrapper=True)
+
+ func_name, arr_input, dtype_str, value = inputs
+ dtype = getattr(bare_xp, dtype_str)
+
+ bare_func = getattr(bare_xp, func_name)
+ bare_func = ensure_unary(bare_func, arr_input)
+
+ wrapped_func = getattr(wrapped_xp, func_name)
+ wrapped_func = ensure_unary(wrapped_func, arr_input)
+
+ # bare namespace: mutate the output, does the input change?
+ a = bare_xp.asarray(arr_input, dtype=dtype)
+ is_view_bare = is_view(bare_func, a, value)
+
+ # wrapped namespace: mutate the output, does the input change?
+ a1 = wrapped_xp.asarray(arr_input, dtype=dtype)
+ is_view_wrapped = is_view(wrapped_func, a1, value)
+
+ assert is_view_bare == is_view_wrapped
diff --git a/tests/test_cupy.py b/tests/test_cupy.py
new file mode 100644
index 00000000..4745b983
--- /dev/null
+++ b/tests/test_cupy.py
@@ -0,0 +1,45 @@
+import pytest
+from array_api_compat import device, to_device
+
+xp = pytest.importorskip("array_api_compat.cupy")
+from cupy.cuda import Stream
+
+
+@pytest.mark.parametrize(
+ "make_stream",
+ [
+ lambda: Stream(),
+ lambda: Stream(non_blocking=True),
+ lambda: Stream(null=True),
+ lambda: Stream(ptds=True),
+ ],
+)
+def test_to_device_with_stream(make_stream):
+ devices = xp.__array_namespace_info__().devices()
+
+ a = xp.asarray([1, 2, 3])
+ for dev in devices:
+ # Streams are device-specific and must be created within
+ # the context of the device...
+ with dev:
+ stream = make_stream()
+ # ... however, to_device() does not need to be inside the
+ # device context.
+ b = to_device(a, dev, stream=stream)
+ assert device(b) == dev
+
+
+def test_to_device_with_dlpack_stream():
+ devices = xp.__array_namespace_info__().devices()
+
+ a = xp.asarray([1, 2, 3])
+ for dev in devices:
+ # Streams are device-specific and must be created within
+ # the context of the device...
+ with dev:
+ s1 = Stream()
+
+ # ... however, to_device() does not need to be inside the
+ # device context.
+ b = to_device(a, dev, stream=s1.ptr)
+ assert device(b) == dev
diff --git a/tests/test_dask.py b/tests/test_dask.py
new file mode 100644
index 00000000..fb0a84d4
--- /dev/null
+++ b/tests/test_dask.py
@@ -0,0 +1,183 @@
+from contextlib import contextmanager
+
+import numpy as np
+import pytest
+
+try:
+ import dask
+ import dask.array as da
+except ImportError:
+ pytestmark = pytest.skip(allow_module_level=True, reason="dask not found")
+
+from array_api_compat import array_namespace
+
+
+@pytest.fixture
+def xp():
+ """Fixture returning the wrapped dask namespace"""
+ return array_namespace(da.empty(0))
+
+
+@contextmanager
+def assert_no_compute():
+ """
+ Context manager that raises if at any point inside it anything calls compute()
+ or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
+ """
+
+ def get(dsk, *args, **kwargs):
+ raise AssertionError("Called compute() or persist()")
+
+ with dask.config.set(scheduler=get):
+ yield
+
+
+def test_assert_no_compute():
+ """Test the assert_no_compute context manager"""
+ a = da.asarray(True)
+ with pytest.raises(AssertionError, match="Called compute"):
+ with assert_no_compute():
+ bool(a)
+
+ # Exiting the context manager restores the original scheduler
+ assert bool(a) is True
+
+
+# Test no_compute for functions that use generic _aliases with xp=np
+
+
+def test_unary_ops_no_compute(xp):
+ with assert_no_compute():
+ a = xp.asarray([1.5, -1.5])
+ xp.ceil(a)
+ xp.floor(a)
+ xp.trunc(a)
+ xp.sign(a)
+
+
+def test_matmul_tensordot_no_compute(xp):
+ A = da.ones((4, 4), chunks=2)
+ B = da.zeros((4, 4), chunks=2)
+ with assert_no_compute():
+ xp.matmul(A, B)
+ xp.tensordot(A, B)
+
+
+# Test no_compute for functions that are fully bespoke for dask
+
+
+def test_asarray_no_compute(xp):
+ with assert_no_compute():
+ a = xp.arange(10)
+ xp.asarray(a)
+ xp.asarray(a, dtype=np.int16)
+ xp.asarray(a, dtype=a.dtype)
+ xp.asarray(a, copy=True)
+ xp.asarray(a, copy=True, dtype=np.int16)
+ xp.asarray(a, copy=True, dtype=a.dtype)
+ xp.asarray(a, copy=False)
+ xp.asarray(a, copy=False, dtype=a.dtype)
+
+
+@pytest.mark.parametrize("copy", [True, False])
+def test_astype_no_compute(xp, copy):
+ with assert_no_compute():
+ a = xp.arange(10)
+ xp.astype(a, np.int16, copy=copy)
+ xp.astype(a, a.dtype, copy=copy)
+
+
+def test_clip_no_compute(xp):
+ with assert_no_compute():
+ a = xp.arange(10)
+ xp.clip(a)
+ xp.clip(a, 1)
+ xp.clip(a, 1, 8)
+
+
+@pytest.mark.parametrize("chunks", (5, 10))
+def test_sort_argsort_nocompute(xp, chunks):
+ with assert_no_compute():
+ a = xp.arange(10, chunks=chunks)
+ xp.sort(a)
+ xp.argsort(a)
+
+
+def test_generators_are_lazy(xp):
+ """
+ Test that generator functions are fully lazy, e.g. that
+ da.ones(n) is not implemented as da.asarray(np.ones(n))
+ """
+ size = 100_000_000_000 # 800 GB
+ chunks = size // 10 # 10x 80 GB chunks
+
+ with assert_no_compute():
+ xp.zeros(size, chunks=chunks)
+ xp.ones(size, chunks=chunks)
+ xp.empty(size, chunks=chunks)
+ xp.full(size, fill_value=123, chunks=chunks)
+ a = xp.arange(size, chunks=chunks)
+ xp.zeros_like(a)
+ xp.ones_like(a)
+ xp.empty_like(a)
+ xp.full_like(a, fill_value=123)
+
+
+@pytest.mark.parametrize("axis", [0, 1])
+@pytest.mark.parametrize("func", ["sort", "argsort"])
+def test_sort_argsort_chunks(xp, func, axis):
+ """Test that sort and argsort are functionally correct when
+ the array is chunked along the sort axis, e.g. the sort is
+ not just local to each chunk.
+ """
+ a = da.random.random((10, 10), chunks=(5, 5))
+ actual = getattr(xp, func)(a, axis=axis)
+ expect = getattr(np, func)(a.compute(), axis=axis)
+ np.testing.assert_array_equal(actual, expect)
+
+
+@pytest.mark.parametrize(
+ "shape,chunks",
+ [
+ # 3 GiB; 128 MiB per chunk; must rechunk before sorting.
+ # Sort chunks can be 128 MiB each; no need for final rechunk.
+ ((20_000, 20_000), "auto"),
+ # 3 GiB; 128 MiB per chunk; must rechunk before sorting.
+ # Must sort on two 1.5 GiB chunks; benefits from final rechunk.
+ ((2, 2**30 * 3 // 16), "auto"),
+ # 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
+ # Surely the user must know what they're doing, so don't
+ # perform the final rechunk.
+ ((2, 2**30 * 3 // 16), (1, -1)),
+ ],
+)
+@pytest.mark.parametrize("func", ["sort", "argsort"])
+def test_sort_argsort_chunk_size(xp, func, shape, chunks):
+ """
+ Test that sort and argsort produce reasonably-sized chunks
+ in the output array, even if they had to go through a singular
+ huge one to perform the operation.
+ """
+ a = da.random.random(shape, chunks=chunks)
+ b = getattr(xp, func)(a)
+ max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
+ assert (
+ max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
+ or b.chunks == a.chunks
+ )
+
+
+@pytest.mark.parametrize("func", ["sort", "argsort"])
+def test_sort_argsort_meta(xp, func):
+ """Test meta-namespace other than numpy"""
+ mxp = pytest.importorskip("array_api_strict")
+ typ = type(mxp.asarray(0))
+ a = da.random.random(10)
+ b = a.map_blocks(mxp.asarray)
+ assert isinstance(b._meta, typ)
+ c = getattr(xp, func)(b)
+ assert isinstance(c._meta, typ)
+ d = c.compute()
+ # Note: np.sort(array_api_strict.asarray(0)) would return a numpy array
+ assert isinstance(d, typ)
+ np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))
diff --git a/tests/test_jax.py b/tests/test_jax.py
new file mode 100644
index 00000000..285958d4
--- /dev/null
+++ b/tests/test_jax.py
@@ -0,0 +1,38 @@
+from numpy.testing import assert_equal
+import pytest
+
+from array_api_compat import device, to_device
+
+try:
+ import jax
+ import jax.numpy as jnp
+except ImportError:
+ pytestmark = pytest.skip(allow_module_level=True, reason="jax not found")
+
+HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"
+
+
+@pytest.mark.parametrize(
+ "func",
+ [
+ lambda x: jnp.zeros(1, device=device(x)),
+ lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
+ lambda x: jnp.zeros_like(jnp.empty(1, device=device(x))),
+ lambda x: jnp.full(1, fill_value=0, device=device(x)),
+ pytest.param(
+ lambda x: jnp.asarray([0], device=device(x)),
+ marks=pytest.mark.skipif(
+ not HAS_JAX_0_4_31, reason="asarray() has no device= parameter"
+ ),
+ ),
+ lambda x: to_device(jnp.zeros(1), device(x)),
+ ]
+)
+def test_device_jit(func):
+ # Test work around to https://github.com/jax-ml/jax/issues/26000
+ # Also test missing to_device() method in JAX < 0.4.31
+ # when inside jax.jit, even after importing jax.experimental.array_api
+
+ x = jnp.ones(1)
+ assert_equal(func(x), jnp.asarray([0]))
+ assert_equal(jax.jit(func)(x), jnp.asarray([0]))
diff --git a/tests/test_torch.py b/tests/test_torch.py
new file mode 100644
index 00000000..7adb4ab3
--- /dev/null
+++ b/tests/test_torch.py
@@ -0,0 +1,119 @@
+"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
+"""
+import itertools
+
+import pytest
+
+try:
+ import torch
+except ImportError:
+ pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found")
+
+from array_api_compat import torch as xp
+
+
+class TestResultType:
+ def test_empty(self):
+ with pytest.raises(ValueError):
+ xp.result_type()
+
+ def test_one_arg(self):
+ for x in [1, 1.0, 1j, '...', None]:
+ with pytest.raises((ValueError, AttributeError)):
+ xp.result_type(x)
+
+ for x in [xp.float32, xp.int64, torch.complex64]:
+ assert xp.result_type(x) == x
+
+ for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
+ assert xp.result_type(x) == x.dtype
+
+ def test_two_args(self):
+ # Only include here things "unspecified" in the spec
+
+ # scalar, tensor or tensor,tensor
+ for x, y in [
+ (1., 1j),
+ (1j, xp.arange(3)),
+ (True, xp.asarray(3.)),
+ (xp.ones(3) == 1, 1j*xp.ones(3)),
+ ]:
+ assert xp.result_type(x, y) == torch.result_type(x, y)
+
+ # dtype, scalar
+ for x, y in [
+ (1j, xp.int64),
+ (True, xp.float64),
+ ]:
+ assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
+
+ # dtype, dtype
+ for x, y in [
+ (xp.bool, xp.complex64)
+ ]:
+ xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
+ assert xp.result_type(x, y) == torch.result_type(xt, yt)
+
+ def test_multi_arg(self):
+ torch.set_default_dtype(torch.float32)
+
+ args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
+ assert xp.result_type(*args) == torch.float16
+
+ args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
+ assert xp.result_type(*args) == xp.complex64
+
+ args = [1, 2, 3j, xp.float64, 4, 5, 6]
+ assert xp.result_type(*args) == xp.complex128
+
+ args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
+ assert xp.result_type(*args) == xp.complex128
+
+ i64 = xp.ones(1, dtype=xp.int64)
+ f16 = xp.ones(1, dtype=xp.float16)
+ for i in itertools.permutations([i64, f16, 1.0, 1.0]):
+ assert xp.result_type(*i) == xp.float16, f"{i}"
+
+ with pytest.raises(ValueError):
+ xp.result_type(1, 2, 3, 4)
+
+
+ @pytest.mark.parametrize("default_dt", ['float32', 'float64'])
+ @pytest.mark.parametrize("dtype_a",
+ (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
+ )
+ @pytest.mark.parametrize("dtype_b",
+ (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
+ )
+ def test_gh_273(self, default_dt, dtype_a, dtype_b):
+ # Regression test for https://github.com/data-apis/array-api-compat/issues/273
+
+ try:
+ prev_default = torch.get_default_dtype()
+ default_dtype = getattr(torch, default_dt)
+ torch.set_default_dtype(default_dtype)
+
+ a = xp.asarray([2, 1], dtype=dtype_a)
+ b = xp.asarray([1, -1], dtype=dtype_b)
+ dtype_1 = xp.result_type(a, b, 1.0)
+ dtype_2 = xp.result_type(b, a, 1.0)
+ assert dtype_1 == dtype_2
+ finally:
+ torch.set_default_dtype(prev_default)
+
+
+def test_meshgrid():
+ """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
+
+ x, y = xp.asarray([1, 2]), xp.asarray([4])
+
+ X, Y = xp.meshgrid(x, y)
+
+ # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different
+ X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]])
+
+ assert X.shape == X_xy.shape
+ assert xp.all(X == X_xy)
+
+ assert Y.shape == Y_xy.shape
+ assert xp.all(Y == Y_xy)
diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py
index 3c9b5d92..1309d2db 100644
--- a/tests/test_vendoring.py
+++ b/tests/test_vendoring.py
@@ -16,12 +16,14 @@ def test_vendoring_cupy():
def test_vendoring_torch():
+ pytest.importorskip("torch")
from vendor_test import uses_torch
uses_torch._test_torch()
def test_vendoring_dask():
+ pytest.importorskip("dask")
from vendor_test import uses_dask
uses_dask._test_dask()
diff --git a/torch-skips.txt b/torch-skips.txt
index cbf7235b..e69de29b 100644
--- a/torch-skips.txt
+++ b/torch-skips.txt
@@ -1,11 +0,0 @@
-# These tests cause a core dump on CI, so we have to skip them entirely
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)]
-array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
diff --git a/torch-xfails.txt b/torch-xfails.txt
index c972659e..989df0c8 100644
--- a/torch-xfails.txt
+++ b/torch-xfails.txt
@@ -8,31 +8,15 @@ array_api_tests/test_array_object.py::test_getitem
array_api_tests/test_array_object.py::test_setitem
# Masking doesn't suport 0 dimensions in the mask
array_api_tests/test_array_object.py::test_getitem_masking
-# torch doesn't have uint dtypes other than uint8
-array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint16)]
-array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint32)]
-array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint64)]
-array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)]
-array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)]
-array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)]
# Overflow error from large inputs
array_api_tests/test_creation_functions.py::test_arange
# pytorch linspace bug (should be fixed in torch 2.0)
-array_api_tests/test_creation_functions.py::test_linspace
-
-# torch doesn't have higher uint dtypes
-array_api_tests/test_data_type_functions.py::test_iinfo[uint16]
-array_api_tests/test_data_type_functions.py::test_iinfo[uint32]
-array_api_tests/test_data_type_functions.py::test_iinfo[uint64]
# We cannot wrap the tensor object
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
-# tensordot doesn't allow integer dtypes in some corner cases
-array_api_tests/test_linalg.py::test_tensordot
-
# We cannot wrap the tensor object
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]
@@ -45,6 +29,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__trued
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)]
+array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
@@ -61,12 +49,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_acos
array_api_tests/test_operators_and_elementwise_functions.py::test_atan
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
-# overflow near float max
-array_api_tests/test_operators_and_elementwise_functions.py::test_log1p
-
-# torch doesn't handle shifting by more than the bit size correctly
-# https://github.com/pytorch/pytorch/issues/70904
-array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)]
# Torch bug for remainder in some cases with large values
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
@@ -77,11 +59,6 @@ array_api_tests/test_set_functions.py::test_unique_all
# (https://github.com/pytorch/pytorch/issues/94106)
array_api_tests/test_set_functions.py::test_unique_inverse
-# The test suite incorrectly divides by 0 here
-# (https://github.com/data-apis/array-api-tests/issues/170)
-array_api_tests/test_signatures.py::test_func_signature[floor_divide]
-array_api_tests/test_signatures.py::test_func_signature[remainder]
-array_api_tests/test_signatures.py::test_array_method_signature[__mod__]
# We cannot add attributes to the tensor object
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
@@ -90,13 +67,6 @@ array_api_tests/test_signatures.py::test_array_method_signature[to_device]
# We do not attempt to work around special-case differences (most are on
# tensor methods which we couldn't fix anyway).
-array_api_tests/test_special_cases.py::test_binary[__add__((x1_i is +0 or x1_i == -0) and isfinite(x2_i) and x2_i != 0) -> x2_i]
-array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and (x2_i is +0 or x2_i == -0)) -> x1_i]
-array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x2_i is +infinity) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x2_i is -infinity) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is +infinity and isfinite(x2_i)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is -infinity and isfinite(x2_i)) -> -infinity]
array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0]
array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> +0]
@@ -121,41 +91,6 @@ array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i <
array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0]
array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0]
array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is +infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is -infinity) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is +infinity) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is -infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i > 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is NaN and not x2_i == 0) -> NaN]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is +0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is -0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is +0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is -0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i < 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i < 0) -> +0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i > 0) -> -0]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
-array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
-array_api_tests/test_special_cases.py::test_binary[add(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0]
array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
@@ -164,7 +99,6 @@ array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinit
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0]
array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0]
-array_api_tests/test_special_cases.py::test_iop[__iadd__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0]
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
@@ -176,31 +110,57 @@ array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0
# Float correction is not supported by pytorch
# (https://github.com/data-apis/array-api-tests/issues/168)
-array_api_tests/test_special_cases.py::test_empty_arrays[std]
-array_api_tests/test_special_cases.py::test_empty_arrays[var]
-array_api_tests/test_special_cases.py::test_nan_propagation[std]
-array_api_tests/test_special_cases.py::test_nan_propagation[var]
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var
-# The test suite is incorrectly checking sums that have loss of significance
-# (https://github.com/data-apis/array-api-tests/issues/168)
-array_api_tests/test_statistical_functions.py::test_sum
-array_api_tests/test_statistical_functions.py::test_prod
# These functions do not yet support complex numbers
-array_api_tests/test_operators_and_elementwise_functions.py::test_expm1
array_api_tests/test_operators_and_elementwise_functions.py::test_round
array_api_tests/test_set_functions.py::test_unique_counts
array_api_tests/test_set_functions.py::test_unique_values
+# finfo/iinfo.dtype is a string instead of a dtype
+array_api_tests/test_data_type_functions.py::test_finfo_dtype
+array_api_tests/test_data_type_functions.py::test_iinfo_dtype
+
# 2023.12 support
-array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
+# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers
array_api_tests/test_manipulation_functions.py::test_repeat
-array_api_tests/test_signatures.py::test_func_signature[repeat]
# Argument 'device' missing from signature
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
# Argument 'max_version' missing from signature
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
-# Argument 'device' missing from signature
-array_api_tests/test_signatures.py::test_func_signature[astype]
+
+# 2024.12 support
+array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
+array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
+array_api_tests/test_signatures.py::test_array_method_signature[__and__]
+array_api_tests/test_signatures.py::test_array_method_signature[__lshift__]
+array_api_tests/test_signatures.py::test_array_method_signature[__or__]
+array_api_tests/test_signatures.py::test_array_method_signature[__rshift__]
+array_api_tests/test_signatures.py::test_array_method_signature[__xor__]
+
+# 2024.12 support: binary functions reject python scalar arguments
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
+
+# https://github.com/pytorch/pytorch/issues/149815
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal]
+
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_and]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_or]
+array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_xor]
diff --git a/vendor_test/uses_torch.py b/vendor_test/uses_torch.py
index 5804aaff..747ecd51 100644
--- a/vendor_test/uses_torch.py
+++ b/vendor_test/uses_torch.py
@@ -23,7 +23,7 @@ def _test_torch():
assert isinstance(b, torch.Tensor)
assert isinstance(res, torch.Tensor)
- torch.testing.assert_allclose(res, [[1., 2., 3.]])
+ torch.testing.assert_close(res, torch.as_tensor([[1., 2., 3.]]))
assert is_torch_array(res)
assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)