From 8e5f16f18112d9940d33f06a0022e60a6271cc32 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Sat, 1 Jul 2023 22:24:25 +0900 Subject: [PATCH 1/2] suppress init costs for cublas and cusolver --- lib/cuda_linalg.hpp | 199 ++++++++++++++++++++++++------------ lib/executors/Transpose.hpp | 18 +++- lib/openmp_linalg.hpp | 50 +++++++++ lib/stdpar/Transpose.hpp | 21 ++-- 4 files changed, 215 insertions(+), 73 deletions(-) diff --git a/lib/cuda_linalg.hpp b/lib/cuda_linalg.hpp index fe38de9..466fd37 100644 --- a/lib/cuda_linalg.hpp +++ b/lib/cuda_linalg.hpp @@ -220,23 +220,70 @@ namespace Impl { return cusolverDnDsyevjBatched(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchSize); } - /* - * Batched matrix matrix product - * Matrix shape - * A (n, m, l), B (m, k, l), C (n, k, l) - * */ + struct blasHandle_t { + cublasHandle_t handle_; + + public: + void create() { + cublasCreate(&handle_); + } + + void destroy() { + cublasDestroy(handle_); + } + }; + + template + struct syevjHandle_t { + cusolverDnHandle_t handle_; + thrust::device_vector workspace_; + thrust::device_vector info_; + syevjInfo_t params_; + cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_VECTOR; + cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_LOWER; + + public: + template = nullptr> + void create(MatrixView& a, VectorView& v, T tol=1.0e-7, int max_sweeps=100, int sort_eig=0) { + cusolverDnCreate(&handle_); + + cusolverDnCreateSyevjInfo(¶ms_); + cusolverDnXsyevjSetTolerance(params_, tol); + cusolverDnXsyevjSetMaxSweeps(params_, max_sweeps); + cusolverDnXsyevjSetSortEig(params_, sort_eig); + + const int batchSize = v.extent(1); + int lwork = 0; + syevjBatched_bufferSize( + handle_, + jobz_, + uplo_, + a.extent(0), a.data_handle(), + v.extent(0), v.data_handle(), + &lwork, + params_, + batchSize + ); + workspace_.resize(lwork); + info_.resize(batchSize, 0); + } + + void destroy() { + cusolverDnDestroy(handle_); + } + }; + template = nullptr> - void matrix_matrix_product(const ViewA& A, + void matrix_matrix_product(const blasHandle_t& blas_handle, + const ViewA& A, const ViewB& B, ViewC& C, std::string _transa, std::string _transb, typename ViewA::value_type alpha = 1, typename ViewA::value_type beta = 0) { - cublasHandle_t handle; - cublasCreate(&handle); - cublasOperation_t transa = _transa == "N" ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t transb = _transb == "N" ? CUBLAS_OP_N : CUBLAS_OP_T; @@ -252,7 +299,7 @@ namespace Impl { const auto Bk = _transb == "N" ? B.extent(0) : B.extent(1); assert(Ak == Bk); - auto status = gemmStridedBatched(handle, + auto status = gemmStridedBatched(blas_handle.handle_, transa, transb, Cm, @@ -271,9 +318,28 @@ namespace Impl { C.extent(0) * C.extent(1), C.extent(2) ); - cublasDestroy(handle); } - + + /* + * Batched matrix matrix product + * Matrix shape + * A (n, m, l), B (m, k, l), C (n, k, l) + * */ + template = nullptr> + void matrix_matrix_product(const ViewA& A, + const ViewB& B, + ViewC& C, + std::string _transa, + std::string _transb, + typename ViewA::value_type alpha = 1, + typename ViewA::value_type beta = 0) { + Impl::blasHandle_t blas_handle; + blas_handle.create(); + matrix_matrix_product(blas_handle, A, B, C, _transa, _transb, alpha, beta); + blas_handle.destroy(); + } + /* * Batched matrix vector product * Matrix shape @@ -282,15 +348,13 @@ namespace Impl { * */ template = nullptr> - void matrix_vector_product(const ViewA& A, + void matrix_vector_product(const blasHandle_t& blas_handle, + const ViewA& A, const ViewB& B, ViewC& C, std::string _transa, typename ViewA::value_type alpha = 1 ) { - cublasHandle_t handle; - cublasCreate(&handle); - cublasOperation_t transa = _transa == "N" ? CUBLAS_OP_N : CUBLAS_OP_T; const auto Cm = C.extent(0); @@ -303,7 +367,7 @@ namespace Impl { using value_type = ViewA::value_type; const value_type beta = 0; - auto status = gemmStridedBatched(handle, + auto status = gemmStridedBatched(blas_handle.handle_, transa, CUBLAS_OP_N, Cm, @@ -322,7 +386,26 @@ namespace Impl { C.extent(0), C.extent(1) ); - cublasDestroy(handle); + } + + /* + * Batched matrix vector product + * Matrix shape + * A (n, m, l), B (m, l), C (n, l) + * C = A * B + * */ + template = nullptr> + void matrix_vector_product(const ViewA& A, + const ViewB& B, + ViewC& C, + std::string _transa, + typename ViewA::value_type alpha = 1 + ) { + Impl::blasHandle_t blas_handle; + blas_handle.create(); + matrix_vector_product(blas_handle, A, B, C, _transa, alpha); + blas_handle.destroy(); } /* @@ -332,9 +415,9 @@ namespace Impl { * v (m, l) * w (m, m, l) * */ - template = nullptr> - void eig(MatrixView& a, VectorView& v) { + void eig(const Handle& syevj_handle, MatrixView& a, VectorView& v) { static_assert( std::is_same_v ); static_assert( std::is_same_v ); @@ -342,60 +425,42 @@ namespace Impl { assert(a.extent(0) == v.extent(0)); assert(a.extent(0) == a.extent(1)); // Square array assert(a.extent(2) == v.extent(1)); // batch size - - cusolverDnHandle_t handle; - cusolverDnCreate(&handle); - - syevjInfo_t params; - constexpr value_type tol = 1.0e-7; - constexpr int max_sweeps = 100; - constexpr int sort_eig = 0; - const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; - const cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER; - - cusolverDnCreateSyevjInfo(¶ms); - cusolverDnXsyevjSetTolerance(params, tol); - cusolverDnXsyevjSetMaxSweeps(params, max_sweeps); - cusolverDnXsyevjSetSortEig(params, sort_eig); const int batchSize = v.extent(1); - int lwork = 0; - syevjBatched_bufferSize( - handle, - jobz, - uplo, - a.extent(0), a.data_handle(), - v.extent(0), v.data_handle(), - &lwork, - params, - batchSize - ); - - thrust::device_vector workspace(lwork); - thrust::device_vector info(batchSize, 0); - value_type* workspace_data = (value_type *)thrust::raw_pointer_cast(workspace.data()); - int* info_data = (int *)thrust::raw_pointer_cast(info.data()); + value_type* workspace_data = (value_type *)thrust::raw_pointer_cast(syevj_handle.workspace_.data()); + int* info_data = (int *)thrust::raw_pointer_cast(syevj_handle.info_.data()); auto status = syevjBatched( - handle, - jobz, - uplo, + syevj_handle.handle_, + syevj_handle.jobz_, + syevj_handle.uplo_, a.extent(0), a.data_handle(), v.extent(0), v.data_handle(), - workspace_data, workspace.size(), + workspace_data, syevj_handle.workspace_.size(), info_data, - params, + syevj_handle.params_, batchSize ); - cudaDeviceSynchronize(); - cusolverDnDestroy(handle); + } + + template = nullptr> + void eig(MatrixView& a, VectorView& v) { + static_assert( std::is_same_v ); + static_assert( std::is_same_v ); + + using value_type = MatrixView::value_type; + Impl::syevjHandle_t syevj_handle; + syevj_handle.create(a, v); + eig(syevj_handle, a, v); + syevj_handle.destroy(); } // 2D transpose template = nullptr> - void transpose(const InputView& in, OutputView& out) { + void transpose(const blasHandle_t& blas_handle, const InputView& in, OutputView& out) { static_assert( std::is_same_v ); static_assert( std::is_same_v ); static_assert( std::is_same_v ); @@ -406,11 +471,9 @@ namespace Impl { using value_type = InputView::value_type; constexpr value_type alpha = 1; constexpr value_type beta = 0; - // transpose by cublas - cublasHandle_t handle; - cublasCreate(&handle); - geam(handle, + // transpose by cublas + geam(blas_handle.handle_, CUBLAS_OP_T, CUBLAS_OP_T, in.extent(1), @@ -424,8 +487,16 @@ namespace Impl { out.data_handle(), out.extent(0) ); + } - cublasDestroy(handle); + // 2D transpose + template = nullptr> + void transpose(const InputView& in, OutputView& out) { + Impl::blasHandle_t blas_handle; + blas_handle.create(); + transpose(blas_handle, in, out); + blas_handle.destroy(); } }; diff --git a/lib/executors/Transpose.hpp b/lib/executors/Transpose.hpp index 6f266b6..a9aed5d 100644 --- a/lib/executors/Transpose.hpp +++ b/lib/executors/Transpose.hpp @@ -20,7 +20,7 @@ namespace Impl { /* Transpose batched matrix */ template = nullptr> - void transpose(const InputView& in, OutputView& out, const std::array& axes) { + void transpose(const blasHandle_t& blas_handle, const InputView& in, OutputView& out, const std::array& axes) { static_assert( std::is_same_v ); static_assert( std::is_same_v ); using value_type = InputView::value_type; @@ -57,7 +57,7 @@ namespace Impl { const mdspan2d_type sub_in(in.data_handle(), in_shape); mdspan2d_type sub_out(out.data_handle(), out_shape); - transpose(sub_in, sub_out); + transpose(blas_handle, sub_in, sub_out); } else if(axes == axes_type({2, 0, 1})) { using mdspan2d_type = stdex::mdspan, layout_type>; using extent2d_type = std::array; @@ -66,7 +66,7 @@ namespace Impl { const mdspan2d_type sub_in(in.data_handle(), in_shape); mdspan2d_type sub_out(out.data_handle(), out_shape); - transpose(sub_in, sub_out); + transpose(blas_handle, sub_in, sub_out); } else if(axes == axes_type({2, 1, 0})) { Impl::for_each(policy3d, [=](const int i0, const int i1, const int i2) { @@ -76,6 +76,18 @@ namespace Impl { std::runtime_error("Invalid axes specified."); } } + + /* Transpose batched matrix */ + template = nullptr> + void transpose(const InputView& in, OutputView& out, const std::array& axes) { + static_assert( std::is_same_v ); + static_assert( std::is_same_v ); + Impl::blasHandle_t blas_handle; + blas_handle.create(); + transpose(blas_handle, in, out, axes); + blas_handle.destroy(); + } }; #endif diff --git a/lib/openmp_linalg.hpp b/lib/openmp_linalg.hpp index 61aaf67..6b5077c 100644 --- a/lib/openmp_linalg.hpp +++ b/lib/openmp_linalg.hpp @@ -8,6 +8,20 @@ namespace stdex = std::experimental; namespace Impl { + struct blasHandle_t { + public: + void create() {} + void destroy() {} + }; + + template + struct syevjHandle_t { + public: + template = nullptr> + void create(MatrixView& a, VectorView& v, T tol=1.0e-7, int max_sweeps=100, int sort_eig=0) {} + void destroy() {} + }; /* * Batched matrix matrix product @@ -69,6 +83,19 @@ namespace Impl { } } } + + template = nullptr> + void matrix_matrix_product(const blasHandle_t& blas_handle, + const ViewA& A, + const ViewB& B, + ViewC& C, + std::string _transa, + std::string _transb, + typename ViewA::value_type alpha = 1, + typename ViewA::value_type beta = 0) { + matrix_matrix_product(A, B, C, _transa, _transb, alpha, beta); + } /* * Batched matrix vector product @@ -120,7 +147,18 @@ namespace Impl { vecC.noalias() = alpha * matA.transpose() * vecB; } } + } + template = nullptr> + void matrix_vector_product(const blasHandle_t& blas_handle, + const ViewA& A, + const ViewB& B, + ViewC& C, + std::string _transa, + typename ViewA::value_type alpha = 1 + ) { + matrix_vector_product(A, B, C, _transa, alpha); } /* @@ -160,6 +198,12 @@ namespace Impl { } } + template = nullptr> + void eig(const Handle& syevj_handle, MatrixView& a, VectorView& v) { + eig(a, v); + } + // 2D transpose template = nullptr> @@ -192,6 +236,12 @@ namespace Impl { } } } + + template = nullptr> + void transpose(const blasHandle_t& blas_handle, const InputView& in, OutputView& out) { + transpose(in, out); + } }; #endif diff --git a/lib/stdpar/Transpose.hpp b/lib/stdpar/Transpose.hpp index 4e16f71..26f97ea 100644 --- a/lib/stdpar/Transpose.hpp +++ b/lib/stdpar/Transpose.hpp @@ -20,7 +20,7 @@ namespace Impl { /* Transpose batched matrix */ template = nullptr> - void transpose(const InputView& in, OutputView& out, const std::array& axes) { + void transpose(const blasHandle_t& blas_handle, const InputView& in, OutputView& out, const std::array& axes) { static_assert( std::is_same_v ); static_assert( std::is_same_v ); using value_type = InputView::value_type; @@ -29,9 +29,6 @@ namespace Impl { using axes_type = std::array; assert(out.size() == in.size()); - //for(std::size_t i=0; i, layout_type>; using extent2d_type = std::array; @@ -76,7 +73,7 @@ namespace Impl { const mdspan2d_type sub_in(in.data_handle(), in_shape); mdspan2d_type sub_out(out.data_handle(), out_shape); - transpose(sub_in, sub_out); + transpose(blas_handle, sub_in, sub_out); } else if(axes == axes_type({2, 1, 0})) { for(std::size_t i=0; i = nullptr> + void transpose(const InputView& in, OutputView& out, const std::array& axes) { + static_assert( std::is_same_v ); + static_assert( std::is_same_v ); + Impl::blasHandle_t blas_handle; + blas_handle.create(); + transpose(blas_handle, in, out, axes); + blas_handle.destroy(); + } }; #endif From b0c9d750b526347c33ac31170c3ad449d9eb02d2 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Sat, 1 Jul 2023 22:26:54 +0900 Subject: [PATCH 2/2] suppress init costs for cublas and cusolver in letkf-lbm2d examples --- mini-apps/lbm2d-letkf/executors/letkf.hpp | 21 ++++++----- .../lbm2d-letkf/executors/letkf_solver.hpp | 35 ++++++++++++------- mini-apps/lbm2d-letkf/stdpar/letkf.hpp | 15 ++++---- mini-apps/lbm2d-letkf/stdpar/letkf_solver.hpp | 35 ++++++++++++------- 4 files changed, 67 insertions(+), 39 deletions(-) diff --git a/mini-apps/lbm2d-letkf/executors/letkf.hpp b/mini-apps/lbm2d-letkf/executors/letkf.hpp index 8c004f9..a0f25ec 100644 --- a/mini-apps/lbm2d-letkf/executors/letkf.hpp +++ b/mini-apps/lbm2d-letkf/executors/letkf.hpp @@ -24,6 +24,7 @@ class LETKF : public DA_Model { using value_type = RealView2D::value_type; MPIConfig mpi_conf_; + Impl::blasHandle_t blas_handle_; std::unique_ptr letkf_solver_; /* Views before transpose */ @@ -43,7 +44,7 @@ class LETKF : public DA_Model { LETKF(Config& conf, IOConfig& io_conf)=delete; LETKF(Config& conf, IOConfig& io_conf, MPIConfig& mpi_conf) : DA_Model(conf, io_conf), mpi_conf_(mpi_conf) {} - virtual ~LETKF(){} + virtual ~LETKF(){ blas_handle_.destroy(); } void initialize() { setFileInfo(); @@ -78,6 +79,8 @@ class LETKF : public DA_Model { Iterate_policy<3> policy3d({0, 0, 0}, {n_obs_x_, n_obs_x_, n_batch}); Impl::for_each(policy3d, initialize_rR_functor(conf_, y_offset, rR)); + + blas_handle_.create(); } void apply(std::unique_ptr& data_vars, const int it, std::vector& timers){ @@ -138,7 +141,7 @@ class LETKF : public DA_Model { auto xk = xk_.mdspan(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(f, xk, {2, 0, 1}); + Impl::transpose(blas_handle_, f, xk, {2, 0, 1}); timers[DA_Set_Matrix]->end(); } @@ -148,7 +151,7 @@ class LETKF : public DA_Model { auto X = letkf_solver_->X().mdspan(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(xk_buffer, X, {0, 2, 1}); + Impl::transpose(blas_handle_, xk_buffer, X, {0, 2, 1}); timers[DA_Set_Matrix]->end(); } @@ -158,7 +161,7 @@ class LETKF : public DA_Model { auto Y = letkf_solver_->Y().mdspan(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(yk_buffer, Y, {0, 2, 1}); // (n_obs, n_batch, n_ens) -> (n_obs, n_ens, n_batch) + Impl::transpose(blas_handle_, yk_buffer, Y, {0, 2, 1}); // (n_obs, n_batch, n_ens) -> (n_obs, n_ens, n_batch) timers[DA_Set_Matrix]->end(); } @@ -273,7 +276,7 @@ class LETKF : public DA_Model { auto X = letkf_solver_->X().mdspan(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(f, xk, {2, 0, 1}); // (nx, ny, Q) -> (Q, nx*ny) + Impl::transpose(blas_handle_, f, xk, {2, 0, 1}); // (nx, ny, Q) -> (Q, nx*ny) timers[DA_Set_Matrix]->end(); timers[DA_All2All]->begin(); @@ -281,7 +284,7 @@ class LETKF : public DA_Model { timers[DA_All2All]->end(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(xk_buffer, X, {0, 2, 1}); + Impl::transpose(blas_handle_, xk_buffer, X, {0, 2, 1}); timers[DA_Set_Matrix]->end(); // set Y @@ -306,7 +309,7 @@ class LETKF : public DA_Model { timers[DA_All2All]->end(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(yk_buffer, Y, {0, 2, 1}); // (n_obs, n_batch, n_ens) -> (n_obs, n_ens, n_batch) + Impl::transpose(blas_handle_, yk_buffer, Y, {0, 2, 1}); // (n_obs, n_batch, n_ens) -> (n_obs, n_ens, n_batch) timers[DA_Set_Matrix]->end(); // set yo @@ -335,9 +338,9 @@ class LETKF : public DA_Model { auto xk = xk_.mdspan(); auto xk_buffer = xk_buffer_.mdspan(); auto f = data_vars->f().mdspan(); - Impl::transpose(X, xk_buffer, {0, 2, 1}); // X (n_stt, n_ens, n_batch) -> xk_buffer (n_stt, n_batch, n_ens) + Impl::transpose(blas_handle_, X, xk_buffer, {0, 2, 1}); // X (n_stt, n_ens, n_batch) -> xk_buffer (n_stt, n_batch, n_ens) all2all(xk_buffer, xk); // xk_buffer (n_stt, n_batch, n_ens) -> xk(n_stt, n_batch, n_ens) - Impl::transpose(xk, f, {1, 2, 0}); // (Q, nx*ny) -> (nx, ny, Q) + Impl::transpose(blas_handle_, xk, f, {1, 2, 0}); // (Q, nx*ny) -> (nx, ny, Q) auto [nx, ny] = conf_.settings_.n_; auto rho = data_vars->rho().mdspan(); diff --git a/mini-apps/lbm2d-letkf/executors/letkf_solver.hpp b/mini-apps/lbm2d-letkf/executors/letkf_solver.hpp index f84a47b..e3949cb 100644 --- a/mini-apps/lbm2d-letkf/executors/letkf_solver.hpp +++ b/mini-apps/lbm2d-letkf/executors/letkf_solver.hpp @@ -13,6 +13,8 @@ using letkf_config_type = std::tuple syevj_handle_; RealView3D X_, dX_; // (n_stt, n_ens, n_batch) RealView3D Y_, dY_; // (n_obs, n_ens, n_batch) @@ -57,7 +59,11 @@ class LETKFSolver { // Allocate views initialize(); } - ~LETKFSolver(){} + + ~LETKFSolver(){ + blas_handle_.destroy(); + syevj_handle_.destroy(); + } public: // Getters @@ -94,14 +100,14 @@ class LETKFSolver { auto tmp_oe = tmp_oe_.mdspan(); const value_type beta = (static_cast(n_ens_) - 1) / beta_; Impl::deep_copy(I, Q); // (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(rR, dY, tmp_oe, "N", "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_obs, n_ens, n_batch) - Impl::matrix_matrix_product(dY, tmp_oe, Q, "T", "N", 1, beta); // (n_ens, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, rR, dY, tmp_oe, "N", "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_obs, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, dY, tmp_oe, Q, "T", "N", 1, beta); // (n_ens, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_ens, n_ens, n_batch) // Q = V * diag(d) * V^T auto d = d_.mdspan(); auto V = V_.mdspan(); Impl::deep_copy(Q, V); - Impl::eig(V, d); // (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch), (n_ens, n_batch) + Impl::eig(syevj_handle_, V, d); // (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch), (n_ens, n_batch) // P = V * inv(d) * V^T // P: (n_ens, n_ens, n_batch) @@ -109,8 +115,8 @@ class LETKFSolver { auto tmp_ee = tmp_ee_.mdspan(); auto P = P_.mdspan(); Impl::diag(d, inv_D, -1); // (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(V, tmp_ee, P, "N", "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, V, tmp_ee, P, "N", "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) // w = P * (dY^T * inv(R) * dyo) auto w = w_.mdspan(); @@ -118,21 +124,21 @@ class LETKFSolver { auto tmp_e = tmp_e_.mdspan(); auto dyo = Impl::squeeze(yo, 1); auto _w = Impl::squeeze(w, 1); - Impl::matrix_vector_product(rR, dyo, tmp_o, "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_batch) -> (n_obs, n_batch) - Impl::matrix_vector_product(dY, tmp_o, tmp_e, "T"); // (n_ens, n_obs, n_batch) * (n_obs, n_batch) -> (n_ens, n_batch) - Impl::matrix_vector_product(P, tmp_e, _w, "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_batch) -> (n_ens, n_batch) + Impl::matrix_vector_product(blas_handle_, rR, dyo, tmp_o, "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_batch) -> (n_obs, n_batch) + Impl::matrix_vector_product(blas_handle_, dY, tmp_o, tmp_e, "T"); // (n_ens, n_obs, n_batch) * (n_obs, n_batch) -> (n_ens, n_batch) + Impl::matrix_vector_product(blas_handle_, P, tmp_e, _w, "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_batch) -> (n_ens, n_batch) // W = sqrt(Ne-1) * V * inv(sqrt(D)) * V^T auto W = W_.mdspan(); const value_type alpha = sqrt(static_cast(n_ens_) - 1); Impl::diag(d, inv_D, -0.5); // (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(V, tmp_ee, W, "N", "N", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, V, tmp_ee, W, "N", "N", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) // W = W + w // Xsol = x_mean + matmat(dX, W) Impl::axpy(W, w); // (n_ens, n_ens, n_batch) + (n_ens, 1, n_batch) -> (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(dX, W, X, "N", "N"); // (n_stt, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_stt, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, dX, W, X, "N", "N"); // (n_stt, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_stt, n_ens, n_batch) Impl::axpy(X, x_mean); // (n_stt, n_ens, n_batch) + (n_stt, 1, n_batch) -> (n_stt, n_ens, n_batch) } @@ -169,6 +175,11 @@ class LETKFSolver { auto I = I_.mdspan(); Impl::identity(rR); Impl::identity(I); + + auto d = d_.mdspan(); + auto V = V_.mdspan(); + blas_handle_.create(); + syevj_handle_.create(V, d); } }; diff --git a/mini-apps/lbm2d-letkf/stdpar/letkf.hpp b/mini-apps/lbm2d-letkf/stdpar/letkf.hpp index 3b00db4..376d1d4 100644 --- a/mini-apps/lbm2d-letkf/stdpar/letkf.hpp +++ b/mini-apps/lbm2d-letkf/stdpar/letkf.hpp @@ -14,6 +14,7 @@ class LETKF : public DA_Model { using value_type = RealView2D::value_type; MPIConfig mpi_conf_; + Impl::blasHandle_t blas_handle_; std::unique_ptr letkf_solver_; /* Views before transpose */ @@ -32,7 +33,7 @@ class LETKF : public DA_Model { LETKF(Config& conf, IOConfig& io_conf)=delete; LETKF(Config& conf, IOConfig& io_conf, MPIConfig& mpi_conf) : DA_Model(conf, io_conf), mpi_conf_(mpi_conf) {} - virtual ~LETKF(){} + virtual ~LETKF(){ blas_handle_.destroy(); } void initialize() { setFileInfo(); @@ -66,6 +67,8 @@ class LETKF : public DA_Model { Iterate_policy<3> policy3d({0, 0, 0}, {n_obs_x_, n_obs_x_, n_batch}); Impl::for_each(policy3d, initialize_rR_functor(conf_, y_offset, rR)); + + blas_handle_.create(); } void apply(std::unique_ptr& data_vars, const int it, std::vector& timers){ @@ -105,7 +108,7 @@ class LETKF : public DA_Model { auto X = letkf_solver_->X().mdspan(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(f, xk, {2, 0, 1}); // (nx, ny, Q) -> (Q, nx*ny) + Impl::transpose(blas_handle_, f, xk, {2, 0, 1}); // (nx, ny, Q) -> (Q, nx*ny) timers[DA_Set_Matrix]->end(); timers[DA_All2All]->begin(); @@ -113,7 +116,7 @@ class LETKF : public DA_Model { timers[DA_All2All]->end(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(xk_buffer, X, {0, 2, 1}); + Impl::transpose(blas_handle_, xk_buffer, X, {0, 2, 1}); timers[DA_Set_Matrix]->end(); // set Y @@ -138,7 +141,7 @@ class LETKF : public DA_Model { timers[DA_All2All]->end(); timers[DA_Set_Matrix]->begin(); - Impl::transpose(yk_buffer, Y, {0, 2, 1}); // (n_obs, n_batch, n_ens) -> (n_obs, n_ens, n_batch) + Impl::transpose(blas_handle_, yk_buffer, Y, {0, 2, 1}); // (n_obs, n_batch, n_ens) -> (n_obs, n_ens, n_batch) timers[DA_Set_Matrix]->end(); // set yo @@ -167,9 +170,9 @@ class LETKF : public DA_Model { auto xk = xk_.mdspan(); auto xk_buffer = xk_buffer_.mdspan(); auto f = data_vars->f().mdspan(); - Impl::transpose(X, xk_buffer, {0, 2, 1}); // X (n_stt, n_ens, n_batch) -> xk_buffer (n_stt, n_batch, n_ens) + Impl::transpose(blas_handle_, X, xk_buffer, {0, 2, 1}); // X (n_stt, n_ens, n_batch) -> xk_buffer (n_stt, n_batch, n_ens) all2all(xk_buffer, xk); // xk_buffer (n_stt, n_batch, n_ens) -> xk(n_stt, n_batch, n_ens) - Impl::transpose(xk, f, {1, 2, 0}); // (Q, nx*ny) -> (nx, ny, Q) + Impl::transpose(blas_handle_, xk, f, {1, 2, 0}); // (Q, nx*ny) -> (nx, ny, Q) auto [nx, ny] = conf_.settings_.n_; auto rho = data_vars->rho().mdspan(); diff --git a/mini-apps/lbm2d-letkf/stdpar/letkf_solver.hpp b/mini-apps/lbm2d-letkf/stdpar/letkf_solver.hpp index 148b6bd..52bf1a4 100644 --- a/mini-apps/lbm2d-letkf/stdpar/letkf_solver.hpp +++ b/mini-apps/lbm2d-letkf/stdpar/letkf_solver.hpp @@ -13,6 +13,8 @@ using letkf_config_type = std::tuple syevj_handle_; RealView3D X_, dX_; // (n_stt, n_ens, n_batch) RealView3D Y_, dY_; // (n_obs, n_ens, n_batch) @@ -57,7 +59,11 @@ class LETKFSolver { // Allocate views initialize(); } - ~LETKFSolver(){} + + ~LETKFSolver(){ + blas_handle_.destroy(); + syevj_handle_.destroy(); + } public: // Getters @@ -94,14 +100,14 @@ class LETKFSolver { auto tmp_oe = tmp_oe_.mdspan(); const value_type beta = (static_cast(n_ens_) - 1) / beta_; Impl::deep_copy(I, Q); // (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(rR, dY, tmp_oe, "N", "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_obs, n_ens, n_batch) - Impl::matrix_matrix_product(dY, tmp_oe, Q, "T", "N", 1, beta); // (n_ens, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, rR, dY, tmp_oe, "N", "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_obs, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, dY, tmp_oe, Q, "T", "N", 1, beta); // (n_ens, n_obs, n_batch) * (n_obs, n_ens, n_batch) -> (n_ens, n_ens, n_batch) // Q = V * diag(d) * V^T auto d = d_.mdspan(); auto V = V_.mdspan(); Impl::deep_copy(Q, V); - Impl::eig(V, d); // (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch), (n_ens, n_batch) + Impl::eig(syevj_handle_, V, d); // (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch), (n_ens, n_batch) // P = V * inv(d) * V^T // P: (n_ens, n_ens, n_batch) @@ -109,8 +115,8 @@ class LETKFSolver { auto tmp_ee = tmp_ee_.mdspan(); auto P = P_.mdspan(); Impl::diag(d, inv_D, -1); // (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(V, tmp_ee, P, "N", "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, V, tmp_ee, P, "N", "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) // w = P * (dY^T * inv(R) * dyo) auto w = w_.mdspan(); @@ -118,21 +124,21 @@ class LETKFSolver { auto tmp_e = tmp_e_.mdspan(); auto dyo = Impl::squeeze(yo, 1); auto _w = Impl::squeeze(w, 1); - Impl::matrix_vector_product(rR, dyo, tmp_o, "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_batch) -> (n_obs, n_batch) - Impl::matrix_vector_product(dY, tmp_o, tmp_e, "T"); // (n_ens, n_obs, n_batch) * (n_obs, n_batch) -> (n_ens, n_batch) - Impl::matrix_vector_product(P, tmp_e, _w, "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_batch) -> (n_ens, n_batch) + Impl::matrix_vector_product(blas_handle_, rR, dyo, tmp_o, "N"); // (n_obs, n_obs, n_batch) * (n_obs, n_batch) -> (n_obs, n_batch) + Impl::matrix_vector_product(blas_handle_, dY, tmp_o, tmp_e, "T"); // (n_ens, n_obs, n_batch) * (n_obs, n_batch) -> (n_ens, n_batch) + Impl::matrix_vector_product(blas_handle_, P, tmp_e, _w, "N"); // (n_ens, n_ens, n_batch) * (n_ens, n_batch) -> (n_ens, n_batch) // W = sqrt(Ne-1) * V * inv(sqrt(D)) * V^T auto W = W_.mdspan(); const value_type alpha = sqrt(static_cast(n_ens_) - 1); Impl::diag(d, inv_D, -0.5); // (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(V, tmp_ee, W, "N", "N", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, V, tmp_ee, W, "N", "N", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch) // W = W + w // Xsol = x_mean + matmat(dX, W) Impl::axpy(W, w); // (n_ens, n_ens, n_batch) + (n_ens, 1, n_batch) -> (n_ens, n_ens, n_batch) - Impl::matrix_matrix_product(dX, W, X, "N", "N"); // (n_stt, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_stt, n_ens, n_batch) + Impl::matrix_matrix_product(blas_handle_, dX, W, X, "N", "N"); // (n_stt, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_stt, n_ens, n_batch) Impl::axpy(X, x_mean); // (n_stt, n_ens, n_batch) + (n_stt, 1, n_batch) -> (n_stt, n_ens, n_batch) } @@ -169,6 +175,11 @@ class LETKFSolver { auto I = I_.mdspan(); Impl::identity(rR); Impl::identity(I); + + auto d = d_.mdspan(); + auto V = V_.mdspan(); + blas_handle_.create(); + syevj_handle_.create(V, d); } };