Skip to content

Commit 9ba3637

Browse files
committed
update
1 parent 435f1d3 commit 9ba3637

File tree

8 files changed

+129
-117
lines changed

8 files changed

+129
-117
lines changed

onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.cu

+46-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// Copyright (c) Microsoft Corporation. All rights reserved.
1717
// Licensed under the MIT License.
1818

19+
#include "core/providers/cuda/cu_inc/common.cuh"
1920
#include "core/providers/shared_library/provider_api.h"
2021
#include "custom_reduce_impl.h"
2122
#include <algorithm>
@@ -27,6 +28,9 @@ namespace ort_trtllm {
2728

2829
#if defined(USE_MPI) || defined(USE_NCCL)
2930

31+
using namespace onnxruntime;
32+
using namespace onnxruntime::cuda;
33+
3034
// Calculates ceil(a / b). User must be careful to ensure that there
3135
// is no overflow or underflow in the calculation.
3236
template <typename T>
@@ -559,13 +563,54 @@ size_t GetMaxRequiredWorkspaceSize(int world_size) {
559563
return 8 * 1000 * 1000;
560564
}
561565

562-
AllReduceStrategyType SelectImplementation(size_t message_size, int world_size, onnxruntime::MLDataType type) {
566+
Status SetPeerAccess(int rank, int world_size, bool enable, int& can_access_peer) {
567+
const int src_node = rank;
568+
569+
for (int dst_node = 0; dst_node < world_size; dst_node++) {
570+
if (dst_node == src_node) {
571+
continue;
572+
}
573+
574+
CUDA_RETURN_IF_ERROR(cudaDeviceCanAccessPeer(&can_access_peer, src_node, dst_node));
575+
576+
if (!can_access_peer) {
577+
return Status::OK();
578+
}
579+
580+
if (enable) {
581+
cudaDeviceEnablePeerAccess(dst_node, 0);
582+
} else {
583+
cudaDeviceDisablePeerAccess(dst_node);
584+
}
585+
586+
auto const error = cudaGetLastError();
587+
if (error != cudaErrorPeerAccessAlreadyEnabled && error != cudaErrorPeerAccessNotEnabled) {
588+
CUDA_RETURN_IF_ERROR(error);
589+
}
590+
}
591+
592+
return Status::OK();
593+
}
594+
595+
AllReduceStrategyType SelectImplementation(size_t message_size, int rank, int world_size,
596+
onnxruntime::MLDataType type) {
563597
AllReduceStrategyType strategy = AllReduceStrategyType::NCCL;
564598
if (type != onnxruntime::DataTypeImpl::GetType<float>() &&
565599
type != onnxruntime::DataTypeImpl::GetType<onnxruntime::MLFloat16>()) {
566600
return strategy;
567601
}
568602

603+
if (world_size != 2 && world_size != 4 && world_size != 6 && world_size != 8) {
604+
return strategy;
605+
}
606+
607+
int can_access_peer = 0;
608+
ORT_ENFORCE(SetPeerAccess(rank, world_size, true, can_access_peer) == Status::OK());
609+
// If P2P is not enabled, we cannot use the custom allreduce, so default to NCCL.
610+
if (!can_access_peer) {
611+
return strategy;
612+
}
613+
569614
const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);
570615
const size_t message_size_bytes = message_size * type->Size();
571616

onnxruntime/contrib_ops/cuda/collective/custom_reduce_impl.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ void CustomAllReduce(AllReduceParams& params, onnxruntime::MLDataType data_type,
6868

6969
size_t GetMaxRequiredWorkspaceSize(int world_size);
7070

71-
AllReduceStrategyType SelectImplementation(size_t message_size, int world_size, onnxruntime::MLDataType type);
71+
Status SetPeerAccess(int rank, int world_size, bool enable, int& can_access_peer);
72+
73+
AllReduceStrategyType SelectImplementation(size_t message_size, int rank, int world_size, onnxruntime::MLDataType type);
7274

7375
#endif
7476

onnxruntime/contrib_ops/cuda/collective/ipc_utils.cc

-28
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,6 @@ namespace ort_trtllm {
2525

2626
using namespace onnxruntime;
2727

28-
Status SetPeerAccess(int rank, int world_size, bool enable) {
29-
const int src_node = rank;
30-
31-
for (int dst_node = 0; dst_node < world_size; dst_node++) {
32-
if (dst_node == src_node) {
33-
continue;
34-
}
35-
36-
int can_access_peer;
37-
CUDA_RETURN_IF_ERROR(cudaDeviceCanAccessPeer(&can_access_peer, src_node, dst_node));
38-
39-
if (enable) {
40-
cudaDeviceEnablePeerAccess(dst_node, 0);
41-
} else {
42-
cudaDeviceDisablePeerAccess(dst_node);
43-
}
44-
auto const error = cudaGetLastError();
45-
if (error != cudaErrorPeerAccessAlreadyEnabled && error != cudaErrorPeerAccessNotEnabled) {
46-
CUDA_RETURN_IF_ERROR(error);
47-
}
48-
}
49-
50-
return Status::OK();
51-
}
52-
5328
IpcMemory::IpcMemory(int rank, int world_size, std::size_t buffer_size)
5429
: rank_(rank), world_size_(world_size), m_comm_ptrs_(world_size), mbuffer_size_(buffer_size) {
5530
ORT_ENFORCE(AllocateIpcMemory() == Status::OK());
@@ -113,9 +88,6 @@ Status GetCustomAllReduceWorkspace(int rank, int world_size, size_t input_size,
11388
return Status::OK();
11489
}
11590

116-
ORT_ENFORCE(SetPeerAccess(rank, world_size, true) == Status::OK());
117-
CUDA_RETURN_IF_ERROR(cudaGetLastError());
118-
11991
const std::size_t buffer_size = world_size * input_size;
12092

12193
std::vector<std::shared_ptr<IpcMemory>>& m_ipc_memory_handles = ipc_mem_res_pack.m_ipc_momery_handles;

onnxruntime/contrib_ops/cuda/collective/ipc_utils.h

-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ namespace ort_trtllm {
2525

2626
#if defined(USE_MPI) || defined(USE_NCCL)
2727

28-
Status SetPeerAccess(int rank, int world_size, bool enable = true);
29-
3028
class IpcMemory {
3129
public:
3230
size_t static constexpr FLAGS_SIZE = (MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t);

onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ Status FuncCustomAllReduce(
441441
int world_size = nccl->Size();
442442

443443
ort_trtllm::AllReduceStrategyType runtime_strategy =
444-
ort_trtllm::SelectImplementation(input_count, world_size, data_type);
444+
ort_trtllm::SelectImplementation(input_count, rank, world_size, data_type);
445445

446446
if (runtime_strategy == ort_trtllm::AllReduceStrategyType::NCCL) {
447447
ncclDataType_t dtype = GetNcclDataType(data_type);

onnxruntime/core/providers/js/operators/conv.h

+46-59
Original file line numberDiff line numberDiff line change
@@ -17,78 +17,65 @@ class ConvBase : public JsKernel {
1717
ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info),
1818
conv_attrs_(info),
1919
w_is_const_(false) {
20-
TensorShapeVector kernel_shape;
2120
const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size();
2221
std::vector<int32_t> local_pads(pads_vec_size, 0);
2322
for (size_t i = 0; i < conv_attrs_.pads.size() && i < pads_vec_size; ++i) {
2423
local_pads[i] = gsl::narrow_cast<int32_t>(conv_attrs_.pads[i]);
2524
}
2625

26+
TensorShapeVector kernel_shape;
2727
if (conv_attrs_.kernel_shape_specified) {
2828
ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK());
2929
}
30+
std::vector<int32_t> kernel_shapes(kernel_shape.size(), 0);
31+
if (conv_attrs_.kernel_shape_specified) {
32+
for (size_t i = 0; i < kernel_shape.size(); ++i) {
33+
kernel_shapes[i] = gsl::narrow_cast<int32_t>(kernel_shape[i]);
34+
}
35+
}
36+
37+
std::vector<int32_t> strides(conv_attrs_.strides.size(), 0);
38+
for (size_t i = 0; i < conv_attrs_.strides.size(); ++i) {
39+
strides[i] = gsl::narrow_cast<int32_t>(conv_attrs_.strides[i]);
40+
}
41+
42+
std::vector<int32_t> dilations(conv_attrs_.dilations.size(), 0);
43+
for (size_t i = 0; i < conv_attrs_.dilations.size(); ++i) {
44+
dilations[i] = gsl::narrow_cast<int32_t>(conv_attrs_.dilations[i]);
45+
}
46+
3047
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
3148
std::vector<float> activation_params = info.GetAttrsOrDefault<float>("activation_params");
3249
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);
33-
auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0;
34-
auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0;
50+
3551
// currently only support Conv 1D/2D. TODO: support Conv3D and other
36-
if (conv_attrs_.dilations.size() == 1 ||
37-
(conv_attrs_.kernel_shape_specified && kernel_shape.size() == 1) ||
38-
conv_attrs_.strides.size() == 1) {
39-
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
40-
"format" : $8 ? "NHWC" : "NCHW",
41-
"auto_pad" : $1,
42-
"dilations" : [$2],
43-
"group" : $3,
44-
"kernel_shape" : [$4],
45-
"pads" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [],
46-
"strides" : [$7],
47-
"w_is_const" : () JS_ARROW(!!HEAP8[$9]),
48-
"activation" : UTF8ToString($10),
49-
"activation_params" : $11 ? Array.from(HEAPF32.subarray($11, $12)) : []
50-
}),
51-
static_cast<int32_t>(conv_attrs_.auto_pad),
52-
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
53-
static_cast<int32_t>(conv_attrs_.group),
54-
static_cast<int32_t>(kernel_shape_0),
55-
JSEP_HEAP32_INDEX_START(local_pads),
56-
JSEP_HEAP32_INDEX_END(local_pads),
57-
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
58-
static_cast<int32_t>(channels_last),
59-
JSEP_HEAP8_INDEX(&w_is_const_),
60-
conv_attrs_.activation.c_str(),
61-
JSEP_HEAP32_INDEX_START(activation_params),
62-
JSEP_HEAP32_INDEX_END(activation_params));
63-
} else {
64-
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
65-
"format" : $11 ? "NHWC" : "NCHW",
66-
"auto_pad" : $1,
67-
"dilations" : [ $2, $3 ],
68-
"group" : $4,
69-
"kernel_shape" : [ $5, $6 ],
70-
"pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [],
71-
"strides" : [ $9, $10 ],
72-
"w_is_const" : () JS_ARROW(!!HEAP8[$12]),
73-
"activation" : UTF8ToString($13),
74-
"activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : []
75-
}),
76-
static_cast<int32_t>(conv_attrs_.auto_pad),
77-
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
78-
static_cast<int32_t>(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0),
79-
static_cast<int32_t>(conv_attrs_.group),
80-
static_cast<int32_t>(kernel_shape_0),
81-
static_cast<int32_t>(kernel_shape_1),
82-
JSEP_HEAP32_INDEX_START(local_pads),
83-
JSEP_HEAP32_INDEX_END(local_pads),
84-
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
85-
static_cast<int32_t>(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0),
86-
static_cast<int32_t>(channels_last),
87-
JSEP_HEAP8_INDEX(&w_is_const_),
88-
conv_attrs_.activation.c_str(),
89-
JSEP_HEAP32_INDEX_START(activation_params),
90-
JSEP_HEAP32_INDEX_END(activation_params));
91-
}
52+
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
53+
"format" : $11 ? "NHWC" : "NCHW",
54+
"auto_pad" : $1,
55+
"dilations" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [],
56+
"group" : $4,
57+
"kernel_shape" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [],
58+
"pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [],
59+
"strides" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [],
60+
"w_is_const" : () JS_ARROW(!!HEAP8[$12]),
61+
"activation" : UTF8ToString($13),
62+
"activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : []
63+
}),
64+
static_cast<int32_t>(conv_attrs_.auto_pad),
65+
JSEP_HEAP32_INDEX_START(dilations),
66+
JSEP_HEAP32_INDEX_END(dilations),
67+
static_cast<int32_t>(conv_attrs_.group),
68+
JSEP_HEAP32_INDEX_START(kernel_shape),
69+
JSEP_HEAP32_INDEX_END(kernel_shape),
70+
JSEP_HEAP32_INDEX_START(local_pads),
71+
JSEP_HEAP32_INDEX_END(local_pads),
72+
JSEP_HEAP32_INDEX_START(strides),
73+
JSEP_HEAP32_INDEX_END(strides),
74+
static_cast<int32_t>(channels_last),
75+
JSEP_HEAP8_INDEX(&w_is_const_),
76+
conv_attrs_.activation.c_str(),
77+
JSEP_HEAP32_INDEX_START(activation_params),
78+
JSEP_HEAP32_INDEX_END(activation_params));
9279
}
9380

