Skip to content

Commit 3e4db2c

Browse files
authored
Fuse Cast + SoftmaxCrossEntropyLossInternal (microsoft#20334)
### Description Fuse Cast + SoftmaxCrossEntropyLossInternal to SoftmaxCrossEntropyLossInternal.
1 parent 923b0ef commit 3e4db2c

File tree

10 files changed

+132
-6
lines changed

10 files changed

+132
-6
lines changed

Diff for: onnxruntime/core/graph/graph_utils.cc

+13-4
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,7 @@ static bool RemoveNodeWithSingleNodeInSingleUsedOutput(Graph& graph, Node& node)
172172
return true;
173173
}
174174

175-
/** Move the input edges that src_node has to target_node.
176-
After the move is complete src_node will have no input edges.
177-
*/
178-
static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
175+
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node) {
179176
auto target_idx = target_node.Index();
180177
auto input_edges = GraphEdge::GetNodeInputEdges(src_node);
181178

@@ -387,6 +384,18 @@ std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node) {
387384
return input_edges;
388385
}
389386

387+
/** Returns a vector of the input GraphEdges of a node for the provided input index. */
388+
std::vector<GraphEdge> GraphEdge::GetNodeInputEdges(const Node& node, size_t index) {
389+
std::vector<GraphEdge> input_edges;
390+
for (auto it = node.InputEdgesBegin(), end = node.InputEdgesEnd(); it != end; ++it) {
391+
if (static_cast<size_t>(it->GetDstArgIndex()) == index) {
392+
input_edges.push_back(GraphEdge::CreateGraphEdge(node, *it, true));
393+
}
394+
}
395+
396+
return input_edges;
397+
}
398+
390399
/** Returns a vector of the output GraphEdges of a node. */
391400
std::vector<GraphEdge> GraphEdge::GetNodeOutputEdges(const Node& node) {
392401
std::vector<GraphEdge> output_edges;

Diff for: onnxruntime/core/graph/graph_utils.h

+8
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ const std::string& GetNodeOutputName(const Node& node, int index);
5959
*/
6060
const Node::EdgeEnd* GetInputEdge(const Node& node, int arg_index);
6161

62+
/** Move the input edges that src_node has to target_node.
63+
After the move is complete src_node will have no input edges.
64+
*/
65+
void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_node);
66+
6267
/** Removes all output edges from the given Node of the Graph.
6368
This should probably be elevated to the Graph API eventually. */
6469
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
@@ -89,6 +94,9 @@ struct GraphEdge {
8994
/** Returns a vector of the input GraphEdges of a node. */
9095
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node);
9196

97+
/** Returns a vector of the input GraphEdges of a node for the provided input index. */
98+
static std::vector<GraphEdge> GetNodeInputEdges(const Node& node, size_t index);
99+
92100
/** Returns a vector of the output GraphEdges of a node. */
93101
static std::vector<GraphEdge> GetNodeOutputEdges(const Node& node);
94102

Diff for: onnxruntime/core/optimizer/gemm_transpose_fusion.cc

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
8787
new_gemm_node.AddAttribute("alpha", gemm_node.GetAttributes().at("alpha").f());
8888
new_gemm_node.AddAttribute("beta", gemm_node.GetAttributes().at("beta").f());
8989

90+
new_gemm_node.SetExecutionProviderType(gemm_node.GetExecutionProviderType());
91+
9092
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, new_gemm_node);
9193

9294
modified = RewriteRuleEffect::kRemovedCurrentNode;

Diff for: onnxruntime/core/optimizer/graph_transformer_utils.cc

+6
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
138138
break;
139139

140140
case TransformerLevel::Level2:
141+
rules.push_back(std::make_unique<GemmTransposeFusion>());
141142
// No level2 rules available today
142143
break;
143144

@@ -253,6 +254,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
253254
} break;
254255

