@@ -30,6 +30,46 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3030 return xps .boolean_dtypes () | all_integer_dtypes ()
3131
3232
33+ class OnewayPromotableDtypes (NamedTuple ):
34+ input_dtype : DataType
35+ result_dtype : DataType
36+
37+
38+ @st .composite
39+ def oneway_promotable_dtypes (
40+ draw , dtypes : List [DataType ]
41+ ) -> st .SearchStrategy [OnewayPromotableDtypes ]:
42+ """Return a strategy for input dtypes that promote to result dtypes."""
43+ d1 , d2 = draw (hh .mutually_promotable_dtypes (dtypes = dtypes ))
44+ result_dtype = dh .result_type (d1 , d2 )
45+ if d1 == result_dtype :
46+ return OnewayPromotableDtypes (d2 , d1 )
47+ elif d2 == result_dtype :
48+ return OnewayPromotableDtypes (d1 , d2 )
49+ else :
50+ reject ()
51+
52+
53+ class OnewayBroadcastableShapes (NamedTuple ):
54+ input_shape : Shape
55+ result_shape : Shape
56+
57+
58+ @st .composite
59+ def oneway_broadcastable_shapes (draw ) -> st .SearchStrategy [OnewayBroadcastableShapes ]:
60+ """Return a strategy for input shapes that broadcast to result shapes."""
61+ result_shape = draw (hh .shapes (min_side = 1 ))
62+ input_shape = draw (
63+ xps .broadcastable_shapes (
64+ result_shape ,
65+ # Override defaults so bad shapes are less likely to be generated.
66+ max_side = None if result_shape == () else max (result_shape ),
67+ max_dims = len (result_shape ),
68+ ).filter (lambda s : sh .broadcast_shapes (result_shape , s ) == result_shape )
69+ )
70+ return OnewayBroadcastableShapes (input_shape , result_shape )
71+
72+
3373def mock_int_dtype (n : int , dtype : DataType ) -> int :
3474 """Returns equivalent of `n` that mocks `dtype` behaviour."""
3575 nbits = dh .dtype_nbits [dtype ]
@@ -306,8 +346,14 @@ def __repr__(self):
306346
307347
308348def make_binary_params (
309- elwise_func_name : str , dtypes_strat : st . SearchStrategy [DataType ]
349+ elwise_func_name : str , dtypes : List [DataType ]
310350) -> List [Param [BinaryParamContext ]]:
351+ if hh .FILTER_UNDEFINED_DTYPES :
352+ dtypes = [d for d in dtypes if not isinstance (d , xp ._UndefinedStub )]
353+ shared_oneway_dtypes = st .shared (oneway_promotable_dtypes (dtypes ))
354+ left_dtypes = shared_oneway_dtypes .map (lambda D : D .result_dtype )
355+ right_dtypes = shared_oneway_dtypes .map (lambda D : D .input_dtype )
356+
311357 def make_param (
312358 func_name : str , func_type : FuncType , right_is_scalar : bool
313359 ) -> Param [BinaryParamContext ]:
@@ -318,26 +364,29 @@ def make_param(
318364 left_sym = "x1"
319365 right_sym = "x2"
320366
321- shared_dtypes = st .shared (dtypes_strat )
322367 if right_is_scalar :
323- left_strat = xps .arrays (dtype = shared_dtypes , shape = hh .shapes (** shapes_kw ))
324- right_strat = shared_dtypes .flatmap (
325- lambda d : xps .from_dtype (d , ** finite_kw )
326- )
368+ left_strat = xps .arrays (dtype = left_dtypes , shape = hh .shapes (** shapes_kw ))
369+ right_strat = right_dtypes .flatmap (lambda d : xps .from_dtype (d , ** finite_kw ))
327370 else :
328371 if func_type is FuncType .IOP :
329- shared_shapes = st .shared (hh .shapes (** shapes_kw ))
330- left_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
331- right_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
372+ shared_oneway_shapes = st .shared (oneway_broadcastable_shapes ())
373+ left_strat = xps .arrays (
374+ dtype = left_dtypes ,
375+ shape = shared_oneway_shapes .map (lambda S : S .result_shape ),
376+ )
377+ right_strat = xps .arrays (
378+ dtype = right_dtypes ,
379+ shape = shared_oneway_shapes .map (lambda S : S .input_shape ),
380+ )
332381 else :
333382 mutual_shapes = st .shared (
334383 hh .mutually_broadcastable_shapes (2 , ** shapes_kw )
335384 )
336385 left_strat = xps .arrays (
337- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
386+ dtype = left_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
338387 )
339388 right_strat = xps .arrays (
340- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
389+ dtype = right_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
341390 )
342391
343392 if func_type is FuncType .FUNC :
@@ -514,7 +563,7 @@ def test_acosh(x):
514563 )
515564
516565
517- @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , xps .numeric_dtypes () ))
566+ @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , dh .numeric_dtypes ))
518567@given (data = st .data ())
519568def test_add (ctx , data ):
520569 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -579,7 +628,7 @@ def test_atanh(x):
579628
580629
581630@pytest .mark .parametrize (
582- "ctx" , make_binary_params ("bitwise_and" , boolean_and_all_integer_dtypes () )
631+ "ctx" , make_binary_params ("bitwise_and" , dh . bool_and_all_int_dtypes )
583632)
584633@given (data = st .data ())
585634def test_bitwise_and (ctx , data ):
@@ -598,7 +647,7 @@ def test_bitwise_and(ctx, data):
598647
599648
600649@pytest .mark .parametrize (
601- "ctx" , make_binary_params ("bitwise_left_shift" , all_integer_dtypes () )
650+ "ctx" , make_binary_params ("bitwise_left_shift" , dh . all_int_dtypes )
602651)
603652@given (data = st .data ())
604653def test_bitwise_left_shift (ctx , data ):
@@ -638,7 +687,7 @@ def test_bitwise_invert(ctx, data):
638687
639688
640689@pytest .mark .parametrize (
641- "ctx" , make_binary_params ("bitwise_or" , boolean_and_all_integer_dtypes () )
690+ "ctx" , make_binary_params ("bitwise_or" , dh . bool_and_all_int_dtypes )
642691)
643692@given (data = st .data ())
644693def test_bitwise_or (ctx , data ):
@@ -657,7 +706,7 @@ def test_bitwise_or(ctx, data):
657706
658707
659708@pytest .mark .parametrize (
660- "ctx" , make_binary_params ("bitwise_right_shift" , all_integer_dtypes () )
709+ "ctx" , make_binary_params ("bitwise_right_shift" , dh . all_int_dtypes )
661710)
662711@given (data = st .data ())
663712def test_bitwise_right_shift (ctx , data ):
@@ -678,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):
678727
679728
680729@pytest .mark .parametrize (
681- "ctx" , make_binary_params ("bitwise_xor" , boolean_and_all_integer_dtypes () )
730+ "ctx" , make_binary_params ("bitwise_xor" , dh . bool_and_all_int_dtypes )
682731)
683732@given (data = st .data ())
684733def test_bitwise_xor (ctx , data ):
@@ -720,7 +769,7 @@ def test_cosh(x):
720769 unary_assert_against_refimpl ("cosh" , x , out , math .cosh )
721770
722771
723- @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , xps . floating_dtypes () ))
772+ @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , dh . float_dtypes ))
724773@given (data = st .data ())
725774def test_divide (ctx , data ):
726775 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -743,7 +792,7 @@ def test_divide(ctx, data):
743792 )
744793
745794
746- @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , xps . scalar_dtypes () ))
795+ @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , dh . all_dtypes ))
747796@given (data = st .data ())
748797def test_equal (ctx , data ):
749798 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -795,9 +844,7 @@ def test_floor(x):
795844 unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
796845
797846
798- @pytest .mark .parametrize (
799- "ctx" , make_binary_params ("floor_divide" , xps .numeric_dtypes ())
800- )
847+ @pytest .mark .parametrize ("ctx" , make_binary_params ("floor_divide" , dh .numeric_dtypes ))
801848@given (data = st .data ())
802849def test_floor_divide (ctx , data ):
803850 left = data .draw (
@@ -816,7 +863,7 @@ def test_floor_divide(ctx, data):
816863 binary_param_assert_against_refimpl (ctx , left , right , res , "//" , operator .floordiv )
817864
818865
819- @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , xps .numeric_dtypes () ))
866+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , dh .numeric_dtypes ))
820867@given (data = st .data ())
821868def test_greater (ctx , data ):
822869 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -836,9 +883,7 @@ def test_greater(ctx, data):
836883 )
837884
838885
839- @pytest .mark .parametrize (
840- "ctx" , make_binary_params ("greater_equal" , xps .numeric_dtypes ())
841- )
886+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater_equal" , dh .numeric_dtypes ))
842887@given (data = st .data ())
843888def test_greater_equal (ctx , data ):
844889 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -882,7 +927,7 @@ def test_isnan(x):
882927 unary_assert_against_refimpl ("isnan" , x , out , math .isnan , res_stype = bool )
883928
884929
885- @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , xps .numeric_dtypes () ))
930+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , dh .numeric_dtypes ))
886931@given (data = st .data ())
887932def test_less (ctx , data ):
888933 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -902,7 +947,7 @@ def test_less(ctx, data):
902947 )
903948
904949
905- @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , xps .numeric_dtypes () ))
950+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , dh .numeric_dtypes ))
906951@given (data = st .data ())
907952def test_less_equal (ctx , data ):
908953 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1014,7 +1059,7 @@ def test_logical_xor(x1, x2):
10141059 )
10151060
10161061
1017- @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , xps .numeric_dtypes () ))
1062+ @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
10181063@given (data = st .data ())
10191064def test_multiply (ctx , data ):
10201065 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1047,7 +1092,7 @@ def test_negative(ctx, data):
10471092 )
10481093
10491094
1050- @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , xps . scalar_dtypes () ))
1095+ @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , dh . all_dtypes ))
10511096@given (data = st .data ())
10521097def test_not_equal (ctx , data ):
10531098 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1079,7 +1124,7 @@ def test_positive(ctx, data):
10791124 ph .assert_array (ctx .func_name , out , x )
10801125
10811126
1082- @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , xps .numeric_dtypes () ))
1127+ @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , dh .numeric_dtypes ))
10831128@given (data = st .data ())
10841129def test_pow (ctx , data ):
10851130 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1103,7 +1148,7 @@ def test_pow(ctx, data):
11031148 )
11041149
11051150
1106- @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , xps .numeric_dtypes () ))
1151+ @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , dh .numeric_dtypes ))
11071152@given (data = st .data ())
11081153def test_remainder (ctx , data ):
11091154 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1174,7 +1219,7 @@ def test_sqrt(x):
11741219 )
11751220
11761221
1177- @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , xps .numeric_dtypes () ))
1222+ @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , dh .numeric_dtypes ))
11781223@given (data = st .data ())
11791224def test_subtract (ctx , data ):
11801225 left = data .draw (ctx .left_strat , label = ctx .left_sym )
0 commit comments