Skip to content

Commit d239426

Browse files
committed
fft now implemented
1 parent 69ea032 commit d239426

File tree

7 files changed

+459
-29
lines changed

7 files changed

+459
-29
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
**Compilers:**
2828
Visual Studio: 2022
29-
GNU: 13.3, 14.2
30-
Clang: 18, 19
29+
GNU: 13.3, 14.2, 15.2
30+
Clang: 18, 19, 20
3131

3232
**Boost Versions:**
3333
1.73+

include/NumCpp/Core/Constants.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ namespace nc::constants
3939
constexpr double pi = 3.141592653589793238462643383279502884; ///< Pi
4040
constexpr double twoPi = 2. * pi; ///< 2Pi
4141
const double nan = std::nan("1"); ///< NaN
42-
constexpr auto j = std::complex<double>(0, 1); // sqrt(-1) unit imaginary number
42+
constexpr auto j = std::complex<double>(0., 1.); // sqrt(-1) unit imaginary number
4343

44-
constexpr double DAYS_PER_WEEK = 7; ///< Number of days in a week
45-
constexpr double MINUTES_PER_HOUR = 60; ///< Number of minutes in an hour
46-
constexpr double SECONDS_PER_MINUTE = 60; ///< Number of seconds in a minute
47-
constexpr double MILLISECONDS_PER_SECOND = 1000; ///< Number of milliseconds in a second
44+
constexpr double DAYS_PER_WEEK = 7.; ///< Number of days in a week
45+
constexpr double MINUTES_PER_HOUR = 60.; ///< Number of minutes in an hour
46+
constexpr double SECONDS_PER_MINUTE = 60.; ///< Number of seconds in a minute
47+
constexpr double MILLISECONDS_PER_SECOND = 1000.; ///< Number of milliseconds in a second
4848
constexpr double SECONDS_PER_HOUR = MINUTES_PER_HOUR * SECONDS_PER_MINUTE; ///< Number of seconds in an hour
49-
constexpr double HOURS_PER_DAY = 24; ///< Number of hours in a day
49+
constexpr double HOURS_PER_DAY = 24.; ///< Number of hours in a day
5050
constexpr double MINUTES_PER_DAY = HOURS_PER_DAY * MINUTES_PER_HOUR; ///< Number of minutes in a day
5151
constexpr double SECONDS_PER_DAY = MINUTES_PER_DAY * SECONDS_PER_MINUTE; ///< Number of seconds in a day
5252
constexpr double MILLISECONDS_PER_DAY =

include/NumCpp/Functions/complex.hpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ namespace nc
4343
/// @param inReal: the real component of the complex number
4444
/// @return value
4545
///
46-
template<typename dtype>
46+
template<typename dtype, typename dtypeOut = dtype>
4747
auto complex(dtype inReal)
4848
{
4949
STATIC_ASSERT_ARITHMETIC(dtype);
50+
STATIC_ASSERT_ARITHMETIC(dtypeOut);
5051

51-
return std::complex<dtype>(inReal);
52+
return std::complex<dtypeOut>(inReal);
5253
}
5354

5455
//============================================================================
@@ -59,12 +60,13 @@ namespace nc
5960
/// @param inImag: the imaginary component of the complex number
6061
/// @return value
6162
///
62-
template<typename dtype>
63+
template<typename dtype, typename dtypeOut = dtype>
6364
auto complex(dtype inReal, dtype inImag)
6465
{
6566
STATIC_ASSERT_ARITHMETIC(dtype);
67+
STATIC_ASSERT_ARITHMETIC(dtypeOut);
6668

67-
return std::complex<dtype>(inReal, inImag);
69+
return std::complex<dtypeOut>(inReal, inImag);
6870
}
6971

7072
//============================================================================
@@ -74,14 +76,14 @@ namespace nc
7476
/// @param inReal: the real component of the complex number
7577
/// @return NdArray
7678
///
77-
template<typename dtype>
79+
template<typename dtype, typename dtypeOut = dtype, std::enable_if_t<std::is_arithmetic_v<dtype>, int> = 0>
7880
auto complex(const NdArray<dtype>& inReal)
7981
{
80-
NdArray<decltype(nc::complex(dtype{ 0 }))> returnArray(inReal.shape());
82+
NdArray<decltype(nc::complex(dtypeOut{ 0 }))> returnArray(inReal.shape());
8183
stl_algorithms::transform(inReal.cbegin(),
8284
inReal.cend(),
8385
returnArray.begin(),
84-
[](dtype real) -> auto { return nc::complex(real); });
86+
[](dtype real) -> auto { return nc::complex<dtype, dtypeOut>(real); });
8587

