Skip to content

Commit 932e387

Browse files
committed
Implemented xtensor FFT
1 parent 22ad9ea commit 932e387

File tree

14 files changed

+615
-3
lines changed

14 files changed

+615
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ __pycache__
6262

6363
# Generated files
6464
*.pc
65+
.vscode/settings.json

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ set(XTENSOR_HEADERS
140140
${XTENSOR_INCLUDE_DIR}/xtensor/xfixed.hpp
141141
${XTENSOR_INCLUDE_DIR}/xtensor/xfunction.hpp
142142
${XTENSOR_INCLUDE_DIR}/xtensor/xfunctor_view.hpp
143+
${XTENSOR_INCLUDE_DIR}/xtensor/xfft.hpp
143144
${XTENSOR_INCLUDE_DIR}/xtensor/xgenerator.hpp
144145
${XTENSOR_INCLUDE_DIR}/xtensor/xhistogram.hpp
145146
${XTENSOR_INCLUDE_DIR}/xtensor/xindex_view.hpp
@@ -199,6 +200,7 @@ target_link_libraries(xtensor INTERFACE xtl)
199200

200201
OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF)
201202
OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF)
203+
OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON)
202204
OPTION(BUILD_TESTS "xtensor test suite" OFF)
203205
OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF)
204206
OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF)
@@ -219,6 +221,10 @@ if(XTENSOR_CHECK_DIMENSION)
219221
add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION)
220222
endif()
221223

224+
if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
225+
add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
226+
endif()
227+
222228
if(DEFAULT_COLUMN_MAJOR)
223229
add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major)
224230
endif()

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# ![xtensor](docs/source/xtensor.svg)
22

