Skip to content

Commit cda8763

Browse files
committed
added fft2 and ifft2
1 parent 815bea2 commit cda8763

File tree

8 files changed

+1059
-920
lines changed

8 files changed

+1059
-920
lines changed

include/NumCpp/FFT/fft.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,11 @@ namespace nc::fft
6363
const auto k = static_cast<double>(&resultElement - result.data());
6464
const auto minusTwoPiKOverN = -constants::twoPi * k / static_cast<double>(n);
6565
resultElement = std::complex<double>{ 0., 0. };
66-
std::for_each(x.begin(),
67-
x.begin() + std::min(n, x.size()),
68-
[minusTwoPiKOverN, &resultElement, &x](const auto& value)
69-
{
70-
const auto m = static_cast<double>(&value - x.data());
71-
const auto angle = minusTwoPiKOverN * m;
72-
resultElement += (value * std::polar(1., angle));
73-
});
66+
for (auto m = 0u; m < std::min(n, x.size()); ++m)
67+
{
68+
const auto angle = minusTwoPiKOverN * static_cast<double>(m);
69+
resultElement += (x[m] * std::polar(1., angle));
70+
}
7471
});
7572

7673
return result;

include/NumCpp/FFT/fft2.hpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,35 @@ namespace nc::fft
4646
///
4747
inline NdArray<std::complex<double>> fft2_internal(const NdArray<std::complex<double>>& x, const Shape& shape)
4848
{
49-
return {};
49+
if (shape.rows == 0 || shape.cols == 0)
50+
{
51+
return {};
52+
}
53+
54+
auto result = NdArray<std::complex<double>>(shape.rows, shape.cols);
55+
56+
stl_algorithms::for_each(result.begin(),
57+
result.end(),
58+
[&](auto& resultElement)
59+
{
60+
const auto i = &resultElement - result.data();
61+
const auto k = static_cast<double>(i / shape.cols);
62+
const auto l = static_cast<double>(i % shape.cols);
63+
resultElement = std::complex<double>{ 0., 0. };
64+
for (auto m = 0u; m < std::min(shape.rows, x.numRows()); ++m)
65+
{
66+
for (auto n = 0u; n < std::min(shape.cols, x.numCols()); ++n)
67+
{
68+
const auto angle =
69+
-constants::twoPi *
70+
(((static_cast<double>(m) * k) / static_cast<double>(shape.rows)) +
71+
((static_cast<double>(n) * l) / static_cast<double>(shape.cols)));
72+
resultElement += (x(m, n) * std::polar(1., angle));
73+
}
74+
}
75+
});
76+
77+
return result;
5078
}
5179
} // namespace detail
5280

@@ -66,7 +94,8 @@ namespace nc::fft
6694
{
6795
STATIC_ASSERT_ARITHMETIC(dtype);
6896

69-
return {};
97+
const auto data = nc::complex<dtype, double>(inArray);
98+
return detail::fft2_internal(data, inShape);
7099
}
71100

72101
//===========================================================================
@@ -103,7 +132,8 @@ namespace nc::fft
103132
{
104133
STATIC_ASSERT_ARITHMETIC(dtype);
105134

106-
return {};
135+
const auto data = nc::complex<dtype, double>(inArray);
136+
return detail::fft2_internal(data, inShape);
107137
}
108138

109139
//============================================================================

include/NumCpp/FFT/ifft.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ namespace nc::fft
4242
{
4343
//===========================================================================
4444
// Method Description:
45-
/// Fast Fourier Transform
45+
/// Inverse Fast Fourier Transform
4646
///
4747
/// @param x the data
4848
/// @param n Length of the transformed axis of the output.
@@ -63,14 +63,12 @@ namespace nc::fft
6363
const auto m = static_cast<double>(&resultElement - result.data());
6464
const auto minusTwoPiKOverN = constants::twoPi * m / static_cast<double>(n);
6565
resultElement = std::complex<double>{ 0., 0. };
66-
std::for_each(x.begin(),
67-
x.begin() + std::min(n, x.size()),
68-
[minusTwoPiKOverN, &resultElement, &x](const auto& value)
69-
{
70-
const auto k = static_cast<double>(&value - x.data());
71-
const auto angle = minusTwoPiKOverN * k;
72-
resultElement += (value * std::polar(1., angle));
73-
});
66+
for (auto k = 0u; k < std::min(n, x.size()); ++k)
67+
{
68+
const auto angle = minusTwoPiKOverN * static_cast<double>(k);
69+
resultElement += (x[k] * std::polar(1., angle));
70+
}
71+
7472
resultElement /= n;
7573
});
7674

