22
33from typing import TYPE_CHECKING , Protocol , TypeVar
44
5- from ._types import (
6- dtype as Dtype ,
7- device as Device ,
8- Any ,
9- PyCapsule ,
10- Enum ,
11- ellipsis ,
12- )
5+ if TYPE_CHECKING :
6+ from ._types import (
7+ device as Device ,
8+ Any ,
9+ PyCapsule ,
10+ Enum ,
11+ ellipsis ,
12+ )
13+ from .data_types import DType
1314
14- Self = TypeVar ("Self " , bound = "Array" )
15- # NOTE: when working with py3.11+ this can be ``typing.Self ``.
15+ array = TypeVar ("array " , bound = "Array" )
16+ # NOTE: when working with py3.11+ this can be ``typing.array ``.
1617
1718
1819class Array (Protocol ):
@@ -21,7 +22,7 @@ def __init__(self) -> None:
2122 ...
2223
2324 @property
24- def dtype (self ) -> Dtype :
25+ def dtype (self ) -> DType :
2526 """
2627 Data type of the array elements.
2728
@@ -45,7 +46,7 @@ def device(self) -> Device:
4546 ...
4647
4748 @property
48- def mT (self : Self ) -> Self :
49+ def mT (self : array ) -> array :
4950 """
5051 Transpose of a matrix (or a stack of matrices).
5152
@@ -109,7 +110,7 @@ def size(self) -> int | None:
109110 ...
110111
111112 @property
112- def T (self : Self ) -> Self :
113+ def T (self : array ) -> array :
113114 """
114115 Transpose of the array.
115116
@@ -126,7 +127,7 @@ def T(self: Self) -> Self:
126127 """
127128 ...
128129
129- def __abs__ (self : Self , / ) -> Self :
130+ def __abs__ (self : array , / ) -> array :
130131 """
131132 Calculates the absolute value for each element of an array instance.
132133
@@ -156,7 +157,7 @@ def __abs__(self: Self, /) -> Self:
156157 """
157158 ...
158159
159- def __add__ (self : Self , other : int | float | Self , / ) -> Self :
160+ def __add__ (self : array , other : int | float | array , / ) -> array :
160161 """
161162 Calculates the sum for each element of an array instance with the respective element of the array ``other``.
162163
@@ -183,7 +184,7 @@ def __add__(self: Self, other: int | float | Self, /) -> Self:
183184 """
184185 ...
185186
186- def __and__ (self : Self , other : int | bool | Self , / ) -> Self :
187+ def __and__ (self : array , other : int | bool | array , / ) -> array :
187188 """
188189 Evaluates ``self_i & other_i`` for each element of an array instance with the respective element of the array ``other``.
189190
@@ -394,7 +395,7 @@ def __dlpack_device__(self, /) -> tuple[Enum, int]:
394395 # Note that __eq__ returns an array while `object.__eq__` returns a bool.
395396 # Hence Mypy will complain that this violates the Liskov substitution
396397 # principle - ignore that.
397- def __eq__ (self : Self , other : int | float | bool | Self , / ) -> Self : # xtype : ignore
398+ def __eq__ (self : array , other : int | float | bool | array , / ) -> array : # type : ignore[override]
398399 r"""
399400 Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
400401
@@ -448,7 +449,7 @@ def __float__(self, /) -> float:
448449 """
449450 ...
450451
451- def __floordiv__ (self : Self , other : int | float | Self , / ) -> Self :
452+ def __floordiv__ (self : array , other : int | float | array , / ) -> array :
452453 """
453454 Evaluates ``self_i // other_i`` for each element of an array instance with the respective element of the array ``other``.
454455
@@ -473,7 +474,7 @@ def __floordiv__(self: Self, other: int | float | Self, /) -> Self:
473474 """
474475 ...
475476
476- def __ge__ (self : Self , other : int | float | Self , / ) -> Self :
477+ def __ge__ (self : array , other : int | float | array , / ) -> array :
477478 """
478479 Computes the truth value of ``self_i >= other_i`` for each element of an array instance with the respective element of the array ``other``.
479480
@@ -499,10 +500,10 @@ def __ge__(self: Self, other: int | float | Self, /) -> Self:
499500 ...
500501
501502 def __getitem__ (
502- self : Self ,
503- key : int | slice | ellipsis | tuple [int | slice | ellipsis , ...] | Self ,
503+ self : array ,
504+ key : int | slice | ellipsis | tuple [int | slice | ellipsis , ...] | array ,
504505 / ,
505- ) -> Self :
506+ ) -> array :
506507 """
507508 Returns ``self[key]``.
508509
@@ -520,7 +521,7 @@ def __getitem__(
520521 """
521522 ...
522523
523- def __gt__ (self : Self , other : int | float | Self , / ) -> Self :
524+ def __gt__ (self : array , other : int | float | array , / ) -> array :
524525 """
525526 Computes the truth value of ``self_i > other_i`` for each element of an array instance with the respective element of the array ``other``.
526527
@@ -605,7 +606,7 @@ def __int__(self, /) -> int:
605606 """
606607 ...
607608
608- def __invert__ (self : Self , / ) -> Self :
609+ def __invert__ (self : array , / ) -> array :
609610 """
610611 Evaluates ``~self_i`` for each element of an array instance.
611612
@@ -625,7 +626,7 @@ def __invert__(self: Self, /) -> Self:
625626 """
626627 ...
627628
628- def __le__ (self : Self , other : int | float | Self , / ) -> Self :
629+ def __le__ (self : array , other : int | float | array , / ) -> array :
629630 """
630631 Computes the truth value of ``self_i <= other_i`` for each element of an array instance with the respective element of the array ``other``.
631632
@@ -650,7 +651,7 @@ def __le__(self: Self, other: int | float | Self, /) -> Self:
650651 """
651652 ...
652653
653- def __lshift__ (self : Self , other : int | Self , / ) -> Self :
654+ def __lshift__ (self : array , other : int | array , / ) -> array :
654655 """
655656 Evaluates ``self_i << other_i`` for each element of an array instance with the respective element of the array ``other``.
656657
@@ -672,7 +673,7 @@ def __lshift__(self: Self, other: int | Self, /) -> Self:
672673 """
673674 ...
674675
675- def __lt__ (self : Self , other : int | float | Self , / ) -> Self :
676+ def __lt__ (self : array , other : int | float | array , / ) -> array :
676677 """
677678 Computes the truth value of ``self_i < other_i`` for each element of an array instance with the respective element of the array ``other``.
678679
@@ -697,7 +698,7 @@ def __lt__(self: Self, other: int | float | Self, /) -> Self:
697698 """
698699 ...
699700
700- def __matmul__ (self : Self , other : Self , / ) -> Self :
701+ def __matmul__ (self : array , other : array , / ) -> array :
701702 """
702703 Computes the matrix product.
703704
@@ -746,7 +747,7 @@ def __matmul__(self: Self, other: Self, /) -> Self:
746747 """
747748 ...
748749
749- def __mod__ (self : Self , other : int | float | Self , / ) -> Self :
750+ def __mod__ (self : array , other : int | float | array , / ) -> array :
750751 """
751752 Evaluates ``self_i % other_i`` for each element of an array instance with the respective element of the array ``other``.
752753
@@ -771,7 +772,7 @@ def __mod__(self: Self, other: int | float | Self, /) -> Self:
771772 """
772773 ...
773774
774- def __mul__ (self : Self , other : int | float | Self , / ) -> Self :
775+ def __mul__ (self : array , other : int | float | array , / ) -> array :
775776 r"""
776777 Calculates the product for each element of an array instance with the respective element of the array ``other``.
777778
@@ -802,7 +803,7 @@ def __mul__(self: Self, other: int | float | Self, /) -> Self:
802803 ...
803804
804805 # See note above __eq__ method for explanation of the `type: ignore`
805- def __ne__ (self : Self , other : int | float | bool | Self , / ) -> Self : # type: ignore
806+ def __ne__ (self : array , other : int | float | bool | array , / ) -> array : # type: ignore[override]
806807 """
807808 Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
808809
@@ -830,7 +831,7 @@ def __ne__(self: Self, other: int | float | bool | Self, /) -> Self: # type: ig
830831 """
831832 ...
832833
833- def __neg__ (self : Self , / ) -> Self :
834+ def __neg__ (self : array , / ) -> array :
834835 """
835836 Evaluates ``-self_i`` for each element of an array instance.
836837
@@ -861,7 +862,7 @@ def __neg__(self: Self, /) -> Self:
861862 """
862863 ...
863864
864- def __or__ (self : Self , other : int | bool | Self , / ) -> Self :
865+ def __or__ (self : array , other : int | bool | array , / ) -> array :
865866 """
866867 Evaluates ``self_i | other_i`` for each element of an array instance with the respective element of the array ``other``.
867868
@@ -883,7 +884,7 @@ def __or__(self: Self, other: int | bool | Self, /) -> Self:
883884 """
884885 ...
885886
886- def __pos__ (self : Self , / ) -> Self :
887+ def __pos__ (self : array , / ) -> array :
887888 """
888889 Evaluates ``+self_i`` for each element of an array instance.
889890
@@ -908,7 +909,7 @@ def __pos__(self: Self, /) -> Self:
908909 """
909910 ...
910911
911- def __pow__ (self : Self , other : int | float | Self , / ) -> Self :
912+ def __pow__ (self : array , other : int | float | array , / ) -> array :
912913 r"""
913914 Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of ``other_i`` (the exponent), where ``other_i`` is the corresponding element of the array ``other``.
914915
@@ -940,7 +941,7 @@ def __pow__(self: Self, other: int | float | Self, /) -> Self:
940941 """
941942 ...
942943
943- def __rshift__ (self : Self , other : int | Self , / ) -> Self :
944+ def __rshift__ (self : array , other : int | array , / ) -> array :
944945 """
945946 Evaluates ``self_i >> other_i`` for each element of an array instance with the respective element of the array ``other``.
946947
@@ -963,9 +964,9 @@ def __rshift__(self: Self, other: int | Self, /) -> Self:
963964 ...
964965
965966 def __setitem__ (
966- self : Self ,
967- key : int | slice | ellipsis | tuple [int | slice | ellipsis , ...] | Self ,
968- value : int | float | bool | Self ,
967+ self : array ,
968+ key : int | slice | ellipsis | tuple [int | slice | ellipsis , ...] | array ,
969+ value : int | float | bool | array ,
969970 / ,
970971 ) -> None :
971972 """
@@ -991,7 +992,7 @@ def __setitem__(
991992 """
992993 ...
993994
994- def __sub__ (self : Self , other : int | float | Self , / ) -> Self :
995+ def __sub__ (self : array , other : int | float | array , / ) -> array :
995996 """
996997 Calculates the difference for each element of an array instance with the respective element of the array ``other``.
997998
@@ -1020,7 +1021,7 @@ def __sub__(self: Self, other: int | float | Self, /) -> Self:
10201021 """
10211022 ...
10221023
1023- def __truediv__ (self : Self , other : int | float | Self , / ) -> Self :
1024+ def __truediv__ (self : array , other : int | float | array , / ) -> array :
10241025 r"""
10251026 Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``.
10261027
@@ -1052,7 +1053,7 @@ def __truediv__(self: Self, other: int | float | Self, /) -> Self:
10521053 """
10531054 ...
10541055
1055- def __xor__ (self : Self , other : int | bool | Self , / ) -> Self :
1056+ def __xor__ (self : array , other : int | bool | array , / ) -> array :
10561057 """
10571058 Evaluates ``self_i ^ other_i`` for each element of an array instance with the respective element of the array ``other``.
10581059
@@ -1075,8 +1076,8 @@ def __xor__(self: Self, other: int | bool | Self, /) -> Self:
10751076 ...
10761077
10771078 def to_device (
1078- self : Self , device : Device , / , * , stream : int | Any | None = None
1079- ) -> Self :
1079+ self : array , device : Device , / , * , stream : int | Any | None = None
1080+ ) -> array :
10801081 """
10811082 Copy the array from the device on which it currently resides to the specified ``device``.
10821083
@@ -1101,6 +1102,4 @@ def to_device(
11011102 ...
11021103
11031104
1104- array = Array
1105-
1106- __all__ = ["array" ]
1105+ __all__ = ["Array" ]
0 commit comments