@@ -33,23 +33,28 @@ def test_argmax(x, data):
3333 )
3434 keepdims = kw .get ("keepdims" , False )
3535
36- out = xp .argmax (x , ** kw )
36+ repro_snippet = ph .format_snippet (f"xp.argmax({ x !r} , **kw) with { kw = } " )
37+ try :
38+ out = xp .argmax (x , ** kw )
3739
38- ph .assert_default_index ("argmax" , out .dtype )
39- axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
40- ph .assert_keepdimable_shape (
41- "argmax" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
42- )
43- scalar_type = dh .get_scalar_type (x .dtype )
44- for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
45- max_i = int (out [out_idx ])
46- elements = []
47- for idx in indices :
48- s = scalar_type (x [idx ])
49- elements .append (s )
50- expected = max (range (len (elements )), key = elements .__getitem__ )
51- ph .assert_scalar_equals ("argmax" , type_ = int , idx = out_idx , out = max_i ,
52- expected = expected , kw = kw )
40+ ph .assert_default_index ("argmax" , out .dtype )
41+ axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
42+ ph .assert_keepdimable_shape (
43+ "argmax" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
44+ )
45+ scalar_type = dh .get_scalar_type (x .dtype )
46+ for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
47+ max_i = int (out [out_idx ])
48+ elements = []
49+ for idx in indices :
50+ s = scalar_type (x [idx ])
51+ elements .append (s )
52+ expected = max (range (len (elements )), key = elements .__getitem__ )
53+ ph .assert_scalar_equals ("argmax" , type_ = int , idx = out_idx , out = max_i ,
54+ expected = expected , kw = kw )
55+ except Exception as exc :
56+ exc .add_note (repro_snippet )
57+ raise
5358
5459
5560@given (
@@ -70,22 +75,27 @@ def test_argmin(x, data):
7075 )
7176 keepdims = kw .get ("keepdims" , False )
7277
73- out = xp .argmin (x , ** kw )
78+ repro_snippet = ph .format_snippet (f"xp.argmin({ x !r} , **kw) with { kw = } " )
79+ try :
80+ out = xp .argmin (x , ** kw )
7481
75- ph .assert_default_index ("argmin" , out .dtype )
76- axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
77- ph .assert_keepdimable_shape (
78- "argmin" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
79- )
80- scalar_type = dh .get_scalar_type (x .dtype )
81- for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
82- min_i = int (out [out_idx ])
83- elements = []
84- for idx in indices :
85- s = scalar_type (x [idx ])
86- elements .append (s )
87- expected = min (range (len (elements )), key = elements .__getitem__ )
88- ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
82+ ph .assert_default_index ("argmin" , out .dtype )
83+ axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
84+ ph .assert_keepdimable_shape (
85+ "argmin" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
86+ )
87+ scalar_type = dh .get_scalar_type (x .dtype )
88+ for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
89+ min_i = int (out [out_idx ])
90+ elements = []
91+ for idx in indices :
92+ s = scalar_type (x [idx ])
93+ elements .append (s )
94+ expected = min (range (len (elements )), key = elements .__getitem__ )
95+ ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
96+ except Exception as exc :
97+ exc .add_note (repro_snippet )
98+ raise
8999
90100
91101# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
@@ -115,23 +125,28 @@ def test_count_nonzero(x, data):
115125
116126 assume (kw .get ("axis" , None ) != ()) # TODO clarify in the spec
117127
118- out = xp .count_nonzero (x , ** kw )
128+ repro_snippet = ph .format_snippet (f"xp.count_nonzero({ x !r} , **kw) with { kw = } " )
129+ try :
130+ out = xp .count_nonzero (x , ** kw )
119131
120- ph .assert_default_index ("count_nonzero" , out .dtype )
121- axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
122- ph .assert_keepdimable_shape (
123- "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
124- )
125- scalar_type = dh .get_scalar_type (x .dtype )
132+ ph .assert_default_index ("count_nonzero" , out .dtype )
133+ axes = sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
134+ ph .assert_keepdimable_shape (
135+ "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw
136+ )
137+ scalar_type = dh .get_scalar_type (x .dtype )
126138
127- for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
128- count = int (out [out_idx ])
129- elements = []
130- for idx in indices :
131- s = scalar_type (x [idx ])
132- elements .append (s )
133- expected = sum (el != 0 for el in elements )
134- ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
139+ for indices , out_idx in zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
140+ count = int (out [out_idx ])
141+ elements = []
142+ for idx in indices :
143+ s = scalar_type (x [idx ])
144+ elements .append (s )
145+ expected = sum (el != 0 for el in elements )
146+ ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
147+ except Exception as exc :
148+ exc .add_note (repro_snippet )
149+ raise
135150
136151
137152@given (hh .arrays (dtype = hh .all_dtypes , shape = ()))
@@ -143,39 +158,44 @@ def test_nonzero_zerodim_error(x):
143158@pytest .mark .data_dependent_shapes
144159@given (hh .arrays (dtype = hh .all_dtypes , shape = hh .shapes (min_dims = 1 , min_side = 1 )))
145160def test_nonzero (x ):
146- out = xp .nonzero (x )
147- assert len (out ) == x .ndim , f"{ len (out )= } , but should be { x .ndim = } "
148- out_size = math .prod (out [0 ].shape )
149- for i in range (len (out )):
150- assert out [i ].ndim == 1 , f"out[{ i } ].ndim={ x .ndim } , but should be 1"
151- size_at = math .prod (out [i ].shape )
152- assert size_at == out_size , (
153- f"prod(out[{ i } ].shape)={ size_at } , "
154- f"but should be prod(out[0].shape)={ out_size } "
155- )
156- ph .assert_default_index ("nonzero" , out [i ].dtype , repr_name = f"out[{ i } ].dtype" )
157- indices = []
158- if x .dtype == xp .bool :
159- for idx in sh .ndindex (x .shape ):
160- if x [idx ]:
161- indices .append (idx )
162- else :
163- for idx in sh .ndindex (x .shape ):
164- if x [idx ] != 0 :
165- indices .append (idx )
166- if x .ndim == 0 :
167- assert out_size == len (
168- indices
169- ), f"prod(out[0].shape)={ out_size } , but should be { len (indices )} "
170- else :
171- for i in range (out_size ):
172- idx = tuple (int (x [i ]) for x in out )
173- f_idx = f"Extrapolated index (x[{ i } ] for x in out)={ idx } "
174- f_element = f"x[{ idx } ]={ x [idx ]} "
175- assert idx in indices , f"{ f_idx } results in { f_element } , a zero element"
176- assert (
177- idx == indices [i ]
178- ), f"{ f_idx } is in the wrong position, should be { indices .index (idx )} "
161+ repro_snippet = ph .format_snippet (f"xp.nonzero({ x !r} )" )
162+ try :
163+ out = xp .nonzero (x )
164+ assert len (out ) == x .ndim , f"{ len (out )= } , but should be { x .ndim = } "
165+ out_size = math .prod (out [0 ].shape )
166+ for i in range (len (out )):
167+ assert out [i ].ndim == 1 , f"out[{ i } ].ndim={ x .ndim } , but should be 1"
168+ size_at = math .prod (out [i ].shape )
169+ assert size_at == out_size , (
170+ f"prod(out[{ i } ].shape)={ size_at } , "
171+ f"but should be prod(out[0].shape)={ out_size } "
172+ )
173+ ph .assert_default_index ("nonzero" , out [i ].dtype , repr_name = f"out[{ i } ].dtype" )
174+ indices = []
175+ if x .dtype == xp .bool :
176+ for idx in sh .ndindex (x .shape ):
177+ if x [idx ]:
178+ indices .append (idx )
179+ else :
180+ for idx in sh .ndindex (x .shape ):
181+ if x [idx ] != 0 :
182+ indices .append (idx )
183+ if x .ndim == 0 :
184+ assert out_size == len (
185+ indices
186+ ), f"prod(out[0].shape)={ out_size } , but should be { len (indices )} "
187+ else :
188+ for i in range (out_size ):
189+ idx = tuple (int (x [i ]) for x in out )
190+ f_idx = f"Extrapolated index (x[{ i } ] for x in out)={ idx } "
191+ f_element = f"x[{ idx } ]={ x [idx ]} "
192+ assert idx in indices , f"{ f_idx } results in { f_element } , a zero element"
193+ assert (
194+ idx == indices [i ]
195+ ), f"{ f_idx } is in the wrong position, should be { indices .index (idx )} "
196+ except Exception as exc :
197+ exc .add_note (repro_snippet )
198+ raise
179199
180200
181201@given (
@@ -188,31 +208,36 @@ def test_where(shapes, dtypes, data):
188208 x1 = data .draw (hh .arrays (dtype = dtypes [0 ], shape = shapes [1 ]), label = "x1" )
189209 x2 = data .draw (hh .arrays (dtype = dtypes [1 ], shape = shapes [2 ]), label = "x2" )
190210
191- out = xp .where (cond , x1 , x2 )
192-
193- shape = sh .broadcast_shapes (* shapes )
194- ph .assert_shape ("where" , out_shape = out .shape , expected = shape )
195- # TODO: generate indices without broadcasting arrays
196- _cond = xp .broadcast_to (cond , shape )
197- _x1 = xp .broadcast_to (x1 , shape )
198- _x2 = xp .broadcast_to (x2 , shape )
199- for idx in sh .ndindex (shape ):
200- if _cond [idx ]:
201- ph .assert_0d_equals (
202- "where" ,
203- x_repr = f"_x1[{ idx } ]" ,
204- x_val = _x1 [idx ],
205- out_repr = f"out[{ idx } ]" ,
206- out_val = out [idx ]
207- )
208- else :
209- ph .assert_0d_equals (
210- "where" ,
211- x_repr = f"_x2[{ idx } ]" ,
212- x_val = _x2 [idx ],
213- out_repr = f"out[{ idx } ]" ,
214- out_val = out [idx ]
215- )
211+ repro_snippet = ph .format_snippet (f"xp.where({ cond !r} , { x1 !r} , { x2 !r} )" )
212+ try :
213+ out = xp .where (cond , x1 , x2 )
214+
215+ shape = sh .broadcast_shapes (* shapes )
216+ ph .assert_shape ("where" , out_shape = out .shape , expected = shape )
217+ # TODO: generate indices without broadcasting arrays
218+ _cond = xp .broadcast_to (cond , shape )
219+ _x1 = xp .broadcast_to (x1 , shape )
220+ _x2 = xp .broadcast_to (x2 , shape )
221+ for idx in sh .ndindex (shape ):
222+ if _cond [idx ]:
223+ ph .assert_0d_equals (
224+ "where" ,
225+ x_repr = f"_x1[{ idx } ]" ,
226+ x_val = _x1 [idx ],
227+ out_repr = f"out[{ idx } ]" ,
228+ out_val = out [idx ]
229+ )
230+ else :
231+ ph .assert_0d_equals (
232+ "where" ,
233+ x_repr = f"_x2[{ idx } ]" ,
234+ x_val = _x2 [idx ],
235+ out_repr = f"out[{ idx } ]" ,
236+ out_val = out [idx ]
237+ )
238+ except Exception as exc :
239+ exc .add_note (repro_snippet )
240+ raise
216241
217242
218243@pytest .mark .min_version ("2023.12" )
@@ -238,12 +263,17 @@ def test_searchsorted(data):
238263 label = "x2" ,
239264 )
240265
241- out = xp .searchsorted (x1 , x2 , sorter = sorter )
266+ repro_snippet = ph .format_snippet (f"xp.searchsorted({ x1 !r} , { x2 !r} , sorter={ sorter !r} )" )
267+ try :
268+ out = xp .searchsorted (x1 , x2 , sorter = sorter )
242269
243- ph .assert_dtype (
244- "searchsorted" ,
245- in_dtype = [x1 .dtype , x2 .dtype ],
246- out_dtype = out .dtype ,
247- expected = xp .__array_namespace_info__ ().default_dtypes ()["indexing" ],
248- )
249- # TODO: shapes and values testing
270+ ph .assert_dtype (
271+ "searchsorted" ,
272+ in_dtype = [x1 .dtype , x2 .dtype ],
273+ out_dtype = out .dtype ,
274+ expected = xp .__array_namespace_info__ ().default_dtypes ()["indexing" ],
275+ )
276+ # TODO: shapes and values testing
277+ except Exception as exc :
278+ exc .add_note (repro_snippet )
279+ raise
0 commit comments