4545from . import _array_module as xp
4646from ._array_module import linalg
4747
48+
4849def assert_equal (x , y , msg_extra = None ):
4950 extra = '' if not msg_extra else f' ({ msg_extra } )'
5051 if x .dtype in dh .all_float_dtypes :
@@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None):
6061 else :
6162 assert_exactly_equal (x , y , msg_extra = msg_extra )
6263
64+
6365def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
6466 matrix_axes = (- 2 , - 1 ),
6567 res_axes = None ,
@@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
106108 if true_val :
107109 assert_equal (decomp_res_stack , true_val (* x_stacks , ** kw ), msg_extra )
108110
111+
109112def _test_namedtuple (res , fields , func_name ):
110113 """
111114 Test that res is a namedtuple with the correct fields.
@@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name):
121124 assert hasattr (res , field ), f"{ func_name } () result namedtuple doesn't have the '{ field } ' field"
122125 assert res [i ] is getattr (res , field ), f"{ func_name } () result namedtuple '{ field } ' field is not in position { i } "
123126
127+
124128@pytest .mark .unvectorized
125129@pytest .mark .xp_extension ('linalg' )
126130@given (
@@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0):
901905
902906 _test_stacks (linalg .trace , x , ** kw , res = res , dims = 0 , true_val = true_trace )
903907
908+
909+ def _conj (x ):
910+ # XXX: replace with xp.dtype when all array libraries implement it
911+ if x .dtype in (xp .complex64 , xp .complex128 ):
912+ return xp .conj (x )
913+ else :
914+ return x
915+
916+
904917def _test_vecdot (namespace , x1 , x2 , data ):
905918 vecdot = namespace .vecdot
906919 broadcasted_shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
@@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data):
925938 ph .assert_result_shape ("vecdot" , in_shapes = [x1 .shape , x2 .shape ],
926939 out_shape = res .shape , expected = expected_shape )
927940
928- if x1 .dtype in dh .int_dtypes :
929- def true_val (x , y , axis = - 1 ):
930- return xp .sum (xp .multiply (x , y ), dtype = res .dtype )
931- else :
932- true_val = None
941+ def true_val (x , y , axis = - 1 ):
942+ return xp .sum (xp .multiply (_conj (x ), y ), dtype = res .dtype )
933943
934944 _test_stacks (vecdot , x1 , x2 , res = res , dims = 0 ,
935945 matrix_axes = (axis ,), true_val = true_val )
@@ -944,6 +954,7 @@ def true_val(x, y, axis=-1):
944954def test_linalg_vecdot (x1 , x2 , data ):
945955 _test_vecdot (linalg , x1 , x2 , data )
946956
957+
947958@pytest .mark .unvectorized
948959@given (
949960 * two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
@@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data):
952963def test_vecdot (x1 , x2 , data ):
953964 _test_vecdot (_array_module , x1 , x2 , data )
954965
966+
955967# Insanely large orders might not work. There isn't a limit specified in the
956968# spec, so we just limit to reasonable values here.
957969max_ord = 100
958970
971+
959972@pytest .mark .unvectorized
960973@pytest .mark .xp_extension ('linalg' )
961974@given (
0 commit comments