From b2d603abdaf580872d3249455241f53d64d5a2a9 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 16 Aug 2024 13:59:51 +0800 Subject: [PATCH] [WebNN EP] Remove workaround for scalar (#21704) Currently Chromium has supported scalar with dims = {}, remove legacy workaround for supporting scalar. --- .../core/providers/webnn/builders/model.cc | 4 ---- .../core/providers/webnn/builders/model.h | 8 ------- .../providers/webnn/builders/model_builder.cc | 22 +++++-------------- .../providers/webnn/builders/model_builder.h | 5 +---- .../webnn/webnn_execution_provider.cc | 10 --------- 5 files changed, 6 insertions(+), 43 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index ef807a8c4fa26..8cd2e8d0ffad3 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -142,10 +142,6 @@ Status Model::Predict(const InlinedHashMap& inputs, return Status::OK(); } -bool Model::IsScalarOutput(const std::string& output_name) const { - return Contains(scalar_outputs_, output_name); -} - const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { return input_output_info_.at(name); } diff --git a/onnxruntime/core/providers/webnn/builders/model.h b/onnxruntime/core/providers/webnn/builders/model.h index 4af82a2675691..5119dbbbc9858 100644 --- a/onnxruntime/core/providers/webnn/builders/model.h +++ b/onnxruntime/core/providers/webnn/builders/model.h @@ -34,8 +34,6 @@ class Model { onnxruntime::common::Status Predict(const InlinedHashMap& inputs, const InlinedHashMap& outputs); - bool IsScalarOutput(const std::string& output_name) const; - // Mutex for exclusive lock to this model object. OrtMutex& GetMutex() { return mutex_; } @@ -65,8 +63,6 @@ class Model { emscripten::val wnn_inputs_ = emscripten::val::object(); emscripten::val wnn_outputs_ = emscripten::val::object(); - InlinedHashSet scalar_outputs_; - std::vector inputs_; std::vector outputs_; @@ -83,10 +79,6 @@ class Model { input_output_info_ = std::move(input_output_info); } - void SetScalarOutputs(InlinedHashSet&& scalar_outputs) { - scalar_outputs_ = std::move(scalar_outputs); - } - void AllocateInputOutputBuffers(); }; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index b21f717eedc7a..44bec1fb6fd48 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -104,13 +104,15 @@ Status ModelBuilder::RegisterInitializers() { emscripten::val operand = emscripten::val::object(); if (IsSupportedDataType(data_type, webnn_supported_data_types)) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); - auto num_elements = SafeInt(Product(tensor.dims())); + auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); std::byte* tensor_ptr = nullptr; if (tensor.has_raw_data()) { tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); } else { - std::vector unpacked_tensor; + // Store temporary unpacked_tensor. + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); tensor_ptr = reinterpret_cast(unpacked_tensor.data()); } @@ -187,16 +189,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i ORT_RETURN_IF(shape_proto == nullptr, "shape_proto cannot be null for ", input_output_type, ": ", name); const auto& shape = shape_proto->dim(); - if (shape.empty()) { - // If we have an empty shape, this is a scalar input. - dims.push_back(1); - - // We need to change the shapes of these scalar outputs back to {} - // when WebNN EP returns these values to ORT. - if (!is_input) { - AddScalarOutput(name); - } - } else { + if (!shape.empty()) { dims.reserve(shape.size()); for (const auto& dim : shape) { // dim_param free dimensions should have already been excluded by IsInputSupported(). @@ -343,7 +336,6 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); - model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews // for inputs and outputs because they will be transferred after compute() done. @@ -352,10 +344,6 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { return Status::OK(); } -void ModelBuilder::AddScalarOutput(const std::string& output_name) { - scalar_outputs_.insert(output_name); -} - void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& operand) { wnn_operands_.insert(std::make_pair(name, operand)); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index b1561f009aa25..2d686070cdcc1 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -69,8 +69,8 @@ class ModelBuilder { InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; + std::vector> unpacked_tensors_; - InlinedHashSet scalar_outputs_; InlinedHashMap input_output_info_; InlinedHashSet skipped_initializers_; @@ -92,9 +92,6 @@ class ModelBuilder { Status RegisterModelOutputs() ORT_MUST_USE_RESULT; Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) ORT_MUST_USE_RESULT; - // Record the onnx scalar output names. - void AddScalarOutput(const std::string& output_name); - static const IOpBuilder* GetOpBuilder(const Node& node); }; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 1cd382c1e75e9..b918daf838c99 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -272,10 +272,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector(input_tensor.GetTensorRawData()); inputs.emplace( input_name, @@ -297,12 +293,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vectorGetInputOutputInfo(output_name); auto output_shape = output_info.shape; auto output_type = output_info.data_type; - - // Since WebNN EP use {1} tensor as scalar, if the model output should have empty shape. - // We are going to replace the {1} shape of the output back to {}. - if (model->IsScalarOutput(output_name)) - output_shape.clear(); - auto output_tensor = ctx.GetOutput(i, output_shape.data(), output_shape.size());