11from collections .abc import Mapping
22from functools import lru_cache
3- from typing import NamedTuple , Tuple , Union
3+ from typing import Any , NamedTuple , Sequence , Tuple , Union
44from warnings import warn
55
66from . import _array_module as xp
@@ -48,8 +48,8 @@ class EqualityMapping(Mapping):
4848 See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
4949 """
5050
51- def __init__ (self , mapping : Mapping ):
52- keys = list ( mapping . keys ())
51+ def __init__ (self , key_value_pairs : Sequence [ Tuple [ Any , Any ]] ):
52+ keys = [ k for k , _ in key_value_pairs ]
5353 for i , key in enumerate (keys ):
5454 if not (key == key ): # specifically checking __eq__, not __neq__
5555 raise ValueError ("Key {key!r} does not have equality with itself" )
@@ -58,23 +58,26 @@ def __init__(self, mapping: Mapping):
5858 for other_key in other_keys :
5959 if key == other_key :
6060 raise ValueError ("Key {key!r} has equality with key {other_key!r}" )
61- self ._mapping = mapping
61+ self ._key_value_pairs = key_value_pairs
6262
6363 def __getitem__ (self , key ):
64- for k , v in self ._mapping . items () :
64+ for k , v in self ._key_value_pairs :
6565 if key == k :
6666 return v
6767 else :
6868 raise KeyError (f"{ key !r} not found" )
6969
7070 def __iter__ (self ):
71- return iter ( self ._mapping )
71+ return ( k for k , _ in self ._key_value_pairs )
7272
7373 def __len__ (self ):
74- return len (self ._mapping )
74+ return len (self ._key_value_pairs )
75+
76+ def __str__ (self ):
77+ return "{" + ", " .join (f"{ k !r} : { v !r} " for k , v in self ._key_value_pairs ) + "}"
7578
7679 def __repr__ (self ):
77- return f"EqualityMapping({ self . _mapping !r } )"
80+ return f"EqualityMapping({ self } )"
7881
7982
8083_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
@@ -92,15 +95,15 @@ def __repr__(self):
9295bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
9396
9497
95- dtype_to_name = EqualityMapping ({ getattr (xp , name ): name for name in _dtype_names } )
98+ dtype_to_name = EqualityMapping ([( getattr (xp , name ), name ) for name in _dtype_names ] )
9699
97100
98101dtype_to_scalars = EqualityMapping (
99- {
100- xp .bool : [bool ],
101- ** { d : [int ] for d in all_int_dtypes } ,
102- ** { d : [int , float ] for d in float_dtypes } ,
103- }
102+ [
103+ ( xp .bool , [bool ]) ,
104+ * [( d , [int ]) for d in all_int_dtypes ] ,
105+ * [( d , [int , float ]) for d in float_dtypes ] ,
106+ ]
104107)
105108
106109
@@ -134,35 +137,30 @@ class MinMax(NamedTuple):
134137
135138
136139dtype_ranges = EqualityMapping (
137- {
138- xp .int8 : MinMax (- 128 , + 127 ),
139- xp .int16 : MinMax (- 32_768 , + 32_767 ),
140- xp .int32 : MinMax (- 2_147_483_648 , + 2_147_483_647 ),
141- xp .int64 : MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ),
142- xp .uint8 : MinMax (0 , + 255 ),
143- xp .uint16 : MinMax (0 , + 65_535 ),
144- xp .uint32 : MinMax (0 , + 4_294_967_295 ),
145- xp .uint64 : MinMax (0 , + 18_446_744_073_709_551_615 ),
146- xp .float32 : MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ),
147- xp .float64 : MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ),
148- }
140+ [
141+ ( xp .int8 , MinMax (- 128 , + 127 ) ),
142+ ( xp .int16 , MinMax (- 32_768 , + 32_767 ) ),
143+ ( xp .int32 , MinMax (- 2_147_483_648 , + 2_147_483_647 ) ),
144+ ( xp .int64 , MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ) ),
145+ ( xp .uint8 , MinMax (0 , + 255 ) ),
146+ ( xp .uint16 , MinMax (0 , + 65_535 ) ),
147+ ( xp .uint32 , MinMax (0 , + 4_294_967_295 ) ),
148+ ( xp .uint64 , MinMax (0 , + 18_446_744_073_709_551_615 ) ),
149+ ( xp .float32 , MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ) ),
150+ ( xp .float64 , MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ) ),
151+ ]
149152)
150153
151154dtype_nbits = EqualityMapping (
152- {
153- ** {d : 8 for d in [xp .int8 , xp .uint8 ]},
154- ** {d : 16 for d in [xp .int16 , xp .uint16 ]},
155- ** {d : 32 for d in [xp .int32 , xp .uint32 , xp .float32 ]},
156- ** {d : 64 for d in [xp .int64 , xp .uint64 , xp .float64 ]},
157- }
155+ [(d , 8 ) for d in [xp .int8 , xp .uint8 ]]
156+ + [(d , 16 ) for d in [xp .int16 , xp .uint16 ]]
157+ + [(d , 32 ) for d in [xp .int32 , xp .uint32 , xp .float32 ]]
158+ + [(d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 ]]
158159)
159160
160161
161162dtype_signed = EqualityMapping (
162- {
163- ** {d : True for d in int_dtypes },
164- ** {d : False for d in uint_dtypes },
165- }
163+ [(d , True ) for d in int_dtypes ] + [(d , False ) for d in uint_dtypes ]
166164)
167165
168166
@@ -186,54 +184,51 @@ class MinMax(NamedTuple):
186184 default_uint = xp .uint64
187185
188186
189- _numeric_promotions = {
187+ _numeric_promotions = [
190188 # ints
191- (xp .int8 , xp .int8 ): xp .int8 ,
192- (xp .int8 , xp .int16 ): xp .int16 ,
193- (xp .int8 , xp .int32 ): xp .int32 ,
194- (xp .int8 , xp .int64 ): xp .int64 ,
195- (xp .int16 , xp .int16 ): xp .int16 ,
196- (xp .int16 , xp .int32 ): xp .int32 ,
197- (xp .int16 , xp .int64 ): xp .int64 ,
198- (xp .int32 , xp .int32 ): xp .int32 ,
199- (xp .int32 , xp .int64 ): xp .int64 ,
200- (xp .int64 , xp .int64 ): xp .int64 ,
189+ (( xp .int8 , xp .int8 ), xp .int8 ) ,
190+ (( xp .int8 , xp .int16 ), xp .int16 ) ,
191+ (( xp .int8 , xp .int32 ), xp .int32 ) ,
192+ (( xp .int8 , xp .int64 ), xp .int64 ) ,
193+ (( xp .int16 , xp .int16 ), xp .int16 ) ,
194+ (( xp .int16 , xp .int32 ), xp .int32 ) ,
195+ (( xp .int16 , xp .int64 ), xp .int64 ) ,
196+ (( xp .int32 , xp .int32 ), xp .int32 ) ,
197+ (( xp .int32 , xp .int64 ), xp .int64 ) ,
198+ (( xp .int64 , xp .int64 ), xp .int64 ) ,
201199 # uints
202- (xp .uint8 , xp .uint8 ): xp .uint8 ,
203- (xp .uint8 , xp .uint16 ): xp .uint16 ,
204- (xp .uint8 , xp .uint32 ): xp .uint32 ,
205- (xp .uint8 , xp .uint64 ): xp .uint64 ,
206- (xp .uint16 , xp .uint16 ): xp .uint16 ,
207- (xp .uint16 , xp .uint32 ): xp .uint32 ,
208- (xp .uint16 , xp .uint64 ): xp .uint64 ,
209- (xp .uint32 , xp .uint32 ): xp .uint32 ,
210- (xp .uint32 , xp .uint64 ): xp .uint64 ,
211- (xp .uint64 , xp .uint64 ): xp .uint64 ,
200+ (( xp .uint8 , xp .uint8 ), xp .uint8 ) ,
201+ (( xp .uint8 , xp .uint16 ), xp .uint16 ) ,
202+ (( xp .uint8 , xp .uint32 ), xp .uint32 ) ,
203+ (( xp .uint8 , xp .uint64 ), xp .uint64 ) ,
204+ (( xp .uint16 , xp .uint16 ), xp .uint16 ) ,
205+ (( xp .uint16 , xp .uint32 ), xp .uint32 ) ,
206+ (( xp .uint16 , xp .uint64 ), xp .uint64 ) ,
207+ (( xp .uint32 , xp .uint32 ), xp .uint32 ) ,
208+ (( xp .uint32 , xp .uint64 ), xp .uint64 ) ,
209+ (( xp .uint64 , xp .uint64 ), xp .uint64 ) ,
212210 # ints and uints (mixed sign)
213- (xp .int8 , xp .uint8 ): xp .int16 ,
214- (xp .int8 , xp .uint16 ): xp .int32 ,
215- (xp .int8 , xp .uint32 ): xp .int64 ,
216- (xp .int16 , xp .uint8 ): xp .int16 ,
217- (xp .int16 , xp .uint16 ): xp .int32 ,
218- (xp .int16 , xp .uint32 ): xp .int64 ,
219- (xp .int32 , xp .uint8 ): xp .int32 ,
220- (xp .int32 , xp .uint16 ): xp .int32 ,
221- (xp .int32 , xp .uint32 ): xp .int64 ,
222- (xp .int64 , xp .uint8 ): xp .int64 ,
223- (xp .int64 , xp .uint16 ): xp .int64 ,
224- (xp .int64 , xp .uint32 ): xp .int64 ,
211+ (( xp .int8 , xp .uint8 ), xp .int16 ) ,
212+ (( xp .int8 , xp .uint16 ), xp .int32 ) ,
213+ (( xp .int8 , xp .uint32 ), xp .int64 ) ,
214+ (( xp .int16 , xp .uint8 ), xp .int16 ) ,
215+ (( xp .int16 , xp .uint16 ), xp .int32 ) ,
216+ (( xp .int16 , xp .uint32 ), xp .int64 ) ,
217+ (( xp .int32 , xp .uint8 ), xp .int32 ) ,
218+ (( xp .int32 , xp .uint16 ), xp .int32 ) ,
219+ (( xp .int32 , xp .uint32 ), xp .int64 ) ,
220+ (( xp .int64 , xp .uint8 ), xp .int64 ) ,
221+ (( xp .int64 , xp .uint16 ), xp .int64 ) ,
222+ (( xp .int64 , xp .uint32 ), xp .int64 ) ,
225223 # floats
226- (xp .float32 , xp .float32 ): xp .float32 ,
227- (xp .float32 , xp .float64 ): xp .float64 ,
228- (xp .float64 , xp .float64 ): xp .float64 ,
229- }
230- promotion_table = EqualityMapping (
231- {
232- (xp .bool , xp .bool ): xp .bool ,
233- ** _numeric_promotions ,
234- ** {(d2 , d1 ): res for (d1 , d2 ), res in _numeric_promotions .items ()},
235- }
236- )
224+ ((xp .float32 , xp .float32 ), xp .float32 ),
225+ ((xp .float32 , xp .float64 ), xp .float64 ),
226+ ((xp .float64 , xp .float64 ), xp .float64 ),
227+ ]
228+ _numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
229+ _promotion_table = list (set (_numeric_promotions ))
230+ _promotion_table .insert (0 , ((xp .bool , xp .bool ), xp .bool ))
231+ promotion_table = EqualityMapping (_promotion_table )
237232
238233
239234def result_type (* dtypes : DataType ):
0 commit comments