8688
return returnArray;
8789
}
@@ -94,21 +96,38 @@ namespace nc
9496
/// @param inImag: the imaginary component of the complex number
9597
/// @return NdArray
9698
///
97-
template<typename dtype>
99+
template<typename dtype, typename dtypeOut = dtype>
98100
auto complex(const NdArray<dtype>& inReal, const NdArray<dtype>& inImag)
99101
{
100102
if (inReal.shape() != inImag.shape())
101103
{
102104
THROW_INVALID_ARGUMENT_ERROR("Input real array must be the same shape as input imag array");
103105
}
104106

105-
NdArray<decltype(nc::complex(dtype{ 0 }, dtype{ 0 }))> returnArray(inReal.shape());
107+
NdArray<decltype(nc::complex(dtypeOut{ 0 }, dtypeOut{ 0 }))> returnArray(inReal.shape());
106108
stl_algorithms::transform(inReal.cbegin(),
107109
inReal.cend(),
108110
inImag.cbegin(),
109111
returnArray.begin(),
110-
[](dtype real, dtype imag) -> auto { return nc::complex(real, imag); });
112+
[](dtype real, dtype imag) -> auto
113+
{ return nc::complex<dtype, dtypeOut>(real, imag); });
111114

112115
return returnArray;
113116
}
117+
118+
//============================================================================
119+
// Method Description:
120+
/// Returns a std::complex from the input real and imag components
121+
///
122+
/// @param inReal: the real component of the complex number
123+
/// @return NdArray
124+
///
125+
template<typename dtype, typename dtypeOut = dtype>
126+
auto complex(const NdArray<std::complex<dtype>>& inArray)
127+
{
128+
STATIC_ASSERT_ARITHMETIC(dtype);
129+
STATIC_ASSERT_ARITHMETIC(dtypeOut);
130+
131+
return inArray.template astype<std::complex<dtypeOut>>();
132+
}
114133
} // namespace nc

include/NumCpp/Functions/fft.hpp

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,54 @@
2929

3030
#include <complex>
3131

32+
#include "NumCpp/Core/Constants.hpp"
3233
#include "NumCpp/Core/Internal/StaticAsserts.hpp"
34+
#include "NumCpp/Core/Internal/StlAlgorithms.hpp"
3335
#include "NumCpp/Core/Types.hpp"
36+
#include "NumCpp/Functions/complex.hpp"
3437
#include "NumCpp/NdArray.hpp"
3538

