From 5126ffb30914e66c6e93bd965070891db7fc63fa Mon Sep 17 00:00:00 2001 From: Shiyi Zou Date: Thu, 24 Oct 2024 14:47:27 +0800 Subject: [PATCH] rename IsNodeArgSupported --- onnxruntime/core/providers/webnn/builders/helper.cc | 6 +++--- onnxruntime/core/providers/webnn/builders/helper.h | 2 +- .../core/providers/webnn/builders/impl/base_op_builder.cc | 4 ++-- onnxruntime/core/providers/webnn/builders/model_builder.cc | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 7cf6d600f1234..dc488f0409418 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,7 +69,7 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsNodeArgSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { const auto& node_arg_name = node_arg.Name(); const auto* shape_proto = node_arg.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. @@ -102,12 +102,12 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v std::vector> supported_node_groups; for (const auto* input : graph_viewer.GetInputs()) { - if (!IsNodeArgSupported(*input, "graph", logger)) { + if (!IsTensorShapeSupported(*input, "graph", logger)) { return supported_node_groups; } } for (const auto* output : graph_viewer.GetOutputs()) { - if (!IsNodeArgSupported(*output, "graph", logger)) { + if (!IsTensorShapeSupported(*output, "graph", logger)) { return supported_node_groups; } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 48bcaa2dbe9ad..6d2e7533750be 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } -bool IsNodeArgSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 0e7a8f53bae49..1e641017f36b6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -47,7 +47,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsNodeArgSupported(*input, node_name, logger)) { + if (!IsTensorShapeSupported(*input, node_name, logger)) { return false; } } @@ -72,7 +72,7 @@ bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* output : node.OutputDefs()) { - if (!IsNodeArgSupported(*output, node_name, logger)) { + if (!IsTensorShapeSupported(*output, node_name, logger)) { return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 8a7fea0cde431..ccf6c7911638b 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -214,7 +214,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (!shape.empty()) { dims.reserve(shape.size()); for (const auto& dim : shape) { - // dim_param free dimensions should have already been excluded by IsInputSupported(). + // dim_param free dimensions should have already been excluded by IsTensorShapeSupported(). assert(dim.has_dim_value()); dims.push_back(SafeInt(dim.dim_value())); }