255256
case TransformerLevel::Level2: {
257+
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
258+
if (rule_transformer != nullptr) {
259+
transformers.emplace_back(std::move(rule_transformer));
260+
}
261+
256262
// we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be
257263
// applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2).
258264
transformers.emplace_back(std::make_unique<TransposeOptimizer>(std::move(cpu_allocator), kCpuExecutionProvider));

Diff for: onnxruntime/core/optimizer/propagate_cast_ops.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ static bool IsFP16Allow(const Node* node, size_t level, const FP16AllowOps& fp16
171171

172172
using OpsSetType = InlinedHashSet<std::string_view>;
173173
static const OpsSetType level1_fp16_allow_set =
174-
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"};
174+
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu", "Slice", "PadAndUnflatten"};
175175
static const OpsSetType level2_fp16_allow_set = {
176176
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"};
177177

Diff for: onnxruntime/core/optimizer/utils.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm
281281
// (plus ShrunkenGather for training) are considered deterministic.
282282
#ifdef ENABLE_TRAINING_OPS
283283
constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear",
284-
"ConcatTraining"};
284+
"ConcatTraining", "PadAndUnflatten"};
285285
#else
286286
constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"};
287287
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
5+
6+
#include "core/graph/graph_utils.h"
7+
#include "core/optimizer/initializer.h"
8+
#include "core/optimizer/utils.h"
9+
10+
namespace onnxruntime {
11+
12+
Status CastSceLossFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
13+
const logging::Logger& logger) const {
14+
GraphViewer graph_viewer(graph);
15+
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
16+
17+
for (auto node_index : node_topology_list) {
18+
auto* node_ptr = graph.GetNode(node_index);
19+
if (!node_ptr) continue; // Node was removed.
20+
21+
auto& node = *node_ptr;
22+
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
23+
24+
bool is_internal_sce = graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternal", {1},
25+
kMSDomain);
26+
27+
if (!is_internal_sce || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
28+
continue;
29+
}
30+
31+
Node* input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[0]->Name());
32+
33+
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*input_node, "Cast", {9, 13, 19}))) {
34+
continue;
35+
}
36+
37+
if (input_node->GetOutputEdgesCount() != 1 || graph.IsOutput(input_node->OutputDefs()[0])) {
38+
continue;
39+
}
40+
41+
if (input_node->MutableInputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT16 &&
42+
input_node->MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT) {
43+
std::vector<graph_utils::GraphEdge> input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node, 0);
44+
graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges);
45+
node.MutableInputDefs()[0] = input_node->MutableInputDefs()[0];
46+
graph_utils::MoveAllNodeInputEdges(graph, *input_node, node);
47+
graph.RemoveNode(input_node->Index());
48+
49+
if (node.GetAttributes().count("output_type") == 0) {
50+
node.AddAttribute("output_type", static_cast<int64_t>(onnx::TensorProto_DataType_FLOAT));
51+
}
52+
modified = true;
53+
}
54+
}
55+
56+
return Status::OK();
57+
}
58+
59+
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/optimizer/graph_transformer.h"
7+
8+
namespace onnxruntime {
9+
10+
/**
11+
@Class CastSceLossFusion
12+
Fuse Cast + SoftmaxCrossEntropyLossInternal to SoftmaxCrossEntropyLossInternal.
13+
*/
14+
class CastSceLossFusion : public GraphTransformer {
15+
public:
16+
explicit CastSceLossFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
17+
: GraphTransformer("CastSceLossFusion", compatible_execution_providers) {
18+
}
19+
20+
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
21+
};
22+
23+
} // namespace onnxruntime

Diff for: orttraining/orttraining/core/optimizer/graph_transformer_utils.cc

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "orttraining/core/framework/distributed_run_context.h"
5353
#include "orttraining/core/optimizer/batchnorm_replacement.h"
5454
#include "orttraining/core/optimizer/bitmask_dropout_replacement.h"
55+
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
5556
#include "orttraining/core/optimizer/concat_replacement.h"
5657
#include "orttraining/core/optimizer/graph_transformer_registry.h"
5758
#include "orttraining/core/optimizer/gru_replacement.h"
@@ -188,6 +189,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
188189
config.propagate_cast_ops_config.allow,
189190
cuda_execution_provider));
190191
}
192+
transformers.emplace_back(std::make_unique<CastSceLossFusion>(compatible_eps));
191193

192194
if (config.enable_compute_optimizer) {
193195
transformers.emplace_back(std::make_unique<UpStreamGatherGraphTransformer>(compatible_eps));

Diff for: orttraining/orttraining/test/optimizer/graph_transform_test.cc

+17
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "test/util/include/asserts.h"
2626
#include "orttraining/test/optimizer/horizontal_parallel_test_utils.h"
2727
#include "orttraining/core/session/training_session.h"
28+
#include "orttraining/core/optimizer/cast_sce_loss_fusion.h"
2829
#include "orttraining/core/optimizer/loss_rewriter.h"
2930
#include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h"
3031
#include "orttraining/core/optimizer/qdq_fusion.h"
@@ -518,6 +519,22 @@ TEST_F(GraphTransformationTests, SceLossGradBiasFusion_Invalid) {
518519
}
519520
}
520521

522+
TEST_F(GraphTransformationTests, CastSceLossFusion) {
523+
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "computation_reduction/reshape/mlm_bert_e2e.onnx";
524+
std::shared_ptr<Model> model;
525+
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
526+
Graph& graph = model->MainGraph();
527+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
528+
ASSERT_EQ(op_to_count["Cast"], 10);
529+
530+
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
531+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<CastSceLossFusion>(), TransformerLevel::Level2));
532+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
533+
534+
op_to_count = CountOpsInGraph(graph);
535+
ASSERT_EQ(op_to_count["Cast"], 9);
536+
}
537+
521538
Node* GetNodeByName(Graph& graph, std::string node_name) {
522539
GraphViewer graph_viewer(graph);
523540
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

0 commit comments

Comments
 (0)