Skip to content

Commit

Permalink
Merge pull request #39 from yasahi-hpc/use-iall2all
Browse files Browse the repository at this point in the history
use asynchronous MPI functions for overlapping
  • Loading branch information
yasahi-hpc authored Sep 25, 2023
2 parents b4f25bd + b7552dd commit e53de1e
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 42 deletions.
76 changes: 61 additions & 15 deletions mini-apps/lbm2d-letkf/executors/letkf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,25 +243,32 @@ class LETKF {
timers[DA_Load_H2D_u]->end();
}

MPI_Request requests[2];
auto _broadcast = stdexec::just() |
stdexec::then([&]{
auto rho_obs = data_vars->rho_obs().mdspan();
auto u_obs = data_vars->u_obs().mdspan();
timers[DA_Broadcast_rho]->begin();
broadcast(rho_obs);
timers[DA_Broadcast_rho]->end();
broadcast(rho_obs, &requests[0]);

timers[DA_Broadcast_u]->begin();
broadcast(u_obs);
timers[DA_Broadcast_u]->end();
broadcast(u_obs, &requests[1]);
});

auto _axpy = letkf_solver_->solve_axpy_sender(scheduler);
auto _axpy_and_braodcast = stdexec::when_all(
auto _evd = letkf_solver_->solve_evd_sender(scheduler);
auto _evd_and_broadcast = stdexec::when_all(
std::move(_broadcast),
std::move(_axpy)
);
stdexec::sync_wait( std::move(_axpy_and_braodcast) );
std::move(_evd)
) | stdexec::then(
[&](){
MPI_Waitall(2, requests, MPI_STATUSES_IGNORE);
timers[DA_Broadcast_rho]->end();
timers[DA_Broadcast_u]->end();
});

timers[DA_LETKF_EVD_and_Broadcast]->begin();
stdexec::sync_wait( std::move(_evd_and_broadcast) );
timers[DA_LETKF_EVD_and_Broadcast]->end();

// set yo
stdexec::sync_wait( scope2.on_empty() ); // complete load v
Expand All @@ -274,20 +281,29 @@ class LETKF {
timers[DA_Load_H2D_v]->end();
}

MPI_Request request;
auto _gemm = letkf_solver_->solve_gemm_sender(scheduler);
auto _broadcast_v = stdexec::just() |
stdexec::then([&]{
auto v_obs = data_vars->v_obs().mdspan();
auto v_obs = data_vars->v_obs().mdspan();
timers[DA_Broadcast_v]->begin();
broadcast(v_obs);
timers[DA_Broadcast_v]->end();
broadcast(v_obs, &request);
});

auto _gemm_and_braodcast = stdexec::when_all(
auto _gemm_and_broadcast = stdexec::when_all(
std::move(_broadcast_v),
std::move(_gemm)
);
stdexec::sync_wait( std::move(_gemm_and_braodcast) );
) | stdexec::then(
[&](){
MPI_Status status;
MPI_Wait(&request, &status);
timers[DA_Broadcast_v]->end();
}
);

timers[DA_LETKF_GEMM_and_Broadcast]->begin();
stdexec::sync_wait( std::move(_gemm_and_broadcast) );
timers[DA_LETKF_GEMM_and_Broadcast]->end();

setyo(data_vars, timers);

Expand Down Expand Up @@ -456,6 +472,23 @@ class LETKF {
mpi_conf_.comm());
}

template <class ViewType,
std::enable_if_t<ViewType::rank()==3, std::nullptr_t> = nullptr>
void all2all(const ViewType& a, ViewType& b, MPI_Request* request) {
assert( a.extents() == b.extents() );
MPI_Datatype mpi_datatype = Impl::getMPIDataType<ViewType::value_type>();

const std::size_t size = a.extent(0) * a.extent(1);
MPI_Ialltoall(a.data_handle(),
size,
mpi_datatype,
b.data_handle(),
size,
mpi_datatype,
mpi_conf_.comm(),
request);
}

template <class ViewType>
void broadcast(ViewType& a) {
MPI_Datatype mpi_datatype = Impl::getMPIDataType<ViewType::value_type>();
Expand All @@ -468,6 +501,19 @@ class LETKF {
mpi_conf_.comm());
}

template <class ViewType>
void broadcast(ViewType& a, MPI_Request* request) {
MPI_Datatype mpi_datatype = Impl::getMPIDataType<ViewType::value_type>();

const std::size_t size = a.size();
MPI_Ibcast(a.data_handle(),
size,
mpi_datatype,
0,
mpi_conf_.comm(),
request);
}

