22
33from functools import reduce as _reduce, wraps as _wraps
44from builtins import all as _builtin_all, any as _builtin_any
5- from typing import List, Optional, Sequence, Tuple, Union
5+ from typing import Any, List, Optional, Sequence, Tuple, Union
66
77import torch
88
99from .._internal import get_xp
1010from ..common import _aliases
11+ from ..common._typing import NestedSequence, SupportsBufferProtocol
1112from ._info import __array_namespace_info__
1213from ._typing import Array, Device, DType
1314
@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207208remainder = _two_arg(torch.remainder)
208209subtract = _two_arg(torch.subtract)
209210
211+
212+ def asarray(
213+ obj: (
214+ Array
215+ | bool | int | float | complex
216+ | NestedSequence[bool | int | float | complex]
217+ | SupportsBufferProtocol
218+ ),
219+ /,
220+ *,
221+ dtype: DType | None = None,
222+ device: Device | None = None,
223+ copy: bool | None = None,
224+ **kwargs: Any,
225+ ) -> Array:
226+ # torch.asarray does not respect input->output device propagation
227+ # https://github.com/pytorch/pytorch/issues/150199
228+ if device is None and isinstance(obj, torch.Tensor):
229+ device = obj.device
230+ return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
231+
232+
210233# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211234# of 'axis'.
212235
@@ -282,7 +305,6 @@ def prod(x: Array,
282305 dtype: Optional[DType] = None,
283306 keepdims: bool = False,
284307 **kwargs) -> Array:
285- x = torch.asarray(x)
286308 ndim = x.ndim
287309
288310 # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +340,6 @@ def sum(x: Array,
318340 dtype: Optional[DType] = None,
319341 keepdims: bool = False,
320342 **kwargs) -> Array:
321- x = torch.asarray(x)
322343 ndim = x.ndim
323344
324345 # https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +369,6 @@ def any(x: Array,
348369 axis: Optional[Union[int, Tuple[int, ...]]] = None,
349370 keepdims: bool = False,
350371 **kwargs) -> Array:
351- x = torch.asarray(x)
352372 ndim = x.ndim
353373 if axis == ():
354374 return x.to(torch.bool)
@@ -373,7 +393,6 @@ def all(x: Array,
373393 axis: Optional[Union[int, Tuple[int, ...]]] = None,
374394 keepdims: bool = False,
375395 **kwargs) -> Array:
376- x = torch.asarray(x)
377396 ndim = x.ndim
378397 if axis == ():
379398 return x.to(torch.bool)
@@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array:
816835 return out
817836
818837
819- __all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
838+ __all__ = ['__array_namespace_info__', 'asarray', ' result_type', 'can_cast',
820839 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
821840 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
822841 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
0 commit comments