Skip to content

Commit 8096c93

Browse files
committed
ENH: add "repro snippets" to test_data_type_functions.py
1 parent 7203669 commit 8096c93

File tree

1 file changed

+163
-96
lines changed

1 file changed

+163
-96
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 163 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,17 @@ def test_astype(x_dtype, dtype, kw, data):
8181
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
8282
assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))
8383

84-
out = xp.astype(x, dtype, **kw)
84+
repro_snippet = ph.format_snippet(f"xp.astype({x!r}, {dtype!r}, **kw) with {kw = }")
85+
try:
86+
out = xp.astype(x, dtype, **kw)
8587

86-
ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)
87-
ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw)
88-
# TODO: test values
89-
# TODO: test copy
88+
ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)
89+
ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw)
90+
# TODO: test values
91+
# TODO: test copy
92+
except Exception as exc:
93+
exc.add_note(repro_snippet)
94+
raise
9095

9196

9297
@given(
@@ -98,24 +103,29 @@ def test_broadcast_arrays(shapes, data):
98103
x = data.draw(hh.arrays(dtype=hh.all_dtypes, shape=shape), label=f"x{c}")
99104
arrays.append(x)
100105

101-
out = xp.broadcast_arrays(*arrays)
102-
103-
expected_shape = sh.broadcast_shapes(*shapes)
104-
for i, x in enumerate(arrays):
105-
ph.assert_dtype(
106-
"broadcast_arrays",
107-
in_dtype=x.dtype,
108-
out_dtype=out[i].dtype,
109-
repr_name=f"out[{i}].dtype"
110-
)
111-
ph.assert_result_shape(
112-
"broadcast_arrays",
113-
in_shapes=shapes,
114-
out_shape=out[i].shape,
115-
expected=expected_shape,
116-
repr_name=f"out[{i}].shape",
117-
)
118-
# TODO: test values
106+
repro_snippet = ph.format_snippet(f"xp.broadcast_arrays(*arrays) with {arrays = }")
107+
try:
108+
out = xp.broadcast_arrays(*arrays)
109+
110+
expected_shape = sh.broadcast_shapes(*shapes)
111+
for i, x in enumerate(arrays):
112+
ph.assert_dtype(
113+
"broadcast_arrays",
114+
in_dtype=x.dtype,
115+
out_dtype=out[i].dtype,
116+
repr_name=f"out[{i}].dtype"
117+
)
118+
ph.assert_result_shape(
119+
"broadcast_arrays",
120+
in_shapes=shapes,
121+
out_shape=out[i].shape,
122+
expected=expected_shape,
123+
repr_name=f"out[{i}].shape",
124+
)
125+
# TODO: test values
126+
except Exception as exc:
127+
exc.add_note(repro_snippet)
128+
raise
119129

120130

121131
@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data())
@@ -127,29 +137,39 @@ def test_broadcast_to(x, data):
127137
label="shape",
128138
)
129139

130-
out = xp.broadcast_to(x, shape)
140+
repro_snippet = ph.format_snippet(f"xp.broadcast_to({x!r}, {shape!r})")
141+
try:
142+
out = xp.broadcast_to(x, shape)
131143

132-
ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype)
133-
ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape)
134-
# TODO: test values
144+
ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype)
145+
ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape)
146+
# TODO: test values
147+
except Exception as exc:
148+
exc.add_note(repro_snippet)
149+
raise
135150

136151

137152
@given(_from=hh.all_dtypes, to=hh.all_dtypes)
138153
def test_can_cast(_from, to):
139-
out = xp.can_cast(_from, to)
154+
repro_snippet = ph.format_snippet(f"xp.can_cast({_from!r}, {to!r})")
155+
try:
156+
out = xp.can_cast(_from, to)
140157

141-
expected = False
142-
for other in dh.all_dtypes:
143-
if dh.promotion_table.get((_from, other)) == to:
144-
expected = True
145-
break
158+
expected = False
159+
for other in dh.all_dtypes:
160+
if dh.promotion_table.get((_from, other)) == to:
161+
expected = True
162+
break
146163

