Skip to content

Commit

Permalink
use node name in transpose optimizer when adding nodes rather than op…
Browse files Browse the repository at this point in the history
…type (#22084)

patch from @john-dance

"The main change is simple: Use the original node name rather than the
original node op_type when creating new nodes. Here are my comments on
the change:
------
The onnx runtime uses the op_type as the basis for a new node name, so a
node claimed by QNN EP might be named
Conv_token_1 with no relation to the original /conv1/Conv. This patch:
1. Adds OpName as a virtual function in NodeRef and implements it in
ApiNode.
2. AddNode now takes an op_name and op_type and passes them both to
CreateNodeHelper.
3. CreateNodeHelper uses the op_name rather than the op_type in
GenerateNodeName
4. Direct calls to AddNode are modified to either use the NodeRef if
available, or just repeat the op_type if not available.
The result is that the new nodes are named something like
/conv1/Conv_token_1, allowing a straight forward mapping back to the
original model node (if they exist in the original graph)."
  • Loading branch information
jywu-msft authored Sep 16, 2024
1 parent 6d7235b commit 1a1669f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ static std::unique_ptr<api::NodeRef> MakeNode1Attr(api::GraphRef& graph, std::st
std::string_view input, std::string_view attr_name,
const std::vector<int64_t>& attr_val) {
std::vector<std::string_view> inputs{input};
std::unique_ptr<api::NodeRef> node = graph.AddNode(op_type, inputs, /*num_outputs*/ 1);
std::unique_ptr<api::NodeRef> node = graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1);
node->SetAttributeInts(attr_name, attr_val);
return node;
}
Expand All @@ -102,7 +102,7 @@ static std::unique_ptr<api::NodeRef> MakeSqueezeOrUnsqueeze(int64_t opset, api::

std::vector<std::string_view> inputs{input, axes_initializer};

return graph.AddNode(op_type, inputs, /*num_outputs*/ 1);
return graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1);
}

