Skip to content

Commit 62f3d6c

Browse files
committed
format test_func.py
1 parent e64813e commit 62f3d6c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/test_funcs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
)
3535
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3636
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
37-
from array_api_extra._lib._utils._compat import device as get_device, is_jax_namespace
37+
from array_api_extra._lib._utils._compat import device as get_device
38+
from array_api_extra._lib._utils._compat import is_jax_namespace
3839
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
3940
from array_api_extra._lib._utils._typing import Array, Device
4041
from array_api_extra.testing import lazy_xp_function
@@ -1270,7 +1271,7 @@ def test_shapes(
12701271
x1 = xp.zeros(shape1)
12711272
x2 = xp.zeros(shape2)
12721273

1273-
if is_jax_namespace(xp) and assume_unique and shape1!=(1,):
1274+
if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
12741275
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
12751276

12761277
actual = setdiff1d(x1, x2, assume_unique=assume_unique)

0 commit comments

Comments
 (0)