Skip to content

Commit

Permalink
declare schedulers outside the mainloop
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Aug 12, 2023
1 parent d8f79c8 commit c9fc2ad
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 36 deletions.
151 changes: 119 additions & 32 deletions mini-apps/lbm2d-letkf/executors/letkf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@

#include <executors/Parallel_For.hpp>
#include <executors/Transpose.hpp>
#include <utils/string_utils.hpp>
#include <utils/mpi_utils.hpp>
#include <utils/file_utils.hpp>
#include <utils/io_utils.hpp>
#include "letkf_solver.hpp"
#include "../functors.hpp"
#include "../da_functors.hpp"
#include "da_models.hpp"

namespace stdex = std::experimental;

#include "nvexec/stream_context.cuh"
#include <exec/static_thread_pool.hpp>
#include <stdexec/execution.hpp>
#include <exec/async_scope.hpp>
#include <exec/on.hpp>

class LETKF : public DA_Model {
namespace stdex = std::experimental;

class LETKF {
private:
using value_type = RealView2D::value_type;
Config conf_;
IOConfig io_conf_;
MPIConfig mpi_conf_;
std::string base_dir_name_;
bool load_to_device_ = true;

Impl::blasHandle_t blas_handle_;
std::unique_ptr<LETKFSolver> letkf_solver_;
Expand All @@ -40,9 +45,27 @@ class LETKF : public DA_Model {

public:
LETKF(Config& conf, IOConfig& io_conf)=delete;
LETKF(Config& conf, IOConfig& io_conf, MPIConfig& mpi_conf) : DA_Model(conf, io_conf), mpi_conf_(mpi_conf) {}
LETKF(Config& conf, IOConfig& io_conf, MPIConfig& mpi_conf)
: conf_(conf), io_conf_(io_conf), mpi_conf_(mpi_conf) {
base_dir_name_ = io_conf_.base_dir_ + "/" + io_conf_.in_case_name_ + "/observed/ens0000";
}

virtual ~LETKF(){ blas_handle_.destroy(); }

void setFileInfo() {
int nb_expected_files = conf_.settings_.nbiter_ / conf_.settings_.io_interval_;
std::string variables[3] = {"rho", "u", "v"};
for(int it=0; it<nb_expected_files; it++) {
for(const auto& variable: variables) {
auto step = it * conf_.settings_.io_interval_;
auto file_name = base_dir_name_ + "/" + variable + "_obs_step" + Impl::zfill(step, 10) + ".dat";
if(!Impl::isFileExists(file_name)) {
std::runtime_error("Expected observation file does not exist." + file_name);
}
}
}
}

void initialize() {
setFileInfo();

Expand Down Expand Up @@ -88,45 +111,63 @@ class LETKF : public DA_Model {
blas_handle_.create();
}

void apply(std::unique_ptr<DataVars>& data_vars, const int it, std::vector<Timer*>& timers){
template <class Scheduler, class IO_Scheduler>
void apply(Scheduler&& scheduler,
IO_Scheduler&& io_scheduler,
std::unique_ptr<DataVars>& data_vars,
const int it,
std::vector<Timer*>& timers){
if(it == 0 || it % conf_.settings_.da_interval_ != 0) return;

if(mpi_conf_.is_master()) {
std::cout << __PRETTY_FUNCTION__ << ": t=" << it << std::endl;
}

if(is_async_) {
apply_async(data_vars, it, timers);
apply_async(scheduler, io_scheduler, data_vars, it, timers);
} else {
apply_sync(data_vars, it, timers);
}
}

private:
// Asynchronous implementation with senders/receivers
void apply_async(std::unique_ptr<DataVars>& data_vars, const int it, std::vector<Timer*>& timers) {
#if defined(ENABLE_OPENMP)
exec::static_thread_pool pool{std::thread::hardware_concurrency()};
auto scheduler = pool.get_scheduler();
#else
nvexec::stream_context stream_ctx{};
auto scheduler = stream_ctx.get_scheduler();
#endif

template <class Scheduler, class IO_Scheduler>
void apply_async(Scheduler&& scheduler,
IO_Scheduler&& io_scheduler,
std::unique_ptr<DataVars>& data_vars,
const int it,
std::vector<Timer*>& timers) {
exec::async_scope scope;
exec::static_thread_pool io_thread_pool{std::thread::hardware_concurrency()};
auto io_scheduler = io_thread_pool.get_scheduler();
auto _load = stdexec::just() |
auto _load_rho = stdexec::just() |
stdexec::then([&]{
timers[DA_Load_rho]->begin();
if(mpi_conf_.is_master()) {
load(data_vars, "rho", it);
}
timers[DA_Load_rho]->end();
});

auto _load_u = stdexec::just() |
stdexec::then([&]{
timers[DA_Load_u]->begin();
if(mpi_conf_.is_master()) {
load(data_vars, "u", it);
}
timers[DA_Load_u]->end();
});

auto _load_v = stdexec::just() |
stdexec::then([&]{
timers[DA_Load]->begin();
timers[DA_Load_v]->begin();
if(mpi_conf_.is_master()) {
load(data_vars, it);
load(data_vars, "v", it);
}
timers[DA_Load]->end();
timers[DA_Load_v]->end();
});

timers[TimerEnum::DA]->begin();
scope.spawn(stdexec::on(io_scheduler, std::move(_load)));
scope.spawn(stdexec::on(io_scheduler, std::move(_load_rho)));

// set X
const auto f = data_vars->f().mdspan();
Expand Down Expand Up @@ -171,30 +212,51 @@ class LETKF : public DA_Model {
Impl::transpose(blas_handle_, yk_buffer, Y, {0, 2, 1}); // (n_obs, n_batch, n_ens) -> (n_obs, n_ens, n_batch)
timers[DA_Unpack_Y]->end();

stdexec::sync_wait( scope.on_empty() );
stdexec::sync_wait( scope.on_empty() ); // load rho only
scope.spawn(stdexec::on(io_scheduler, std::move(_load_u)));
scope.spawn(stdexec::on(io_scheduler, std::move(_load_v)));

auto _axpy = letkf_solver_->solve_axpy_sender(scheduler);
if(!load_to_device_) {
timers[DA_Load_H2D]->begin();
timers[DA_Load_H2D_rho]->begin();
if(mpi_conf_.is_master()) {
data_vars->rho_obs().updateDevice();
}
timers[DA_Load_H2D_rho]->end();
}
auto rho_obs = data_vars->rho_obs().mdspan();
timers[DA_Broadcast_rho]->begin();
broadcast(rho_obs);
timers[DA_Broadcast_rho]->end();

stdexec::sync_wait( scope.on_empty() ); // load u and v
if(!load_to_device_) {
timers[DA_Load_H2D_u]->begin();
if(mpi_conf_.is_master()) {
data_vars->u_obs().updateDevice();
}
timers[DA_Load_H2D_u]->end();

timers[DA_Load_H2D_v]->begin();
if(mpi_conf_.is_master()) {
data_vars->v_obs().updateDevice();
}
timers[DA_Load_H2D]->end();
timers[DA_Load_H2D_v]->end();
}

auto _axpy = letkf_solver_->solve_axpy_sender(scheduler);

// set yo
auto _broadcast = stdexec::just() |
stdexec::then([&]{
auto rho_obs = data_vars->rho_obs().mdspan();
auto u_obs = data_vars->u_obs().mdspan();
auto v_obs = data_vars->v_obs().mdspan();
timers[DA_Broadcast]->begin();
broadcast(rho_obs);
timers[DA_Broadcast_u]->begin();
broadcast(u_obs);
timers[DA_Broadcast_u]->end();

timers[DA_Broadcast_v]->begin();
broadcast(v_obs);
timers[DA_Broadcast]->end();
timers[DA_Broadcast_v]->end();
});

auto _axpy_and_braodcast = stdexec::when_all(
Expand Down Expand Up @@ -382,6 +444,31 @@ class LETKF : public DA_Model {
mpi_conf_.comm());
}

void load(std::unique_ptr<DataVars>& data_vars, const int it) {
from_file(data_vars->rho_obs(), it);
from_file(data_vars->u_obs(), it);
from_file(data_vars->v_obs(), it);
}

void load(std::unique_ptr<DataVars>& data_vars, const std::string variable, const int it) {
if(variable == "rho") {
from_file(data_vars->rho_obs(), it);
} else if(variable == "u") {
from_file(data_vars->u_obs(), it);
} else if(variable == "v") {
from_file(data_vars->v_obs(), it);
}
}

template <class ViewType>
void from_file(ViewType& value, const int step) {
auto file_name = base_dir_name_ + "/" + value.name() + "_step" + Impl::zfill(step, 10) + ".dat";
auto mdspan = value.host_mdspan();
Impl::from_binary(file_name, mdspan);
if(load_to_device_) {
value.updateDevice();
}
}
};

#endif
3 changes: 0 additions & 3 deletions mini-apps/lbm2d-letkf/executors/model_factories.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "../io_config.hpp"
#include "../mpi_config.hpp"
#include "nudging.hpp"
#include "letkf.hpp"
#include "lbm2d.hpp"

static std::unique_ptr<Model> model_factory(std::string model, Config& conf, IOConfig& io_conf) {
Expand All @@ -20,8 +19,6 @@ static std::unique_ptr<Model> model_factory(std::string model, Config& conf, IOC
static std::unique_ptr<DA_Model> da_model_factory(std::string da_model, Config& conf, IOConfig& io_conf, MPIConfig& mpi_conf) {
if(da_model == "nudging") {
return std::unique_ptr<Nudging>(new Nudging(conf, io_conf));
} else if(da_model == "letkf") {
return std::unique_ptr<LETKF>(new LETKF(conf, io_conf, mpi_conf));
}
return std::unique_ptr<NonDA>(new NonDA(conf, io_conf));
};
Expand Down
27 changes: 26 additions & 1 deletion mini-apps/lbm2d-letkf/executors/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
#include <map>
#include <memory>
#include <nlohmann/json.hpp>
#include <utils/string_utils.hpp>
#include <utils/commandline_utils.hpp>
#include "../timer.hpp"
#include "../config.hpp"
#include "../io_config.hpp"
#include "../mpi_config.hpp"
#include "models.hpp"
#include "letkf.hpp"
#include "model_factories.hpp"
#include "data_vars.hpp"

Expand All @@ -24,6 +26,7 @@ class Solver {
std::string sim_type_;
std::unique_ptr<Model> model_;
std::unique_ptr<DA_Model> da_model_;
std::unique_ptr<LETKF> letkf_;
std::unique_ptr<DataVars> data_vars_;
std::vector<Timer*> timers_;

Expand All @@ -48,6 +51,10 @@ class Solver {
data_vars_ = std::move( std::unique_ptr<DataVars>(new DataVars(conf_)) );
model_ = std::move( model_factory(sim_type_, conf_, io_conf_) );
da_model_ = std::move( da_model_factory(sim_type_, conf_, io_conf_, mpi_conf_) );
if(sim_type_ == "letkf") {
letkf_ = std::move( std::unique_ptr<LETKF>(new LETKF(conf_, io_conf_, mpi_conf_)) );
letkf_->initialize();
}

model_->initialize(data_vars_);
da_model_->initialize();
Expand All @@ -68,11 +75,27 @@ class Solver {
};

void run(){
#if defined(ENABLE_OPENMP)
exec::static_thread_pool pool{std::thread::hardware_concurrency()};
auto scheduler = pool.get_scheduler();
#else
nvexec::stream_context stream_ctx{};
auto scheduler = stream_ctx.get_scheduler();
#endif

exec::static_thread_pool io_thread_pool{std::thread::hardware_concurrency()};
auto io_scheduler = io_thread_pool.get_scheduler();

timers_[TimerEnum::Total]->begin();
for(int it=0; it<conf_.settings_.nbiter_; it++) {
timers_[TimerEnum::MainLoop]->begin();

da_model_->apply(data_vars_, it, timers_);
if(sim_type_ == "letkf") {
letkf_->apply(scheduler, io_scheduler, data_vars_, it, timers_);
} else {
da_model_->apply(data_vars_, it, timers_);
}

if(!conf_.settings_.disable_output_) {
model_->diag(data_vars_, it, timers_);
}
Expand Down Expand Up @@ -188,6 +211,8 @@ class Solver {
}

// Saving json file to output directory
const int n_ens = mpi_conf_.size();
io_conf_.case_name_ = sim_type_ == "letkf" ? io_conf_.case_name_ + "_ens" + Impl::zfill(n_ens, 3) : io_conf_.case_name_;
if(mpi_conf_.is_master()) {
const std::string out_dir = io_conf_.base_dir_ + "/" + io_conf_.case_name_;
const std::string performance_dir = out_dir + "/" + "performance";
Expand Down
2 changes: 2 additions & 0 deletions mini-apps/lbm2d-letkf/stdpar/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ class Solver {
}

// Saving json file to output directory
const int n_ens = mpi_conf_.size();
io_conf_.case_name_ = sim_type_ == "letkf" ? io_conf_.case_name_ + "_ens" + Impl::zfill(n_ens, 3) : io_conf_.case_name_;
if(mpi_conf_.is_master()) {
const std::string out_dir = io_conf_.base_dir_ + "/" + io_conf_.case_name_;
const std::string performance_dir = out_dir + "/" + "performance";
Expand Down
3 changes: 3 additions & 0 deletions mini-apps/lbm2d-letkf/thrust/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <map>
#include <memory>
#include <nlohmann/json.hpp>
#include <utils/string_utils.hpp>
#include <utils/commandline_utils.hpp>
#include <utils/device_utils.hpp>
#include "../timer.hpp"
Expand Down Expand Up @@ -182,6 +183,8 @@ class Solver {
}

// Saving json file to output directory
const int n_ens = mpi_conf_.size();
io_conf_.case_name_ = sim_type_ == "letkf" ? io_conf_.case_name_ + "_ens" + Impl::zfill(n_ens, 3) : io_conf_.case_name_;
if(mpi_conf_.is_master()) {
const std::string out_dir = io_conf_.base_dir_ + "/" + io_conf_.case_name_;
const std::string performance_dir = out_dir + "/" + "performance";
Expand Down
18 changes: 18 additions & 0 deletions mini-apps/lbm2d-letkf/timer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ enum TimerEnum : int {Total,
DA,
DA_Load,
DA_Load_H2D,
DA_Load_rho,
DA_Load_u,
DA_Load_v,
DA_Load_H2D_rho,
DA_Load_H2D_u,
DA_Load_H2D_v,
DA_Pack_X,
DA_All2All_X,
DA_Unpack_X,
Expand All @@ -83,6 +89,9 @@ enum TimerEnum : int {Total,
DA_Unpack_Y,
DA_Pack_Obs,
DA_Broadcast,
DA_Broadcast_rho,
DA_Broadcast_u,
DA_Broadcast_v,
DA_LETKF,
DA_Update,
Diag,
Expand All @@ -97,6 +106,12 @@ static void defineTimers(std::vector<Timer*> &timers, bool use_time_stamps=false
timers[TimerEnum::DA] = new Timer("DA", use_time_stamps);
timers[TimerEnum::DA_Load] = new Timer("DA_Load", use_time_stamps);
timers[TimerEnum::DA_Load_H2D] = new Timer("DA_Load_H2D", use_time_stamps);
timers[TimerEnum::DA_Load_rho] = new Timer("DA_Load_rho", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_rho] = new Timer("DA_Load_H2D_rho", use_time_stamps);
timers[TimerEnum::DA_Load_u] = new Timer("DA_Load_u", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_u] = new Timer("DA_Load_H2D_u", use_time_stamps);
timers[TimerEnum::DA_Load_v] = new Timer("DA_Load_v", use_time_stamps);
timers[TimerEnum::DA_Load_H2D_v] = new Timer("DA_Load_H2D_v", use_time_stamps);
timers[TimerEnum::DA_Pack_X] = new Timer("DA_Pack_X", use_time_stamps);
timers[TimerEnum::DA_All2All_X] = new Timer("DA_All2All_X", use_time_stamps);
timers[TimerEnum::DA_Unpack_X] = new Timer("DA_Unpack_X", use_time_stamps);
Expand All @@ -105,6 +120,9 @@ static void defineTimers(std::vector<Timer*> &timers, bool use_time_stamps=false
timers[TimerEnum::DA_Unpack_Y] = new Timer("DA_Unpack_Y", use_time_stamps);
timers[TimerEnum::DA_Pack_Obs] = new Timer("DA_Pack_Obs", use_time_stamps);
timers[TimerEnum::DA_Broadcast] = new Timer("DA_Broadcast", use_time_stamps);
timers[TimerEnum::DA_Broadcast_rho] = new Timer("DA_Broadcast_rho", use_time_stamps);
timers[TimerEnum::DA_Broadcast_u] = new Timer("DA_Broadcast_u", use_time_stamps);
timers[TimerEnum::DA_Broadcast_v] = new Timer("DA_Broadcast_v", use_time_stamps);
timers[TimerEnum::DA_LETKF] = new Timer("DA_LETKF", use_time_stamps);
timers[TimerEnum::DA_Update] = new Timer("DA_Update", use_time_stamps);
timers[TimerEnum::Diag] = new Timer("diag");
Expand Down

0 comments on commit c9fc2ad

Please sign in to comment.