@@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3030 return xps .boolean_dtypes () | all_integer_dtypes ()
3131
3232
33+ class OnewayBroadcastableShapes (NamedTuple ):
34+ input_shape : Shape
35+ result_shape : Shape
36+
37+
38+ @st .composite
39+ def oneway_broadcastable_shapes (draw ) -> st .SearchStrategy [OnewayBroadcastableShapes ]:
40+ """Return a strategy for input shapes that broadcast to result shapes."""
41+ result_shape = draw (hh .shapes (min_side = 1 ))
42+ input_shape = draw (
43+ xps .broadcastable_shapes (
44+ result_shape ,
45+ # Override defaults so bad shapes are less likely to be generated.
46+ max_side = None if result_shape == () else max (result_shape ),
47+ max_dims = len (result_shape ),
48+ ).filter (lambda s : sh .broadcast_shapes (result_shape , s ) == result_shape )
49+ )
50+ return OnewayBroadcastableShapes (input_shape , result_shape )
51+
52+
3353def mock_int_dtype (n : int , dtype : DataType ) -> int :
3454 """Returns equivalent of `n` that mocks `dtype` behaviour."""
3555 nbits = dh .dtype_nbits [dtype ]
@@ -326,9 +346,15 @@ def make_param(
326346 )
327347 else :
328348 if func_type is FuncType .IOP :
329- shared_shapes = st .shared (hh .shapes (** shapes_kw ))
330- left_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
331- right_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
349+ shared_oneway_shapes = st .shared (oneway_broadcastable_shapes ())
350+ left_strat = xps .arrays (
351+ dtype = shared_dtypes ,
352+ shape = shared_oneway_shapes .map (lambda S : S .result_shape ),
353+ )
354+ right_strat = xps .arrays (
355+ dtype = shared_dtypes ,
356+ shape = shared_oneway_shapes .map (lambda S : S .input_shape ),
357+ )
332358 else :
333359 mutual_shapes = st .shared (
334360 hh .mutually_broadcastable_shapes (2 , ** shapes_kw )
0 commit comments