From 5a1ca0405df60843c0a38c39f6b524200b2f49b8 Mon Sep 17 00:00:00 2001 From: ccui Date: Fri, 7 Feb 2025 16:46:33 -0500 Subject: [PATCH 1/2] Add function to access field of a medium --- include/compute/fields/data_access.tpp | 357 ++------------------ include/compute/fields/simulation_field.hpp | 17 + 2 files changed, 45 insertions(+), 329 deletions(-) diff --git a/include/compute/fields/data_access.tpp b/include/compute/fields/data_access.tpp index 1fc1cae91..7ea80a67c 100644 --- a/include/compute/fields/data_access.tpp +++ b/include/compute/fields/data_access.tpp @@ -20,17 +20,7 @@ KOKKOS_FORCEINLINE_FUNCTION void impl_load_on_device(const int iglob, constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < ViewType::components; ++icomp) { @@ -79,17 +69,7 @@ impl_load_on_device(const typename ViewType::simd::mask_type &mask, constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { @@ -152,17 +132,7 @@ impl_load_on_device(const specfem::point::simd_assembly_index &index, mask_type mask([&](std::size_t lane) { return index.mask(lane); }); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < components; ++icomp) { @@ -208,17 +178,7 @@ inline void impl_load_on_host(const int iglob, const WavefieldType &field, constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < ViewType::components; ++icomp) { @@ -266,17 +226,7 @@ inline void impl_load_on_host(const typename ViewType::simd::mask_type &mask, constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { @@ -338,17 +288,7 @@ inline void impl_load_on_host(const specfem::point::simd_assembly_index &index, mask_type mask([&](std::size_t lane) { return index.mask(lane); }); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < components; ++icomp) { @@ -501,17 +441,7 @@ impl_store_on_device(const int iglob, const ViewType &point_field, constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); + const auto &curr_field = field.template get_field(); for (int icomp = 0; icomp < ViewType::components; ++icomp) { if constexpr (StoreDisplacement) { @@ -547,18 +477,7 @@ impl_store_on_device(const typename ViewType::simd::mask_type &mask, constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { continue; @@ -620,18 +539,7 @@ impl_store_on_device(const specfem::point::simd_assembly_index &index, mask_type mask([&](std::size_t lane) { return index.mask(lane); }); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < components; ++icomp) { Kokkos::Experimental::where(mask, point_field.displacement(icomp)) @@ -676,18 +584,7 @@ inline void impl_store_on_host(const int iglob, const ViewType &point_field, constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int icomp = 0; icomp < ViewType::components; ++icomp) { if constexpr (StoreDisplacement) { curr_field.h_field(iglob, icomp) = point_field.displacement(icomp); @@ -722,18 +619,7 @@ inline void impl_store_on_host(const typename ViewType::simd::mask_type &mask, constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { continue; @@ -795,18 +681,7 @@ inline void impl_store_on_host(const specfem::point::simd_assembly_index &index, mask_type mask([&](std::size_t lane) { return index.mask(lane); }); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < components; ++icomp) { Kokkos::Experimental::where(mask, point_field.displacement(icomp)) @@ -968,18 +843,7 @@ impl_add_on_device(const int iglob, const ViewType &point_field, constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int icomp = 0; icomp < ViewType::components; ++icomp) { if constexpr (StoreDisplacement) { curr_field.field(iglob, icomp) += point_field.displacement(icomp); @@ -1014,18 +878,7 @@ impl_add_on_device(const typename ViewType::simd::mask_type &mask, constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { continue; @@ -1087,18 +940,7 @@ impl_add_on_device(const specfem::point::simd_assembly_index &index, mask_type mask([&](std::size_t lane) { return index.mask(lane); }); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < components; ++icomp) { typename ViewType::simd::datatype lhs; @@ -1163,18 +1005,7 @@ inline void impl_add_on_host(const int iglob, const ViewType &point_field, constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int icomp = 0; icomp < ViewType::components; ++icomp) { if constexpr (StoreDisplacement) { curr_field.h_field(iglob, icomp) += point_field.displacement(icomp); @@ -1209,18 +1040,7 @@ inline void impl_add_on_host(const typename ViewType::simd::mask_type &mask, constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { continue; @@ -1281,18 +1101,7 @@ inline void impl_add_on_host(const specfem::point::simd_assembly_index &index, mask_type mask([&](std::size_t lane) { return index.mask(lane); }); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < components; ++icomp) { typename ViewType::simd::datatype lhs; @@ -1472,18 +1281,7 @@ impl_atomic_add_on_device(const int iglob, const ViewType &point_field, constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int icomp = 0; icomp < ViewType::components; ++icomp) { if constexpr (StoreDisplacement) { Kokkos::atomic_add(&curr_field.field(iglob, icomp), @@ -1522,18 +1320,7 @@ impl_atomic_add_on_device(const typename ViewType::simd::mask_type &mask, constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { continue; @@ -1586,18 +1373,7 @@ inline void impl_atomic_add_on_host(const int iglob, const ViewType &point_field constexpr static bool StoreMassMatrix = ViewType::store_mass_matrix; constexpr static auto MediumType = ViewType::medium_tag; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int icomp = 0; icomp < ViewType::components; ++icomp) { if constexpr (StoreDisplacement) { Kokkos::atomic_add(&curr_field.h_field(iglob, icomp), @@ -1635,18 +1411,7 @@ inline void impl_atomic_add_on_host(const typename ViewType::simd::mask_type &ma constexpr static auto MediumType = ViewType::medium_tag; constexpr static int components = ViewType::components; - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); for (int lane = 0; lane < ViewType::simd::size(); ++lane) { if (!mask[lane]) { continue; @@ -1793,18 +1558,7 @@ impl_load_on_device(const MemberType &team, const int ispec, Kokkos::DefaultExecutionSpace>, "Calling team must have a device execution space"); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); Kokkos::parallel_for( Kokkos::TeamThreadRange(team, NGLL * NGLL), [&](const int &xz) { int iz, ix; @@ -1855,18 +1609,7 @@ inline void impl_load_on_host(const MemberType &team, const int ispec, std::is_same_v, "Calling team must have a host execution space"); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); Kokkos::parallel_for( Kokkos::TeamThreadRange(team, NGLL * NGLL), [&](const int &xz) { int iz, ix; @@ -1929,18 +1672,7 @@ impl_load_on_device(const MemberType &team, const IteratorType &iterator, typename ViewType::memory_space>::accessible, "Calling team must have access to the memory space of the view"); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); Kokkos::parallel_for( Kokkos::TeamThreadRange(team, iterator.chunk_size()), [&](const int &i) { const auto iterator_index = iterator(i); @@ -2006,18 +1738,7 @@ impl_load_on_device(const MemberType &team, const IteratorType &iterator, Kokkos::DefaultExecutionSpace>, "Calling team must have a device execution space"); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); Kokkos::parallel_for( Kokkos::TeamThreadRange(team, iterator.chunk_size()), [&](const int &i) { const auto iterator_index = iterator(i); @@ -2089,18 +1810,7 @@ inline void impl_load_on_host(const MemberType &team, const IteratorType &iterat typename ViewType::memory_space>::accessible, "Calling team must have access to the memory space of the view"); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); Kokkos::parallel_for( Kokkos::TeamThreadRange(team, iterator.chunk_size()), [&](const int &i) { const auto iterator_index = iterator(i); @@ -2164,18 +1874,7 @@ inline void impl_load_on_host(const MemberType &team, const IteratorType &iterat typename ViewType::memory_space>::accessible, "Calling team must have access to the memory space of the view"); - const auto &curr_field = - [&]() -> const specfem::compute::impl::field_impl< - specfem::dimension::type::dim2, MediumType> & { - if constexpr (MediumType == specfem::element::medium_tag::elastic) { - return field.elastic; - } else if constexpr (MediumType == specfem::element::medium_tag::acoustic) { - return field.acoustic; - } else { - static_assert("medium type not supported"); - } - }(); - + const auto &curr_field = field.template get_field(); Kokkos::parallel_for( Kokkos::TeamThreadRange(team, iterator.chunk_size()), [&](const int &i) { const auto iterator_index = iterator(i); diff --git a/include/compute/fields/simulation_field.hpp b/include/compute/fields/simulation_field.hpp index 70335c809..2fe7fb9b2 100644 --- a/include/compute/fields/simulation_field.hpp +++ b/include/compute/fields/simulation_field.hpp @@ -94,6 +94,23 @@ struct simulation_field { } } + /** + * @brief Returns the field for a given medium + * + */ + template + KOKKOS_INLINE_FUNCTION constexpr specfem::compute::impl::field_impl< + specfem::dimension::type::dim2, MediumTag> const & + get_field() const { + if constexpr (MediumTag == specfem::element::medium_tag::elastic) { + return elastic; + } else if constexpr (MediumTag == specfem::element::medium_tag::acoustic) { + return acoustic; + } else { + static_assert("medium type not supported"); + } + } + int nglob = 0; ///< Number of global degrees of freedom int nspec; ///< Number of spectral elements int ngllz; ///< Number of quadrature points in z direction From 0dee5ec9b085fc9641b71ebe2778495f958553fb Mon Sep 17 00:00:00 2001 From: Congyue Cui Date: Wed, 26 Feb 2025 14:40:46 -0500 Subject: [PATCH 2/2] Fix typo in impl_store_on_device. --- include/compute/fields/data_access.tpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/compute/fields/data_access.tpp b/include/compute/fields/data_access.tpp index 7ea80a67c..94f27af49 100644 --- a/include/compute/fields/data_access.tpp +++ b/include/compute/fields/data_access.tpp @@ -543,7 +543,7 @@ impl_store_on_device(const specfem::point::simd_assembly_index &index, if constexpr (StoreDisplacement) { for (int icomp = 0; icomp < components; ++icomp) { Kokkos::Experimental::where(mask, point_field.displacement(icomp)) - .copy_to(&curr_field.h_field(iglob, icomp), tag_type()); + .copy_to(&curr_field.field(iglob, icomp), tag_type()); } }