11import pytest
22from hypothesis import given , strategies as st
3- from array_api_tests .dtype_helpers import available_kinds
3+ from array_api_tests .dtype_helpers import available_kinds , dtype_names
44
55from . import xp
66
77pytestmark = pytest .mark .min_version ("2023.12" )
88
99
10- def test_array_namespace_info ():
11- out = xp .__array_namespace_info__ ()
10+ class TestInspection :
11+ def test_capabilities (self ):
12+ out = xp .__array_namespace_info__ ()
1213
13- capabilities = out .capabilities ()
14- assert isinstance (capabilities , dict )
14+ capabilities = out .capabilities ()
15+ assert isinstance (capabilities , dict )
1516
16- out .default_device ()
17+ expected_attr = {"boolean indexing" : bool , "data-dependent shapes" : bool }
18+ if xp .__array_api_version__ >= "2024.12" :
19+ expected_attr .update (** {"max dimensions" : int })
20+
21+ for attr , typ in expected_attr .items ():
22+ assert attr in capabilities , f'capabilites is missing "{ attr } ".'
23+ assert isinstance (capabilities [attr ], typ )
24+
25+ assert capabilities .get ("max dimensions" , 100500 ) > 0
26+
27+ def test_devices (self ):
28+ out = xp .__array_namespace_info__ ()
29+
30+ assert hasattr (out , "devices" )
31+ assert hasattr (out , "default_device" )
32+
33+ assert isinstance (out .devices (), list )
34+ if out .default_device () is not None :
35+ # Per https://github.com/data-apis/array-api/issues/923
36+ # default_device() can return None. Otherwise, it must be a valid device.
37+ assert out .default_device () in out .devices ()
38+
39+ def test_default_dtypes (self ):
40+ out = xp .__array_namespace_info__ ()
41+
42+ for device in xp .__array_namespace_info__ ().devices ():
43+ default_dtypes = out .default_dtypes (device = device )
44+ assert isinstance (default_dtypes , dict )
45+ expected_subset = (
46+ {"real floating" , "complex floating" , "integral" }
47+ & available_kinds ()
48+ | {"indexing" }
49+ )
50+ assert expected_subset .issubset (set (default_dtypes .keys ()))
1751
18- default_dtypes = out .default_dtypes ()
19- assert isinstance (default_dtypes , dict )
20- expected_subset = {"real floating" , "complex floating" , "integral" } & available_kinds () | {"indexing" }
21- assert expected_subset .issubset (set (default_dtypes .keys ()))
2252
23- devices = out .devices ()
24- assert isinstance (devices , list )
25-
26-
2753atomic_kinds = [
2854 "bool" ,
2955 "signed integer" ,
@@ -34,12 +60,21 @@ def test_array_namespace_info():
3460
3561
3662@given (
37- st .one_of (
63+ kind = st .one_of (
3864 st .none (),
3965 st .sampled_from (atomic_kinds + ["integral" , "numeric" ]),
4066 st .lists (st .sampled_from (atomic_kinds ), unique = True , min_size = 1 ).map (tuple ),
67+ ),
68+ device = st .one_of (
69+ st .none (),
70+ st .sampled_from (xp .__array_namespace_info__ ().devices ())
4171 )
4272)
43- def test_array_namespace_info_dtypes (kind ):
44- out = xp .__array_namespace_info__ ().dtypes (kind = kind )
73+ def test_array_namespace_info_dtypes (kind , device ):
74+ out = xp .__array_namespace_info__ ().dtypes (kind = kind , device = device )
4575 assert isinstance (out , dict )
76+
77+ for name , dtyp in out .items ():
78+ assert name in dtype_names
79+ xp .empty (1 , dtype = dtyp , device = device ) # check `dtyp` is a valid dtype
80+
0 commit comments