3030)
3131from array_api_extra ._lib ._backends import NUMPY_VERSION , Backend
3232from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
33- from array_api_extra ._lib ._utils ._compat import device as get_device
33+ from array_api_extra ._lib ._utils ._compat import device as get_device , is_jax_namespace
3434from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
3535from array_api_extra ._lib ._utils ._typing import Array , Device
3636from array_api_extra .testing import lazy_xp_function
@@ -1228,6 +1228,10 @@ def test_shapes(
12281228 ):
12291229 x1 = xp .zeros (shape1 )
12301230 x2 = xp .zeros (shape2 )
1231+
1232+ if is_jax_namespace (xp ) and assume_unique and shape1 != (1 ,):
1233+ pytest .xfail (reason = "jax#32335 fixed with jax>=0.8.0" )
1234+
12311235 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
12321236 xp_assert_equal (actual , xp .empty ((0 ,)))
12331237
@@ -1240,6 +1244,9 @@ def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
12401244 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
12411245 xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
12421246
1247+ if is_jax_namespace (xp ) and assume_unique :
1248+ pytest .xfail (reason = "jax#32335 fixed with jax>=0.8.0" )
1249+
12431250 actual = setdiff1d (x2 , x1 , assume_unique = assume_unique )
12441251 xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
12451252
0 commit comments