88from collections .abc import Sequence
99from typing import TYPE_CHECKING , Any , NamedTuple , cast
1010
11- from ._helpers import _check_device , array_namespace
11+ from ._helpers import _device_ctx , array_namespace
1212from ._helpers import device as _get_device
1313from ._helpers import is_cupy_namespace
1414from ._typing import Array , Device , DType , Namespace
@@ -33,8 +33,8 @@ def arange(
3333 device : Device | None = None ,
3434 ** kwargs : object ,
3535) -> Array :
36- _check_device (xp , device )
37- return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
36+ with _device_ctx (xp , device ):
37+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
3838
3939
4040def empty (
@@ -45,8 +45,8 @@ def empty(
4545 device : Device | None = None ,
4646 ** kwargs : object ,
4747) -> Array :
48- _check_device (xp , device )
49- return xp .empty (shape , dtype = dtype , ** kwargs )
48+ with _device_ctx (xp , device ):
49+ return xp .empty (shape , dtype = dtype , ** kwargs )
5050
5151
5252def empty_like (
@@ -58,8 +58,8 @@ def empty_like(
5858 device : Device | None = None ,
5959 ** kwargs : object ,
6060) -> Array :
61- _check_device (xp , device )
62- return xp .empty_like (x , dtype = dtype , ** kwargs )
61+ with _device_ctx (xp , device , like = x ):
62+ return xp .empty_like (x , dtype = dtype , ** kwargs )
6363
6464
6565def eye (
@@ -73,8 +73,8 @@ def eye(
7373 device : Device | None = None ,
7474 ** kwargs : object ,
7575) -> Array :
76- _check_device (xp , device )
77- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
76+ with _device_ctx (xp , device ):
77+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
7878
7979
8080def full (
@@ -86,8 +86,8 @@ def full(
8686 device : Device | None = None ,
8787 ** kwargs : object ,
8888) -> Array :
89- _check_device (xp , device )
90- return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
89+ with _device_ctx (xp , device ):
90+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
9191
9292
9393def full_like (
@@ -100,8 +100,8 @@ def full_like(
100100 device : Device | None = None ,
101101 ** kwargs : object ,
102102) -> Array :
103- _check_device (xp , device )
104- return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
103+ with _device_ctx (xp , device , like = x ):
104+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
105105
106106
107107def linspace (
@@ -116,8 +116,8 @@ def linspace(
116116 endpoint : bool = True ,
117117 ** kwargs : object ,
118118) -> Array :
119- _check_device (xp , device )
120- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
119+ with _device_ctx (xp , device ):
120+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
121121
122122
123123def ones (
@@ -128,8 +128,8 @@ def ones(
128128 device : Device | None = None ,
129129 ** kwargs : object ,
130130) -> Array :
131- _check_device (xp , device )
132- return xp .ones (shape , dtype = dtype , ** kwargs )
131+ with _device_ctx (xp , device ):
132+ return xp .ones (shape , dtype = dtype , ** kwargs )
133133
134134
135135def ones_like (
@@ -141,8 +141,8 @@ def ones_like(
141141 device : Device | None = None ,
142142 ** kwargs : object ,
143143) -> Array :
144- _check_device (xp , device )
145- return xp .ones_like (x , dtype = dtype , ** kwargs )
144+ with _device_ctx (xp , device , like = x ):
145+ return xp .ones_like (x , dtype = dtype , ** kwargs )
146146
147147
148148def zeros (
@@ -153,8 +153,8 @@ def zeros(
153153 device : Device | None = None ,
154154 ** kwargs : object ,
155155) -> Array :
156- _check_device (xp , device )
157- return xp .zeros (shape , dtype = dtype , ** kwargs )
156+ with _device_ctx (xp , device ):
157+ return xp .zeros (shape , dtype = dtype , ** kwargs )
158158
159159
160160def zeros_like (
@@ -166,8 +166,8 @@ def zeros_like(
166166 device : Device | None = None ,
167167 ** kwargs : object ,
168168) -> Array :
169- _check_device (xp , device )
170- return xp .zeros_like (x , dtype = dtype , ** kwargs )
169+ with _device_ctx (xp , device , like = x ):
170+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
171171
172172
173173# np.unique() is split into four functions in the array API:
0 commit comments