Skip to content

Commit 6887ffb

Browse files
committed
add xfail according to jax issue
1 parent 6efe3b9 commit 6887ffb

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

src/array_api_extra/_delegation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,10 @@ def setdiff1d(
462462
if xp is None:
463463
xp = array_namespace(x1, x2)
464464

465-
if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp):
466-
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)
465+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
466+
# https://github.com/microsoft/pyright/issues/10103
467+
x1_, x2_ = asarrays(x1, x2, xp=xp)
468+
return xp.setdiff1d(x1_, x2_, assume_unique=assume_unique)
467469

468470
return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)
469471

tests/test_funcs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3232
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
33-
from array_api_extra._lib._utils._compat import device as get_device
33+
from array_api_extra._lib._utils._compat import device as get_device, is_jax_namespace
3434
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
3535
from array_api_extra._lib._utils._typing import Array, Device
3636
from array_api_extra.testing import lazy_xp_function
@@ -1228,6 +1228,10 @@ def test_shapes(
12281228
):
12291229
x1 = xp.zeros(shape1)
12301230
x2 = xp.zeros(shape2)
1231+
1232+
if is_jax_namespace(xp) and assume_unique and shape1!=(1,):
1233+
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1234+
12311235
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12321236
xp_assert_equal(actual, xp.empty((0,)))
12331237

@@ -1240,6 +1244,9 @@ def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
12401244
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12411245
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
12421246

1247+
if is_jax_namespace(xp) and assume_unique:
1248+
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1249+
12431250
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
12441251
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
12451252

0 commit comments

Comments
 (0)