Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed May 22, 2024
1 parent 9ba3637 commit 46a1eb7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ void AllReduceDispatchType(AllReduceParams& param, AllReduceStrategyType strateg
}
}

AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tp_size, size_t tp_rank, uint32_t flag) {
AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tp_size, size_t tp_rank, uint32_t flag) {
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
AllReduceParams params;

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ struct AllReduceParams {
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
void* local_output_buffer_ptr;
void const* local_input_buffer_ptr;
const void* local_input_buffer_ptr;

static AllReduceParams deserialize(int32_t const* buffer, size_t tp_size, size_t tp_rank, uint32_t flag);
static AllReduceParams deserialize(const int32_t* buffer, size_t tp_size, size_t tp_rank, uint32_t flag);
};

bool ConfigurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t world_size,
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,17 @@ Status FuncCustomAllReduce(

ort_trtllm::AllReduceStrategyConfig m_config = ort_trtllm::AllReduceStrategyConfig::USE_MEMCPY;

static std::mutex s_mutex;
std::unique_lock<std::mutex> lock(s_mutex);
ORT_RETURN_IF_ERROR(ort_trtllm::GetCustomAllReduceWorkspace(rank, world_size, input_count * data_type->Size(),
ipc_mem_res_pack));

ort_trtllm::AllReduceParams params = ort_trtllm::AllReduceParams::deserialize(
reinterpret_cast<int32_t const*>(ipc_mem_res_pack.m_comm_ptrs.data()),
reinterpret_cast<const int32_t*>(ipc_mem_res_pack.m_comm_ptrs.data()),
world_size,
rank,
++ipc_mem_res_pack.counter);
lock.unlock();

CUDA_RETURN_IF_ERROR(cudaGetLastError());

Expand Down

0 comments on commit 46a1eb7

Please sign in to comment.