147-
f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
148-
if expected:
149-
# cross-kind casting is not explicitly disallowed. We can only test
150-
# the cases where it should return True. TODO: if expected=False,
151-
# check that the array library actually allows such casts.
152-
assert out == expected, f"{out=}, but should be {expected} {f_func}"
164+
f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
165+
if expected:
166+
# cross-kind casting is not explicitly disallowed. We can only test
167+
# the cases where it should return True. TODO: if expected=False,
168+
# check that the array library actually allows such casts.
169+
assert out == expected, f"{out=}, but should be {expected} {f_func}"
170+
except Exception as exc:
171+
exc.add_note(repro_snippet)
172+
raise
153173

154174

155175
@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
@@ -160,7 +180,7 @@ def test_finfo(dtype):
160180
# np.float64 and np.asarray(1, dtype=np.float64).dtype are different
161181
xp.asarray(1, dtype=dtype).dtype,
162182
):
163-
repro_snippet = ph.format_snippet(f"xp.finfo({arg})")
183+
repro_snippet = ph.format_snippet(f"xp.finfo({arg!r})")
164184
try:
165185
out = xp.finfo(arg)
166186
assert isinstance(out.bits, int)
@@ -175,19 +195,24 @@ def test_finfo(dtype):
175195
@pytest.mark.min_version("2022.12")
176196
@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
177197
def test_finfo_dtype(dtype):
178-
out = xp.finfo(dtype)
179-
180-
if dtype == xp.complex64:
181-
assert out.dtype == xp.float32
182-
elif dtype == xp.complex128:
183-
assert out.dtype == xp.float64
184-
else:
185-
assert out.dtype == dtype
198+
repro_snippet = ph.format_snippet(f"xp.finfo({dtype!r})")
199+
try:
200+
out = xp.finfo(dtype)
201+
202+
if dtype == xp.complex64:
203+
assert out.dtype == xp.float32
204+
elif dtype == xp.complex128:
205+
assert out.dtype == xp.float64
206+
else:
207+
assert out.dtype == dtype
186208

187-
# Guard vs. numpy.dtype.__eq__ lax comparison
188-
assert not isinstance(out.dtype, str)
189-
assert out.dtype is not float
190-
assert out.dtype is not complex
209+
# Guard vs. numpy.dtype.__eq__ lax comparison
210+
assert not isinstance(out.dtype, str)
211+
assert out.dtype is not float
212+
assert out.dtype is not complex
213+
except Exception as exc:
214+
exc.add_note(repro_snippet)
215+
raise
191216

192217

193218
@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
@@ -198,20 +223,30 @@ def test_iinfo(dtype):
198223
# np.int64 and np.asarray(1, dtype=np.int64).dtype are different
199224
xp.asarray(1, dtype=dtype).dtype,
200225
):
201-
out = xp.iinfo(arg)
202-
assert isinstance(out.bits, int)
203-
assert isinstance(out.max, int)
204-
assert isinstance(out.min, int)
226+
repro_snippet = ph.format_snippet(f"xp.iinfo({arg!r})")
227+
try:
228+
out = xp.iinfo(arg)
229+
assert isinstance(out.bits, int)
230+
assert isinstance(out.max, int)
231+
assert isinstance(out.min, int)
232+
except Exception as exc:
233+
exc.add_note(repro_snippet)
234+
raise
205235

206236

207237
@pytest.mark.min_version("2022.12")
208238
@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
209239
def test_iinfo_dtype(dtype):
210-
out = xp.iinfo(dtype)
211-
assert out.dtype == dtype
212-
# Guard vs. numpy.dtype.__eq__ lax comparison
213-
assert not isinstance(out.dtype, str)
214-
assert out.dtype is not int
240+
repro_snippet = ph.format_snippet(f"xp.iinfo({dtype!r})")
241+
try:
242+
out = xp.iinfo(dtype)
243+
assert out.dtype == dtype
244+
# Guard vs. numpy.dtype.__eq__ lax comparison
245+
assert not isinstance(out.dtype, str)
246+
assert out.dtype is not int
247+
except Exception as exc:
248+
exc.add_note(repro_snippet)
249+
raise
215250

216251

