File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 3434)
3535from array_api_extra ._lib ._backends import NUMPY_VERSION , Backend
3636from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
37- from array_api_extra ._lib ._utils ._compat import device as get_device , is_jax_namespace
37+ from array_api_extra ._lib ._utils ._compat import device as get_device
38+ from array_api_extra ._lib ._utils ._compat import is_jax_namespace
3839from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
3940from array_api_extra ._lib ._utils ._typing import Array , Device
4041from array_api_extra .testing import lazy_xp_function
@@ -1270,7 +1271,7 @@ def test_shapes(
12701271 x1 = xp .zeros (shape1 )
12711272 x2 = xp .zeros (shape2 )
12721273
1273- if is_jax_namespace (xp ) and assume_unique and shape1 != (1 ,):
1274+ if is_jax_namespace (xp ) and assume_unique and shape1 != (1 ,):
12741275 pytest .xfail (reason = "jax#32335 fixed with jax>=0.8.0" )
12751276
12761277 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
You can’t perform that action at this time.
0 commit comments