Skip to content

Commit

Permalink
[WebNN EP] Support QuantizeLinear and DequantizeLinear ops (#22097)
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry authored Sep 17, 2024
1 parent afd642a commit 9786909
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 103 deletions.
2 changes: 2 additions & 0 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d ||| Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group |
| Cos | ai.onnx(7+) | cos ||| |
| Div | ai.onnx(7-12, 13, 14+) | div ||| |
| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear ||| |
| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity ||| Only supports test mode |
| Elu | ai.onnx(7+) | elu ||| WebNN CPU backend only supports 'alpha' value is 1.0 |
| Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal ||| |
Expand Down Expand Up @@ -62,6 +63,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad ||| modes == 'wrap' is not supported |
| Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow ||| |
| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu ||| WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) |
| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear ||| |
| Reciprocal | ai.onnx(7-12, 13+) | reciprocal ||| |
| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 ||| Input 'axes' if present should be a constant |
| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 ||| Input 'axes' if present should be a constant |
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Pad", "pad"},
{"Pow", "pow"},
{"PRelu", "prelu"},
{"QuantizeLinear", "quantizeLinear"},
{"Reciprocal", "reciprocal"},
{"ReduceL1", "reduceL1"},
{"ReduceL2", "reduceL2"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (input_defs.size() >= 3) {
x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
x_zero_point = model_builder.GetZeroConstant("uint8");
x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
}
if (input_defs.size() >= 4) {
w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
w_zero_point = model_builder.GetZeroConstant("uint8");
w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
}
output = model_builder.GetBuilder().call<emscripten::val>("conv2dInteger",
input, x_zero_point, filter, w_zero_point, options);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (input_defs.size() >= 3) {
a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
a_zero_point = model_builder.GetZeroConstant("uint8");
a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
}
if (input_defs.size() >= 4) {
b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name());
} else {
b_zero_point = model_builder.GetZeroConstant("uint8");
b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
}
output = model_builder.GetBuilder().call<emscripten::val>("matmulInteger",
a,
Expand Down
152 changes: 152 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "core/providers/webnn/builders/impl/base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class QDQOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
const auto& output_defs = node.OutputDefs();

std::vector<int64_t> input_shape;
std::vector<int64_t> scale_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape");
int32_t input_type = 0;
int32_t output_type = 0;
int32_t zero_point_type = 0;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input data type");
ORT_RETURN_IF_NOT(GetType(*output_defs[0], output_type, logger), "Cannot get output data type");
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name());

emscripten::val zero_point = emscripten::val::null();
if (input_defs.size() == 3 && input_defs[2]->Exists()) {
zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
} else {
// DequantizeLinear: x_zero_point's data type equals to input data type
// QuantizeLinear: x_zero_point's data type equals to output data type
zero_point_type = op_type == "DequantizeLinear" ? input_type : output_type;
zero_point = model_builder.GetZeroConstant(zero_point_type);
}

emscripten::val output;
NodeAttrHelper helper(node);
int32_t axis = helper.Get("axis", 1);
int32_t block_size = helper.Get("block_size", 0);
// axis is valid for input shape greater than 1D.
if (input_shape.size() > 1) {
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_shape.size()));
}
// Insert ones before and after the axis dimension for broadcasting of 1D scale tensor.
if (1 == scale_shape.size() && 1 < input_shape.size()) {
std::vector<int32_t> target_shape{static_cast<int>(input_shape[axis])};
target_shape.insert(target_shape.begin(), axis, 1);
target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1);
emscripten::val reshape_scale_options = emscripten::val::object();
reshape_scale_options.set("label", node.Name() + "_reshape_scale");
scale = model_builder.GetBuilder().call<emscripten::val>("reshape",
scale,
emscripten::val::array(target_shape),
reshape_scale_options);
emscripten::val reshape_zero_point_options = emscripten::val::object();
reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point");
zero_point = model_builder.GetBuilder().call<emscripten::val>("reshape",
zero_point,
emscripten::val::array(target_shape),
reshape_zero_point_options);
}

