Skip to content

Commit 2f1eadb

Browse files
committed
ENH: setdiff1d delegate function
1 parent ebe9a5b commit 2f1eadb

File tree

4 files changed

+67
-41
lines changed

4 files changed

+67
-41
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
one_hot,
1212
pad,
1313
partition,
14+
setdiff1d,
1415
sinc,
1516
)
1617
from ._lib._at import at
@@ -21,7 +22,6 @@
2122
default_dtype,
2223
kron,
2324
nunique,
24-
setdiff1d,
2525
)
2626
from ._lib._lazy import lazy_apply
2727

src/array_api_extra/_delegation.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,60 @@ def pad(
553553
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
554554

555555

556+
def setdiff1d(
557+
x1: Array | complex,
558+
x2: Array | complex,
559+
/,
560+
*,
561+
assume_unique: bool = False,
562+
xp: ModuleType | None = None,
563+
) -> Array:
564+
"""
565+
Find the set difference of two arrays.
566+
567+
Return the unique values in `x1` that are not in `x2`.
568+
569+
Parameters
570+
----------
571+
x1 : array | int | float | complex | bool
572+
Input array.
573+
x2 : array
574+
Input comparison array.
575+
assume_unique : bool
576+
If ``True``, the input arrays are both assumed to be unique, which
577+
can speed up the calculation. Default is ``False``.
578+
xp : array_namespace, optional
579+
The standard-compatible namespace for `x1` and `x2`. Default: infer.
580+
581+
Returns
582+
-------
583+
array
584+
1D array of values in `x1` that are not in `x2`. The result
585+
is sorted when `assume_unique` is ``False``, but otherwise only sorted
586+
if the input is sorted.
587+
588+
Examples
589+
--------
590+
>>> import array_api_strict as xp
591+
>>> import array_api_extra as xpx
592+
593+
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
594+
>>> x2 = xp.asarray([3, 4, 5, 6])
595+
>>> xpx.setdiff1d(x1, x2, xp=xp)
596+
Array([1, 2], dtype=array_api_strict.int64)
597+
"""
598+
599+
if xp is None:
600+
xp = array_namespace(x1, x2)
601+
602+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
603+
# https://github.com/microsoft/pyright/issues/10103
604+
x1_, x2_ = asarrays(x1, x2, xp=xp)
605+
return xp.setdiff1d(x1_, x2_, assume_unique=assume_unique)
606+
607+
return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)
608+
609+
556610
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
557611
r"""
558612
Return the normalized sinc function.

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -715,44 +715,10 @@ def setdiff1d(
715715
/,
716716
*,
717717
assume_unique: bool = False,
718-
xp: ModuleType | None = None,
719-
) -> Array:
720-
"""
721-
Find the set difference of two arrays.
722-
723-
Return the unique values in `x1` that are not in `x2`.
724-
725-
Parameters
726-
----------
727-
x1 : array | int | float | complex | bool
728-
Input array.
729-
x2 : array
730-
Input comparison array.
731-
assume_unique : bool
732-
If ``True``, the input arrays are both assumed to be unique, which
733-
can speed up the calculation. Default is ``False``.
734-
xp : array_namespace, optional
735-
The standard-compatible namespace for `x1` and `x2`. Default: infer.
736-
737-
Returns
738-
-------
739-
array
740-
1D array of values in `x1` that are not in `x2`. The result
741-
is sorted when `assume_unique` is ``False``, but otherwise only sorted
742-
if the input is sorted.
743-
744-
Examples
745-
--------
746-
>>> import array_api_strict as xp
747-
>>> import array_api_extra as xpx
718+
xp: ModuleType,
719+
) -> Array: # numpydoc ignore=PR01,RT01
720+
"""See docstring in `array_api_extra._delegation.py`."""
748721

749-
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
750-
>>> x2 = xp.asarray([3, 4, 5, 6])
751-
>>> xpx.setdiff1d(x1, x2, xp=xp)
752-
Array([1, 2], dtype=array_api_strict.int64)
753-
"""
754-
if xp is None:
755-
xp = array_namespace(x1, x2)
756722
# https://github.com/microsoft/pyright/issues/10103
757723
x1_, x2_ = asarrays(x1, x2, xp=xp)
758724

tests/test_funcs.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@
3434
)
3535
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3636
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
37-
from array_api_extra._lib._utils._compat import (
38-
device as get_device,
39-
)
37+
from array_api_extra._lib._utils._compat import device as get_device
38+
from array_api_extra._lib._utils._compat import is_jax_namespace
4039
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
4140
from array_api_extra._lib._utils._typing import Array, Device
4241
from array_api_extra.testing import lazy_xp_function
@@ -1271,6 +1270,10 @@ def test_shapes(
12711270
):
12721271
x1 = xp.zeros(shape1)
12731272
x2 = xp.zeros(shape2)
1273+
1274+
if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
1275+
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1276+
12741277
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12751278
xp_assert_equal(actual, xp.empty((0,)))
12761279

@@ -1283,6 +1286,9 @@ def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
12831286
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12841287
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
12851288

1289+
if is_jax_namespace(xp) and assume_unique:
1290+
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1291+
12861292
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
12871293
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
12881294

0 commit comments

Comments
 (0)