3639
namespace nc
3740
{
41+
namespace detail
42+
{
43+
//===========================================================================
44+
// Method Description:
45+
/// Fast Fourier Transform
46+
///
47+
/// @param x the data
48+
/// @param n Length of the transformed axis of the output.
49+
///
50+
NdArray<std::complex<double>> fft(const NdArray<std::complex<double>>& x, uint32 n)
51+
{
52+
if (n == 0)
53+
{
54+
return {};
55+
}
56+
57+
auto result = NdArray<std::complex<double>>(1, n);
58+
59+
stl_algorithms::for_each(result.begin(),
60+
result.end(),
61+
[n, &x, &result](auto& resultElement)
62+
{
63+
const auto k = static_cast<double>(&resultElement - result.data());
64+
const auto minusTwoPiKOverN = -constants::twoPi * k / static_cast<double>(n);
65+
resultElement = std::complex<double>{ 0., 0. };
66+
std::for_each(x.begin(),
67+
x.begin() + std::min(n, x.size()),
68+
[minusTwoPiKOverN, &resultElement, &x, n](const auto& value)
69+
{
70+
const auto m = static_cast<double>(&value - x.data());
71+
const auto angle = minusTwoPiKOverN * m;
72+
resultElement += (value * std::polar(1., angle));
73+
});
74+
});
75+
76+
return result;
77+
}
78+
} // namespace detail
79+
3880
//===========================================================================
3981
// Method Description:
4082
/// Compute the one-dimensional discrete Fourier Transform.
@@ -56,15 +98,29 @@ namespace nc
5698
{
5799
case Axis::NONE:
58100
{
59-
return {};
101+
const auto data = nc::complex<dtype, double>(inArray);
102+
return detail::fft(data, inN);
60103
}
61104
case Axis::COL:
62105
{
63-
return {};
106+
auto data = nc::complex<dtype, double>(inArray);
107+
const auto& shape = inArray.shape();
108+
auto result = NdArray<std::complex<double>>(shape.rows, inN);
109+
const auto dataColSlice = data.cSlice();
110+
const auto resultColSlice = result.cSlice();
111+
112+
for (uint32 row = 0; row < data.numRows(); ++row)
113+
{
114+
const auto rowData = data(row, dataColSlice);
115+
const auto rowResult = detail::fft(rowData, inN);
116+
result.put(row, resultColSlice, rowResult);
117+
}
118+
119+
return result;
64120
}
65121
case Axis::ROW:
66122
{
67-
return fft(inArray.transpose(), inN, Axis::COL);
123+
return fft(inArray.transpose(), inN, Axis::COL).transpose();
68124
}
69125
default:
70126
{
@@ -133,15 +189,29 @@ namespace nc
133189
{
134190
case Axis::NONE:
135191
{
136-
return {};
192+
const auto data = nc::complex<dtype, double>(inArray);
193+
return detail::fft(data, inN);
137194
}
138195
case Axis::COL:
139196
{
140-
return {};
197+
const auto data = nc::complex<dtype, double>(inArray);
198+
const auto& shape = inArray.shape();
199+
auto result = NdArray<std::complex<double>>(shape.rows, inN);
200+
const auto dataColSlice = data.cSlice();
201+
const auto resultColSlice = result.cSlice();
202+
203+
for (uint32 row = 0; row < data.numRows(); ++row)
204+
{
205+
const auto rowData = data(row, dataColSlice);
206+
const auto rowResult = detail::fft(rowData, inN);
207+
result.put(row, resultColSlice, rowResult);
208+
}
209+
210+
return result;
141211
}
142212
case Axis::ROW:
143213
{
144-
return fft(inArray.transpose(), inN, Axis::COL);
214+
return fft(inArray.transpose(), inN, Axis::COL).transpose();
145215
}
146216
default:
147217
{

test/pytest/src/Functions.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,38 @@ namespace FunctionsInterface
11421142

11431143
//================================================================================
11441144

1145+
template<typename dtype>
1146+
pbArrayGeneric fft(const NdArray<dtype>& inArray, Axis inAxis)
1147+
{
1148+
return nc2pybind(nc::fft(inArray, inAxis));
1149+
}
1150+
1151+
//================================================================================
1152+
1153+
template<typename dtype>
1154+
pbArrayGeneric fftN(const NdArray<dtype>& inArray, uint32 inN, Axis inAxis)
1155+
{
1156+
return nc2pybind(nc::fft(inArray, inN, inAxis));
1157+
}
1158+
1159+
//================================================================================
1160+
1161+
template<typename dtype>
1162+
pbArrayGeneric fftComplex(const NdArray<std::complex<dtype>>& inArray, Axis inAxis)
1163+
{
1164+
return nc2pybind(nc::fft(inArray, inAxis));
1165+
}
1166+
1167+
//================================================================================
1168+
1169+
template<typename dtype>
1170+
pbArrayGeneric fftComplexN(const NdArray<std::complex<dtype>>& inArray, uint32 inN, Axis inAxis)
1171+
{
1172+
return nc2pybind(nc::fft(inArray, inN, inAxis));
1173+
}
1174+
1175+
//================================================================================
1176+
11451177
pbArrayGeneric find(const NdArray<bool>& inArray)
11461178
{
11471179
return nc2pybind(nc::find(inArray));
@@ -3414,6 +3446,10 @@ void initFunctions(pb11::module& m)
34143446
m.def("eyeShape", &FunctionsInterface::eyeShape<double>);
34153447
m.def("eyeShapeComplex", &FunctionsInterface::eyeShape<ComplexDouble>);
34163448

3449+
m.def("fft", &FunctionsInterface::fft<double>);
3450+
m.def("fft", &FunctionsInterface::fftN<double>);
3451+
m.def("fft", &FunctionsInterface::fftComplex<double>);
3452+
m.def("fft", &FunctionsInterface::fftComplexN<double>);
34173453
m.def("fillDiagonal", &fillDiagonal<double>);
34183454
m.def("find", &FunctionsInterface::find);
34193455
m.def("findN", &FunctionsInterface::findN);

0 commit comments

Comments
 (0)