3535 torch .complex128 ,
3636}
3737
38- _promotion_table = {
39- # bool
40- (torch .bool , torch .bool ): torch .bool ,
38+ _promotion_table = {
4139 # ints
42- (torch .int8 , torch .int8 ): torch .int8 ,
4340 (torch .int8 , torch .int16 ): torch .int16 ,
4441 (torch .int8 , torch .int32 ): torch .int32 ,
4542 (torch .int8 , torch .int64 ): torch .int64 ,
46- (torch .int16 , torch .int8 ): torch .int16 ,
47- (torch .int16 , torch .int16 ): torch .int16 ,
4843 (torch .int16 , torch .int32 ): torch .int32 ,
4944 (torch .int16 , torch .int64 ): torch .int64 ,
50- (torch .int32 , torch .int8 ): torch .int32 ,
51- (torch .int32 , torch .int16 ): torch .int32 ,
52- (torch .int32 , torch .int32 ): torch .int32 ,
5345 (torch .int32 , torch .int64 ): torch .int64 ,
54- (torch .int64 , torch .int8 ): torch .int64 ,
55- (torch .int64 , torch .int16 ): torch .int64 ,
56- (torch .int64 , torch .int32 ): torch .int64 ,
57- (torch .int64 , torch .int64 ): torch .int64 ,
58- # uints
59- (torch .uint8 , torch .uint8 ): torch .uint8 ,
6046 # ints and uints (mixed sign)
61- (torch .int8 , torch .uint8 ): torch .int16 ,
62- (torch .int16 , torch .uint8 ): torch .int16 ,
63- (torch .int32 , torch .uint8 ): torch .int32 ,
64- (torch .int64 , torch .uint8 ): torch .int64 ,
6547 (torch .uint8 , torch .int8 ): torch .int16 ,
6648 (torch .uint8 , torch .int16 ): torch .int16 ,
6749 (torch .uint8 , torch .int32 ): torch .int32 ,
6850 (torch .uint8 , torch .int64 ): torch .int64 ,
6951 # floats
70- (torch .float32 , torch .float32 ): torch .float32 ,
7152 (torch .float32 , torch .float64 ): torch .float64 ,
72- (torch .float64 , torch .float32 ): torch .float64 ,
73- (torch .float64 , torch .float64 ): torch .float64 ,
7453 # complexes
75- (torch .complex64 , torch .complex64 ): torch .complex64 ,
7654 (torch .complex64 , torch .complex128 ): torch .complex128 ,
77- (torch .complex128 , torch .complex64 ): torch .complex128 ,
78- (torch .complex128 , torch .complex128 ): torch .complex128 ,
7955 # Mixed float and complex
8056 (torch .float32 , torch .complex64 ): torch .complex64 ,
8157 (torch .float32 , torch .complex128 ): torch .complex128 ,
8258 (torch .float64 , torch .complex64 ): torch .complex128 ,
8359 (torch .float64 , torch .complex128 ): torch .complex128 ,
8460}
8561
62+ _promotion_table .update ({(b , a ): c for (a , b ), c in _promotion_table .items ()})
63+ _promotion_table .update ({(a , a ): a for a in _array_api_dtypes })
64+
8665
8766def _two_arg (f ):
8867 @_wraps (f )
@@ -150,13 +129,18 @@ def result_type(
150129 return _reduce (_result_type , others + scalars )
151130
152131
153- def _result_type (x , y ):
132+ def _result_type (
133+ x : Array | DType | bool | int | float | complex ,
134+ y : Array | DType | bool | int | float | complex ,
135+ ) -> DType :
154136 if not (isinstance (x , _py_scalars ) or isinstance (y , _py_scalars )):
155- xdt = x . dtype if not isinstance (x , torch .dtype ) else x
156- ydt = y . dtype if not isinstance (y , torch .dtype ) else y
137+ xdt = x if isinstance (x , torch .dtype ) else x . dtype
138+ ydt = y if isinstance (y , torch .dtype ) else y . dtype
157139
158- if ( xdt , ydt ) in _promotion_table :
140+ try :
159141 return _promotion_table [xdt , ydt ]
142+ except KeyError :
143+ pass
160144
161145 # This doesn't result_type(dtype, dtype) for non-array API dtypes
162146 # because torch.result_type only accepts tensors. This does however, allow
@@ -301,27 +285,35 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
301285 out = torch .unsqueeze (out , a )
302286 return out
303287
288+
289+ def _sum_prod_no_axis (x : Array , dtype : DType | None ) -> Array :
290+ """
291+ Implements `sum(..., axis=())` and `prod(..., axis=())`.
292+
293+ Works around https://github.com/pytorch/pytorch/issues/29137
294+ """
295+ if dtype is not None :
296+ return x .clone () if dtype == x .dtype else x .to (dtype )
297+
298+ # We can't upcast uint8 according to the spec because there is no
299+ # torch.uint64, so at least upcast to int64 which is what prod does
300+ # when axis=None.
301+ if x .dtype in (torch .uint8 , torch .int8 , torch .int16 , torch .int32 ):
302+ return x .to (torch .int64 )
303+
304+ return x .clone ()
305+
306+
304307def prod (x : Array ,
305308 / ,
306309 * ,
307310 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
308311 dtype : Optional [DType ] = None ,
309312 keepdims : bool = False ,
310313 ** kwargs ) -> Array :
311- ndim = x .ndim
312314
313- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
314- # below because it still needs to upcast.
315315 if axis == ():
316- if dtype is None :
317- # We can't upcast uint8 according to the spec because there is no
318- # torch.uint64, so at least upcast to int64 which is what sum does
319- # when axis=None.
320- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
321- return x .to (torch .int64 )
322- return x .clone ()
323- return x .to (dtype )
324-
316+ return _sum_prod_no_axis (x , dtype )
325317 # torch.prod doesn't support multiple axes
326318 # (https://github.com/pytorch/pytorch/issues/56586).
327319 if isinstance (axis , tuple ):
@@ -330,7 +322,7 @@ def prod(x: Array,
330322 # torch doesn't support keepdims with axis=None
331323 # (https://github.com/pytorch/pytorch/issues/71209)
332324 res = torch .prod (x , dtype = dtype , ** kwargs )
333- res = _axis_none_keepdims (res , ndim , keepdims )
325+ res = _axis_none_keepdims (res , x . ndim , keepdims )
334326 return res
335327
336328 return torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -343,25 +335,14 @@ def sum(x: Array,
343335 dtype : Optional [DType ] = None ,
344336 keepdims : bool = False ,
345337 ** kwargs ) -> Array :
346- ndim = x .ndim
347338
348- # https://github.com/pytorch/pytorch/issues/29137.
349- # Make sure it upcasts.
350339 if axis == ():
351- if dtype is None :
352- # We can't upcast uint8 according to the spec because there is no
353- # torch.uint64, so at least upcast to int64 which is what sum does
354- # when axis=None.
355- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
356- return x .to (torch .int64 )
357- return x .clone ()
358- return x .to (dtype )
359-
340+ return _sum_prod_no_axis (x , dtype )
360341 if axis is None :
361342 # torch doesn't support keepdims with axis=None
362343 # (https://github.com/pytorch/pytorch/issues/71209)
363344 res = torch .sum (x , dtype = dtype , ** kwargs )
364- res = _axis_none_keepdims (res , ndim , keepdims )
345+ res = _axis_none_keepdims (res , x . ndim , keepdims )
365346 return res
366347
367348 return torch .sum (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -372,7 +353,7 @@ def any(x: Array,
372353 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
373354 keepdims : bool = False ,
374355 ** kwargs ) -> Array :
375- ndim = x . ndim
356+
376357 if axis == ():
377358 return x .to (torch .bool )
378359 # torch.any doesn't support multiple axes
@@ -384,7 +365,7 @@ def any(x: Array,
384365 # torch doesn't support keepdims with axis=None
385366 # (https://github.com/pytorch/pytorch/issues/71209)
386367 res = torch .any (x , ** kwargs )
387- res = _axis_none_keepdims (res , ndim , keepdims )
368+ res = _axis_none_keepdims (res , x . ndim , keepdims )
388369 return res .to (torch .bool )
389370
390371 # torch.any doesn't return bool for uint8
@@ -396,7 +377,7 @@ def all(x: Array,
396377 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
397378 keepdims : bool = False ,
398379 ** kwargs ) -> Array :
399- ndim = x . ndim
380+
400381 if axis == ():
401382 return x .to (torch .bool )
402383 # torch.all doesn't support multiple axes
@@ -408,7 +389,7 @@ def all(x: Array,
408389 # torch doesn't support keepdims with axis=None
409390 # (https://github.com/pytorch/pytorch/issues/71209)
410391 res = torch .all (x , ** kwargs )
411- res = _axis_none_keepdims (res , ndim , keepdims )
392+ res = _axis_none_keepdims (res , x . ndim , keepdims )
412393 return res .to (torch .bool )
413394
414395 # torch.all doesn't return bool for uint8
0 commit comments