Skip to content

Commit

Permalink
coll/acoll: Remove use of cid as array index
Browse files Browse the repository at this point in the history
The use of cid as array index is removed. Instead, now
coll_acoll_subcomms_t is dynamically allocated for each new
communicator.

Signed-off-by: Nithya V S <[email protected]>
  • Loading branch information
amd-nithyavs committed Aug 29, 2024
1 parent 39a8583 commit 3392b94
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 134 deletions.
5 changes: 3 additions & 2 deletions ompi/mca/coll/acoll/coll_acoll.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base

END_C_DECLS

#define MCA_COLL_ACOLL_MAX_CID 100
#define MCA_COLL_ACOLL_MAX_SUBC 10
#define MCA_COLL_ACOLL_ROOT_CHANGE_THRESH 10

typedef enum MCA_COLL_ACOLL_SG_SIZES {
Expand Down Expand Up @@ -208,8 +208,9 @@ struct mca_coll_acoll_module_t {
int mnode_log2_sg_size;
int allg_lin;
int allg_ring;
coll_acoll_subcomms_t subc[MCA_COLL_ACOLL_MAX_CID];
coll_acoll_subcomms_t *subc[MCA_COLL_ACOLL_MAX_SUBC];
coll_acoll_reserve_mem_t reserve_mem_s;
int num_subc;
};

#ifdef HAVE_XPMEM_H
Expand Down
12 changes: 6 additions & 6 deletions ompi/mca/coll/acoll/coll_acoll_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -481,21 +481,21 @@ int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_dataty
int brank, last_brank;
int use_rd_base;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
coll_acoll_subcomms_t *subc = NULL;
char *local_rbuf;
ompi_communicator_t *intra_comm;

/* Fallback to ring if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
err = check_and_create_subc(comm, acoll_module, &subc);
/* Fallback to ring if subc is not obtained */
if (subc == NULL) {
return ompi_coll_base_allgather_intra_ring(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm,
module);
}

subc = &acoll_module->subc[cid];
size = ompi_comm_size(comm);
if (!subc->initialized && size > 2) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, 0);
if (MPI_SUCCESS != err) {
return err;
}
Expand Down
60 changes: 25 additions & 35 deletions ompi/mca/coll/acoll/coll_acoll_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module, int intra);
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc, int intra);


static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size)
Expand All @@ -52,16 +53,13 @@ static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size)
static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc)
{
int size;
size_t total_dsize, dsize;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;

coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];
coll_acoll_init(module, comm, subc->data);
coll_acoll_init(module, comm, subc->data, subc);
coll_acoll_data_t *data = subc->data;
if (NULL == data) {
return -1;
Expand Down Expand Up @@ -188,16 +186,13 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc)
{
int size;
size_t total_dsize, dsize;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;

coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];
coll_acoll_init(module, comm, subc->data);
coll_acoll_init(module, comm, subc->data, subc);
coll_acoll_data_t *data = subc->data;
if (NULL == data) {
return -1;
Expand Down Expand Up @@ -361,15 +356,13 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module, int intra)
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc, int intra)
{
size_t dsize;
int err = MPI_SUCCESS;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];
coll_acoll_init(module, comm, subc->data);

coll_acoll_init(module, comm, subc->data, subc);
coll_acoll_data_t *data = subc->data;
if (NULL == data) {
return -1;
Expand All @@ -385,7 +378,6 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c

int l1_local_rank = data->l1_local_rank;
int l2_local_rank = data->l2_local_rank;
int comm_id = ompi_comm_get_local_cid(comm);

int offset1 = data->offset[0];
int offset2 = data->offset[1];
Expand Down Expand Up @@ -441,8 +433,8 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c
}
}

if (intra && (ompi_comm_size(acoll_module->subc[comm_id].numa_comm) > 1)) {
err = mca_coll_acoll_bcast(rbuf, count, dtype, 0, acoll_module->subc[comm_id].numa_comm, module);
if (intra && (ompi_comm_size(subc->numa_comm) > 1)) {
err = mca_coll_acoll_bcast(rbuf, count, dtype, 0, subc->numa_comm, module);
}
return err;
}
Expand All @@ -466,25 +458,23 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
return MPI_SUCCESS;
}

coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];

/* Falling back to recursivedoubling for non-commutative operators to be safe */
if (!ompi_op_is_commute(op)) {
return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, comm,
module);
}

/* Fallback to knomial if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
coll_acoll_subcomms_t *subc = NULL;
err = check_and_create_subc(comm, acoll_module, &subc);

/* Fallback to knomial if subc is not obtained */
if (subc == NULL) {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm,
module);
}

subc = &acoll_module->subc[cid];
if (!subc->initialized) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, 0);
if (MPI_SUCCESS != err)
return err;
}
Expand All @@ -499,7 +489,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
comm, module);
} else if (total_dsize < 512) {
return mca_coll_acoll_allreduce_small_msgs_h(sbuf, rbuf, count, dtype, op, comm, module,
1);
subc, 1);
} else if (total_dsize <= 2048) {
return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op,
comm, module);
Expand All @@ -517,7 +507,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
} else if (total_dsize < 4194304) {
#ifdef HAVE_XPMEM_H
if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) {
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module);
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module, subc);
} else {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype,
op, comm, module);
Expand All @@ -529,7 +519,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
} else if (total_dsize <= 16777216) {
#ifdef HAVE_XPMEM_H
if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) {
mca_coll_acoll_reduce_xpmem_h(sbuf, rbuf, count, dtype, op, comm, module);
mca_coll_acoll_reduce_xpmem_h(sbuf, rbuf, count, dtype, op, comm, module, subc);
return mca_coll_acoll_bcast(rbuf, count, dtype, 0, comm, module);
} else {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype,
Expand All @@ -542,7 +532,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
} else {
#ifdef HAVE_XPMEM_H
if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) {
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module);
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module, subc);
} else {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype,
op, comm, module);
Expand Down
13 changes: 7 additions & 6 deletions ompi/mca/coll/acoll/coll_acoll_barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,22 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base
ompi_request_t **reqs;
int num_nodes;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
coll_acoll_subcomms_t *subc = NULL;

