Skip to content

Commit ebf088a

Browse files
committed
implemented irfft2
1 parent d0d89e8 commit ebf088a

File tree

4 files changed

+66
-12
lines changed

4 files changed

+66
-12
lines changed

include/NumCpp/Core/Constants.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace nc::constants
3636
constexpr double c = 3.e8; ///< speed of light
3737
constexpr double e = 2.718281828459045; ///< eulers number
3838
constexpr double inf = std::numeric_limits<double>::infinity(); ///< infinity
39-
constexpr double pi = 3.141592653589793238462643383279502884; ///< Pi
39+
constexpr double pi = 3.141592653589793; ///< Pi
4040
constexpr double twoPi = 2. * pi; ///< 2Pi
4141
const double nan = std::nan("1"); ///< NaN
4242
constexpr auto j = std::complex<double>(0., 1.); // sqrt(-1) unit imaginary number

include/NumCpp/FFT/irfft.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ namespace nc::fft
5656
return {};
5757
}
5858

59-
const auto isOdd = n % 2 == 1;
60-
6159
const auto necessaryInputPoints = n / 2 + 1;
6260
auto input = NdArray<std::complex<double>>{};
6361
if (x.size() > necessaryInputPoints)
@@ -75,7 +73,7 @@ namespace nc::fft
7573
}
7674

7775
auto realN = 2 * (input.size() - 1);
78-
realN += isOdd ? 1 : 0;
76+
realN += n % 2 == 1 ? 1 : 0;
7977
auto fullOutput = NdArray<std::complex<double>>(1, realN);
8078
stl_algorithms::copy(input.begin(), input.end(), fullOutput.begin());
8179
stl_algorithms::transform(fullOutput.begin() + 1,

include/NumCpp/FFT/irfft2.hpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "NumCpp/Core/Internal/StaticAsserts.hpp"
3434
#include "NumCpp/Core/Internal/StlAlgorithms.hpp"
3535
#include "NumCpp/Core/Types.hpp"
36-
#include "NumCpp/FFT/ifft.hpp"
36+
#include "NumCpp/FFT/ifft2.hpp"
3737
#include "NumCpp/Functions/complex.hpp"
3838
#include "NumCpp/Functions/real.hpp"
3939
#include "NumCpp/NdArray.hpp"
@@ -49,9 +49,49 @@ namespace nc::fft
4949
/// @param x the data
5050
/// @param shape Shape (length of each transformed axis) of the output
5151
///
52-
inline NdArray<double> internal_irfft2(const NdArray<std::complex<double>>& x, const Shape& inShape)
52+
inline NdArray<double> irfft2_internal(const NdArray<std::complex<double>>& x, const Shape& shape)
5353
{
54-
return {};
54+
if (x.size() == 0 || shape.rows == 0 || shape.cols == 0)
55+
{
56+
return {};
57+
}
58+
59+
const auto necessaryInputPoints = shape.cols / 2 + 1;
60+
auto input = NdArray<std::complex<double>>{};
61+
if (x.numCols() > necessaryInputPoints)
62+
{
63+
input = x(x.rSlice(), Slice(necessaryInputPoints + 1));
64+
}
65+
else if (x.numCols() < necessaryInputPoints)
66+
{
67+
input = NdArray<std::complex<double>>(shape.rows, necessaryInputPoints).zeros();
68+
input.put(x.rSlice(), x.cSlice(), x);
69+
}
70+
else
71+
{
72+
input = x;
73+
}
74+
75+
auto realN = 2 * (input.numCols() - 1);
76+
realN += shape.cols % 2 == 1 ? 1 : 0;
77+
auto fullOutput = NdArray<std::complex<double>>(shape.rows, realN).zeros();
78+
for (auto row = 0u; row < input.numRows(); ++row)
79+
{
80+
stl_algorithms::copy(input.begin(row), input.end(row), fullOutput.begin(row));
81+
}
82+
stl_algorithms::transform(fullOutput.begin(0) + 1,
83+
fullOutput.begin(0) + input.numCols(),
84+
fullOutput.rbegin(0),
85+
[](const auto& value) { return std::conj(value); });
86+
for (auto col = 1u; col < input.numCols(); ++col)
87+
{
88+
stl_algorithms::transform(input.colbegin(col) + 1,
89+
input.colend(col),
90+
fullOutput.rcolbegin(fullOutput.numCols() - col),
91+
[](const auto& value) { return std::conj(value); });
92+
}
93+
94+
return real(ifft2_internal(fullOutput, shape));
5595
}
5696
} // namespace detail
5797

@@ -71,7 +111,8 @@ namespace nc::fft
71111
{
72112
STATIC_ASSERT_ARITHMETIC(dtype);
73113

74-
return {};
114+
const auto data = nc::complex<dtype, double>(inArray);
115+
return detail::irfft2_internal(data, inShape);
75116
}
76117

77118
//============================================================================
@@ -89,6 +130,8 @@ namespace nc::fft
89130
{
90131
STATIC_ASSERT_ARITHMETIC(dtype);
91132

92-
return irfft2(inArray, inArray.shape());
133+
const auto& shape = inArray.shape();
134+
const auto newCols = 2 * (shape.cols - 1);
135+
return irfft2(inArray, { shape.rows, newCols });
93136
}
94137
} // namespace nc::fft

test/pytest/test_fft.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import NumCppPy as NumCpp # noqa E402
44

5-
NUM_TRIALS = 1
6-
ROUNDING_DIGITS = 6
5+
NUM_TRIALS = 5
6+
ROUNDING_DIGITS = 5
77

88

99
####################################################################################
@@ -1080,4 +1080,17 @@ def test_rfft2():
10801080

10811081
####################################################################################
10821082
def test_irfft2():
1083-
assert False
1083+
for _ in range(NUM_TRIALS):
1084+
shapeInput = np.random.randint(
1085+
10,
1086+
30,
1087+
[
1088+
2,
1089+
],
1090+
)
1091+
data = np.random.randint(0, 100, shapeInput)
1092+
rfft2 = np.fft.rfft2(data)
1093+
cShape = NumCpp.Shape(*rfft2.shape)
1094+
cArray = NumCpp.NdArrayComplexDouble(cShape)
1095+
cArray.setArray(rfft2)
1096+
assert np.array_equal(np.round(NumCpp.irfft2(cArray), ROUNDING_DIGITS), np.round(np.fft.irfft2(rfft2), ROUNDING_DIGITS))

0 commit comments

Comments
 (0)