Skip to content

Commit

Permalink
Merge pull request #181 from sony/feature/20190812-skip-all-reduce-if…
Browse files Browse the repository at this point in the history
…-zeroing

Skip allreduce if array is not updated
  • Loading branch information
AkioHayakawa-sony authored Sep 11, 2019
2 parents 4e72614 + 87bbb4d commit 9a22013
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ using std::unordered_map;
using std::unordered_set;
using std::pair;

/** Forward decl. of a wrapper object of MPI_Comm
*/
class MpiCommWrapper;

/** \addtogroup NNablaCoreGrp */
/*@{*/

Expand Down Expand Up @@ -74,6 +78,13 @@ class NBLA_API MultiProcessDataParallelCommunicatorNccl

// Groups
unordered_map<string, ncclComm_t> comms_;
unordered_map<string, shared_ptr<MpiCommWrapper>> mpi_comms_;

bool mpi_check_any(bool condition, const string &group);
bool mpi_check_all(bool condition, const string &group);

vector<NdArrayPtr> get_modified_arrays(const vector<NdArrayPtr> &arrays,
const string &group);

public:
typedef typename CudaType<T>::type Tc;
Expand Down
266 changes: 222 additions & 44 deletions src/nbla/cuda/communicator/multi_process_data_parallel_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,167 @@ using std::vector;
using std::make_shared;
using std::unordered_set;

/**
MPI singleton class that manages lifetime of MPI.
The MPI singleton object
can be obtained as a shared pointer by a static function get(). Using a
shared pointer ensures the singleton is alive as long as it is holded by any
object.
@todo This is only used by NCCL communicator so far. However, this should be
exposed to another communicator which relies on MPI.
*/
class Mpi {
public:
/**
Returns the singleton object as a shared pointer.
*/
static shared_ptr<Mpi> get() {
static std::shared_ptr<Mpi> mpi;
if (!mpi) {
mpi.reset(new Mpi);
}
return mpi;
}
/**
Returns MPI_Group of MPI_COMM_WORLD.
It is maintained in this singleton to ensure the MPI_Group_free is called
when the program ends.
*/
static MPI_Group world_group() { return Mpi::get()->world_group_; }
~Mpi() {
if (finalized()) {
/*
MPI might be finalized by other libraries.
*/
return;
}
MPI_Group_free(&world_group_);
MPI_Finalize();
}

/**
Returns whether MPI is initialzed or not.
*/
static bool initialized() {
int flag = 1;
MPI_Initialized(&flag);
return bool(flag);
}

/**
Returns whether MPI is finalized or not.
*/
static bool finalized() {
int flag = 1;
MPI_Finalized(&flag);
return bool(flag);
}

/**
Return whether MPI is active (initialized and not finalized).
*/
static bool active() { return initialized() && !finalized(); }

// Prohibit copy and move
Mpi(const Mpi &rhs) = delete;
Mpi(Mpi &&rhs) = delete;

private:
/*
Constructor is hidden to make it a singleton.
*/
Mpi() {
if (!initialized()) {
/*
Initialize if MPI haven't been initialized by other libraries.
*/
int argc = 0;
char **argv = nullptr;
int requiredThreadLevelSupport = MPI_THREAD_SERIALIZED;
int provided;
MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided);
if (provided != requiredThreadLevelSupport) {
NBLA_ERROR(error_code::target_specific,
"MPI_Init_thread failed since provided (%d) is not equal to "
"requiredThreadLevelSupport (%d)",
provided, requiredThreadLevelSupport);
}
}
MPI_Comm_group(MPI_COMM_WORLD, &world_group_);
}
MPI_Group world_group_;
};

