diff --git a/cpputil/usid/.clang-format b/cpputil/usid/.clang-format new file mode 100644 index 0000000..09cc831 --- /dev/null +++ b/cpputil/usid/.clang-format @@ -0,0 +1,60 @@ +--- +Language: Cpp +# BasedOnStyle: LLVM +AccessModifierOffset: -2 +ConstructorInitializerIndentWidth: 4 +#AlignEscapedNewlinesLeft: false +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AlwaysBreakTemplateDeclarations: true +AlwaysBreakBeforeMultilineStrings: false +BreakBeforeBinaryOperators: false +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BinPackParameters: false +BinPackArguments: false +ColumnLimit: 120 +ConstructorInitializerAllOnOneLineOrOnePerLine: false +DerivePointerAlignment: false +ExperimentalAutoDetectBinPacking: false +IndentCaseLabels: false +IndentWrappedFunctionNames: false +IndentFunctionDeclarationAfterType: false +MaxEmptyLinesToKeep: 1 +KeepEmptyLinesAtTheStartOfBlocks: true +NamespaceIndentation: All +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakString: 1000 +PenaltyBreakFirstLessLess: 120 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Right +SpacesBeforeTrailingComments: 1 +Cpp11BracedListStyle: true +Standard: Cpp11 +IndentWidth: 4 +TabWidth: 8 +UseTab: Never +BreakBeforeBraces: Attach +SpacesInParentheses: false +SpacesInAngles: false +SpaceInEmptyParentheses: false +SpacesInCStyleCastParentheses: false +SpacesInContainerLiterals: true +SpaceBeforeAssignmentOperators: true +ContinuationIndentWidth: 4 +CommentPragmas: '*' +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +SpaceBeforeParens: ControlStatements +DisableFormat: false +AlignAfterOpenBracket: false +AlignEscapedNewlinesLeft: true +... + diff --git a/cpputil/usid/CMakeLists.txt b/cpputil/usid/CMakeLists.txt new file mode 100644 index 0000000..ecf288e --- /dev/null +++ b/cpputil/usid/CMakeLists.txt @@ -0,0 +1,56 @@ +cmake_minimum_required(VERSION 3.14.5) +project(usid LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) + +add_compile_options(-Wall) + +include(FetchContent) +FetchContent_Declare(GridTools + GIT_REPOSITORY https://github.com/GridTools/gridtools.git + GIT_TAG master +) +FetchContent_MakeAvailable(GridTools) + +add_library(usid_naive_helper INTERFACE) +add_library(usid::usid_naive_helper ALIAS usid_naive_helper) +target_include_directories(usid_naive_helper INTERFACE ${PROJECT_SOURCE_DIR}/include) +target_link_libraries(usid_naive_helper INTERFACE GridTools::gridtools) + +include(CheckLanguage) +check_language(CUDA) +if(CMAKE_CUDA_COMPILER) + enable_language(CUDA) + + add_library(usid_cuda_helper INTERFACE) + add_library(usid::usid_cuda_helper ALIAS usid_cuda_helper) + target_include_directories(usid_cuda_helper INTERFACE ${PROJECT_SOURCE_DIR}/include) + target_link_libraries(usid_cuda_helper INTERFACE GridTools::gridtools GridTools::stencil_gpu) +endif() + +set(BUILD_SHARED_LIBS OFF) + +FetchContent_Declare(GoogleTest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.10.0 +) +FetchContent_MakeAvailable(GoogleTest) + +find_package(eckit REQUIRED) +find_package(Atlas REQUIRED) + +add_library(fvm_nabla_driver INTERFACE) +add_library(usid::fvm_nabla_driver ALIAS fvm_nabla_driver) +target_include_directories(fvm_nabla_driver INTERFACE ${PROJECT_SOURCE_DIR}/include) +target_link_libraries(fvm_nabla_driver INTERFACE atlas eckit GridTools::gridtools gtest) +if(CMAKE_CUDA_COMPILER) + target_link_libraries(fvm_nabla_driver INTERFACE GridTools::stencil_gpu) +endif() + + +if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) + include(CTest) + if(BUILD_TESTING) + add_subdirectory(tests) + endif() +endif() diff --git a/cpputil/usid/include/gridtools/usid/atlas.hpp b/cpputil/usid/include/gridtools/usid/atlas.hpp new file mode 100644 index 0000000..ca5e7c8 --- /dev/null +++ b/cpputil/usid/include/gridtools/usid/atlas.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include + +namespace atlas::mesh { +template +auto make_storage_producer(MaxNeighbors max_neighbors, + Connectivity const &src) { + return [&src, max_neighbors](auto traits) { + return gridtools::storage::builder + .template type() + .dimensions(src.rows(), max_neighbors) + .initializer([&src](auto row, auto col) { return col < src.cols(row) ? src(row, col) : -1; }) + .build(); + }; +} +} // namespace atlas::mesh diff --git a/cpputil/usid/include/gridtools/usid/cuda_helpers.hpp b/cpputil/usid/include/gridtools/usid/cuda_helpers.hpp new file mode 100644 index 0000000..7861453 --- /dev/null +++ b/cpputil/usid/include/gridtools/usid/cuda_helpers.hpp @@ -0,0 +1,52 @@ +#pragma once + +#ifndef __CUDACC__ +#error Tried to compile CUDA code with a regular C++ compiler. +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "dim.hpp" +#include "helpers.hpp" + +namespace gridtools::usid::cuda { +using traits_t = storage::gpu; + +inline auto make_allocator() { + return sid::device::make_cached_allocator(&cuda_util::cuda_malloc); +} + +template +__global__ void kernel(int h_size, Ptr ptr_holder, Strides strides, + Neighbors... neighbors) { + auto h = blockIdx.x * blockDim.x + threadIdx.x; + if (h >= h_size) + return; + Kernel()()( + sid::shifted(ptr_holder(), device::at_key(strides), h), strides, + tuple_util::device::make(neighbors.first(), neighbors.second)...); +} + +template +void call_kernel(Size size, Sid &&fields, Sids &&...neighbor_fields) { + int threads_per_block = 32; + int blocks = (size + threads_per_block - 1) / threads_per_block; + kernel<<>>( + size, sid::get_origin(fields), sid::get_strides(fields), + tuple_util::make( + sid::get_origin(neighbor_fields), + at_key(sid::get_strides(neighbor_fields)))...); + GT_CUDA_CHECK(cudaGetLastError()); +} + +template +__device__ decltype(auto) field(Ptr const &ptr) { + return *device::at_key(ptr); +} +} // namespace gridtools::usid::cuda diff --git a/cpputil/usid/include/gridtools/usid/dim.hpp b/cpputil/usid/include/gridtools/usid/dim.hpp new file mode 100644 index 0000000..48214f8 --- /dev/null +++ b/cpputil/usid/include/gridtools/usid/dim.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace gridtools::usid::dim { +using horizontal = integral_constant; +using vertical = integral_constant; +using neighbor = integral_constant; +using sparse = integral_constant; + +using h = horizontal; +using k = vertical; +using n = neighbor; +using s = sparse; +} // namespace gridtools::usid::dim diff --git a/cpputil/usid/include/gridtools/usid/domain.hpp b/cpputil/usid/include/gridtools/usid/domain.hpp new file mode 100644 index 0000000..b709e93 --- /dev/null +++ b/cpputil/usid/include/gridtools/usid/domain.hpp @@ -0,0 +1,10 @@ +#pragma once + +namespace gridtools::usid { +struct domain { + int vertex; + int edge; + int cell; + int k; +}; +} // namespace gridtools::usid diff --git a/cpputil/usid/include/gridtools/usid/helpers.hpp b/cpputil/usid/include/gridtools/usid/helpers.hpp new file mode 100644 index 0000000..2fcd5eb --- /dev/null +++ b/cpputil/usid/include/gridtools/usid/helpers.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dim.hpp" +#include "domain.hpp" + +namespace gridtools::usid { + template + auto make_simple_tmp_storage(HSize h_size, KSize k_size, Alloc &alloc) { + return sid::make_contiguous(alloc, hymap::keys::values(h_size, k_size)); + } + + template + struct connectivity { + using max_neighbors_t = integral_constant; + using has_skip_values_t = std::bool_constant; + }; + + template + struct sparse_field { + using connectivity_t = Connectivity; + }; + + template + struct is_sparse_field : std::false_type {}; + + template + struct is_sparse_field, T>>> + : std::true_type {}; + + template + GT_FUNCTION T fold_neighbors(F f, Init init, G g, Ptr &&ptr, Strides &&strides, Neighbors &&neighbors) { + T acc = init(meta::lazy::id()); + sid::make_loop(typename Conncectivity::max_neighbors_t())([&](auto const &ptr, auto &&) { + auto i = *host_device::at_key(ptr); + if constexpr (Conncectivity::has_skip_values_t::value) + if (i < 0) + return; + acc = f(acc, g(ptr, sid::shifted(neighbors.first, neighbors.second, i))); + })(wstd::forward(ptr), strides); + return acc; + } + + template + GT_FUNCTION T sum_neighbors(F f, Args &&...args) { + return fold_neighbors([](auto x, auto y) { return x + y; }, + [](auto z) -> typename decltype(z)::type { return 0; }, + f, + wstd::forward(args)...); + } + + template + GT_FUNCTION T product_neighbors(F f, Args &&...args) { + return fold_neighbors([](auto x, auto y) { return x * y; }, + [](auto z) -> typename decltype(z)::type { return 1; }, + f, + wstd::forward(args)...); + } + + template + GT_FUNCTION T min_neighbors(F f, Args &&...args) { + return fold_neighbors([](auto x, auto y) { return x < y ? x : y; }, + [](auto z) -> typename decltype(z)::type { + constexpr auto res = std::numeric_limits::max(); + return res; + }, + f, + wstd::forward(args)...); + } + + template + GT_FUNCTION T max_neighbors(F f, Args &&...args) { + return fold_neighbors([](auto x, auto y) { return x > y ? x : y; }, + [](auto z) -> typename decltype(z)::type { + constexpr auto res = std::numeric_limits::min(); + return res; + }, + f, + wstd::forward(args)...); + } + + template + decltype(auto) composite_item(Tag, Val &&val) { + if constexpr (is_sparse_field::value) + return sid::rename_dimensions(std::forward(val)); + else + return std::forward(val); + } + + template + struct make_composite_f { + template + auto operator()(Vals &&...vals) const { + return sid::composite::make(composite_item(Tags(), std::forward(vals))...); + } + }; + + template + constexpr make_composite_f make_composite = {}; +} // namespace gridtools::usid diff --git a/cpputil/usid/include/gridtools/usid/naive_helpers.hpp b/cpputil/usid/include/gridtools/usid/naive_helpers.hpp new file mode 100644 index 0000000..c094bb4 --- /dev/null +++ b/cpputil/usid/include/gridtools/usid/naive_helpers.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "dim.hpp" +#include "helpers.hpp" + +namespace gridtools::usid::naive { +using traits_t = storage::cpu_ifirst; + +inline auto make_allocator() { + return sid::make_cached_allocator(&std::make_unique); +} + +template +void call_kernel(Size size, Sid &&fields, Sids &&...neighbor_fields) { + sid::make_loop(size)( + [params = std::make_tuple(std::make_pair( + sid::get_origin(neighbor_fields)(), + sid::get_stride(sid::get_strides(neighbor_fields)))...)]( + auto &ptr, auto const &strides) { + std::apply(Kernel()(), + std::tuple_cat(std::forward_as_tuple(ptr, strides), params)); + })(sid::get_origin(fields)(), sid::get_strides(fields)); +} + +template decltype(auto) field(Ptr const &ptr) { + return *at_key(ptr); +} +} // namespace gridtools::usid::naive diff --git a/cpputil/usid/include/tests/fvm_nabla_driver.hpp b/cpputil/usid/include/tests/fvm_nabla_driver.hpp new file mode 100644 index 0000000..752b42b --- /dev/null +++ b/cpputil/usid/include/tests/fvm_nabla_driver.hpp @@ -0,0 +1,176 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#ifdef __CUDACC__ +#include +namespace fvm_nabla_driver_impl_ { + using storage_traits_t = gridtools::storage::gpu; +} +#else +#include +namespace fvm_nabla_driver_impl_ { + using storage_traits_t = gridtools::storage::cpu_ifirst; +} +#endif + +namespace fvm_nabla_driver_impl_ { + using namespace gridtools::literals; + + inline constexpr auto storage_builder = gridtools::storage::builder; + + inline auto make_mesh() { + using namespace atlas; + using param_t = StructuredMeshGenerator::Parameters; + auto res = StructuredMeshGenerator(param_t("triangulate", true) | param_t("angle", -1.)).generate(Grid("O32")); + functionspace::EdgeColumns(res, option::levels(1) | option::halo(1)); + functionspace::NodeColumns(res, option::levels(1) | option::halo(1)); + mesh::actions::build_edges(res); + mesh::actions::build_node_to_edge_connectivity(res); + mesh::actions::build_median_dual_mesh(res); + return res; + } + + inline const double rpi = 2 * std::asin(1); + inline const double radius = 6371.22e+03; + inline const double deg2rad = 2 * rpi / 360; + inline constexpr int MXX = 0; + inline constexpr int MYY = 1; + inline constexpr auto edges_per_node = 7_c; + + inline auto make_vol(atlas::mesh::Nodes const &nodes) { + auto dual_volumes = atlas::array::make_view(nodes.field("dual_volumes")); + auto init = [&](int n) { return dual_volumes(n) * (std::pow(deg2rad, 2) * std::pow(radius, 2)); }; + return storage_builder.type().dimensions(nodes.size()).initializer(init).build(); + } + + inline auto make_sign(atlas::Mesh const &mesh) { + auto &&edges = mesh.edges(); + auto &&nodes = mesh.nodes(); + auto &&n2e = nodes.edge_connectivity(); + auto &&e2n = edges.node_connectivity(); + auto flags = atlas::array::make_view(edges.flags()); + auto is_pole_edge = [&](auto e) { + using topology_t = atlas::mesh::Nodes::Topology; + return topology_t::check(flags(e), topology_t::POLE); + }; + auto init = [&](int n, int, int e) -> double { + if (e >= n2e.cols(n)) + return 0; + auto ee = n2e(n, e); + return n == e2n(ee, 0) || is_pole_edge(ee) ? 1 : -1; + }; + return storage_builder.type() + .selector<1, 0, 1>() + .dimensions(nodes.size(), 1, edges_per_node) + .initializer(init) + .build(); + } + + inline auto make_S_MXX(atlas::mesh::Edges const &edges) { + auto dual_normals = atlas::array::make_view(edges.field("dual_normals")); + return storage_builder.type() + .dimensions(edges.size()) + .initializer([&](int i) { return dual_normals(i, MXX) * radius * deg2rad; }) + .build(); + } + + inline auto make_S_MYY(atlas::mesh::Edges const &edges) { + auto dual_normals = atlas::array::make_view(edges.field("dual_normals")); + return storage_builder.type() + .dimensions(edges.size()) + .initializer([&](int i) { return dual_normals(i, MYY) * radius * deg2rad; }) + .build(); + } + + // TODO ask Christian for a proper name for this input data + inline auto make_pp(atlas::mesh::Nodes const &nodes) { + static const double zh0 = 2000; + static const double zrad = 3 * rpi / 4 * radius; + static const double zeta = rpi / 16 * radius; + static const double zlatc = 0; + static const double zlonc = 3 * rpi / 2; + + auto lonlat = atlas::array::make_view(nodes.field("lonlat")); + // lonlatcr is in physical space and may differ from coords later + auto rlonlatcr = [&](int n, int i) { return lonlat(n, i) * deg2rad; }; + auto init = [&](int n) { + double zlon = rlonlatcr(n, MXX); + double rcosa = std::cos(rlonlatcr(n, MYY)); + double rsina = std::sin(rlonlatcr(n, MYY)); + double zdist = std::sin(zlatc) * rsina + std::cos(zlatc) * rcosa * std::cos(zlon - zlonc); + zdist = radius * std::acos(zdist); + return zdist < zrad + ? .5 * zh0 * (1 + std::cos(rpi * zdist / zrad)) * std::pow(std::cos(rpi * zdist / zeta), 2) + : 0; + }; + return storage_builder.type().dimensions(nodes.size()).initializer(init).build(); + } + + template + auto min_max(T const &field) { + double min = std::numeric_limits::max(); + double max = std::numeric_limits::min(); + auto view = field->const_host_view(); + auto lengths = field->lengths(); + for (int i = 0; i < (int)lengths[0]; ++i) + for (int k = 0; k < (int)lengths[1]; ++k) { + double x = view(i, k); + min = std::min(min, x); + max = std::max(max, x); + } + return std::make_tuple(min, max); + } + + template + void fvm_nabla_driver(Nabla nabla) { + constexpr auto k = 1_c; + + auto mesh = make_mesh(); + auto &&edges = mesh.edges(); + auto &&nodes = mesh.nodes(); + + // output + auto make_output = storage_builder.type().dimensions(nodes.size(), k); + auto pnabla_MXX = make_output(); + auto pnabla_MYY = make_output(); + + nabla({nodes.size(), edges.size(), mesh.cells().size(), k}, + make_storage_producer(edges_per_node, nodes.edge_connectivity()), + make_storage_producer(2_c, edges.node_connectivity()))(make_S_MXX(edges), + make_S_MYY(edges), + make_pp(nodes), + pnabla_MXX, + pnabla_MYY, + make_vol(nodes), + make_sign(mesh)); + + auto [x_min, x_max] = min_max(pnabla_MXX); + auto [y_min, y_max] = min_max(pnabla_MYY); + + EXPECT_DOUBLE_EQ(-3.5455427772566003E-003, x_min); + EXPECT_DOUBLE_EQ(3.5455427772565435E-003, x_max); + EXPECT_DOUBLE_EQ(-3.3540113705465301E-003, y_min); + EXPECT_DOUBLE_EQ(3.3540113705465301E-003, y_max); + } // +} // namespace fvm_nabla_driver_impl_ +using fvm_nabla_driver_impl_::fvm_nabla_driver; diff --git a/cpputil/usid/tests/CMakeLists.txt b/cpputil/usid/tests/CMakeLists.txt new file mode 100644 index 0000000..5f0c7df --- /dev/null +++ b/cpputil/usid/tests/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(nabla_naive nabla_naive.cpp) +target_link_libraries(nabla_naive PRIVATE usid_naive_helper fvm_nabla_driver gtest_main) +add_test(NAME nabla_naive COMMAND $) + +if(CMAKE_CUDA_COMPILER) + add_executable(nabla_cuda nabla_cuda.cu) + target_link_libraries(nabla_cuda PRIVATE usid_cuda_helper fvm_nabla_driver gtest_main) + gridtools_setup_target(nabla_cuda CUDA_ARCH sm_50) + target_compile_options(nabla_cuda PRIVATE "-std=c++17") + add_test(NAME nabla_cuda COMMAND $) +endif() diff --git a/cpputil/usid/tests/nabla_cuda.cu b/cpputil/usid/tests/nabla_cuda.cu new file mode 100644 index 0000000..a5784f9 --- /dev/null +++ b/cpputil/usid/tests/nabla_cuda.cu @@ -0,0 +1,5 @@ +#include "nabla_cuda.hpp" +#include +#include + +TEST(fvm, nabla_cuda) { fvm_nabla_driver(nabla); } diff --git a/cpputil/usid/tests/nabla_cuda.hpp b/cpputil/usid/tests/nabla_cuda.hpp new file mode 100644 index 0000000..3a57185 --- /dev/null +++ b/cpputil/usid/tests/nabla_cuda.hpp @@ -0,0 +1,72 @@ +#pragma once +#include +namespace gridtools::usid::cuda::nabla_impl_ { + struct v2e_tag; + struct e2v_tag; + struct S_MXX_tag; + struct S_MYY_tag; + struct zavgS_MXX_tag; + struct zavgS_MYY_tag; + struct pnabla_MXX_tag; + struct pnabla_MYY_tag; + struct vol_tag; + struct sign_tag; + struct pp_tag; + struct kernel_0 { + GT_FUNCTION auto operator()() const { + return [](auto &&ptr, auto &&strides, auto &&neighbors) { + auto zavg = 0.5 * sum_neighbors( + [](auto &&, auto &&n) { return field(n); }, ptr, strides, neighbors); + field(ptr) = field(ptr) * zavg; + field(ptr) = field(ptr) * zavg; + }; + } + }; + struct kernel_1 { + GT_FUNCTION auto operator()() const { + return [](auto &&ptr, auto &&strides, auto &&neighbors) { + field(ptr) = sum_neighbors( + [](auto &&p, auto &&n) { return field(n) * field(p); }, + ptr, + strides, + neighbors); + field(ptr) = sum_neighbors( + [](auto &&p, auto &&n) { return field(n) * field(p); }, + ptr, + strides, + neighbors); + field(ptr) = field(ptr) / field(ptr); + field(ptr) = field(ptr) / field(ptr); + }; + } + }; + inline constexpr auto nabla = [](domain d, auto &&v2e, auto &&e2v) { + static_assert(is_sid()); + static_assert(is_sid()); + return + [d = std::move(d), + v2e = sid::rename_dimensions(std::forward(v2e)(traits_t())), + e2v = sid::rename_dimensions(std::forward(e2v)(traits_t()))]( + auto &&S_MXX, auto &&S_MYY, auto &&pp, auto &&pnabla_MXX, auto &&pnabla_MYY, auto &&vol, auto &&sign) { + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + auto alloc = make_allocator(); + auto zavgS_MXX = make_simple_tmp_storage(d.edge, d.k, alloc); + auto zavgS_MYY = make_simple_tmp_storage(d.edge, d.k, alloc); + call_kernel(d.edge, + sid::composite::make( + e2v, S_MXX, S_MYY, zavgS_MXX, zavgS_MYY), + sid::composite::make(pp)); + call_kernel(d.vertex, + sid::composite::make( + v2e, pnabla_MXX, pnabla_MYY, sid::rename_dimensions(sign), vol), + sid::composite::make(zavgS_MXX, zavgS_MYY)); + }; + }; +} // namespace gridtools::usid::cuda::nabla_impl_ +using gridtools::usid::cuda::nabla_impl_::nabla; diff --git a/cpputil/usid/tests/nabla_naive.cpp b/cpputil/usid/tests/nabla_naive.cpp new file mode 100644 index 0000000..a3cca65 --- /dev/null +++ b/cpputil/usid/tests/nabla_naive.cpp @@ -0,0 +1,5 @@ +#include "nabla_naive.hpp" +#include +#include + +TEST(fvm, nabla_naive) { fvm_nabla_driver(nabla); } diff --git a/cpputil/usid/tests/nabla_naive.hpp b/cpputil/usid/tests/nabla_naive.hpp new file mode 100644 index 0000000..8e4e115 --- /dev/null +++ b/cpputil/usid/tests/nabla_naive.hpp @@ -0,0 +1,66 @@ +#pragma once +#include +namespace gridtools::usid::naive::nabla_impl_ { + struct v2e_tag; + struct e2v_tag; + struct S_MXX_tag; + struct S_MYY_tag; + struct zavgS_MXX_tag; + struct zavgS_MYY_tag; + struct pnabla_MXX_tag; + struct pnabla_MYY_tag; + struct vol_tag; + struct sign_tag; + struct pp_tag; + struct kernel_0 { + GT_FUNCTION auto operator()() const { + return [](auto &&e, auto &&strides, auto &&v) { + auto zavg = 0.5 * sum_neighbors( + [](auto &&e, auto &&v) { return field(v); }, e, strides, v); + field(e) = field(e) * zavg; + field(e) = field(e) * zavg; + }; + } + }; + struct kernel_1 { + GT_FUNCTION auto operator()() const { + return [](auto &&v, auto &&strides, auto &&e) { + field(v) = sum_neighbors( + [](auto &&v, auto &&e) { return field(e) * field(v); }, v, strides, e); + field(v) = sum_neighbors( + [](auto &&v, auto &&e) { return field(e) * field(v); }, v, strides, e); + field(v) = field(v) / field(v); + field(v) = field(v) / field(v); + }; + } + }; + inline constexpr auto nabla = [](domain d, auto &&v2e, auto &&e2v) { + static_assert(is_sid()); + static_assert(is_sid()); + return + [d = std::move(d), + v2e = sid::rename_dimensions(std::forward(v2e)(traits_t())), + e2v = sid::rename_dimensions(std::forward(e2v)(traits_t()))]( + auto &&S_MXX, auto &&S_MYY, auto &&pp, auto &&pnabla_MXX, auto &&pnabla_MYY, auto &&vol, auto &&sign) { + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + static_assert(is_sid()); + auto alloc = make_allocator(); + auto zavgS_MXX = make_simple_tmp_storage(d.edge, d.k, alloc); + auto zavgS_MYY = make_simple_tmp_storage(d.edge, d.k, alloc); + call_kernel(d.edge, + sid::composite::make( + e2v, S_MXX, S_MYY, zavgS_MXX, zavgS_MYY), + sid::composite::make(pp)); + call_kernel(d.vertex, + sid::composite::make( + v2e, pnabla_MXX, pnabla_MYY, sid::rename_dimensions(sign), vol), + sid::composite::make(zavgS_MXX, zavgS_MYY)); + }; + }; +} // namespace gridtools::usid::naive::nabla_impl_ +using gridtools::usid::naive::nabla_impl_::nabla; diff --git a/setup.cfg b/setup.cfg index 8d13861..de4c799 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ install_requires = black>=19.10b0 boltons>=20.0 devtools>=0.5 + toolz>=0.11.1 jinja2>=2.10 lark-parser>=0.8 mako>=1.1 @@ -120,7 +121,7 @@ lines_after_imports = 2 default_section = THIRDPARTY sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER known_first_party = eve,gtc -known_third_party = atlas4py,black,boltons,cppimport,devtools,fvm_nabla_wrapper,jinja2,mako,networkx,numpy,packaging,pydantic,pytest,setuptools,sphinx_material,typing_inspect,xxhash +known_third_party = atlas4py,black,boltons,cppimport,devtools,fvm_nabla_wrapper,jinja2,mako,networkx,numpy,packaging,pydantic,pytest,setuptools,sphinx_material,toolz,typing_inspect,xxhash #-- mypy -- diff --git a/src/gt_frontend/built_in_types.py b/src/gt_frontend/built_in_types.py index 784855b..8fe5419 100644 --- a/src/gt_frontend/built_in_types.py +++ b/src/gt_frontend/built_in_types.py @@ -76,6 +76,10 @@ class Field(BuiltInType): pass +class SparseField(BuiltInType): + pass + + class TemporaryField(BuiltInType): # TODO(tehrengruber): make this a subtype of Field pass @@ -90,3 +94,7 @@ class Local(BuiltInType): """ pass + + +class Connectivity(BuiltInType): + pass diff --git a/src/gt_frontend/gtscript.py b/src/gt_frontend/gtscript.py index 1f84cce..1e9bd1a 100644 --- a/src/gt_frontend/gtscript.py +++ b/src/gt_frontend/gtscript.py @@ -18,7 +18,7 @@ import gtc.common as common -from .built_in_types import Field, Local, Location, Mesh, TemporaryField +from .built_in_types import Connectivity, Field, Local, Location, Mesh, SparseField, TemporaryField # built-in symbols @@ -47,6 +47,8 @@ "Vertex", "Edge", "Cell", + "Connectivity", + "SparseField", ] diff --git a/src/gtc/unstructured/frontend.py b/src/gtc/unstructured/frontend.py new file mode 100644 index 0000000..9ad188b --- /dev/null +++ b/src/gtc/unstructured/frontend.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from toolz.functoolz import compose + +from gtc.unstructured import gtir_to_usid2, py_to_gtir2, usid2_codegen +from gtc.unstructured.gtir_passes2 import merge_stencils + + +def _impl(codegen): + return compose( + codegen, gtir_to_usid2.transform, merge_stencils.transform, py_to_gtir2.transform + ) + + +naive = _impl(usid2_codegen.naive) +gpu = _impl(usid2_codegen.gpu) diff --git a/src/gtc/unstructured/gtir2.py b/src/gtc/unstructured/gtir2.py new file mode 100644 index 0000000..9917760 --- /dev/null +++ b/src/gtc/unstructured/gtir2.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import enum +from typing import Any, Union + +from eve import Bool, Int, Node, Str, StrEnum +from eve.typingx import FrozenList +from gtc import common + + +class Connectivity(Node): + name: Str + primary: common.LocationType + secondary: common.LocationType + max_neighbors: Int + has_skip_values: Bool + + +class Field(Node): + name: Str + location_type: common.LocationType + dtype: common.DataType + + +class SparseField(Node): + name: Str + connectivity: Str + dtype: common.DataType + + +class Expr(Node): + pass + + +class FieldAccess(Expr): + name: Str + location: Str + + +class SparseFieldAccess(Expr): + name: Str + primary: Str + secondary: Str + + +class BinaryOp(Expr): + op: common.BinaryOperator + left: Expr + right: Expr + + +class Literal(Expr): + value: Any + dtype: common.DataType + + +@enum.unique +class ReduceOperator(StrEnum): + """Reduction operator identifier.""" + + SUM = "sum" + PRODUCT = "product" + MIN = "min" + MAX = "max" + + +class SecondaryLocation(Node): + name: Str + connectivity: Str + primary: Str + + +class NeighborReduce(Expr): + op: ReduceOperator + dtype: common.DataType + location: SecondaryLocation + body: Expr + + +class Assign(Node): + left: FieldAccess + right: Expr + + +class PrimaryLocation(Node): + name: Str + location_type: common.LocationType + + +class Stencil(Node): + loop_order: common.LoopOrder + location: PrimaryLocation + body: FrozenList[Assign] + + +class Computation(Node): + name: Str + connectivities: FrozenList[Connectivity] + args: FrozenList[Union[Field, SparseField]] + temporaries: FrozenList[Field] + stencils: FrozenList[Stencil] diff --git a/src/gtc/unstructured/gtir_passes2/merge_stencils.py b/src/gtc/unstructured/gtir_passes2/merge_stencils.py new file mode 100644 index 0000000..6b0e4d5 --- /dev/null +++ b/src/gtc/unstructured/gtir_passes2/merge_stencils.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import functools + +import eve +from gtc.unstructured import gtir2 +from gtc.unstructured.gtir_passes2 import rename_location + + +def _merge(lhs, rhs): + t = rename_location.transform(rhs.location.name, lhs.location.name) + return gtir2.Stencil( + loop_order=lhs.loop_order, + location=lhs.location, + body=lhs.body + tuple(t(e) for e in rhs.body), + ) + + +def _folder(body, cur): + if len(body) == 0: + return (cur,) + last = body[-1] + return ( + body[0:-1] + (_merge(last, cur),) + if last.loop_order == cur.loop_order + and last.location.location_type == cur.location.location_type + else body + (cur,) + ) + + +class _Visitor(eve.NodeVisitor): + def visit_Computation(self, src: gtir2.Computation): + return gtir2.Computation( + name=src.name, + connectivities=src.connectivities, + args=src.args, + temporaries=src.temporaries, + stencils=functools.reduce(_folder, src.stencils, ()), + ) + + +transform = _Visitor().visit diff --git a/src/gtc/unstructured/gtir_passes2/rename_location.py b/src/gtc/unstructured/gtir_passes2/rename_location.py new file mode 100644 index 0000000..10e1ac8 --- /dev/null +++ b/src/gtc/unstructured/gtir_passes2/rename_location.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import eve +from gtc.unstructured import gtir2 + + +class _Translator(eve.NodeTranslator): + def visit_FieldAccess(self, src: gtir2.FieldAccess, renamer): + return gtir2.FieldAccess(name=src.name, location=renamer(src.location)) + + def visit_SparseFieldAccess(self, src: gtir2.SparseFieldAccess, renamer): + return gtir2.SparseFieldAccess( + name=src.name, primary=renamer(src.primary), secondary=renamer(src.secondary) + ) + + def visit_SecondaryLocation(self, src: gtir2.SecondaryLocation, renamer): + return gtir2.SecondaryLocation( + name=src.name, connectivity=src.connectivity, primary=renamer(src.primary), + ) + + def visit_NeighborReduce(self, src: gtir2.NeighborReduce, **kwarg): + return gtir2.NeighborReduce( + op=src.op, + dtype=src.dtype, + location=self.visit(src.location, **kwarg), + body=self.visit(src.body, **kwarg), + ) + + def visit_BinaryOp(self, src: gtir2.BinaryOp, **kwarg): + return gtir2.BinaryOp( + op=src.op, left=self.visit(src.left, **kwarg), right=self.visit(src.right, **kwarg) + ) + + def visit_Assign(self, src: gtir2.Assign, **kwarg): + return gtir2.Assign( + left=self.visit(src.left, **kwarg), right=self.visit(src.right, **kwarg) + ) + + +def transform(old, new): + return ( + lambda x: x + if old == new + else lambda src: _Translator().visit(src, lambda x: new if x == old else x) + ) diff --git a/src/gtc/unstructured/gtir_to_usid2.py b/src/gtc/unstructured/gtir_to_usid2.py new file mode 100644 index 0000000..42ebece --- /dev/null +++ b/src/gtc/unstructured/gtir_to_usid2.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import functools +import itertools + +from toolz.functoolz import compose + +import eve +from gtc import common +from gtc.unstructured import gtir2, usid2 + + +_C_TYPES = { + common.DataType.FLOAT64: "double", + common.DataType.FLOAT32: "float", + common.DataType.INT32: "::std::int32_t", + common.DataType.UINT32: "::std::uint32_t", + common.DataType.BOOLEAN: "bool", +} + +_LITERAL_CONVERTERS = { + common.DataType.FLOAT64: compose(str, float), + common.DataType.FLOAT32: lambda x: str(float(x)) + "f", + common.DataType.INT32: compose(str, int), + common.DataType.UINT32: lambda x: str(int(x)) + "u", + common.DataType.BOOLEAN: lambda x: str(bool(x)).lower(), +} + + +def _loc2str(x: common.LocationType): + return x.name.lower() + + +class _PrimaryCompositeExtractor(eve.NodeVisitor): + def visit_Literal(self, src: gtir2.Literal, **kwargs): + return set() + + def visit_FieldAccess(self, src: gtir2.FieldAccess, primary): + return {src.name} if src.location == primary else set() + + def visit_SparseFieldAccess(self, src: gtir2.SparseFieldAccess, primary): + assert src.primary == primary + return {src.name} + + def visit_BinaryOp(self, src: gtir2.BinaryOp, **kwargs): + return self.visit(src.left, **kwargs) | self.visit(src.right, **kwargs) + + def visit_NeighborReduce(self, src: gtir2.NeighborReduce, primary): + assert src.location.primary == primary + return {src.location.connectivity} | self.visit(src.body, primary=primary) + + def visit_Assign(self, src: gtir2.Assign, **kwargs): + return self.visit(src.left, **kwargs) | self.visit(src.right, **kwargs) + + def visit_Stencil(self, src: gtir2.Stencil): + return functools.reduce( + lambda acc, e: acc | self.visit(e, primary=src.location.name), src.body, set() + ) + + +_extract_primary_composite = _PrimaryCompositeExtractor().visit + + +class _SecondaryCompositeExtractor(eve.NodeVisitor): + def visit_SparseFieldAccess(self, src: gtir2.SparseFieldAccess, **kwargs): + return set() + + def visit_Literal(self, src: gtir2.Literal, **kwargs): + return set() + + def visit_FieldAccess(self, src: gtir2.FieldAccess, secondary, **kwargs): + return {src.name} if src.location == secondary else {} + + def visit_BinaryOp(self, src: gtir2.BinaryOp, **kwargs): + return self.visit(src.left, **kwargs) | self.visit(src.right, **kwargs) + + def visit_NeighborReduce(self, src: gtir2.NeighborReduce): + return {src.location.connectivity: self.visit(src.body, secondary=src.location.name)} + + +_extract_secondary_composite = _SecondaryCompositeExtractor().visit + + +def _merge_dicts_of_sets(*srcs): + res = {} + for k in set(itertools.chain(*srcs)): + res[k] = set() + for src in srcs: + if k in src: + res[k] = res[k] | src[k] + return res + + +class _SecondaryCompositesExtractor(eve.NodeVisitor): + def visit_FieldAccess(self, src: gtir2.FieldAccess): + return {} + + def visit_Literal(self, src: gtir2.Literal): + return {} + + def visit_Assign(self, src: gtir2.Assign): + return self.visit(src.right) + + def visit_NeighborReduce(self, src: gtir2.NeighborReduce): + return _extract_secondary_composite(src) + + def visit_BinaryOp(self, src: gtir2.BinaryOp): + return _merge_dicts_of_sets(self.visit(src.left), self.visit(src.right)) + + def visit_Stencil(self, src: gtir2.Stencil): + return _merge_dicts_of_sets(*(self.visit(e) for e in src.body)).items() + + +_extract_secondary_composites = _SecondaryCompositesExtractor().visit + + +class _Visitor(eve.NodeVisitor): + def visit_FieldAccess(self, src: gtir2.FieldAccess): + return usid2.FieldAccess(name=src.name, location=src.location) + + def visit_SparseFieldAccess(self, src: gtir2.SparseFieldAccess): + return usid2.FieldAccess(name=src.name, location=src.primary) + + def visit_Literal(self, src: gtir2.Literal): + return usid2.Literal(value=_LITERAL_CONVERTERS[src.dtype](src.value)) + + def visit_BinaryOp(self, src: gtir2.BinaryOp): + return usid2.BinaryOp(op=src.op, left=self.visit(src.left), right=self.visit(src.right)) + + def visit_NeighborReduce(self, src: gtir2.NeighborReduce): + return usid2.NeighborReduce( + op=src.op, + dtype=_C_TYPES[src.dtype], + connectivity=src.location.connectivity, + primary=src.location.primary, + secondary=src.location.name, + body=self.visit(src.body), + ) + + def visit_Assign(self, src: gtir2.Assign): + return usid2.Assign(left=self.visit(src.left), right=self.visit(src.right)) + + def visit_Stencil(self, src: gtir2.Stencil): + return usid2.Kernel( + location_type=_loc2str(src.location.location_type), + primary=usid2.Composite(name=src.location.name, items=_extract_primary_composite(src)), + secondaries=tuple( + usid2.Composite(name=name, items=items) + for name, items in _extract_secondary_composites(src) + ), + body=tuple(self.visit(e) for e in src.body), + ) + + def visit_Connectivity(self, src: gtir2.Connectivity): + return usid2.Connectivity( + name=src.name, max_neighbors=src.max_neighbors, has_skip_values=src.has_skip_values + ) + + def visit_Field(self, src: gtir2.Field): + return usid2.Field(name=src.name) + + def visit_SparseField(self, src: gtir2.SparseField): + return usid2.SparseField(name=src.name, connectivity=src.connectivity) + + def visit_Computation(self, src: gtir2.Computation): + return usid2.Computation( + name=src.name, + connectivities=tuple(self.visit(e) for e in src.connectivities), + args=tuple(self.visit(e) for e in src.args), + temporaries=tuple( + usid2.Temporary( + name=e.name, location_type=_loc2str(e.location_type), dtype=_C_TYPES[e.dtype] + ) + for e in src.temporaries + ), + kernels=tuple(self.visit(e) for e in src.stencils), + ) + + +transform = _Visitor().visit diff --git a/src/gtc/unstructured/gtscript_ast2.py b/src/gtc/unstructured/gtscript_ast2.py new file mode 100644 index 0000000..e88bf67 --- /dev/null +++ b/src/gtc/unstructured/gtscript_ast2.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Optional, Union + +from eve import Node, Str +from eve.typingx import FrozenList + + +class Stmt(Node): + pass + + +class Expr(Node): + pass + + +class Name(Expr): + name: Str + + +class Subscript(Expr): + value: Name + index: Union[Name, FrozenList[Name]] + + +class Call(Expr): + func: Name + args: FrozenList[Expr] + + +class Comprehension(Node): + target: Union[Name, FrozenList[Name]] + iter_: Expr + + +class GeneratorExp(Expr): + elt: Expr + generators: FrozenList[Comprehension] + + +class BinOp(Expr): + left: Expr + op: Str + right: Expr + + +class Constant(Expr): + value: Any + + +class Assign(Stmt): + targets: FrozenList[Expr] + value: Expr + + +class WithItem(Node): + expr: Expr + var: Optional[Name] + + +class With(Stmt): + items: FrozenList[WithItem] + body: FrozenList[Stmt] + + +class Function(Node): + name: Str + body: FrozenList[Stmt] diff --git a/src/gtc/unstructured/py_to_gtir2.py b/src/gtc/unstructured/py_to_gtir2.py new file mode 100644 index 0000000..33c7152 --- /dev/null +++ b/src/gtc/unstructured/py_to_gtir2.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import collections +import functools +import inspect +import sys + +from gt_frontend import built_in_types + +import eve +from gtc import common +from gtc.unstructured import gtir2, gtscript_ast2, py_to_gtscript2 + + +def _is_conncetivity_annotation(src): + return hasattr(src, "__supertype__") and issubclass( + src.__supertype__, built_in_types.Connectivity + ) + + +def _is_field_annotation(src): + return issubclass(src, built_in_types.Field) + + +def _is_sparse_field_annotation(src): + return issubclass(src, built_in_types.SparseField) + + +def _is_legit_annotation(src): + return ( + _is_conncetivity_annotation(src) + or _is_field_annotation(src) + or _is_sparse_field_annotation(src) + ) + + +def _extract_params(fun): + res = tuple(p for _, p in inspect.signature(fun).parameters.items()) + for p in res: + if not _is_legit_annotation(p.annotation): + raise RuntimeError(f"unexpected type annotation: {p}") + return res + + +def _extract_connectivity(param): + args = param.annotation.__supertype__.args + assert len(args) > 3 + return gtir2.Connectivity( + name=param.name, + primary=args[0], + secondary=args[1], + max_neighbors=args[2], + has_skip_values=args[3] if len(args) > 3 else True, + ) + + +def _extract_arg(param, connectivity_type_to_name): + args = param.annotation.args + if _is_field_annotation(param.annotation): + assert len(args) == 2 + return gtir2.Field(name=param.name, location_type=args[0], dtype=args[1]) + else: + assert len(args) == 2 + return gtir2.SparseField( + name=param.name, connectivity=connectivity_type_to_name[args[0]], dtype=args[1] + ) + + +_SORTED_TYPES = ( + common.DataType.INVALID, + common.DataType.AUTO, + common.DataType.FLOAT64, + common.DataType.FLOAT32, + common.DataType.INT32, + common.DataType.UINT32, + common.DataType.BOOLEAN, +) + + +def _common_type(lhs: common.DataType, rhs: common.DataType) -> common.DataType: + for t in _SORTED_TYPES: + if lhs == t or rhs == t: + return t + raise RuntimeError(f"unsupported types {lhs}, {rhs}") + + +class _DeduceExprTypeVisotor(eve.NodeVisitor): + def visit_Literal(self, src: gtir2.Literal, tbl): + return src.dtype + + def visit_NeighborReduce(self, src: gtir2.NeighborReduce, tbl): + return src.dtype + + def visit_FieldAccess(self, src: gtir2.FieldAccess, tbl): + return tbl[src.name].dtype + + def visit_SparseFieldAccess(self, src: gtir2.SparseFieldAccess, tbl): + return tbl[src.name].dtype + + def visit_BinaryOp(self, src: gtir2.BinaryOp, tbl): + return _common_type(self.visit(src.left, tbl=tbl), self.visit(src.right, tbl=tbl)) + + +def _deduce_expr_type(src: gtir2.Expr, tbl): + return _DeduceExprTypeVisotor().visit(src, tbl=tbl) + + +_PY_TYPE_TO_DATA_TYPE = { + int: common.DataType.INT32, + float: common.DataType.FLOAT64 if sys.float_info.mant_dig >= 53 else common.DataType.FLOAT32, +} + + +class _Visitor(eve.NodeVisitor): + def visit_Constant(self, src: gtscript_ast2.Constant, **kwargs): + return gtir2.Literal(value=src.value, dtype=_PY_TYPE_TO_DATA_TYPE[type(src.value)]) + + def visit_Call(self, src: gtscript_ast2.Call, tbl, location): + assert isinstance(location, gtir2.PrimaryLocation) + assert len(src.args) == 1 + generator_exp = src.args[0] + assert isinstance(generator_exp, gtscript_ast2.GeneratorExp) + assert len(generator_exp.generators) == 1 + generator = generator_exp.generators[0] + target = generator.target + assert isinstance(target, gtscript_ast2.Name) + secondary = gtir2.SecondaryLocation( + name=target.name, + connectivity=self.visit(generator.iter_, tbl=tbl, location=location), + primary=location.name, + ) + body = self.visit(generator_exp.elt, tbl=tbl, location=secondary) + return gtir2.NeighborReduce( + op=gtir2.ReduceOperator(src.func.name), + location=secondary, + dtype=_deduce_expr_type(body, tbl), + body=body, + ) + + def visit_Name(self, src: gtscript_ast2.Name, tbl, location, is_target=False): + if src.name in tbl: + field = tbl[src.name] + if isinstance(field, gtir2.Field): + if isinstance(location, gtir2.PrimaryLocation): + assert field.location_type == location.location_type + loc = location.name + else: + assert isinstance(location, gtir2.SecondaryLocation) + if field.location_type == tbl[location.connectivity].primary: + loc = location.primary + elif field.location_type == tbl[location.connectivity].secondary: + loc = location.name + else: + raise RuntimeError(f"invalid field access {src}") + return gtir2.FieldAccess(name=field.name, location=loc) + elif isinstance(field, gtir2.SparseField): + assert isinstance(location, gtir2.SecondaryLocation) + assert field.connectivity == location.connectivity + return gtir2.SparseFieldAccess( + name=field.name, primary=location.primary, secondary=location.name + ) + elif isinstance(field, gtir2.Connectivity): + assert isinstance(location, gtir2.PrimaryLocation) + assert field.primary == location.location_type + return field.name + else: + raise RuntimeError(f"invalid access {src} to {field}") + elif is_target: + assert isinstance(location, gtir2.PrimaryLocation) + return gtir2.FieldAccess(name=src.name, location=location.name) + else: + raise RuntimeError(f"unbound name {src}") + + def visit_Subscript(self, src: gtscript_ast2.Subscript, location, **kwargs): + res = self.visit(src.value, location=location, **kwargs) + if isinstance(res, gtir2.FieldAccess): + assert isinstance(src.index, gtscript_ast2.Name) + assert src.index.name == res.location + elif isinstance(res, gtir2.SparseFieldAccess): + assert isinstance(src.index, collections.Sequence) + assert len(src.index) == 2 + assert src.index[0].name == res.primary + assert src.index[1].name == res.secondary + elif isinstance(res, str): + assert src.index.name == location.name + return res + + def visit_BinOp(self, src: gtscript_ast2.BinOp, **kwargs): + return gtir2.BinaryOp( + op=common.BinaryOperator(src.op), + left=self.visit(src.left, **kwargs), + right=self.visit(src.right, **kwargs), + ) + + def visit_Assign(self, src: gtscript_ast2.Assign, tbl, primary): + assert len(src.targets) == 1 + target = src.targets[0] + left = self.visit(target, tbl=tbl, location=primary, is_target=True) + right = self.visit(src.value, tbl=tbl, location=primary) + return ( + gtir2.Assign(left=left, right=right), + () + if left.name in tbl + else ( + gtir2.Field( + name=left.name, + location_type=primary.location_type, + dtype=_deduce_expr_type(right, tbl), + ), + ), + ) + + def visit_With(self, src: gtscript_ast2.With, tbl): + loop_order = None + location = None + for item in src.items: + assert isinstance(item.expr, gtscript_ast2.Call) + f = item.expr + if f.func.name == "computation": + assert len(f.args) == 1 + assert isinstance(f.args[0], gtscript_ast2.Name) + assert loop_order is None + loop_order = common.LoopOrder[f.args[0].name] + elif f.func.name == "location": + assert len(f.args) == 1 + assert isinstance(f.args[0], gtscript_ast2.Name) + assert location is None + assert item.var is not None + location = gtir2.PrimaryLocation( + name=item.var.name, location_type=common.LocationType[f.args[0].name] + ) + elif f.func.name == "interval": + pass + else: + raise RuntimeError(f"unexpected withitem: {item}") + assert all(isinstance(s, gtscript_ast2.Assign) for s in src.body) + + def folder(acc, stmt): + assign, temporaries = self.visit( + stmt, tbl={**tbl, **{t.name: t for t in acc[1]}}, primary=location + ) + return acc[0] + (assign,), acc[1] + temporaries + + body, temporaries = functools.reduce(folder, src.body, ((), ())) + return gtir2.Stencil(loop_order=loop_order, location=location, body=body), temporaries + + +def _transform(src: gtscript_ast2.Function, fun_params): + connectivity_params = tuple(p for p in fun_params if _is_conncetivity_annotation(p.annotation)) + connectivity_type_to_name = dict((p.annotation, p.name) for p in connectivity_params) + if len(connectivity_params) > len(connectivity_type_to_name): + raise RuntimeError( + "the types of the conncetivities within computation should be all different" + ) + connectivities = tuple(_extract_connectivity(p) for p in connectivity_params) + args = tuple( + _extract_arg(p, connectivity_type_to_name) + for p in fun_params + if not _is_conncetivity_annotation(p.annotation) + ) + assert all(isinstance(s, gtscript_ast2.With) for s in src.body) + tbl = {e.name: e for e in connectivities + args} + + def folder(acc, stmt): + stencil, temporaries = _Visitor().visit(stmt, tbl={**tbl, **{t.name: t for t in acc[1]}}) + return acc[0] + (stencil,), acc[1] + temporaries + + stencils, temporaries = functools.reduce(folder, src.body, ((), ())) + return gtir2.Computation( + name=src.name, + connectivities=connectivities, + args=args, + temporaries=temporaries, + stencils=stencils, + ) + + +def transform(src): + return _transform(py_to_gtscript2.transform(src), _extract_params(src)) diff --git a/src/gtc/unstructured/py_to_gtscript2.py b/src/gtc/unstructured/py_to_gtscript2.py new file mode 100644 index 0000000..a687cd3 --- /dev/null +++ b/src/gtc/unstructured/py_to_gtscript2.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import ast +import inspect +import textwrap +from typing import Sequence + +from toolz.functoolz import compose + +import eve +from gtc.unstructured import gtscript_ast2 + + +class _Visitor(ast.NodeVisitor): + def _visit_iterable(self, src): + return tuple(self.visit(e) for e in src) + + def visit_Tuple(self, src: ast.Tuple): + return tuple(self.visit(e) for e in src.elts) + + def visit_Name(self, src: ast.Name): + return gtscript_ast2.Name(name=src.id) + + def visit_Subscript(self, src: ast.Subscript): + return gtscript_ast2.Subscript( + value=self.visit(src.value), index=self.visit(src.slice.value) + ) + + def visit_withitem(self, src: ast.withitem): + return gtscript_ast2.WithItem( + expr=self.visit(src.context_expr), + var=self.visit(src.optional_vars) if src.optional_vars else None, + ) + + def visit_With(self, src: ast.With): + return gtscript_ast2.With( + items=self._visit_iterable(src.items), body=self._visit_iterable(src.body) + ) + + def visit_Call(self, src: ast.Call): + return gtscript_ast2.Call(func=self.visit(src.func), args=self._visit_iterable(src.args)) + + def visit_BinOp(self, src: ast.BinOp): + return gtscript_ast2.BinOp( + left=self.visit(src.left), op=self.visit(src.op), right=self.visit(src.right) + ) + + def visit_Div(self, src: ast.Div): + return "/" + + def visit_Mult(self, src: ast.Mult): + return "*" + + def visit_Add(self, src: ast.Div): + return "+" + + def visit_Sub(self, src: ast.Mult): + return "-" + + def visit_Constant(self, src: ast.Constant): + return gtscript_ast2.Constant(value=src.value) + + def visit_GeneratorExp(self, src: ast.GeneratorExp): + return gtscript_ast2.GeneratorExp( + elt=self.visit(src.elt), generators=self._visit_iterable(src.generators) + ) + + def visit_comprehension(self, src: ast.comprehension): + return gtscript_ast2.Comprehension( + target=self.visit(src.target), iter_=self.visit(src.iter) + ) + + def visit_Assign(self, src: ast.Assign): + return gtscript_ast2.Assign( + targets=self._visit_iterable(src.targets), value=self.visit(src.value) + ) + + def visit_FunctionDef(self, src: ast.FunctionDef): + return gtscript_ast2.Function(name=src.name, body=self._visit_iterable(src.body)) + + +def _flatten(ll): + return tuple(e for l in ll for e in l) + + +class _FlattenWiths(eve.NodeTranslator): + def _process_statements(self, src: Sequence[gtscript_ast2.Stmt]): + return ( + _flatten( + self.visit(s, items=()) if isinstance(s, gtscript_ast2.With) else (self.visit(s),) + for s in src + ) + if any(isinstance(s, gtscript_ast2.With) for s in src) + else tuple(self.visit(s) for s in src) + ) + + def visit_Function(self, src: gtscript_ast2.Function): + return gtscript_ast2.Function(name=src.name, body=self._process_statements(src.body)) + + def visit_With(self, src: gtscript_ast2.With, items: Sequence[gtscript_ast2.WithItem]): + items = items + src.items + return ( + _flatten(self.visit(s, items=items) for s in src.body) + if all(isinstance(s, gtscript_ast2.With) for s in src.body) + else (gtscript_ast2.With(items=items, body=self._process_statements(src.body)),) + ) + + +transform = compose( + _FlattenWiths().visit, + _Visitor().visit, + lambda x: x.body[0], + ast.parse, + textwrap.dedent, + inspect.getsource, +) diff --git a/src/gtc/unstructured/usid2.py b/src/gtc/unstructured/usid2.py new file mode 100644 index 0000000..4755fe6 --- /dev/null +++ b/src/gtc/unstructured/usid2.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from eve import Bool, Int, Node, Str +from eve.typingx import FrozenList +from gtc import common +from gtc.unstructured.gtir2 import ReduceOperator + + +class Connectivity(Node): + name: Str + max_neighbors: Int + has_skip_values: Bool + + +class Field(Node): + name: Str + # dtype is missing because it is not used in the code generation + + +class SparseField(Field): + connectivity: Str + + +class Temporary(Node): + name: Str + location_type: Str + dtype: Str + + +class Composite(Node): + name: Str + items: FrozenList[Str] + + +class Expr(Node): + pass + + +class Literal(Expr): + value: Str + + +class FieldAccess(Expr): + name: Str + location: Str + + +class BinaryOp(Expr): + op: common.BinaryOperator + left: Expr + right: Expr + + +# TODO(till): discuss it with Hannes (primary, secondary) +class NeighborReduce(Expr): + op: ReduceOperator + dtype: Str + connectivity: Str + primary: Str + secondary: Str + body: Expr + + +class Assign(Node): + left: FieldAccess + right: Expr + + +class Kernel(Node): + location_type: Str + primary: Composite + secondaries: FrozenList[Composite] + body: FrozenList[Assign] + + +class Computation(Node): + name: Str + connectivities: FrozenList[Connectivity] + args: FrozenList[Field] + temporaries: FrozenList[Temporary] + kernels: FrozenList[Kernel] diff --git a/src/gtc/unstructured/usid2_codegen.py b/src/gtc/unstructured/usid2_codegen.py new file mode 100644 index 0000000..0d9f196 --- /dev/null +++ b/src/gtc/unstructured/usid2_codegen.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from toolz.functoolz import compose + +from eve import codegen +from eve.codegen import FormatTemplate as as_fmt +from eve.codegen import MakoTemplate as as_mako +from gtc.unstructured import usid2 + + +class _KernelCallGenerator(codegen.TemplatedGenerator): + Composite = as_mako( + "make_composite<${','.join(f'{i}_tag' for i in items)}>(${','.join(items)})" + ) + Kernel = as_mako( + "call_kernel<${id_}>(d.${location_type}${''.join(f', {c}' for c in [primary] + secondaries)});" + ) + + +class _Generator(codegen.TemplatedGenerator): + def visit_Computation(self, node: usid2.Computation, **kwargs): + return self.generic_visit( + node, kernel_calls=tuple(_KernelCallGenerator.apply(k) for k in node.kernels), **kwargs + ) + + Literal = as_mako("${value}") + BinaryOp = as_fmt("({left} {op} {right})") + FieldAccess = as_fmt("field<{name}_tag>({location})") + NeighborReduce = as_mako( + "${op}_neighbors<${dtype}, ${connectivity}_tag>" + + "([](auto &&${primary}, auto &&${secondary}) { return ${body}; }, ${primary}, strides, ${connectivity})" + ) + Assign = as_fmt("{left} = {right};") + Kernel = as_mako( + """ + struct ${id_} { + GT_FUNCTION auto operator()() const { + return [](auto && ${_this_node.primary.name}, + auto &&strides${''.join(f', auto&& {s.name}' for s in _this_node.secondaries)}) { + ${''.join(body)} + }; + } + }; + """ + ) + Temporary = as_fmt( + "auto {name} = make_simple_tmp_storage<{dtype}>(d.{location_type}, d.k, alloc);" + ) + Connectivity = as_mako( + "struct ${name}_tag: connectivity<${max_neighbors}, ${has_skip_values.lower()}> {};" + ) + Field = as_mako("struct ${name}_tag {};") + SparseField = as_mako("struct ${name}_tag: sparse_field<${connectivity}_tag> {};") + Computation = as_mako( + """<% + + ts = tuple(e.name for e in _this_node.temporaries) + cs = tuple(e.name for e in _this_node.connectivities) + ps = tuple(e.name for e in _this_node.args) + + %>#pragma once + #include + namespace gridtools::usid::${backend}::${name}_impl_ { + ${''.join(connectivities)} + ${''.join(args)} + ${''.join(f'struct {t}_tag {{}};' for t in ts)} + ${''.join(kernels)} + inline constexpr auto ${name} = [](domain d${''.join(f', auto&& {c}' for c in cs)}) { + ${''.join(f'static_assert(is_sid());' for c in cs)} + return[d = std::move(d) + ${ ''.join(f', {c} = sid::rename_dimensions(std::forward({c})(traits_t()))' for c in cs) }] + (${ ','.join(f'auto&& {p}' for p in ps)}) { + ${''.join(f'static_assert(is_sid());' for p in ps)} + % if len(temporaries) > 0: + auto alloc = make_allocator(); + ${''.join(temporaries)} + % endif + ${''.join(kernel_calls)} + }; + }; + } + using gridtools::usid::${backend}::${name}_impl_::${name}; + """ + ) + + +def _impl(backend): + # TOOO(anstaf): agree on python style here + return compose( + lambda x: codegen.format_source("cpp", x, style="LLVM"), + lambda x: _Generator.apply(x, backend=backend), + ) + + +naive = _impl("naive") +gpu = _impl("gpu") diff --git a/tests/tests_gtc/regression/cpp2/CMakeLists.txt b/tests/tests_gtc/regression/cpp2/CMakeLists.txt new file mode 100644 index 0000000..875d02c --- /dev/null +++ b/tests/tests_gtc/regression/cpp2/CMakeLists.txt @@ -0,0 +1,25 @@ +cmake_minimum_required(VERSION 3.14.5) +project(gtc_regression2 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) + +include(FetchContent) +FetchContent_Declare(usid + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../cpputil/usid) +FetchContent_MakeAvailable(usid) + +include(CTest) +find_package(Python3 COMPONENTS Interpreter REQUIRED) + +# TODO(anstaf): factor out the stuff into cmake function +add_custom_command( + OUTPUT fvm_nabla_naive.hpp + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/fvm_nabla.py naive > fvm_nabla_naive.hpp) +add_custom_target(fvm_nabla_naive_header DEPENDS fvm_nabla_naive.hpp) +add_executable(fvm_nabla_naive fvm_nabla_naive.cpp) +target_include_directories(fvm_nabla_naive PRIVATE ${CMAKE_BINARY_DIR}) +add_dependencies(fvm_nabla_naive fvm_nabla_naive_header) +target_link_libraries(fvm_nabla_naive PRIVATE usid::usid_naive_helper usid::fvm_nabla_driver gtest_main) +add_test(NAME fvm_nabla_naive COMMAND $) + +# TODO(anstaf): add gpu executable diff --git a/tests/tests_gtc/regression/cpp2/fvm_nabla.py b/tests/tests_gtc/regression/cpp2/fvm_nabla.py new file mode 100644 index 0000000..12e29ed --- /dev/null +++ b/tests/tests_gtc/regression/cpp2/fvm_nabla.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import sys +import typing + +from gt_frontend.gtscript import ( + FORWARD, + Connectivity, + Edge, + Field, + SparseField, + Vertex, + computation, + location, +) + +from gtc.common import DataType +from gtc.unstructured import frontend + + +dtype = DataType.FLOAT64 +E2V = typing.NewType("E2V", Connectivity[Edge, Vertex, 2, False]) +V2E = typing.NewType("V2E", Connectivity[Vertex, Edge, 7, True]) + + +def nabla( + v2e: V2E, + e2v: E2V, + S_MXX: Field[Edge, dtype], + S_MYY: Field[Edge, dtype], + pp: Field[Vertex, dtype], + pnabla_MXX: Field[Vertex, dtype], + pnabla_MYY: Field[Vertex, dtype], + vol: Field[Vertex, dtype], + sign: SparseField[V2E, dtype], +): + with computation(FORWARD): + with location(Edge) as e: + zavg = 0.5 * sum(pp[v] for v in e2v[e]) + zavgS_MXX = S_MXX * zavg + zavgS_MYY = S_MYY * zavg + with location(Vertex) as v: + pnabla_MXX = sum(zavgS_MXX[e] * sign[v, e] for e in v2e[v]) + pnabla_MYY = sum(zavgS_MYY[e] * sign[v, e] for e in v2e[v]) + pnabla_MXX = pnabla_MXX / vol + pnabla_MYY = pnabla_MYY / vol + + +if __name__ == "__main__": + print((frontend.gpu if len(sys.argv) > 1 and sys.argv[1] == "gpu" else frontend.naive)(nabla)) diff --git a/tests/tests_gtc/regression/cpp2/fvm_nabla_naive.cpp b/tests/tests_gtc/regression/cpp2/fvm_nabla_naive.cpp new file mode 100644 index 0000000..764316d --- /dev/null +++ b/tests/tests_gtc/regression/cpp2/fvm_nabla_naive.cpp @@ -0,0 +1,5 @@ +#include "fvm_nabla_naive.hpp" +#include +#include + +TEST(fvm, nabla_naive) { fvm_nabla_driver(nabla); }