// If block_size is specified, we need to expand the scale and zero_point tensors.
if (block_size > 1) {
emscripten::val concat_scale_inputs = emscripten::val::array();
emscripten::val concat_zero_point_inputs = emscripten::val::array();
for (int i = 0; i < block_size; i++) {
concat_scale_inputs.call<void>("push", scale);
concat_zero_point_inputs.call<void>("push", zero_point);
}

emscripten::val concat_scale_options = emscripten::val::object();
concat_scale_options.set("label", node.Name() + "_concat_scale");
scale = model_builder.GetBuilder().call<emscripten::val>("concat", concat_scale_inputs, axis, concat_scale_options);

emscripten::val concat_zero_point_options = emscripten::val::object();
concat_zero_point_options.set("label", node.Name() + "_concat_zero_point");
zero_point = model_builder.GetBuilder().call<emscripten::val>(
"concat", concat_zero_point_inputs, axis, concat_zero_point_options);
}

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
std::string webnn_op_type;
ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type");
output = model_builder.GetBuilder().call<emscripten::val>(webnn_op_type.c_str(), input, scale, zero_point, options);

model_builder.AddOperand(output_defs[0]->Name(), std::move(output));

return Status::OK();
}

bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type = 0; // input data type
int32_t input1_type = 0; // x_scale data type
int32_t input2_type = 0; // x_zero_point data type
bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists();

if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger) ||
(has_input2 && !GetType(*input_defs[2], input2_type, logger))) {
return false;
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) &&
IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) &&
(!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger));
}

void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;

static std::vector<std::string> op_types =
{
"DequantizeLinear",
"QuantizeLinear",
};

op_registrations.builders.push_back(std::make_unique<QDQOpBuilder>());
for (const auto& type : op_types) {
op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get());
}
}

} // namespace webnn
} // namespace onnxruntime
49 changes: 35 additions & 14 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,27 +354,48 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-value-type
// BTW, the spec is discussing if the builer.constant(value, type) should be dropped at
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
const emscripten::val& ModelBuilder::GetZeroConstant(const std::string& data_type) {
std::string name = "webnn_zero_constant_" + data_type;
const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) {
std::string name = "webnn_zero_constant_" + std::to_string(data_type);
// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
emscripten::val desc = emscripten::val::object();
emscripten::val dims = emscripten::val::array();
desc.set("dimensions", dims);
emscripten::val zero_buffer = emscripten::val::undefined();
if (data_type == "uint8") {
if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_UINT8)) {
ORT_THROW("Unsupported data type: " + data_type);
}
zero_buffer = emscripten::val::global("Uint8Array").new_(1);
} else if (data_type == "float32") {
if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
ORT_THROW("Unsupported data type: " + data_type);
}
zero_buffer = emscripten::val::global("Float32Array").new_(1);
} else {
ORT_THROW("Unsupported data type: " + data_type);
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}

switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
zero_buffer = emscripten::val::global("Uint8Array").new_(1);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
zero_buffer = emscripten::val::global("Int8Array").new_(1);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
zero_buffer = emscripten::val::global("Uint16Array").new_(1);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
zero_buffer = emscripten::val::global("Float32Array").new_(1);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
zero_buffer = emscripten::val::global("Int32Array").new_(1);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
zero_buffer = emscripten::val::global("BigInt64Array").new_(1);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
zero_buffer = emscripten::val::global("Uint32Array").new_(1);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
zero_buffer = emscripten::val::global("BigUint64Array").new_(1);
break;
default:
break;
}

emscripten::val zero_constant = wnn_builder_.call<emscripten::val>("constant", desc, zero_buffer);
wnn_operands_.insert(std::make_pair(name, zero_constant));
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ModelBuilder {
const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; }

void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(const std::string& data_type);
const emscripten::val& GetZeroConstant(const int32_t& data_type);
// Use the buffers to persist WebNN allocated data like transposed weight.
// It ensures the validity during inference session.
std::vector<std::unique_ptr<uint8_t[]>> mem_persist_buffers_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateDropoutOpBuilder("Dropout", op_registrations);
}

{ // Quantize/Dequantize
{ // DequantizeLinear/QuantizeLinear/DynamicQuantizeLinear
CreateQDQOpBuilder("DequantizeLinear", op_registrations);
CreateQDQOpBuilder("QuantizeLinear", op_registrations);
CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations);
CreateDequantizeLinearOpBuilder("DequantizeLinear", op_registrations);
}

{ // Expand
Expand Down
Loading

0 comments on commit 9786909

Please sign in to comment.