/**
A wrapper class of MPI_Comm.
It holds instances of MPI_Comm and MPI_Group, and those will be properly freed
at the end of lifetime of this class object.
@todo Move it to a public header after some refactoring.
*/
class MpiCommWrapper {
private:
// Holds MPI singleton to ensure the life time of it
shared_ptr<Mpi> mpi_;
/*
Whether it owns MPI_Comm and MPI_Group. The comm_ and group_ will be freed
if they are owned by this class object.
*/
bool own_;
MPI_Comm comm_;
MPI_Group group_;

public:
/**
Create an object for MPI_COMM_WORLD.
*/
MpiCommWrapper()
: mpi_(Mpi::get()), own_(false), comm_(MPI_COMM_WORLD),
group_(Mpi::world_group()) {}

/**
Create an object for a new group given by rank integers under
MPI_COMM_WORLD.
*/
MpiCommWrapper(std::vector<int> ranks) : mpi_(Mpi::get()), own_(true) {
MPI_Group_incl(Mpi::world_group(), ranks.size(), ranks.data(), &group_);
MPI_Comm_create_group(MPI_COMM_WORLD, group_, 0, &this->comm_);
}
/*
Deletes MPI_Group and MPI_Comm objects when it owns.
*/
~MpiCommWrapper() {
if (!this->own_) {
return;
}
MPI_Group_free(&group_);
if (this->comm_ == MPI_COMM_NULL) {
return;
}
MPI_Comm_free(&this->comm_);
}
/**
Returns MPI_Comm holded by this object.
It must not be freed outside.
*/
MPI_Comm comm() { return comm_; }
/**
Returns MPI_Group holded by this object.
It must not be freed outside.
*/
MPI_Group group() { return group_; }

// Prohibit copy and move
MpiCommWrapper(const MpiCommWrapper &rhs) = delete;
MpiCommWrapper(MpiCommWrapper &&rhs) = delete;
};

template <typename T>
__global__ void kernel_divide_inplace(const int size, const int n_devices,
T *dw) {
Expand Down Expand Up @@ -68,12 +229,43 @@ static void get_host_name(char *hostname, int maxlen) {
}
}

template <typename T>
bool MultiProcessDataParallelCommunicatorNccl<T>::mpi_check_any(
bool condition, const string &group) {
bool result;
MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LOR,
this->mpi_comms_[group]->comm());
return result;
}

template <typename T>
bool MultiProcessDataParallelCommunicatorNccl<T>::mpi_check_all(
bool condition, const string &group) {
bool result;
MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LAND,
this->mpi_comms_[group]->comm());
return result;
}

template <typename T>
vector<NdArrayPtr>
MultiProcessDataParallelCommunicatorNccl<T>::get_modified_arrays(
const vector<NdArrayPtr> &arrays, const string &group) {
// Get arrays which are modified after zeroing.
vector<NdArrayPtr> modified_array_list;

for (auto &a : arrays) {
if (mpi_check_any(!a->array()->zeroing(), group)) {
modified_array_list.push_back(a);
}
}
return modified_array_list;
}

template <typename T>
MultiProcessDataParallelCommunicatorNccl<
T>::MultiProcessDataParallelCommunicatorNccl(const Context &ctx)
: MultiProcessDataParallelCommunicator<T>(ctx) {
mpi_initialized_ = false;
}
: MultiProcessDataParallelCommunicator<T>(ctx) {}

template <typename T>
MultiProcessDataParallelCommunicatorNccl<
Expand All @@ -89,32 +281,12 @@ MultiProcessDataParallelCommunicatorNccl<
NBLA_CUDA_CHECK(cudaStreamDestroy(stream));
}
}
if (mpi_initialized_) {
MPI_Finalize();
}
}

template <typename T>
bool MultiProcessDataParallelCommunicatorNccl<T>::mpi_initialized_;

template <typename T> void MultiProcessDataParallelCommunicatorNccl<T>::init() {
Communicator::init();

// MPI init
if (!mpi_initialized_) {
int argc = 0;
char **argv = nullptr;
int requiredThreadLevelSupport = MPI_THREAD_SERIALIZED;
int provided;
MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided);
if (provided != requiredThreadLevelSupport) {
NBLA_ERROR(error_code::target_specific,
"MPI_Init_thread failed since provided (%d) is not equal to "
"requiredThreadLevelSupport (%d)",
provided, requiredThreadLevelSupport);
}
mpi_initialized_ = true;
}
Mpi::get(); // Make sure MPI singleton is initialized.
this->mpi_comms_["world"] = make_shared<MpiCommWrapper>();

// Create comm, set size, and rank
MPI_Comm_size(MPI_COMM_WORLD, &this->size_);
Expand Down Expand Up @@ -209,20 +381,15 @@ string MultiProcessDataParallelCommunicatorNccl<T>::new_group(
NBLA_CHECK(min >= 0, error_code::value,
"Min value of the specified ranks is greater than or equal to 0.");

// Create new group
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
MPI_Group new_group;
MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group);