/* Fallback to linear if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
err = check_and_create_subc(comm, acoll_module, &subc);

/* Fallback to linear if subcomms structure is not obtained */
if (subc == NULL) {
return ompi_coll_base_barrier_intra_basic_linear(comm, module);
}

subc = &acoll_module->subc[cid];
size = ompi_comm_size(comm);
if (size == 1) {
return err;
}
if (!subc->initialized && size > 1) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, 0);
if (MPI_SUCCESS != err) {
return err;
}
Expand Down
15 changes: 8 additions & 7 deletions ompi/mca/coll/acoll/coll_acoll_bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -444,24 +444,25 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat
size_t total_dsize, dsize;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
bcast_subc_func bcast_func[2] = {&bcast_binomial, &bcast_flat_tree};
coll_acoll_subcomms_t *subc;
coll_acoll_subcomms_t *subc = NULL;
struct ompi_communicator_t *subcomms[MCA_COLL_ACOLL_NUM_SC] = {NULL};
int subc_roots[MCA_COLL_ACOLL_NUM_SC] = {-1};
int cid = ompi_comm_get_local_cid(comm);

/* Fallback to knomial if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
err = check_and_create_subc(comm, acoll_module, &subc);
/* Fallback to knomial if subcomms is not obtained */
if (subc == NULL) {
return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4);
}

subc = &acoll_module->subc[cid];
/* Fallback to knomial if no. of root changes is beyond a threshold */
if (subc->num_root_change > MCA_COLL_ACOLL_ROOT_CHANGE_THRESH) {
if ((subc->num_root_change > MCA_COLL_ACOLL_ROOT_CHANGE_THRESH)
&& (root != subc->prev_init_root)) {
return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4);
}
size = ompi_comm_size(comm);
if ((!subc->initialized || (root != subc->prev_init_root)) && size > 2) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, root);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, root);
if (MPI_SUCCESS != err) {
return err;
}
Expand Down
50 changes: 8 additions & 42 deletions ompi/mca/coll/acoll/coll_acoll_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -186,47 +186,9 @@ static int acoll_register(void)
*/
static void mca_coll_acoll_module_construct(mca_coll_acoll_module_t *module)
{
for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) {
coll_acoll_subcomms_t *subc = &module->subc[i];
subc->initialized = 0;
subc->is_root_node = 0;
subc->is_root_sg = 0;
subc->is_root_numa = 0;
subc->outer_grp_root = -1;
subc->subgrp_root = 0;
subc->num_nodes = 1;
subc->prev_init_root = -1;
subc->num_root_change = 0;
subc->numa_root = 0;
subc->socket_ldr_root = -1;
subc->local_comm = NULL;
subc->local_r_comm = NULL;
subc->leader_comm = NULL;
subc->subgrp_comm = NULL;
subc->socket_comm = NULL;
subc->socket_ldr_comm = NULL;
for (int j = 0; j < MCA_COLL_ACOLL_NUM_LAYERS; j++) {
for (int k = 0; k < MCA_COLL_ACOLL_NUM_BASE_LYRS; k++) {
subc->base_comm[k][j] = NULL;
subc->base_root[k][j] = -1;
}
subc->local_root[j] = 0;
}

subc->numa_comm = NULL;
subc->numa_comm_ldrs = NULL;
subc->node_comm = NULL;
subc->inter_comm = NULL;
subc->cid = -1;
subc->initialized_data = false;
subc->initialized_shm_data = false;
subc->data = NULL;
#ifdef HAVE_XPMEM_H
subc->xpmem_buf_size = mca_coll_acoll_xpmem_buffer_size;
subc->without_xpmem = mca_coll_acoll_without_xpmem;
subc->xpmem_use_sr_buf = mca_coll_acoll_xpmem_use_sr_buf;
#endif
}
/* Set number of subcomms to 0 */
module->num_subc = 0;

/* Reserve memory init. Lazy allocation of memory when needed. */
(module->reserve_mem_s).reserve_mem = NULL;
Expand All @@ -247,8 +209,8 @@ static void mca_coll_acoll_module_construct(mca_coll_acoll_module_t *module)
static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module)
{

for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) {
coll_acoll_subcomms_t *subc = &module->subc[i];
for (int i = 0; i < module->num_subc; i++) {
coll_acoll_subcomms_t *subc = module->subc[i];
if (subc->initialized_data) {
if (subc->initialized_shm_data) {
if (subc->orig_comm != NULL) {
Expand Down Expand Up @@ -334,8 +296,12 @@ static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module)
}
}
subc->initialized = 0;
free(subc);
module->subc[i] = NULL;
}

module->num_subc = 0;

if ((true == (module->reserve_mem_s).reserve_mem_allocate)
&& (NULL != (module->reserve_mem_s).reserve_mem)) {
free((module->reserve_mem_s).reserve_mem);
Expand Down
Loading

0 comments on commit 3392b94

Please sign in to comment.