@@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3030 return xps .boolean_dtypes () | all_integer_dtypes ()
3131
3232
33+ class OnewayPromotableDtypes (NamedTuple ):
34+ input_dtype : DataType
35+ result_dtype : DataType
36+
37+
38+ @st .composite
39+ def oneway_promotable_dtypes (
40+ draw , dtypes : List [DataType ]
41+ ) -> st .SearchStrategy [OnewayPromotableDtypes ]:
42+ """Return a strategy for input dtypes that promote to result dtypes."""
43+ d1 , d2 = draw (hh .mutually_promotable_dtypes (dtypes = dtypes ))
44+ result_dtype = dh .result_type (d1 , d2 )
45+ if d1 == result_dtype :
46+ return OnewayPromotableDtypes (d2 , d1 )
47+ elif d2 == result_dtype :
48+ return OnewayPromotableDtypes (d1 , d2 )
49+ else :
50+ reject ()
51+
52+
3353class OnewayBroadcastableShapes (NamedTuple ):
3454 input_shape : Shape
3555 result_shape : Shape
@@ -326,8 +346,14 @@ def __repr__(self):
326346
327347
328348def make_binary_params (
329- elwise_func_name : str , dtypes_strat : st . SearchStrategy [DataType ]
349+ elwise_func_name : str , dtypes : List [DataType ]
330350) -> List [Param [BinaryParamContext ]]:
351+ if hh .FILTER_UNDEFINED_DTYPES :
352+ dtypes = [d for d in dtypes if not isinstance (d , xp ._UndefinedStub )]
353+ shared_oneway_dtypes = st .shared (oneway_promotable_dtypes (dtypes ))
354+ left_dtypes = shared_oneway_dtypes .map (lambda D : D .result_dtype )
355+ right_dtypes = shared_oneway_dtypes .map (lambda D : D .input_dtype )
356+
331357 def make_param (
332358 func_name : str , func_type : FuncType , right_is_scalar : bool
333359 ) -> Param [BinaryParamContext ]:
@@ -338,32 +364,29 @@ def make_param(
338364 left_sym = "x1"
339365 right_sym = "x2"
340366
341- shared_dtypes = st .shared (dtypes_strat )
342367 if right_is_scalar :
343- left_strat = xps .arrays (dtype = shared_dtypes , shape = hh .shapes (** shapes_kw ))
344- right_strat = shared_dtypes .flatmap (
345- lambda d : xps .from_dtype (d , ** finite_kw )
346- )
368+ left_strat = xps .arrays (dtype = left_dtypes , shape = hh .shapes (** shapes_kw ))
369+ right_strat = right_dtypes .flatmap (lambda d : xps .from_dtype (d , ** finite_kw ))
347370 else :
348371 if func_type is FuncType .IOP :
349372 shared_oneway_shapes = st .shared (oneway_broadcastable_shapes ())
350373 left_strat = xps .arrays (
351- dtype = shared_dtypes ,
374+ dtype = left_dtypes ,
352375 shape = shared_oneway_shapes .map (lambda S : S .result_shape ),
353376 )
354377 right_strat = xps .arrays (
355- dtype = shared_dtypes ,
378+ dtype = right_dtypes ,
356379 shape = shared_oneway_shapes .map (lambda S : S .input_shape ),
357380 )
358381 else :
359382 mutual_shapes = st .shared (
360383 hh .mutually_broadcastable_shapes (2 , ** shapes_kw )
361384 )
362385 left_strat = xps .arrays (
363- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
386+ dtype = left_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
364387 )
365388 right_strat = xps .arrays (
366- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
389+ dtype = right_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
367390 )
368391
369392 if func_type is FuncType .FUNC :
@@ -540,7 +563,7 @@ def test_acosh(x):
540563 )
541564
542565
543- @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , xps .numeric_dtypes () ))
566+ @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , dh .numeric_dtypes ))
544567@given (data = st .data ())
545568def test_add (ctx , data ):
546569 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -605,7 +628,7 @@ def test_atanh(x):
605628
606629
607630@pytest .mark .parametrize (
608- "ctx" , make_binary_params ("bitwise_and" , boolean_and_all_integer_dtypes () )
631+ "ctx" , make_binary_params ("bitwise_and" , dh . bool_and_all_int_dtypes )
609632)
610633@given (data = st .data ())
611634def test_bitwise_and (ctx , data ):
@@ -624,7 +647,7 @@ def test_bitwise_and(ctx, data):
624647
625648
626649@pytest .mark .parametrize (
627- "ctx" , make_binary_params ("bitwise_left_shift" , all_integer_dtypes () )
650+ "ctx" , make_binary_params ("bitwise_left_shift" , dh . all_int_dtypes )
628651)
629652@given (data = st .data ())
630653def test_bitwise_left_shift (ctx , data ):
@@ -664,7 +687,7 @@ def test_bitwise_invert(ctx, data):
664687
665688
666689@pytest .mark .parametrize (
667- "ctx" , make_binary_params ("bitwise_or" , boolean_and_all_integer_dtypes () )
690+ "ctx" , make_binary_params ("bitwise_or" , dh . bool_and_all_int_dtypes )
668691)
669692@given (data = st .data ())
670693def test_bitwise_or (ctx , data ):
@@ -683,7 +706,7 @@ def test_bitwise_or(ctx, data):
683706
684707
685708@pytest .mark .parametrize (
686- "ctx" , make_binary_params ("bitwise_right_shift" , all_integer_dtypes () )
709+ "ctx" , make_binary_params ("bitwise_right_shift" , dh . all_int_dtypes )
687710)
688711@given (data = st .data ())
689712def test_bitwise_right_shift (ctx , data ):
@@ -704,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):
704727
705728
706729@pytest .mark .parametrize (
707- "ctx" , make_binary_params ("bitwise_xor" , boolean_and_all_integer_dtypes () )
730+ "ctx" , make_binary_params ("bitwise_xor" , dh . bool_and_all_int_dtypes )
708731)
709732@given (data = st .data ())
710733def test_bitwise_xor (ctx , data ):
@@ -746,7 +769,7 @@ def test_cosh(x):
746769 unary_assert_against_refimpl ("cosh" , x , out , math .cosh )
747770
748771
749- @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , xps . floating_dtypes () ))
772+ @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , dh . float_dtypes ))
750773@given (data = st .data ())
751774def test_divide (ctx , data ):
752775 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -769,7 +792,7 @@ def test_divide(ctx, data):
769792 )
770793
771794
772- @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , xps . scalar_dtypes () ))
795+ @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , dh . all_dtypes ))
773796@given (data = st .data ())
774797def test_equal (ctx , data ):
775798 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -821,9 +844,7 @@ def test_floor(x):
821844 unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
822845
823846
824- @pytest .mark .parametrize (
825- "ctx" , make_binary_params ("floor_divide" , xps .numeric_dtypes ())
826- )
847+ @pytest .mark .parametrize ("ctx" , make_binary_params ("floor_divide" , dh .numeric_dtypes ))
827848@given (data = st .data ())
828849def test_floor_divide (ctx , data ):
829850 left = data .draw (
@@ -842,7 +863,7 @@ def test_floor_divide(ctx, data):
842863 binary_param_assert_against_refimpl (ctx , left , right , res , "//" , operator .floordiv )
843864
844865
845- @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , xps .numeric_dtypes () ))
866+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , dh .numeric_dtypes ))
846867@given (data = st .data ())
847868def test_greater (ctx , data ):
848869 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -862,9 +883,7 @@ def test_greater(ctx, data):
862883 )
863884
864885
865- @pytest .mark .parametrize (
866- "ctx" , make_binary_params ("greater_equal" , xps .numeric_dtypes ())
867- )
886+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater_equal" , dh .numeric_dtypes ))
868887@given (data = st .data ())
869888def test_greater_equal (ctx , data ):
870889 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -908,7 +927,7 @@ def test_isnan(x):
908927 unary_assert_against_refimpl ("isnan" , x , out , math .isnan , res_stype = bool )
909928
910929
911- @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , xps .numeric_dtypes () ))
930+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , dh .numeric_dtypes ))
912931@given (data = st .data ())
913932def test_less (ctx , data ):
914933 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -928,7 +947,7 @@ def test_less(ctx, data):
928947 )
929948
930949
931- @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , xps .numeric_dtypes () ))
950+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , dh .numeric_dtypes ))
932951@given (data = st .data ())
933952def test_less_equal (ctx , data ):
934953 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1040,7 +1059,7 @@ def test_logical_xor(x1, x2):
10401059 )
10411060
10421061
1043- @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , xps .numeric_dtypes () ))
1062+ @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
10441063@given (data = st .data ())
10451064def test_multiply (ctx , data ):
10461065 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1073,7 +1092,7 @@ def test_negative(ctx, data):
10731092 )
10741093
10751094
1076- @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , xps . scalar_dtypes () ))
1095+ @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , dh . all_dtypes ))
10771096@given (data = st .data ())
10781097def test_not_equal (ctx , data ):
10791098 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1105,7 +1124,7 @@ def test_positive(ctx, data):
11051124 ph .assert_array (ctx .func_name , out , x )
11061125
11071126
1108- @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , xps .numeric_dtypes () ))
1127+ @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , dh .numeric_dtypes ))
11091128@given (data = st .data ())
11101129def test_pow (ctx , data ):
11111130 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1129,7 +1148,7 @@ def test_pow(ctx, data):
11291148 )
11301149
11311150
1132- @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , xps .numeric_dtypes () ))
1151+ @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , dh .numeric_dtypes ))
11331152@given (data = st .data ())
11341153def test_remainder (ctx , data ):
11351154 left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1200,7 +1219,7 @@ def test_sqrt(x):
12001219 )
12011220
12021221
1203- @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , xps .numeric_dtypes () ))
1222+ @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , dh .numeric_dtypes ))
12041223@given (data = st .data ())
12051224def test_subtract (ctx , data ):
12061225 left = data .draw (ctx .left_strat , label = ctx .left_sym )
0 commit comments