Skip to content

Commit

Permalink
Add complicated distributed shared array test
Browse files Browse the repository at this point in the history
  • Loading branch information
hmenke committed Oct 4, 2023
1 parent 48f5ce8 commit 1738e29
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
2 changes: 2 additions & 0 deletions c++/mpi/mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ namespace mpi {

[[nodiscard]] MPI_Comm get() const noexcept { return _com; }

[[nodiscard]] bool is_null() const noexcept { return _com == MPI_COMM_NULL; }

[[nodiscard]] int rank() const {
if (has_env) {
int num = 0;
Expand Down
83 changes: 79 additions & 4 deletions test/c++/mpi_window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
//
// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell

#include "mpi/mpi.hpp"
#include <mpi/mpi.hpp>
#include <mpi/vector.hpp>
#include <gtest/gtest.h>
#include <numeric>

Expand Down Expand Up @@ -95,10 +96,10 @@ TEST(MPI_Window, SharedArray) {
auto shm = world.split_shared();
int const rank_shm = shm.rank();

constexpr int const size = 20;
constexpr int const array_size = 20;
constexpr int const magic = 21;

mpi::shared_window<int> win{shm, rank_shm == 0 ? size : 0};
mpi::shared_window<int> win{shm, rank_shm == 0 ? array_size : 0};
std::span array_view{win.base(0), static_cast<std::size_t>(win.size(0))};

win.fence();
Expand All @@ -110,7 +111,81 @@ TEST(MPI_Window, SharedArray) {
win.fence();

int sum = std::accumulate(array_view.begin(), array_view.end(), int{0});
EXPECT_EQ(sum, size * magic);
EXPECT_EQ(sum, array_size * magic);
}

TEST(MPI_Window, DistributedSharedArray) {
mpi::communicator world;
auto shm = world.split_shared();

// Number of total array elements (prime number to make it a bit more exciting)
constexpr int const array_size_total = 197;

// Create a communicator between rank0 of all shared memory islands ("head node")
auto head = world.split(shm.rank() == 0 ? 0 : MPI_UNDEFINED);

// Determine number of shared memory islands and broadcast to everyone
int head_size = (world.rank() == 0 ? head.size(): -1);
mpi::broadcast(head_size, world);

// Determine rank in head node communicator and broadcast to all other ranks
// on the same shared memory island
int head_rank = (head.get() != MPI_COMM_NULL ? head.rank() : -1);
mpi::broadcast(head_rank, shm);

// Determine number of ranks on each shared memory island and broadcast to everyone
std::vector<int> shm_sizes(head_size, 0);
if (!head.is_null()) {
shm_sizes.at(head_rank) = shm.size();
shm_sizes = mpi::all_reduce(shm_sizes, head);
}
mpi::broadcast(shm_sizes, world);

// Chunk the total array such that each rank has approximately the same number
// of array elements
std::vector<int> array_sizes(head_size, 0);
for (auto &&[shm_size, array_size]: itertools::zip(shm_sizes, array_sizes)) {
array_size = array_size_total / world.size() * shm_size;
}
// Distribute the remainder evenly over the islands to reduce load imbalance
for (auto i: itertools::range(array_size_total % world.size())) {
array_sizes.at(i % array_sizes.size()) += 1;
}

EXPECT_EQ(array_size_total, std::accumulate(array_sizes.begin(), array_sizes.end(), int{0}));

// Determine the global index offset on the current shared memory island
auto begin = array_sizes.begin();
std::advance(begin, head_rank);
std::ptrdiff_t offset = std::accumulate(array_sizes.begin(), begin, std::ptrdiff_t{0});

// Allocate memory
mpi::shared_window<int> win{shm, shm.rank() == 0 ? array_sizes.at(head_rank) : 0};
std::span array_view{win.base(0), static_cast<std::size_t>(win.size(0))};

// Fill array with global index (= local index + global offset)
// We do this in parallel on each shared memory island by chunking the total range
win.fence();
auto slice = itertools::chunk_range(0, array_view.size(), shm.size(), shm.rank());
for (auto i = slice.first; i < slice.second; ++i) {
array_view[i] = i + offset;
}
win.fence();

// Calculate partial sum on head node of each shared memory island and
// all_reduce the partial sums into a total sum over the head node
// communicator and broadcast result to everyone
std::vector<int> partial_sum(head_size, 0);
int sum = 0;
if (!head.is_null()) {
partial_sum[head_rank] = std::accumulate(array_view.begin(), array_view.end(), int{0});
partial_sum = mpi::all_reduce(partial_sum, head);
sum = std::accumulate(partial_sum.begin(), partial_sum.end(), int{0});
}
mpi::broadcast(sum, world);

// Total sum is just sum of numbers in interval [0, array_size_total)
EXPECT_EQ(sum, (array_size_total * (array_size_total - 1)) / 2);
}

MPI_TEST_MAIN;

0 comments on commit 1738e29

Please sign in to comment.