@@ -34,7 +34,7 @@ class __array_namespace_info__:
3434
3535 Examples
3636 --------
37- >>> info = np .__array_namespace_info__()
37+ >>> info = xp .__array_namespace_info__()
3838 >>> info.default_dtypes()
3939 {'real floating': numpy.float64,
4040 'complex floating': numpy.complex128,
@@ -76,16 +76,16 @@ def capabilities(self):
7676
7777 Examples
7878 --------
79- >>> info = np .__array_namespace_info__()
79+ >>> info = xp .__array_namespace_info__()
8080 >>> info.capabilities()
8181 {'boolean indexing': True,
82- 'data-dependent shapes': True}
82+ 'data-dependent shapes': True,
83+ 'max dimensions': 64}
8384
8485 """
8586 return {
8687 "boolean indexing" : True ,
8788 "data-dependent shapes" : True ,
88- # 'max rank' will be part of the 2024.12 standard
8989 "max dimensions" : 64 ,
9090 }
9191
@@ -102,15 +102,24 @@ def default_device(self):
102102
103103 Returns
104104 -------
105- device : str
105+ device : Device
106106 The default device used for new PyTorch arrays.
107107
108108 Examples
109109 --------
110- >>> info = np .__array_namespace_info__()
110+ >>> info = xp .__array_namespace_info__()
111111 >>> info.default_device()
112- 'cpu'
112+ device(type= 'cpu')
113113
114+ Notes
115+ -----
116+ This method returns the static default device when PyTorch is initialized.
117+ However, the *current* device used by creation functions (``empty`` etc.)
118+ can be changed at runtime.
119+
120+ See Also
121+ --------
122+ https://github.com/data-apis/array-api/issues/835
114123 """
115124 return torch .device ("cpu" )
116125
@@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None):
120129
121130 Parameters
122131 ----------
123- device : str , optional
124- The device to get the default data types for. For PyTorch, only
125- ``'cpu'`` is allowed .
132+ device : Device , optional
133+ The device to get the default data types for.
134+ Unused for PyTorch, as all devices use the same default dtypes .
126135
127136 Returns
128137 -------
@@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None):
139148
140149 Examples
141150 --------
142- >>> info = np .__array_namespace_info__()
151+ >>> info = xp .__array_namespace_info__()
143152 >>> info.default_dtypes()
144153 {'real floating': torch.float32,
145154 'complex floating': torch.complex64,
@@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None):
250259
251260 Parameters
252261 ----------
253- device : str , optional
262+ device : Device , optional
254263 The device to get the data types for.
264+ Unused for PyTorch, as all devices use the same dtypes.
255265 kind : str or tuple of str, optional
256266 The kind of data types to return. If ``None``, all data types are
257267 returned. If a string, only data types of that kind are returned.
@@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None):
287297
288298 Examples
289299 --------
290- >>> info = np .__array_namespace_info__()
300+ >>> info = xp .__array_namespace_info__()
291301 >>> info.dtypes(kind='signed integer')
292302 {'int8': numpy.int8,
293303 'int16': numpy.int16,
@@ -310,7 +320,7 @@ def devices(self):
310320
311321 Returns
312322 -------
313- devices : list of str
323+ devices : list[Device]
314324 The devices supported by PyTorch.
315325
316326 See Also
@@ -322,7 +332,7 @@ def devices(self):
322332
323333 Examples
324334 --------
325- >>> info = np .__array_namespace_info__()
335+ >>> info = xp .__array_namespace_info__()
326336 >>> info.devices()
327337 [device(type='cpu'), device(type='mps', index=0), device(type='meta')]
328338
@@ -333,6 +343,7 @@ def devices(self):
333343 # device:
334344 try :
335345 torch .device ('notadevice' )
346+ raise AssertionError ("unreachable" ) # pragma: nocover
336347 except RuntimeError as e :
337348 # The error message is something like:
338349 # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
0 commit comments