217252
def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:
@@ -224,29 +259,39 @@ def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:
224259
kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple),
225260
)
226261
def test_isdtype(dtype, kind):
227-
out = xp.isdtype(dtype, kind)
228-
229-
assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]"
230-
_kinds = kind if isinstance(kind, tuple) else (kind,)
231-
expected = False
232-
for _kind in _kinds:
233-
if isinstance(_kind, str):
234-
if dtype in dh.kind_to_dtypes[_kind]:
235-
expected = True
236-
break
237-
else:
238-
if dtype == _kind:
239-
expected = True
240-
break
241-
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
262+
repro_snippet = ph.format_snippet(f"xp.isdtype({dtype!r}, {kind!r})")
263+
try:
264+
out = xp.isdtype(dtype, kind)
265+
266+
assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]"
267+
_kinds = kind if isinstance(kind, tuple) else (kind,)
268+
expected = False
269+
for _kind in _kinds:
270+
if isinstance(_kind, str):
271+
if dtype in dh.kind_to_dtypes[_kind]:
272+
expected = True
273+
break
274+
else:
275+
if dtype == _kind:
276+
expected = True
277+
break
278+
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
279+
except Exception as exc:
280+
exc.add_note(repro_snippet)
281+
raise
242282

243283

244284
@pytest.mark.min_version("2024.12")
245285
class TestResultType:
246286
@given(dtypes=hh.mutually_promotable_dtypes(None))
247287
def test_result_type(self, dtypes):
248-
out = xp.result_type(*dtypes)
249-
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
288+
repro_snippet = ph.format_snippet(f"xp.result_type(*dtypes) with {dtypes = }")
289+
try:
290+
out = xp.result_type(*dtypes)
291+
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
292+
except Exception as exc:
293+
exc.add_note(repro_snippet)
294+
raise
250295

251296
@given(pair=hh.pair_of_mutually_promotable_dtypes(None))
252297
def test_shuffled(self, pair):
@@ -261,33 +306,55 @@ def test_arrays_and_dtypes(self, pair, data):
261306
s1, s2 = pair
262307
a2 = tuple(xp.empty(1, dtype=dt) for dt in s2)
263308
a_and_dt = data.draw(st.permutations(s1 + a2))
264-
out = xp.result_type(*a_and_dt)
265-
ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out")
309+
310+
repro_snippet = ph.format_snippet(f"xp.result_type(*a_and_dt) with {a_and_dt = }")
311+
try:
312+
out = xp.result_type(*a_and_dt)
313+
ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out")
314+
except Exception as exc:
315+
exc.add_note(repro_snippet)
316+
raise
266317

267318
@given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data())
268319
def test_with_scalars(self, dtypes, data):
269-
out = xp.result_type(*dtypes)
270-
271-
if out == xp.bool:
272-
scalars = [True]
273-
elif out in dh.all_int_dtypes:
274-
scalars = [1]
275-
elif out in dh.real_dtypes:
276-
scalars = [1, 1.0]
277-
elif out in dh.numeric_dtypes:
278-
scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types
279-
else:
280-
raise ValueError(f"unknown dtype {out = }.")
320+
321+
repro_snippet = ph.format_snippet(f"xp.result_type(*dtypes) with {dtypes = }")
322+
try:
323+
out = xp.result_type(*dtypes)
324+
325+
if out == xp.bool:
326+
scalars = [True]
327+
elif out in dh.all_int_dtypes:
328+
scalars = [1]
329+
elif out in dh.real_dtypes:
330+
scalars = [1, 1.0]
331+
elif out in dh.numeric_dtypes:
332+
scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types
333+
else:
334+
raise ValueError(f"unknown dtype {out = }.")
335+
except Exception as exc:
336+
exc.add_note(repro_snippet)
337+
raise
281338

282339
scalar = data.draw(st.sampled_from(scalars))
283340
inputs = data.draw(st.permutations(dtypes + (scalar,)))
284341

285-
out_scalar = xp.result_type(*inputs)
286-
assert out_scalar == out
342+
repro_snippet = ph.format_snippet(f"xp.result_type(*inputs) with {inputs = }")
343+
try:
344+
out_scalar = xp.result_type(*inputs)
345+
assert out_scalar == out
346+
except Exception as exc:
347+
exc.add_note(repro_snippet)
348+
raise
287349

288350
# retry with arrays
289351
arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes)
290352
inputs = data.draw(st.permutations(arrays + (scalar,)))
291-
out_scalar = xp.result_type(*inputs)
292-
assert out_scalar == out
293353

354+
repro_snippet = ph.format_snippet(f"xp.result_type(*inputs) with {inputs = }")
355+
try:
356+
out_scalar = xp.result_type(*inputs)
357+
assert out_scalar == out
358+
except Exception as exc:
359+
exc.add_note(repro_snippet)
360+
raise

0 commit comments

Comments
 (0)