// Create mpi communicator
MPI_Comm mpi_comm;
MPI_Comm_create(MPI_COMM_WORLD, new_group,
&mpi_comm); // have to call in all processes
auto group_mpi_comm = make_shared<MpiCommWrapper>(ranks);

// Add group name in all ranks
this->groups_[group_name] = ranks;

// Add group MPI comm
this->mpi_comms_[group_name] = group_mpi_comm;

// Leave if self is not in ranks
auto result = std::find(ranks.begin(), ranks.end(), this->rank_);
if (result == ranks.end()) { // self is not found in ranks.
Expand All @@ -235,10 +402,9 @@ string MultiProcessDataParallelCommunicatorNccl<T>::new_group(
ncclGetUniqueId(&comm_id);
}
int rank;
MPI_Comm_rank(mpi_comm, &rank);
MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, mpi_comm);
MPI_Barrier(mpi_comm);
MPI_Comm_free(&mpi_comm);
MPI_Comm_rank(group_mpi_comm->comm(), &rank);
MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, group_mpi_comm->comm());
MPI_Barrier(group_mpi_comm->comm());

// NCCL Comm Init
cuda_set_device(device_id_);
Expand Down Expand Up @@ -486,14 +652,20 @@ void MultiProcessDataParallelCommunicatorNccl<T>::all_reduce(
int k = 0;
dtypes dtype = get_dtype<Tc>();
for (auto ndarray : ndarray_list) { // ndarray loop
if (mpi_check_all(ndarray->array()->zeroing(), group)) {
continue;
}
int stream_id = k % num_streams_;
all_reduce(ndarray, streams_[stream_id], division, inplace, group);
k++;
}
} else { // out-of-place. use a large array.
NdArrayPtr large_ndarray = copy_inside_device(ndarray_list);
all_reduce(large_ndarray, nullptr, division, inplace, group);
copy_back_inside_device(ndarray_list, large_ndarray);
auto modified_ndarray_list = get_modified_arrays(ndarray_list, group);
if (!modified_ndarray_list.empty()) {
NdArrayPtr large_ndarray = copy_inside_device(modified_ndarray_list);
all_reduce(large_ndarray, nullptr, division, inplace, group);
copy_back_inside_device(modified_ndarray_list, large_ndarray);
}
}
launch_kernel_null();
}
Expand All @@ -505,6 +677,10 @@ void MultiProcessDataParallelCommunicatorNccl<T>::all_reduce(
NBLA_ERROR(error_code::value, "self (rank=%d) is not included in %s.",
this->rank_, group.c_str());
}
if (mpi_check_all(ndarray->array()->zeroing(), group)) {
// Skip since this array is not updated.
return;
}
all_reduce(ndarray, nullptr, division, inplace, group);
}

Expand Down Expand Up @@ -639,7 +815,6 @@ void MultiProcessDataParallelCommunicatorNccl<T>::all_gather(
NBLA_ERROR(error_code::value, "self (rank=%d) is not included in %s.",
this->rank_, group.c_str());
}

// TODO: currently nnabla uses default stream for computation.
// The following logic relies on that, so if nnabla uses another stream for
// computation,
Expand Down Expand Up @@ -756,9 +931,12 @@ void MultiProcessDataParallelCommunicatorNccl<T>::AllReduceCallback::
vector<std::pair<Tc *, size_t>> device_ptr_list;
device_ptr_list.reserve(ptr->function_inputs().size());
for (auto &input : ptr->function_inputs()) {
Tc *device_ptr = input->cast_grad_and_get_pointer<Tc>(this->parent_.ctx_);

if (this->device_ptrs_.find(input->grad()) != this->device_ptrs_.end()) {
if (parent_.mpi_check_all(input->grad()->array()->zeroing(), "world")) {
// Skip as the gradient array is not updated.
continue;
}
Tc *device_ptr = input->cast_grad_and_get_pointer<Tc>(this->parent_.ctx_);
device_ptr_list.push_back(std::make_pair(device_ptr, input->size()));
}
}
Expand Down

0 comments on commit 9a22013

Please sign in to comment.