Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
one_hot,
pad,
partition,
setdiff1d,
sinc,
)
from ._lib._at import at
Expand All @@ -21,7 +22,6 @@
default_dtype,
kron,
nunique,
setdiff1d,
)
from ._lib._lazy import lazy_apply

Expand Down
152 changes: 103 additions & 49 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._lib._utils._typing import Array, DType

__all__ = [
"atleast_nd",
"cov",
"expand_dims",
"isclose",
Expand All @@ -29,6 +30,55 @@
]


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.

Parameters
----------
x : array
Input array.
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
until ``res.ndim`` equals `ndim`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1])
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
Array([[[1]]], dtype=array_api_strict.int64)

>>> x = xp.asarray([[[1, 2],
... [3, 4]]])
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
True
"""
if xp is None:
xp = array_namespace(x)

if 1 <= ndim <= 3 and (
is_numpy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
or is_cupy_namespace(xp)
or is_torch_namespace(xp)
):
return getattr(xp, f"atleast_{ndim}d")(x)

return _funcs.atleast_nd(x, ndim=ndim, xp=xp)


def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Estimate a covariance matrix.
Expand Down Expand Up @@ -197,55 +247,6 @@ def expand_dims(
return _funcs.expand_dims(a, axis=axis, xp=xp)


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.

Parameters
----------
x : array
Input array.
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
until ``res.ndim`` equals `ndim`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1])
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
Array([[[1]]], dtype=array_api_strict.int64)

>>> x = xp.asarray([[[1, 2],
... [3, 4]]])
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
True
"""
if xp is None:
xp = array_namespace(x)

if 1 <= ndim <= 3 and (
is_numpy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
or is_cupy_namespace(xp)
or is_torch_namespace(xp)
):
return getattr(xp, f"atleast_{ndim}d")(x)

return _funcs.atleast_nd(x, ndim=ndim, xp=xp)


def isclose(
a: Array | complex,
b: Array | complex,
Expand Down Expand Up @@ -553,6 +554,59 @@ def pad(
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def setdiff1d(
x1: Array | complex,
x2: Array | complex,
/,
*,
assume_unique: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Find the set difference of two arrays.

Return the unique values in `x1` that are not in `x2`.

Parameters
----------
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Returns
-------
array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx

>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)
"""

if xp is None:
xp = array_namespace(x1, x2)

if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
x1, x2 = asarrays(x1, x2, xp=xp)
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)

return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
r"""
Return the normalized sinc function.
Expand Down
40 changes: 3 additions & 37 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,44 +715,10 @@ def setdiff1d(
/,
*,
assume_unique: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Find the set difference of two arrays.

Return the unique values in `x1` that are not in `x2`.

Parameters
----------
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Returns
-------
array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""

>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)
"""
if xp is None:
xp = array_namespace(x1, x2)
# https://github.com/microsoft/pyright/issues/10103
x1_, x2_ = asarrays(x1, x2, xp=xp)

Expand Down
19 changes: 14 additions & 5 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
sinc,
)
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import (
device as get_device,
)
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._compat import is_jax_namespace
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -1264,25 +1263,35 @@ def test_assume_unique(self, xp: ModuleType):
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
def test_shapes(
self,
request: pytest.FixtureRequest,
assume_unique: bool,
shape1: tuple[int, ...],
shape2: tuple[int, ...],
xp: ModuleType,
):
x1 = xp.zeros(shape1)
x2 = xp.zeros(shape2)

if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")

actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.empty((0,)))

@assume_unique
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
def test_python_scalar(
self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool
):
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
x2 = 3
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))

if is_jax_namespace(xp) and assume_unique:
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")

actual = setdiff1d(x2, x1, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))

Expand Down