|
17 | 17 | from hypothesis import assume, given |
18 | 18 | from hypothesis.strategies import (booleans, composite, none, tuples, integers, |
19 | 19 | shared, sampled_from, data, just) |
| 20 | +from ndindex import iter_indices |
20 | 21 |
|
21 | | -from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity |
| 22 | +from .array_helpers import assert_exactly_equal, asarray |
22 | 23 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, |
23 | 24 | square_matrix_shapes, symmetric_matrices, |
24 | 25 | positive_definite_matrices, MAX_ARRAY_SIZE, |
|
43 | 44 | # Standin strategy for not yet implemented tests |
44 | 45 | todo = none() |
45 | 46 |
|
46 | | -def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw): |
| 47 | +def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), |
| 48 | + assert_equal=assert_exactly_equal, **kw): |
47 | 49 | """ |
48 | 50 | Test that f(*args, **kw) maps across stacks of matrices |
49 | 51 |
|
50 | | - dims is the number of dimensions f should have for a single n x m matrix |
51 | | - stack. |
| 52 | + dims is the number of dimensions f(*args) should have for a single n x m |
| 53 | + matrix stack. |
| 54 | +
|
| 55 | + matrix_axes are the axes along which matrices (or vectors) are stacked in |
| 56 | + the input. |
| 57 | +
|
| 58 | + true_val may be a function such that true_val(*x_stacks, **kw) gives the |
| 59 | + true value for f on a stack. |
| 60 | +
|
| 61 | + res should be the result of f(*args, **kw). It is computed if not passed |
| 62 | + in. |
52 | 63 |
|
53 | | - true_val may be a function such that true_val(*x_stacks) gives the true |
54 | | - value for f on a stack |
55 | 64 | """ |
56 | 65 | if res is None: |
57 | 66 | res = f(*args, **kw) |
58 | 67 |
|
59 | | - shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape |
60 | | - for x in args]) |
61 | | - for _idx in sh.ndindex(shape[:-2]): |
62 | | - idx = _idx + (slice(None),)*dims |
63 | | - res_stack = res[idx] |
64 | | - x_stacks = [x[_idx + (...,)] for x in args] |
| 68 | + shapes = [x.shape for x in args] |
| 69 | + |
| 70 | + for (x_idxes, (res_idx,)) in zip( |
| 71 | + iter_indices(*shapes, skip_axes=matrix_axes), |
| 72 | + iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))): |
| 73 | + x_idxes = [x_idx.raw for x_idx in x_idxes] |
| 74 | + res_idx = res_idx.raw |
| 75 | + |
| 76 | + res_stack = res[res_idx] |
| 77 | + x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)] |
65 | 78 | decomp_res_stack = f(*x_stacks, **kw) |
66 | | - assert_exactly_equal(res_stack, decomp_res_stack) |
| 79 | + assert_equal(res_stack, decomp_res_stack) |
67 | 80 | if true_val: |
68 | | - assert_exactly_equal(decomp_res_stack, true_val(*x_stacks)) |
| 81 | + assert_equal(decomp_res_stack, true_val(*x_stacks)) |
69 | 82 |
|
70 | 83 | def _test_namedtuple(res, fields, func_name): |
71 | 84 | """ |
@@ -452,10 +465,12 @@ def test_slogdet(x): |
452 | 465 |
|
453 | 466 | # Check that when the determinant is 0, the sign and logabsdet are (0, |
454 | 467 | # -inf). |
455 | | - d = linalg.det(x) |
456 | | - zero_det = equal(d, zero(d.shape, d.dtype)) |
457 | | - assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype)) |
458 | | - assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype)) |
| 468 | + # TODO: This test does not necessarily hold exactly. Update it to test it |
| 469 | + # approximately. |
| 470 | + # d = linalg.det(x) |
| 471 | + # zero_det = equal(d, zero(d.shape, d.dtype)) |
| 472 | + # assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype)) |
| 473 | + # assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype)) |
459 | 474 |
|
460 | 475 | # More generally, det(x) should equal sign*exp(logabsdet), but this does |
461 | 476 | # not hold exactly due to floating-point loss of precision. |
@@ -614,7 +629,7 @@ def true_trace(x_stack): |
614 | 629 |
|
615 | 630 | @given( |
616 | 631 | dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), |
617 | | - shape=shapes(), |
| 632 | + shape=shapes(min_dims=1), |
618 | 633 | data=data(), |
619 | 634 | ) |
620 | 635 | def test_vecdot(dtypes, shape, data): |
|
0 commit comments