void load(std::unique_ptr<DataVars>& data_vars, const int it) {
from_file(data_vars->rho_obs(), it);
from_file(data_vars->u_obs(), it);
Expand Down
2 changes: 1 addition & 1 deletion mini-apps/lbm2d-letkf/executors/letkf_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class LETKFSolver {
}

public:
stdexec::sender auto solve_axpy_sender(stdexec::scheduler auto&& scheduler) {
stdexec::sender auto solve_evd_sender(stdexec::scheduler auto&& scheduler) {
auto X = X_.mdspan();
auto Y = Y_.mdspan();
auto dX = dX_.mdspan();
Expand Down
58 changes: 32 additions & 26 deletions mini-apps/lbm2d-letkf/timer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ enum TimerEnum : int {Total,
DA_Broadcast_u,
DA_Broadcast_v,
DA_LETKF,
DA_LETKF_EVD_and_Broadcast,
DA_LETKF_GEMM_and_Broadcast,
DA_LETKF_Update,
DA_Update,
Diag,
LBMSolver,
Expand All @@ -101,32 +104,35 @@ enum TimerEnum : int {Total,
static void defineTimers(std::vector<Timer*> &timers, bool use_time_stamps=false) {
// Set timers
timers.resize(Nb_timers);
timers[TimerEnum::Total] = new Timer("total");
timers[TimerEnum::MainLoop] = new Timer("MainLoop");
timers[TimerEnum::DA] = new Timer("DA", use_time_stamps);
timers[TimerEnum::DA_Load] = new Timer("DA_Load", use_time_stamps);
timers[TimerEnum::DA_Load_H2D] = new Timer("DA_Load_H2D", use_time_stamps);
timers[TimerEnum::DA_Load_rho] = new Timer("DA_Load_rho", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_rho] = new Timer("DA_Load_H2D_rho", use_time_stamps);
timers[TimerEnum::DA_Load_u] = new Timer("DA_Load_u", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_u] = new Timer("DA_Load_H2D_u", use_time_stamps);
timers[TimerEnum::DA_Load_v] = new Timer("DA_Load_v", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_v] = new Timer("DA_Load_H2D_v", use_time_stamps);
timers[TimerEnum::DA_Pack_X] = new Timer("DA_Pack_X", use_time_stamps);
timers[TimerEnum::DA_All2All_X] = new Timer("DA_All2All_X", use_time_stamps);
timers[TimerEnum::DA_Unpack_X] = new Timer("DA_Unpack_X", use_time_stamps);
timers[TimerEnum::DA_Pack_Y] = new Timer("DA_Pack_Y", use_time_stamps);
timers[TimerEnum::DA_All2All_Y] = new Timer("DA_All2All_Y", use_time_stamps);
timers[TimerEnum::DA_Unpack_Y] = new Timer("DA_Unpack_Y", use_time_stamps);
timers[TimerEnum::DA_Pack_Obs] = new Timer("DA_Pack_Obs", use_time_stamps);
timers[TimerEnum::DA_Broadcast] = new Timer("DA_Broadcast", use_time_stamps);
timers[TimerEnum::DA_Broadcast_rho] = new Timer("DA_Broadcast_rho", use_time_stamps);
timers[TimerEnum::DA_Broadcast_u] = new Timer("DA_Broadcast_u", use_time_stamps);
timers[TimerEnum::DA_Broadcast_v] = new Timer("DA_Broadcast_v", use_time_stamps);
timers[TimerEnum::DA_LETKF] = new Timer("DA_LETKF", use_time_stamps);
timers[TimerEnum::DA_Update] = new Timer("DA_Update", use_time_stamps);
timers[TimerEnum::Diag] = new Timer("diag");
timers[TimerEnum::LBMSolver] = new Timer("lbm");
timers[TimerEnum::Total] = new Timer("total");
timers[TimerEnum::MainLoop] = new Timer("MainLoop");
timers[TimerEnum::DA] = new Timer("DA", use_time_stamps);
timers[TimerEnum::DA_Load] = new Timer("DA_Load", use_time_stamps);
timers[TimerEnum::DA_Load_H2D] = new Timer("DA_Load_H2D", use_time_stamps);
timers[TimerEnum::DA_Load_rho] = new Timer("DA_Load_rho", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_rho] = new Timer("DA_Load_H2D_rho", use_time_stamps);
timers[TimerEnum::DA_Load_u] = new Timer("DA_Load_u", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_u] = new Timer("DA_Load_H2D_u", use_time_stamps);
timers[TimerEnum::DA_Load_v] = new Timer("DA_Load_v", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_v] = new Timer("DA_Load_H2D_v", use_time_stamps);
timers[TimerEnum::DA_Pack_X] = new Timer("DA_Pack_X", use_time_stamps);
timers[TimerEnum::DA_All2All_X] = new Timer("DA_All2All_X", use_time_stamps);
timers[TimerEnum::DA_Unpack_X] = new Timer("DA_Unpack_X", use_time_stamps);
timers[TimerEnum::DA_Pack_Y] = new Timer("DA_Pack_Y", use_time_stamps);
timers[TimerEnum::DA_All2All_Y] = new Timer("DA_All2All_Y", use_time_stamps);
timers[TimerEnum::DA_Unpack_Y] = new Timer("DA_Unpack_Y", use_time_stamps);
timers[TimerEnum::DA_Pack_Obs] = new Timer("DA_Pack_Obs", use_time_stamps);
timers[TimerEnum::DA_Broadcast] = new Timer("DA_Broadcast", use_time_stamps);
timers[TimerEnum::DA_Broadcast_rho] = new Timer("DA_Broadcast_rho", use_time_stamps);
timers[TimerEnum::DA_Broadcast_u] = new Timer("DA_Broadcast_u", use_time_stamps);
timers[TimerEnum::DA_Broadcast_v] = new Timer("DA_Broadcast_v", use_time_stamps);
timers[TimerEnum::DA_LETKF] = new Timer("DA_LETKF", use_time_stamps);
timers[TimerEnum::DA_LETKF_EVD_and_Broadcast] = new Timer("DA_LETKF_EVD_and_Broadcast", use_time_stamps);
timers[TimerEnum::DA_LETKF_GEMM_and_Broadcast] = new Timer("DA_LETKF_GEMM_and_Broadcast", use_time_stamps);
timers[TimerEnum::DA_LETKF_Update] = new Timer("DA_LETKF_Update", use_time_stamps);
timers[TimerEnum::DA_Update] = new Timer("DA_Update", use_time_stamps);
timers[TimerEnum::Diag] = new Timer("diag");
timers[TimerEnum::LBMSolver] = new Timer("lbm");
}

static void printTimers(std::vector<Timer*> &timers) {
Expand Down

0 comments on commit e53de1e

Please sign in to comment.