Skip to content

Commit

Permalink
Apply review of jbigot
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau committed Mar 15, 2024
1 parent 6ec18f6 commit 6aaee6f
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 30 deletions.
2 changes: 1 addition & 1 deletion docs/first_steps.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Their type is not specified because we use C++
all of the same type: `DiscreteDomain<DDimX>` that represents a set of contiguous points in the
discretization of `X`.

\ref ddc::uniform_point_sampling_init_ghosted "init_ghosted" takes as parameters the coordinate of the first and last discretized points, the
\ref ddc::UniformPointSampling::init_ghosted "init_ghosted" takes as parameters the coordinate of the first and last discretized points, the
number of discretized points in the domain and the number of additional points on each side of the
domain.
The fours `DiscreteDomain`s returned are:
Expand Down
4 changes: 3 additions & 1 deletion include/ddc/discrete_space.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ auto extract_after(Tuple&& t, std::index_sequence<Ids...>)
template <class DDim, class... Args>
void init_discrete_space(Args&&... args)
{
static_assert(!std::is_same_v<DDim, typename DDim::discrete_dimension_type>);
static_assert(
!std::is_same_v<DDim, typename DDim::discrete_dimension_type>,
"Discrete dimensions should inherit from the discretization, not use an alias");
if (detail::g_discrete_space_dual<DDim>) {
throw std::runtime_error("Discrete space function already initialized.");
}
Expand Down
42 changes: 42 additions & 0 deletions include/ddc/kernels/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,18 @@ struct kwArgs_core
template <typename DDim, typename... DDimX>
int N(ddc::DiscreteDomain<DDimX...> x_mesh)
{
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
return ddc::get<DDim>(x_mesh.extents());
}

template <typename DDim, typename... DDimX>
double a(ddc::DiscreteDomain<DDimX...> x_mesh)
{
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
return ((2 * N<DDim>(x_mesh) - 1) * coordinate(ddc::select<DDim>(x_mesh).front())
- coordinate(ddc::select<DDim>(x_mesh).back()))
/ 2 / (N<DDim>(x_mesh) - 1);
Expand All @@ -343,6 +349,9 @@ double a(ddc::DiscreteDomain<DDimX...> x_mesh)
template <typename DDim, typename... DDimX>
double b(ddc::DiscreteDomain<DDimX...> x_mesh)
{
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
return ((2 * N<DDim>(x_mesh) - 1) * coordinate(ddc::select<DDim>(x_mesh).back())
- coordinate(ddc::select<DDim>(x_mesh).front()))
/ 2 / (N<DDim>(x_mesh) - 1);
Expand All @@ -366,6 +375,9 @@ void core(
static_assert(
Kokkos::SpaceAccessibility<ExecSpace, MemorySpace>::accessible,
"MemorySpace has to be accessible for ExecutionSpace.");
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");

std::array<int, sizeof...(DDimX)> n = {(int)ddc::get<DDimX>(mesh.extents())...};
int idist = 1;
Expand Down Expand Up @@ -569,6 +581,12 @@ template <typename DDimFx, typename DDimX>
typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> FourierSampling(
ddc::DiscreteDomain<DDimX> x_mesh)
{
static_assert(
is_uniform_sampling_v<DDimX>,
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
is_periodic_sampling_v<DDimFx>,
"DDimFx dimensions should derive from PeriodicPointSampling");
auto [impl, ddom] = DDimFx::template init<DDimFx>(
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(0),
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(
Expand All @@ -586,6 +604,12 @@ namespace ddc {
template <typename... DDimFx, typename... DDimX>
void init_fourier_space(ddc::DiscreteDomain<DDimX...> x_mesh)
{
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");
return (ddc::init_discrete_space<DDimFx>(
ddc::detail::fft::FourierSampling<DDimFx>(ddc::select<DDimX>(x_mesh))),
...);
Expand All @@ -595,6 +619,12 @@ void init_fourier_space(ddc::DiscreteDomain<DDimX...> x_mesh)
template <typename... DDimFx, typename... DDimX>
ddc::DiscreteDomain<DDimFx...> FourierMesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
{
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");
return ddc::DiscreteDomain<DDimFx...>(ddc::DiscreteDomain<DDimFx>(
ddc::DiscreteElement<DDimFx>(0),
ddc::DiscreteVector<DDimFx>(
Expand Down Expand Up @@ -631,6 +661,12 @@ void fft(
std::experimental::
layout_right> && std::is_same_v<layout_out, std::experimental::layout_right>,
"Layouts must be right-handed");
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");

ddc::detail::fft::core<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
execSpace,
Expand Down Expand Up @@ -662,6 +698,12 @@ void ifft(
std::experimental::
layout_right> && std::is_same_v<layout_out, std::experimental::layout_right>,
"Layouts must be right-handed");
static_assert(
(is_uniform_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");

ddc::detail::fft::core<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
execSpace,
Expand Down
48 changes: 24 additions & 24 deletions include/ddc/kernels/splines/greville_interpolation_points.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@

namespace ddc {

template <class Sampling>
struct UniformSamplingHidden
: ddc::UniformPointSampling<typename Sampling::continuous_dimension_type>
{
};

template <class Sampling>
struct NonUniformSamplingHidden
: ddc::NonUniformPointSampling<typename Sampling::continuous_dimension_type>
{
};

template <class BSplines, ddc::BoundCond BcXmin, ddc::BoundCond BcXmax>
class GrevilleInterpolationPoints
{
using tag_type = typename BSplines::tag_type;

template <class Sampling>
struct IntermediateUniformSampling
: UniformPointSampling<typename Sampling::continuous_dimension_type>
{
};

template <class Sampling>
struct IntermediateNonUniformSampling
: NonUniformPointSampling<typename Sampling::continuous_dimension_type>
{
};

template <class Sampling, typename U = BSplines, class = std::enable_if_t<U::is_uniform()>>
static auto uniform_greville_points()
{
Expand Down Expand Up @@ -116,8 +116,8 @@ class GrevilleInterpolationPoints
{
using SamplingImpl = typename Sampling::template Impl<Sampling, Kokkos::HostSpace>;
if constexpr (U::is_uniform()) {
using HiddenSampling = UniformSamplingHidden<Sampling>;
auto points_wo_bcs = uniform_greville_points<HiddenSampling>();
using IntermediateSampling = IntermediateUniformSampling<Sampling>;
auto points_wo_bcs = uniform_greville_points<IntermediateSampling>();
int const n_break_points = ddc::discrete_space<BSplines>().ncells() + 1;
int const npoints = ddc::discrete_space<BSplines>().nbasis() - N_BE_MIN - N_BE_MAX;
std::vector<double> points_with_bcs(npoints);
Expand All @@ -136,15 +136,15 @@ class GrevilleInterpolationPoints
}
} else {
points_with_bcs[0]
= points_wo_bcs.coordinate(ddc::DiscreteElement<HiddenSampling>(0));
= points_wo_bcs.coordinate(ddc::DiscreteElement<IntermediateSampling>(0));
}

int const n_start
= (BcXmin == ddc::BoundCond::GREVILLE) ? BSplines::degree() / 2 + 1 : 1;
int const domain_size = n_break_points - 2;
ddc::DiscreteDomain<HiddenSampling> const
domain(ddc::DiscreteElement<HiddenSampling>(1),
ddc::DiscreteVector<HiddenSampling>(domain_size));
ddc::DiscreteDomain<IntermediateSampling> const
domain(ddc::DiscreteElement<IntermediateSampling>(1),
ddc::DiscreteVector<IntermediateSampling>(domain_size));

// Copy central points
ddc::for_each(domain, [&](auto ip) {
Expand All @@ -167,25 +167,25 @@ class GrevilleInterpolationPoints
}
} else {
points_with_bcs[npoints - 1]
= points_wo_bcs.coordinate(ddc::DiscreteElement<HiddenSampling>(
= points_wo_bcs.coordinate(ddc::DiscreteElement<IntermediateSampling>(
ddc::discrete_space<BSplines>().ncells() - 1
+ BSplines::degree() % 2));
}
return SamplingImpl(points_with_bcs);
} else {
using HiddenSampling = NonUniformSamplingHidden<Sampling>;
using IntermediateSampling = IntermediateNonUniformSampling<Sampling>;
if constexpr (N_BE_MIN == 0 && N_BE_MAX == 0) {
return non_uniform_greville_points<Sampling>();
} else {
auto points_wo_bcs = non_uniform_greville_points<HiddenSampling>();
auto points_wo_bcs = non_uniform_greville_points<IntermediateSampling>();
// All points are Greville points. Extract unnecessary points near the boundary
std::vector<double> points_with_bcs(points_wo_bcs.size() - N_BE_MIN - N_BE_MAX);
int constexpr n_start = N_BE_MIN;

using length = ddc::DiscreteVector<HiddenSampling>;
using length = ddc::DiscreteVector<IntermediateSampling>;

ddc::DiscreteDomain<HiddenSampling> const
domain(ddc::DiscreteElement<HiddenSampling>(n_start),
ddc::DiscreteDomain<IntermediateSampling> const
domain(ddc::DiscreteElement<IntermediateSampling>(n_start),
length(points_with_bcs.size()));

points_with_bcs[0] = points_wo_bcs.coordinate(domain.front());
Expand Down
7 changes: 6 additions & 1 deletion include/ddc/non_uniform_point_sampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ class NonUniformPointSampling : NonUniformPointSamplingBase
};

template <class DDim>
constexpr bool is_non_uniform_sampling_v = std::is_base_of_v<NonUniformPointSamplingBase, DDim>;
struct is_non_uniform_sampling : public std::is_base_of<NonUniformPointSamplingBase, DDim>
{
};

template <class DDim>
constexpr bool is_non_uniform_sampling_v = is_non_uniform_sampling<DDim>::value;

template <
class DDimImpl,
Expand Down
9 changes: 7 additions & 2 deletions include/ddc/periodic_sampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class PeriodicSampling : PeriodicSamplingBase
assert(a < b);
assert(n > 1);
assert(n_period > 1);
double discretization_step {(b - a) / (n - 1)};
Real discretization_step {(b - a) / (n - 1)};
Impl<DDim, Kokkos::HostSpace>
disc(a - n_ghosts_before.value() * discretization_step,
discretization_step,
Expand Down Expand Up @@ -231,7 +231,12 @@ class PeriodicSampling : PeriodicSamplingBase
};

template <class DDim>
constexpr bool is_periodic_sampling_v = std::is_base_of_v<PeriodicSamplingBase, DDim>;
struct is_periodic_sampling : public std::is_base_of<PeriodicSamplingBase, DDim>
{
};

template <class DDim>
constexpr bool is_periodic_sampling_v = is_periodic_sampling<DDim>::value;

template <
class DDimImpl,
Expand Down
7 changes: 6 additions & 1 deletion include/ddc/uniform_point_sampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,12 @@ class UniformPointSampling : UniformPointSamplingBase
};

template <class DDim>
constexpr bool is_uniform_sampling_v = std::is_base_of_v<UniformPointSamplingBase, DDim>;
struct is_uniform_sampling : public std::is_base_of<UniformPointSamplingBase, DDim>
{
};

template <class DDim>
constexpr bool is_uniform_sampling_v = is_uniform_sampling<DDim>::value;

template <
class DDimImpl,
Expand Down

0 comments on commit 6aaee6f

Please sign in to comment.