@@ -44,23 +44,19 @@ def _check_norm(norm):
4444 )
4545
4646
47- def _check_shapes_for_direct (xs , shape , axes ):
47+ def _check_shapes_for_direct (s , shape , axes ):
4848 if len (axes ) > 7 : # Intel MKL supports up to 7D
4949 return False
50- if not ( len (xs ) == len (shape ) ):
51- # full-dimensional transform
50+ if len (s ) != len (shape ):
51+ # not a full-dimensional transform
5252 return False
53- if not ( len (set (axes )) == len (axes ) ):
53+ if len (set (axes )) != len (axes ):
5454 # repeated axes
5555 return False
56- for xsi , ai in zip (xs , axes ):
57- try :
58- sh_ai = shape [ai ]
59- except IndexError :
60- raise ValueError ("Invalid axis (%d) specified" % ai )
61-
62- if not (xsi == sh_ai ):
63- return False
56+ new_shape = tuple (shape [ax ] for ax in axes )
57+ if tuple (s ) != new_shape :
58+ # trimming or padding is needed
59+ return False
6460 return True
6561
6662
@@ -78,30 +74,6 @@ def _compute_fwd_scale(norm, n, shape):
7874 return np .sqrt (fsc )
7975
8076
81- def _cook_nd_args (a , s = None , axes = None , invreal = False ):
82- if s is None :
83- shapeless = True
84- if axes is None :
85- s = list (a .shape )
86- else :
87- try :
88- s = [a .shape [i ] for i in axes ]
89- except IndexError :
90- # fake s designed to trip the ValueError further down
91- s = range (len (axes ) + 1 )
92- pass
93- else :
94- shapeless = False
95- s = list (s )
96- if axes is None :
97- axes = list (range (- len (s ), 0 ))
98- if len (s ) != len (axes ):
99- raise ValueError ("Shape and axes have different lengths." )
100- if invreal and shapeless :
101- s [- 1 ] = (a .shape [axes [- 1 ]] - 1 ) * 2
102- return s , axes
103-
104-
10577# copied from scipy.fft module
10678# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
10779def _datacopied (arr , original ):
@@ -129,89 +101,7 @@ def _flat_to_multi(ind, shape):
129101 return m_ind
130102
131103
132- # copied from scipy.fftpack.helper
133- def _init_nd_shape_and_axes (x , shape , axes ):
134- """Handle shape and axes arguments for n-dimensional transforms.
135- Returns the shape and axes in a standard form, taking into account negative
136- values and checking for various potential errors.
137- Parameters
138- ----------
139- x : array_like
140- The input array.
141- shape : int or array_like of ints or None
142- The shape of the result. If both `shape` and `axes` (see below) are
143- None, `shape` is ``x.shape``; if `shape` is None but `axes` is
144- not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``.
145- If `shape` is -1, the size of the corresponding dimension of `x` is
146- used.
147- axes : int or array_like of ints or None
148- Axes along which the calculation is computed.
149- The default is over all axes.
150- Negative indices are automatically converted to their positive
151- counterpart.
152- Returns
153- -------
154- shape : array
155- The shape of the result. It is a 1D integer array.
156- axes : array
157- The shape of the result. It is a 1D integer array.
158- """
159- x = np .asarray (x )
160- noshape = shape is None
161- noaxes = axes is None
162-
163- if noaxes :
164- axes = np .arange (x .ndim , dtype = np .intc )
165- else :
166- axes = np .atleast_1d (axes )
167-
168- if axes .size == 0 :
169- axes = axes .astype (np .intc )
170-
171- if not axes .ndim == 1 :
172- raise ValueError ("when given, axes values must be a scalar or vector" )
173- if not np .issubdtype (axes .dtype , np .integer ):
174- raise ValueError ("when given, axes values must be integers" )
175-
176- axes = np .where (axes < 0 , axes + x .ndim , axes )
177-
178- if axes .size != 0 and (axes .max () >= x .ndim or axes .min () < 0 ):
179- raise ValueError ("axes exceeds dimensionality of input" )
180- if axes .size != 0 and np .unique (axes ).shape != axes .shape :
181- raise ValueError ("all axes must be unique" )
182-
183- if not noshape :
184- shape = np .atleast_1d (shape )
185- elif np .isscalar (x ):
186- shape = np .array ([], dtype = np .intc )
187- elif noaxes :
188- shape = np .array (x .shape , dtype = np .intc )
189- else :
190- shape = np .take (x .shape , axes )
191-
192- if shape .size == 0 :
193- shape = shape .astype (np .intc )
194-
195- if shape .ndim != 1 :
196- raise ValueError ("when given, shape values must be a scalar or vector" )
197- if not np .issubdtype (shape .dtype , np .integer ):
198- raise ValueError ("when given, shape values must be integers" )
199- if axes .shape != shape .shape :
200- raise ValueError (
201- "when given, axes and shape arguments have to be of the same length"
202- )
203-
204- shape = np .where (shape == - 1 , np .array (x .shape )[axes ], shape )
205- if shape .size != 0 and (shape < 1 ).any ():
206- raise ValueError (f"invalid number of data points ({ shape } ) specified" )
207-
208- return shape , axes
209-
210-
211104def _iter_complementary (x , axes , func , kwargs , result ):
212- if axes is None :
213- # s and axes are None, direct N-D FFT
214- return func (x , ** kwargs , out = result )
215105 x_shape = x .shape
216106 nd = x .ndim
217107 r = list (range (nd ))
@@ -260,9 +150,6 @@ def _iter_fftnd(
260150 direction = + 1 ,
261151 scale_function = lambda ind : 1.0 ,
262152):
263- a = np .asarray (a )
264- s , axes = _init_nd_shape_and_axes (a , s , axes )
265-
266153 # Combine the two, but in reverse, to end with the first axis given.
267154 axes_and_s = list (zip (axes , s ))[::- 1 ]
268155 # We try to use in-place calculations where possible, which is
@@ -309,13 +196,14 @@ def _output_dtype(dt):
309196def _pad_array (arr , s , axes ):
310197 """Pads array arr with zeros to attain shape s associated with axes"""
311198 arr_shape = arr .shape
199+ new_shape = tuple (arr_shape [ax ] for ax in axes )
200+ if tuple (s ) == new_shape :
201+ return arr
202+
312203 no_padding = True
313204 pad_widths = [(0 , 0 )] * len (arr_shape )
314205 for si , ai in zip (s , axes ):
315- try :
316- shp_i = arr_shape [ai ]
317- except IndexError :
318- raise ValueError (f"Invalid axis { ai } specified" )
206+ shp_i = arr_shape [ai ]
319207 if si > shp_i :
320208 no_padding = False
321209 pad_widths [ai ] = (0 , si - shp_i )
@@ -345,14 +233,14 @@ def _trim_array(arr, s, axes):
345233 """
346234
347235 arr_shape = arr .shape
236+ new_shape = tuple (arr_shape [ax ] for ax in axes )
237+ if tuple (s ) == new_shape :
238+ return arr
239+
348240 no_trim = True
349241 ind = [slice (None , None , None )] * len (arr_shape )
350242 for si , ai in zip (s , axes ):
351- try :
352- shp_i = arr_shape [ai ]
353- except IndexError :
354- raise ValueError (f"Invalid axis { ai } specified" )
355- if si < shp_i :
243+ if si < arr_shape [ai ]:
356244 no_trim = False
357245 ind [ai ] = slice (None , si , None )
358246 if no_trim :
@@ -383,16 +271,11 @@ def _c2c_fftnd_impl(
383271 if direction not in [- 1 , + 1 ]:
384272 raise ValueError ("Direction of FFT should +1 or -1" )
385273
274+ x = np .asarray (x )
386275 valid_dtypes = [np .complex64 , np .complex128 , np .float32 , np .float64 ]
387276 # _direct_fftnd requires complex type, and full-dimensional transform
388- if isinstance (x , np .ndarray ) and x .size != 0 and x .ndim > 1 :
389- _direct = s is None and axes is None
390- if _direct :
391- _direct = x .ndim <= 7 # Intel MKL only supports FFT up to 7D
392- if not _direct :
393- xs , xa = _cook_nd_args (x , s , axes )
394- if _check_shapes_for_direct (xs , x .shape , xa ):
395- _direct = True
277+ if x .size != 0 and x .ndim > 1 :
278+ _direct = _check_shapes_for_direct (s , x .shape , axes )
396279 _direct = _direct and x .dtype in valid_dtypes
397280 else :
398281 _direct = False
@@ -405,14 +288,23 @@ def _c2c_fftnd_impl(
405288 out = out ,
406289 )
407290 else :
408- if s is None and x .dtype in valid_dtypes :
409- x = np .asarray (x )
291+ new_shape = tuple (x .shape [ax ] for ax in axes )
292+ if (
293+ tuple (s ) == new_shape
294+ and x .dtype in valid_dtypes
295+ and len (set (axes )) == len (axes )
296+ ):
410297 if out is None :
411298 res = np .empty_like (x , dtype = _output_dtype (x .dtype ))
412299 else :
413300 _validate_out_array (out , x , _output_dtype (x .dtype ))
414301 res = out
415302
303+ # MKL is capable of doing batch N-D FFT, it is not required to
304+ # manually loop over the batches as done in _iter_complementary and
305+ # it is the reason for bad performance mentioned in the gh-issue-#67
306+ # TODO: implement a batch N-D FFT using MKL
307+ # _iter_complementary performs batches of N-D FFT
416308 return _iter_complementary (
417309 x ,
418310 axes ,
@@ -434,14 +326,9 @@ def _c2c_fftnd_impl(
434326
435327def _r2c_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
436328 a = np .asarray (x )
437- no_trim = (s is None ) and (axes is None )
438- s , axes = _cook_nd_args (a , s , axes )
439- axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
440329 la = axes [- 1 ]
441-
442330 # trim array, so that rfft avoids doing unnecessary computations
443- if not no_trim :
444- a = _trim_array (a , s , axes )
331+ a = _trim_array (a , s , axes )
445332
446333 # last axis is not included since we calculate r2c FFT separately
447334 # and not in the loop
@@ -453,13 +340,11 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
453340 a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = res )
454341 res = a
455342 if len (s ) > 1 :
456-
457343 len_axes = len (axes )
458344 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
459- if not no_trim :
460- ss = list (s )
461- ss [- 1 ] = a .shape [la ]
462- a = _pad_array (a , tuple (ss ), axes )
345+ ss = list (s )
346+ ss [- 1 ] = a .shape [la ]
347+ a = _pad_array (a , tuple (ss ), axes )
463348 # a series of ND c2c FFTs along last axis
464349 ss , aa = _remove_axis (s , axes , - 1 )
465350 ind = [slice (None , None , 1 )] * len (s )
@@ -494,17 +379,12 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
494379
495380def _c2r_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
496381 a = np .asarray (x )
497- no_trim = (s is None ) and (axes is None )
498- s , axes = _cook_nd_args (a , s , axes , invreal = True )
499- axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
500382 la = axes [- 1 ]
501- if not no_trim :
502- a = _trim_array (a , s , axes )
503383 if len (s ) > 1 :
504384 len_axes = len (axes )
505385 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
506- if not no_trim :
507- a = _pad_array (a , s , axes )
386+ a = _trim_array ( a , s , axes )
387+ a = _pad_array (a , s , axes )
508388 # a series of ND c2c FFTs along last axis
509389 # due to need to write into a, we must copy
510390 a = a if _datacopied (a , x ) else a .copy ()
@@ -521,8 +401,8 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
521401 tind = tuple (ind )
522402 a_inp = a [tind ]
523403 # out has real dtype and cannot be used in intermediate steps
524- # ss and aa are reversed since np.irfftn uses forward order but
525- # np .ifftn uses reverse order see numpy-gh-28950
404+ # ss and aa are reversed since np.fft. irfftn uses forward order
405+ # but np.fft .ifftn uses reverse order see numpy-gh-28950
526406 _ = _c2c_fftnd_impl (
527407 a_inp , s = ss [::- 1 ], axes = aa [::- 1 ], out = a_inp , direction = - 1
528408 )
0 commit comments