|
14 | 14 | "doesnt_raise", |
15 | 15 | "nargs", |
16 | 16 | "fmt_kw", |
| 17 | + "is_pos_zero", |
| 18 | + "is_neg_zero", |
17 | 19 | "assert_dtype", |
18 | 20 | "assert_kw_dtype", |
19 | 21 | "assert_default_float", |
|
22 | 24 | "assert_shape", |
23 | 25 | "assert_result_shape", |
24 | 26 | "assert_keepdimable_shape", |
| 27 | + "assert_0d_equals", |
25 | 28 | "assert_fill", |
26 | 29 | "assert_array", |
27 | 30 | ] |
@@ -69,6 +72,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: |
69 | 72 | return ", ".join(f"{k}={v}" for k, v in kw.items()) |
70 | 73 |
|
71 | 74 |
|
| 75 | +def is_pos_zero(n: float) -> bool: |
| 76 | + return n == 0 and math.copysign(1, n) == 1 |
| 77 | + |
| 78 | + |
| 79 | +def is_neg_zero(n: float) -> bool: |
| 80 | + return n == 0 and math.copysign(1, n) == -1 |
| 81 | + |
| 82 | + |
72 | 83 | def assert_dtype( |
73 | 84 | func_name: str, |
74 | 85 | in_dtype: Union[DataType, Sequence[DataType]], |
@@ -232,15 +243,28 @@ def assert_fill( |
232 | 243 | def assert_array(func_name: str, out: Array, expected: Array, /, **kw): |
233 | 244 | assert_dtype(func_name, out.dtype, expected.dtype) |
234 | 245 | assert_shape(func_name, out.shape, expected.shape, **kw) |
235 | | - msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}" |
| 246 | + f_func = f"[{func_name}({fmt_kw(kw)})]" |
236 | 247 | if dh.is_float_dtype(out.dtype): |
237 | | - neg_zeros = expected == -0.0 |
238 | | - assert xp.all((out == -0.0) == neg_zeros), msg |
239 | | - pos_zeros = expected == +0.0 |
240 | | - assert xp.all((out == +0.0) == pos_zeros), msg |
241 | | - nans = xp.isnan(expected) |
242 | | - assert xp.all(xp.isnan(out) == nans), msg |
243 | | - mask = ~(neg_zeros | pos_zeros | nans) |
244 | | - assert xp.all(out[mask] == expected[mask]), msg |
| 248 | + for idx in sh.ndindex(out.shape): |
| 249 | + at_out = out[idx] |
| 250 | + at_expected = expected[idx] |
| 251 | + msg = ( |
| 252 | + f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} " |
| 253 | + f"{f_func}" |
| 254 | + ) |
| 255 | + if xp.isnan(at_expected): |
| 256 | + assert xp.isnan(at_out), msg |
| 257 | + elif at_expected == 0.0 or at_expected == -0.0: |
| 258 | + scalar_at_expected = float(at_expected) |
| 259 | + scalar_at_out = float(at_out) |
| 260 | + if is_pos_zero(scalar_at_expected): |
| 261 | + assert is_pos_zero(scalar_at_out), msg |
| 262 | + else: |
| 263 | + assert is_neg_zero(scalar_at_expected) # sanity check |
| 264 | + assert is_neg_zero(scalar_at_out), msg |
| 265 | + else: |
| 266 | + assert at_out == at_expected, msg |
245 | 267 | else: |
246 | | - assert xp.all(out == expected), msg |
| 268 | + assert xp.all(out == expected), ( |
| 269 | + f"out not as expected {f_func}\n" f"{out=}\n{expected=}" |
| 270 | + ) |
0 commit comments