1616
1717NamespaceT_co = TypeVar ("NamespaceT_co" , covariant = True , default = ModuleType )
1818DTypeT_co = TypeVar ("DTypeT_co" , covariant = True )
19+ DeviceT_co = TypeVar ("DeviceT_co" , covariant = True , default = object )
1920
2021
2122class HasArrayNamespace (Protocol [NamespaceT_co ]):
@@ -74,11 +75,11 @@ def dtype(self, /) -> DTypeT_co:
7475 ...
7576
7677
77- class HasDevice (Protocol ):
78+ class HasDevice (Protocol [ DeviceT_co ] ):
7879 """Protocol for array classes that have a device attribute."""
7980
8081 @property
81- def device (self ) -> object : # TODO: more specific type
82+ def device (self ) -> DeviceT_co :
8283 """Hardware device the array data resides on."""
8384 ...
8485
@@ -191,7 +192,7 @@ def T(self) -> Self: # noqa: N802
191192class Array (
192193 # ------ Attributes -------
193194 HasDType [DTypeT_co ],
194- HasDevice ,
195+ HasDevice [ DeviceT_co ] ,
195196 HasMatrixTranspose ,
196197 HasNDim ,
197198 HasShape ,
@@ -200,14 +201,18 @@ class Array(
200201 # ------- Methods ---------
201202 HasArrayNamespace [NamespaceT_co ],
202203 # -------------------------
203- Protocol [DTypeT_co , NamespaceT_co ],
204+ Protocol [DTypeT_co , DeviceT_co , NamespaceT_co ],
204205):
205206 """Array API specification for array object attributes and methods.
206207
207- The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
208- NamespaceT]`` where:
208+ The type is: ``Array[+DTypeT, +DeviceT = object, + NamespaceT = ModuleType] =
209+ Array[DTypeT, DeviceT, NamespaceT]`` where:
209210
210211 - `DTypeT` is the data type of the array elements.
212+ - `DeviceT` is the type of the device attribute. It defaults to `object` to
213+ enable skipping device specification. Array objects supporting device
214+ management can specify a more specific type if they use types (as opposed
215+ to object instances) to distinguish between different devices.
211216 - `NamespaceT` is the type of the array namespace. It defaults to
212217 `ModuleType`, which is the most common form of array namespace (e.g.,
213218 `numpy`, `cupy`, etc.). However, it can be any type, e.g. a
0 commit comments