include/NumCpp/FFT/ifft2.hpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,43 @@ namespace nc::fft
3939
{
4040
//===========================================================================
4141
// Method Description:
42-
/// Fast Fourier Transform
42+
/// Inverse Fast Fourier Transform
4343
///
4444
/// @param x the data
4545
/// @param shape Shape (length of each transformed axis) of the output
4646
///
4747
inline NdArray<std::complex<double>> ifft2_internal(const NdArray<std::complex<double>>& x, const Shape& shape)
4848
{
49-
return {};
49+
if (shape.rows == 0 || shape.cols == 0)
50+
{
51+
return {};
52+
}
53+
54+
auto result = NdArray<std::complex<double>>(shape.rows, shape.cols);
55+
56+
stl_algorithms::for_each(result.begin(),
57+
result.end(),
58+
[&](auto& resultElement)
59+
{
60+
const auto i = &resultElement - result.data();
61+
const auto m = static_cast<double>(i / shape.cols);
62+
const auto n = static_cast<double>(i % shape.cols);
63+
resultElement = std::complex<double>{ 0., 0. };
64+
for (auto k = 0u; k < std::min(shape.rows, x.numRows()); ++k)
65+
{
66+
for (auto l = 0u; l < std::min(shape.cols, x.numCols()); ++l)
67+
{
68+
const auto angle =
69+
constants::twoPi *
70+
(((static_cast<double>(k) * m) / static_cast<double>(shape.rows)) +
71+
((static_cast<double>(l) * n) / static_cast<double>(shape.cols)));
72+
resultElement += (x(k, l) * std::polar(1., angle));
73+
}
74+
}
75+
resultElement /= shape.size();
76+
});
77+
78+
return result;
5079
}
5180
} // namespace detail
5281

@@ -62,11 +91,12 @@ namespace nc::fft
6291
/// @return NdArray
6392
///
6493
template<typename dtype>
65-
NdArray<double> ifft2(const NdArray<dtype>& inArray, const Shape& inShape)
94+
NdArray<std::complex<double>> ifft2(const NdArray<dtype>& inArray, const Shape& inShape)
6695
{
6796
STATIC_ASSERT_ARITHMETIC(dtype);
6897

69-
return {};
98+
const auto data = nc::complex<dtype, double>(inArray);
99+
return detail::ifft2_internal(data, inShape);
70100
}
71101

72102
//===========================================================================
@@ -80,7 +110,7 @@ namespace nc::fft
80110
/// @return NdArray
81111
///
82112
template<typename dtype>
83-
NdArray<double> ifft2(const NdArray<dtype>& inArray)
113+
NdArray<std::complex<double>> ifft2(const NdArray<dtype>& inArray)
84114
{
85115
STATIC_ASSERT_ARITHMETIC(dtype);
86116

@@ -103,7 +133,8 @@ namespace nc::fft
103133
{
104134
STATIC_ASSERT_ARITHMETIC(dtype);
105135

106-
return {};
136+
const auto data = nc::complex<dtype, double>(inArray);
137+
return detail::ifft2_internal(data, inShape);
107138
}
108139

109140
//============================================================================

include/NumCpp/FFT/irfft.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace nc::fft
4444
{
4545
//===========================================================================
4646
// Method Description:
47-
/// Fast Fourier Transform
47+
/// Inverse Fast Fourier Transform
4848
///
4949
/// @param x the data
5050
/// @param n Length of the transformed axis of the output.

include/NumCpp/FFT/irfft2.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace nc::fft
4444
{
4545
//===========================================================================
4646
// Method Description:
47-
/// Fast Fourier Transform
47+
/// Inverse Fast Fourier Transform
4848
///
4949
/// @param x the data
5050
/// @param shape Shape (length of each transformed axis) of the output

include/NumCpp/FFT/rfft.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,11 @@ namespace nc::fft
6464
const auto k = static_cast<double>(&resultElement - result.data());
6565
const auto minusTwoPiKOverN = -constants::twoPi * k / static_cast<double>(n);
6666
resultElement = std::complex<double>{ 0., 0. };
67-
std::for_each(x.begin(),
68-
x.begin() + std::min(n, x.size()),
69-
[minusTwoPiKOverN, &resultElement, &x](const auto& value)
70-
{
71-
const auto m = static_cast<double>(&value - x.data());
72-
const auto angle = minusTwoPiKOverN * m;
73-
resultElement += (value * std::polar(1., angle));
74-
});
67+
for (auto m = 0u; m < std::min(n, x.size()); ++m)
68+
{
69+
const auto angle = minusTwoPiKOverN * static_cast<double>(m);
70+
resultElement += (x[m] * std::polar(1., angle));
71+
}
7572
});
7673

7774
return result;

0 commit comments

Comments
 (0)