@@ -598,6 +598,44 @@ def your_function(x, y):
598598get_namespace = array_namespace
599599
600600
601+ def _device_ctx (
602+ bare_xp : Namespace , device : Device , like : Array | None = None
603+ ) -> Generator [None ]:
604+ """Context manager which changes the current device in CuPy.
605+
606+ Used internally by array creation functions in common._aliases.
607+ """
608+ if device is None :
609+ if like is None :
610+ return contextlib .nullcontext ()
611+ device = _device (like )
612+
613+ if bare_xp is sys .modules .get ('numpy' ):
614+ if device != "cpu" :
615+ raise ValueError (f"Unsupported device for NumPy: { device !r} " )
616+ return contextlib .nullcontext ()
617+
618+ if bare_xp is sys .modules .get ('dask.array' ):
619+ if device not in ("cpu" , _DASK_DEVICE ):
620+ raise ValueError (f"Unsupported device for Dask: { device !r} " )
621+ return contextlib .nullcontext ()
622+
623+ if bare_xp is sys .modules .get ('cupy' ):
624+ if not isinstance (device , bare_xp .cuda .Device ):
625+ raise TypeError (f"device is not a cupy.cuda.Device: { device !r} " )
626+ return device
627+
628+ # PyTorch doesn't have a "current device" context manager and you
629+ # can't use array creation functions from common._aliases.
630+ raise AssertionError ("unreachable" ) # pragma: nocover
631+
632+
633+ def _check_device (bare_xp : Namespace , device : Device ) -> None :
634+ """Validate dummy device on device-less array backends."""
635+ with _device_ctx (bare_xp , device ):
636+ pass
637+
638+
601639# Placeholder object to represent the dask device
602640# when the array backend is not the CPU.
603641# (since it is not easy to tell which device a dask array is on)
@@ -607,7 +645,6 @@ def __repr__(self):
607645
608646_DASK_DEVICE = _dask_device ()
609647
610-
611648# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
612649# or cupy.ndarray. They are not included in array objects of this library
613650# because this library just reuses the respective ndarray classes without
@@ -799,43 +836,6 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
799836 return x .to_device (device , stream = stream )
800837
801838
802- def _device_ctx (
803- bare_xp : Namespace , device : Device , like : Array | None = None
804- ) -> Generator [None ]:
805- """Context manager which changes the current device in CuPy.
806-
807- Used internally by array creation functions in common._aliases.
808- """
809- if device is None :
810- if like is None :
811- return contextlib .nullcontext ()
812- device = _device (like )
813-
814- if bare_xp is sys .modules .get ('numpy' ):
815- if device != "cpu" :
816- raise ValueError (f"Unsupported device for NumPy: { device !r} " )
817- return contextlib .nullcontext ()
818-
819- if bare_xp is sys .modules .get ('dask.array' ):
820- if device not in ("cpu" , _DASK_DEVICE ):
821- raise ValueError (f"Unsupported device for Dask: { device !r} " )
822- return contextlib .nullcontext ()
823-
824- if bare_xp is sys .modules .get ('cupy' ):
825- if not isinstance (device , bare_xp .cuda .Device ):
826- raise TypeError (f"device is not a cupy.cuda.Device: { device !r} " )
827- return device
828-
829- # PyTorch doesn't have a "current device" context manager and you
830- # can't use array creation functions from common._aliases.
831- raise AssertionError ("unreachable" ) # pragma: nocover
832-
833-
834- def _check_device (bare_xp : Namespace , device : Device ) -> None :
835- with _device_ctx (bare_xp , device ):
836- pass
837-
838-
839839def size (x : Array ) -> int | None :
840840 """
841841 Return the total number of elements of x.
0 commit comments