3333#include " NumCpp/Core/Internal/StaticAsserts.hpp"
3434#include " NumCpp/Core/Internal/StlAlgorithms.hpp"
3535#include " NumCpp/Core/Types.hpp"
36- #include " NumCpp/FFT/ifft .hpp"
36+ #include " NumCpp/FFT/ifft2 .hpp"
3737#include " NumCpp/Functions/complex.hpp"
3838#include " NumCpp/Functions/real.hpp"
3939#include " NumCpp/NdArray.hpp"
@@ -49,9 +49,49 @@ namespace nc::fft
4949 // / @param x the data
5050 // / @param shape Shape (length of each transformed axis) of the output
5151 // /
52- inline NdArray<double > internal_irfft2 (const NdArray<std::complex <double >>& x, const Shape& inShape )
52+ inline NdArray<double > irfft2_internal (const NdArray<std::complex <double >>& x, const Shape& shape )
5353 {
54- return {};
54+ if (x.size () == 0 || shape.rows == 0 || shape.cols == 0 )
55+ {
56+ return {};
57+ }
58+
59+ const auto necessaryInputPoints = shape.cols / 2 + 1 ;
60+ auto input = NdArray<std::complex <double >>{};
61+ if (x.numCols () > necessaryInputPoints)
62+ {
63+ input = x (x.rSlice (), Slice (necessaryInputPoints + 1 ));
64+ }
65+ else if (x.numCols () < necessaryInputPoints)
66+ {
67+ input = NdArray<std::complex <double >>(shape.rows , necessaryInputPoints).zeros ();
68+ input.put (x.rSlice (), x.cSlice (), x);
69+ }
70+ else
71+ {
72+ input = x;
73+ }
74+
75+ auto realN = 2 * (input.numCols () - 1 );
76+ realN += shape.cols % 2 == 1 ? 1 : 0 ;
77+ auto fullOutput = NdArray<std::complex <double >>(shape.rows , realN).zeros ();
78+ for (auto row = 0u ; row < input.numRows (); ++row)
79+ {
80+ stl_algorithms::copy (input.begin (row), input.end (row), fullOutput.begin (row));
81+ }
82+ stl_algorithms::transform (fullOutput.begin (0 ) + 1 ,
83+ fullOutput.begin (0 ) + input.numCols (),
84+ fullOutput.rbegin (0 ),
85+ [](const auto & value) { return std::conj (value); });
86+ for (auto col = 1u ; col < input.numCols (); ++col)
87+ {
88+ stl_algorithms::transform (input.colbegin (col) + 1 ,
89+ input.colend (col),
90+ fullOutput.rcolbegin (fullOutput.numCols () - col),
91+ [](const auto & value) { return std::conj (value); });
92+ }
93+
94+ return real (ifft2_internal (fullOutput, shape));
5595 }
5696 } // namespace detail
5797
@@ -71,7 +111,8 @@ namespace nc::fft
71111 {
72112 STATIC_ASSERT_ARITHMETIC (dtype);
73113
74- return {};
114+ const auto data = nc::complex <dtype, double >(inArray);
115+ return detail::irfft2_internal (data, inShape);
75116 }
76117
77118 // ============================================================================
@@ -89,6 +130,8 @@ namespace nc::fft
89130 {
90131 STATIC_ASSERT_ARITHMETIC (dtype);
91132
92- return irfft2 (inArray, inArray.shape ());
133+ const auto & shape = inArray.shape ();
134+ const auto newCols = 2 * (shape.cols - 1 );
135+ return irfft2 (inArray, { shape.rows , newCols });
93136 }
94137} // namespace nc::fft
0 commit comments