/// <summary>
Expand Down Expand Up @@ -136,7 +136,7 @@ static std::unique_ptr<api::NodeRef> MakeQuantizeOp(api::GraphRef& graph, std::s
std::optional<int64_t> block_size,
std::optional<int64_t> output_dtype,
std::optional<int64_t> saturate) {
std::unique_ptr<api::NodeRef> node = graph.AddNode("QuantizeLinear", inputs, /* num_outputs */ 1, domain);
std::unique_ptr<api::NodeRef> node = graph.AddNode("QuantizeLinear", "QuantizeLinear", inputs, /* num_outputs */ 1, domain);

SetAttrIfNotDefault(*node, "axis", axis, 1);

Expand Down Expand Up @@ -170,7 +170,7 @@ static std::unique_ptr<api::NodeRef> MakeDequantizeOp(api::GraphRef& graph, std:
std::vector<std::string_view> inputs,
std::optional<int64_t> axis,
std::optional<int64_t> block_size) {
std::unique_ptr<api::NodeRef> node = graph.AddNode("DequantizeLinear", inputs, /* num_outputs */ 1, domain);
std::unique_ptr<api::NodeRef> node = graph.AddNode("DequantizeLinear", "DequantizeLinear", inputs, /* num_outputs */ 1, domain);

SetAttrIfNotDefault(*node, "axis", axis, 1);

Expand Down Expand Up @@ -1724,7 +1724,7 @@ static bool HandleShape(HandlerArgs& args) {

// X -> Shape -> Y, Gather
std::vector<std::string_view> gather_inputs{"", perm_const};
auto gather_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1);
auto gather_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1);
api::NodeRef& gather = *gather_ptr;
gather.SetAttributeInt("axis", 0);

Expand Down Expand Up @@ -1767,7 +1767,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con
// inputs that would never be quantized.
std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm);
std::vector<std::string_view> gather_inputs{input_name, gather_indices_const};
auto gather_ptr = graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1);
auto gather_ptr = graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1);
api::NodeRef& gather = *gather_ptr;
std::string_view gather_output = gather.Outputs()[0];
graph.CopyValueInfo(input_name, gather_output);
Expand Down Expand Up @@ -2215,7 +2215,7 @@ static bool HandleTile(HandlerArgs& args) {
// Case 2: Repeats is computed. Insert Gather node.
std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv);
std::vector<std::string_view> gather_inputs{repeats_inp, perm_inv_const};
auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1);
auto gather_node_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1);
api::NodeRef& gather_node = *gather_node_ptr;
std::string_view gather_output = gather_node.Outputs()[0];
args.ctx.graph.CopyValueInfo(repeats_inp, gather_output);
Expand Down Expand Up @@ -2265,7 +2265,7 @@ static void RemoveCancelingTransposeNodes(HandlerArgs& args) {
// Worst-case scenario: Both parent output and 2nd transpose/reshape output cannot be removed (both graph outputs)
// despite computing the same value. Use an Identity op instead.
std::vector<std::string_view> single_empty_input{""};
auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1);
auto identity_ptr = args.ctx.graph.AddNode("Identity", "Identity", single_empty_input, /*num_outputs*/ 1);
api::NodeRef& identity = *identity_ptr;
args.ctx.graph.MoveOutput(args.node, 0, identity, 0);
identity.SetInput(0, transpose_input);
Expand Down Expand Up @@ -2297,7 +2297,7 @@ static bool HandleTransposeImpl(HandlerArgs& args, const std::vector<int64_t>& n
// replace Reshape with Transpose to simplify the logic.
// use the same input as the 1st Transpose, move the output from the Reshape to the new Transpose node,
// and remove the Reshape node.
new_node = args.ctx.graph.AddNode("Transpose", {args.transpose.Inputs()[0]}, 1);
new_node = args.ctx.graph.AddNode("Transpose", "Transpose", {args.transpose.Inputs()[0]}, 1);
args.ctx.graph.MoveOutput(args.node, 0, *new_node, 0);
args.ctx.graph.RemoveNode(args.node);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class ValueInfoRef {
/// </summary>
class NodeRef {
public:
/// <returns>Node name</returns>
virtual std::string_view Name() const = 0;

/// <returns>Op computed by the node</returns>
virtual std::string_view OpType() const = 0;

Expand Down Expand Up @@ -361,14 +364,15 @@ class GraphRef {
/// generated. Outputs of created node have unspecified shapes/dtypes. They will be populated afterwards using
/// CopyValueInfo.
/// </summary>
/// <param name="name">The new node's name</param>
/// <param name="op_type">The new node's op type</param>
/// <param name="inputs">Inputs for the node. "" for missing optional inputs.</param>
/// <param name="num_outputs">
/// Number of outputs for the node. Names automatically generated. Optional outputs not supported.
/// </param>
/// <param name="domain">The new node's domain. Empty string signifies default onnx domain.</param>
/// <returns>The new node</returns>
virtual std::unique_ptr<NodeRef> AddNode(std::string_view op_type, const std::vector<std::string_view>& inputs,
virtual std::unique_ptr<NodeRef> AddNode(std::string_view name, std::string_view op_type, const std::vector<std::string_view>& inputs,
size_t num_outputs, std::string_view domain = /*kOnnxDomain*/ "") = 0;

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class ApiNode final : public api::NodeRef {
return node_;
}

std::string_view Name() const override {
return node_.Name();
}

std::string_view OpType() const override {
return node_.OpType();
}
Expand Down Expand Up @@ -134,7 +138,7 @@ class ApiGraph final : public api::GraphRef {
std::unique_ptr<api::NodeRef> GetNodeProducingOutput(std::string_view name) const override;
void TransposeInitializer(std::string_view name, const std::vector<int64_t>& perm) override;
void ReshapeInitializer(std::string_view name, const std::vector<int64_t>& shape) override;
std::unique_ptr<api::NodeRef> AddNode(std::string_view op_type, const std::vector<std::string_view>& inputs,
std::unique_ptr<api::NodeRef> AddNode(std::string_view name, std::string_view op_type, const std::vector<std::string_view>& inputs,
size_t num_outputs = 1, std::string_view domain = "") override;

std::unique_ptr<api::NodeRef> CopyNode(const api::NodeRef& source_node, std::string_view op_type,
Expand Down Expand Up @@ -621,11 +625,12 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vector<int64
node_arg->SetShape(new_shape);
}

static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type,
static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_name, std::string_view op_type,
const std::vector<std::string_view>& inputs, size_t num_outputs,
std::string_view domain, int since_version, std::string_view node_ep) {
const std::string op_type_str(op_type);
std::string name = graph.GenerateNodeName(op_type_str);
const std::string op_name_str(op_name);
std::string name = graph.GenerateNodeName(op_name_str);
std::vector<NodeArg*> input_args;
std::vector<NodeArg*> output_args;

Expand Down Expand Up @@ -731,11 +736,11 @@ static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view do
return *since_version;
}

std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view name, std::string_view op_type,
const std::vector<std::string_view>& inputs, size_t num_outputs,
std::string_view domain) {
int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap());
Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs,
Node& node = CreateNodeHelper(graph_, name, op_type, inputs, num_outputs,
domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : "");

return std::make_unique<ApiNode>(node, graph_);
Expand All @@ -744,7 +749,7 @@ std::unique_ptr<api::NodeRef> ApiGraph::AddNode(std::string_view op_type,
std::unique_ptr<api::NodeRef> ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type,
std::string_view domain, std::optional<int> since_version) {
const int new_node_since_version = since_version.has_value() ? *since_version : source_node.SinceVersion();
Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(),
Node& node = CreateNodeHelper(graph_, source_node.Name(), op_type, source_node.Inputs(),
source_node.Outputs().size(), domain, new_node_since_version,
source_node.GetExecutionProviderType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) {

// Error message should come from the Conv implementation with the statically registered kernel
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
"Non-zero status code returned while running Conv node. Name:'Conv' "
"Non-zero status code returned while running Conv node. Name:'_token_2' "
"Status Message: TODO: add NHWC implementation here.");
}

Expand Down Expand Up @@ -242,7 +242,7 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) {
std::vector<OrtValue> fetches;

ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches),
"Non-zero status code returned while running Conv node. Name:'Conv' "
"Non-zero status code returned while running Conv node. Name:'_token_2' "
"Status Message: TODO: add NHWC implementation here.");
};

Expand Down

0 comments on commit 1a1669f

Please sign in to comment.