|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "orttraining/core/optimizer/cast_sce_loss_fusion.h" |
| 5 | + |
| 6 | +#include "core/graph/graph_utils.h" |
| 7 | +#include "core/optimizer/initializer.h" |
| 8 | +#include "core/optimizer/utils.h" |
| 9 | + |
| 10 | +namespace onnxruntime { |
| 11 | + |
| 12 | +Status CastSceLossFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, |
| 13 | + const logging::Logger& logger) const { |
| 14 | + GraphViewer graph_viewer(graph); |
| 15 | + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); |
| 16 | + |
| 17 | + for (auto node_index : node_topology_list) { |
| 18 | + auto* node_ptr = graph.GetNode(node_index); |
| 19 | + if (!node_ptr) continue; // Node was removed. |
| 20 | + |
| 21 | + auto& node = *node_ptr; |
| 22 | + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); |
| 23 | + |
| 24 | + bool is_internal_sce = graph_utils::IsSupportedOptypeVersionAndDomain(node, "SoftmaxCrossEntropyLossInternal", {1}, |
| 25 | + kMSDomain); |
| 26 | + |
| 27 | + if (!is_internal_sce || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { |
| 28 | + continue; |
| 29 | + } |
| 30 | + |
| 31 | + Node* input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[0]->Name()); |
| 32 | + |
| 33 | + if (!(graph_utils::IsSupportedOptypeVersionAndDomain(*input_node, "Cast", {9, 13, 19}))) { |
| 34 | + continue; |
| 35 | + } |
| 36 | + |
| 37 | + if (input_node->GetOutputEdgesCount() != 1 || graph.IsOutput(input_node->OutputDefs()[0])) { |
| 38 | + continue; |
| 39 | + } |
| 40 | + |
| 41 | + if (input_node->MutableInputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT16 && |
| 42 | + input_node->MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type() == onnx::TensorProto_DataType_FLOAT) { |
| 43 | + std::vector<graph_utils::GraphEdge> input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node, 0); |
| 44 | + graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges); |
| 45 | + node.MutableInputDefs()[0] = input_node->MutableInputDefs()[0]; |
| 46 | + graph_utils::MoveAllNodeInputEdges(graph, *input_node, node); |
| 47 | + graph.RemoveNode(input_node->Index()); |
| 48 | + |
| 49 | + if (node.GetAttributes().count("output_type") == 0) { |
| 50 | + node.AddAttribute("output_type", static_cast<int64_t>(onnx::TensorProto_DataType_FLOAT)); |
| 51 | + } |
| 52 | + modified = true; |
| 53 | + } |
| 54 | + } |
| 55 | + |
| 56 | + return Status::OK(); |
| 57 | +} |
| 58 | + |
| 59 | +} // namespace onnxruntime |
0 commit comments