Skip to content

Commit

Permalink
Generalize serial_host to serial
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau committed Sep 6, 2023
1 parent 4f91a1b commit 15999ff
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 90 deletions.
21 changes: 17 additions & 4 deletions examples/heat_equation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void display(double time, ChunkType temp)
std::cout << " * temperature[y:"
<< ddc::get_domain<DDimY>(temp).size() / 2 << "] = {";
ddc::for_each(
ddc::policies::serial_host,
ddc::policies::serial,
ddc::get_domain<DDimX>(temp),
[=](ddc::DiscreteElement<DDimX> const ix) {
std::cout << std::setw(6) << temp_slice(ix);
Expand Down Expand Up @@ -238,7 +238,10 @@ int main(int argc, char** argv)

//! [initial output]
// display the initial data
ddc::deepcopy(ghosted_temp, ghosted_last_temp);
ddc::deepcopy(
ddc::policies::parallel_device,
ghosted_temp,
ghosted_last_temp);
display(ddc::coordinate(time_domain.front()),
ghosted_temp[x_domain][y_domain]);
// time of the iteration where the last output happened
Expand All @@ -253,15 +256,19 @@ int main(int argc, char** argv)
//! [boundary conditions]
// Periodic boundary conditions
ddc::deepcopy(
ddc::policies::parallel_device,
ghosted_last_temp[x_pre_ghost][y_domain],
ghosted_last_temp[y_domain][x_domain_end]);
ddc::deepcopy(
ddc::policies::parallel_device,
ghosted_last_temp[y_domain][x_post_ghost],
ghosted_last_temp[y_domain][x_domain_begin]);
ddc::deepcopy(
ddc::policies::parallel_device,
ghosted_last_temp[x_domain][y_pre_ghost],
ghosted_last_temp[x_domain][y_domain_end]);
ddc::deepcopy(
ddc::policies::parallel_device,
ghosted_last_temp[x_domain][y_post_ghost],
ghosted_last_temp[x_domain][y_domain_begin]);
//! [boundary conditions]
Expand Down Expand Up @@ -311,7 +318,10 @@ int main(int argc, char** argv)
//! [output]
if (iter - last_output >= t_output_period) {
last_output = iter;
ddc::deepcopy(ghosted_temp, ghosted_last_temp);
ddc::deepcopy(
ddc::policies::parallel_device,
ghosted_temp,
ghosted_last_temp);
display(ddc::coordinate(iter),
ghosted_temp[x_domain][y_domain]);
}
Expand All @@ -325,7 +335,10 @@ int main(int argc, char** argv)

//! [final output]
if (last_output < time_domain.back()) {
ddc::deepcopy(ghosted_temp, ghosted_last_temp);
ddc::deepcopy(
ddc::policies::parallel_device,
ghosted_temp,
ghosted_last_temp);
display(ddc::coordinate(time_domain.back()),
ghosted_temp[x_domain][y_domain]);
}
Expand Down
17 changes: 13 additions & 4 deletions examples/heat_equation_spectral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void display(double time, ChunkType temp)
std::cout << " * temperature[y:"
<< ddc::get_domain<DDimY>(temp).size() / 2 << "] = {";
ddc::for_each(
ddc::policies::serial_host,
ddc::policies::serial,
ddc::get_domain<DDimX>(temp),
[=](ddc::DiscreteElement<DDimX> const ix) {
std::cout << std::setw(6) << temp_slice(ix);
Expand Down Expand Up @@ -200,7 +200,10 @@ int main(int argc, char** argv)

//! [initial output]
// display the initial data
ddc::deepcopy(_host_temp, _last_temp);
ddc::deepcopy(
ddc::policies::parallel_device,
_host_temp,
_last_temp);
display(ddc::coordinate(time_domain.front()),
_host_temp[x_domain][y_domain]);
// time of the iteration where the last output happened
Expand Down Expand Up @@ -273,7 +276,10 @@ int main(int argc, char** argv)
//! [output]
if (iter - last_output >= t_output_period) {
last_output = iter;
ddc::deepcopy(_host_temp, _last_temp);
ddc::deepcopy(
ddc::policies::parallel_device,
_host_temp,
_last_temp);
display(ddc::coordinate(iter),
_host_temp[x_domain][y_domain]);
}
Expand All @@ -287,7 +293,10 @@ int main(int argc, char** argv)

//! [final output]
if (last_output < time_domain.back()) {
ddc::deepcopy(_host_temp, _last_temp);
ddc::deepcopy(
ddc::policies::parallel_device,
_host_temp,
_last_temp);
display(ddc::coordinate(time_domain.back()),
_host_temp[x_domain][y_domain]);
}
Expand Down
54 changes: 53 additions & 1 deletion include/ddc/deepcopy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <Kokkos_Core.hpp>

#include "ddc/chunk_span.hpp"
#include "ddc/for_each.hpp"

namespace ddc {

Expand All @@ -16,7 +17,7 @@ namespace ddc {
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class ChunkSrc>
auto deepcopy(ChunkDst&& dst, ChunkSrc&& src)
auto deepcopy(parallel_host_policy, ChunkDst&& dst, ChunkSrc&& src)
{
static_assert(is_borrowed_chunk_v<ChunkDst>);
static_assert(is_borrowed_chunk_v<ChunkSrc>);
Expand All @@ -28,4 +29,55 @@ auto deepcopy(ChunkDst&& dst, ChunkSrc&& src)
return dst.span_view();
}

/** Copy the content of a borrowed chunk into another
* @param[out] dst the borrowed chunk in which to copy
* @param[in] src the borrowed chunk from which to copy
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class ChunkSrc>
auto deepcopy(parallel_device_policy, ChunkDst&& dst, ChunkSrc&& src)
{
return deepcopy(
policies::parallel_host,
std::forward<ChunkDst>(dst),
std::forward<ChunkSrc>(src));
}

/** Copy the content of a borrowed chunk into another
* @param[out] dst the borrowed chunk in which to copy
* @param[in] src the borrowed chunk from which to copy
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class ChunkSrc>
constexpr auto deepcopy(serial_policy, ChunkDst&& dst, ChunkSrc&& src)
{
static_assert(is_borrowed_chunk_v<ChunkDst>);
static_assert(is_borrowed_chunk_v<ChunkSrc>);
static_assert(
std::is_assignable_v<chunk_reference_t<ChunkDst>, chunk_reference_t<ChunkSrc>>,
"Not assignable");
assert(dst.domain().extents() == src.domain().extents());
KOKKOS_ENSURES((Kokkos::SpaceAccessibility<
DDC_CURRENT_KOKKOS_SPACE,
typename std::remove_cv_t<std::remove_reference_t<ChunkSrc>>::memory_space>::
accessible));
KOKKOS_ENSURES((Kokkos::SpaceAccessibility<
DDC_CURRENT_KOKKOS_SPACE,
typename std::remove_cv_t<std::remove_reference_t<ChunkDst>>::memory_space>::
accessible));
for_each(policies::serial, dst.domain(), [&](auto elem) { dst(elem) = src(elem); });
return dst.span_view();
}

/** Copy the content of a borrowed chunk into another
* @param[out] dst the borrowed chunk in which to copy
* @param[in] src the borrowed chunk from which to copy
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class ChunkSrc>
constexpr auto deepcopy(ChunkDst&& dst, ChunkSrc&& src)
{
return deepcopy(policies::serial, std::forward<ChunkDst>(dst), std::forward<ChunkSrc>(src));
}

} // namespace ddc
10 changes: 5 additions & 5 deletions include/ddc/discrete_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,31 +173,31 @@ class DiscreteDomain
template <
std::size_t N = sizeof...(DDims),
class DDim0 = std::enable_if_t<N == 1, std::tuple_element_t<0, std::tuple<DDims...>>>>
auto begin() const
constexpr auto begin() const
{
return DiscreteDomainIterator<DDim0>(front());
}

template <
std::size_t N = sizeof...(DDims),
class DDim0 = std::enable_if_t<N == 1, std::tuple_element_t<0, std::tuple<DDims...>>>>
auto end() const
constexpr auto end() const
{
return DiscreteDomainIterator<DDim0>(m_element_end);
}

template <
std::size_t N = sizeof...(DDims),
class DDim0 = std::enable_if_t<N == 1, std::tuple_element_t<0, std::tuple<DDims...>>>>
auto cbegin() const
constexpr auto cbegin() const
{
return DiscreteDomainIterator<DDim0>(front());
}

template <
std::size_t N = sizeof...(DDims),
class DDim0 = std::enable_if_t<N == 1, std::tuple_element_t<0, std::tuple<DDims...>>>>
auto cend() const
constexpr auto cend() const
{
return DiscreteDomainIterator<DDim0>(m_element_end);
}
Expand Down Expand Up @@ -459,7 +459,7 @@ struct DiscreteDomainIterator

using difference_type = std::ptrdiff_t;

DiscreteDomainIterator() = default;
constexpr DiscreteDomainIterator() = default;

constexpr explicit DiscreteDomainIterator(DiscreteElement<DDim> value) : m_value(value) {}

Expand Down
46 changes: 45 additions & 1 deletion include/ddc/fill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
#pragma once

#include <type_traits>
#include <utility>
#include <variant>

#include <Kokkos_Core.hpp>

#include "ddc/chunk_span.hpp"
#include "ddc/detail/macros.hpp"
#include "ddc/for_each.hpp"

namespace ddc {

Expand All @@ -16,12 +20,52 @@ namespace ddc {
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class T>
auto fill(ChunkDst&& dst, T const& value)
auto fill(parallel_host_policy, ChunkDst&& dst, T const& value)
{
static_assert(is_borrowed_chunk_v<ChunkDst>);
static_assert(std::is_assignable_v<chunk_reference_t<ChunkDst>, T>, "Not assignable");
Kokkos::deep_copy(dst.allocation_kokkos_view(), value);
return dst.span_view();
}

/** Fill a borrowed chunk with a given value
* @param[out] dst the borrowed chunk in which to copy
* @param[in] value the value to fill `dst`
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class T>
auto fill(parallel_device_policy, ChunkDst&& dst, T const& value)
{
return fill(ddc::policies::parallel_host, std::forward<ChunkDst>(dst), value);
}

/** Fill a borrowed chunk with a given value
* @param[out] dst the borrowed chunk in which to copy
* @param[in] value the value to fill `dst`
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class T>
constexpr auto fill(serial_policy, ChunkDst&& dst, T const& value)
{
static_assert(is_borrowed_chunk_v<ChunkDst>);
static_assert(std::is_assignable_v<chunk_reference_t<ChunkDst>, T>, "Not assignable");
KOKKOS_ENSURES((Kokkos::SpaceAccessibility<
DDC_CURRENT_KOKKOS_SPACE,
typename std::remove_cv_t<std::remove_reference_t<ChunkDst>>::memory_space>::
accessible));
for_each(policies::serial, dst.domain(), [&](auto elem) { dst(elem) = value; });
return dst.span_view();
}

/** Fill a borrowed chunk with a given value
* @param[out] dst the borrowed chunk in which to copy
* @param[in] value the value to fill `dst`
* @return dst as a ChunkSpan
*/
template <class ChunkDst, class T>
constexpr auto fill(ChunkDst&& dst, T const& value)
{
return fill(policies::serial, std::forward<ChunkDst>(dst), value);
}

} // namespace ddc
Loading

0 comments on commit 15999ff

Please sign in to comment.