1616import pytest
1717from hypothesis import assume , given
1818from hypothesis .strategies import (booleans , composite , none , tuples , integers ,
19- shared , sampled_from , data , just )
19+ shared , sampled_from , one_of , data , just )
2020from ndindex import iter_indices
2121
2222from .array_helpers import assert_exactly_equal , asarray
2626 invertible_matrices , two_mutual_arrays ,
2727 mutually_promotable_dtypes , one_d_shapes ,
2828 two_mutually_broadcastable_shapes ,
29- SQRT_MAX_ARRAY_SIZE , finite_matrices )
29+ SQRT_MAX_ARRAY_SIZE , finite_matrices ,
30+ rtol_shared_matrix_shapes , rtols )
3031from . import dtype_helpers as dh
3132from . import pytest_helpers as ph
3233from . import shape_helpers as sh
3738
3839pytestmark = pytest .mark .ci
3940
40-
41-
4241# Standin strategy for not yet implemented tests
4342todo = none ()
4443
45- def _test_stacks (f , * args , res = None , dims = 2 , true_val = None , matrix_axes = (- 2 , - 1 ),
44+ def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
45+ matrix_axes = (- 2 , - 1 ),
4646 assert_equal = assert_exactly_equal , ** kw ):
4747 """
4848 Test that f(*args, **kw) maps across stacks of matrices
4949
50- dims is the number of dimensions f(*args) should have for a single n x m
51- matrix stack.
50+ dims is the number of dimensions f(*args, *kw ) should have for a single n
51+ x m matrix stack.
5252
5353 matrix_axes are the axes along which matrices (or vectors) are stacked in
5454 the input.
@@ -65,9 +65,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)
6565
6666 shapes = [x .shape for x in args ]
6767
68+ # Assume the result is stacked along the last 'dims' axes of matrix_axes.
69+ # This holds for all the functions tested in this file
70+ res_axes = matrix_axes [::- 1 ][:dims ]
71+
6872 for (x_idxes , (res_idx ,)) in zip (
6973 iter_indices (* shapes , skip_axes = matrix_axes ),
70- iter_indices (res .shape , skip_axes = tuple ( range ( - dims , 0 )) )):
74+ iter_indices (res .shape , skip_axes = res_axes )):
7175 x_idxes = [x_idx .raw for x_idx in x_idxes ]
7276 res_idx = res_idx .raw
7377
@@ -159,26 +163,18 @@ def test_cross(x1_x2_kw):
159163 assert res .dtype == dh .result_type (x1 .dtype , x2 .dtype ), "cross() did not return the correct dtype"
160164 assert res .shape == shape , "cross() did not return the correct shape"
161165
162- # cross is too different from other functions to use _test_stacks, and it
163- # is the only function that works the way it does, so it's not really
164- # worth generalizing _test_stacks to handle it.
165- a = axis if axis >= 0 else axis + len (shape )
166- for _idx in sh .ndindex (shape [:a ] + shape [a + 1 :]):
167- idx = _idx [:a ] + (slice (None ),) + _idx [a :]
168- assert len (idx ) == len (shape ), "Invalid index. This indicates a bug in the test suite."
169- res_stack = res [idx ]
170- x1_stack = x1 [idx ]
171- x2_stack = x2 [idx ]
172- assert x1_stack .shape == x2_stack .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
173- decomp_res_stack = linalg .cross (x1_stack , x2_stack )
174- assert_exactly_equal (res_stack , decomp_res_stack )
175-
176- exact_cross = asarray ([
177- x1_stack [1 ]* x2_stack [2 ] - x1_stack [2 ]* x2_stack [1 ],
178- x1_stack [2 ]* x2_stack [0 ] - x1_stack [0 ]* x2_stack [2 ],
179- x1_stack [0 ]* x2_stack [1 ] - x1_stack [1 ]* x2_stack [0 ],
180- ], dtype = res .dtype )
181- assert_exactly_equal (res_stack , exact_cross )
166+ def exact_cross (a , b ):
167+ assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
168+ return asarray ([
169+ a [1 ]* b [2 ] - a [2 ]* b [1 ],
170+ a [2 ]* b [0 ] - a [0 ]* b [2 ],
171+ a [0 ]* b [1 ] - a [1 ]* b [0 ],
172+ ], dtype = res .dtype )
173+
174+ # We don't want to pass in **kw here because that would pass axis to
175+ # cross() on a single stack, but the axis is not meaningful on unstacked
176+ # vectors.
177+ _test_stacks (linalg .cross , x1 , x2 , dims = 1 , matrix_axes = (axis ,), res = res , true_val = exact_cross )
182178
183179@pytest .mark .xp_extension ('linalg' )
184180@given (
@@ -313,14 +309,30 @@ def test_matmul(x1, x2):
313309 assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
314310 _test_stacks (_array_module .matmul , x1 , x2 , res = res )
315311
312+ matrix_norm_shapes = shared (matrix_shapes ())
313+
316314@pytest .mark .xp_extension ('linalg' )
317315@given (
318- x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ()),
319- kw = kwargs (axis = todo , keepdims = todo , ord = todo )
316+ x = finite_matrices (),
317+ kw = kwargs (keepdims = booleans (),
318+ ord = sampled_from ([- float ('inf' ), - 2 , - 2 , 1 , 2 , float ('inf' ), 'fro' , 'nuc' ]))
320319)
321320def test_matrix_norm (x , kw ):
322- # res = linalg.matrix_norm(x, **kw)
323- pass
321+ res = linalg .matrix_norm (x , ** kw )
322+
323+ keepdims = kw .get ('keepdims' , False )
324+ # TODO: Check that the ord values give the correct norms.
325+ # ord = kw.get('ord', 'fro')
326+
327+ if keepdims :
328+ expected_shape = x .shape [:- 2 ] + (1 , 1 )
329+ else :
330+ expected_shape = x .shape [:- 2 ]
331+ assert res .shape == expected_shape , f"matrix_norm({ keepdims = } ) did not return the correct shape"
332+ assert res .dtype == x .dtype , "matrix_norm() did not return the correct dtype"
333+
334+ _test_stacks (linalg .matrix_norm , x , ** kw , dims = 2 if keepdims else 0 ,
335+ res = res )
324336
325337matrix_power_n = shared (integers (- 1000 , 1000 ), key = 'matrix_power n' )
326338@pytest .mark .xp_extension ('linalg' )
@@ -347,12 +359,11 @@ def test_matrix_power(x, n):
347359
348360@pytest .mark .xp_extension ('linalg' )
349361@given (
350- x = xps . arrays ( dtype = xps . floating_dtypes (), shape = shapes () ),
351- kw = kwargs (rtol = todo )
362+ x = finite_matrices ( shape = rtol_shared_matrix_shapes ),
363+ kw = kwargs (rtol = rtols )
352364)
353365def test_matrix_rank (x , kw ):
354- # res = linalg.matrix_rank(x, **kw)
355- pass
366+ linalg .matrix_rank (x , ** kw )
356367
357368@given (
358369 x = xps .arrays (dtype = dtypes , shape = matrix_shapes ()),
@@ -397,12 +408,11 @@ def test_outer(x1, x2):
397408
398409@pytest .mark .xp_extension ('linalg' )
399410@given (
400- x = xps . arrays ( dtype = xps . floating_dtypes (), shape = shapes () ),
401- kw = kwargs (rtol = todo )
411+ x = finite_matrices ( shape = rtol_shared_matrix_shapes ),
412+ kw = kwargs (rtol = rtols )
402413)
403414def test_pinv (x , kw ):
404- # res = linalg.pinv(x, **kw)
405- pass
415+ linalg .pinv (x , ** kw )
406416
407417@pytest .mark .xp_extension ('linalg' )
408418@given (
@@ -482,7 +492,7 @@ def solve_args():
482492 Strategy for the x1 and x2 arguments to test_solve()
483493
484494 solve() takes x1, x2, where x1 is any stack of square invertible matrices
485- of shape (..., M, M), and x2 is either shape (..., M ) or (..., M, K),
495+ of shape (..., M, M), and x2 is either shape (M, ) or (..., M, K),
486496 where the ... parts of x1 and x2 are broadcast compatible.
487497 """
488498 stack_shapes = shared (two_mutually_broadcastable_shapes )
@@ -492,30 +502,22 @@ def solve_args():
492502 pair [0 ])))
493503
494504 @composite
495- def x2_shapes (draw ):
496- end = draw (xps .array_shapes (min_dims = 0 , max_dims = 1 , min_side = 0 ,
497- max_side = SQRT_MAX_ARRAY_SIZE ))
498- return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + end
505+ def _x2_shapes (draw ):
506+ end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
507+ return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
499508
500- x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes ())
509+ x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
510+ x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes )
501511 return x1 , x2
502512
503513@pytest .mark .xp_extension ('linalg' )
504514@given (* solve_args ())
505515def test_solve (x1 , x2 ):
506- # TODO: solve() is currently ambiguous, in that some inputs can be
507- # interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
508- # and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
509- # of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
510- # broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
511- # (2, 2, 2, 2).
512-
513- # res = linalg.solve(x1, x2)
514- pass
516+ linalg .solve (x1 , x2 )
515517
516518@pytest .mark .xp_extension ('linalg' )
517519@given (
518- x = finite_matrices ,
520+ x = finite_matrices () ,
519521 kw = kwargs (full_matrices = booleans ())
520522)
521523def test_svd (x , kw ):
@@ -551,7 +553,7 @@ def test_svd(x, kw):
551553
552554@pytest .mark .xp_extension ('linalg' )
553555@given (
554- x = finite_matrices ,
556+ x = finite_matrices () ,
555557)
556558def test_svdvals (x ):
557559 res = linalg .svdvals (x )
0 commit comments