3-
![linux](https://github.com/xtensor-stack/xtensor/actions/workflows/linux.yml/badge.svg)
4-
![osx](https://github.com/xtensor-stack/xtensor/actions/workflows/osx.yml/badge.svg)
5-
![windows](https://github.com/xtensor-stack/xtensor/actions/workflows/windows.yml/badge.svg)
3+
[![GHA Linux](https://github.com/xtensor-stack/xtensor/actions/workflows/linux.yml/badge.svg)](https://github.com/xtensor-stack/xtensor/actions/workflows/linux.yml)
4+
[![GHA OSX](https://github.com/xtensor-stack/xtensor/actions/workflows/osx.yml/badge.svg)](https://github.com/xtensor-stack/xtensor/actions/workflows/osx.yml)
5+
[![GHA Windows](https://github.com/xtensor-stack/xtensor/actions/workflows/windows.yml/badge.svg)](https://github.com/xtensor-stack/xtensor/actions/workflows/windows.yml)
66
[![Documentation](http://readthedocs.org/projects/xtensor/badge/?version=latest)](https://xtensor.readthedocs.io/en/latest/?badge=latest)
77
[![Doxygen -> gh-pages](https://github.com/xtensor-stack/xtensor/workflows/gh-pages/badge.svg)](https://xtensor-stack.github.io/xtensor)
88
[![Binder](https://mybinder.org/badge.svg)](https://mybinder.org/v2/gh/xtensor-stack/xtensor/stable?filepath=notebooks%2Fxtensor.ipynb)

docs/source/api/container_index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ xexpression API is actually implemented in ``xstrided_container`` and ``xcontain
3333
xindex_view
3434
xfunctor_view
3535
xrepeat
36+
xfft

docs/source/xfft.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht
2+
Distributed under the terms of the BSD 3-Clause License.
3+
The full license is in the file LICENSE, distributed with this software.
4+
xfft
5+
====
6+
7+
Defined in ``xtensor/xfft.hpp``
8+
9+
.. doxygenclass:: xt::fft_convolve
10+
:project: xtensor
11+
:members:
12+
13+
.. doxygentypedef:: xt::fft
14+
:project: xtensor
15+
16+
.. doxygentypedef:: xt::ifft
17+
:project: xtensor

include/xtensor/xbroadcast.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,29 @@ namespace xt
118118
return linear_end(c.expression());
119119
}
120120

121+
/*************************************
122+
* overlapping_memory_checker_traits *
123+
*************************************/
124+
125+
template <class E>
126+
struct overlapping_memory_checker_traits<
127+
E,
128+
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
129+
{
130+
static bool check_overlap(const E& expr, const memory_range& dst_range)
131+
{
132+
if (expr.size() == 0)
133+
{
134+
return false;
135+
}
136+
else
137+
{
138+
using ChildE = std::decay_t<decltype(expr.expression())>;
139+
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
140+
}
141+
}
142+
};
143+
121144
/**
122145
* @class xbroadcast
123146
* @brief Broadcasted xexpression to a specified shape.

include/xtensor/xfft.hpp

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#ifdef XTENSOR_USE_TBB
2+
#include <oneapi/tbb.h>
3+
#endif
4+
#include <stdexcept>
5+
6+
#include <xtl/xcomplex.hpp>
7+
8+
#include <xtensor/xarray.hpp>
9+
#include <xtensor/xaxis_slice_iterator.hpp>
10+
#include <xtensor/xbuilder.hpp>
11+
#include <xtensor/xcomplex.hpp>
12+
#include <xtensor/xmath.hpp>
13+
#include <xtensor/xnoalias.hpp>
14+
#include <xtensor/xview.hpp>
15+
16+
namespace xt
17+
{
18+
namespace fft
19+
{
20+
namespace detail
21+
{
22+
template <
23+
class E,
24+
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
25+
inline auto radix2(E&& e)
26+
{
27+
using namespace xt::placeholders;
28+
using namespace std::complex_literals;
29+
using value_type = typename std::decay_t<E>::value_type;
30+
using precision = typename value_type::value_type;
31+
auto N = e.size();
32+
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
33+
// check for power of 2
34+
if (!powerOfTwo || N == 0)
35+
{
36+
// TODO: Replace implementation with dft
37+
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
38+
}
39+
auto pi = xt::numeric_constants<precision>::PI;
40+
xt::xtensor<value_type, 1> ev = e;
41+
if (N <= 1)
42+
{
43+
return ev;
44+
}
45+
else
46+
{
47+
#ifdef XTENSOR_USE_TBB
48+
xt::xtensor<value_type, 1> even;
49+
xt::xtensor<value_type, 1> odd;
50+
oneapi::tbb::parallel_invoke(
51+
[&]
52+
{
53+
even = radix2(xt::view(ev, xt::range(0, _, 2)));
54+
},
55+
[&]
56+
{
57+
odd = radix2(xt::view(ev, xt::range(1, _, 2)));
58+
}
59+
);
60+
#else
61+
auto even = radix2(xt::view(ev, xt::range(0, _, 2)));
62+
auto odd = radix2(xt::view(ev, xt::range(1, _, 2)));
63+
#endif
64+
65+
auto range = xt::arange<double>(N / 2);
66+
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
67+
auto t = exp * odd;
68+
auto first_half = even + t;
69+
auto second_half = even - t;
70+
// TODO: should be a call to stack if performance was improved
71+
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N});
72+
xt::view(spectrum, xt::range(0, N / 2)) = first_half;
73+
xt::view(spectrum, xt::range(N / 2, N)) = second_half;
74+
return spectrum;
75+
}
76+
}
77+
78+
template <typename E>
79+
auto transform_bluestein(E&& data)
80+
{
81+
using value_type = typename std::decay_t<E>::value_type;
82+
using precision = typename value_type::value_type;
83+
84+
// Find a power-of-2 convolution length m such that m >= n * 2 + 1
85+
const std::size_t n = data.size();
86+
size_t m = std::ceil(std::log2(n * 2 + 1));
87+
m = std::pow(2, m);
88+
89+
// Trignometric table
90+
auto exp_table = xt::xtensor<std::complex<precision>, 1>::from_shape({n});
91+
xt::xtensor<std::size_t, 1> i = xt::pow(xt::linspace<std::size_t>(0, n - 1, n), 2);
92+
i %= (n * 2);
93+
94+
auto angles = xt::eval(precision{3.141592653589793238463} * i / n);
95+
auto j = std::complex<precision>(0, 1);
96+
exp_table = xt::exp(-angles * j);
97+
98+
// Temporary vectors and preprocessing
99+
auto av = xt::empty<std::complex<precision>>({m});
100+
xt::view(av, xt::range(0, n)) = data * exp_table;
101+
102+
103+
auto bv = xt::empty<std::complex<precision>>({m});
104+
xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table);
105+
xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view(
106+
::xt::conj(xt::flip(exp_table)),
107+
xt::range(xt::placeholders::_, -1)
108+
);
109+
110+
// Convolution
111+
auto xv = radix2(av);
112+
auto yv = radix2(bv);
113+
auto spectrum_k = xv * yv;
114+
auto complex_args = xt::conj(spectrum_k);
115+
auto fft_res = radix2(complex_args);
116+
auto cv = xt::conj(fft_res) / m;
117+
118+
return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table);
119+
}
120+
} // namespace detail
121+
122+
/**
123+
* @brief 1D FFT of an Nd array along a specified axis
124+
* @param e an Nd expression to be transformed to the fourier domain
125+
* @param axis the axis along which to perform the 1D FFT
126+
* @return a transformed xarray of the specified precision
127+
*/
128+
template <
129+
class E,
130+
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
131+
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
132+
{
133+
using value_type = typename std::decay_t<E>::value_type;
134+
using precision = typename value_type::value_type;
135+
const auto saxis = xt::normalize_axis(e.dimension(), axis);
136+
const size_t N = e.shape(saxis);
137+
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
138+
xt::xarray<std::complex<precision>> out = xt::eval(e);
139+
auto begin = xt::axis_slice_begin(out, saxis);
140+
auto end = xt::axis_slice_end(out, saxis);
141+
for (auto iter = begin; iter != end; iter++)
142+
{
143+
if (powerOfTwo)
144+
{
145+
xt::noalias(*iter) = detail::radix2(*iter);
146+
}
147+
else
148+
{
149+
xt::noalias(*iter) = detail::transform_bluestein(*iter);
150+
}
151+
}
152+
return out;
153+
}
154+
155+
/**
156+
* @breif 1D FFT of an Nd array along a specified axis
157+
* @param e an Nd expression to be transformed to the fourier domain
158+
* @param axis the axis along which to perform the 1D FFT
159+
* @return a transformed xarray of the specified precision
160+
*/
161+
template <
162+
class E,
163+
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
164+
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
165+
{
166+
using value_type = typename std::decay<E>::type::value_type;
167+
return fft(xt::cast<std::complex<value_type>>(e), axis);
168+
}
169+
170+
template <
171+
class E,
172+
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
173+
auto ifft(E&& e, std::ptrdiff_t axis = -1)
174+
{
175+
// check the length of the data on that axis
176+
const std::size_t n = e.shape(axis);
177+
if (n == 0)
178+
{
179+
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
180+
}
181+
auto complex_args = xt::conj(e);
182+
auto fft_res = xt::fft::fft(complex_args, axis);
183+
fft_res = xt::conj(fft_res);
184+
return fft_res;
185+
}
186+
187+
template <
188+
class E,
189+
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
190+
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
191+
{
192+
using value_type = typename std::decay<E>::type::value_type;
193+
return ifft(xt::cast<std::complex<value_type>>(e), axis);
194+
}
195+
196+
/*
197+
* @brief performs a circular fft convolution xvec and yvec must
198+
* be the same shape.
199+
* @param xvec first array of the convolution
200+
* @param yvec second array of the convolution
201+
* @param axis axis along which to perform the convolution
202+
*/
203+
template <typename E1, typename E2>
204+
auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1)
205+
{
206+
// we could broadcast but that could get complicated???
207+
if (xvec.dimension() != yvec.dimension())
208+
{
209+
XTENSOR_THROW(std::runtime_error, "Mismatched dimentions");
210+
}
211+
212+
auto saxis = xt::normalize_axis(xvec.dimension(), axis);
213+
if (xvec.shape(saxis) != yvec.shape(saxis))
214+
{
215+
XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis");
216+
}
217+
218+
const std::size_t n = xvec.shape(saxis);
219+
220+
auto xv = fft(xvec, axis);
221+
auto yv = fft(yvec, axis);
222+
223+
auto begin_x = xt::axis_slice_begin(xv, saxis);
224+
auto end_x = xt::axis_slice_end(xv, saxis);
225+
auto iter_y = xt::axis_slice_begin(yv, saxis);
226+
227+
for (auto iter = begin_x; iter != end_x; iter++)
228+
{
229+
(*iter) = (*iter_y++) * (*iter);
230+
}
231+
232+
auto outvec = ifft(xv, axis);
233+
234+
// Scaling (because this FFT implementation omits it)
235+
outvec = outvec / n;
236+
237+
return outvec;
238+
}
239+
240+
}
241+
} // namespace xt::fft

include/xtensor/xfunction.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,42 @@ namespace xt
162162
{
163163
};
164164

165+
/*************************************
166+
* overlapping_memory_checker_traits *
167+
*************************************/
168+
169+
template <class E>
170+
struct overlapping_memory_checker_traits<
171+
E,
172+
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xfunction, E>::value>>
173+
{
174+
template <std::size_t I = 0, class... T, std::enable_if_t<(I == sizeof...(T)), int> = 0>
175+
static bool check_tuple(const std::tuple<T...>&, const memory_range&)
176+
{
177+
return false;
178+
}
179+
180+
template <std::size_t I = 0, class... T, std::enable_if_t<(I < sizeof...(T)), int> = 0>
181+
static bool check_tuple(const std::tuple<T...>& t, const memory_range& dst_range)
182+
{
183+
using ChildE = std::decay_t<decltype(std::get<I>(t))>;
184+
return overlapping_memory_checker_traits<ChildE>::check_overlap(std::get<I>(t), dst_range)
185+
|| check_tuple<I + 1>(t, dst_range);
186+
}
187+
188+
static bool check_overlap(const E& expr, const memory_range& dst_range)
189+
{
190+
if (expr.size() == 0)
191+
{
192+
return false;
193+
}
194+
else
195+
{
196+
return check_tuple(expr.arguments(), dst_range);
197+
}
198+
}
199+
};
200+
165201
/*************
166202
* xfunction *
167203
*************/

0 commit comments

Comments
 (0)