9481
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,

orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc

+31-23
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,30 @@ void IterateSubgraphFromNode(Graph& graph,
359359
}
360360
} // namespace
361361

362+
void RemovePrintDensityFlag(Graph& graph,
363+
const std::vector<NodeIndex>& node_topology_list,
364+
bool& modified,
365+
const logging::Logger& logger) {
366+
for (auto node_index : node_topology_list) {
367+
Node* node = graph.GetNode(node_index);
368+
if (node == nullptr) {
369+
continue;
370+
}
371+
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node, "PythonOp", {1}, kMSDomain) &&
372+
static_cast<std::string>(node->GetAttributes().at("func_name").s()) == kFlagAndPrintDensityFuncName) {
373+
if (graph_utils::CanRemoveNode(graph, *node, logger)) {
374+
if (graph_utils::RemoveNode(graph, *node)) {
375+
modified = true;
376+
} else {
377+
LOG_DEBUG_INFO(logger, "Failed to remove node " + node->Name() + "(" + node->OpType() + ")");
378+
}
379+
} else {
380+
LOG_DEBUG_INFO(logger, "Can not remove node " + node->Name() + "(" + node->OpType() + ")");
381+
}
382+
}
383+
}
384+
}
385+
362386
Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
363387
LOG_DEBUG_INFO(logger, "Enter PaddingElimination");
364388

