Skip to content

Commit 815bea2

Browse files
committed
added fftshift and ifftshift
1 parent 544e5f8 commit 815bea2

File tree

3 files changed

+81
-4
lines changed

3 files changed

+81
-4
lines changed

include/NumCpp/FFT/fftshift.hpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include "NumCpp/Core/Internal/StaticAsserts.hpp"
3131
#include "NumCpp/Core/Types.hpp"
32+
#include "NumCpp/Functions/roll.hpp"
3233
#include "NumCpp/NdArray.hpp"
3334

3435
namespace nc::fft
@@ -51,6 +52,25 @@ namespace nc::fft
5152
{
5253
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
5354

54-
return {};
55+
switch (inAxis)
56+
{
57+
case Axis::NONE:
58+
{
59+
return roll(inX, inX.size() / 2, inAxis);
60+
}
61+
case Axis::COL:
62+
{
63+
return roll(inX, inX.numCols() / 2, inAxis);
64+
}
65+
case Axis::ROW:
66+
{
67+
return roll(inX, inX.numRows() / 2, inAxis);
68+
}
69+
default:
70+
{
71+
THROW_INVALID_ARGUMENT_ERROR("Unimplemented axis type.");
72+
return {};
73+
}
74+
}
5575
}
5676
} // namespace nc::fft

include/NumCpp/FFT/ifftshift.hpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,31 @@ namespace nc::fft
5050
{
5151
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
5252

53-
return {};
53+
switch (inAxis)
54+
{
55+
case Axis::NONE:
56+
{
57+
auto shift = inX.size() / 2;
58+
shift += inX.size() % 2 == 1 ? 1 : 0;
59+
return roll(inX, shift, inAxis);
60+
}
61+
case Axis::COL:
62+
{
63+
auto shift = inX.numCols() / 2;
64+
shift += inX.numCols() % 2 == 1 ? 1 : 0;
65+
return roll(inX, shift, inAxis);
66+
}
67+
case Axis::ROW:
68+
{
69+
auto shift = inX.numRows() / 2;
70+
shift += inX.numRows() % 2 == 1 ? 1 : 0;
71+
return roll(inX, shift, inAxis);
72+
}
73+
default:
74+
{
75+
THROW_INVALID_ARGUMENT_ERROR("Unimplemented axis type.");
76+
return {};
77+
}
78+
}
5479
}
5580
} // namespace nc::fft

test/pytest/test_fft.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,12 +611,44 @@ def test_fftfreq():
611611

612612
####################################################################################
613613
def test_fftshift():
614-
assert False
614+
for _ in range(50):
615+
n = np.random.randint(10, 1000)
616+
d = np.random.rand()
617+
freqs = np.fft.fftfreq(n, d)
618+
cFreqs = NumCpp.NdArray(1, freqs.size)
619+
cFreqs.setArray(freqs)
620+
assert np.array_equal(np.round(NumCpp.fftshift(cFreqs, NumCpp.Axis.NONE).flatten(), 8), np.round(np.fft.fftshift(freqs), 8))
621+
622+
dim0 = np.random.randint(10, 100)
623+
dim1 = np.random.randint(10, 100)
624+
n = dim0 * dim1
625+
d = np.random.rand()
626+
freqs = np.fft.fftfreq(n, d).reshape(dim0, dim1)
627+
cFreqs = NumCpp.NdArray(freqs.shape[0], freqs.shape[1])
628+
cFreqs.setArray(freqs)
629+
assert np.array_equal(np.round(NumCpp.fftshift(cFreqs, NumCpp.Axis.ROW), 8), np.round(np.fft.fftshift(freqs, axes=0), 8))
630+
assert np.array_equal(np.round(NumCpp.fftshift(cFreqs, NumCpp.Axis.COL), 8), np.round(np.fft.fftshift(freqs, axes=1), 8))
615631

616632

617633
####################################################################################
618634
def test_ifftshift():
619-
assert False
635+
for _ in range(50):
636+
n = np.random.randint(10, 1000)
637+
d = np.random.rand()
638+
freqs = np.fft.fftfreq(n, d)
639+
cFreqs = NumCpp.NdArray(1, freqs.size)
640+
cFreqs.setArray(freqs)
641+
assert np.array_equal(np.round(NumCpp.ifftshift(cFreqs, NumCpp.Axis.NONE).flatten(), 8), np.round(np.fft.ifftshift(freqs), 8))
642+
643+
dim0 = np.random.randint(10, 100)
644+
dim1 = np.random.randint(10, 100)
645+
n = dim0 * dim1
646+
d = np.random.rand()
647+
freqs = np.fft.fftfreq(n, d).reshape(dim0, dim1)
648+
cFreqs = NumCpp.NdArray(freqs.shape[0], freqs.shape[1])
649+
cFreqs.setArray(freqs)
650+
assert np.array_equal(np.round(NumCpp.ifftshift(cFreqs, NumCpp.Axis.ROW), 8), np.round(np.fft.ifftshift(freqs, axes=0), 8))
651+
assert np.array_equal(np.round(NumCpp.ifftshift(cFreqs, NumCpp.Axis.COL), 8), np.round(np.fft.ifftshift(freqs, axes=1), 8))
620652

621653

622654
####################################################################################

0 commit comments

Comments
 (0)