3131 "Info" ,
3232]
3333
34- from dataclasses import dataclass
3534from typing import (
3635 Any ,
3736 List ,
4544 Protocol ,
4645)
4746from enum import Enum
47+ from .data_types import DType
4848
49- array = TypeVar ("array" , bound = "array_ " )
49+ array = TypeVar ("array" , bound = "Array " )
5050device = TypeVar ("device" )
51- dtype = TypeVar ("dtype" )
51+ dtype = TypeVar ("dtype" , bound = DType )
52+ device_ = TypeVar ("device_" ) # only used in this file
53+ dtype_ = TypeVar ("dtype_" , bound = DType ) # only used in this file
5254SupportsDLPack = TypeVar ("SupportsDLPack" )
5355SupportsBufferProtocol = TypeVar ("SupportsBufferProtocol" )
5456PyCapsule = TypeVar ("PyCapsule" )
@@ -88,7 +90,7 @@ def __len__(self, /) -> int:
8890 ...
8991
9092
91- class Info (Protocol ):
93+ class Info (Protocol [ device ] ):
9294 """Namespace returned by `__array_namespace_info__`."""
9395
9496 def capabilities (self ) -> Capabilities :
@@ -147,12 +149,12 @@ def dtypes(
147149)
148150
149151
150- class _array (Protocol [array , dtype , device ]):
152+ class Array (Protocol [array , dtype_ , device_ , PyCapsule ]): # type: ignore
151153 def __init__ (self : array ) -> None :
152154 """Initialize the attributes for the array object class."""
153155
154156 @property
155- def dtype (self : array ) -> dtype :
157+ def dtype (self : array ) -> dtype_ :
156158 """
157159 Data type of the array elements.
158160
@@ -163,7 +165,7 @@ def dtype(self: array) -> dtype:
163165 """
164166
165167 @property
166- def device (self : array ) -> device :
168+ def device (self : array ) -> device_ :
167169 """
168170 Hardware device the array data resides on.
169171
@@ -625,7 +627,7 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
625627 ONE_API = 14
626628 """
627629
628- def __eq__ (self : array , other : Union [int , float , bool , array ], / ) -> array :
630+ def __eq__ (self : array , other : Union [int , float , bool , array ], / ) -> array : # type: ignore
629631 r"""
630632 Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
631633
@@ -1072,7 +1074,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
10721074 Added complex data type support.
10731075 """
10741076
1075- def __ne__ (self : array , other : Union [int , float , bool , array ], / ) -> array :
1077+ def __ne__ (self : array , other : Union [int , float , bool , array ], / ) -> array : # type: ignore
10761078 """
10771079 Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
10781080
@@ -1342,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
13421344 """
13431345
13441346 def to_device (
1345- self : array , device : device , / , * , stream : Optional [Union [int , Any ]] = None
1347+ self : array , device : device_ , / , * , stream : Optional [Union [int , Any ]] = None
13461348 ) -> array :
13471349 """
13481350 Copy the array from the device on which it currently resides to the specified ``device``.
0 commit comments