11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Union , Optional , Literal
3+ from collections .abc import Sequence
4+ from typing import Union , Optional , Literal
45
5- if TYPE_CHECKING :
6- from ._typing import Device , ndarray , DType
7- from collections .abc import Sequence
6+ from ._typing import Device , Array , Dtype , Namespace
87
98# Note: NumPy fft functions improperly upcast float32 and complex64 to
109# complex128, which is why we require wrapping them all here.
1110
1211def fft (
13- x : ndarray ,
12+ x : Array ,
1413 / ,
15- xp ,
14+ xp : Namespace ,
1615 * ,
1716 n : Optional [int ] = None ,
1817 axis : int = - 1 ,
1918 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
20- ) -> ndarray :
19+ ) -> Array :
2120 res = xp .fft .fft (x , n = n , axis = axis , norm = norm )
2221 if x .dtype in [xp .float32 , xp .complex64 ]:
2322 return res .astype (xp .complex64 )
2423 return res
2524
2625def ifft (
27- x : ndarray ,
26+ x : Array ,
2827 / ,
29- xp ,
28+ xp : Namespace ,
3029 * ,
3130 n : Optional [int ] = None ,
3231 axis : int = - 1 ,
3332 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
34- ) -> ndarray :
33+ ) -> Array :
3534 res = xp .fft .ifft (x , n = n , axis = axis , norm = norm )
3635 if x .dtype in [xp .float32 , xp .complex64 ]:
3736 return res .astype (xp .complex64 )
3837 return res
3938
4039def fftn (
41- x : ndarray ,
40+ x : Array ,
4241 / ,
43- xp ,
42+ xp : Namespace ,
4443 * ,
4544 s : Sequence [int ] = None ,
4645 axes : Sequence [int ] = None ,
4746 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
48- ) -> ndarray :
47+ ) -> Array :
4948 res = xp .fft .fftn (x , s = s , axes = axes , norm = norm )
5049 if x .dtype in [xp .float32 , xp .complex64 ]:
5150 return res .astype (xp .complex64 )
5251 return res
5352
5453def ifftn (
55- x : ndarray ,
54+ x : Array ,
5655 / ,
57- xp ,
56+ xp : Namespace ,
5857 * ,
5958 s : Sequence [int ] = None ,
6059 axes : Sequence [int ] = None ,
6160 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
62- ) -> ndarray :
61+ ) -> Array :
6362 res = xp .fft .ifftn (x , s = s , axes = axes , norm = norm )
6463 if x .dtype in [xp .float32 , xp .complex64 ]:
6564 return res .astype (xp .complex64 )
6665 return res
6766
6867def rfft (
69- x : ndarray ,
68+ x : Array ,
7069 / ,
71- xp ,
70+ xp : Namespace ,
7271 * ,
7372 n : Optional [int ] = None ,
7473 axis : int = - 1 ,
7574 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
76- ) -> ndarray :
75+ ) -> Array :
7776 res = xp .fft .rfft (x , n = n , axis = axis , norm = norm )
7877 if x .dtype == xp .float32 :
7978 return res .astype (xp .complex64 )
8079 return res
8180
8281def irfft (
83- x : ndarray ,
82+ x : Array ,
8483 / ,
85- xp ,
84+ xp : Namespace ,
8685 * ,
8786 n : Optional [int ] = None ,
8887 axis : int = - 1 ,
8988 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
90- ) -> ndarray :
89+ ) -> Array :
9190 res = xp .fft .irfft (x , n = n , axis = axis , norm = norm )
9291 if x .dtype == xp .complex64 :
9392 return res .astype (xp .float32 )
9493 return res
9594
9695def rfftn (
97- x : ndarray ,
96+ x : Array ,
9897 / ,
99- xp ,
98+ xp : Namespace ,
10099 * ,
101100 s : Sequence [int ] = None ,
102101 axes : Sequence [int ] = None ,
103102 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
104- ) -> ndarray :
103+ ) -> Array :
105104 res = xp .fft .rfftn (x , s = s , axes = axes , norm = norm )
106105 if x .dtype == xp .float32 :
107106 return res .astype (xp .complex64 )
108107 return res
109108
110109def irfftn (
111- x : ndarray ,
110+ x : Array ,
112111 / ,
113- xp ,
112+ xp : Namespace ,
114113 * ,
115114 s : Sequence [int ] = None ,
116115 axes : Sequence [int ] = None ,
117116 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
118- ) -> ndarray :
117+ ) -> Array :
119118 res = xp .fft .irfftn (x , s = s , axes = axes , norm = norm )
120119 if x .dtype == xp .complex64 :
121120 return res .astype (xp .float32 )
122121 return res
123122
124123def hfft (
125- x : ndarray ,
124+ x : Array ,
126125 / ,
127- xp ,
126+ xp : Namespace ,
128127 * ,
129128 n : Optional [int ] = None ,
130129 axis : int = - 1 ,
131130 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
132- ) -> ndarray :
131+ ) -> Array :
133132 res = xp .fft .hfft (x , n = n , axis = axis , norm = norm )
134133 if x .dtype in [xp .float32 , xp .complex64 ]:
135134 return res .astype (xp .float32 )
136135 return res
137136
138137def ihfft (
139- x : ndarray ,
138+ x : Array ,
140139 / ,
141- xp ,
140+ xp : Namespace ,
142141 * ,
143142 n : Optional [int ] = None ,
144143 axis : int = - 1 ,
145144 norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
146- ) -> ndarray :
145+ ) -> Array :
147146 res = xp .fft .ihfft (x , n = n , axis = axis , norm = norm )
148147 if x .dtype in [xp .float32 , xp .complex64 ]:
149148 return res .astype (xp .complex64 )
@@ -152,12 +151,12 @@ def ihfft(
152151def fftfreq (
153152 n : int ,
154153 / ,
155- xp ,
154+ xp : Namespace ,
156155 * ,
157156 d : float = 1.0 ,
158- dtype : Optional [DType ] = None ,
159- device : Optional [Device ] = None
160- ) -> ndarray :
157+ dtype : Optional [Dtype ] = None ,
158+ device : Optional [Device ] = None ,
159+ ) -> Array :
161160 if device not in ["cpu" , None ]:
162161 raise ValueError (f"Unsupported device { device !r} " )
163162 res = xp .fft .fftfreq (n , d = d )
@@ -168,23 +167,27 @@ def fftfreq(
168167def rfftfreq (
169168 n : int ,
170169 / ,
171- xp ,
170+ xp : Namespace ,
172171 * ,
173172 d : float = 1.0 ,
174- dtype : Optional [DType ] = None ,
175- device : Optional [Device ] = None
176- ) -> ndarray :
173+ dtype : Optional [Dtype ] = None ,
174+ device : Optional [Device ] = None ,
175+ ) -> Array :
177176 if device not in ["cpu" , None ]:
178177 raise ValueError (f"Unsupported device { device !r} " )
179178 res = xp .fft .rfftfreq (n , d = d )
180179 if dtype is not None :
181180 return res .astype (dtype )
182181 return res
183182
184- def fftshift (x : ndarray , / , xp , * , axes : Union [int , Sequence [int ]] = None ) -> ndarray :
183+ def fftshift (
184+ x : Array , / , xp : Namespace , * , axes : Union [int , Sequence [int ]] = None
185+ ) -> Array :
185186 return xp .fft .fftshift (x , axes = axes )
186187
187- def ifftshift (x : ndarray , / , xp , * , axes : Union [int , Sequence [int ]] = None ) -> ndarray :
188+ def ifftshift (
189+ x : Array , / , xp : Namespace , * , axes : Union [int , Sequence [int ]] = None
190+ ) -> Array :
188191 return xp .fft .ifftshift (x , axes = axes )
189192
190193__all__ = [
0 commit comments