11from __future__ import annotations
22
3- from ...common import _aliases
3+ from typing import Callable
4+
5+ from ...common import _aliases , array_namespace
46
57from ..._internal import get_xp
68
2931)
3032
3133from typing import TYPE_CHECKING
34+
3235if TYPE_CHECKING :
3336 from typing import Optional , Union
3437
35- from ...common ._typing import Device , Dtype , Array , NestedSequence , SupportsBufferProtocol
38+ from ...common ._typing import (
39+ Device ,
40+ Dtype ,
41+ Array ,
42+ NestedSequence ,
43+ SupportsBufferProtocol ,
44+ )
3645
3746import dask .array as da
3847
3948isdtype = get_xp (np )(_aliases .isdtype )
4049unstack = get_xp (da )(_aliases .unstack )
4150
51+
4252# da.astype doesn't respect copy=True
4353def astype (
4454 x : Array ,
4555 dtype : Dtype ,
4656 / ,
4757 * ,
4858 copy : bool = True ,
49- device : Optional [Device ] = None
59+ device : Optional [Device ] = None ,
5060) -> Array :
5161 """
5262 Array API compatibility wrapper for astype().
@@ -61,8 +71,10 @@ def astype(
6171 x = x .astype (dtype )
6272 return x .copy () if copy else x
6373
74+
6475# Common aliases
6576
77+
6678# This arange func is modified from the common one to
6779# not pass stop/step as keyword arguments, which will cause
6880# an error with dask
@@ -189,6 +201,7 @@ def asarray(
189201 concatenate as concat ,
190202)
191203
204+
192205# dask.array.clip does not work unless all three arguments are provided.
193206# Furthermore, the masking workaround in common._aliases.clip cannot work with
194207# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -205,8 +218,10 @@ def clip(
205218 See the corresponding documentation in the array library and/or the array API
206219 specification for more details.
207220 """
221+
208222 def _isscalar (a ):
209223 return isinstance (a , (int , float , type (None )))
224+
210225 min_shape = () if _isscalar (min ) else min .shape
211226 max_shape = () if _isscalar (max ) else max .shape
212227
@@ -228,12 +243,99 @@ def _isscalar(a):
228243
229244 return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
230245
231- # exclude these from all since dask.array has no sorting functions
232- _da_unsupported = ['sort' , 'argsort' ]
233246
234- _common_aliases = [alias for alias in _aliases .__all__ if alias not in _da_unsupported ]
247+ def _ensure_single_chunk (x : Array , axis : int ) -> tuple [Array , Callable [[Array ], Array ]]:
248+ """
249+ Make sure that Array is not broken into multiple chunks along axis.
250+
251+ Returns
252+ -------
253+ x : Array
254+ The input Array with a single chunk along axis.
255+ restore : Callable[Array, Array]
256+ function to apply to the output to rechunk it back into reasonable chunks
257+ """
258+ if axis < 0 :
259+ axis += x .ndim
260+ if x .numblocks [axis ] < 2 :
261+ return x , lambda x : x
262+
263+ # Break chunks on other axes in an attempt to keep chunk size low
264+ x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
265+
266+ # Rather than reconstructing the original chunks, which can be a
267+ # very expensive affair, just break down oversized chunks without
268+ # incurring in any transfers over the network.
269+ # This has the downside of a risk of overchunking if the array is
270+ # then used in operations against other arrays that match the
271+ # original chunking pattern.
272+ return x , lambda x : x .rechunk ()
273+
274+
275+ def sort (
276+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
277+ ) -> Array :
278+ """
279+ Array API compatibility layer around the lack of sort() in Dask.
280+
281+ Warnings
282+ --------
283+ This function temporarily rechunks the array along `axis` to a single chunk.
284+ This can be extremely inefficient and can lead to out-of-memory errors.
285+
286+ See the corresponding documentation in the array library and/or the array API
287+ specification for more details.
288+ """
289+ x , restore = _ensure_single_chunk (x , axis )
290+
291+ meta_xp = array_namespace (x ._meta )
292+ x = da .map_blocks (
293+ meta_xp .sort ,
294+ x ,
295+ axis = axis ,
296+ meta = x ._meta ,
297+ dtype = x .dtype ,
298+ descending = descending ,
299+ stable = stable ,
300+ )
301+
302+ return restore (x )
235303
236- __all__ = _common_aliases + ['__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
304+
305+ def argsort (
306+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
307+ ) -> Array :
308+ """
309+ Array API compatibility layer around the lack of argsort() in Dask.
310+
311+ See the corresponding documentation in the array library and/or the array API
312+ specification for more details.
313+
314+ Warnings
315+ --------
316+ This function temporarily rechunks the array along `axis` into a single chunk.
317+ This can be extremely inefficient and can lead to out-of-memory errors.
318+ """
319+ x , restore = _ensure_single_chunk (x , axis )
320+
321+ meta_xp = array_namespace (x ._meta )
322+ dtype = meta_xp .argsort (x ._meta ).dtype
323+ meta = meta_xp .astype (x ._meta , dtype )
324+ x = da .map_blocks (
325+ meta_xp .argsort ,
326+ x ,
327+ axis = axis ,
328+ meta = meta ,
329+ dtype = dtype ,
330+ descending = descending ,
331+ stable = stable ,
332+ )
333+
334+ return restore (x )
335+
336+
337+ __all__ = _aliases .__all__ + [
338+ '__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
237339 'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' ,
238340 'atanh' , 'bitwise_left_shift' , 'bitwise_invert' ,
239341 'bitwise_right_shift' , 'concat' , 'pow' , 'iinfo' , 'finfo' , 'can_cast' ,
@@ -242,4 +344,4 @@ def _isscalar(a):
242344 'complex64' , 'complex128' , 'iinfo' , 'finfo' ,
243345 'can_cast' , 'result_type' ]
244346
245- _all_ignore = ["get_xp" , "da" , "np" ]
347+ _all_ignore = ["Callable" , "array_namespace" , " get_xp" , "da" , "np" ]
0 commit comments