File tree Expand file tree Collapse file tree 3 files changed +27
-6
lines changed Expand file tree Collapse file tree 3 files changed +27
-6
lines changed Original file line number Diff line number Diff line change @@ -595,11 +595,29 @@ def your_function(x, y):
595595# backwards compatibility alias
596596get_namespace = array_namespace
597597
598- def _check_device (xp , device ):
599- if xp == sys .modules .get ('numpy' ):
598+
599+ def _check_device (bare_xp , device ):
600+ """
601+ Validate dummy device on device-less array backends.
602+
603+ Notes
604+ -----
605+ This function is also invoked by CuPy, which does have multiple devices
606+ if there are multiple GPUs available.
607+ However, CuPy multi-device support is currently impossible
608+ without using the global device or a context manager:
609+
610+ https://github.com/data-apis/array-api-compat/pull/293
611+ """
612+ if bare_xp == sys .modules .get ('numpy' ):
600613 if device not in ["cpu" , None ]:
601614 raise ValueError (f"Unsupported device for NumPy: { device !r} " )
602615
616+ elif bare_xp is sys .modules .get ('dask.array' ):
617+ if device not in ("cpu" , _DASK_DEVICE ):
618+ raise ValueError (f"Unsupported device for Dask: { device !r} " )
619+
620+
603621# Placeholder object to represent the dask device
604622# when the array backend is not the CPU.
605623# (since it is not easy to tell which device a dask array is on)
Original file line number Diff line number Diff line change 2525)
2626import dask .array as da
2727
28- from ...common import _aliases , array_namespace
28+ from ...common import _aliases , _helpers , array_namespace
2929from ...common ._typing import (
3030 Array ,
3131 Device ,
@@ -56,6 +56,7 @@ def astype(
5656 specification for more details.
5757 """
5858 # TODO: respect device keyword?
59+ _helpers ._check_device (da , device )
5960
6061 if not copy and dtype == x .dtype :
6162 return x
@@ -86,6 +87,7 @@ def arange(
8687 specification for more details.
8788 """
8889 # TODO: respect device keyword?
90+ _helpers ._check_device (da , device )
8991
9092 args = [start ]
9193 if stop is not None :
@@ -155,6 +157,7 @@ def asarray(
155157 specification for more details.
156158 """
157159 # TODO: respect device keyword?
160+ _helpers ._check_device (da , device )
158161
159162 if isinstance (obj , da .Array ):
160163 if dtype is not None and dtype != obj .dtype :
Original file line number Diff line number Diff line change 33from typing import Optional , Union
44
55from .._internal import get_xp
6- from ..common import _aliases
6+ from ..common import _aliases , _helpers
77from ..common ._typing import NestedSequence , SupportsBufferProtocol
88from ._info import __array_namespace_info__
99from ._typing import Array , Device , DType
@@ -95,8 +95,7 @@ def asarray(
9595 See the corresponding documentation in the array library and/or the array API
9696 specification for more details.
9797 """
98- if device not in ["cpu" , None ]:
99- raise ValueError (f"Unsupported device for NumPy: { device !r} " )
98+ _helpers ._check_device (np , device )
10099
101100 if hasattr (np , '_CopyMode' ):
102101 if copy is None :
@@ -122,6 +121,7 @@ def astype(
122121 copy : bool = True ,
123122 device : Optional [Device ] = None ,
124123) -> Array :
124+ _helpers ._check_device (np , device )
125125 return x .astype (dtype = dtype , copy = copy )
126126
127127
You can’t perform that action at this time.
0 commit comments