Skip to content

Commit

Permalink
Merge pull request #22 from yasahi-hpc/optimization
Browse files Browse the repository at this point in the history
[Bugfix] gemm in letkf-solver
  • Loading branch information
yasahi-hpc authored Jun 27, 2023
2 parents 7d833d5 + d668cbe commit 0c8020e
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 22 deletions.
4 changes: 2 additions & 2 deletions lib/cuda_linalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ namespace Impl {

thrust::device_vector<value_type> workspace(lwork);
thrust::device_vector<int> info(batchSize, 0);
auto workspace_data = (value_type *)thrust::raw_pointer_cast(workspace.data());
auto info_data = (int *)thrust::raw_pointer_cast(info.data());
value_type* workspace_data = (value_type *)thrust::raw_pointer_cast(workspace.data());
int* info_data = (int *)thrust::raw_pointer_cast(info.data());

auto status = syevjBatched(
handle,
Expand Down
19 changes: 16 additions & 3 deletions lib/stdpar/Transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,33 @@ namespace Impl {
using layout_type = InputView::layout_type;
using axes_type = std::array<int, 3>;

for(std::size_t i=0; i<axes.size(); i++) {
assert(out.extent(i) == in.extent(axes[i]));
}
assert(out.size() == in.size());
//for(std::size_t i=0; i<axes.size(); i++) {
// assert(out.extent(i) == in.extent(axes[i]));
//}

const auto n0 = in.extent(0), n1 = in.extent(1), n2 = in.extent(2);
// Not quite sure, this is a better strategy
IteratePolicy<typename InputView::layout_type, 3> policy3d({0, 0, 0}, {n0, n1, n2});

if(axes == axes_type({0, 1, 2}) ) {
const auto n = in.size();
for(std::size_t i=0; i<axes.size(); i++) {
assert(out.extent(i) == in.extent(axes[i]));
}
std::copy(std::execution::par_unseq, in.data_handle(), in.data_handle()+n, out.data_handle());
} else if(axes == axes_type({0, 2, 1}) ) {
for(std::size_t i=0; i<axes.size(); i++) {
assert(out.extent(i) == in.extent(axes[i]));
}
Impl::for_each(policy3d,
[=](const int i0, const int i1, const int i2) {
out(i0, i2, i1) = in(i0, i1, i2);
});
} else if(axes == axes_type({1, 0, 2})) {
for(std::size_t i=0; i<axes.size(); i++) {
assert(out.extent(i) == in.extent(axes[i]));
}
Impl::for_each(policy3d,
[=](const int i0, const int i1, const int i2) {
out(i1, i0, i2) = in(i0, i1, i2);
Expand All @@ -68,6 +78,9 @@ namespace Impl {
mdspan2d_type sub_out(out.data_handle(), out_shape);
transpose(sub_in, sub_out);
} else if(axes == axes_type({2, 1, 0})) {
for(std::size_t i=0; i<axes.size(); i++) {
assert(out.extent(i) == in.extent(axes[i]));
}
Impl::for_each(policy3d,
[=](const int i0, const int i1, const int i2) {
out(i2, i1, i0) = in(i0, i1, i2);
Expand Down
2 changes: 1 addition & 1 deletion mini-apps/lbm2d-letkf/executors/letkf_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class LETKFSolver {
const value_type alpha = sqrt(static_cast<int>(n_ens_) - 1);
Impl::diag(d, inv_D, -0.5); // (n_ens, n_ens, n_batch)
Impl::matrix_matrix_product(inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch)
Impl::matrix_matrix_product(V, tmp_ee, W, "N", "T", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch)
Impl::matrix_matrix_product(V, tmp_ee, W, "N", "N", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch)

// W = W + w
// Xsol = x_mean + matmat(dX, W)
Expand Down
2 changes: 1 addition & 1 deletion mini-apps/lbm2d-letkf/stdpar/letkf_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class LETKFSolver {
const value_type alpha = sqrt(static_cast<int>(n_ens_) - 1);
Impl::diag(d, inv_D, -0.5); // (n_ens, n_ens, n_batch)
Impl::matrix_matrix_product(inv_D, V, tmp_ee, "N", "T"); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch)
Impl::matrix_matrix_product(V, tmp_ee, W, "N", "T", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch)
Impl::matrix_matrix_product(V, tmp_ee, W, "N", "N", alpha); // (n_ens, n_ens, n_batch) * (n_ens, n_ens, n_batch) -> (n_ens, n_ens, n_batch)

// W = W + w
// Xsol = x_mean + matmat(dX, W)
Expand Down
10 changes: 5 additions & 5 deletions wk/letkf_256.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
"in_case_name": "nature256",
"nx": 256,
"ny": 256,
"spinup": 10000,
"nbiter": 10000,
"io_interval": 20,
"da_interval": 20,
"spinup": 200000,
"nbiter": 40000,
"io_interval": 200,
"da_interval": 200,
"obs_interval": 1,
"lyapnov": false,
"les": true,
"da_nud_rate": 0.1,
"beta": 1.0,
"beta": 1.07,
"rloc_len": 1
}
}
8 changes: 4 additions & 4 deletions wk/nature_256.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
"case_name": "nature256",
"nx": 256,
"ny": 256,
"spinup": 10000,
"nbiter": 10000,
"io_interval": 20,
"da_interval": 20,
"spinup": 200000,
"nbiter": 40000,
"io_interval": 200,
"da_interval": 200,
"obs_interval": 1,
"lyapnov": false,
"les": true,
Expand Down
12 changes: 6 additions & 6 deletions wk/sub_stdpar_lbm2d_letkf_A100.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ export UCX_MEMTYPE_CACHE=n
export UCX_IB_GPU_DIRECT_RDMA=no
export UCX_RNDV_FRAG_MEM_TYPE=cuda

#mpiexec -machinefile $PJM_O_NODEINF -np 1 -npernode 1 \
# ../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename nature_256.json
mpiexec -machinefile $PJM_O_NODEINF -np 1 -npernode 1 \
../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename nature.json
../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename nature_256.json
#mpiexec -machinefile $PJM_O_NODEINF -np 1 -npernode 1 \
# ../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename nature.json

#mpiexec -machinefile $PJM_O_NODEINF -np $PJM_MPI_PROC -npernode 4 \
# ./wrapper.sh ../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename letkf_256.json
mpiexec -machinefile $PJM_O_NODEINF -np $PJM_MPI_PROC -npernode 4 \
./wrapper.sh ../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename letkf.json
./wrapper.sh ../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename letkf_256.json
#mpiexec -machinefile $PJM_O_NODEINF -np $PJM_MPI_PROC -npernode 4 \
# ./wrapper.sh ../build/mini-apps/lbm2d-letkf/stdpar/lbm2d-letkf-stdpar --filename letkf.json

0 comments on commit 0c8020e

Please sign in to comment.