From 46a1eb77abe56cae21d1e0ac84bca175c5ec7b11 Mon Sep 17 00:00:00 2001 From: Your Date: Wed, 22 May 2024 21:33:07 +0000 Subject: [PATCH] update --- .../contrib_ops/cuda/collective/custom_reduce_impl.cu | 2 +- onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h | 4 ++-- onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc | 5 ++++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu index 2b857d31f43c9..a8d2eb7903fca 100644 --- a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu +++ b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu @@ -522,7 +522,7 @@ void AllReduceDispatchType(AllReduceParams& param, AllReduceStrategyType strateg } } -AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tp_size, size_t tp_rank, uint32_t flag) { +AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tp_size, size_t tp_rank, uint32_t flag) { void* const* buffer_ptrs = reinterpret_cast(buffer); AllReduceParams params; diff --git a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h index 7d41c57a8a37e..1d74bc2946f8c 100644 --- a/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h +++ b/onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h @@ -55,9 +55,9 @@ struct AllReduceParams { uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; void* local_output_buffer_ptr; - void const* local_input_buffer_ptr; + const void* local_input_buffer_ptr; - static AllReduceParams deserialize(int32_t const* buffer, size_t tp_size, size_t tp_rank, uint32_t flag); + static AllReduceParams deserialize(const int32_t* buffer, size_t tp_size, size_t tp_rank, uint32_t flag); }; bool ConfigurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t world_size, diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc index e69d987cbaa78..5a63e3e61ac55 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -452,14 +452,17 @@ Status FuncCustomAllReduce( ort_trtllm::AllReduceStrategyConfig m_config = ort_trtllm::AllReduceStrategyConfig::USE_MEMCPY; + static std::mutex s_mutex; + std::unique_lock lock(s_mutex); ORT_RETURN_IF_ERROR(ort_trtllm::GetCustomAllReduceWorkspace(rank, world_size, input_count * data_type->Size(), ipc_mem_res_pack)); ort_trtllm::AllReduceParams params = ort_trtllm::AllReduceParams::deserialize( - reinterpret_cast(ipc_mem_res_pack.m_comm_ptrs.data()), + reinterpret_cast(ipc_mem_res_pack.m_comm_ptrs.data()), world_size, rank, ++ipc_mem_res_pack.counter); + lock.unlock(); CUDA_RETURN_IF_ERROR(cudaGetLastError());