@@ -392,10 +416,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
392416
node.InputDefs()[1]->Exists() &&
393417
node.InputDefs()[1]->Shape() &&
394418
node.InputDefs()[1]->Shape()->dim_size() >= 2) {
395-
const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd());
396-
if (outputNodeCount != 1) {
397-
continue;
398-
}
399419
Node* embedding_input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[1]->Name());
400420
if (embedding_input_node == nullptr ||
401421
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_input_node, "PythonOp", {1}, kMSDomain) ||
@@ -404,21 +424,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
404424
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
405425
continue;
406426
}
407-
if (!print_density_) {
408-
if (graph_utils::CanRemoveNode(graph, *embedding_input_node, logger)) {
409-
if (graph_utils::RemoveNode(graph, *embedding_input_node)) {
410-
modified = true;
411-
} else {
412-
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_input_node->Name() +
413-
"(" + embedding_input_node->OpType() + ")");
414-
continue;
415-
}
416-
} else {
417-
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_input_node->Name() +
418-
"(" + embedding_input_node->OpType() + ")");
419-
continue;
420-
}
421-
}
422427
const ONNX_NAMESPACE::TensorProto* padding_initializer =
423428
graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
424429
if (padding_initializer != nullptr &&
@@ -430,19 +435,22 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
430435
continue;
431436
}
432437
embedding_node = &node;
433-
input_ids_arg = embedding_node->MutableInputDefs()[1];
434-
for (auto output_defs : embedding_node->MutableOutputDefs()) {
435-
subgraph.insert(output_defs);
436-
}
437438
break;
438439
}
439440
}
440441
}
441442

