Skip to content

Commit

Permalink
rename IsNodeArgSupported
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyi9801 committed Oct 24, 2024
1 parent d51f01d commit 5126ffb
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}

bool IsNodeArgSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
Expand Down Expand Up @@ -102,12 +102,12 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
std::vector<std::vector<size_t>> supported_node_groups;

for (const auto* input : graph_viewer.GetInputs()) {
if (!IsNodeArgSupported(*input, "graph", logger)) {
if (!IsTensorShapeSupported(*input, "graph", logger)) {
return supported_node_groups;
}
}
for (const auto* output : graph_viewer.GetOutputs()) {
if (!IsNodeArgSupported(*output, "graph", logger)) {
if (!IsTensorShapeSupported(*output, "graph", logger)) {
return supported_node_groups;
}
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}

bool IsNodeArgSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsNodeArgSupported(*input, node_name, logger)) {
if (!IsTensorShapeSupported(*input, node_name, logger)) {
return false;
}
}
Expand All @@ -72,7 +72,7 @@ bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* output : node.OutputDefs()) {
if (!IsNodeArgSupported(*output, node_name, logger)) {
if (!IsTensorShapeSupported(*output, node_name, logger)) {
return false;
}
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
if (!shape.empty()) {
dims.reserve(shape.size());
for (const auto& dim : shape) {
// dim_param free dimensions should have already been excluded by IsInputSupported().
// dim_param free dimensions should have already been excluded by IsTensorShapeSupported().
assert(dim.has_dim_value());
dims.push_back(SafeInt<int32_t>(dim.dim_value()));
}
Expand Down

0 comments on commit 5126ffb

Please sign in to comment.