diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index e509a74e..4541bd36 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,3 +1,9 @@ # Add "repro_snippets" to test_manipulation_functions.py (gh-384) da420a4a369ee9a587f91a61300f4eb4a2f5b8d8 - +# Add "repro_snippets" to more test modules (gh-392) +7203669453655693e4edb68b5c4de56a199edd83 +8096c9368288c0200ea226424bbead2cf6a5a51f +e807ffe526c7330691e8f39d31347dc2b3106de3 +bd42e84d2e5aae26ade8d70384e74effd1de89cb +f7e822883b7e24b5aa540e2413759a85128b42ef +a37f348ba27b6818e92fda8aee2406c653c671ea diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 1f144c72..d239194e 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -147,51 +147,57 @@ def test_arange(dtype, data): kvds.insert(0, hh.KVD("step", step, 1)) kwargs = data.draw(hh.specified_kwargs(*kvds), label="kwargs") - out = xp.arange(*args, **kwargs) + repro_snippet = ph.format_snippet(f"xp.arange(*args, **kwargs) with {args = } and {kwargs = }") - if dtype is None: - if all_int: - ph.assert_default_int("arange", out.dtype) + try: + out = xp.arange(*args, **kwargs) + + if dtype is None: + if all_int: + ph.assert_default_int("arange", out.dtype) + else: + ph.assert_default_float("arange", out.dtype) else: - ph.assert_default_float("arange", out.dtype) - else: - ph.assert_kw_dtype("arange", kw_dtype=dtype, out_dtype=out.dtype) - f_sig = ", ".join(str(n) for n in args) - if len(kwargs) > 0: - f_sig += f", {ph.fmt_kw(kwargs)}" - f_func = f"[arange({f_sig})]" - assert out.ndim == 1, f"{out.ndim=}, but should be 1 [{f_func}]" - # We check size is roughly as expected to avoid edge cases e.g. - # - # >>> xp.arange(2, step=0.333333333333333) - # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0] - # >>> xp.arange(2, step=0.3333333333333333) - # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66] - # - # >>> start, stop, step = 0, 108086391056891901, 1080863910568919 - # >>> x = xp.arange(start, stop, step, dtype=xp.uint64) - # >>> x.size - # 100 - # >>> r = range(start, stop, step) - # >>> len(r) - # 101 - # - min_size = math.floor(size * 0.9) - max_size = max(math.ceil(size * 1.1), 1) - out_size = math.prod(out.shape) - assert ( - min_size <= out_size <= max_size - ), f"prod(out.shape)={out_size}, but should be roughly {size} {f_func}" - if dh.is_int_dtype(_dtype): - elements = list(r) - assume(out_size == len(elements)) - ph.assert_array_elements("arange", out=out, expected=xp.asarray(elements, dtype=_dtype)) - else: - assume(out_size == size) - if out_size > 0: - assert xp.equal( - out[0], xp.asarray(_start, dtype=out.dtype) - ), f"out[0]={out[0]}, but should be {_start} {f_func}" + ph.assert_kw_dtype("arange", kw_dtype=dtype, out_dtype=out.dtype) + f_sig = ", ".join(str(n) for n in args) + if len(kwargs) > 0: + f_sig += f", {ph.fmt_kw(kwargs)}" + f_func = f"[arange({f_sig})]" + assert out.ndim == 1, f"{out.ndim=}, but should be 1 [{f_func}]" + # We check size is roughly as expected to avoid edge cases e.g. + # + # >>> xp.arange(2, step=0.333333333333333) + # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0] + # >>> xp.arange(2, step=0.3333333333333333) + # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66] + # + # >>> start, stop, step = 0, 108086391056891901, 1080863910568919 + # >>> x = xp.arange(start, stop, step, dtype=xp.uint64) + # >>> x.size + # 100 + # >>> r = range(start, stop, step) + # >>> len(r) + # 101 + # + min_size = math.floor(size * 0.9) + max_size = max(math.ceil(size * 1.1), 1) + out_size = math.prod(out.shape) + assert ( + min_size <= out_size <= max_size + ), f"prod(out.shape)={out_size}, but should be roughly {size} {f_func}" + if dh.is_int_dtype(_dtype): + elements = list(r) + assume(out_size == len(elements)) + ph.assert_array_elements("arange", out=out, expected=xp.asarray(elements, dtype=_dtype)) + else: + assume(out_size == size) + if out_size > 0: + assert xp.equal( + out[0], xp.asarray(_start, dtype=out.dtype) + ), f"out[0]={out[0]}, but should be {_start} {f_func}" + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(shape=hh.shapes(min_side=1), data=st.data()) @@ -229,26 +235,30 @@ def test_asarray_scalars(shape, data): obj = sh.reshape(_obj, shape) note(f"{obj=}") - out = xp.asarray(obj, **kw) - - if dtype is None: - msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be " - if dtype_family == (xp.float32, xp.float64): - msg += "default floating-point dtype (float32 or float64)" - elif dtype_family == (xp.int32, xp.int64): - msg += "default integer dtype (int32 or int64)" + repro_snippet = ph.format_snippet(f"xp.asarray({obj!r}, **kw) with {kw = }") + try: + out = xp.asarray(obj, **kw) + + if dtype is None: + msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be " + if dtype_family == (xp.float32, xp.float64): + msg += "default floating-point dtype (float32 or float64)" + elif dtype_family == (xp.int32, xp.int64): + msg += "default integer dtype (int32 or int64)" + else: + msg += "boolean dtype" + msg += " [asarray()]" + assert out.dtype in dtype_family, msg else: - msg += "boolean dtype" - msg += " [asarray()]" - assert out.dtype in dtype_family, msg - else: - assert kw["dtype"] == _dtype # sanity check - ph.assert_kw_dtype("asarray", kw_dtype=_dtype, out_dtype=out.dtype) - ph.assert_shape("asarray", out_shape=out.shape, expected=shape) - for idx, v_expect in zip(sh.ndindex(out.shape), _obj): - v = scalar_type(out[idx]) - ph.assert_scalar_equals("asarray", type_=scalar_type, idx=idx, out=v, expected=v_expect, kw=kw) - + assert kw["dtype"] == _dtype # sanity check + ph.assert_kw_dtype("asarray", kw_dtype=_dtype, out_dtype=out.dtype) + ph.assert_shape("asarray", out_shape=out.shape, expected=shape) + for idx, v_expect in zip(sh.ndindex(out.shape), _obj): + v = scalar_type(out[idx]) + ph.assert_scalar_equals("asarray", type_=scalar_type, idx=idx, out=v, expected=v_expect, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise def scalar_eq(s1: Scalar, s2: Scalar) -> bool: if cmath.isnan(s1): @@ -273,53 +283,64 @@ def test_asarray_arrays(shape, dtypes, data): label="kw", ) - out = xp.asarray(x, **kw) + repro_snippet = ph.format_snippet(f"xp.asarray({x!r}, **kw) with {kw = }") + try: + out = xp.asarray(x, **kw) - dtype = kw.get("dtype", None) - if dtype is None: - ph.assert_dtype("asarray", in_dtype=x.dtype, out_dtype=out.dtype) - else: - ph.assert_kw_dtype("asarray", kw_dtype=dtype, out_dtype=out.dtype) - ph.assert_shape("asarray", out_shape=out.shape, expected=x.shape) - ph.assert_array_elements("asarray", out=out, expected=x, kw=kw) - copy = kw.get("copy", None) - if copy is not None: - stype = dh.get_scalar_type(x.dtype) - idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") - old_value = stype(x[idx]) - scalar_strat = hh.from_dtype(dtypes.input_dtype).filter( - lambda n: not scalar_eq(n, old_value) - ) - value = data.draw( - scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)), - label="mutating value", - ) - x[idx] = value - note(f"mutated {x=}") - # sanity check - ph.assert_scalar_equals( - "__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x" - ) - new_out_value = stype(out[idx]) - f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" - if copy: - assert scalar_eq( - new_out_value, old_value - ), f"{f_out}, but should be {old_value} even after x was mutated" + dtype = kw.get("dtype", None) + if dtype is None: + ph.assert_dtype("asarray", in_dtype=x.dtype, out_dtype=out.dtype) else: - assert scalar_eq( - new_out_value, value - ), f"{f_out}, but should be {value} after x was mutated" + ph.assert_kw_dtype("asarray", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("asarray", out_shape=out.shape, expected=x.shape) + ph.assert_array_elements("asarray", out=out, expected=x, kw=kw) + copy = kw.get("copy", None) + if copy is not None: + stype = dh.get_scalar_type(x.dtype) + idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") + old_value = stype(x[idx]) + scalar_strat = hh.from_dtype(dtypes.input_dtype).filter( + lambda n: not scalar_eq(n, old_value) + ) + value = data.draw( + scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)), + label="mutating value", + ) + x[idx] = value + note(f"mutated {x=}") + # sanity check + ph.assert_scalar_equals( + "__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x" + ) + new_out_value = stype(out[idx]) + f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" + if copy: + assert scalar_eq( + new_out_value, old_value + ), f"{f_out}, but should be {old_value} even after x was mutated" + else: + assert scalar_eq( + new_out_value, value + ), f"{f_out}, but should be {value} after x was mutated" + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes)) def test_empty(shape, kw): - out = xp.empty(shape, **kw) - if kw.get("dtype", None) is None: - ph.assert_default_float("empty", out.dtype) - else: - ph.assert_kw_dtype("empty", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("empty", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) + + repro_snippet = ph.format_snippet(f"xp.empty({shape!r}, **kw) with {kw = }") + try: + out = xp.empty(shape, **kw) + if kw.get("dtype", None) is None: + ph.assert_default_float("empty", out.dtype) + else: + ph.assert_kw_dtype("empty", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("empty", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( @@ -327,13 +348,17 @@ def test_empty(shape, kw): kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), ) def test_empty_like(x, kw): - out = xp.empty_like(x, **kw) - if kw.get("dtype", None) is None: - ph.assert_dtype("empty_like", in_dtype=x.dtype, out_dtype=out.dtype) - else: - ph.assert_kw_dtype("empty_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("empty_like", out_shape=out.shape, expected=x.shape) - + repro_snippet = ph.format_snippet(f"xp.empty_like({x!r}, **kw) with {kw = }") + try: + out = xp.empty_like(x, **kw) + if kw.get("dtype", None) is None: + ph.assert_dtype("empty_like", in_dtype=x.dtype, out_dtype=out.dtype) + else: + ph.assert_kw_dtype("empty_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("empty_like", out_shape=out.shape, expected=x.shape) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( n_rows=hh.sqrt_sizes, @@ -344,21 +369,26 @@ def test_empty_like(x, kw): ), ) def test_eye(n_rows, n_cols, kw): - out = xp.eye(n_rows, n_cols, **kw) - if kw.get("dtype", None) is None: - ph.assert_default_float("eye", out.dtype) - else: - ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype) - _n_cols = n_rows if n_cols is None else n_cols - ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols)) - k = kw.get("k", 0) - expected = xp.asarray( - [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)], - dtype=out.dtype # Note: dtype already checked above. - ) - if 0 in expected.shape: - expected = xp.reshape(expected, (n_rows, _n_cols)) - ph.assert_array_elements("eye", out=out, expected=expected, kw=kw) + repro_snippet = ph.format_snippet(f"xp.eye({n_rows!r}, {n_cols!r}, **kw) with {kw = }") + try: + out = xp.eye(n_rows, n_cols, **kw) + if kw.get("dtype", None) is None: + ph.assert_default_float("eye", out.dtype) + else: + ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype) + _n_cols = n_rows if n_cols is None else n_cols + ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols)) + k = kw.get("k", 0) + expected = xp.asarray( + [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)], + dtype=out.dtype # Note: dtype already checked above. + ) + if 0 in expected.shape: + expected = xp.reshape(expected, (n_rows, _n_cols)) + ph.assert_array_elements("eye", out=out, expected=expected, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise default_unsafe_dtypes = [xp.uint64] @@ -388,40 +418,45 @@ def full_fill_values(draw) -> Union[bool, int, float, complex]: kw=st.shared(hh.kwargs(dtype=st.none() | hh.all_dtypes), key="full_kw"), ) def test_full(shape, fill_value, kw): - with hh.reject_overflow(): - out = xp.full(shape, fill_value, **kw) - if kw.get("dtype", None): - dtype = kw["dtype"] - elif isinstance(fill_value, bool): - dtype = xp.bool - elif isinstance(fill_value, int): - dtype = dh.default_int - elif isinstance(fill_value, float): - dtype = dh.default_float - else: - assert isinstance(fill_value, complex) # sanity check - dtype = dh.default_complex - # Ignore large components so we don't fail like - # - # >>> torch.fill(complex(0.0, 3.402823466385289e+38)) - # RuntimeError: value cannot be converted to complex without overflow - # - M = dh.dtype_ranges[dh.dtype_components[dtype]].max - assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag])) - if kw.get("dtype", None) is None: - if isinstance(fill_value, bool): - assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]" + repro_snippet = ph.format_snippet(f"xp.full({shape!r}, {fill_value!r}, **kw) with {kw = }") + try: + with hh.reject_overflow(): + out = xp.full(shape, fill_value, **kw) + if kw.get("dtype", None): + dtype = kw["dtype"] + elif isinstance(fill_value, bool): + dtype = xp.bool elif isinstance(fill_value, int): - ph.assert_default_int("full", out.dtype) + dtype = dh.default_int elif isinstance(fill_value, float): - ph.assert_default_float("full", out.dtype) + dtype = dh.default_float else: assert isinstance(fill_value, complex) # sanity check - ph.assert_default_complex("full", out.dtype) - else: - ph.assert_kw_dtype("full", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("full", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) - ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) + dtype = dh.default_complex + # Ignore large components so we don't fail like + # + # >>> torch.fill(complex(0.0, 3.402823466385289e+38)) + # RuntimeError: value cannot be converted to complex without overflow + # + M = dh.dtype_ranges[dh.dtype_components[dtype]].max + assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag])) + if kw.get("dtype", None) is None: + if isinstance(fill_value, bool): + assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]" + elif isinstance(fill_value, int): + ph.assert_default_int("full", out.dtype) + elif isinstance(fill_value, float): + ph.assert_default_float("full", out.dtype) + else: + assert isinstance(fill_value, complex) # sanity check + ph.assert_default_complex("full", out.dtype) + else: + ph.assert_kw_dtype("full", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("full", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) + ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), data=st.data()) @@ -429,15 +464,20 @@ def test_full_like(kw, data): dtype = kw.get("dtype", None) or data.draw(hh.all_dtypes, label="dtype") x = data.draw(hh.arrays(dtype=dtype, shape=hh.shapes()), label="x") fill_value = data.draw(hh.from_dtype(dtype), label="fill_value") - out = xp.full_like(x, fill_value, **kw) - dtype = kw.get("dtype", None) or x.dtype - if kw.get("dtype", None) is None: - ph.assert_dtype("full_like", in_dtype=x.dtype, out_dtype=out.dtype) - else: - ph.assert_kw_dtype("full_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("full_like", out_shape=out.shape, expected=x.shape) - ph.assert_fill("full_like", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) + repro_snippet = ph.format_snippet(f"xp.full_like({x!r}, {fill_value!r}, **kw) with {kw = }") + try: + out = xp.full_like(x, fill_value, **kw) + dtype = kw.get("dtype", None) or x.dtype + if kw.get("dtype", None) is None: + ph.assert_dtype("full_like", in_dtype=x.dtype, out_dtype=out.dtype) + else: + ph.assert_kw_dtype("full_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("full_like", out_shape=out.shape, expected=x.shape) + ph.assert_fill("full_like", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) + except Exception as exc: + exc.add_note(repro_snippet) + raise finite_kw = {"allow_nan": False, "allow_infinity": False} @@ -467,29 +507,35 @@ def test_linspace(num, dtype, endpoint, data): ), label="kw", ) - out = xp.linspace(start, stop, num, **kw) - if dtype is None: - ph.assert_default_float("linspace", out.dtype) - else: - ph.assert_kw_dtype("linspace", kw_dtype=dtype, out_dtype=out.dtype) - ph.assert_shape("linspace", out_shape=out.shape, expected=num, kw=dict(start=start, stop=stop, num=num)) - f_func = f"[linspace({start}, {stop}, {num})]" - if num > 0: - assert xp.equal( - out[0], xp.asarray(start, dtype=out.dtype) - ), f"out[0]={out[0]}, but should be {start} {f_func}" - if endpoint: - if num > 1: + repro_snippet = ph.format_snippet(f"xp.linspace({start!r}, {stop!r}, {num!r}, **kw) with {kw = }") + try: + out = xp.linspace(start, stop, num, **kw) + + if dtype is None: + ph.assert_default_float("linspace", out.dtype) + else: + ph.assert_kw_dtype("linspace", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("linspace", out_shape=out.shape, expected=num, kw=dict(start=start, stop=stop, num=num)) + f_func = f"[linspace({start}, {stop}, {num})]" + if num > 0: assert xp.equal( - out[-1], xp.asarray(stop, dtype=out.dtype) - ), f"out[-1]={out[-1]}, but should be {stop} {f_func}" - else: - # linspace(..., num, endpoint=True) should return an array equivalent to - # the first num elements when endpoint=False - expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True) - expected = expected[:-1] - ph.assert_array_elements("linspace", out=out, expected=expected) + out[0], xp.asarray(start, dtype=out.dtype) + ), f"out[0]={out[0]}, but should be {start} {f_func}" + if endpoint: + if num > 1: + assert xp.equal( + out[-1], xp.asarray(stop, dtype=out.dtype) + ), f"out[-1]={out[-1]}, but should be {stop} {f_func}" + else: + # linspace(..., num, endpoint=True) should return an array equivalent to + # the first num elements when endpoint=False + expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True) + expected = expected[:-1] + ph.assert_array_elements("linspace", out=out, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(dtype=hh.numeric_dtypes, data=st.data()) @@ -510,9 +556,15 @@ def test_meshgrid(dtype, data): arrays.append(x) # sanity check # assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE - out = xp.meshgrid(*arrays) - for i, x in enumerate(out): - ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype") + + repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays) with {arrays = }") + try: + out = xp.meshgrid(*arrays) + for i, x in enumerate(out): + ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype") + except Exception as exc: + exc.add_note(repro_snippet) + raise def make_one(dtype: DataType) -> Scalar: @@ -526,15 +578,20 @@ def make_one(dtype: DataType) -> Scalar: @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes)) def test_ones(shape, kw): - out = xp.ones(shape, **kw) - if kw.get("dtype", None) is None: - ph.assert_default_float("ones", out.dtype) - else: - ph.assert_kw_dtype("ones", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("ones", out_shape=out.shape, expected=shape, - kw={'shape': shape, **kw}) - dtype = kw.get("dtype", None) or dh.default_float - ph.assert_fill("ones", fill_value=make_one(dtype), dtype=dtype, out=out, kw=kw) + repro_snippet = ph.format_snippet(f"xp.ones({shape!r}, **kw) with {kw = }") + try: + out = xp.ones(shape, **kw) + if kw.get("dtype", None) is None: + ph.assert_default_float("ones", out.dtype) + else: + ph.assert_kw_dtype("ones", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("ones", out_shape=out.shape, expected=shape, + kw={'shape': shape, **kw}) + dtype = kw.get("dtype", None) or dh.default_float + ph.assert_fill("ones", fill_value=make_one(dtype), dtype=dtype, out=out, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( @@ -542,15 +599,20 @@ def test_ones(shape, kw): kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), ) def test_ones_like(x, kw): - out = xp.ones_like(x, **kw) - if kw.get("dtype", None) is None: - ph.assert_dtype("ones_like", in_dtype=x.dtype, out_dtype=out.dtype) - else: - ph.assert_kw_dtype("ones_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("ones_like", out_shape=out.shape, expected=x.shape, kw=kw) - dtype = kw.get("dtype", None) or x.dtype - ph.assert_fill("ones_like", fill_value=make_one(dtype), dtype=dtype, - out=out, kw=kw) + repro_snippet = ph.format_snippet(f"xp.ones_like({x!r}, **kw) with {kw = }") + try: + out = xp.ones_like(x, **kw) + if kw.get("dtype", None) is None: + ph.assert_dtype("ones_like", in_dtype=x.dtype, out_dtype=out.dtype) + else: + ph.assert_kw_dtype("ones_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("ones_like", out_shape=out.shape, expected=x.shape, kw=kw) + dtype = kw.get("dtype", None) or x.dtype + ph.assert_fill("ones_like", fill_value=make_one(dtype), dtype=dtype, + out=out, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise def make_zero(dtype: DataType) -> Scalar: @@ -564,15 +626,20 @@ def make_zero(dtype: DataType) -> Scalar: @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes)) def test_zeros(shape, kw): - out = xp.zeros(shape, **kw) - if kw.get("dtype", None) is None: - ph.assert_default_float("zeros", out_dtype=out.dtype) - else: - ph.assert_kw_dtype("zeros", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("zeros", out_shape=out.shape, expected=shape, kw={'shape': shape, **kw}) - dtype = kw.get("dtype", None) or dh.default_float - ph.assert_fill("zeros", fill_value=make_zero(dtype), dtype=dtype, out=out, - kw=kw) + repro_snippet = ph.format_snippet(f"xp.zeros({shape!r}, **kw) with {kw = }") + try: + out = xp.zeros(shape, **kw) + if kw.get("dtype", None) is None: + ph.assert_default_float("zeros", out_dtype=out.dtype) + else: + ph.assert_kw_dtype("zeros", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("zeros", out_shape=out.shape, expected=shape, kw={'shape': shape, **kw}) + dtype = kw.get("dtype", None) or dh.default_float + ph.assert_fill("zeros", fill_value=make_zero(dtype), dtype=dtype, out=out, + kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( @@ -580,13 +647,18 @@ def test_zeros(shape, kw): kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), ) def test_zeros_like(x, kw): - out = xp.zeros_like(x, **kw) - if kw.get("dtype", None) is None: - ph.assert_dtype("zeros_like", in_dtype=x.dtype, out_dtype=out.dtype) - else: - ph.assert_kw_dtype("zeros_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) - ph.assert_shape("zeros_like", out_shape=out.shape, expected=x.shape, - kw=kw) - dtype = kw.get("dtype", None) or x.dtype - ph.assert_fill("zeros_like", fill_value=make_zero(dtype), dtype=dtype, - out=out, kw=kw) + repro_snippet = ph.format_snippet(f"xp.zeros_like({x!r}, **kw) with {kw = }") + try: + out = xp.zeros_like(x, **kw) + if kw.get("dtype", None) is None: + ph.assert_dtype("zeros_like", in_dtype=x.dtype, out_dtype=out.dtype) + else: + ph.assert_kw_dtype("zeros_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("zeros_like", out_shape=out.shape, expected=x.shape, + kw=kw) + dtype = kw.get("dtype", None) or x.dtype + ph.assert_fill("zeros_like", fill_value=make_zero(dtype), dtype=dtype, + out=out, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index e50b621e..ad1b551d 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -81,12 +81,17 @@ def test_astype(x_dtype, dtype, kw, data): # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes))) - out = xp.astype(x, dtype, **kw) + repro_snippet = ph.format_snippet(f"xp.astype({x!r}, {dtype!r}, **kw) with {kw = }") + try: + out = xp.astype(x, dtype, **kw) - ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype) - ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw) - # TODO: test values - # TODO: test copy + ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw) + # TODO: test values + # TODO: test copy + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( @@ -98,24 +103,29 @@ def test_broadcast_arrays(shapes, data): x = data.draw(hh.arrays(dtype=hh.all_dtypes, shape=shape), label=f"x{c}") arrays.append(x) - out = xp.broadcast_arrays(*arrays) - - expected_shape = sh.broadcast_shapes(*shapes) - for i, x in enumerate(arrays): - ph.assert_dtype( - "broadcast_arrays", - in_dtype=x.dtype, - out_dtype=out[i].dtype, - repr_name=f"out[{i}].dtype" - ) - ph.assert_result_shape( - "broadcast_arrays", - in_shapes=shapes, - out_shape=out[i].shape, - expected=expected_shape, - repr_name=f"out[{i}].shape", - ) - # TODO: test values + repro_snippet = ph.format_snippet(f"xp.broadcast_arrays(*arrays) with {arrays = }") + try: + out = xp.broadcast_arrays(*arrays) + + expected_shape = sh.broadcast_shapes(*shapes) + for i, x in enumerate(arrays): + ph.assert_dtype( + "broadcast_arrays", + in_dtype=x.dtype, + out_dtype=out[i].dtype, + repr_name=f"out[{i}].dtype" + ) + ph.assert_result_shape( + "broadcast_arrays", + in_shapes=shapes, + out_shape=out[i].shape, + expected=expected_shape, + repr_name=f"out[{i}].shape", + ) + # TODO: test values + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data()) @@ -127,29 +137,39 @@ def test_broadcast_to(x, data): label="shape", ) - out = xp.broadcast_to(x, shape) + repro_snippet = ph.format_snippet(f"xp.broadcast_to({x!r}, {shape!r})") + try: + out = xp.broadcast_to(x, shape) - ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape) - # TODO: test values + ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape) + # TODO: test values + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(_from=hh.all_dtypes, to=hh.all_dtypes) def test_can_cast(_from, to): - out = xp.can_cast(_from, to) + repro_snippet = ph.format_snippet(f"xp.can_cast({_from!r}, {to!r})") + try: + out = xp.can_cast(_from, to) - expected = False - for other in dh.all_dtypes: - if dh.promotion_table.get((_from, other)) == to: - expected = True - break + expected = False + for other in dh.all_dtypes: + if dh.promotion_table.get((_from, other)) == to: + expected = True + break - f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]" - if expected: - # cross-kind casting is not explicitly disallowed. We can only test - # the cases where it should return True. TODO: if expected=False, - # check that the array library actually allows such casts. - assert out == expected, f"{out=}, but should be {expected} {f_func}" + f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]" + if expected: + # cross-kind casting is not explicitly disallowed. We can only test + # the cases where it should return True. TODO: if expected=False, + # check that the array library actually allows such casts. + assert out == expected, f"{out=}, but should be {expected} {f_func}" + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes) @@ -160,7 +180,7 @@ def test_finfo(dtype): # np.float64 and np.asarray(1, dtype=np.float64).dtype are different xp.asarray(1, dtype=dtype).dtype, ): - repro_snippet = ph.format_snippet(f"xp.finfo({arg})") + repro_snippet = ph.format_snippet(f"xp.finfo({arg!r})") try: out = xp.finfo(arg) assert isinstance(out.bits, int) @@ -175,19 +195,24 @@ def test_finfo(dtype): @pytest.mark.min_version("2022.12") @pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes) def test_finfo_dtype(dtype): - out = xp.finfo(dtype) - - if dtype == xp.complex64: - assert out.dtype == xp.float32 - elif dtype == xp.complex128: - assert out.dtype == xp.float64 - else: - assert out.dtype == dtype + repro_snippet = ph.format_snippet(f"xp.finfo({dtype!r})") + try: + out = xp.finfo(dtype) + + if dtype == xp.complex64: + assert out.dtype == xp.float32 + elif dtype == xp.complex128: + assert out.dtype == xp.float64 + else: + assert out.dtype == dtype - # Guard vs. numpy.dtype.__eq__ lax comparison - assert not isinstance(out.dtype, str) - assert out.dtype is not float - assert out.dtype is not complex + # Guard vs. numpy.dtype.__eq__ lax comparison + assert not isinstance(out.dtype, str) + assert out.dtype is not float + assert out.dtype is not complex + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes) @@ -198,20 +223,30 @@ def test_iinfo(dtype): # np.int64 and np.asarray(1, dtype=np.int64).dtype are different xp.asarray(1, dtype=dtype).dtype, ): - out = xp.iinfo(arg) - assert isinstance(out.bits, int) - assert isinstance(out.max, int) - assert isinstance(out.min, int) + repro_snippet = ph.format_snippet(f"xp.iinfo({arg!r})") + try: + out = xp.iinfo(arg) + assert isinstance(out.bits, int) + assert isinstance(out.max, int) + assert isinstance(out.min, int) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes) def test_iinfo_dtype(dtype): - out = xp.iinfo(dtype) - assert out.dtype == dtype - # Guard vs. numpy.dtype.__eq__ lax comparison - assert not isinstance(out.dtype, str) - assert out.dtype is not int + repro_snippet = ph.format_snippet(f"xp.iinfo({dtype!r})") + try: + out = xp.iinfo(dtype) + assert out.dtype == dtype + # Guard vs. numpy.dtype.__eq__ lax comparison + assert not isinstance(out.dtype, str) + assert out.dtype is not int + except Exception as exc: + exc.add_note(repro_snippet) + raise def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]: @@ -224,29 +259,39 @@ def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]: kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple), ) def test_isdtype(dtype, kind): - out = xp.isdtype(dtype, kind) - - assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]" - _kinds = kind if isinstance(kind, tuple) else (kind,) - expected = False - for _kind in _kinds: - if isinstance(_kind, str): - if dtype in dh.kind_to_dtypes[_kind]: - expected = True - break - else: - if dtype == _kind: - expected = True - break - assert out == expected, f"{out=}, but should be {expected} [isdtype()]" + repro_snippet = ph.format_snippet(f"xp.isdtype({dtype!r}, {kind!r})") + try: + out = xp.isdtype(dtype, kind) + + assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]" + _kinds = kind if isinstance(kind, tuple) else (kind,) + expected = False + for _kind in _kinds: + if isinstance(_kind, str): + if dtype in dh.kind_to_dtypes[_kind]: + expected = True + break + else: + if dtype == _kind: + expected = True + break + assert out == expected, f"{out=}, but should be {expected} [isdtype()]" + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2024.12") class TestResultType: @given(dtypes=hh.mutually_promotable_dtypes(None)) def test_result_type(self, dtypes): - out = xp.result_type(*dtypes) - ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") + repro_snippet = ph.format_snippet(f"xp.result_type(*dtypes) with {dtypes = }") + try: + out = xp.result_type(*dtypes) + ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(pair=hh.pair_of_mutually_promotable_dtypes(None)) def test_shuffled(self, pair): @@ -261,33 +306,55 @@ def test_arrays_and_dtypes(self, pair, data): s1, s2 = pair a2 = tuple(xp.empty(1, dtype=dt) for dt in s2) a_and_dt = data.draw(st.permutations(s1 + a2)) - out = xp.result_type(*a_and_dt) - ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out") + + repro_snippet = ph.format_snippet(f"xp.result_type(*a_and_dt) with {a_and_dt = }") + try: + out = xp.result_type(*a_and_dt) + ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out") + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data()) def test_with_scalars(self, dtypes, data): - out = xp.result_type(*dtypes) - - if out == xp.bool: - scalars = [True] - elif out in dh.all_int_dtypes: - scalars = [1] - elif out in dh.real_dtypes: - scalars = [1, 1.0] - elif out in dh.numeric_dtypes: - scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types - else: - raise ValueError(f"unknown dtype {out = }.") + + repro_snippet = ph.format_snippet(f"xp.result_type(*dtypes) with {dtypes = }") + try: + out = xp.result_type(*dtypes) + + if out == xp.bool: + scalars = [True] + elif out in dh.all_int_dtypes: + scalars = [1] + elif out in dh.real_dtypes: + scalars = [1, 1.0] + elif out in dh.numeric_dtypes: + scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types + else: + raise ValueError(f"unknown dtype {out = }.") + except Exception as exc: + exc.add_note(repro_snippet) + raise scalar = data.draw(st.sampled_from(scalars)) inputs = data.draw(st.permutations(dtypes + (scalar,))) - out_scalar = xp.result_type(*inputs) - assert out_scalar == out + repro_snippet = ph.format_snippet(f"xp.result_type(*inputs) with {inputs = }") + try: + out_scalar = xp.result_type(*inputs) + assert out_scalar == out + except Exception as exc: + exc.add_note(repro_snippet) + raise # retry with arrays arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes) inputs = data.draw(st.permutations(arrays + (scalar,))) - out_scalar = xp.result_type(*inputs) - assert out_scalar == out + repro_snippet = ph.format_snippet(f"xp.result_type(*inputs) with {inputs = }") + try: + out_scalar = xp.result_type(*inputs) + assert out_scalar == out + except Exception as exc: + exc.add_note(repro_snippet) + raise diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index b72c8030..8df475d8 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -33,23 +33,28 @@ def test_argmax(x, data): ) keepdims = kw.get("keepdims", False) - out = xp.argmax(x, **kw) + repro_snippet = ph.format_snippet(f"xp.argmax({x!r}, **kw) with {kw = }") + try: + out = xp.argmax(x, **kw) - ph.assert_default_index("argmax", out.dtype) - axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): - max_i = int(out[out_idx]) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = max(range(len(elements)), key=elements.__getitem__) - ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i, - expected=expected, kw=kw) + ph.assert_default_index("argmax", out.dtype) + axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): + max_i = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(range(len(elements)), key=elements.__getitem__) + ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i, + expected=expected, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( @@ -70,22 +75,27 @@ def test_argmin(x, data): ) keepdims = kw.get("keepdims", False) - out = xp.argmin(x, **kw) + repro_snippet = ph.format_snippet(f"xp.argmin({x!r}, **kw) with {kw = }") + try: + out = xp.argmin(x, **kw) - ph.assert_default_index("argmin", out.dtype) - axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): - min_i = int(out[out_idx]) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = min(range(len(elements)), key=elements.__getitem__) - ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) + ph.assert_default_index("argmin", out.dtype) + axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): + min_i = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(range(len(elements)), key=elements.__getitem__) + ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise # XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on @@ -115,23 +125,28 @@ def test_count_nonzero(x, data): assume(kw.get("axis", None) != ()) # TODO clarify in the spec - out = xp.count_nonzero(x, **kw) + repro_snippet = ph.format_snippet(f"xp.count_nonzero({x!r}, **kw) with {kw = }") + try: + out = xp.count_nonzero(x, **kw) - ph.assert_default_index("count_nonzero", out.dtype) - axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(x.dtype) + ph.assert_default_index("count_nonzero", out.dtype) + axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): - count = int(out[out_idx]) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = sum(el != 0 for el in elements) - ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): + count = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(el != 0 for el in elements) + ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(hh.arrays(dtype=hh.all_dtypes, shape=())) @@ -143,39 +158,44 @@ def test_nonzero_zerodim_error(x): @pytest.mark.data_dependent_shapes @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, min_side=1))) def test_nonzero(x): - out = xp.nonzero(x) - assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" - out_size = math.prod(out[0].shape) - for i in range(len(out)): - assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" - size_at = math.prod(out[i].shape) - assert size_at == out_size, ( - f"prod(out[{i}].shape)={size_at}, " - f"but should be prod(out[0].shape)={out_size}" - ) - ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") - indices = [] - if x.dtype == xp.bool: - for idx in sh.ndindex(x.shape): - if x[idx]: - indices.append(idx) - else: - for idx in sh.ndindex(x.shape): - if x[idx] != 0: - indices.append(idx) - if x.ndim == 0: - assert out_size == len( - indices - ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}" - else: - for i in range(out_size): - idx = tuple(int(x[i]) for x in out) - f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" - f_element = f"x[{idx}]={x[idx]}" - assert idx in indices, f"{f_idx} results in {f_element}, a zero element" - assert ( - idx == indices[i] - ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" + repro_snippet = ph.format_snippet(f"xp.nonzero({x!r})") + try: + out = xp.nonzero(x) + assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" + out_size = math.prod(out[0].shape) + for i in range(len(out)): + assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" + size_at = math.prod(out[i].shape) + assert size_at == out_size, ( + f"prod(out[{i}].shape)={size_at}, " + f"but should be prod(out[0].shape)={out_size}" + ) + ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") + indices = [] + if x.dtype == xp.bool: + for idx in sh.ndindex(x.shape): + if x[idx]: + indices.append(idx) + else: + for idx in sh.ndindex(x.shape): + if x[idx] != 0: + indices.append(idx) + if x.ndim == 0: + assert out_size == len( + indices + ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}" + else: + for i in range(out_size): + idx = tuple(int(x[i]) for x in out) + f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" + f_element = f"x[{idx}]={x[idx]}" + assert idx in indices, f"{f_idx} results in {f_element}, a zero element" + assert ( + idx == indices[i] + ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( @@ -188,31 +208,36 @@ def test_where(shapes, dtypes, data): x1 = data.draw(hh.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") x2 = data.draw(hh.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") - out = xp.where(cond, x1, x2) - - shape = sh.broadcast_shapes(*shapes) - ph.assert_shape("where", out_shape=out.shape, expected=shape) - # TODO: generate indices without broadcasting arrays - _cond = xp.broadcast_to(cond, shape) - _x1 = xp.broadcast_to(x1, shape) - _x2 = xp.broadcast_to(x2, shape) - for idx in sh.ndindex(shape): - if _cond[idx]: - ph.assert_0d_equals( - "where", - x_repr=f"_x1[{idx}]", - x_val=_x1[idx], - out_repr=f"out[{idx}]", - out_val=out[idx] - ) - else: - ph.assert_0d_equals( - "where", - x_repr=f"_x2[{idx}]", - x_val=_x2[idx], - out_repr=f"out[{idx}]", - out_val=out[idx] - ) + repro_snippet = ph.format_snippet(f"xp.where({cond!r}, {x1!r}, {x2!r})") + try: + out = xp.where(cond, x1, x2) + + shape = sh.broadcast_shapes(*shapes) + ph.assert_shape("where", out_shape=out.shape, expected=shape) + # TODO: generate indices without broadcasting arrays + _cond = xp.broadcast_to(cond, shape) + _x1 = xp.broadcast_to(x1, shape) + _x2 = xp.broadcast_to(x2, shape) + for idx in sh.ndindex(shape): + if _cond[idx]: + ph.assert_0d_equals( + "where", + x_repr=f"_x1[{idx}]", + x_val=_x1[idx], + out_repr=f"out[{idx}]", + out_val=out[idx] + ) + else: + ph.assert_0d_equals( + "where", + x_repr=f"_x2[{idx}]", + x_val=_x2[idx], + out_repr=f"out[{idx}]", + out_val=out[idx] + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2023.12") @@ -238,12 +263,17 @@ def test_searchsorted(data): label="x2", ) - out = xp.searchsorted(x1, x2, sorter=sorter) + repro_snippet = ph.format_snippet(f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r})") + try: + out = xp.searchsorted(x1, x2, sorter=sorter) - ph.assert_dtype( - "searchsorted", - in_dtype=[x1.dtype, x2.dtype], - out_dtype=out.dtype, - expected=xp.__array_namespace_info__().default_dtypes()["indexing"], - ) - # TODO: shapes and values testing + ph.assert_dtype( + "searchsorted", + in_dtype=[x1.dtype, x2.dtype], + out_dtype=out.dtype, + expected=xp.__array_namespace_info__().default_dtypes()["indexing"], + ) + # TODO: shapes and values testing + except Exception as exc: + exc.add_note(repro_snippet) + raise diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index c9abaad1..4ca046cf 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -17,223 +17,242 @@ @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_all(x): - out = xp.unique_all(x) - - assert hasattr(out, "values") - assert hasattr(out, "indices") - assert hasattr(out, "inverse_indices") - assert hasattr(out, "counts") - - ph.assert_dtype( - "unique_all", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" - ) - ph.assert_default_index( - "unique_all", out.indices.dtype, repr_name="out.indices.dtype" - ) - ph.assert_default_index( - "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype" - ) - ph.assert_default_index( - "unique_all", out.counts.dtype, repr_name="out.counts.dtype" - ) - - assert ( - out.indices.shape == out.values.shape - ), f"{out.indices.shape=}, but should be {out.values.shape=}" - ph.assert_shape( - "unique_all", - out_shape=out.inverse_indices.shape, - expected=x.shape, - repr_name="out.inverse_indices.shape", - ) - assert ( - out.counts.shape == out.values.shape - ), f"{out.counts.shape=}, but should be {out.values.shape=}" - - scalar_type = dh.get_scalar_type(out.values.dtype) - counts = defaultdict(int) - firsts = {} - for i, idx in enumerate(sh.ndindex(x.shape)): - val = scalar_type(x[idx]) - if counts[val] == 0: - firsts[val] = i - counts[val] += 1 - - for idx in sh.ndindex(out.indices.shape): - val = scalar_type(out.values[idx]) - if cmath.isnan(val): - break - i = int(out.indices[idx]) - expected = firsts[val] - assert i == expected, ( - f"out.values[{idx}]={val} and out.indices[{idx}]={i}, " - f"but first occurence of {val} is at {expected}" + repro_snippet = ph.format_snippet(f"xp.unique_all({x!r}") + try: + out = xp.unique_all(x) + + assert hasattr(out, "values") + assert hasattr(out, "indices") + assert hasattr(out, "inverse_indices") + assert hasattr(out, "counts") + + ph.assert_dtype( + "unique_all", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" + ) + ph.assert_default_index( + "unique_all", out.indices.dtype, repr_name="out.indices.dtype" + ) + ph.assert_default_index( + "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype" + ) + ph.assert_default_index( + "unique_all", out.counts.dtype, repr_name="out.counts.dtype" ) - for idx in sh.ndindex(out.inverse_indices.shape): - ridx = int(out.inverse_indices[idx]) - val = out.values[ridx] - expected = x[idx] - msg = ( - f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " - f"but should result in x[{idx}]={expected}" + assert ( + out.indices.shape == out.values.shape + ), f"{out.indices.shape=}, but should be {out.values.shape=}" + ph.assert_shape( + "unique_all", + out_shape=out.inverse_indices.shape, + expected=x.shape, + repr_name="out.inverse_indices.shape", ) - if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): - assert xp.isnan(val), msg - else: - assert val == expected, msg - - vals_idx = {} - nans = 0 - for idx in sh.ndindex(out.values.shape): - val = scalar_type(out.values[idx]) - count = int(out.counts[idx]) - if cmath.isnan(val): - nans += 1 - assert count == 1, ( - f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " - "but count should be 1 as NaNs are distinct" + assert ( + out.counts.shape == out.values.shape + ), f"{out.counts.shape=}, but should be {out.values.shape=}" + + scalar_type = dh.get_scalar_type(out.values.dtype) + counts = defaultdict(int) + firsts = {} + for i, idx in enumerate(sh.ndindex(x.shape)): + val = scalar_type(x[idx]) + if counts[val] == 0: + firsts[val] = i + counts[val] += 1 + + for idx in sh.ndindex(out.indices.shape): + val = scalar_type(out.values[idx]) + if cmath.isnan(val): + break + i = int(out.indices[idx]) + expected = firsts[val] + assert i == expected, ( + f"out.values[{idx}]={val} and out.indices[{idx}]={i}, " + f"but first occurence of {val} is at {expected}" ) - else: - expected = counts[val] - assert ( - expected > 0 - ), f"out.values[{idx}]={val}, but {val} not in input array" - count = int(out.counts[idx]) - assert count == expected, ( - f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " - f"but should be {expected}" + + for idx in sh.ndindex(out.inverse_indices.shape): + ridx = int(out.inverse_indices[idx]) + val = out.values[ridx] + expected = x[idx] + msg = ( + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " + f"but should result in x[{idx}]={expected}" ) - assert ( - val not in vals_idx.keys() - ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" - vals_idx[val] = idx + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): + assert xp.isnan(val), msg + else: + assert val == expected, msg - if dh.is_float_dtype(out.values.dtype): - assume(math.prod(x.shape) <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if cmath.isnan(k)) - assert nans == expected, f"{nans} NaNs in out, but should be {expected}" + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + count = int(out.counts[idx]) + if cmath.isnan(val): + nans += 1 + assert count == 1, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + "but count should be 1 as NaNs are distinct" + ) + else: + expected = counts[val] + assert ( + expected > 0 + ), f"out.values[{idx}]={val}, but {val} not in input array" + count = int(out.counts[idx]) + assert count == expected, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + f"but should be {expected}" + ) + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + + if dh.is_float_dtype(out.values.dtype): + assume(math.prod(x.shape) <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_counts(x): - out = xp.unique_counts(x) - assert hasattr(out, "values") - assert hasattr(out, "counts") - ph.assert_dtype( - "unique_counts", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" - ) - ph.assert_default_index( - "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" - ) - assert ( - out.counts.shape == out.values.shape - ), f"{out.counts.shape=}, but should be {out.values.shape=}" - scalar_type = dh.get_scalar_type(out.values.dtype) - counts = Counter(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) - vals_idx = {} - nans = 0 - for idx in sh.ndindex(out.values.shape): - val = scalar_type(out.values[idx]) - count = int(out.counts[idx]) - if cmath.isnan(val): - nans += 1 - assert count == 1, ( - f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " - "but count should be 1 as NaNs are distinct" - ) - else: - expected = counts[val] - assert ( - expected > 0 - ), f"out.values[{idx}]={val}, but {val} not in input array" + repro_snippet = ph.format_snippet(f"xp.unique_counts({x!r}") + try: + out = xp.unique_counts(x) + assert hasattr(out, "values") + assert hasattr(out, "counts") + ph.assert_dtype( + "unique_counts", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" + ) + ph.assert_default_index( + "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" + ) + assert ( + out.counts.shape == out.values.shape + ), f"{out.counts.shape=}, but should be {out.values.shape=}" + scalar_type = dh.get_scalar_type(out.values.dtype) + counts = Counter(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) count = int(out.counts[idx]) - assert count == expected, ( - f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " - f"but should be {expected}" - ) - assert ( - val not in vals_idx.keys() - ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" - vals_idx[val] = idx - if dh.is_float_dtype(out.values.dtype): - assume(math.prod(x.shape) <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if cmath.isnan(k)) - assert nans == expected, f"{nans} NaNs in out, but should be {expected}" - + if cmath.isnan(val): + nans += 1 + assert count == 1, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + "but count should be 1 as NaNs are distinct" + ) + else: + expected = counts[val] + assert ( + expected > 0 + ), f"out.values[{idx}]={val}, but {val} not in input array" + count = int(out.counts[idx]) + assert count == expected, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + f"but should be {expected}" + ) + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + if dh.is_float_dtype(out.values.dtype): + assume(math.prod(x.shape) <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_inverse(x): - out = xp.unique_inverse(x) - assert hasattr(out, "values") - assert hasattr(out, "inverse_indices") - ph.assert_dtype( - "unique_inverse", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" - ) - ph.assert_default_index( - "unique_inverse", - out.inverse_indices.dtype, - repr_name="out.inverse_indices.dtype", - ) - ph.assert_shape( - "unique_inverse", - out_shape=out.inverse_indices.shape, - expected=x.shape, - repr_name="out.inverse_indices.shape", - ) - scalar_type = dh.get_scalar_type(out.values.dtype) - distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) - vals_idx = {} - nans = 0 - for idx in sh.ndindex(out.values.shape): - val = scalar_type(out.values[idx]) - if cmath.isnan(val): - nans += 1 - else: - assert ( - val in distinct - ), f"out.values[{idx}]={val}, but {val} not in input array" - assert ( - val not in vals_idx.keys() - ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" - vals_idx[val] = idx - for idx in sh.ndindex(out.inverse_indices.shape): - ridx = int(out.inverse_indices[idx]) - val = out.values[ridx] - expected = x[idx] - msg = ( - f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " - f"but should result in x[{idx}]={expected}" + repro_snippet = ph.format_snippet(f"xp.unique_inverse({x!r}") + try: + out = xp.unique_inverse(x) + assert hasattr(out, "values") + assert hasattr(out, "inverse_indices") + ph.assert_dtype( + "unique_inverse", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" + ) + ph.assert_default_index( + "unique_inverse", + out.inverse_indices.dtype, + repr_name="out.inverse_indices.dtype", ) - if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): - assert xp.isnan(val), msg - else: - assert val == expected, msg - if dh.is_float_dtype(out.values.dtype): - assume(math.prod(x.shape) <= 128) # may not be representable - expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) - assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" + ph.assert_shape( + "unique_inverse", + out_shape=out.inverse_indices.shape, + expected=x.shape, + repr_name="out.inverse_indices.shape", + ) + scalar_type = dh.get_scalar_type(out.values.dtype) + distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + if cmath.isnan(val): + nans += 1 + else: + assert ( + val in distinct + ), f"out.values[{idx}]={val}, but {val} not in input array" + assert ( + val not in vals_idx.keys() + ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + for idx in sh.ndindex(out.inverse_indices.shape): + ridx = int(out.inverse_indices[idx]) + val = out.values[ridx] + expected = x[idx] + msg = ( + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " + f"but should result in x[{idx}]={expected}" + ) + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): + assert xp.isnan(val), msg + else: + assert val == expected, msg + if dh.is_float_dtype(out.values.dtype): + assume(math.prod(x.shape) <= 128) # may not be representable + expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) + assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" + except Exception as exc: + exc.add_note(repro_snippet) + raise @given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_values(x): - out = xp.unique_values(x) - ph.assert_dtype("unique_values", in_dtype=x.dtype, out_dtype=out.dtype) - scalar_type = dh.get_scalar_type(x.dtype) - distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) - vals_idx = {} - nans = 0 - for idx in sh.ndindex(out.shape): - val = scalar_type(out[idx]) - if cmath.isnan(val): - nans += 1 - else: - assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" - assert ( - val not in vals_idx.keys() - ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" - vals_idx[val] = idx - if dh.is_float_dtype(out.dtype): - assume(math.prod(x.shape) <= 128) # may not be representable - expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) - assert nans == expected, f"{nans} NaNs in out, but should be {expected}" + repro_snippet = ph.format_snippet(f"xp.unique_values({x!r}") + try: + out = xp.unique_values(x) + ph.assert_dtype("unique_values", in_dtype=x.dtype, out_dtype=out.dtype) + scalar_type = dh.get_scalar_type(x.dtype) + distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.shape): + val = scalar_type(out[idx]) + if cmath.isnan(val): + nans += 1 + else: + assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + if dh.is_float_dtype(out.dtype): + assume(math.prod(x.shape) <= 128) # may not be representable + expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" + except Exception as exc: + exc.add_note(repro_snippet) + raise diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index 3d25798c..3b4baa23 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -51,43 +51,47 @@ def test_argsort(x, data): label="kw", ) - out = xp.argsort(x, **kw) - - ph.assert_default_index("argsort", out.dtype) - ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw) - axis = kw.get("axis", -1) - axes = sh.normalize_axis(axis, x.ndim) - scalar_type = dh.get_scalar_type(x.dtype) - for indices in sh.axes_ndindex(x.shape, axes): - elements = [scalar_type(x[idx]) for idx in indices] - orders = list(range(len(elements))) - sorders = sorted( - orders, key=elements.__getitem__, reverse=kw.get("descending", False) - ) - if kw.get("stable", True): - for idx, o in zip(indices, sorders): - ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw) - else: - idx_elements = dict(zip(indices, elements)) - idx_orders = dict(zip(indices, orders)) - element_orders = {} - for e in set(elements): - element_orders[e] = [ - idx_orders[idx] for idx in indices if idx_elements[idx] == e - ] - selements = [elements[o] for o in sorders] - for idx, e in zip(indices, selements): - expected_orders = element_orders[e] - out_o = int(out[idx]) - if len(expected_orders) == 1: - ph.assert_scalar_equals( - "argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw - ) - else: - assert_scalar_in_set( - "argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw - ) + repro_snippet = ph.format_snippet(f"xp.argsort({x!r}, **kw) with {kw = }") + try: + out = xp.argsort(x, **kw) + ph.assert_default_index("argsort", out.dtype) + ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw) + axis = kw.get("axis", -1) + axes = sh.normalize_axis(axis, x.ndim) + scalar_type = dh.get_scalar_type(x.dtype) + for indices in sh.axes_ndindex(x.shape, axes): + elements = [scalar_type(x[idx]) for idx in indices] + orders = list(range(len(elements))) + sorders = sorted( + orders, key=elements.__getitem__, reverse=kw.get("descending", False) + ) + if kw.get("stable", True): + for idx, o in zip(indices, sorders): + ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw) + else: + idx_elements = dict(zip(indices, elements)) + idx_orders = dict(zip(indices, orders)) + element_orders = {} + for e in set(elements): + element_orders[e] = [ + idx_orders[idx] for idx in indices if idx_elements[idx] == e + ] + selements = [elements[o] for o in sorders] + for idx, e in zip(indices, selements): + expected_orders = element_orders[e] + out_o = int(out[idx]) + if len(expected_orders) == 1: + ph.assert_scalar_equals( + "argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw + ) + else: + assert_scalar_in_set( + "argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized # TODO: Test with signed zeros and NaNs (and ignore them somehow) @@ -112,27 +116,32 @@ def test_sort(x, data): label="kw", ) - out = xp.sort(x, **kw) + repro_snippet = ph.format_snippet(f"xp.sort({x!r}, **kw) with {kw = }") + try: + out = xp.sort(x, **kw) - ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype) - ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw) - axis = kw.get("axis", -1) - axes = sh.normalize_axis(axis, x.ndim) - scalar_type = dh.get_scalar_type(x.dtype) - for indices in sh.axes_ndindex(x.shape, axes): - elements = [scalar_type(x[idx]) for idx in indices] - size = len(elements) - orders = sorted( - range(size), key=elements.__getitem__, reverse=kw.get("descending", False) - ) - for out_idx, o in zip(indices, orders): - x_idx = indices[o] - # TODO: error message when unstable should not imply just one idx - ph.assert_0d_equals( - "sort", - x_repr=f"x[{x_idx}]", - x_val=x[x_idx], - out_repr=f"out[{out_idx}]", - out_val=out[out_idx], - kw=kw, + ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype) + ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw) + axis = kw.get("axis", -1) + axes = sh.normalize_axis(axis, x.ndim) + scalar_type = dh.get_scalar_type(x.dtype) + for indices in sh.axes_ndindex(x.shape, axes): + elements = [scalar_type(x[idx]) for idx in indices] + size = len(elements) + orders = sorted( + range(size), key=elements.__getitem__, reverse=kw.get("descending", False) ) + for out_idx, o in zip(indices, orders): + x_idx = indices[o] + # TODO: error message when unstable should not imply just one idx + ph.assert_0d_equals( + "sort", + x_repr=f"x[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 0e3aa9d4..58204e78 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -42,55 +42,59 @@ def test_cumulative_sum(x, data): label="kw", ) - out = xp.cumulative_sum(x, **kw) - - expected_shape = list(x.shape) - if include_initial: - expected_shape[_axis] += 1 - expected_shape = tuple(expected_shape) - ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape) - - expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) - if expected_dtype is None: - # If a default uint cannot exist (i.e. in PyTorch which doesn't support - # uint32 or uint64), we skip testing the output dtype. - # See https://github.com/data-apis/array-api-tests/issues/106 - if x.dtype in dh.uint_dtypes: - assert dh.is_int_dtype(out.dtype) # sanity check - else: - ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) - - scalar_type = dh.get_scalar_type(out.dtype) - - for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): - x_arr = x[x_idx.raw] - out_arr = out[out_idx.raw] + repro_snippet = ph.format_snippet(f"xp.cumulative_sum({x!r}, **kw) with {kw = }") + try: + out = xp.cumulative_sum(x, **kw) + expected_shape = list(x.shape) if include_initial: - ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0) - - for n in range(x.shape[_axis]): - start = 1 if include_initial else 0 - out_val = out_arr[n + start] - assume(cmath.isfinite(out_val)) - elements = [] - for idx in range(n + 1): - s = scalar_type(x_arr[idx]) - elements.append(s) - expected = sum(elements) - if dh.is_int_dtype(out.dtype): - m, M = dh.dtype_ranges[out.dtype] - assume(m <= expected <= M) - ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, - idx=out_idx.raw, out=out_val, - expected=expected) - else: - condition_number = _sum_condition_number(elements) - assume(condition_number < 1e6) - ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type, - idx=out_idx.raw, out=out_val, - expected=expected) - + expected_shape[_axis] += 1 + expected_shape = tuple(expected_shape) + ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape) + + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: + ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + + scalar_type = dh.get_scalar_type(out.dtype) + + for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): + x_arr = x[x_idx.raw] + out_arr = out[out_idx.raw] + + if include_initial: + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0) + + for n in range(x.shape[_axis]): + start = 1 if include_initial else 0 + out_val = out_arr[n + start] + assume(cmath.isfinite(out_val)) + elements = [] + for idx in range(n + 1): + s = scalar_type(x_arr[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, + idx=out_idx.raw, out=out_val, + expected=expected) + else: + condition_number = _sum_condition_number(elements) + assume(condition_number < 1e6) + ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type, + idx=out_idx.raw, out=out_val, + expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2024.12") @@ -119,34 +123,39 @@ def test_cumulative_prod(x, data): label="kw", ) - out = xp.cumulative_prod(x, **kw) + repro_snippet = ph.format_snippet(f"xp.cumulative_prod({x!r}, **kw) with {kw = }") + try: + out = xp.cumulative_prod(x, **kw) - expected_shape = list(x.shape) - if include_initial: - expected_shape[_axis] += 1 - expected_shape = tuple(expected_shape) - ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape) - - expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) - if expected_dtype is None: - # If a default uint cannot exist (i.e. in PyTorch which doesn't support - # uint32 or uint64), we skip testing the output dtype. - # See https://github.com/data-apis/array-api-tests/issues/106 - if x.dtype in dh.uint_dtypes: - assert dh.is_int_dtype(out.dtype) # sanity check - else: - ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + expected_shape = list(x.shape) + if include_initial: + expected_shape[_axis] += 1 + expected_shape = tuple(expected_shape) + ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape) + + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: + ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) - scalar_type = dh.get_scalar_type(out.dtype) + scalar_type = dh.get_scalar_type(out.dtype) - for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): - #x_arr = x[x_idx.raw] - out_arr = out[out_idx.raw] + for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): + #x_arr = x[x_idx.raw] + out_arr = out[out_idx.raw] - if include_initial: - ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=1) + if include_initial: + ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=1) - #TODO: add value testing of cumulative_prod + #TODO: add value testing of cumulative_prod + except Exception as exc: + exc.add_note(repro_snippet) + raise def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: @@ -169,22 +178,27 @@ def test_max(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") keepdims = kw.get("keepdims", False) - out = xp.max(x, **kw) - - ph.assert_dtype("max", in_dtype=x.dtype, out_dtype=out.dtype) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "max", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): - max_ = scalar_type(out[out_idx]) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = max(elements) - ph.assert_scalar_equals("max", type_=scalar_type, idx=out_idx, out=max_, expected=expected) + repro_snippet = ph.format_snippet(f"xp.max({x!r}, **kw) with {kw = }") + try: + out = xp.max(x, **kw) + + ph.assert_dtype("max", in_dtype=x.dtype, out_dtype=out.dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "max", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + max_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(elements) + ph.assert_scalar_equals("max", type_=scalar_type, idx=out_idx, out=max_, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise @given( @@ -199,14 +213,19 @@ def test_mean(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") keepdims = kw.get("keepdims", False) - out = xp.mean(x, **kw) + repro_snippet = ph.format_snippet(f"xp.mean({x!r}, **kw) with {kw = }") + try: + out = xp.mean(x, **kw) - ph.assert_dtype("mean", in_dtype=x.dtype, out_dtype=out.dtype) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "mean", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - # Values testing mean is too finicky + ph.assert_dtype("mean", in_dtype=x.dtype, out_dtype=out.dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "mean", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + # Values testing mean is too finicky + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -222,22 +241,27 @@ def test_min(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") keepdims = kw.get("keepdims", False) - out = xp.min(x, **kw) - - ph.assert_dtype("min", in_dtype=x.dtype, out_dtype=out.dtype) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "min", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): - min_ = scalar_type(out[out_idx]) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = min(elements) - ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected) + repro_snippet = ph.format_snippet(f"xp.min({x!r}, **kw) with {kw = }") + try: + out = xp.min(x, **kw) + + ph.assert_dtype("min", in_dtype=x.dtype, out_dtype=out.dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "min", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + min_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(elements) + ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise def _prod_condition_number(elements): @@ -250,6 +274,7 @@ def _prod_condition_number(elements): return abs_max / abs_min + @pytest.mark.unvectorized @given( x=hh.arrays( @@ -270,42 +295,47 @@ def test_prod(x, data): ) keepdims = kw.get("keepdims", False) - with hh.reject_overflow(): - out = xp.prod(x, **kw) - - dtype = kw.get("dtype", None) - expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) - if expected_dtype is None: - # If a default uint cannot exist (i.e. in PyTorch which doesn't support - # uint32 or uint64), we skip testing the output dtype. - # See https://github.com/data-apis/array-api-tests/issues/106 - if x.dtype in dh.uint_dtypes: - assert dh.is_int_dtype(out.dtype) # sanity check - else: - ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): - prod = scalar_type(out[out_idx]) - assume(cmath.isfinite(prod)) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = math.prod(elements) - if dh.is_int_dtype(out.dtype): - m, M = dh.dtype_ranges[out.dtype] - assume(m <= expected <= M) - ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, - out=prod, expected=expected) + repro_snippet = ph.format_snippet(f"xp.prod({x!r}, **kw) with {kw = }") + try: + with hh.reject_overflow(): + out = xp.prod(x, **kw) + + dtype = kw.get("dtype", None) + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check else: - condition_number = _prod_condition_number(elements) - assume(condition_number < 1e15) - ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx, - out=prod, expected=expected) + ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + prod = scalar_type(out[out_idx]) + assume(cmath.isfinite(prod)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = math.prod(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, + out=prod, expected=expected) + else: + condition_number = _prod_condition_number(elements) + assume(condition_number < 1e15) + ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx, + out=prod, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.skip(reason="flaky") # TODO: fix! @@ -336,13 +366,18 @@ def test_std(x, data): ) keepdims = kw.get("keepdims", False) - out = xp.std(x, **kw) + repro_snippet = ph.format_snippet(f"xp.std({x!r}, **kw) with {kw = }") + try: + out = xp.std(x, **kw) - ph.assert_dtype("std", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_keepdimable_shape( - "std", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - # We can't easily test the result(s) as standard deviation methods vary a lot + ph.assert_dtype("std", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_keepdimable_shape( + "std", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + # We can't easily test the result(s) as standard deviation methods vary a lot + except Exception as exc: + exc.add_note(repro_snippet) + raise def _sum_condition_number(elements): @@ -354,6 +389,7 @@ def _sum_condition_number(elements): return sum_abs / abs_sum + # @pytest.mark.unvectorized @given( x=hh.arrays( @@ -374,44 +410,49 @@ def test_sum(x, data): ) keepdims = kw.get("keepdims", False) - with hh.reject_overflow(): - out = xp.sum(x, **kw) - - dtype = kw.get("dtype", None) - expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) - if expected_dtype is None: - # If a default uint cannot exist (i.e. in PyTorch which doesn't support - # uint32 or uint64), we skip testing the output dtype. - # See https://github.com/data-apis/array-api-tests/issues/160 - if x.dtype in dh.uint_dtypes: - assert dh.is_int_dtype(out.dtype) # sanity check - else: - ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): - sum_ = scalar_type(out[out_idx]) - assume(cmath.isfinite(sum_)) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = sum(elements) - if dh.is_int_dtype(out.dtype): - m, M = dh.dtype_ranges[out.dtype] - assume(m <= expected <= M) - ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, - out=sum_, expected=expected) + repro_snippet = ph.format_snippet(f"xp.sum({x!r}, **kw) with {kw = }") + try: + with hh.reject_overflow(): + out = xp.sum(x, **kw) + + dtype = kw.get("dtype", None) + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/160 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check else: - # Avoid value testing for ill conditioned summations. See - # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and - # https://en.wikipedia.org/wiki/Condition_number. - condition_number = _sum_condition_number(elements) - assume(condition_number < 1e6) - ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) + ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + sum_ = scalar_type(out[out_idx]) + assume(cmath.isfinite(sum_)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, + out=sum_, expected=expected) + else: + # Avoid value testing for ill conditioned summations. See + # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and + # https://en.wikipedia.org/wiki/Condition_number. + condition_number = _sum_condition_number(elements) + assume(condition_number < 1e6) + ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -443,10 +484,15 @@ def test_var(x, data): ) keepdims = kw.get("keepdims", False) - out = xp.var(x, **kw) - - ph.assert_dtype("var", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_keepdimable_shape( - "var", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - # We can't easily test the result(s) as variance methods vary a lot + repro_snippet = ph.format_snippet(f"xp.var({x!r}, **kw) with {kw = }") + try: + out = xp.var(x, **kw) + + ph.assert_dtype("var", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_keepdimable_shape( + "var", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + # We can't easily test the result(s) as variance methods vary a lot + except Exception as exc: + exc.add_note(repro_snippet) + raise