@@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True):
116116_py_scalars = (bool , int , float , complex )
117117
118118
119- def result_type (* arrays_and_dtypes : Array | DType | complex ) -> DType :
119+ def result_type (
120+ * arrays_and_dtypes : Array | DType | bool | int | float | complex
121+ ) -> DType :
120122 num = len (arrays_and_dtypes )
121123
122124 if num == 0 :
@@ -550,10 +552,16 @@ def count_nonzero(
550552 return result
551553
552554
553- def where (condition : Array , x1 : Array , x2 : Array , / ) -> Array :
555+ def where (
556+ condition : Array ,
557+ x1 : Array | bool | int | float | complex ,
558+ x2 : Array | bool | int | float | complex ,
559+ / ,
560+ ) -> Array :
554561 x1 , x2 = _fix_promotion (x1 , x2 )
555562 return torch .where (condition , x1 , x2 )
556563
564+
557565# torch.reshape doesn't have the copy keyword
558566def reshape (x : Array ,
559567 / ,
@@ -622,7 +630,7 @@ def linspace(start: Union[int, float],
622630# torch.full does not accept an int size
623631# https://github.com/pytorch/pytorch/issues/70906
624632def full (shape : Union [int , Tuple [int , ...]],
625- fill_value : complex ,
633+ fill_value : bool | int | float | complex ,
626634 * ,
627635 dtype : Optional [DType ] = None ,
628636 device : Optional [Device ] = None ,
0 commit comments