77import inspect
88from typing import NamedTuple , Optional , Sequence , Tuple , Union
99
10- from ._helpers import array_namespace , _check_device , device , is_cupy_namespace
1110from ._typing import Array , Device , DType , Namespace
11+ from ._helpers import (
12+ array_namespace ,
13+ _check_device ,
14+ device as _get_device ,
15+ is_cupy_namespace as _is_cupy_namespace
16+ )
17+
1218
1319# These functions are modified from the NumPy versions.
1420
@@ -298,7 +304,7 @@ def cumulative_sum(
298304 initial_shape = list (x .shape )
299305 initial_shape [axis ] = 1
300306 res = xp .concatenate (
301- [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
307+ [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
302308 axis = axis ,
303309 )
304310 return res
@@ -328,7 +334,7 @@ def cumulative_prod(
328334 initial_shape = list (x .shape )
329335 initial_shape [axis ] = 1
330336 res = xp .concatenate (
331- [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
337+ [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
332338 axis = axis ,
333339 )
334340 return res
@@ -381,7 +387,7 @@ def _isscalar(a):
381387 if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
382388 max = None
383389
384- dev = device (x )
390+ dev = _get_device (x )
385391 if out is None :
386392 out = wrapped_xp .empty (result_shape , dtype = x .dtype , device = dev )
387393 out [()] = x
@@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
599605 out = xp .sign (x , ** kwargs )
600606 # CuPy sign() does not propagate nans. See
601607 # https://github.com/data-apis/array-api-compat/issues/136
602- if is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
608+ if _is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
603609 out [xp .isnan (x )] = xp .nan
604610 return out [()]
605611
@@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
611617 'reshape' , 'argsort' , 'sort' , 'nonzero' , 'ceil' , 'floor' , 'trunc' ,
612618 'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ,
613619 'unstack' , 'sign' ]
620+
621+ _all_ignore = ['inspect' , 'array_namespace' , 'NamedTuple' ]
0 commit comments