Skip to content

Commit

Permalink
[WebNN EP] Remove workaround for scalar (#21704)
Browse files Browse the repository at this point in the history
Currently Chromium has supported scalar with dims = {}, remove legacy
workaround for supporting scalar.
  • Loading branch information
Honry authored Aug 16, 2024
1 parent c97cc5c commit b2d603a
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 43 deletions.
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
return Status::OK();
}

bool Model::IsScalarOutput(const std::string& output_name) const {
return Contains(scalar_outputs_, output_name);
}

const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const {
return input_output_info_.at(name);
}
Expand Down
8 changes: 0 additions & 8 deletions onnxruntime/core/providers/webnn/builders/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class Model {
onnxruntime::common::Status Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
const InlinedHashMap<std::string, OnnxTensorData>& outputs);

bool IsScalarOutput(const std::string& output_name) const;

// Mutex for exclusive lock to this model object.
OrtMutex& GetMutex() { return mutex_; }

Expand Down Expand Up @@ -65,8 +63,6 @@ class Model {
emscripten::val wnn_inputs_ = emscripten::val::object();
emscripten::val wnn_outputs_ = emscripten::val::object();

InlinedHashSet<std::string> scalar_outputs_;

std::vector<std::string> inputs_;
std::vector<std::string> outputs_;

Expand All @@ -83,10 +79,6 @@ class Model {
input_output_info_ = std::move(input_output_info);
}

void SetScalarOutputs(InlinedHashSet<std::string>&& scalar_outputs) {
scalar_outputs_ = std::move(scalar_outputs);
}

void AllocateInputOutputBuffers();
};

Expand Down
22 changes: 5 additions & 17 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ Status ModelBuilder::RegisterInitializers() {
emscripten::val operand = emscripten::val::object();
if (IsSupportedDataType(data_type, webnn_supported_data_types)) {
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
auto num_elements = SafeInt<size_t>(Product(tensor.dims()));
auto num_elements = SafeInt<size_t>(Product(shape));
emscripten::val view = emscripten::val::undefined();
std::byte* tensor_ptr = nullptr;
if (tensor.has_raw_data()) {
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));
} else {
std::vector<uint8_t> unpacked_tensor;
// Store temporary unpacked_tensor.
unpacked_tensors_.push_back({});
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
}
Expand Down Expand Up @@ -187,16 +189,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
ORT_RETURN_IF(shape_proto == nullptr,
"shape_proto cannot be null for ", input_output_type, ": ", name);
const auto& shape = shape_proto->dim();
if (shape.empty()) {
// If we have an empty shape, this is a scalar input.
dims.push_back(1);

// We need to change the shapes of these scalar outputs back to {}
// when WebNN EP returns these values to ORT.
if (!is_input) {
AddScalarOutput(name);
}
} else {
if (!shape.empty()) {
dims.reserve(shape.size());
for (const auto& dim : shape) {
// dim_param free dimensions should have already been excluded by IsInputSupported().
Expand Down Expand Up @@ -343,7 +336,6 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_));
model->SetInputs(std::move(input_names_));
model->SetOutputs(std::move(output_names_));
model->SetScalarOutputs(std::move(scalar_outputs_));
model->SetInputOutputInfo(std::move(input_output_info_));
// Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews
// for inputs and outputs because they will be transferred after compute() done.
Expand All @@ -352,10 +344,6 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
return Status::OK();
}

void ModelBuilder::AddScalarOutput(const std::string& output_name) {
scalar_outputs_.insert(output_name);
}

void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& operand) {
wnn_operands_.insert(std::make_pair(name, operand));
}
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class ModelBuilder {
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
std::vector<std::vector<uint8_t>> unpacked_tensors_;

InlinedHashSet<std::string> scalar_outputs_;
InlinedHashMap<std::string, OnnxTensorInfo> input_output_info_;

InlinedHashSet<std::string> skipped_initializers_;
Expand All @@ -92,9 +92,6 @@ class ModelBuilder {
Status RegisterModelOutputs() ORT_MUST_USE_RESULT;
Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) ORT_MUST_USE_RESULT;

// Record the onnx scalar output names.
void AddScalarOutput(const std::string& output_name);

static const IOpBuilder* GetOpBuilder(const Node& node);
};

Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
auto input_tensor = ctx.GetInput(input_idx);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
auto shape = tensor_info.GetShape();
// If we have an empty shape, this is a scalar input,
// Since all the input output of WebNN EP is MultiArray, we will make the scalar input as a {1} MultiArray.
if (shape.empty())
shape.push_back(1);
const void* inputBuffer = const_cast<void*>(input_tensor.GetTensorRawData());
inputs.emplace(
input_name,
Expand All @@ -297,12 +293,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
const auto& output_info = model->GetInputOutputInfo(output_name);
auto output_shape = output_info.shape;
auto output_type = output_info.data_type;

// Since WebNN EP use {1} tensor as scalar, if the model output should have empty shape.
// We are going to replace the {1} shape of the output back to {}.
if (model->IsScalarOutput(output_name))
output_shape.clear();

auto output_tensor =
ctx.GetOutput(i, output_shape.data(), output_shape.size());

Expand Down

0 comments on commit b2d603a

Please sign in to comment.