1+ from collections .abc import Mapping
12from functools import lru_cache
2- from typing import NamedTuple , Tuple , Union
3+ from typing import Any , NamedTuple , Sequence , Tuple , Union
34from warnings import warn
45
56from . import _array_module as xp
3637]
3738
3839
40+ class EqualityMapping (Mapping ):
41+ """
42+ Mapping that uses equality for indexing
43+
44+ Typical mappings (e.g. the built-in dict) use hashing for indexing. This
45+ isn't ideal for the Array API, as no __hash__() method is specified for
46+ dtype objects - but __eq__() is!
47+
48+ See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
49+ """
50+
51+ def __init__ (self , key_value_pairs : Sequence [Tuple [Any , Any ]]):
52+ keys = [k for k , _ in key_value_pairs ]
53+ for i , key in enumerate (keys ):
54+ if not (key == key ): # specifically checking __eq__, not __neq__
55+ raise ValueError ("Key {key!r} does not have equality with itself" )
56+ other_keys = keys [:]
57+ other_keys .pop (i )
58+ for other_key in other_keys :
59+ if key == other_key :
60+ raise ValueError ("Key {key!r} has equality with key {other_key!r}" )
61+ self ._key_value_pairs = key_value_pairs
62+
63+ def __getitem__ (self , key ):
64+ for k , v in self ._key_value_pairs :
65+ if key == k :
66+ return v
67+ else :
68+ raise KeyError (f"{ key !r} not found" )
69+
70+ def __iter__ (self ):
71+ return (k for k , _ in self ._key_value_pairs )
72+
73+ def __len__ (self ):
74+ return len (self ._key_value_pairs )
75+
76+ def __str__ (self ):
77+ return "{" + ", " .join (f"{ k !r} : { v !r} " for k , v in self ._key_value_pairs ) + "}"
78+
79+ def __repr__ (self ):
80+ return f"EqualityMapping({ self } )"
81+
82+
3983_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
4084_int_names = ("int8" , "int16" , "int32" , "int64" )
4185_float_names = ("float32" , "float64" )
5195bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
5296
5397
54- dtype_to_name = { getattr (xp , name ): name for name in _dtype_names }
98+ dtype_to_name = EqualityMapping ([( getattr (xp , name ), name ) for name in _dtype_names ])
5599
56100
57- dtype_to_scalars = {
58- xp .bool : [bool ],
59- ** {d : [int ] for d in all_int_dtypes },
60- ** {d : [int , float ] for d in float_dtypes },
61- }
101+ dtype_to_scalars = EqualityMapping (
102+ [
103+ (xp .bool , [bool ]),
104+ * [(d , [int ]) for d in all_int_dtypes ],
105+ * [(d , [int , float ]) for d in float_dtypes ],
106+ ]
107+ )
62108
63109
64110def is_int_dtype (dtype ):
@@ -90,31 +136,32 @@ class MinMax(NamedTuple):
90136 max : Union [int , float ]
91137
92138
93- dtype_ranges = {
94- xp .int8 : MinMax (- 128 , + 127 ),
95- xp .int16 : MinMax (- 32_768 , + 32_767 ),
96- xp .int32 : MinMax (- 2_147_483_648 , + 2_147_483_647 ),
97- xp .int64 : MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ),
98- xp .uint8 : MinMax (0 , + 255 ),
99- xp .uint16 : MinMax (0 , + 65_535 ),
100- xp .uint32 : MinMax (0 , + 4_294_967_295 ),
101- xp .uint64 : MinMax (0 , + 18_446_744_073_709_551_615 ),
102- xp .float32 : MinMax (- 3.4028234663852886e+38 , 3.4028234663852886e+38 ),
103- xp .float64 : MinMax (- 1.7976931348623157e+308 , 1.7976931348623157e+308 ),
104- }
139+ dtype_ranges = EqualityMapping (
140+ [
141+ (xp .int8 , MinMax (- 128 , + 127 )),
142+ (xp .int16 , MinMax (- 32_768 , + 32_767 )),
143+ (xp .int32 , MinMax (- 2_147_483_648 , + 2_147_483_647 )),
144+ (xp .int64 , MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 )),
145+ (xp .uint8 , MinMax (0 , + 255 )),
146+ (xp .uint16 , MinMax (0 , + 65_535 )),
147+ (xp .uint32 , MinMax (0 , + 4_294_967_295 )),
148+ (xp .uint64 , MinMax (0 , + 18_446_744_073_709_551_615 )),
149+ (xp .float32 , MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 )),
150+ (xp .float64 , MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 )),
151+ ]
152+ )
105153
106- dtype_nbits = {
107- ** { d : 8 for d in [xp .int8 , xp .uint8 ]},
108- ** { d : 16 for d in [xp .int16 , xp .uint16 ]},
109- ** { d : 32 for d in [xp .int32 , xp .uint32 , xp .float32 ]},
110- ** { d : 64 for d in [xp .int64 , xp .uint64 , xp .float64 ]},
111- }
154+ dtype_nbits = EqualityMapping (
155+ [( d , 8 ) for d in [xp .int8 , xp .uint8 ]]
156+ + [( d , 16 ) for d in [xp .int16 , xp .uint16 ]]
157+ + [( d , 32 ) for d in [xp .int32 , xp .uint32 , xp .float32 ]]
158+ + [( d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 ]]
159+ )
112160
113161
114- dtype_signed = {
115- ** {d : True for d in int_dtypes },
116- ** {d : False for d in uint_dtypes },
117- }
162+ dtype_signed = EqualityMapping (
163+ [(d , True ) for d in int_dtypes ] + [(d , False ) for d in uint_dtypes ]
164+ )
118165
119166
120167if isinstance (xp .asarray , _UndefinedStub ):
@@ -137,52 +184,51 @@ class MinMax(NamedTuple):
137184 default_uint = xp .uint64
138185
139186
140- _numeric_promotions = {
187+ _numeric_promotions = [
141188 # ints
142- (xp .int8 , xp .int8 ): xp .int8 ,
143- (xp .int8 , xp .int16 ): xp .int16 ,
144- (xp .int8 , xp .int32 ): xp .int32 ,
145- (xp .int8 , xp .int64 ): xp .int64 ,
146- (xp .int16 , xp .int16 ): xp .int16 ,
147- (xp .int16 , xp .int32 ): xp .int32 ,
148- (xp .int16 , xp .int64 ): xp .int64 ,
149- (xp .int32 , xp .int32 ): xp .int32 ,
150- (xp .int32 , xp .int64 ): xp .int64 ,
151- (xp .int64 , xp .int64 ): xp .int64 ,
189+ (( xp .int8 , xp .int8 ), xp .int8 ) ,
190+ (( xp .int8 , xp .int16 ), xp .int16 ) ,
191+ (( xp .int8 , xp .int32 ), xp .int32 ) ,
192+ (( xp .int8 , xp .int64 ), xp .int64 ) ,
193+ (( xp .int16 , xp .int16 ), xp .int16 ) ,
194+ (( xp .int16 , xp .int32 ), xp .int32 ) ,
195+ (( xp .int16 , xp .int64 ), xp .int64 ) ,
196+ (( xp .int32 , xp .int32 ), xp .int32 ) ,
197+ (( xp .int32 , xp .int64 ), xp .int64 ) ,
198+ (( xp .int64 , xp .int64 ), xp .int64 ) ,
152199 # uints
153- (xp .uint8 , xp .uint8 ): xp .uint8 ,
154- (xp .uint8 , xp .uint16 ): xp .uint16 ,
155- (xp .uint8 , xp .uint32 ): xp .uint32 ,
156- (xp .uint8 , xp .uint64 ): xp .uint64 ,
157- (xp .uint16 , xp .uint16 ): xp .uint16 ,
158- (xp .uint16 , xp .uint32 ): xp .uint32 ,
159- (xp .uint16 , xp .uint64 ): xp .uint64 ,
160- (xp .uint32 , xp .uint32 ): xp .uint32 ,
161- (xp .uint32 , xp .uint64 ): xp .uint64 ,
162- (xp .uint64 , xp .uint64 ): xp .uint64 ,
200+ (( xp .uint8 , xp .uint8 ), xp .uint8 ) ,
201+ (( xp .uint8 , xp .uint16 ), xp .uint16 ) ,
202+ (( xp .uint8 , xp .uint32 ), xp .uint32 ) ,
203+ (( xp .uint8 , xp .uint64 ), xp .uint64 ) ,
204+ (( xp .uint16 , xp .uint16 ), xp .uint16 ) ,
205+ (( xp .uint16 , xp .uint32 ), xp .uint32 ) ,
206+ (( xp .uint16 , xp .uint64 ), xp .uint64 ) ,
207+ (( xp .uint32 , xp .uint32 ), xp .uint32 ) ,
208+ (( xp .uint32 , xp .uint64 ), xp .uint64 ) ,
209+ (( xp .uint64 , xp .uint64 ), xp .uint64 ) ,
163210 # ints and uints (mixed sign)
164- (xp .int8 , xp .uint8 ): xp .int16 ,
165- (xp .int8 , xp .uint16 ): xp .int32 ,
166- (xp .int8 , xp .uint32 ): xp .int64 ,
167- (xp .int16 , xp .uint8 ): xp .int16 ,
168- (xp .int16 , xp .uint16 ): xp .int32 ,
169- (xp .int16 , xp .uint32 ): xp .int64 ,
170- (xp .int32 , xp .uint8 ): xp .int32 ,
171- (xp .int32 , xp .uint16 ): xp .int32 ,
172- (xp .int32 , xp .uint32 ): xp .int64 ,
173- (xp .int64 , xp .uint8 ): xp .int64 ,
174- (xp .int64 , xp .uint16 ): xp .int64 ,
175- (xp .int64 , xp .uint32 ): xp .int64 ,
211+ (( xp .int8 , xp .uint8 ), xp .int16 ) ,
212+ (( xp .int8 , xp .uint16 ), xp .int32 ) ,
213+ (( xp .int8 , xp .uint32 ), xp .int64 ) ,
214+ (( xp .int16 , xp .uint8 ), xp .int16 ) ,
215+ (( xp .int16 , xp .uint16 ), xp .int32 ) ,
216+ (( xp .int16 , xp .uint32 ), xp .int64 ) ,
217+ (( xp .int32 , xp .uint8 ), xp .int32 ) ,
218+ (( xp .int32 , xp .uint16 ), xp .int32 ) ,
219+ (( xp .int32 , xp .uint32 ), xp .int64 ) ,
220+ (( xp .int64 , xp .uint8 ), xp .int64 ) ,
221+ (( xp .int64 , xp .uint16 ), xp .int64 ) ,
222+ (( xp .int64 , xp .uint32 ), xp .int64 ) ,
176223 # floats
177- (xp .float32 , xp .float32 ): xp .float32 ,
178- (xp .float32 , xp .float64 ): xp .float64 ,
179- (xp .float64 , xp .float64 ): xp .float64 ,
180- }
181- promotion_table = {
182- (xp .bool , xp .bool ): xp .bool ,
183- ** _numeric_promotions ,
184- ** {(d2 , d1 ): res for (d1 , d2 ), res in _numeric_promotions .items ()},
185- }
224+ ((xp .float32 , xp .float32 ), xp .float32 ),
225+ ((xp .float32 , xp .float64 ), xp .float64 ),
226+ ((xp .float64 , xp .float64 ), xp .float64 ),
227+ ]
228+ _numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
229+ _promotion_table = list (set (_numeric_promotions ))
230+ _promotion_table .insert (0 , ((xp .bool , xp .bool ), xp .bool ))
231+ promotion_table = EqualityMapping (_promotion_table )
186232
187233
188234def result_type (* dtypes : DataType ):
0 commit comments