Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyi9801 committed Sep 18, 2024
1 parent 3d7777a commit cb19d7b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"LessOrEqual", "lesserOrEqual"},
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"Lstm", "lstm"},
{"LSTM", "lstm"},
{"MatMul", "matmul"},
{"MatMulInteger", "matmulInteger"},
{"Max", "max"},
Expand Down
66 changes: 59 additions & 7 deletions onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

void LstmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
Expand All @@ -49,12 +51,14 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
emscripten::val recurrent_weight = model_builder.GetOperand(input_defs[2]->Name());

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
options.set("layout", emscripten::val("iofg"));

if (input_defs.size() > 3 && input_defs[3]->Exists()) {
emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name());
emscripten::val split_options = emscripten::val::object();
split_options.set("axis", 1);
split_options.set("label", node.Name() + "_split");
// Split it to bias and recurrentBias.
emscripten::val splitted_biases =
model_builder.GetBuilder().call<emscripten::val>("split", bias, /*splits*/ 2, split_options);
Expand Down Expand Up @@ -84,22 +88,19 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists();
bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists();
bool has_Y_c = output_defs.size() > 2 && output_defs[2]->Exists();
if (has_Y) {
options.set("returnSequence", true);
}
options.set("returnSequence", has_Y);

if (helper.HasAttr("activations")) {
const auto activations = helper.Get("activations", std::vector<std::string>{"Sigmoid", "Tanh", "Tanh"});

emscripten::val opt_activations = emscripten::val::array();
for (size_t i = 0; i < 3; ++i) {
const std::string& activation = activations[i];
if (activation == "Relu") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("relu"));
opt_activations.call<void>("push", emscripten::val("relu"));
} else if (activation == "Sigmoid") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("sigmoid"));
opt_activations.call<void>("push", emscripten::val("sigmoid"));
} else if (activation == "Tanh") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("tanh"));
opt_activations.call<void>("push", emscripten::val("tanh"));
}
}

Expand All @@ -125,6 +126,10 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 3) {
LOGS(logger, ERROR) << "LSTM: input size must be greater than or equal to 3";
return false;
}

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
Expand Down Expand Up @@ -191,6 +196,53 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}

bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type = 0; // input data type
int32_t input1_type = 0; // weight data type
int32_t input2_type = 0; // recurrentWeight data type
int32_t input3_type = 0; // bias data type
// input4 sequence_lens is skipped.
int32_t input5_type = 0; // initialHiddenState data type
int32_t input6_type = 0; // initialCellState data type
int32_t input7_type = 0; // peepholeWeight data type
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists();
bool has_input6 = input_defs.size() > 6 && input_defs[6]->Exists();
bool has_input7 = input_defs.size() > 7 && input_defs[7]->Exists();

if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger) ||
!GetType(*input_defs[2], input2_type, logger) ||
(has_input3 && !GetType(*input_defs[3], input3_type, logger)) ||
(has_input5 && !GetType(*input_defs[5], input5_type, logger)) ||
(has_input6 && !GetType(*input_defs[6], input6_type, logger)) ||
(has_input7 && !GetType(*input_defs[7], input7_type, logger))) {
return false;
}

InlinedVector<int32_t, 7> input_types = {input0_type, input1_type, input2_type};
if (has_input3) {
input_types.push_back(input3_type);
}
if (has_input5) {
input_types.push_back(input5_type);
}
if (has_input6) {
input_types.push_back(input6_type);
}
if (has_input7) {
input_types.push_back(input7_type);
}
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
}

void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<LstmOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down

0 comments on commit cb19d7b

Please sign in to comment.