Skip to content

Commit

Permalink
Merge host and device access functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
icui committed Feb 26, 2025
1 parent 450c926 commit b09b9c2
Show file tree
Hide file tree
Showing 7 changed files with 501 additions and 1,397 deletions.
68 changes: 39 additions & 29 deletions include/compute/compute_mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ struct mesh {
*/

/**
* @brief Load quadrature data for a spectral element on the device
* @brief Load quadrature data for a spectral element on host or device
*
* @ingroup QuadratureDataAccess
*
Expand All @@ -239,11 +239,11 @@ struct mesh {
* @param quadrature Quadrature data
* @param element_quadrature Quadrature data for the element (output)
*/
template <typename MemberType, typename ViewType>
KOKKOS_FUNCTION void
load_on_device(const MemberType &team,
const specfem::compute::quadrature &quadrature,
ViewType &element_quadrature) {
template <bool on_device, typename MemberType, typename ViewType>
KOKKOS_INLINE_FUNCTION void
impl_load(const MemberType &team,
const specfem::compute::quadrature &quadrature,
ViewType &element_quadrature) {

constexpr bool store_hprime_gll = ViewType::store_hprime_gll;

Expand All @@ -260,15 +260,45 @@ load_on_device(const MemberType &team,
int ix, iz;
sub2ind(xz, NGLL, iz, ix);
if constexpr (store_hprime_gll) {
element_quadrature.hprime_gll(iz, ix) = quadrature.gll.hprime(iz, ix);
element_quadrature.hprime_gll(iz, ix) =
on_device ? quadrature.gll.hprime(iz, ix)
: quadrature.gll.h_hprime(iz, ix);
}
if constexpr (store_weight_times_hprime_gll) {
element_quadrature.hprime_wgll(ix, iz) =
quadrature.gll.hprime(iz, ix) * quadrature.gll.weights(iz);
on_device
? quadrature.gll.hprime(iz, ix) * quadrature.gll.weights(iz)
: quadrature.gll.h_hprime(iz, ix) *
quadrature.gll.h_weights(iz);
}
});
}

/**
* @defgroup QuadratureDataAccess
*
*/

/**
* @brief Load quadrature data for a spectral element on the device
*
* @ingroup QuadratureDataAccess
*
* @tparam MemberType Member type. Needs to be a Kokkos::TeamPolicy member type
* @tparam ViewType View type. Needs to be of @ref specfem::element::quadrature
* @param team Team member
* @param quadrature Quadrature data
* @param element_quadrature Quadrature data for the element (output)
*/
template <typename MemberType, typename ViewType>
KOKKOS_FUNCTION void
load_on_device(const MemberType &team,
const specfem::compute::quadrature &quadrature,
ViewType &element_quadrature) {

impl_load<true>(team, quadrature, element_quadrature);
}

/**
* @brief Load quadrature data for a spectral element on the host
*
Expand All @@ -284,27 +314,7 @@ template <typename MemberType, typename ViewType>
void load_on_host(const MemberType &team,
const specfem::compute::quadrature &quadrature,
ViewType &element_quadrature) {

constexpr bool store_hprime_gll = ViewType::store_hprime_gll;
constexpr bool store_weight_times_hprime_gll =
ViewType::store_weight_times_hprime_gll;
constexpr int NGLL = ViewType::ngll;

Kokkos::parallel_for(
Kokkos::TeamThreadRange(team, NGLL * NGLL), [=](const int &xz) {
int ix, iz;
sub2ind(xz, NGLL, iz, ix);
if constexpr (store_hprime_gll) {
element_quadrature.hprime_gll(iz, ix) =
quadrature.gll.h_hprime(iz, ix);
}
if constexpr (store_weight_times_hprime_gll) {
element_quadrature.hprime_wgll(ix, iz) =
quadrature.gll.h_hprime(iz, ix) * quadrature.gll.h_weights(iz);
}
});

return;
impl_load<false>(team, quadrature, element_quadrature);
}

} // namespace compute
Expand Down
109 changes: 32 additions & 77 deletions include/compute/compute_partial_derivatives.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ struct partial_derivatives {
*
*/

template <typename PointPartialDerivativesType,
template <bool on_device, typename PointPartialDerivativesType,
typename std::enable_if_t<
PointPartialDerivativesType::simd::using_simd, int> = 0>
KOKKOS_FORCEINLINE_FUNCTION void impl_load_on_device(
KOKKOS_FORCEINLINE_FUNCTION void impl_load(
const specfem::point::simd_index<PointPartialDerivativesType::dimension>
&index,
const specfem::compute::partial_derivatives &derivatives,
Expand All @@ -95,84 +95,33 @@ KOKKOS_FORCEINLINE_FUNCTION void impl_load_on_device(
mask_type mask([&](std::size_t lane) { return index.mask(lane); });

Kokkos::Experimental::where(mask, partial_derivatives.xix)
.copy_from(&derivatives.xix(ispec, iz, ix), tag_type());
Kokkos::Experimental::where(mask, partial_derivatives.gammax)
.copy_from(&derivatives.gammax(ispec, iz, ix), tag_type());
Kokkos::Experimental::where(mask, partial_derivatives.xiz)
.copy_from(&derivatives.xiz(ispec, iz, ix), tag_type());
Kokkos::Experimental::where(mask, partial_derivatives.gammaz)
.copy_from(&derivatives.gammaz(ispec, iz, ix), tag_type());
if constexpr (StoreJacobian) {
Kokkos::Experimental::where(mask, partial_derivatives.jacobian)
.copy_from(&derivatives.jacobian(ispec, iz, ix), tag_type());
}
}

template <typename PointPartialDerivativesType,
typename std::enable_if_t<
!PointPartialDerivativesType::simd::using_simd, int> = 0>
KOKKOS_FORCEINLINE_FUNCTION void impl_load_on_device(
const specfem::point::index<PointPartialDerivativesType::dimension> &index,
const specfem::compute::partial_derivatives &derivatives,
PointPartialDerivativesType &partial_derivatives) {

const int ispec = index.ispec;
const int iz = index.iz;
const int ix = index.ix;

constexpr static bool StoreJacobian =
PointPartialDerivativesType::store_jacobian;

partial_derivatives.xix = derivatives.xix(ispec, iz, ix);
partial_derivatives.gammax = derivatives.gammax(ispec, iz, ix);
partial_derivatives.xiz = derivatives.xiz(ispec, iz, ix);
partial_derivatives.gammaz = derivatives.gammaz(ispec, iz, ix);
if constexpr (StoreJacobian) {
partial_derivatives.jacobian = derivatives.jacobian(ispec, iz, ix);
}
}

template <typename PointPointPartialDerivativesType,
typename std::enable_if_t<
PointPointPartialDerivativesType::simd::using_simd, int> = 0>
inline void
impl_load_on_host(const specfem::point::simd_index<
PointPointPartialDerivativesType::dimension> &index,
const specfem::compute::partial_derivatives &derivatives,
PointPointPartialDerivativesType &partial_derivatives) {

const int ispec = index.ispec;
const int nspec = derivatives.nspec;
const int iz = index.iz;
const int ix = index.ix;

constexpr static bool StoreJacobian =
PointPointPartialDerivativesType::store_jacobian;

using simd = typename PointPointPartialDerivativesType::simd;
using mask_type = typename simd::mask_type;
using tag_type = typename simd::tag_type;

mask_type mask([&](std::size_t lane) { return index.mask(lane); });

Kokkos::Experimental::where(mask, partial_derivatives.xix)
.copy_from(&derivatives.h_xix(ispec, iz, ix), tag_type());
.copy_from(on_device ? &derivatives.xix(ispec, iz, ix)
: &derivatives.h_xix(ispec, iz, ix),
tag_type());
Kokkos::Experimental::where(mask, partial_derivatives.gammax)
.copy_from(&derivatives.h_gammax(ispec, iz, ix), tag_type());
.copy_from(on_device ? &derivatives.gammax(ispec, iz, ix)
: &derivatives.h_gammax(ispec, iz, ix),
tag_type());
Kokkos::Experimental::where(mask, partial_derivatives.xiz)
.copy_from(&derivatives.h_xiz(ispec, iz, ix), tag_type());
.copy_from(on_device ? &derivatives.xiz(ispec, iz, ix)
: &derivatives.h_xiz(ispec, iz, ix),
tag_type());
Kokkos::Experimental::where(mask, partial_derivatives.gammaz)
.copy_from(&derivatives.h_gammaz(ispec, iz, ix), tag_type());
.copy_from(on_device ? &derivatives.gammaz(ispec, iz, ix)
: &derivatives.h_gammaz(ispec, iz, ix),
tag_type());
if constexpr (StoreJacobian) {
Kokkos::Experimental::where(mask, partial_derivatives.jacobian)
.copy_from(&derivatives.h_jacobian(ispec, iz, ix), tag_type());
.copy_from(on_device ? &derivatives.jacobian(ispec, iz, ix)
: &derivatives.h_jacobian(ispec, iz, ix),
tag_type());
}
}

template <typename PointPartialDerivativesType,
template <bool on_device, typename PointPartialDerivativesType,
typename std::enable_if_t<
!PointPartialDerivativesType::simd::using_simd, int> = 0>
inline void impl_load_on_host(
KOKKOS_FORCEINLINE_FUNCTION void impl_load(
const specfem::point::index<PointPartialDerivativesType::dimension> &index,
const specfem::compute::partial_derivatives &derivatives,
PointPartialDerivativesType &partial_derivatives) {
Expand All @@ -184,12 +133,18 @@ inline void impl_load_on_host(
constexpr static bool StoreJacobian =
PointPartialDerivativesType::store_jacobian;

partial_derivatives.xix = derivatives.h_xix(ispec, iz, ix);
partial_derivatives.gammax = derivatives.h_gammax(ispec, iz, ix);
partial_derivatives.xiz = derivatives.h_xiz(ispec, iz, ix);
partial_derivatives.gammaz = derivatives.h_gammaz(ispec, iz, ix);
partial_derivatives.xix = on_device ? derivatives.xix(ispec, iz, ix)
: derivatives.h_xix(ispec, iz, ix);
partial_derivatives.gammax = on_device ? derivatives.gammax(ispec, iz, ix)
: derivatives.h_gammax(ispec, iz, ix);
partial_derivatives.xiz = on_device ? derivatives.xiz(ispec, iz, ix)
: derivatives.h_xiz(ispec, iz, ix);
partial_derivatives.gammaz = on_device ? derivatives.gammaz(ispec, iz, ix)
: derivatives.h_gammaz(ispec, iz, ix);
if constexpr (StoreJacobian) {
partial_derivatives.jacobian = derivatives.h_jacobian(ispec, iz, ix);
partial_derivatives.jacobian = on_device
? derivatives.jacobian(ispec, iz, ix)
: derivatives.h_jacobian(ispec, iz, ix);
}
}

Expand Down Expand Up @@ -276,7 +231,7 @@ KOKKOS_FORCEINLINE_FUNCTION void
load_on_device(const IndexType &index,
const specfem::compute::partial_derivatives &derivatives,
PointPartialDerivativesType &partial_derivatives) {
impl_load_on_device(index, derivatives, partial_derivatives);
impl_load<true>(index, derivatives, partial_derivatives);
}

/**
Expand All @@ -302,7 +257,7 @@ inline void
load_on_host(const IndexType &index,
const specfem::compute::partial_derivatives &derivatives,
PointPartialDerivativesType &partial_derivatives) {
impl_load_on_host(index, derivatives, partial_derivatives);
impl_load<false>(index, derivatives, partial_derivatives);
}

/**
Expand Down
Loading

0 comments on commit b09b9c2

Please sign in to comment.