443+
if (!print_density_) {
444+
RemovePrintDensityFlag(graph, node_topology_list, modified, logger);
445+
}
442446
if (!embedding_node) {
443447
LOG_DEBUG_INFO(logger, "Exit PaddingElimination optimization for not finding any valid embedding node.");
444448
return Status::OK();
445449
}
450+
input_ids_arg = embedding_node->MutableInputDefs()[1];
451+
for (auto output_defs : embedding_node->MutableOutputDefs()) {
452+
subgraph.insert(output_defs);
453+
}
446454

447455
if (!input_ids_arg->Shape()) {
448456
LOG_DEBUG_INFO(logger, "Exit PaddingElimination optimization for not finding shape of input_ids.");

tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ stages:
122122
--volume $(Build.BinariesDirectory):/build \
123123
--volume /mnist:/mnist \
124124
onnxruntime_ortmodule_distributed_tests_image \
125-
bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install mpi4py onnxscript && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && mpirun -n 4 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_collective.py && mpirun -n 2 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_distributed.py" \
126-
displayName: 'Run onnxruntime_test_collective.py'
125+
bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install mpi4py onnxscript && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && mpirun -n 4 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_collective.py && mpirun -n 2 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_distributed.py && mpirun -n 2 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py" \
126+
displayName: 'Run onnxruntime_test_collective.py, onnxruntime_test_distributed.py and test_sharded_moe.py'
127127
condition: succeededOrFailed()
128128
timeoutInMinutes: 30
129129

0 commit comments

Comments
 (0)