3434)
3535from array_api_extra ._lib ._backends import NUMPY_VERSION , Backend
3636from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
37- from array_api_extra ._lib ._utils ._compat import (
38- device as get_device ,
39- )
37+ from array_api_extra ._lib ._utils ._compat import device as get_device , is_jax_namespace
4038from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
4139from array_api_extra ._lib ._utils ._typing import Array , Device
4240from array_api_extra .testing import lazy_xp_function
@@ -1271,6 +1269,10 @@ def test_shapes(
12711269 ):
12721270 x1 = xp .zeros (shape1 )
12731271 x2 = xp .zeros (shape2 )
1272+
1273+ if is_jax_namespace (xp ) and assume_unique and shape1 != (1 ,):
1274+ pytest .xfail (reason = "jax#32335 fixed with jax>=0.8.0" )
1275+
12741276 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
12751277 xp_assert_equal (actual , xp .empty ((0 ,)))
12761278
@@ -1283,6 +1285,9 @@ def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
12831285 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
12841286 xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
12851287
1288+ if is_jax_namespace (xp ) and assume_unique :
1289+ pytest .xfail (reason = "jax#32335 fixed with jax>=0.8.0" )
1290+
12861291 actual = setdiff1d (x2 , x1 , assume_unique = assume_unique )
12871292 xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
12881293
0 commit comments