@@ -65,7 +65,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
6565
6666
6767@wraps (xps .arrays )
68- def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
68+ def arrays_no_scalars (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
6969 """xps.arrays() without the crazy large numbers."""
7070 if isinstance (dtype , SearchStrategy ):
7171 return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
@@ -78,6 +78,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
7878 return xps .arrays (dtype , * args , elements = elements , ** kwargs )
7979
8080
81+ def _f (a , flag ):
82+ return a [()] if a .ndim == 0 and flag else a
83+
84+
85+ @wraps (xps .arrays )
86+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
87+ """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
88+
89+ Is only relevant for numpy: on all other libraries, array[()] is no-op.
90+ """
91+ return builds (_f , arrays_no_scalars (dtype , * args , elements = elements , ** kwargs ), booleans ())
92+
93+
8194_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
8295_sorted_dtypes = [d for category in _dtype_categories for d in category ]
8396
0 commit comments