@@ -29,7 +29,7 @@ import sys
2929
3030import numpy as np
3131
32- if np.lib.NumpyVersion(np.__version__) >= " 2.0.0a0 " :
32+ if np.lib.NumpyVersion(np.__version__) >= " 2.0.0 " :
3333 from numpy._core._multiarray_tests import internal_overlap
3434else :
3535 from numpy.core._multiarray_tests import internal_overlap
@@ -389,9 +389,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
389389 x_arr = _process_arguments(x, n, axis, & axis_, & n_, & in_place, & xnd, 0 )
390390 x_type = cnp.PyArray_TYPE(x_arr)
391391
392- if out is not None :
393- in_place = 0
394- elif x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
392+ if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
395393 # we can operate in place if requested.
396394 if in_place:
397395 if not cnp.PyArray_ISONESEGMENT(x_arr):
@@ -416,6 +414,29 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
416414 x_type = cnp.PyArray_TYPE(x_arr)
417415 in_place = 1
418416
417+ f_arr = None
418+ if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
419+ f_type = cnp.NPY_CFLOAT
420+ else :
421+ f_type = cnp.NPY_CDOUBLE
422+
423+ if out is not None :
424+ out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
425+ _validate_out_array(out, x, out_dtype, axis = axis_, n = n_)
426+ if x is out:
427+ in_place = 1
428+ elif (
429+ _get_element_strides(x) == _get_element_strides(out)
430+ and not np.shares_memory(x, out)
431+ ):
432+ # out array that is used in OneMKL c2c FFT must have the same stride
433+ # as input array and must have no common elements with input array.
434+ # If these conditions are not met, we need to allocate a new array,
435+ # which is done later.
436+ # TODO: check to see if the same stride condition can be relaxed
437+ f_arr = < cnp.ndarray> out
438+ in_place = 0
439+
419440 if in_place:
420441 _cache_capsule = _tls_dfti_cache_capsule()
421442 _cache = < DftiCache * > cpython.pycapsule.PyCapsule_GetPointer(
@@ -453,25 +474,14 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
453474 ind[axis_] = slice (0 , n_, None )
454475 x_arr = x_arr[tuple (ind)]
455476
456- return x_arr
457- else :
458- if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
459- f_type = cnp.NPY_CFLOAT
477+ if out is not None :
478+ out[...] = x_arr
479+ return out
460480 else :
461- f_type = cnp.NPY_CDOUBLE
462-
463- if out is None :
481+ return x_arr
482+ else :
483+ if f_arr is None :
464484 f_arr = _allocate_result(x_arr, n_, axis_, f_type)
465- else :
466- out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
467- _validate_out_array(out, x, out_dtype, axis = axis_, n = n_)
468- # out array that is used in OneMKL c2c FFT must have the exact same
469- # stride as input array. If not, we need to allocate a new array.
470- # TODO: check to see if this condition can be relaxed
471- if _get_element_strides(x) == _get_element_strides(out):
472- f_arr = < cnp.ndarray> out
473- else :
474- f_arr = _allocate_result(x_arr, n_, axis_, f_type)
475485
476486 # call out-of-place FFT
477487 _cache_capsule = _tls_dfti_cache_capsule()
@@ -612,9 +622,10 @@ def _r2c_fft1d_impl(
612622 # be compared directly.
613623 # TODO: currently instead of this condition, we check both input
614624 # and output to be c_contig or f_contig, relax this condition
625+ # In addition, input and output data sets must have no common elements
615626 c_contig = x.flags.c_contiguous and out.flags.c_contiguous
616627 f_contig = x.flags.f_contiguous and out.flags.f_contiguous
617- if c_contig or f_contig:
628+ if c_contig or f_contig and not np.shares_memory(x, out) :
618629 f_arr = < cnp.ndarray> out
619630 else :
620631 f_arr = _allocate_result(x_arr, f_shape, axis_, f_type)
@@ -715,9 +726,10 @@ def _c2r_fft1d_impl(
715726 # strides cannot be compared directly.
716727 # TODO: currently instead of this condition, we check both input
717728 # and output to be c_contig or f_contig, relax this condition
729+ # Also input and output data sets must have no common elements
718730 c_contig = x.flags.c_contiguous and out.flags.c_contiguous
719731 f_contig = x.flags.f_contiguous and out.flags.f_contiguous
720- if c_contig or f_contig:
732+ if c_contig or f_contig and not np.shares_memory(x, out) :
721733 f_arr = < cnp.ndarray> out
722734 else :
723735 f_arr = _allocate_result(x_arr, n_, axis_, f_type)
@@ -755,13 +767,13 @@ def _c2r_fft1d_impl(
755767
756768
757769def _direct_fftnd (
758- x , direction = + 1 , double fsc = 1.0 , out = None
770+ x , direction = + 1 , double fsc = 1.0 , in_place = 0 , out = None
759771):
760772 """ Perform n-dimensional FFT over all axes"""
761773 cdef int err
762774 cdef cnp.ndarray x_arr " xxnd_arrayObject"
763775 cdef cnp.ndarray f_arr " ffnd_arrayObject"
764- cdef int in_place, x_type, f_type
776+ cdef int x_type, f_type
765777
766778 if direction not in [- 1 , + 1 ]:
767779 raise ValueError (" Direction of FFT should +1 or -1" )
@@ -779,7 +791,7 @@ def _direct_fftnd(
779791 raise ValueError (" An input argument x is not an array-like object" )
780792
781793 # a copy was made, so we can work in place.
782- in_place = 1 if _datacopied(x_arr, x) else 0
794+ in_place = 1 if _datacopied(x_arr, x) else in_place
783795
784796 x_type = cnp.PyArray_TYPE(x_arr)
785797 if (
@@ -798,15 +810,35 @@ def _direct_fftnd(
798810 assert x_type == cnp.NPY_CDOUBLE
799811 in_place = 1
800812
801- if out is not None :
802- in_place = 0
803-
804813 if in_place:
805814 if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT:
806815 in_place = 1
807816 else :
808817 in_place = 0
809818
819+ f_arr = None
820+ if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
821+ f_type = cnp.NPY_CDOUBLE
822+ else :
823+ f_type = cnp.NPY_CFLOAT
824+
825+ if out is not None :
826+ out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
827+ _validate_out_array(out, x, out_dtype)
828+ if x is out:
829+ in_place = 1
830+ elif (
831+ _get_element_strides(x) == _get_element_strides(out)
832+ and not np.shares_memory(x, out)
833+ ):
834+ # out array that is used in OneMKL c2c FFT must have the same stride
835+ # as input array and must have no common elements with input array.
836+ # If these conditions are not met, we need to allocate a new array,
837+ # which is done later.
838+ # TODO: check to see if the same stride condition can be relaxed
839+ f_arr = < cnp.ndarray> out
840+ in_place = 0
841+
810842 if in_place:
811843 if x_type == cnp.NPY_CDOUBLE:
812844 if direction == 1 :
@@ -821,24 +853,14 @@ def _direct_fftnd(
821853 else :
822854 raise ValueError (" An input argument x is not complex type array" )
823855
824- return x_arr
825- else :
826- if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
827- f_type = cnp.NPY_CDOUBLE
856+ if out is not None :
857+ out[...] = x_arr
858+ return out
828859 else :
829- f_type = cnp.NPY_CFLOAT
830- if out is None :
860+ return x_arr
861+ else :
862+ if f_arr is None :
831863 f_arr = _allocate_result(x_arr, - 1 , 0 , f_type)
832- else :
833- out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
834- _validate_out_array(out, x, out_dtype)
835- # out array that is used in OneMKL c2c FFT must have the exact same
836- # stride as input array. If not, we need to allocate a new array.
837- # TODO: check to see if this condition can be relaxed
838- if _get_element_strides(x) == _get_element_strides(out):
839- f_arr = < cnp.ndarray> out
840- else :
841- f_arr = _allocate_result(x_arr, - 1 , 0 , f_type)
842864
843865 if x_type == cnp.NPY_CDOUBLE:
844866 if direction == 1 :
0 commit comments