From e91ff9438bd368300b7d6d95aabfffe8cb9547b6 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Wed, 11 Sep 2024 09:54:15 -0700 Subject: [PATCH 01/26] Enable Pad->Conv(no pads) fusion (#22001) ### Description ### Motivation and Context For some model has pattern Pad -> Conv. If the Conv doesn't have pads attributes, the Pad can be fused into Conv. --- onnxruntime/core/optimizer/pad_fusion.cc | 12 ++--- .../test/optimizer/graph_transform_test.cc | 48 ++++++++++++++++++ .../transform/fusion/fuse-pad-nopadsconv.onnx | Bin 0 -> 397 bytes 3 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index 3391e20cf0bb7..25afed52403c4 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -31,15 +31,15 @@ bool VerifyNotCastChild(const Node& child_node) { return false; } - // This pass currently assumed that this attribute already exists on the child node - if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) { - return false; - } - return true; } void UpdatePaddingAttribute(Node& child_node, const std::vector& pads_values, const uint32_t pads_size) { + if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) { + std::vector pads(pads_size - 4, 0); + child_node.AddAttribute("pads", pads); + } + auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); uint32_t child_pads_size = static_cast(child_pads->size()); @@ -162,4 +162,4 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef rule_effect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6ae66e35e7853..3aec0d5a67e94 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1469,6 +1469,54 @@ TEST_F(GraphTransformationTests, FusePadWithConv) { } } +TEST_F(GraphTransformationTests, FusePadWithNoPadsConv) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-nopadsconv.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::vector expected_pads; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Pad") { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + gsl::span pads_values = pads.DataAsSpan(); + expected_pads.resize(pads_values.size() - 4); + + for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) { + expected_pads[index] = pads_values[pads_index]; + expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)]; + } + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Pad"], 0); + ASSERT_EQ(op_to_count["Conv"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Conv") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + ASSERT_EQ(child_pads->size(), static_cast(expected_pads.size())) + << "fusion should produce the same size of pads integer as the Conv node"; + for (uint32_t index = 0; index < expected_pads.size(); index++) { + ASSERT_EQ(expected_pads[index], child_pads->Get(index)) + << "fusion does not produce correct padding value"; + } + } + } +} + TEST_F(GraphTransformationTests, FusePadWithMaxPool) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-maxpool.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx new file mode 100644 index 0000000000000000000000000000000000000000..145847cdc47fadd7a2feaaefadf5eeef5c1be270 GIT binary patch literal 397 zcmduN`%;;npvEIflD0$fZSj8M!9q*Q9XgFj2WnmIZ1&F>OCw5u>)0b LI Date: Wed, 11 Sep 2024 19:41:04 +0200 Subject: [PATCH 02/26] Improve hash_function used by TreeEnsemble (#22043) ### Description unordered_map are implemented in a different way on VisualStudio and gcc. It seems that inserting consecutive keys has a poor performance on Windows. ### Motivation and Context Improve the performance of onnxruntime when initializing trees. --- onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index b9f3050e59c5b..34c6db61982b5 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -23,7 +23,9 @@ struct TreeNodeElementId { } struct hash_fn { std::size_t operator()(const TreeNodeElementId& key) const { - return static_cast(static_cast(key.tree_id) << 32 | static_cast(key.node_id)); + // unordered_map has poor performance on Windows when inserting consecutive keys. + // keys are usually inserted with key.node_id being incremented at each iteration. + return static_cast(static_cast(key.tree_id) | static_cast(key.node_id) << 32); } }; }; From 4d824045444756ba70223c32ae11693a252adde6 Mon Sep 17 00:00:00 2001 From: Bin Miao Date: Thu, 12 Sep 2024 05:16:36 +0800 Subject: [PATCH 03/26] [WebNN EP] Support GRU operator (#20405) This PR support Gru operator for WebNN EP. @Honry , @fdwr thanks! --- js/web/docs/webnn-operators.md | 1 + js/web/test/suite-test-list.jsonc | 6 +- .../core/providers/shared/utils/utils.cc | 9 + .../core/providers/shared/utils/utils.h | 1 + .../core/providers/webnn/builders/helper.h | 1 + .../webnn/builders/impl/gru_op_builder.cc | 250 ++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 4 + .../webnn/builders/op_builder_factory.h | 1 + 8 files changed, 270 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 48b06b780dfc7..164096b4fda9a 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -41,6 +41,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | GlobalLpPool| ai.onnx(7+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 'p' value is 2 | | Greater | ai.onnx(7-8, 9-12, 13+) | greater | ✓ | ✓ | | | GreaterOrEqual | ai.onnx(12-15, 16+) | greaterOrEqual | ✓ | ✓ | | +| GRU | ai.onnx(7-13, 14-21, 22+) | gru | ✓ | ✓ | Only supports 'layout' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | HardSigmoid | ai.onnx(7+) | hardSigmoid | ✓ | ✓ | | | HardSwish | ai.onnx(14+) | hardSwish | ✓ | ✓ | | | Identity | ai.onnx(7-13, 14-15, 16-18, 19-20, 21+) | identity | ✓ | ✓ | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 7f0c1cc3e420c..5c1e2e27a6eff 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1812,9 +1812,9 @@ // // "test_gridsample_zeros_padding", // // "test_gridsample", // // "test_gru_batchwise", - // // "test_gru_defaults", - // // "test_gru_seq_length", - // // "test_gru_with_initial_bias", + "test_gru_defaults", + "test_gru_seq_length", + "test_gru_with_initial_bias", // // "test_hammingwindow_expanded", // // "test_hammingwindow_symmetric_expanded", // // "test_hammingwindow_symmetric", diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 2088618538de5..5b2f2c1fa1b2e 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -192,6 +192,15 @@ std::vector NodeAttrHelper::Get(const std::string& key, const std::vect return def_val; } +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.strings(); + return std::vector{values.cbegin(), values.cend()}; + } + + return def_val; +} + std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { const auto& values = entry->second.floats(); diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 5813dcc48d72b..ddbae42534711 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -57,6 +57,7 @@ class NodeAttrHelper { std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 4d723a3c59ee2..b51092619db22 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -183,6 +183,7 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, {"GreaterOrEqual", "greaterOrEqual"}, + {"Gru", "gru"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, {"Identity", "identity"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc new file mode 100644 index 0000000000000..23cc7f1b11459 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GruOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; +}; + +void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.InputDefs().size() > 4 && node.InputDefs()[4]->Exists()) { + model_builder.AddInitializerToSkip(node.InputDefs()[4]->Name()); // sequence_lens + model_builder.AddInputToSkip(node.InputDefs()[4]->Name()); + } +} + +Status GruOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + uint32_t hidden_size = helper.Get("hidden_size", 1); + + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape"); + uint32_t steps = static_cast(input_shape[0]); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val weight = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val recurrent_weight = model_builder.GetOperand(input_defs[2]->Name()); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + options.set("layout", emscripten::val("zrn")); + + if (input_defs.size() > 3 && input_defs[3]->Exists()) { + emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name()); + emscripten::val split_options = emscripten::val::object(); + split_options.set("label", node.Name() + "_split"); + split_options.set("axis", 1); + // Split it to bias and recurrentBias. + emscripten::val splitted_biases = + model_builder.GetBuilder().call("split", bias, /*splits*/ 2, split_options); + options.set("bias", splitted_biases[0]); + options.set("recurrentBias", splitted_biases[1]); + } + + if (input_defs.size() > 5 && input_defs[5]->Exists()) { + options.set("initialHiddenState", model_builder.GetOperand(input_defs[5]->Name())); + } + + bool linear_before_reset = !!helper.Get("linear_before_reset ", 0); + options.set("resetAfter", linear_before_reset); + + const auto& output_defs = node.OutputDefs(); + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + options.set("returnSequence", has_Y); + + std::string direction = helper.Get("direction", "forward"); + if (direction == "forward") { + options.set("direction", emscripten::val("forward")); + } else if (direction == "reverse") { + options.set("direction", emscripten::val("backward")); + } else if (direction == "bidirectional") { + options.set("direction", emscripten::val("both")); + } + + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh"}); + emscripten::val recurrent_network_activations = emscripten::val::array(); + for (size_t i = 0; i < 2; ++i) { + const std::string& activation = activations[i]; + if (activation == "Relu") { + recurrent_network_activations.call("push", emscripten::val("relu")); + } else if (activation == "Sigmoid") { + recurrent_network_activations.call("push", emscripten::val("sigmoid")); + } else if (activation == "Tanh") { + recurrent_network_activations.call("push", emscripten::val("tanh")); + } + } + + options.set("activations", recurrent_network_activations); + } + + emscripten::val outputs = model_builder.GetBuilder().call("gru", input, weight, recurrent_weight, + steps, hidden_size, options); + + if (has_Y) { + model_builder.AddOperand(output_defs[0]->Name(), outputs[1]); + } + if (has_Y_h) { + model_builder.AddOperand(output_defs[1]->Name(), outputs[0]); + } + + return Status::OK(); +} + +bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 3) { + LOGS(logger, ERROR) << "GRU: input size must greater than or equal to 3"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger) || input_shape.empty()) { + LOGS(logger, ERROR) << "Cannot get input's shape"; + return false; + } + int32_t steps = static_cast(input_shape[0]); + + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + if (!Contains(initializers, input_defs[4]->Name())) { + LOGS(logger, ERROR) << "GRU: sequence_lens must be constant"; + return false; + } + + const auto& sequence_lens_tensor = *initializers.at(input_defs[4]->Name()); + std::vector sequence_lens; + if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) { + LOGS(logger, ERROR) << "Cannot read sequence lens tensor"; + return false; + } + if (!std::all_of(sequence_lens.begin(), sequence_lens.end(), + [steps](int32_t lens) -> bool { return steps == lens; })) { + LOGS(logger, ERROR) << "GRU: every sequence length must be equal to input shape[0]"; + return false; + } + } + + NodeAttrHelper helper(node); + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh"}); + + if (activations.size() >= 4) { + if (activations[0] != activations[2] || activations[1] != activations[3]) { + LOGS(logger, ERROR) << "GRU: forward and reverse directions must have the same activations"; + return false; + } + } + + const InlinedHashSet supported_activations = {"Relu", "Tanh", "Sigmoid"}; + if (!std::all_of(activations.begin(), activations.end(), + [&supported_activations](const std::string& activation) -> bool { + return supported_activations.contains(activation); + })) { + LOGS(logger, ERROR) << "GRU: activations must be one of Relu, Tanh, Sigmoid"; + return false; + } + } + + if (helper.Get("clip", std::numeric_limits::max()) != std::numeric_limits::max()) { + LOGS(logger, ERROR) << "GRU: clip is not supported"; + return false; + } + + if (helper.Get("layout", 0) != 0) { + LOGS(logger, ERROR) << "GRU: batchwise (layout == 1) is not supported"; + return false; + } + + return true; +} + +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type = 0; // input data type + int32_t input1_type = 0; // weight data type + int32_t input2_type = 0; // recurrentWeight data type + int32_t input3_type = 0; // bias data type + int32_t input4_type = 0; // recurrentBias data type + int32_t input5_type = 0; // initialHiddenState data type + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input4 = input_defs.size() > 4 && input_defs[4]->Exists(); + bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + !GetType(*input_defs[2], input2_type, logger) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input4 && !GetType(*input_defs[4], input4_type, logger)) || + (has_input5 && !GetType(*input_defs[5], input5_type, logger))) { + return false; + } + + std::unordered_set supported_data_types; + if (device_type == WebnnDeviceType::CPU) { + // WebNN CPU backend only support float32 input data type. + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + }; + } else if (device_type == WebnnDeviceType::GPU) { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + } + + if (!IsSupportedDataType(input0_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type || + input0_type != input2_type || + (has_input3 && input0_type != input3_type) || + (has_input4 && input0_type != input4_type) || + (has_input5 && input0_type != input5_type)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + +void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 862cf5ded15bc..01761290f07e3 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -108,6 +108,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGemmOpBuilder("MatMulInteger", op_registrations); } + { // GRU + CreateGruOpBuilder("GRU", op_registrations); + } + { // Logical CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index e11938d8fa406..b66218cc9a902 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -33,6 +33,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); From b80032862800f516b9810ac60a4a30f0a565b8e4 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Wed, 11 Sep 2024 14:52:18 -0700 Subject: [PATCH 04/26] [ROCm EP/ MIGraphx EP] matmul_nbits: Use GPU_WARP_SIZE_HOST for host side code (#22045) ### Description For ROCm device, the host side code needs to call GPU_WARP_SIZE_HOST to query warpSize of the underlying GPU device. ### Motivation and Context Fixes MatMulNBits tests on gfx1100/01 which has warpSize of 32. Signed-off-by: Jagadish Krishnamoorthy --- onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index af9e87eaf225d..ce6c07fbed2bc 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -289,7 +289,7 @@ bool TryMatMul4Bits( return false; } dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); - dim3 threads(kWarpSize, kColsPerThreadBlock); + dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); From 0309c5f02fa8ffc11c3ab74c582e88c3997969e0 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Thu, 12 Sep 2024 03:25:40 +0530 Subject: [PATCH 05/26] Ovep release lnl 1.2.1 (#22027) Error Codes are added to catch compilation error and signal recompile. Remote Tensors are added to ensure direct memory access for NPU inferencing. UMD Bypass cache enabled with 2024.4 will eliminate need to disk caching ### Motivation and Context The changes are needed to ensure backward compatibility UMD Bypass caching eliminates driver caching Remote Tensors lead to performance improvement with inferencing on NPU --------- Co-authored-by: Preetha Veeramalai Co-authored-by: Srirammaswamy Co-authored-by: saurabh Co-authored-by: Javier E. Martinez Co-authored-by: Eric Crawford Co-authored-by: jatinwadhwa921 --- cmake/onnxruntime_providers_openvino.cmake | 4 + .../onnxruntime/core/framework/allocator.h | 2 + onnxruntime/core/framework/allocator.cc | 4 + .../providers/openvino/backend_manager.cc | 38 ++++- .../openvino/backends/basic_backend.cc | 161 +++++++++++++++--- .../openvino/backends/basic_backend.h | 9 + .../openvino/openvino_execution_provider.cc | 17 ++ .../openvino/openvino_execution_provider.h | 4 +- .../core/providers/openvino/ov_allocator.cc | 55 ++++++ .../core/providers/openvino/ov_allocator.h | 24 +++ .../core/providers/openvino/ov_interface.h | 1 + onnxruntime/test/perftest/ort_test_session.cc | 59 ++++++- onnxruntime/test/perftest/ort_test_session.h | 3 + 13 files changed, 338 insertions(+), 43 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/ov_allocator.cc create mode 100644 onnxruntime/core/providers/openvino/ov_allocator.h diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index e559583fae8f5..2eb3611bae902 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -21,6 +21,10 @@ message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release") endif() + if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4) + add_definitions(-DUSE_OVEP_NPU_MEMORY=1) + endif() + if (WIN32) unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) endif() diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 097873c5e3653..abab118efd04f 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -50,6 +50,8 @@ constexpr const char* HIP = "Hip"; constexpr const char* HIP_PINNED = "HipPinned"; constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; +constexpr const char* OpenVINO_RT = "OpenVINO_RT"; +constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; constexpr size_t kAllocAlignment = 256; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index c3e96e450c59b..5e66f2b99fded 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -145,6 +145,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { + *out = new OrtMemoryInfo( + name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index be41b125e4440..4fca4037301fb 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -107,12 +108,15 @@ BackendManager::BackendManager(const GlobalContext& global_context, subgraph_context_, ep_ctx_handle_); } catch (const OnnxRuntimeException& ex) { + std::string exception_str = ex.what(); + bool eligible_for_cpu_fallback = device_type.find("NPU") != std::string::npos && + !GetGlobalContext().disable_cpu_fallback && + !ep_ctx_handle_.IsValidOVEPCtxGraph(); #if defined(OPENVINO_DISABLE_NPU_FALLBACK) - ORT_THROW(ex.what()); + eligible_for_cpu_fallback = false; #else - if (device_type.find("NPU") != std::string::npos && - !GetGlobalContext().disable_cpu_fallback) { - LOGS_DEFAULT(WARNING) << ex.what(); + if (eligible_for_cpu_fallback) { + LOGS_DEFAULT(VERBOSE) << exception_str; LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." << "Falling back to OV CPU for execution"; GetGlobalContext().device_type = "CPU"; @@ -125,10 +129,32 @@ BackendManager::BackendManager(const GlobalContext& global_context, } catch (std::string const& msg) { ORT_THROW(msg); } - } else { - ORT_THROW(ex.what()); } #endif + if (!eligible_for_cpu_fallback) { + if (device_type.find("NPU") != std::string::npos && + exception_str.find("intel_npu") != std::string::npos) { + // Handle NPU device related errors +#ifndef NDEBUG + ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); +#else + std::string error_message = "UNKNOWN NPU ERROR"; + std::string error_code = "code 0x0"; + std::regex error_message_pattern(R"(\bZE_\w*\b)"); + std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); + std::smatch matches; + if (std::regex_search(exception_str, matches, error_message_pattern)) { + error_message = matches[0]; + } + if (std::regex_search(exception_str, matches, error_code_pattern)) { + error_code = matches[0]; + } + throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); +#endif + } else { + ORT_THROW(exception_str); + } + } } } if (global_context_.export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 8d340e2daf4b5..1f9c61780f27a 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -48,14 +48,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // Set the inference_num_threads property of the CPU SetNumThreads(device_config); -#ifndef NDEBUG - if (IsDebugEnabled()) { - std::string file_name = subgraph_context.subgraph_name + "_static.onnx"; - std::fstream outfile(file_name, std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(outfile); - } -#endif - try { std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str; @@ -180,6 +172,11 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type); } device_config.emplace(ov::device::properties("NPU", device_property)); +#if (OPENVINO_VERSION_MAJOR >= 2024) && (OPENVINO_VERSION_MINOR > 3) + if (global_context_.export_ep_ctx_blob) { + global_context_.ie_core.Get().set_property("NPU", ov::intel_npu::bypass_umd_caching(true)); + } +#endif } } @@ -295,16 +292,104 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque ORT_THROW(msg); } } else { - OVTensorPtr graph_input_blob; - try { - graph_input_blob = infer_request->GetTensor(input_name); - } catch (const char* msg) { - ORT_THROW(msg); + if ((global_context_.device_type.find("CPU") != std::string::npos || + global_context_.device_type.find("GPU") != std::string::npos)) { + OVTensorPtr graph_input_blob; + try { + graph_input_blob = infer_request->GetTensor(input_name); + } catch (const char* msg) { + ORT_THROW(msg); + } + FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); + } else { + auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); + auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); + ov_tensor_data_t ov_tensor_key; + ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_key = it->second; + } else { + // Does this make sense for both types of allocators? + auto input = graph_input_info.at(input_idx); + if (allocator_name == OpenVINO_RT_NPU) { + ov_tensor_key.copy_needed = false; + ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), + (void*)tensor.GetTensorRawData()); + } else { + ov_tensor_key.copy_needed = true; + ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape()); + } + ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_key); + + if (ov_tensor_key.copy_needed) { + const char* ort_tensor_data = tensor.GetTensorData(); + size_t tensor_data_size = ov_tensor_key.tensor_ptr->get_byte_size(); + auto ort_batch_memory_offset = ort_tensor_data + tensor_data_size * batch_slice_idx; + std::memcpy(ov_tensor_key.tensor_ptr->data(), ort_batch_memory_offset, tensor_data_size); + } + + try { + infer_request->SetTensor(input_name, ov_tensor_key.tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); + } + } } - FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); } input_idx++; } + if (global_context_.device_type.find("NPU") != std::string::npos) { + // Set the output blob as remote blob + auto graph_output_info = exe_network_.Get().outputs(); + auto output_idx = 0; + for (auto output_info_iter = graph_output_info.begin(); + output_info_iter != graph_output_info.end(); ++output_info_iter) { + auto output_names = output_info_iter->get_names(); + std::string onnx_output_name; + std::string output_name; + // using the output name retrieved from ONNX original to match with the output names returned by OV tensors + for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { + onnx_output_name = it->first; + if (output_names.find(onnx_output_name) != output_names.end()) { + // Assigning the output_name + output_name = it->first; + break; + } + } + size_t batch_size = 1; + Ort::UnownedValue tensor = GetOutputTensor(context, + batch_size, + infer_request, + output_name, + subgraph_context_.output_names); + auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); + + ov_tensor_data_t ov_tensor_data; + ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_data = it->second; + } else { + auto output = graph_output_info.at(output_idx); + if (allocator_name == OpenVINO_RT_NPU) { + ov_tensor_data.copy_needed = false; + ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), + (void*)tensor.GetTensorRawData()); + } else { + ov_tensor_data.copy_needed = true; + ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape()); + } + ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_data); + + try { + infer_request->SetTensor(output_name, ov_tensor_data.tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); + } + } + output_idx++; + } + } + // Start Async inference infer_request->StartAsync(); } catch (const char* msg) { @@ -454,20 +539,42 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe " doesn't exist in the " "list of OpenVINO output tensor names"); } - try { - graph_output_blob = infer_request->GetTensor(output_name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - return; + if ((global_context_.device_type.find("CPU") != std::string::npos || + global_context_.device_type.find("GPU") != std::string::npos)) { + try { + graph_output_blob = infer_request->GetTensor(output_name); + } catch (const char* msg) { + ORT_THROW(msg); + } + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + return; + } else { + size_t batch_slice = 0; + FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); + } } else { - size_t batch_slice = 0; - FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + auto allocator_name = output_tensor.GetTensorMemoryInfo().GetAllocatorName(); + ov_tensor_data_t ov_tensor_data; + ort_tensor_key_t ort_tensor_key{output_tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_data = it->second; + } else { + ORT_THROW(log_tag + "Expected all outputs to have associated OV::Tensor's"); + } + + if (ov_tensor_data.copy_needed) { + auto ort_tensor_data = output_tensor.GetTensorMutableData(); + size_t tensor_data_size = ov_tensor_data.tensor_ptr->get_byte_size(); + auto ort_batch_memory_offset = ort_tensor_data /*+ tensor_data_size * batch_size*/; + std::memcpy(ort_batch_memory_offset, ov_tensor_data.tensor_ptr->data(), tensor_data_size); + } } } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index cd242a06b27d4..cd69e88f994b9 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/openvino/contexts.h" @@ -20,6 +21,11 @@ namespace onnxruntime { namespace openvino_ep { +struct ov_tensor_data_t { + OVTensorPtr tensor_ptr; + bool copy_needed; +}; + class InferRequestsQueue; class BasicBackend : public IBackend { public: @@ -60,6 +66,9 @@ class BasicBackend : public IBackend { #if defined IO_BUFFER_ENABLED OVRemoteContextPtr remote_context_; #endif + + using ort_tensor_key_t = std::pair; + std::map ort_ov_tensor_map; }; class InferRequestsQueue { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 29c45916795d3..08144651319cf 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -10,6 +10,9 @@ #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/ov_versions/capability.h" #include "openvino/core/version.hpp" +#ifdef USE_OVEP_NPU_MEMORY +#include "core/providers/openvino/ov_allocator.h" +#endif #define MEMCPY_S(dest, src, destsz, srcsz) memcpy(dest, src, std::min(destsz, srcsz)) @@ -180,4 +183,18 @@ common::Status OpenVINOExecutionProvider::Compile( return Status::OK(); } +#ifdef USE_OVEP_NPU_MEMORY +std::vector OpenVINOExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo npu_allocator_info{ + [this](OrtDevice::DeviceId device_id) { + return std::make_unique(global_context_->ie_core.Get(), OrtDevice::NPU, device_id, OpenVINO_RT_NPU); + }, + 0, + }; + + // fill in allocator + return std::vector{CreateAllocator(npu_allocator_info)}; +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 030e5bba71b67..8b1c62c607f6e 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -189,7 +189,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider { const void* GetExecutionHandle() const noexcept override { return nullptr; } - +#ifdef USE_OVEP_NPU_MEMORY + std::vector CreatePreferredAllocators() override; +#endif private: std::unique_ptr global_context_; openvino_ep::EPCtxHandler ep_ctx_handle_{}; diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc new file mode 100644 index 0000000000000..6700244b754d8 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -0,0 +1,55 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License +#ifdef USE_OVEP_NPU_MEMORY +#include "core/providers/openvino/ov_allocator.h" +#include "core/providers/openvino/ov_interface.h" +#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp" +#include "openvino/runtime/intel_npu/properties.hpp" + +namespace onnxruntime { + +using namespace openvino_ep; + +constexpr size_t default_alignment = 4096; + +static inline size_t align_up(size_t size, size_t pow2_alignment) { + return (size + pow2_alignment - 1) & ~(pow2_alignment - 1); +} + +OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name) : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(device_type, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeCPUInput)), core_(core) { + if (device_type == OrtDevice::NPU) { + remote_ctx_ = core_.get_default_context("NPU").as(); + } else { + ORT_THROW("Invalid device type"); + } +} + +void* OVRTAllocator::Alloc(size_t size) { + try { + size_t alloc_size = align_up(size + sizeof(ov::Tensor*) + default_alignment, default_alignment); + ov::Tensor* tensor = new ov::Tensor(remote_ctx_.create_host_tensor(ov::element::Type_t::u8, + {alloc_size})); + uintptr_t data_ptr = reinterpret_cast(tensor->data()); + + ov::Tensor** ptr = reinterpret_cast(align_up(data_ptr + sizeof(ov::Tensor*), default_alignment)); + ptr[-1] = tensor; + + return reinterpret_cast(ptr); + + } catch (const ov::Exception& e) { + ORT_THROW(std::string("Alloc failed: ") + e.what()); + } + return nullptr; +} + +void OVRTAllocator::Free(void* p) { + try { + ov::Tensor** ptr = reinterpret_cast(p); + delete ptr[-1]; + } catch (const ov::Exception& e) { + ORT_THROW(std::string("Free failed: ") + e.what()); + } +} + +} // namespace onnxruntime +#endif diff --git a/onnxruntime/core/providers/openvino/ov_allocator.h b/onnxruntime/core/providers/openvino/ov_allocator.h new file mode 100644 index 0000000000000..083cfc4d5aed3 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_allocator.h @@ -0,0 +1,24 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License +#ifdef USE_OVEP_NPU_MEMORY +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include "openvino/runtime/remote_context.hpp" + +namespace onnxruntime { + +class OVRTAllocator : public IAllocator { + public: + OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name); + void* Alloc(size_t size) override; + void Free(void* p) override; + + private: + ov::Core& core_; + ov::RemoteContext remote_ctx_; +}; + +} // namespace onnxruntime +#endif diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index fa22e0f3cb03d..f4da4ea3e3244 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -10,6 +10,7 @@ #include #include "openvino/openvino.hpp" +#include "openvino/runtime/intel_npu/properties.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" #include "openvino/frontend/manager.hpp" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 837aeb3c37acd..ae7680571ced1 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -34,10 +34,18 @@ std::chrono::duration OnnxRuntimeTestSession::Run() { // Randomly pick one OrtValueArray from test_inputs_. (NOT ThreadSafe) const std::uniform_int_distribution::param_type p(0, static_cast(test_inputs_.size() - 1)); const size_t id = static_cast(dist_(rand_engine_, p)); + auto& input = test_inputs_.at(id); auto start = std::chrono::high_resolution_clock::now(); - auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), - output_names_raw_ptr.data(), output_names_raw_ptr.size()); + + if (!use_device_mem) { + auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), + output_names_raw_ptr.data(), output_names_raw_ptr.size()); + } else { + session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), + output_names_raw_ptr.data(), outputs_.data(), output_names_raw_ptr.size()); + } + auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration_seconds = end - start; return duration_seconds; @@ -815,6 +823,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [OpenVINO] The value for the key 'export_ep_ctx_blob' " "should be a boolean i.e. true or false. Default value is false.\n"); } + } else if (key == "use_device_mem") { + if (value == "true" || value == "True") { + use_device_mem = true; + } } else { ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n"); } @@ -858,6 +870,27 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); input_names_str_[i] = m.GetInputName(i); input_names_[i] = input_names_str_[i].c_str(); } + + if (use_device_mem) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo("OpenVINO_RT_NPU", OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + custom_allocator_ = std::make_unique(session_, memory_info); + for (size_t i = 0; i < output_names_raw_ptr.size(); i++) { + Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + + std::vector output_shape = tensor_info.GetShape(); + + // free dimensions are treated as 1 if not overridden + for (int64_t& dim : output_shape) { + if (dim == -1) { + dim = 1; + } + } + + outputs_.push_back(Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)output_shape.data(), + output_shape.size(), tensor_info.GetElementType())); + } + } } template @@ -944,9 +977,11 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { // iterate over all input nodes for (size_t i = 0; i < static_cast(input_length_); i++) { Ort::TypeInfo type_info = session_.GetInputTypeInfo(i); - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + if (!use_device_mem) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + } std::vector input_node_dim = tensor_info.GetShape(); // free dimensions are treated as 1 if not overridden @@ -955,12 +990,18 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { dim = 1; } } - - auto allocator = Ort::AllocatorWithDefaultOptions(); - Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), - input_node_dim.size(), tensor_info.GetElementType()); - InitializeTensorWithSeed(seed, input_tensor); - PreLoadTestData(0, i, std::move(input_tensor)); + if (use_device_mem) { + Ort::Value input_tensor = Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } else { + auto allocator = Ort::AllocatorWithDefaultOptions(); + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } } } return true; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index f1a4220ab325e..e33041a2a0958 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -38,6 +38,8 @@ class OnnxRuntimeTestSession : public TestSession { std::mt19937 rand_engine_; std::uniform_int_distribution dist_; std::vector> test_inputs_; + std::unique_ptr custom_allocator_; + std::vector outputs_; std::vector output_names_; // The same size with output_names_. // TODO: implement a customized allocator, then we can remove output_names_ to simplify this code @@ -46,6 +48,7 @@ class OnnxRuntimeTestSession : public TestSession { std::vector input_names_str_; const int input_length_; std::string provider_name_; + bool use_device_mem = false; }; } // namespace perftest From d8e64bb529c1d0f18efd47710d179205c96ffbca Mon Sep 17 00:00:00 2001 From: Lennart Hannink Date: Thu, 12 Sep 2024 01:05:37 +0200 Subject: [PATCH 06/26] Refactor CoreMLExecution to C++ bridge class (#21857) Refactor Objective-C++ class `CoreMLExecution` into existing C++ bridge class `onnxruntime::coreml::Execution`. --- .../core/providers/coreml/model/model.mm | 398 ++++++++---------- 1 file changed, 171 insertions(+), 227 deletions(-) diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 4d20061820e71..68460ff7c9b31 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -30,8 +30,8 @@ // to manually do this asm(".linker_option \"-framework\", \"CoreML\""); -using namespace onnxruntime; -using namespace onnxruntime::coreml; +namespace onnxruntime { +namespace coreml { namespace { /** @@ -247,213 +247,6 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff } } // namespace -NS_ASSUME_NONNULL_BEGIN - -// Execution for a CoreML model, it performs -// 1. Compile the model by given path for execution -// 2. Predict using given OnnxTensorFeatureProvider input and copy the output data back ORT -// 3. The compiled model will be removed in dealloc or removed using cleanup function -@interface CoreMLExecution : NSObject { - NSString* coreml_model_path_; - NSString* compiled_model_path_; - const logging::Logger* logger_; - uint32_t coreml_flags_; -} - -- (instancetype)initWithPath:(const std::string&)path - logger:(const logging::Logger&)logger - coreml_flags:(uint32_t)coreml_flags; -- (void)cleanup; -- (void)dealloc; -- (Status)loadModel API_AVAILABLE_COREML3; -- (Status)predict:(const std::unordered_map&)inputs - outputs:(const std::unordered_map&)outputs - getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)get_output_tensor_mutable_raw_data_fn - API_AVAILABLE_COREML3; - -@property(nullable) MLModel* model API_AVAILABLE_COREML3; - -@end - -@implementation CoreMLExecution - -- (instancetype)initWithPath:(const std::string&)path - logger:(const logging::Logger&)logger - coreml_flags:(uint32_t)coreml_flags { - if (self = [super init]) { - coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); - logger_ = &logger; - coreml_flags_ = coreml_flags; - } - return self; -} - -- (void)cleanup { - NSError* error = nil; - if (compiled_model_path_ != nil) { - [[NSFileManager defaultManager] removeItemAtPath:compiled_model_path_ error:&error]; - if (error != nil) { - LOGS(*logger_, ERROR) << "Failed cleaning up the compiled model: " << [compiled_model_path_ UTF8String] - << ", error message: " << [[error localizedDescription] UTF8String]; - } - compiled_model_path_ = nil; - } - -#if !defined(NDEBUG) - std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); - if (!path_override.empty()) { - // don't cleanup - coreml_model_path_ = nil; - } -#endif - - if (coreml_model_path_ != nil) { - error = nil; - [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; - if (error != nil) { - LOGS(*logger_, ERROR) << "Failed cleaning up the coreml model: " << [coreml_model_path_ UTF8String] - << ", error message: " << [[error localizedDescription] UTF8String]; - } - coreml_model_path_ = nil; - } -} - -- (void)dealloc { - [self cleanup]; -} - -- (Status)loadModel { - NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; - if (modelUrl == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); - } - - // TODO: Update this to version with callback handler as the API used here is deprecated. - // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl - // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the - // background. We will have to check for completion in `predict` and block until it is done. - NSError* error = nil; - NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; - - if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", - [[error localizedDescription] UTF8String]); - } - - compiled_model_path_ = [compileUrl path]; - - MLModelConfiguration* config = [MLModelConfiguration alloc]; - config.computeUnits = (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) - ? MLComputeUnitsCPUOnly - : MLComputeUnitsAll; - _model = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; - - if (error != nil || _model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", - (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); - } - - return Status::OK(); -} - -- (Status)predict:(const std::unordered_map&)inputs - outputs:(const std::unordered_map&)outputs - getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)get_output_tensor_mutable_raw_data_fn { - Status status = Status::OK(); - ORT_TRY { - if (_model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); - } - - id input_features; - InlinedVector> conversion_buffers; - ORT_RETURN_IF_ERROR(CreateInputFeatureProvider(inputs, *logger_, &input_features, conversion_buffers)); - - MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; - NSError* error = nil; - id output_features = [_model predictionFromFeatures:input_features - options:options - error:&error]; - - if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", - [[error localizedDescription] UTF8String]); - } - - for (const auto& [output_name, output_tensor_info] : outputs) { - MLFeatureValue* output_value = - [output_features featureValueForName:util::Utf8StringToNSString(output_name.c_str())]; - - if (output_value == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); - } - - MLMultiArray* data = [output_value multiArrayValue]; - - const auto coreml_static_output_shape = [data]() { - InlinedVector result; - result.reserve(data.shape.count); - for (NSNumber* dim in data.shape) { - const auto dim_value = dim.longLongValue; - result.push_back(dim_value); - } - return result; - }(); - - const auto static_output_shape = GetStaticOutputShape(output_tensor_info.shape, coreml_static_output_shape, - *logger_); - - void* output_buffer = get_output_tensor_mutable_raw_data_fn(output_name, output_tensor_info.data_type, - static_output_shape); - - if (const size_t num_elements = data.count; num_elements > 0) { - if (const auto shape_size = ShapeSize(static_output_shape); - shape_size < 0 || num_elements != static_cast(shape_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "CoreML MLMultiArray count (", num_elements, ") and shape size (", shape_size, - ") do not match"); - } - - // support a non-contiguous array, provided only one dimension is not contiguous - int64_t num_blocks = 0; - int64_t block_size = 0; - int64_t stride = 0; - - ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, num_blocks, block_size, stride)); - - __block Status copy_status; - const auto* tensor_info = &output_tensor_info; - // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions - if (@available(macOS 12.3, iOS 15.4, *)) { - [data getBytesWithHandler:^(const void* bytes, NSInteger size) { - copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, - num_blocks, block_size, stride, tensor_info); - }]; - } else { - copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, - num_blocks, block_size, stride, tensor_info); - } - - ORT_RETURN_IF_ERROR(copy_status); - } - } - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception: ", e.what()); - }); - } - - return status; -} - -@end - -NS_ASSUME_NONNULL_END - -namespace onnxruntime { -namespace coreml { - Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, int64_t& num_blocks, int64_t& block_size, int64_t& stride) { const auto* shape = array.shape; @@ -498,11 +291,14 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, } // Internal Execution class -// This class will bridge Model (c++) with CoreMLExecution (objective c++) +// This class is part of the model class and handles the calls into CoreML. Specifically, it performs +// 1. Compile the model by given path for execution +// 2. Predict using given OnnxTensorFeatureProvider input and copy the output data back ORT +// 3. The compiled model will be removed in dealloc or removed using cleanup function class Execution { public: Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); - ~Execution() {}; + ~Execution(); Status LoadModel(); Status Predict(const std::unordered_map& inputs, @@ -510,30 +306,97 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn); private: - bool model_loaded{false}; - CoreMLExecution* execution_; + void cleanup(); + NSString* coreml_model_path_{nil}; + NSString* compiled_model_path_{nil}; + const logging::Logger& logger_; + uint32_t coreml_flags_{0}; + MLModel* model_{nil}; }; -Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) { +Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) + : logger_(logger), + coreml_flags_(coreml_flags) { @autoreleasepool { - execution_ = [[CoreMLExecution alloc] initWithPath:path - logger:logger - coreml_flags:coreml_flags]; + coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); + } +} + +Execution::~Execution() { + @autoreleasepool { + cleanup(); + } +} + +void Execution::cleanup() { + NSError* error = nil; + if (compiled_model_path_ != nil) { + [[NSFileManager defaultManager] removeItemAtPath:compiled_model_path_ error:&error]; + if (error != nil) { + LOGS(logger_, ERROR) << "Failed cleaning up the compiled model: " << [compiled_model_path_ UTF8String] + << ", error message: " << [[error localizedDescription] UTF8String]; + } + compiled_model_path_ = nil; + } + +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + // don't cleanup + coreml_model_path_ = nil; + } +#endif + + if (coreml_model_path_ != nil) { + error = nil; + [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; + if (error != nil) { + LOGS(logger_, ERROR) << "Failed cleaning up the coreml model: " << [coreml_model_path_ UTF8String] + << ", error message: " << [[error localizedDescription] UTF8String]; + } + coreml_model_path_ = nil; } } Status Execution::LoadModel() { - if (model_loaded) { + if (model_ != nil) { return Status::OK(); } if (HAS_COREML3_OR_LATER) { - Status status{}; @autoreleasepool { - status = [execution_ loadModel]; + NSError* error = nil; + + NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; + if (modelUrl == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); + } + + // TODO: Update this to version with callback handler as the API used here is deprecated. + // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl + // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the + // background. We will have to check for completion in `predict` and block until it is done. + NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; + if (error != nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", + [[error localizedDescription] UTF8String]); + } + + compiled_model_path_ = [compileUrl path]; + + MLModelConfiguration* config = [MLModelConfiguration alloc]; + config.computeUnits = (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) + ? MLComputeUnitsCPUOnly + : MLComputeUnitsAll; + model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; + + if (error != nil || model_ == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", + (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); + } + + return Status::OK(); } - model_loaded = status.IsOK(); - return status; } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+"); @@ -542,13 +405,94 @@ Status Predict(const std::unordered_map& inputs, Status Execution::Predict(const std::unordered_map& inputs, const std::unordered_map& outputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { - ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel"); - if (HAS_COREML3_OR_LATER) { @autoreleasepool { - return [execution_ predict:inputs - outputs:outputs - getOutputTensorDataFn:get_output_tensor_mutable_raw_data_fn]; + Status status = Status::OK(); + ORT_TRY { + if (model_ == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); + } + + id input_features; + InlinedVector> conversion_buffers; + ORT_RETURN_IF_ERROR(CreateInputFeatureProvider(inputs, logger_, &input_features, conversion_buffers)); + + MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; + NSError* error = nil; + id output_features = [model_ predictionFromFeatures:input_features + options:options + error:&error]; + + if (error != nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", + [[error localizedDescription] UTF8String]); + } + + for (const auto& [output_name, output_tensor_info] : outputs) { + MLFeatureValue* output_value = + [output_features featureValueForName:util::Utf8StringToNSString(output_name.c_str())]; + + if (output_value == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); + } + + MLMultiArray* data = [output_value multiArrayValue]; + + const auto coreml_static_output_shape = [data]() { + InlinedVector result; + result.reserve(data.shape.count); + for (NSNumber* dim in data.shape) { + const auto dim_value = dim.longLongValue; + result.push_back(dim_value); + } + return result; + }(); + + const auto static_output_shape = GetStaticOutputShape(output_tensor_info.shape, coreml_static_output_shape, + logger_); + + void* output_buffer = get_output_tensor_mutable_raw_data_fn(output_name, output_tensor_info.data_type, + static_output_shape); + + if (const size_t num_elements = data.count; num_elements > 0) { + if (const auto shape_size = ShapeSize(static_output_shape); + shape_size < 0 || num_elements != static_cast(shape_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "CoreML MLMultiArray count (", num_elements, ") and shape size (", shape_size, + ") do not match"); + } + + // support a non-contiguous array, provided only one dimension is not contiguous + int64_t num_blocks = 0; + int64_t block_size = 0; + int64_t stride = 0; + + ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, num_blocks, block_size, stride)); + + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, + num_blocks, block_size, stride, tensor_info); + }]; + } else { + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, + num_blocks, block_size, stride, tensor_info); + } + + ORT_RETURN_IF_ERROR(copy_status); + } + } + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception: ", e.what()); + }); + } + + return status; } } From d495e6cf1c477098255511c4136bb7ea43a7c0dc Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 11 Sep 2024 22:02:30 -0700 Subject: [PATCH 07/26] adds support for Uint8ClampedArray (#21985) Fixes https://github.com/microsoft/onnxruntime/issues/21753 --- js/common/lib/tensor-impl.ts | 19 ++++++++++++++++--- js/common/lib/tensor.ts | 17 +++++++++++++++++ .../type-tests/tensor/create-new-uint8.ts | 19 +++++++++++++++++++ .../unit-tests/tensor/constructor-type.ts | 8 ++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 js/common/test/type-tests/tensor/create-new-uint8.ts diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 4e0ef821dde57..342f5e3a467eb 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -51,13 +51,16 @@ export class Tensor implements TensorInterface { */ constructor( type: TensorType, - data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[], + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly number[] | readonly boolean[], dims?: readonly number[], ); /** * Construct a new CPU tensor object from the given data and dims. Type is inferred from data. */ - constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]); + constructor( + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly boolean[], + dims?: readonly number[], + ); /** * Construct a new tensor object from the pinned CPU data with the given type and dims. * @@ -90,12 +93,13 @@ export class Tensor implements TensorInterface { arg0: | TensorType | TensorDataType + | Uint8ClampedArray | readonly string[] | readonly boolean[] | CpuPinnedConstructorParameters | TextureConstructorParameters | GpuBufferConstructorParameters, - arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[], + arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[], arg2?: readonly number[], ) { // perform one-time check for BigInt/Float16Array support @@ -216,6 +220,12 @@ export class Tensor implements TensorInterface { } } else if (arg1 instanceof typedArrayConstructor) { data = arg1; + } else if (arg1 instanceof Uint8ClampedArray) { + if (arg0 === 'uint8') { + data = Uint8Array.from(arg1); + } else { + throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`); + } } else { throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`); } @@ -243,6 +253,9 @@ export class Tensor implements TensorInterface { } else { throw new TypeError(`Invalid element type of data array: ${firstElementType}.`); } + } else if (arg0 instanceof Uint8ClampedArray) { + type = 'uint8'; + data = Uint8Array.from(arg0); } else { // get tensor type from TypedArray const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get( diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 70396bbe1e9a3..8a1197994393b 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -192,6 +192,15 @@ export interface TensorConstructor extends TensorFactory { dims?: readonly number[], ): TypedTensor<'bool'>; + /** + * Construct a new uint8 tensor object from a Uint8ClampedArray, data and dims. + * + * @param type - Specify the element type. + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (type: 'uint8', data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new 64-bit integer typed tensor object from the given type, data and dims. * @@ -245,6 +254,14 @@ export interface TensorConstructor extends TensorFactory { */ new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>; + /** + * Construct a new uint8 tensor object from the given data and dims. + * + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new uint16 tensor object from the given data and dims. * diff --git a/js/common/test/type-tests/tensor/create-new-uint8.ts b/js/common/test/type-tests/tensor/create-new-uint8.ts new file mode 100644 index 0000000000000..46438f97ca2e7 --- /dev/null +++ b/js/common/test/type-tests/tensor/create-new-uint8.ts @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as ort from 'onnxruntime-common'; + +// construct from Uint8Array +// +// {type-tests}|pass +new ort.Tensor(new Uint8Array(1)); + +// construct from Uint8ClampedArray +// +// {type-tests}|pass +new ort.Tensor(new Uint8ClampedArray(1)); + +// construct from type (bool), data (Uint8ClampedArray) and shape (number array) +// +// {type-tests}|fail|1|2769 +new ort.Tensor('bool', new Uint8ClampedArray([255, 256]), [2]); diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index def711684d7f5..02390800e8611 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -82,6 +82,14 @@ describe('Tensor Constructor Tests - check types', () => { assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); }); + it('[uint8] new Tensor(uint8ClampedArray, dims): uint8 tensor can be constructed from Uint8ClampedArray', () => { + const uint8ClampedArray = new Uint8ClampedArray(2); + uint8ClampedArray[0] = 0; + uint8ClampedArray[1] = 256; // clamped + const tensor = new Tensor('uint8', uint8ClampedArray, [2]); + assert.equal(tensor.type, 'uint8', "tensor.type should be 'uint8'"); + }); + it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => { const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]); assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); From ae39c40e5b65874735cd07aca692287aa1cf1b62 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 12 Sep 2024 19:07:42 +0800 Subject: [PATCH 08/26] fix typo in iOS pipeline (#22067) ### Description ### Motivation and Context The parameter isn't correct. Maybe it hasn't negative impact by chance so far. https://github.com/microsoft/onnxruntime/blob/d8e64bb529c1d0f18efd47710d179205c96ffbca/cmake/CMakeLists.txt#L1712-L1717 --- tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml index 48d48156fe913..74211bc5dbd7c 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml @@ -53,7 +53,7 @@ jobs: python3 $(Build.SourcesDirectory)/tools/ci_build/build.py \ --skip_submodule_sync \ --build_dir $(Build.BinariesDirectory)/iOS \ - --build_shared \ + --build_shared_lib \ --use_coreml \ --use_xnnpack \ --ios \ From 951b1b7160b0efc21d97ead1051777410e2ca775 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Fri, 13 Sep 2024 01:54:32 +0900 Subject: [PATCH 09/26] [CI] Linux ROCm CI Pipeline: fix error, set trigger rules. (#22069) ### Description * Correct the wrong EP name for ROCm, fix CI error. * Update `set-trigger-rules.py`. * Modify the .yml via `set-trigger-rules.py` --- onnxruntime/test/python/onnxruntime_test_python.py | 2 +- .../ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml | 1 + tools/ci_build/set-trigger-rules.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index feabd648f8385..24151932a6681 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1694,7 +1694,7 @@ def test_register_custom_e_ps_library(self): available_eps = C.get_available_providers() # skip amd gpu build - if "RocmExecutionProvider" in available_eps: + if "ROCMExecutionProvider" in available_eps: return if sys.platform.startswith("win"): diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml index 7b77281b0efe2..50f3862761320 100644 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml @@ -1,4 +1,5 @@ ##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +### please do rerun set-trigger-rules.py ### trigger: branches: include: diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 583e5b05ed6d8..fb6aa44cdf31a 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -24,6 +24,7 @@ "linux-migraphx-ci-pipeline.yml", "linux-openvino-ci-pipeline.yml", "linux-qnn-ci-pipeline.yml", + "linux-rocm-ci-pipeline.yml", "mac-ci-pipeline.yml", "mac-coreml-ci-pipeline.yml", "mac-ios-ci-pipeline.yml", From 84f73327f55b3dadbf20b69bc1a12cc2811986ed Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:33:37 -0700 Subject: [PATCH 10/26] allow scalar axes for Unsqueeze for WebGPU (#22054) ### Description Align with CPU behavior. https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc#L60-L62 --- onnxruntime/core/providers/js/operators/unsqueeze.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.h b/onnxruntime/core/providers/js/operators/unsqueeze.h index 7cbfdc38b742d..f15a3008895aa 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.h +++ b/onnxruntime/core/providers/js/operators/unsqueeze.h @@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase { if (num_inputs == 2) { // axes is an input const Tensor* axes_tensor = context->Input(1); ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); axes.assign(data, data + nDims); From 10883d7997ed4b53f989a49bd4387c5769fbd12f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20P=C3=A9ron?= Date: Thu, 12 Sep 2024 18:46:27 +0100 Subject: [PATCH 11/26] Suppress GCC warning in TreeEnsembleAggregator (#22062) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description When building with GCC 14.2.1, I got the following warning: onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h:329:59: error: template-id not allowed for constructor in C++20 [-Werror=template-id-cdtor] Remove template parameters from the constructor: The constructor TreeAggregatorMax has been simplified to TreeAggregatorMax, because the compiler already knows the template parameters from the class definition. ### Motivation and Context Fix the build issue Signed-off-by: Clément Péron --- .../core/providers/cpu/ml/tree_ensemble_aggregator.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index 34c6db61982b5..b031a6f0cefa3 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -328,11 +328,10 @@ class TreeAggregatorMin : public TreeAggregator class TreeAggregatorMax : public TreeAggregator { public: - TreeAggregatorMax(size_t n_trees, - const int64_t& n_targets_or_classes, - POST_EVAL_TRANSFORM post_transform, - const std::vector& base_values) : TreeAggregator(n_trees, n_targets_or_classes, - post_transform, base_values) {} + TreeAggregatorMax(size_t n_trees, + const int64_t& n_targets_or_classes, + POST_EVAL_TRANSFORM post_transform, + const std::vector& base_values) : TreeAggregator(n_trees, n_targets_or_classes, post_transform, base_values) {} // 1 output From d539c27de82b9d1631b743b941f9c3ade49e7a05 Mon Sep 17 00:00:00 2001 From: wangshuai09 <391746016@qq.com> Date: Fri, 13 Sep 2024 02:42:17 +0800 Subject: [PATCH 12/26] Fix version check for using -mavxvnni (#21616) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Change the `CMAKE_CXX_COMPILER_VERSION` greater than `11` for using '-mavxvnni'. ### Motivation and Context `CMakeFiles/onnxruntime_mlas.dir/root/Git.d/onnxruntime/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S.o cc: error: unrecognized command-line option ‘-mavxvnni’; did you mean ‘-mavx512vnni’?` using `gcc (GCC) 10.3.1`. `-mavxnni` is supported since [GCC 11 Release](https://gcc.gnu.org/gcc-11/changes.html), this PR change the version check. --- cmake/onnxruntime_mlas.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index cf23416943c1f..b612b3ead4658 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -578,7 +578,7 @@ else() message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") message(STATUS "Using -mavx2 -mfma -mavxvnni flags") set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") else() From 5c361106e61b94213784c7a6953e8a099235c7e4 Mon Sep 17 00:00:00 2001 From: 0xdr3dd Date: Fri, 13 Sep 2024 00:20:34 +0530 Subject: [PATCH 13/26] [Fuzzer] Add two new ORT libfuzzer (Linux clang support for now) (#22055) ### Description This PR adds two new libfuzzer in fuzzer project. 1. Binary libfuzzer 2. libprotobuf-fuzzer To compile run below cmd on linux: ``` LLVM_PROFILE_FILE="%p.profraw" CFLAGS="-g -fsanitize=address,fuzzer-no-link -shared-libasan -fprofile-instr-generate -fcoverage-mapping" CXXFLAGS="-g -shared-libasan -fsanitize=address,fuzzer-no-link -fprofile-instr-generate -fcoverage-mapping" CC=clang CXX=clang++ ./build.sh --update --build --config Debug --compile_no_warning_as_error --build_shared_lib --skip_submodule_sync --use_full_protobuf --parallel --fuzz_testing --build_dir build/ ``` Run fuzzer: ``` LD_PRELOAD=$(clang -print-file-name=libclang_rt.asan-x86_64.so) build/Debug/onnxruntime_libfuzzer_fuzz testinput -rss_limit_mb=8196 -max_total_time=472800 -fork=2 -jobs=4 -workers=4 -ignore_crashes=1 -max_len=2097152 2>&1 | grep -v "\[libprotobuf ERROR" ``` ### Motivation and Context The existing custom fuzzer is not coverage guided and it's slow and it will work on one model mutation at a time. The new fuzzers are coverage guided, and we can use more models' files as a corpus to increase the coverage. --- cmake/onnxruntime_fuzz_test.cmake | 145 ++++++++++++------ .../fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp | 42 +++++ .../ort_libfuzzer/OrtProtoLibfuzzer.cpp | 94 ++++++++++++ 3 files changed, 236 insertions(+), 45 deletions(-) create mode 100644 onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp create mode 100644 onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp diff --git a/cmake/onnxruntime_fuzz_test.cmake b/cmake/onnxruntime_fuzz_test.cmake index 26d41e98687d4..eea411d938176 100644 --- a/cmake/onnxruntime_fuzz_test.cmake +++ b/cmake/onnxruntime_fuzz_test.cmake @@ -4,23 +4,24 @@ # Check that the options are properly set for # the fuzzing project if (onnxruntime_FUZZ_ENABLED) - message(STATUS "Building dependency protobuf-mutator and libfuzzer") - - # set the options used to control the protobuf-mutator build - set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB}) - set(LIB_PROTO_MUTATOR_TESTING OFF) - - # include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes - # with google test - add_subdirectory("external/libprotobuf-mutator/src") - - # add the appropriate include directory and compilation flags - # needed by the protobuf-mutator target and the libfuzzer - set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator") - onnxruntime_add_include_to_target(protobuf-mutator ${PROTOBUF_LIB}) - onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) - target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) - target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + message(STATUS "Building dependency protobuf-mutator and libfuzzer") + + # set the options used to control the protobuf-mutator build + set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB}) + set(LIB_PROTO_MUTATOR_TESTING OFF) + + # include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes + # with google test + add_subdirectory("external/libprotobuf-mutator/src") + + # add the appropriate include directory and compilation flags + # needed by the protobuf-mutator target and the libfuzzer + set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator") + onnxruntime_add_include_to_target(protobuf-mutator ${PROTOBUF_LIB}) + onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # MSVC-specific compiler options target_compile_options(protobuf-mutator PRIVATE "/wd4244" "/wd4245" "/wd4267" "/wd4100" "/wd4456") @@ -44,42 +45,96 @@ if (onnxruntime_FUZZ_ENABLED) ) endif() - # add Fuzzing Engine Build Configuration - message(STATUS "Building Fuzzing engine") + # add Fuzzing Engine Build Configuration + message(STATUS "Building Fuzzing engine") + + # set Fuzz root directory + set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing) + + # Security fuzzing engine src file reference + set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp" + "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/src/test.cpp") + + # compile the executables + onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC}) + + # compile with c++17 + target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17) - # set Fuzz root directory - set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime) - # Security fuzzing engine src file reference - set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp" - "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" - "${SEC_FUZZ_ROOT}/src/testlog.cpp" - "${SEC_FUZZ_ROOT}/src/test.cpp") + # Assign all include to one variable + set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include") + set(INCLUDE_FILES ${SEC_FUZ_INC} "$") - # compile the executables - onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC}) + # add all these include directory to the Fuzzing engine + target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES}) - # compile with c++17 - target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17) + # add link libraries to the project + target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) - # Security fuzzing engine header file reference - onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime) + # add the dependencies + add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) - # Assign all include to one variable - set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include") - set(INCLUDE_FILES ${SEC_FUZ_INC} "$") + # copy the shared libraries (DLLs on Windows, SOs on Linux) to the execution directory + add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) - # add all these include directory to the Fuzzing engine - target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES}) + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + # Add a second fuzzer that uses libFuzzer in fuzzer/libfuzzer + message(STATUS "Building libProtoBufFuzzer-based fuzzer") - # add link libraries the project - target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) + # Set source files for the libFuzzer + set(LIBFUZZER_SRC "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/ort_libfuzzer/OrtProtoLibfuzzer.cpp") - # add the dependencies - add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) + # Compile the libFuzzer-based fuzzer + onnxruntime_add_executable(onnxruntime_proto_libfuzzer ${LIBFUZZER_SRC}) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_proto_libfuzzer onnx onnxruntime) + # Set include directories for libFuzzer + target_include_directories(onnxruntime_proto_libfuzzer PRIVATE ${INCLUDE_FILES}) - # copy the dlls to the execution directory - add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + # Add link libraries for libFuzzer + target_link_libraries(onnxruntime_proto_libfuzzer onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer -fsanitize=fuzzer,address ${PROTOBUF_LIB}) + + # Add the dependencies for libFuzzer + add_dependencies(onnxruntime_proto_libfuzzer onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + + # Copy shared libraries for libFuzzer + add_custom_command(TARGET onnxruntime_proto_libfuzzer POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + # Add a second fuzzer that uses libFuzzer in fuzzer/libfuzzer + message(STATUS "Building libBufFuzzer-based fuzzer") + + # Set source files for the libFuzzer + set(LIBFUZZER_SRC "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/ort_libfuzzer/OrtLibfuzzer.cpp") + + # Compile the libFuzzer-based fuzzer + onnxruntime_add_executable(onnxruntime_libfuzzer_fuzz ${LIBFUZZER_SRC}) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_libfuzzer_fuzz onnx onnxruntime) + # Set include directories for libFuzzer + target_compile_definitions(onnxruntime_libfuzzer_fuzz PRIVATE GOOGLE_PROTOBUF_NO_LOGGING=1) + target_include_directories(onnxruntime_libfuzzer_fuzz PRIVATE ${INCLUDE_FILES}) + + # Add link libraries for libFuzzer + target_link_libraries(onnxruntime_libfuzzer_fuzz onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer -fsanitize=fuzzer,address ${PROTOBUF_LIB}) + + # Add the dependencies for libFuzzer + add_dependencies(onnxruntime_libfuzzer_fuzz onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + + # Copy shared libraries for libFuzzer + add_custom_command(TARGET onnxruntime_libfuzzer_fuzz POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + endif() endif() diff --git a/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp new file mode 100644 index 0000000000000..406aca722bb67 --- /dev/null +++ b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "OnnxPrediction.h" +#include "onnxruntime_session_options_config_keys.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "fuzzer/FuzzedDataProvider.h" + +Ort::Env env; + +void predict(onnx::ModelProto& msg, unsigned int seed, Ort::Env& env) { + // Create object for prediction + // + OnnxPrediction predict(msg, env); + + // Give predict a function to generate the data + // to run prediction on. + // + predict.SetupInput(GenerateDataForInputTypeTensor, seed); + + // Run the prediction on the data + // + predict.RunInference(); +} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + FuzzedDataProvider data_provider(data, size); + onnx::ModelProto msg; + try { + if (!msg.ParseFromArray(data, static_cast(size))) { + return 0; // Ignore invalid inputs + } + predict(msg, data_provider.ConsumeIntegral(), env); + } catch (const std::exception& e) { + // Optionally log or suppress the exception + // std::cerr << "Caught exception: " << e.what() << std::endl; + } catch (...) { + // Handle any other exceptions + // std::cerr << "Caught unknown exception." << std::endl; + } + return 0; +} diff --git a/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp new file mode 100644 index 0000000000000..607d9cfd9c755 --- /dev/null +++ b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "src/mutator.h" +#include "OnnxPrediction.h" +#include "onnxruntime_session_options_config_keys.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "onnx/onnx_pb.h" + +#include + +Ort::Env env; + +std::string wstring_to_string(const std::wstring& wstr) { + std::wstring_convert> converter; + return converter.to_bytes(wstr); +} + +void predict(onnx::ModelProto& msg, unsigned int seed, Ort::Env& env) { + // Create object for prediction + // + OnnxPrediction predict(msg, env); + + // Give predict a function to generate the data + // to run prediction on. + // + predict.SetupInput(GenerateDataForInputTypeTensor, seed); + + // Run the prediction on the data + // + predict.RunInference(); + + // View the output + // + predict.PrintOutputValues(); +} + +template +using PostProcessor = + protobuf_mutator::libfuzzer::PostProcessorRegistration; + +// Helper function to generate random strings +std::string generate_random_string(size_t length, std::mt19937& rng) { + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::uniform_int_distribution<> dist(0, characters.size() - 1); + std::string result; + for (size_t i = 0; i < length; ++i) { + result += characters[dist(rng)]; + } + return result; +} + +// Helper function to generate random float +float generate_random_float(std::mt19937& rng) { + std::uniform_real_distribution dist(0.0f, 1.0f); + return dist(rng); +} + +// PostProcessor for ONNX ModelProto with random values +static PostProcessor reg1 = { + [](onnx::ModelProto* model_proto, unsigned int seed) { + std::mt19937 rng(seed); + + // Set model's IR version + model_proto->set_ir_version(7); + + model_proto->set_producer_name("onnx"); + model_proto->set_producer_version("7.0"); + model_proto->set_domain("example.com"); + + // Add a dummy opset import + auto* opset_import = model_proto->add_opset_import(); + opset_import->set_version(10); + + // Access the graph from the model + auto* graph = model_proto->mutable_graph(); + + // Set a random name for the graph + graph->set_name(generate_random_string(10, rng)); + }}; + +DEFINE_PROTO_FUZZER(const onnx::ModelProto& msg) { + try { + auto seed = static_cast(std::chrono::system_clock::now().time_since_epoch().count()); + onnx::ModelProto msg_proto = msg; + predict(msg_proto, seed, env); + } catch (const std::exception& e) { + // Optionally log or suppress the exception + // std::cerr << "Caught exception: " << e.what() << std::endl; + } catch (...) { + // Handle any other exceptions + // std::cerr << "Caught unknown exception." << std::endl; + } +} From 55ab13e7ca8c5147ff5d7e82da5b6bde01720f7d Mon Sep 17 00:00:00 2001 From: mingyueliuh <131847423+mingyueliuh@users.noreply.github.com> Date: Thu, 12 Sep 2024 19:23:09 -0400 Subject: [PATCH 14/26] [VitisAI] support memory buffer contains the TensorProto external data (#22042) ### Description Extend VitisAI EP `tensor_proto_as_raw` API to support memory buffer containing the TensorProto external data ### Motivation and Context For reduce peak memory usage, VitisAI EP need support ORT format model and setting session option `session.use_ort_model_bytes_for_initializers` for enable directly use the model bytes for initializers. Co-authored-by: mingyue --- .../providers/vitisai/imp/tensor_proto.cc | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 4b2b7610cf7ea..872d022e85264 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -9,9 +9,44 @@ #include "core/providers/shared_library/provider_api.h" namespace vaip { using namespace onnxruntime; + +static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { + auto tensor_proto = const_cast(&tensor); + auto file = std::string(); + uintptr_t offset = 0; + size_t size = 0; + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) { + auto external_data = tensor_proto->mutable_external_data(); + auto external_data_size = external_data->size(); + for (auto i = 0; i < external_data_size; ++i) { + auto& data = external_data->at(i); + char* end = nullptr; + if (*data.mutable_key() == "location") { + file = *data.mutable_value(); + } else if (*data.mutable_key() == "offset") { + offset = (uintptr_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "length") { + size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "checksum") { + // checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } + } + if (file == "*/_ORT_MEM_ADDR_/*") { + auto addr = reinterpret_cast(offset); + return {addr, size}; + } + } + return {}; +} + gsl::span tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor) { auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { + auto maybe_external_memory_address = process_ext_address(tensor); + if (!maybe_external_memory_address.empty()) { + return maybe_external_memory_address; + } + std::vector unpacked_tensor; auto path = graph.ModelPath(); auto s = onnxruntime::utils::UnpackInitializerData(tensor, path, unpacked_tensor); From f7bf5a19baf0a7caa9cca7dc08bf192e392a14e4 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 12 Sep 2024 17:18:50 -0700 Subject: [PATCH 15/26] [QNN EP] Ensure QNN EP rejects nodes with I/O of dynamic shape (#22066) ### Description Updates QNN EP to properly reject nodes that have inputs or outputs with dynamic shapes. ### Motivation and Context Currently, QNN EP does not properly offload subgraphs with dynamic shapes to the CPU EP. This PR ensures that QNN EP rejects nodes that consume or generate I/O with dynamic shapes. --- .../qnn/builder/qnn_model_wrapper.cc | 4 +- .../test/providers/qnn/qnn_basic_test.cc | 57 +++++++++++++++++++ .../test/providers/qnn/qnn_test_utils.cc | 4 +- .../test/providers/qnn/qnn_test_utils.h | 6 +- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 3c029fda9cd52..2c7f3c8b22ddd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -308,8 +308,10 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vectordim()) { + if (!dim.has_dim_value()) { + return false; // Do not support dynamic shapes. + } shape.push_back(SafeInt(dim.dim_value())); } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 9d19c36dc94b2..c4367aeb52edc 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -948,6 +948,63 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } +// Test that QNN EP only handles nodes with static shapes and rejects nodes with dynamic shape I/O. +TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) { + // Local function that builds a model in which the last two nodes use dynamic shapes. + auto model_build_fn = [](ModelTestBuilder& builder) { + NodeArg* input1 = builder.MakeInput(std::vector{1, 2, 8, 8}, + GetFloatDataInRange(0.0f, 1.0f, 128)); + NodeArg* input2 = builder.MakeInput(std::vector{3}, std::vector{1, 2, 49}); + + // Add a Conv with known shapes. QNN EP should support it. + NodeArg* weight = builder.MakeInitializer(std::vector{2, 2, 2, 2}, + GetFloatDataInRange(-0.3f, 0.3f, 16)); + NodeArg* bias = builder.MakeInitializer(std::vector{2}, {0.0f, 1.0f}); + + auto* conv_output = builder.MakeIntermediate(); + builder.AddNode("Conv", {input1, weight, bias}, {conv_output}); + + // Add a Reshape to a dynamic shape. QNN EP should reject this node. + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {conv_output, input2}, {reshape_output}); + + // Add a Softmax. QNN EP should reject this node because its input has a dynamic shape. + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Softmax", {reshape_output}, {output}); + }; + + // Local function that checks that the nodes with dynamic shape I/O were assigned to CPU EP. + std::function ep_graph_checker = [](const Graph& graph) { + for (const Node& node : graph.Nodes()) { + const std::string& ep_name = node.GetExecutionProviderType(); + const std::string& op_type = node.OpType(); + if (op_type == "Reshape" || op_type == "Softmax") { + EXPECT_EQ(ep_name, kCpuExecutionProvider); + } else { + EXPECT_EQ(ep_name, kQnnExecutionProvider); + } + } + }; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_fp16_precision"] = "1"; // QNN EP will use fp16 precision. + // CPU EP will use fp32, so we can relax accuracy requirements. + + RunQnnModelTest(model_build_fn, + provider_options, + /*opset*/ 19, + ExpectedEPNodeAssignment::Some, + /*abs_err*/ 1e-4f, + logging::Severity::kERROR, + /*verify_output*/ true, + &ep_graph_checker); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index afaa5a341d5e9..8a4f7f2a1f6b5 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -98,10 +98,12 @@ void TryEnableQNNSaver(ProviderOptions& qnn_options) { void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, - float fp32_abs_err, logging::Severity log_severity, bool verify_outputs) { + float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, + std::function* ep_graph_checker) { EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; + verification_params.graph_verifier = ep_graph_checker; // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 3a6753e9b6131..bb77c92668853 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -1033,12 +1033,16 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). * \param fp32_abs_err The acceptable error between CPU EP and QNN EP. * \param log_severity The logger's minimum severity level. + * \param verify_outputs True to verify that the outputs match (within tolerance). + * \param ep_graph_checker Function called on the Graph generated for the EP's session. Used to check node + * EP assignment. */ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f, logging::Severity log_severity = logging::Severity::kERROR, - bool verify_outputs = true); + bool verify_outputs = true, + std::function* ep_graph_checker = nullptr); enum class BackendSupport { SUPPORT_UNKNOWN, From 22437b581b8559702fe9f5a5fe2309a495bd9e15 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 12 Sep 2024 22:38:17 -0400 Subject: [PATCH 16/26] [java] Fix for OnnxTensor creation when passing in a ByteBuffer containing elements of a different type (#21774) ### Description Fixes a bug where the buffer offset and position was incorrectly computed if the user supplied a `ByteBuffer` to `createTensor` but set the type of the tensor to something other than `INT8`. This would be more common if the user was trying to load the initializers from a serialized representation and didn't want to bother with the type information (which is the case in #21321). ### Motivation and Context Partial fix for #21321. The remainder of the fix is to add a helper which allows users to load initializers out of an `onnx_data` file, but that will require adding protobuf as a dependency for the Java API to allow the parsing of an ONNX file separately from the native code. It might be nicer to put that functionality into ORT's C API so it can return the lengths & offsets of the initializers when provided with an ONNX file containing external initializers. We hit this kind of thing in Java more often than other languages as in Java models can be supplied as classpath resources which we can easily read, but not materialize on disk for the ORT native library to read. --- .../src/main/java/ai/onnxruntime/OrtUtil.java | 13 ++++++---- .../java/ai/onnxruntime/OnnxTensorTest.java | 26 ++++++++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 5b2e9b2efac4c..4f3dee3c00b91 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -483,9 +483,12 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) { if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) { throw new IllegalStateException("Cannot create a " + type + " tensor from a buffer"); } + // This buffer could be a ByteBuffer which is being used to carry data of another type, if so, + // it's type.size should be 1 to compute the correct buffer size and offset. + int elementSize = data instanceof ByteBuffer ? 1 : type.size; int bufferPos; - long bufferSizeLong = data.remaining() * (long) type.size; - if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) { + long bufferSizeLong = data.remaining() * (long) elementSize; + if (bufferSizeLong > (Integer.MAX_VALUE - (8L * elementSize))) { // The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending // on the JVM, so we check for something 8 elements below the maximum size which // should be allocatable (assuming there is enough memory) on all 64-bit JVMs. @@ -496,11 +499,11 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) { + type); } // Now we know we're in range - int bufferSize = data.remaining() * type.size; + int bufferSize = data.remaining() * elementSize; Buffer tmp; if (data.isDirect()) { tmp = data; - bufferPos = data.position() * type.size; + bufferPos = data.position() * elementSize; } else { // Copy the data to a new direct buffer, then restore the state of the input. int origPosition = data.position(); diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index c060cf73ecf14..ea210d96c1507 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -218,6 +218,30 @@ public void testUint8Creation() throws OrtException { } } + @Test + public void testByteBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder()); + FloatBuffer floatBuf = byteBuf.asFloatBuffer(); + floatBuf.put(1.0f); + floatBuf.put(2.0f); + floatBuf.put(3.0f); + floatBuf.put(4.0f); + floatBuf.put(5.0f); + floatBuf.position(1); + float[] expected = new float[floatBuf.remaining()]; + floatBuf.get(expected); + floatBuf.position(1); + byteBuf.position(4); + try (OnnxTensor t = + OnnxTensor.createTensor( + env, byteBuf, new long[] {floatBuf.remaining()}, OnnxJavaType.FLOAT)) { + Assertions.assertNotNull(t); + float[] actual = (float[]) t.getValue(); + Assertions.assertArrayEquals(expected, actual); + } + } + @Test public void testEmptyTensor() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); From 904b850b445ccfb3dc935e39b96cfb3dbfb52673 Mon Sep 17 00:00:00 2001 From: Michael Tyler <67695629+MichaelTylerArm@users.noreply.github.com> Date: Fri, 13 Sep 2024 04:51:59 +0100 Subject: [PATCH 17/26] Update Arm Compute Library Execution Provider (#22032) ### Description This PR makes the following updates to the Arm Compute Library execution provider: - Target Arm Compute Library 24.07 - Add support for the following operators: - Conv (FP16) - NhwcConv - QLinearConv - MatMul - FusedMatMul - MatMulIntegerToFloat - Optimize memory usage and performance - Expose the enable_fast_math setting - Use the main runtime thread pool ### Motivation and Context These updates improve performance and memory usage, and enable use of a more recent version of Arm Compute Library. @microsoft-github-policy-service agree company="Arm Ltd" --------- Signed-off-by: Michael Tyler --- cmake/CMakeLists.txt | 39 +- .../core/providers/acl/acl_provider_factory.h | 4 +- .../main/java/ai/onnxruntime/OrtSession.java | 10 +- ...ai_onnxruntime_OrtSession_SessionOptions.c | 8 +- .../core/optimizer/graph_transformer_utils.cc | 33 +- .../core/optimizer/nhwc_transformer.cc | 4 +- .../qdq_selector_action_transformer.cc | 5 +- onnxruntime/core/providers/acl/acl_common.cc | 155 ++++- onnxruntime/core/providers/acl/acl_common.h | 27 +- .../providers/acl/acl_execution_provider.cc | 110 +++- .../providers/acl/acl_execution_provider.h | 30 +- .../providers/acl/acl_provider_factory.cc | 16 +- .../acl/acl_provider_factory_creator.h | 3 +- onnxruntime/core/providers/acl/math/gemm.h | 33 +- onnxruntime/core/providers/acl/math/matmul.cc | 404 ++++++++++++ onnxruntime/core/providers/acl/math/matmul.h | 64 ++ .../core/providers/acl/nn/batch_norm.cc | 5 +- onnxruntime/core/providers/acl/nn/conv.cc | 605 +++++++++++------- onnxruntime/core/providers/acl/nn/conv.h | 57 +- .../core/providers/acl/nn/fused_conv.cc | 9 +- onnxruntime/core/providers/acl/nn/pool.cc | 21 +- onnxruntime/core/providers/acl/scheduler.cc | 44 ++ onnxruntime/core/providers/acl/scheduler.h | 33 + .../core/providers/acl/tensor/concat.cc | 9 +- .../python/onnxruntime_pybind_schema.cc | 3 +- .../python/onnxruntime_pybind_state.cc | 22 +- .../python/onnxruntime_pybind_state_common.h | 3 +- onnxruntime/test/onnx/main.cc | 3 +- .../test/perftest/command_args_parser.cc | 2 + onnxruntime/test/perftest/ort_test_session.cc | 38 +- onnxruntime/test/providers/cpu/model_tests.cc | 3 +- onnxruntime/test/util/default_providers.cc | 7 +- .../test/util/include/default_providers.h | 3 +- tools/ci_build/build.py | 10 +- 34 files changed, 1396 insertions(+), 426 deletions(-) create mode 100644 onnxruntime/core/providers/acl/math/matmul.cc create mode 100644 onnxruntime/core/providers/acl/math/matmul.h create mode 100644 onnxruntime/core/providers/acl/scheduler.cc create mode 100644 onnxruntime/core/providers/acl/scheduler.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index fb3b75fda4eaf..3d4f055bb6f53 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # Licensed under the MIT License. # Minimum CMake required @@ -132,11 +133,6 @@ option(onnxruntime_USE_DML "Build with DirectML support" OFF) option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF) option(onnxruntime_USE_WINML "Build with WinML support" OFF) option(onnxruntime_USE_ACL "Build with ACL support" OFF) -option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) -option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) -option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) -option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF) -option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF) option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF) option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON) option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON) @@ -1207,25 +1203,8 @@ function(onnxruntime_add_include_to_target dst_target) endfunction() # ACL -if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308) +if (onnxruntime_USE_ACL) set(onnxruntime_USE_ACL ON) - if (onnxruntime_USE_ACL_1902) - add_definitions(-DACL_1902=1) - else() - if (onnxruntime_USE_ACL_1908) - add_definitions(-DACL_1908=1) - else() - if (onnxruntime_USE_ACL_2002) - add_definitions(-DACL_2002=1) - else() - if (onnxruntime_USE_ACL_2308) - add_definitions(-DACL_2308=1) - else() - add_definitions(-DACL_1905=1) - endif() - endif() - endif() - endif() if (NOT ${onnxruntime_ACL_LIBS} STREQUAL "") add_library(arm_compute SHARED IMPORTED) @@ -1233,18 +1212,13 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") - add_library(arm_compute_core SHARED IMPORTED) - set_target_properties(arm_compute_core PROPERTIES - IMPORTED_NO_SONAME 1 - IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") - add_library(arm_compute_graph SHARED IMPORTED) set_target_properties(arm_compute_graph PROPERTIES IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_graph.so") endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_core arm_compute_graph) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_graph) endif() @@ -1263,11 +1237,6 @@ if (onnxruntime_USE_ARMNN) IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") - add_library(arm_compute_core SHARED IMPORTED) - set_target_properties(arm_compute_core PROPERTIES - IMPORTED_NO_SONAME 1 - IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") - add_library(arm_compute_graph SHARED IMPORTED) set_target_properties(arm_compute_graph PROPERTIES IMPORTED_NO_SONAME 1 @@ -1281,7 +1250,7 @@ if (onnxruntime_USE_ARMNN) IMPORTED_LOCATION "${onnxruntime_ARMNN_LIBS}/libarmnn.so") endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_core arm_compute_graph) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_graph) endif() if (onnxruntime_USE_DNNL) diff --git a/include/onnxruntime/core/providers/acl/acl_provider_factory.h b/include/onnxruntime/core/providers/acl/acl_provider_factory.h index 0dc0ec27ff345..8875a83a39f54 100644 --- a/include/onnxruntime/core/providers/acl/acl_provider_factory.h +++ b/include/onnxruntime/core/providers/acl/acl_provider_factory.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "onnxruntime_c_api.h" @@ -10,7 +11,8 @@ extern "C" { /** * \param use_arena zero: false. non-zero: true. */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, int use_arena) +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, + bool enable_fast_math) ORT_ALL_ARGS_NONNULL; #ifdef __cplusplus diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8ab4a1cb26bb1..8fe73ff69e169 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ package ai.onnxruntime; @@ -1181,12 +1182,12 @@ public void addDirectML(int deviceId) throws OrtException { /** * Adds the ARM Compute Library as an execution backend. * - * @param useArena If true use the arena memory allocator. + * @param enableFastMath Enable fast math mode in ACL. * @throws OrtException If there was an error in native code. */ - public void addACL(boolean useArena) throws OrtException { + public void addACL(boolean enableFastMath) throws OrtException { checkClosed(); - addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); + addACL(OnnxRuntime.ortApiHandle, nativeHandle, enableFastMath); } /** @@ -1354,7 +1355,8 @@ private native void addTvm(long apiHandle, long nativeHandle, String settings) private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) throws OrtException; - private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException; + private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) + throws OrtException; private native void addArmNN(long apiHandle, long nativeHandle, int useArena) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 337f4c1921c6e..ff9348c299e90 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -1,5 +1,6 @@ /* * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ #include @@ -644,12 +645,13 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDir * Signature: (JJI)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addACL - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint useArena) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jboolean enableFastMath) { (void)jobj; #ifdef USE_ACL - checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle,useArena)); + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle, + OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle, enableFastMath)); #else - (void)apiHandle;(void)handle;(void)useArena; // Parameters used when ACL is defined. + (void)apiHandle;(void)handle;(void)enableFastMath; // Parameters used when ACL is defined. throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with ACL support."); #endif } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 0530ab771e0be..997d99441d36d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/optimizer/graph_transformer_utils.h" @@ -196,6 +197,8 @@ InlinedVector> GenerateTransformers( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; #ifndef DISABLE_CONTRIB_OPS const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; + const InlinedHashSet cpu_acl_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; AllocatorPtr cpu_allocator = std::make_shared(); @@ -285,6 +288,11 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_acl_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kDmlExecutionProvider}; const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kAclExecutionProvider, @@ -296,8 +304,9 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider}; - const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kDmlExecutionProvider, + onnxruntime::kAclExecutionProvider}; const int64_t qdq_matmulnbits_accuracy_level = ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, @@ -323,26 +332,26 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_dml_eps)); - transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_dml_acl_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_eps)); transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); // GeluApproximation has side effects which may change results. It needs to be manually enabled, // or alternatively the model can be updated offline using a model conversion script @@ -367,7 +376,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); #endif - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(dml_ep)); #ifdef MLAS_TARGET_AMD64_IX86 diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index e67557dcf9391..ee79fa620374e 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -183,7 +184,8 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, modified = false; for (std::unique_ptr& node : api_graph->Nodes()) { // If the node is not supported in the CPU EP, skip it - if (node->GetExecutionProviderType() != kCpuExecutionProvider) { + const auto ep = node->GetExecutionProviderType(); + if ((ep != kCpuExecutionProvider) && (ep != kAclExecutionProvider)) { continue; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index adfa680878945..1c506bafd1d14 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -381,9 +382,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool, p_buffered_tensors), apply_context, - // this transformer is compatible with CPU, DML and CUDA EP. + // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. - {kCpuExecutionProvider, kDmlExecutionProvider, kCudaExecutionProvider}} { + {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider}} { } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_common.cc b/onnxruntime/core/providers/acl/acl_common.cc index f1ab6682a8259..c8d878a81bd1a 100644 --- a/onnxruntime/core/providers/acl/acl_common.cc +++ b/onnxruntime/core/providers/acl/acl_common.cc @@ -1,5 +1,6 @@ // Copyright(C) 2018 Intel Corporation // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License #ifdef _WIN32 @@ -8,14 +9,45 @@ #include "core/providers/acl/acl_common.h" -#include "arm_compute/runtime/PoolManager.h" -#include "arm_compute/runtime/BlobLifetimeManager.h" - -#undef ACL_1902 - namespace onnxruntime { namespace acl { +void PopulateWorkspace(const arm_compute::experimental::MemoryRequirements& reqs, + Workspace& workspace, arm_compute::MemoryGroup& memory_group, + arm_compute::ITensorPack& run_pack, arm_compute::ITensorPack& prep_pack) { + for (const arm_compute::experimental::MemoryInfo& req : reqs) { + if (req.size == 0) { + continue; + } + + arm_compute::Tensor* aux_tensor; + if (req.lifetime == arm_compute::experimental::MemoryLifetime::Temporary) { + workspace.temporary_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.temporary_tensors.back().get(); + + memory_group.manage(aux_tensor); + } else if (req.lifetime == arm_compute::experimental::MemoryLifetime::Prepare) { + workspace.prepare_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.prepare_tensors.back().get(); + + prep_pack.add_tensor(req.slot, aux_tensor); + } else { + workspace.persistent_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.persistent_tensors.back().get(); + + prep_pack.add_tensor(req.slot, aux_tensor); + } + run_pack.add_tensor(req.slot, aux_tensor); + + const auto aux_info = arm_compute::TensorInfo{arm_compute::TensorShape(req.size), 1, arm_compute::DataType::U8}; + aux_tensor->allocator()->init(aux_info, req.alignment); + } + + for (const std::unique_ptr& tensor : workspace.temporary_tensors) { + tensor->allocator()->allocate(); + } +} + arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned int extDim) { arm_compute::TensorShape shape; unsigned int inDim = tensorShape.NumDimensions(); @@ -36,27 +68,112 @@ arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned return shape; } +Status GetArgShape(const NodeArg* tensor, TensorShape& outShape) { + const auto& inShape = tensor->Shape(); + TensorShapeVector shapeVec; + + for (int i = 0; i < inShape->dim_size(); i++) { + const auto& dim = inShape->dim(i); + ORT_RETURN_IF_NOT(dim.has_dim_value(), "ACL does not support unknown tensor shapes: ", tensor->Name()); + shapeVec.push_back(dim.dim_value()); + } + + outShape = TensorShape(shapeVec); + return Status::OK(); +} + void ACLPrintTensorShape(const char* s, arm_compute::Tensor& t) { for (unsigned int i = 0; i < t.info()->tensor_shape().num_dimensions(); i++) LOGS_DEFAULT(VERBOSE) << "ACL " << s << " " << t.info()->tensor_shape()[i]; LOGS_DEFAULT(VERBOSE) << std::endl; } -std::shared_ptr ACLCreateMemoryManager() { - auto lifetime_mgr = std::make_shared(); - auto pool_mgr = std::make_shared(); - auto mm = std::make_shared(lifetime_mgr, pool_mgr); +arm_compute::DataType ACLDataType(const std::string& dtype) { + if (dtype == "tensor(float)") { + return arm_compute::DataType::F32; + } + if (dtype == "tensor(float16)") { + return arm_compute::DataType::F16; + } + if (dtype == "tensor(bfloat16)") { + return arm_compute::DataType::BFLOAT16; + } + if (dtype == "tensor(uint8)") { + return arm_compute::DataType::QASYMM8; + } + if (dtype == "tensor(int8)") { + return arm_compute::DataType::QASYMM8_SIGNED; + } + if (dtype == "tensor(int32)") { + return arm_compute::DataType::S32; + } + ORT_THROW("ACL execution provider does not support data type ", dtype); +} - return mm; +int GetIntScalar(const Tensor* tensor) { + ORT_ENFORCE(tensor->Shape().Size() == 1, "Tensor is not a scalar"); + if (tensor->IsDataType()) { + return *tensor->Data(); + } + if (tensor->IsDataType()) { + return *tensor->Data(); + } + ORT_THROW("Unsupported int type: ", DataTypeImpl::ToString(tensor->DataType())); } -arm_compute::Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size) { -#ifdef ACL_1902 - return allocator->import_memory(memory, size); -#else +Status LoadQuantizationInfo(const OpKernelInfo& info, arm_compute::Tensor* tensor, + const int scaleIdx, const int zpIdx, bool flipZeroPoint) { + const Tensor* scaleTensor = nullptr; + ORT_RETURN_IF_NOT(info.TryGetConstantInput(scaleIdx, &scaleTensor), "Scale must be constant"); + + const Tensor* zeroPointTensor = nullptr; + ORT_RETURN_IF_NOT(info.TryGetConstantInput(zpIdx, &zeroPointTensor), "Zero point must be constant"); + + const float* scale = scaleTensor->Data(); + const int zeroPoint = GetIntScalar(zeroPointTensor); + tensor->info()->set_quantization_info(arm_compute::QuantizationInfo(*scale, flipZeroPoint ? -zeroPoint : zeroPoint)); + + return Status::OK(); +} + +void GetPackingInfo(std::vector>& state, size_t& packedSize, size_t& alignment) { + alignment = 0; + for (auto& tensor : state) { + alignment = std::max(alignment, tensor->allocator()->alignment()); + } + + packedSize = 0; + for (auto& tensor : state) { + const size_t size = tensor->info()->total_size(); + packedSize += ((size - 1) / alignment + 1) * alignment; + } +} + +Status LoadPackedTensors(std::vector>& state, void* packed, + const size_t packedSize, const size_t alignment) { + auto buffSize = packedSize + alignment; + uint8_t* alignedPtr = (uint8_t*)(alignment == 0 ? packed : std::align(alignment, packedSize, packed, buffSize)); + + uint8_t* currentPtr = alignedPtr; + for (auto& tensor : state) { + ORT_RETURN_IF_ERROR(ACLImportMemory(tensor->allocator(), currentPtr, 0)); + + const size_t size = tensor->info()->total_size(); + currentPtr += ((size - 1) / alignment + 1) * alignment; + } + + return Status::OK(); +} + +Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size) { ORT_UNUSED_PARAMETER(size); - return allocator->import_memory(memory); -#endif + arm_compute::Status status = allocator->import_memory(memory); + + if (status) { + return Status::OK(); + } else { + return Status(common::ONNXRUNTIME, common::FAIL, status.error_description()); + } } template @@ -71,12 +188,13 @@ void importDataToTensor(arm_compute::Tensor* tensor, const T* data) { arm_compute::execute_window_loop( aclInpuWindow, [&](const arm_compute::Coordinates& co) { - *reinterpret_cast(aclInputIt.ptr()) = data[index]; + *reinterpret_cast(aclInputIt.ptr()) = data[index]; index++; }, aclInputIt); } template void importDataToTensor(arm_compute::Tensor*, const float*); +template void importDataToTensor(arm_compute::Tensor*, const MLFloat16*); template void importDataFromTensor(arm_compute::Tensor* tensor, T* data) { @@ -89,12 +207,13 @@ void importDataFromTensor(arm_compute::Tensor* tensor, T* data) { arm_compute::execute_window_loop( aclInpuWindow, [&](const arm_compute::Coordinates& co) { - data[index] = *reinterpret_cast(aclInputIt.ptr()); + data[index] = *reinterpret_cast(aclInputIt.ptr()); index++; }, aclInputIt); } template void importDataFromTensor(arm_compute::Tensor*, float*); +template void importDataFromTensor(arm_compute::Tensor*, MLFloat16*); } // namespace acl } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_common.h b/onnxruntime/core/providers/acl/acl_common.h index 899736c477165..f2e89de15efd9 100644 --- a/onnxruntime/core/providers/acl/acl_common.h +++ b/onnxruntime/core/providers/acl/acl_common.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -7,17 +8,37 @@ #include "core/framework/op_kernel.h" // ACL +#include "arm_compute/core/experimental/Types.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" -#include "arm_compute/runtime/MemoryManagerOnDemand.h" namespace onnxruntime { namespace acl { +struct Workspace { + std::vector> temporary_tensors; + std::vector> prepare_tensors; + std::vector> persistent_tensors; +}; + +void PopulateWorkspace(const arm_compute::experimental::MemoryRequirements& reqs, + Workspace& workspace, arm_compute::MemoryGroup& memory_group, + arm_compute::ITensorPack& run_pack, arm_compute::ITensorPack& prep_pack); + arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned int extDim = 0); +Status GetArgShape(const NodeArg* tensor, TensorShape& outShape); void ACLPrintTensorShape(const char*, arm_compute::Tensor& t); -std::shared_ptr ACLCreateMemoryManager(); -arm_compute::Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size); +arm_compute::DataType ACLDataType(const std::string& dtype); + +int GetIntScalar(const Tensor* tensor); +Status LoadQuantizationInfo(const OpKernelInfo& info, arm_compute::Tensor* tensor, + const int scaleIdx, const int zpIdx, bool flipZeroPoint); + +void GetPackingInfo(std::vector>& state, size_t& packedSize, size_t& alignment); +Status LoadPackedTensors(std::vector>& state, void* packed, + const size_t packedSize, const size_t alignment); + +Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size); template void importDataToTensor(arm_compute::Tensor* tensor, const T* data); template diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index d19dc15e17f6d..8d34e36fe7cd6 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "acl_execution_provider.h" @@ -7,13 +8,19 @@ #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/compute_capability.h" +#include "core/providers/acl/math/matmul.h" +#include "core/providers/acl/nn/conv.h" +#include "core/session/inference_session.h" #include "contrib_ops/cpu/cpu_contrib_kernels.h" #include "acl_fwd.h" +#include "scheduler.h" -namespace onnxruntime { +#include "arm_compute/runtime/Scheduler.h" +#include "arm_compute/runtime/PoolManager.h" +#include "arm_compute/runtime/BlobLifetimeManager.h" +#include "arm_compute/runtime/Allocator.h" -constexpr const char* ACL = "Acl"; -constexpr const char* ACL_CPU = "AclCpu"; +namespace onnxruntime { namespace acl { @@ -22,7 +29,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 6, Rel class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Gemm); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 1, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, MLFloat16, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 8, 11, float, MaxPool); @@ -39,6 +47,22 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDoma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Concat); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, float, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, NhwcConv); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 13, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, uint8_t, MatMulIntegerToFloat); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, int8_t, MatMulIntegerToFloat); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 10, int8_t, QLinearConv); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, uint8_t, QLinearConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, int8_t, QLinearConv); Status RegisterACLKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -48,7 +72,8 @@ Status RegisterACLKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -67,6 +92,22 @@ Status RegisterACLKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -85,10 +126,22 @@ std::shared_ptr GetAclKernelRegistry() { return kernel_registry; } +std::shared_ptr ACLCreateMemoryManager() { + auto lifetime_mgr = std::make_shared(); + auto pool_mgr = std::make_shared(); + auto mm = std::make_shared(lifetime_mgr, pool_mgr); + + return mm; +} + } // namespace acl -ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo&) - : IExecutionProvider{onnxruntime::kAclExecutionProvider} {} +ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kAclExecutionProvider}, + info(info), + memory_manager(onnxruntime::acl::ACLCreateMemoryManager()) { + arm_compute::Scheduler::set(std::make_shared(this)); +} ACLExecutionProvider::~ACLExecutionProvider() {} @@ -97,4 +150,47 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const return kernel_registry; } +std::vector> +ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + std::vector> result; + for (const auto& node : graph.Nodes()) { + if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); + kernel_create_info != nullptr) { + Status support_status = Status::OK(); + const std::string op_name = kernel_create_info->kernel_def->OpName(); + + if (op_name == "Conv" || op_name == "NhwcConv" || op_name == "QLinearConv") { + support_status = onnxruntime::acl::ValidateConv(node); + } + if (op_name == "MatMul" || op_name == "FusedMatMul" || op_name == "MatMulIntegerToFloat") { + support_status = onnxruntime::acl::ValidateMatMul(node); + } + + if (support_status.IsOK()) { + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node.Index()); + result.push_back(std::make_unique(std::move(sub_graph))); + } else { + LOGS_DEFAULT(WARNING) << "ACL supports operator " << op_name + << ", but not with these parameters. Using fallback for node: " << node.Name() + << " Reason: " << support_status.ErrorMessage(); + } + } + } + + return result; +} + +Status ACLExecutionProvider::OnRunStart(const onnxruntime::RunOptions&) { + arm_compute::Allocator alloc{}; + memory_manager->populate(alloc, 1); + return Status::OK(); +}; + +Status ACLExecutionProvider::OnRunEnd(bool, const onnxruntime::RunOptions&) { + memory_manager->clear(); + return Status::OK(); +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index 126656e0956bb..1c267d8713673 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -1,20 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/framework/execution_provider.h" #include "core/graph/constants.h" +#include "core/platform/threadpool.h" + +#include "arm_compute/runtime/MemoryManagerOnDemand.h" namespace onnxruntime { // Information needed to construct ACL execution providers. struct ACLExecutionProviderInfo { - bool create_arena{true}; + bool enable_fast_math{false}; - explicit ACLExecutionProviderInfo(bool use_arena) - : create_arena(use_arena) {} + explicit ACLExecutionProviderInfo(bool enable_fast_math) + : enable_fast_math(enable_fast_math) {} ACLExecutionProviderInfo() = default; }; @@ -31,6 +35,26 @@ class ACLExecutionProvider : public IExecutionProvider { } std::shared_ptr GetKernelRegistry() const override; + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const override; + + Status OnRunStart(const onnxruntime::RunOptions&) override; + + Status OnRunEnd(bool, const onnxruntime::RunOptions&) override; + + void SetThreadPool(concurrency::ThreadPool* thread_pool) { + thread_pool_ = thread_pool; + } + + concurrency::ThreadPool* GetThreadPool() const { + return thread_pool_; + } + + const ACLExecutionProviderInfo info; + const std::shared_ptr memory_manager; + concurrency::ThreadPool* thread_pool_ = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_provider_factory.cc b/onnxruntime/core/providers/acl/acl_provider_factory.cc index 4eb11b222e576..26a41afeeee36 100755 --- a/onnxruntime/core/providers/acl/acl_provider_factory.cc +++ b/onnxruntime/core/providers/acl/acl_provider_factory.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/providers/acl/acl_provider_factory.h" @@ -11,27 +12,28 @@ namespace onnxruntime { struct ACLProviderFactory : IExecutionProviderFactory { - ACLProviderFactory(bool create_arena) : create_arena_(create_arena) {} + ACLProviderFactory(bool enable_fast_math) : enable_fast_math_(enable_fast_math) {} ~ACLProviderFactory() override {} std::unique_ptr CreateProvider() override; private: - bool create_arena_; + bool enable_fast_math_; }; std::unique_ptr ACLProviderFactory::CreateProvider() { ACLExecutionProviderInfo info; - info.create_arena = create_arena_; + info.enable_fast_math = enable_fast_math_; return std::make_unique(info); } -std::shared_ptr ACLProviderFactoryCreator::Create(int use_arena) { - return std::make_shared(use_arena != 0); +std::shared_ptr ACLProviderFactoryCreator::Create(bool enable_fast_math) { + return std::make_shared(enable_fast_math); } } // namespace onnxruntime -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, int use_arena) { - options->provider_factories.push_back(onnxruntime::ACLProviderFactoryCreator::Create(use_arena)); +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, + bool enable_fast_math) { + options->provider_factories.push_back(onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math)); return nullptr; } diff --git a/onnxruntime/core/providers/acl/acl_provider_factory_creator.h b/onnxruntime/core/providers/acl/acl_provider_factory_creator.h index 2eee50ee710da..31a596f2d4bbc 100644 --- a/onnxruntime/core/providers/acl/acl_provider_factory_creator.h +++ b/onnxruntime/core/providers/acl/acl_provider_factory_creator.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -10,7 +11,7 @@ namespace onnxruntime { struct ACLProviderFactoryCreator { - static std::shared_ptr Create(int use_arena); + static std::shared_ptr Create(bool enable_fast_math); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/math/gemm.h b/onnxruntime/core/providers/acl/math/gemm.h index f5288d7f231b0..5db2372705184 100644 --- a/onnxruntime/core/providers/acl/math/gemm.h +++ b/onnxruntime/core/providers/acl/math/gemm.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -28,7 +29,6 @@ namespace acl { typedef struct { std::shared_ptr layer; std::shared_ptr a, b, c, d; - std::shared_ptr mm_layer; } ACLNEGEMM; typedef std::map::iterator GEMMLayersIterator; @@ -37,6 +37,9 @@ template class Gemm : public onnxruntime::Gemm { public: Gemm(const OpKernelInfo& info) : onnxruntime::Gemm(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + int64_t temp; ORT_ENFORCE(info.GetAttr("transA", &temp).IsOK()); @@ -49,12 +52,11 @@ class Gemm : public onnxruntime::Gemm { } Status Compute(OpKernelContext* context) const override { -#ifdef ACL_2308 if (this->packed_b_) { // Prepacked RHS not supported, defaulting to cpu execution provider return onnxruntime::Gemm::Compute(context); } -#endif + const auto A = context->Input(0); const auto B = context->Input(1); const auto C = context->Input(2); @@ -96,19 +98,20 @@ class Gemm : public onnxruntime::Gemm { (cShape[1] == 1 && cShape[0] != (long unsigned int)N)) { return onnxruntime::Gemm::Compute(context); } -#ifdef ACL_2308 cShape = arm_compute::TensorShape(N); LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {" << N << "}"; -#else - cShape = arm_compute::TensorShape(1, N); - LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {1," << N << "}"; -#endif } int64_t K = helper.K(); - if (A) LOGS_DEFAULT(VERBOSE) << "A " << A->Shape().ToString().c_str(); - if (B) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); - if (C) LOGS_DEFAULT(VERBOSE) << "C " << C->Shape().ToString().c_str(); + if (A) { + LOGS_DEFAULT(VERBOSE) << "A " << A->Shape().ToString().c_str(); + } + if (B) { + LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); + } + if (C) { + LOGS_DEFAULT(VERBOSE) << "C " << C->Shape().ToString().c_str(); + } LOGS_DEFAULT(VERBOSE) << "D " << D->Shape().ToString().c_str(); LOGS_DEFAULT(VERBOSE) << "M " << (int)M << ", N " << (int)N << ", K " << (int)K; LOGS_DEFAULT(VERBOSE) << "Alfa " << alpha_ << ", Beta " << beta_; @@ -131,10 +134,8 @@ class Gemm : public onnxruntime::Gemm { // dimensions are stored in the opposite order to ACL's tGEMM.d->allocator()->init(arm_compute::TensorInfo(arm_compute::TensorShape(N, M), arm_compute::Format::F32)); - tGEMM.mm_layer = ACLCreateMemoryManager(); - if (FC) { - auto layer = std::make_shared(tGEMM.mm_layer); + auto layer = std::make_shared(provider_->memory_manager); arm_compute::FullyConnectedLayerInfo fc_info; fc_info.transpose_weights = trans_B_ == CblasTrans; layer->configure(tGEMM.a.get(), tGEMM.b.get(), useC ? tGEMM.c.get() : nullptr, tGEMM.d.get(), fc_info); @@ -173,10 +174,7 @@ class Gemm : public onnxruntime::Gemm { ACLPrintTensorShape("c", *pGEMM->c); ACLPrintTensorShape("d", *pGEMM->d); - arm_compute::Allocator alloc_mm{}; - pGEMM->mm_layer->populate(alloc_mm, 1); pGEMM->layer->run(); - pGEMM->mm_layer->clear(); if (D->Shape().Size() != 0 && pGEMM->d->info()->has_padding()) { importDataFromTensor(pGEMM->d.get(), d_data); @@ -195,6 +193,7 @@ class Gemm : public onnxruntime::Gemm { } private: + ACLExecutionProvider* provider_; static thread_local std::map gemmLayers; CBLAS_TRANSPOSE trans_A_; diff --git a/onnxruntime/core/providers/acl/math/matmul.cc b/onnxruntime/core/providers/acl/math/matmul.cc new file mode 100644 index 0000000000000..468b394471c13 --- /dev/null +++ b/onnxruntime/core/providers/acl/math/matmul.cc @@ -0,0 +1,404 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include +#include +#include "core/common/status.h" +#include "core/framework/op_kernel_info.h" +#include "core/framework/op_node_proto_helper.h" +#include "core/framework/tensor_shape.h" +#ifdef _WIN32 +#pragma warning(disable : 4244) +#endif +#include +#include + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math.h" +#include "core/util/math_cpuonly.h" + +#include "core/providers/acl/math/matmul.h" +#include "core/providers/acl/acl_common.h" +#include "core/providers/acl/acl_fwd.h" + +// ACL +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/NEON/functions/NEMatMul.h" +#include "src/cpu/operators/CpuGemm.h" +#include "src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h" +#include "src/cpu/operators/CpuMatMul.h" + +namespace onnxruntime { + +namespace acl { + +TensorShape BroadcastInput(const TensorShape& shape, bool prependDim) { + const auto nd = shape.NumDimensions(); + if (nd == 0) { + ORT_THROW("MatMul by scalar not allowed"); + } + + int64_t batchSize = 1; + if (nd == 1) { + if (prependDim) { + return {1, 1, shape[0]}; + } else { + return {1, shape[0], 1}; + } + } + + for (size_t i = 0; i < nd - 2; i++) { + batchSize *= shape[i]; + } + + return {batchSize, shape[nd - 2], shape[nd - 1]}; +} + +struct MatMulConfig { + bool isQuantized; + float alpha; + bool transA; + bool transB; + TensorShape aShapeBroadcast; + TensorShape bShapeBroadcast; +}; + +Status ParseMatMul(const onnxruntime::Node& node, MatMulConfig& config) { + onnxruntime::ProtoHelperNodeContext ctx(node); + onnxruntime::OpNodeProtoHelper attrs(&ctx); + const auto inputDefs = node.InputDefs(); + + config.isQuantized = node.OpType() == "MatMulIntegerToFloat"; + + config.alpha = 1; + attrs.GetAttr("alpha", &config.alpha); + + int64_t transA = 0; + attrs.GetAttr("transA", &transA); + int64_t transB = 0; + attrs.GetAttr("transB", &transB); + + config.transA = transA; + config.transB = transB; + + const int64_t transBatchA = attrs.GetAttrOrDefault("transBatchA", 0); + const int64_t transBatchB = attrs.GetAttrOrDefault("transBatchB", 0); + + ORT_RETURN_IF(transBatchA, "transBatchA not supported by ACL"); + ORT_RETURN_IF(transBatchB, "transBatchB not supported by ACL"); + + ORT_RETURN_IF(config.isQuantized && inputDefs.size() >= 7, "ACL MatMulIntegerToFloat does not support bias"); + + TensorShape aShapeIn; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[0], aShapeIn)); + + TensorShape bShapeIn; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[1], bShapeIn)); + + config.aShapeBroadcast = BroadcastInput(aShapeIn, !config.transA); + config.bShapeBroadcast = BroadcastInput(bShapeIn, config.transB); + + ORT_RETURN_IF(!(config.bShapeBroadcast[0] == 1 || (config.aShapeBroadcast[0] == config.bShapeBroadcast[0])), + "ACL does not support broadcasting"); + + ORT_RETURN_IF(config.alpha != 1 && config.bShapeBroadcast[0] > 1, + "ACL does not support alpha scaling with batched B"); + + return Status::OK(); +} + +Status ValidateMatMul(const onnxruntime::Node& node) { + MatMulConfig config; + return ParseMatMul(node, config); +} + +MatMul::MatMul(const OpKernelInfo& info) : onnxruntime::OpKernel(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + + const auto inputDefs = OpKernel::Node().InputDefs(); + const auto outputDefs = OpKernel::Node().OutputDefs(); + + const Tensor* tmp = nullptr; + const bool aIsConst = info.TryGetConstantInput(0, &tmp); + const bool bIsConst = info.TryGetConstantInput(1, &tmp); + + MatMulConfig config; + ORT_THROW_IF_ERROR(ParseMatMul(OpKernel::Node(), config)); + + ORT_THROW_IF_ERROR(GetArgShape(outputDefs[0], outShape)); + if (outShape.Size() == 0) { + return; + } + + const TensorShape aShape{ + config.aShapeBroadcast[0], + config.aShapeBroadcast[config.transA ? 2 : 1], + config.aShapeBroadcast[config.transA ? 1 : 2]}; + + const TensorShape bShape{ + config.bShapeBroadcast[0], + config.bShapeBroadcast[config.transB ? 2 : 1], + config.bShapeBroadcast[config.transB ? 1 : 2]}; + + const TensorShape outShapeBroadcast{aShape[0], aShape[1], bShape[2]}; + + ORT_ENFORCE(outShape.Size() == outShapeBroadcast.Size(), "Output sizes do not match"); + + arm_compute::DataType aType = ACLDataType(*inputDefs[0]->Type()); + arm_compute::DataType bType = ACLDataType(*inputDefs[1]->Type()); + arm_compute::DataType outType = ACLDataType(*outputDefs[0]->Type()); + + arm_compute::GEMMInfo gemmInfo(false, false, bIsConst); + gemmInfo.set_fast_math(provider_->info.enable_fast_math); + + a = std::make_shared(); + b = std::make_shared(); + out = std::make_shared(); + + a->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(config.aShapeBroadcast), 1, aType)); + b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(config.bShapeBroadcast), 1, bType)); + out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(outShapeBroadcast), 1, outType)); + + if (config.isQuantized) { + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, a.get(), 2, 4, true)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, b.get(), 3, 5, true)); + } + + arm_compute::ITensor* a_to_use = a.get(); + if (config.transA) { + a_transposed = std::make_shared(); + a_transposed->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(aShape), 1, aType)); + a_to_use = a_transposed.get(); + + a_permute = std::make_shared(); + a_permute->configure(a.get(), a_transposed.get(), {1, 0, 2}); + } + + arm_compute::ITensor* b_to_use = b.get(); + if (config.transB) { + if (bIsConst) { + workspace.persistent_tensors.emplace_back(std::make_unique()); + b_transposed = workspace.persistent_tensors.back().get(); + } else { + workspace.temporary_tensors.emplace_back(std::make_unique()); + b_transposed = workspace.temporary_tensors.back().get(); + } + + b_transposed->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(bShape), 1, bType), 128); + b_to_use = b_transposed; + + b_permute = std::make_shared(); + b_permute->configure(b.get(), b_transposed, {1, 0, 2}); + } + + a_to_use->info()->set_are_values_constant(aIsConst); + b_to_use->info()->set_are_values_constant(bIsConst); + + if (config.bShapeBroadcast[0] > 1) { + arm_compute::CpuMatMulSettings settings; + settings.fast_math(provider_->info.enable_fast_math); + + a_to_use->info()->set_are_values_constant(false); + b_to_use->info()->set_are_values_constant(false); + + const auto matmul = std::make_shared(); + matmul->configure(a_to_use->info(), b_to_use->info(), out->info(), {}, settings, {}); + layer = std::move(matmul); + } else if (config.isQuantized) { + const auto gemm = std::make_shared(); + gemm->configure(a_to_use->info(), b_to_use->info(), nullptr, out->info(), gemmInfo); + layer = std::move(gemm); + } else { + const auto gemm = std::make_shared(); + gemm->configure(a_to_use->info(), b_to_use->info(), nullptr, out->info(), config.alpha, 0.f, gemmInfo); + layer = std::move(gemm); + } + + memory_group = arm_compute::MemoryGroup(provider_->memory_manager); + run_pack = {{arm_compute::ACL_SRC_0, a_to_use}, {arm_compute::ACL_SRC_1, b_to_use}, {arm_compute::ACL_DST, out.get()}}; + prep_pack = {{arm_compute::ACL_SRC_1, b_to_use}}; + + PopulateWorkspace(layer->workspace(), workspace, memory_group, run_pack, prep_pack); +} + +Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (input_idx != 1 || outShape.Size() == 0) { + return Status::OK(); + } + + const uint8_t* data = (uint8_t*)tensor.DataRaw(); + + ORT_RETURN_IF_ERROR(ACLImportMemory(b->allocator(), (void*)data, 0)); + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + auto buffSize = packedSize + alignment; + + pbRaw = IAllocator::MakeUniquePtr(alloc, buffSize, true); + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, pbRaw.get(), packedSize, alignment)); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(pbRaw)); + prepacked_weights->buffer_sizes_.push_back(buffSize); + } + + is_packed = true; + } + + if (b_transposed) { + b_permute->run(); + } + + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->allocate(); + } + + layer->prepare(prep_pack); + + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->free(); + } + + return Status::OK(); +} + +Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (input_idx != 1) { + return Status::OK(); + } + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, prepacked_buffers[0].get(), packedSize, alignment)); + + used_shared_buffers = true; + } + + return Status::OK(); +} + +Status MatMul::Compute(OpKernelContext* context) const { + provider_->SetThreadPool(context->GetOperatorThreadPool()); + + const Tensor* A = context->Input(0); + const Tensor* B = pbRaw ? nullptr : context->Input(1); + + Tensor* outOrt = context->Output(0, outShape); + + if (outShape.Size() == 0) { + return Status::OK(); + } + + const void* a_data = A->DataRaw(); + const void* b_data = B == nullptr ? nullptr : B->DataRaw(); + void* out_data = outOrt->MutableDataRaw(); + + ORT_RETURN_IF(A->Shape().Size() != 0 && a->info()->has_padding(), "Padded ACL input tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(a->allocator(), (void*)a_data, 0)); + + if (b_data != nullptr) { + ORT_RETURN_IF_ERROR(ACLImportMemory(b->allocator(), (void*)b_data, 0)); + } + + ORT_RETURN_IF(outOrt->Shape().Size() != 0 && out->info()->has_padding(), "Padded ACL output tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(out->allocator(), (void*)out_data, 0)); + + ORT_RETURN_IF(B != nullptr && workspace.persistent_tensors.size(), "Persistent state requires pre-packing"); + + if (a_transposed) { + a_transposed->allocator()->allocate(); + a_permute->run(); + } + + { + arm_compute::MemoryGroupResourceScope scope_mg(const_cast(memory_group)); + if (b_transposed && B) { + b_permute->run(); + } + + layer->run(const_cast(run_pack)); + } + + a->allocator()->free(); + if (B != nullptr) + b->allocator()->free(); + out->allocator()->free(); + + if (a_transposed) { + a_transposed->allocator()->free(); + } + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMul, + kOnnxDomain, + 13, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMul, + kOnnxDomain, + 13, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_KERNEL_EX( + FusedMatMul, + kMSDomain, + 1, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FusedMatMul, + kMSDomain, + 1, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulIntegerToFloat, + kMSDomain, + 1, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulIntegerToFloat, + kMSDomain, + 1, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMul); + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/math/matmul.h b/onnxruntime/core/providers/acl/math/matmul.h new file mode 100644 index 0000000000000..b137e33833de9 --- /dev/null +++ b/onnxruntime/core/providers/acl/math/matmul.h @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once +#include "core/framework/op_kernel.h" +#include "core/providers/acl/acl_common.h" +#include "core/providers/acl/acl_execution_provider.h" + +// ACL +#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/IOperator.h" +#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/runtime/TensorAllocator.h" +#include "arm_compute/runtime/Allocator.h" +#include "arm_compute/runtime/PoolManager.h" +#include "arm_compute/runtime/BlobLifetimeManager.h" +#include "arm_compute/runtime/MemoryManagerOnDemand.h" + +// NEON +#include "arm_compute/runtime/NEON/functions/NEGEMM.h" +#include "arm_compute/runtime/NEON/functions/NEPermute.h" + +namespace onnxruntime { +namespace acl { + +Status ValidateMatMul(const onnxruntime::Node& node); + +class MatMul : public OpKernel { + public: + explicit MatMul(const OpKernelInfo& info); + + Status PrePack(const Tensor&, int, AllocatorPtr, + bool& is_packed, PrePackedWeights*) override; + + Status UseSharedPrePackedBuffers(std::vector&, + int, bool&) override; + + Status Compute(OpKernelContext* context) const override; + + protected: + ACLExecutionProvider* provider_; + std::shared_ptr a_permute; + std::shared_ptr b_permute; + std::shared_ptr layer; + + arm_compute::MemoryGroup memory_group; + arm_compute::ITensorPack run_pack; + arm_compute::ITensorPack prep_pack; + + Workspace workspace; + + std::shared_ptr a; + std::shared_ptr b; + std::shared_ptr a_transposed; + arm_compute::Tensor* b_transposed = nullptr; + std::shared_ptr out; + arm_compute::Tensor* pb; + + IAllocatorUniquePtr pbRaw; + TensorShape outShape; +}; +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.cc b/onnxruntime/core/providers/acl/nn/batch_norm.cc index be0e57c5c0543..192bc34556eef 100755 --- a/onnxruntime/core/providers/acl/nn/batch_norm.cc +++ b/onnxruntime/core/providers/acl/nn/batch_norm.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/common/common.h" @@ -80,7 +81,6 @@ Status BatchNorm::Compute(OpKernelContext* context) const { auto layer = std::make_shared(); -#ifdef ACL_2308 arm_compute::TensorShape in_x_shape; const TensorShape& x_shape = X->Shape(); const auto& dims_vec = x_shape.GetDims(); @@ -94,9 +94,6 @@ Status BatchNorm::Compute(OpKernelContext* context) const { in_x_shape.set(2, onnxruntime::narrow(dims_vec[1])); // C tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(in_x_shape, arm_compute::Format::F32)); -#else - tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape()), arm_compute::Format::F32)); -#endif tbatch_norm.out->allocator()->init(arm_compute::TensorInfo(tbatch_norm.in->info()->tensor_shape(), arm_compute::Format::F32)); tbatch_norm.scale->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(S->Shape()), arm_compute::Format::F32)); diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index 85bd0cfe96279..a62158f1c26ee 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #ifdef _WIN32 @@ -19,31 +20,81 @@ // ACL #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/ITensorPack.h" +#include "src/cpu/operators/CpuConv2d.h" // NEON -#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h" -#ifdef ACL_1902 -#include "arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h" -#endif -#if defined(ACL_1905) || defined(ACL_1908) -#include "arm_compute/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.h" -#endif - #define CONV_ACL #undef DEPTHWISE_CPU #define PREF_DIM 4 namespace onnxruntime { + namespace acl { -template -thread_local std::map Conv::convLayers; +struct ConvConfig { + bool isQuantized; + bool is_channels_last; + bool isDepthwise; + TensorShape inShapeIn; + TensorShape kShapeIn; + const std::string* inType; + const std::string* kType; +}; + +Status ParseConv(const onnxruntime::Node& node, ConvConfig& config) { + onnxruntime::ProtoHelperNodeContext ctx(node); + onnxruntime::OpNodeProtoHelper attrs(&ctx); + const auto inputDefs = node.InputDefs(); + + config.isQuantized = node.OpType() == "QLinearConv"; + + if (config.isQuantized) { + TensorShape scaleShape; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[4], scaleShape)); + ORT_RETURN_IF(scaleShape.Size() > 1, "ACL execution provider does not support per-channel quantization"); + } + + config.is_channels_last = node.OpType() == "NhwcConv"; + if (!config.is_channels_last) { + int64_t cl_ret = 0; + attrs.GetAttr("channels_last", &cl_ret); + config.is_channels_last = (bool)cl_ret; + } -template -arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const { + int64_t group = 1; + attrs.GetAttr("group", &group); + + const NodeArg* kDef = inputDefs[config.isQuantized ? 3 : 1]; + + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[0], config.inShapeIn)); + ORT_RETURN_IF_ERROR(GetArgShape(kDef, config.kShapeIn)); + + ORT_RETURN_IF(config.kShapeIn.NumDimensions() > 4, "ACL execution provider supports 1D and 2D Conv only"); + + config.inType = inputDefs[0]->Type(); + config.kType = kDef->Type(); + const bool mixedType = config.inType != config.kType; + + config.isDepthwise = group > 1; + if (config.isDepthwise) { + const size_t channels = config.inShapeIn[config.is_channels_last ? config.inShapeIn.NumDimensions() - 1 : 1]; + ORT_RETURN_IF(group != channels, "ACL does not support grouping unless group == channels"); + ORT_RETURN_IF(mixedType, "ACL does not support mixed input types for depthwise Conv"); + } + + return Status::OK(); +} + +Status ValidateConv(const onnxruntime::Node& node) { + ConvConfig config; + return ParseConv(node, config); +} + +arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const { arm_compute::TensorShape shape = arm_compute::TensorShape(kernel->info()->tensor_shape()); shape[2] = shape[2] * shape[3]; shape[3] = 1; @@ -51,43 +102,89 @@ arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor return shape; } -#ifdef CONV_ACL -template -Status Conv::Compute(OpKernelContext* context) const { +Conv::Conv(const OpKernelInfo& info) : onnxruntime::OpKernel(info), conv_attrs_(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + + ConvConfig config; + ORT_THROW_IF_ERROR(ParseConv(OpKernel::Node(), config)); + isQuantized = config.isQuantized; + is_channels_last = config.is_channels_last; + size_t num_inputs = OpKernel::Node().InputDefs().size(); + has_bias = isQuantized ? (num_inputs == 9) : (num_inputs == 3); - ACLNEConv* pConv; - ConvLayersIterator it = Conv::convLayers.find((OpKernel*)this); - if (it != Conv::convLayers.end()) { - pConv = &it->second; - if (pConv->isDepthwiseCPU == true) { - Status s = onnxruntime::Conv::Compute(context); - return s; - } + const Tensor* tmp = nullptr; + const bool kIsConst = info.TryGetConstantInput(1, &tmp); + ORT_ENFORCE(kIsConst, "ACL does not support Conv with mutable weights"); + + in = std::make_shared(); + k = std::make_shared(); + if (has_bias) + b = std::make_shared(); + out = std::make_shared(); + + const arm_compute::DataLayout data_layout = is_channels_last ? arm_compute::DataLayout::NHWC : arm_compute::DataLayout::NCHW; + + TensorShape inShape = config.inShapeIn; + if (is_channels_last && config.inShapeIn.NumDimensions() < 4) { + inShape = TensorShape({config.inShapeIn[0], config.inShapeIn[1], 1, config.inShapeIn[2]}); } - const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; + arm_compute::DataType inType = ACLDataType(*config.inType); + in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(inShape, PREF_DIM), 1, inType, data_layout)); + + arm_compute::DataType kType = ACLDataType(*config.kType); + + TensorShapeVector kShapeVec = config.kShapeIn.AsShapeVector(); + while (kShapeVec.size() < 4) { + kShapeVec.push_back(1); + } + + const TensorShape kShape = is_channels_last ? TensorShape({kShapeVec[0], kShapeVec[2], kShapeVec[3], kShapeVec[1]}) : TensorShape(kShapeVec); + + k->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(kShape), 1, kType, data_layout)); + + TensorShape bShape; + if (has_bias) { + const Tensor* bias = nullptr; + const bool biasIsConst = info.TryGetConstantInput(isQuantized ? 8 : 2, &bias); + ORT_ENFORCE(biasIsConst, "ACL does not support Conv with mutable bias"); + + const auto bDef = OpKernel::Node().InputDefs()[isQuantized ? 8 : 2]; + ORT_THROW_IF_ERROR(GetArgShape(bDef, bShape)); + arm_compute::DataType bType = ACLDataType(*bDef->Type()); + b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(bShape), 1, bType, data_layout)); + + const void* b_data = bias->DataRaw(); + ORT_THROW_IF_ERROR(ACLImportMemory(b->allocator(), (void*)b_data, 0)); + } + + ORT_THROW_IF_ERROR(GetArgShape(OpKernel::Node().OutputDefs()[0], outShape)); + TensorShape outShapeACL = outShape; + if (is_channels_last && outShape.NumDimensions() < 4) { + outShapeACL = TensorShape({outShape[0], outShape[1], 1, outShape[2]}); + } + + out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(outShapeACL, PREF_DIM), 1, inType, data_layout)); - const int64_t N = X->Shape()[0]; - const int64_t M = W->Shape()[0]; + if (isQuantized) { + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, in.get(), 1, 2, false)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, k.get(), 4, 5, false)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, out.get(), 6, 7, false)); + } LOGS_DEFAULT(VERBOSE) << "Conv ACL:"; - LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str(); - LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str(); - if (B != nullptr) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); - - if (X->Shape().NumDimensions() != PREF_DIM) { - LOGS_DEFAULT(WARNING) << "ACL does not have support for tensors with 4 or more dimensions; defaulting to cpu implementation"; - Status s = onnxruntime::Conv::Compute(context); - return s; + LOGS_DEFAULT(VERBOSE) << "X " << inShape.ToString().c_str(); + LOGS_DEFAULT(VERBOSE) << "W " << config.kShapeIn.ToString().c_str(); + if (has_bias) { + LOGS_DEFAULT(VERBOSE) << "B " << bShape.ToString().c_str(); } - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); + ORT_THROW_IF_ERROR(conv_attrs_.ValidateInputShape(config.inShapeIn, config.kShapeIn, config.is_channels_last)); TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_THROW_IF_ERROR(conv_attrs_.ComputeKernelShape(config.kShapeIn, kernel_shape)); ConvAttributes::ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { @@ -102,16 +199,13 @@ Status Conv::Compute(OpKernelContext* context) const { strides.resize(kernel_shape.size(), 1); } - TensorShapeVector Y_dims; - Y_dims.insert(Y_dims.begin(), {N, M}); - TensorShape input_shape = X->Shape().Slice(2); -#ifdef ACL_2308 - ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); -#else - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); -#endif - Tensor* Y = context->Output(0, TensorShape(Y_dims)); - LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str(); + TensorShape input_shape = config.inShapeIn.Slice(2); + TensorShapeVector out_shape; + ORT_THROW_IF_ERROR(conv_attrs_.InferPadsAndOutputShape( + input_shape, kernel_shape, strides, dilations, + pads, out_shape)); + + LOGS_DEFAULT(VERBOSE) << "Y " << outShape.ToString().c_str(); arm_compute::ActivationLayerInfo::ActivationFunction acl_activ_func; bool acl_activ_enabled = false; @@ -136,243 +230,274 @@ Status Conv::Compute(OpKernelContext* context) const { ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_type); } - if (it == Conv::convLayers.end()) { - auto mm_layer = ACLCreateMemoryManager(); - - ACLNEConv tconv; - tconv.mm_layer = std::move(mm_layer); + const size_t idx_channel = arm_compute::get_data_layout_dimension_index(data_layout, arm_compute::DataLayoutDimension::CHANNEL); + isDepthwiseCPU = config.isDepthwise; - tconv.in = std::make_shared(); - tconv.k = std::make_shared(); - if (B != nullptr) - tconv.b = std::make_shared(); - tconv.out = std::make_shared(); + std::vector aclStrides(2); + aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; + aclStrides[1] = strides[0]; - tconv.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape(), PREF_DIM), arm_compute::Format::F32)); - tconv.k->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - if (B != nullptr) { - tconv.b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(B->Shape()), arm_compute::Format::F32)); - } - tconv.out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(Y->Shape(), PREF_DIM), arm_compute::Format::F32)); - - const arm_compute::DataLayout data_layout = tconv.in->info()->data_layout(); - const int idx_channel = arm_compute::get_data_layout_dimension_index(data_layout, arm_compute::DataLayoutDimension::CHANNEL); - bool isDepthwise = (conv_attrs_.group > 1 && conv_attrs_.group == tconv.in->info()->tensor_shape()[idx_channel]); - tconv.isDepthwiseCPU = isDepthwise; - - std::vector aclStrides(2); - aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; - aclStrides[1] = strides[0]; - - std::vector aclPads(4); - // The pad order in acl is: pad_left, pad_right, pad_top, pad_bottom - if (pads.size() == 2) { - if (strides.size() == 1) { - aclPads[0] = 0; - aclPads[1] = 0; - aclPads[2] = pads[1]; - aclPads[3] = pads[0]; - } else { - aclPads[0] = pads[1]; - aclPads[1] = pads[0]; - aclPads[2] = pads[1]; - aclPads[3] = pads[0]; - } + std::vector aclPads(4); + // The pad order in acl is: pad_left, pad_right, pad_top, pad_bottom + if (pads.size() == 2) { + if (strides.size() == 1) { + aclPads[0] = 0; + aclPads[1] = 0; + aclPads[2] = pads[0]; + aclPads[3] = pads[1]; } else { aclPads[0] = pads[1]; - aclPads[1] = pads[3]; - aclPads[2] = pads[0]; - aclPads[3] = pads[2]; + aclPads[1] = pads[0]; + aclPads[2] = pads[1]; + aclPads[3] = pads[0]; } + } else { + aclPads[0] = pads[1]; + aclPads[1] = pads[3]; + aclPads[2] = pads[0]; + aclPads[3] = pads[2]; + } - arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo(aclStrides[0], aclStrides[1], - aclPads[0], aclPads[1], aclPads[2], aclPads[3], arm_compute::DimensionRoundingType::FLOOR); - unsigned int aclDilation0 = (dilations.size() == 2) ? dilations[1] : 1; - - LOGS_DEFAULT(VERBOSE) << "padding: {" << aclPads[0] << "," << aclPads[1] << "," << aclPads[2] << "," << aclPads[3] << "}"; - LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; - - if (isDepthwise) { - LOGS_DEFAULT(VERBOSE) << "Depthwise convolution"; -#ifdef DEPTHWISE_CPU - Status s = onnxruntime::Conv::Compute(context); - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - return s; -#else - tconv.k->info()->set_tensor_shape(ACLReshapeWeightsDepthwise(tconv.k.get())); - - // in the configure function for NEDepthwiseConvolutionLayer3x3, there is a separation based on the optimization -#ifdef ACL_1902 - bool optimizable = - arm_compute::NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(tconv.in->info()->tensor_shape(), - aclPadStride, - tconv.in->info()->data_type(), - 1 /* depth multiplier */, - tconv.in->info()->data_layout()); -#elif defined(ACL_1905) || defined(ACL_1908) - bool optimizable = - arm_compute::NEDepthwiseConvolutionAssemblyDispatch::is_optimized_supported(tconv.in->info(), - tconv.k->info(), - aclPadStride, - 1 /* depth multiplier */, - arm_compute::Size2D(aclDilation0, dilations[0])); -#elif defined(ACL_2002) - bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayerOptimized::validate(tconv.in->info(), - tconv.k->info(), - (B != nullptr) ? tconv.b->info() : nullptr, - tconv.out->info(), - aclPadStride, - 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0]))); -#elif defined(ACL_2308) - bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayer::validate(tconv.in->info(), - tconv.k->info(), - (B != nullptr) ? tconv.b->info() : nullptr, - tconv.out->info(), - aclPadStride, - 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0]))); -#endif + arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo( + (unsigned int)aclStrides[0], (unsigned int)aclStrides[1], + (unsigned int)aclPads[0], (unsigned int)aclPads[1], + (unsigned int)aclPads[2], (unsigned int)aclPads[3], arm_compute::DimensionRoundingType::FLOOR); + size_t aclDilation0 = (dilations.size() == 2) ? dilations[1] : 1; + + LOGS_DEFAULT(VERBOSE) << "padding: {" << aclPads[0] << "," << aclPads[1] << "," << aclPads[2] << "," << aclPads[3] << "}"; + LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; + + if (config.isDepthwise) { + LOGS_DEFAULT(VERBOSE) << "Depthwise convolution"; + k->info()->set_tensor_shape(ACLReshapeWeightsDepthwise(k.get())); + auto dl = std::make_shared(); + dl->configure(in.get(), k.get(), (has_bias) ? b.get() : nullptr, out.get(), + aclPadStride, 1 /* depth multiplier */, + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + arm_compute::Size2D(aclDilation0, dilations[0])); + depthwise_layer = std::move(dl); + isDepthwiseCPU = false; + } else { + LOGS_DEFAULT(VERBOSE) << "ACL 2D convolution"; + auto cl = std::make_shared(); + cl->configure(in->info(), k->info(), (has_bias) ? b->info() : nullptr, out->info(), + aclPadStride, + arm_compute::WeightsInfo(), arm_compute::Size2D(aclDilation0, dilations[0]), + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + provider_->info.enable_fast_math, (unsigned int)conv_attrs_.group); + conv_layer = std::move(cl); + + memory_group = arm_compute::MemoryGroup(provider_->memory_manager); + run_pack = {{arm_compute::ACL_SRC_0, in.get()}, {arm_compute::ACL_SRC_1, k.get()}, {arm_compute::ACL_SRC_2, b.get()}, {arm_compute::ACL_DST, out.get()}}; + prep_pack = {{arm_compute::ACL_SRC_1, k.get()}, {arm_compute::ACL_SRC_2, b.get()}}; + + PopulateWorkspace(conv_layer->workspace(), workspace, memory_group, run_pack, prep_pack); + } - if (optimizable) { - LOGS_DEFAULT(VERBOSE) << "ACL optimized depthwise convolution"; -#if defined(ACL_1902) || defined(ACL_1905) - auto layer = std::make_shared(); -#elif defined(ACL_1908) - auto layer = std::make_shared(); -#elif defined(ACL_2002) || defined(ACL_2308) - auto layer = std::make_shared(); -#endif + ACLPrintTensorShape("X", *in.get()); + ACLPrintTensorShape("Y", *out.get()); +} -#ifdef ACL_1902 - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo()); -#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) || defined(ACL_2308) - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0])); -#endif - tconv.layer = std::move(layer); - tconv.isDepthwiseCPU = false; - } else { - LOGS_DEFAULT(VERBOSE) << "CPU depthwise convolution"; - Status s = onnxruntime::Conv::Compute(context); - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - return s; - } -#endif // DEPTHWISE_CPU - } else { - if (tconv.k->info()->tensor_shape()[0] == 1 && tconv.k->info()->tensor_shape()[1] == 1) { - LOGS_DEFAULT(VERBOSE) << "CPU pointwise convolution"; - Status s = onnxruntime::Conv::Compute(context); - return s; - } else { - if (tconv.k->info()->tensor_shape()[0] == 9 && tconv.k->info()->tensor_shape()[1] == 9) { - LOGS_DEFAULT(WARNING) << "9x9 DirectConvolution does not have an implementation in NCHW layout; defaulting to cpu implementation"; - Status s = onnxruntime::Conv::Compute(context); - return s; - } - LOGS_DEFAULT(VERBOSE) << "ACL 2D convolution"; - auto layer = std::make_shared(mm_layer); - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, - arm_compute::WeightsInfo(), arm_compute::Size2D(aclDilation0, dilations[0]), - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - false, conv_attrs_.group); - tconv.layer = std::move(layer); - } +#ifdef CONV_ACL +Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { + return Status::OK(); + } + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + auto buffSize = packedSize + alignment; + + pkRaw = IAllocator::MakeUniquePtr(alloc, buffSize, true); + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, pkRaw.get(), packedSize, alignment)); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(pkRaw)); + prepacked_weights->buffer_sizes_.push_back(buffSize); } - tconv.out->info()->set_format(tconv.in->info()->format()); + is_packed = true; + } - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - pConv = &ret.first->second; + bool free_k = false; + const void* k_data = tensor.DataRaw(); + if (is_channels_last) { + TensorShape shape = tensor.Shape(); + if (shape.NumDimensions() < 4) { + shape = TensorShape({shape[0], shape[1], shape[2], 1}); + } - ACLPrintTensorShape("X", *tconv.in.get()); - ACLPrintTensorShape("Y", *tconv.out.get()); + arm_compute::Tensor kIn; + kIn.allocator()->init(arm_compute::TensorInfo(ACLTensorShape(shape), 1, + k->info()->data_type(), arm_compute::DataLayout::NCHW)); + kIn.info()->set_quantization_info(k->info()->quantization_info()); + ORT_RETURN_IF_ERROR(ACLImportMemory(kIn.allocator(), (void*)k_data, 0)); + k->allocator()->allocate(); + free_k = is_packed; + is_packed = true; + + arm_compute::NEPermute perm_layer; + perm_layer.configure(&kIn, k.get(), {2, 0, 1, 3}); + perm_layer.run(); } else { - // TODO: valildate shapes - pConv = &it->second; + ORT_RETURN_IF_ERROR(ACLImportMemory(k->allocator(), (void*)k_data, 0)); } - const T* x_data = X->Data(); - if (X->Shape().Size() != 0 && pConv->in->info()->has_padding()) { - pConv->in->allocator()->allocate(); - importDataToTensor(pConv->in.get(), x_data); - } else { - ACLImportMemory(pConv->in->allocator(), (void*)x_data, X->Shape().Size() * 4); + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->allocate(); } - const T* k_data = W->Data(); - ACLImportMemory(pConv->k->allocator(), (void*)k_data, W->Shape().Size() * 4); + if (conv_layer) { + conv_layer->prepare(prep_pack); + } else { + depthwise_layer->prepare(); + } - if (B != nullptr) { - const T* b_data = B->Data(); - ACLImportMemory(pConv->b->allocator(), (void*)b_data, B->Shape().Size() * 4); + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->free(); } - T* y_data = Y->MutableData(); - if (Y->Shape().Size() != 0 && pConv->out->info()->has_padding()) { - pConv->out->allocator()->allocate(); - } else { - ACLImportMemory(pConv->out->allocator(), (void*)y_data, Y->Shape().Size() * 4); + if (free_k) { + k->allocator()->free(); } - arm_compute::Allocator alloc_mm{}; - pConv->mm_layer->populate(alloc_mm, 1); - pConv->layer->run(); - pConv->mm_layer->clear(); + return Status::OK(); +} - if (Y->Shape().Size() != 0 && pConv->out->info()->has_padding()) { - importDataFromTensor(pConv->out.get(), y_data); +Status Conv::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { + return Status::OK(); } - pConv->in->allocator()->free(); - pConv->k->allocator()->free(); - if (B != nullptr) - pConv->b->allocator()->free(); - pConv->out->allocator()->free(); + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); - LOGS_DEFAULT(VERBOSE) << std::endl; + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, prepacked_buffers[0].get(), + packedSize, alignment)); + + used_shared_buffers = true; + } return Status::OK(); } -#else -template -Status Conv::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); + +Status Conv::Compute(OpKernelContext* context) const { + provider_->SetThreadPool(context->GetOperatorThreadPool()); const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; - LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str(); - LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str(); - if (B != nullptr) - LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); + Tensor* Y = context->Output(0, outShape); + + const void* x_data = X->DataRaw(); + ORT_RETURN_IF(X->Shape().Size() != 0 && in->info()->has_padding(), "Padded ACL input tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(in->allocator(), (void*)x_data, 0)); + + void* y_data = Y->MutableDataRaw(); + ORT_RETURN_IF(Y->Shape().Size() != 0 && out->info()->has_padding(), "Padded ACL output tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(out->allocator(), (void*)y_data, 0)); + + if (conv_layer) { + arm_compute::MemoryGroupResourceScope scope_mg(const_cast(memory_group)); + conv_layer->run(const_cast(run_pack)); + } else { + depthwise_layer->run(); + } + + in->allocator()->free(); + k->allocator()->free(); + out->allocator()->free(); LOGS_DEFAULT(VERBOSE) << std::endl; - Status s = onnxruntime::Conv::Compute(context); - return s; + return Status::OK(); } #endif ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, + 11, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + Conv, + kOnnxDomain, + 11, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_KERNEL_EX( + NhwcConv, + kMSDomain, 1, kAclExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Conv); + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kOnnxDomain, + 10, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kOnnxDomain, + 10, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kMSDomain, + 1, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kMSDomain, + 1, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); } // namespace acl } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index 660d47b4172df..b05ba5363542f 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/framework/op_kernel.h" #include "core/providers/cpu/nn/conv.h" +#include "core/providers/acl/acl_common.h" #include "core/providers/acl/acl_execution_provider.h" // ACL -#ifdef ACL_2308 #include "arm_compute/runtime/Tensor.h" -#endif #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/IOperator.h" +#include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" #include "arm_compute/runtime/Allocator.h" #include "arm_compute/runtime/PoolManager.h" @@ -19,45 +21,50 @@ #include "arm_compute/runtime/MemoryManagerOnDemand.h" // NEON -#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h" namespace onnxruntime { namespace acl { -typedef struct -{ - std::shared_ptr layer; - std::shared_ptr mm_layer; - std::shared_ptr in; - std::shared_ptr k; - std::shared_ptr b; - std::shared_ptr out; - bool isDepthwiseCPU; -} ACLNEConv; - -typedef std::map::iterator ConvLayersIterator; +Status ValidateConv(const onnxruntime::Node& node); -template -class Conv : public onnxruntime::Conv { +class Conv : public onnxruntime::OpKernel { public: - explicit Conv(const OpKernelInfo& info) : onnxruntime::Conv(info), conv_attrs_(info) { - provider_ = (const_cast( - static_cast(info.GetExecutionProvider()))); - } + explicit Conv(const OpKernelInfo& info); - ~Conv() { - Conv::convLayers.erase(this); - } + Status PrePack(const Tensor&, int, AllocatorPtr, + bool& is_packed, PrePackedWeights*) override; + + Status UseSharedPrePackedBuffers(std::vector&, + int, bool&) override; Status Compute(OpKernelContext* context) const override; protected: - static thread_local std::map convLayers; ConvAttributes conv_attrs_; ACLExecutionProvider* provider_; std::string activation_type; + std::shared_ptr depthwise_layer; + + std::shared_ptr conv_layer; + arm_compute::MemoryGroup memory_group; + arm_compute::ITensorPack run_pack; + arm_compute::ITensorPack prep_pack; + + Workspace workspace; + + std::shared_ptr in; + std::shared_ptr k; + IAllocatorUniquePtr pkRaw; + std::shared_ptr b; + std::shared_ptr out; + TensorShape outShape; + bool is_channels_last; + bool isQuantized; + bool isDepthwiseCPU; + bool has_bias; + arm_compute::TensorShape ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const; }; } // namespace acl diff --git a/onnxruntime/core/providers/acl/nn/fused_conv.cc b/onnxruntime/core/providers/acl/nn/fused_conv.cc index 3cf18394b5c4c..34e50ebdf6921 100644 --- a/onnxruntime/core/providers/acl/nn/fused_conv.cc +++ b/onnxruntime/core/providers/acl/nn/fused_conv.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #ifdef _WIN32 @@ -17,11 +18,13 @@ namespace onnxruntime { namespace acl { -class FusedConv final : public acl::Conv { +class FusedConv final : public acl::Conv { public: - explicit FusedConv(const OpKernelInfo& info) : acl::Conv(info) { + explicit FusedConv(const OpKernelInfo& info) : acl::Conv(info) { ORT_ENFORCE(info.GetAttr("activation", &(this->activation_type)).IsOK()); - ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + ORT_ENFORCE(GetFusedActivationAttr(info, activation).IsOK()); } }; diff --git a/onnxruntime/core/providers/acl/nn/pool.cc b/onnxruntime/core/providers/acl/nn/pool.cc index 01d9bc0302c3a..cbbecef6bbfac 100644 --- a/onnxruntime/core/providers/acl/nn/pool.cc +++ b/onnxruntime/core/providers/acl/nn/pool.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -63,12 +64,7 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, if (pool_attrs.global_pooling) { layer->configure(tpool.in.get(), tpool.out.get(), - arm_compute::PoolingLayerInfo(pool_type -#ifdef ACL_2308 - , - arm_compute::DataLayout::NCHW -#endif - )); + arm_compute::PoolingLayerInfo(pool_type, arm_compute::DataLayout::NCHW)); } else { TensorShapeVector aclStrides(2); aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; @@ -95,8 +91,11 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, aclPads[3] = pads[2]; } - arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo(aclStrides[0], aclStrides[1], - aclPads[0], aclPads[1], aclPads[2], aclPads[3], arm_compute::DimensionRoundingType::FLOOR); + arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo( + (unsigned int)aclStrides[0], (unsigned int)aclStrides[1], + (unsigned int)aclPads[0], (unsigned int)aclPads[1], + (unsigned int)aclPads[2], (unsigned int)aclPads[3], + arm_compute::DimensionRoundingType::FLOOR); TensorShapeVector aclKernelShape(2); aclKernelShape[0] = (kernel_shape.size() > 1) ? kernel_shape[1] : 1; @@ -113,9 +112,7 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, arm_compute::PoolingLayerInfo pool_info(pool_type, aclSize, -#ifdef ACL_2308 arm_compute::DataLayout::NCHW, -#endif aclPadStride, excludePadding); layer->configure(tpool.in.get(), tpool.out.get(), pool_info); @@ -133,8 +130,8 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, aclInpuWindow.use_tensor_dimensions(tpool.in->info()->tensor_shape()); arm_compute::Iterator aclInputIt(tpool.in.get(), aclInpuWindow); - const unsigned int aclWidth = tpool.in->info()->dimension(0); - const unsigned int aclHeight = tpool.in->info()->dimension(1); + const size_t aclWidth = tpool.in->info()->dimension(0); + const size_t aclHeight = tpool.in->info()->dimension(1); // copy input tensor into the larger buffer arm_compute::execute_window_loop( diff --git a/onnxruntime/core/providers/acl/scheduler.cc b/onnxruntime/core/providers/acl/scheduler.cc new file mode 100644 index 0000000000000..e1bab6adb5a1f --- /dev/null +++ b/onnxruntime/core/providers/acl/scheduler.cc @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "core/common/common.h" +#include "scheduler.h" + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace acl { + +void ORTScheduler::set_num_threads(unsigned int num_threads) { + ORT_THROW("Not supported"); +} + +unsigned int ORTScheduler::num_threads() const { + // We can't check the size of the thread pool during kernel initialization, + // as required by ACL. Therefore we have to choose a fixed thread count and + // let some cores run multiple workloads if there are fewer than 32 cores. + // This doesn't seem to cause performance issues with fewer cores in practice. + return 32; +} + +void ORTScheduler::schedule(arm_compute::ICPPKernel* kernel, const Hints& hints) { + arm_compute::ITensorPack tensors; + schedule_op(kernel, hints, kernel->window(), tensors); +} + +void ORTScheduler::schedule_op(arm_compute::ICPPKernel* kernel, const Hints& hints, + const arm_compute::Window& window, arm_compute::ITensorPack& tensors) { + schedule_common(kernel, hints, window, tensors); +} + +void ORTScheduler::run_workloads(std::vector& workloads) { + ThreadPool::TrySimpleParallelFor(_provider->GetThreadPool(), workloads.size(), + [&](std::ptrdiff_t id) { + const arm_compute::ThreadInfo info{ + (int)id, (int)workloads.size(), &cpu_info()}; + workloads[id](info); + }); +} + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/scheduler.h b/onnxruntime/core/providers/acl/scheduler.h new file mode 100644 index 0000000000000..c66700a48f3d5 --- /dev/null +++ b/onnxruntime/core/providers/acl/scheduler.h @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "acl_execution_provider.h" + +#include "arm_compute/runtime/IScheduler.h" +#include "arm_compute/core/CPP/ICPPKernel.h" + +namespace onnxruntime { +namespace acl { + +class ORTScheduler : public arm_compute::IScheduler { + public: + ORTScheduler(ACLExecutionProvider* provider) : _provider(provider) { + } + + void set_num_threads(unsigned int num_threads) override; + + unsigned int num_threads() const override; + + void schedule(arm_compute::ICPPKernel* kernel, const Hints& hints) override; + + void schedule_op(arm_compute::ICPPKernel* kernel, const Hints& hints, + const arm_compute::Window& window, arm_compute::ITensorPack& tensors) override; + + void run_workloads(std::vector& workloads) override; + + private: + ACLExecutionProvider* _provider = nullptr; +}; + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/tensor/concat.cc b/onnxruntime/core/providers/acl/tensor/concat.cc index 75eedaac80aea..0cf02ab8762b9 100644 --- a/onnxruntime/core/providers/acl/tensor/concat.cc +++ b/onnxruntime/core/providers/acl/tensor/concat.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/providers/acl/tensor/concat.h" @@ -76,11 +77,7 @@ Status Concat::Compute(OpKernelContext* ctx) const { LOGS_DEFAULT(VERBOSE) << "Concat ACL:"; arm_compute::Tensor output; -#ifdef ACL_2308 std::vector inputs_vector; -#else - std::vector inputs_vector; -#endif for (int i = 0; i < input_count; i++) { arm_compute::Tensor* input = new arm_compute::Tensor(); auto X = input_tensors[i]; @@ -101,11 +98,7 @@ Status Concat::Compute(OpKernelContext* ctx) const { for (int i = 0; i < input_count; i++) { auto X = input_tensors[i]; const T* x_data = X->Data(); -#ifdef ACL_2308 arm_compute::Tensor* in = const_cast(static_cast(inputs_vector[i])); -#else - arm_compute::Tensor* in = static_cast(inputs_vector[i]); -#endif if (X->Shape().Size() != 0 && in->info()->has_padding()) { in->allocator()->allocate(); diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index c5757095e2e1e..1319e8f6fe959 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "python/onnxruntime_pybind_state_common.h" @@ -54,7 +55,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::VitisAIProviderFactoryCreator::Create(ProviderOptions{}), #endif #ifdef USE_ACL - onnxruntime::ACLProviderFactoryCreator::Create(0), + onnxruntime::ACLProviderFactoryCreator::Create(false), #endif #ifdef USE_ARMNN onnxruntime::ArmNNProviderFactoryCreator::Create(0), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 47b8d75f22aea..e8bf61612c89b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "python/onnxruntime_pybind_exceptions.h" @@ -1141,8 +1142,25 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kAclExecutionProvider) { #ifdef USE_ACL - return onnxruntime::ACLProviderFactoryCreator::Create( - session_options.enable_cpu_mem_arena) + bool enable_fast_math = false; + auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + for (auto option : it->second) { + if (option.first == "enable_fast_math") { + std::set supported_values = {"true", "True", "false", "False"}; + if (supported_values.find(option.second) != supported_values.end()) { + enable_fast_math = (option.second == "true") || (option.second == "True"); + } else { + ORT_THROW( + "Invalid value for enable_fast_math. " + "Select from 'true' or 'false'\n"); + } + } else { + ORT_THROW("Unrecognized option: ", option.first); + } + } + } + return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math) ->CreateProvider(); #endif } else if (type == kArmNNExecutionProvider) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 4d6e411defae3..08e5e4f7b18fa 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -439,7 +440,7 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); #endif -std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); +std::shared_ptr CreateExecutionProviderFactory_ACL(bool enable_fast_math); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); std::shared_ptr CreateExecutionProviderFactory_Nnapi( diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 0397bba90438b..924616f49ab25 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -655,7 +656,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } if (enable_acl) { #ifdef USE_ACL - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ACL(sf, enable_cpu_mem_arena ? 1 : 0)); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ACL(sf, false)); #else fprintf(stderr, "ACL is not supported in this build"); return -1; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7d06bbadbd645..c1c48d4945a4d 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "command_args_parser.h" @@ -66,6 +67,7 @@ namespace perftest { "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -e -i '| |'\n" "\n" + "\t [ACL only] [enable_fast_math]: Options: 'true', 'false', default: 'false', \n" "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index ae7680571ced1..3ed5eaee5a5f7 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "ort_test_session.h" @@ -519,9 +520,42 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL +#if defined(_MSC_VER) + std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); +#else + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; +#endif // defined(_MSC_VER) + std::istringstream ss(ov_string); + std::string token; + bool enable_fast_math = false; + while (ss >> token) { + if (token == "") { + continue; + } + auto pos = token.find("|"); + if (pos == std::string::npos || pos == 0 || pos == token.length()) { + ORT_THROW("[ERROR] [ACL] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + auto key = token.substr(0, pos); + auto value = token.substr(pos + 1); + + if (key == "enable_fast_math") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + enable_fast_math = (value == "true") || (value == "True"); + } else { + ORT_THROW( + "[ERROR] [ACL] You have selcted an invalid value for the key 'enable_fast_math'. " + "Select from 'true' or 'false' \n"); + } + } else { + ORT_THROW( + "[ERROR] [ACL] Unrecognized option: ", key); + } + } Ort::ThrowOnError( - OrtSessionOptionsAppendExecutionProvider_ACL(session_options, - performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); + OrtSessionOptionsAppendExecutionProvider_ACL(session_options, enable_fast_math)); #else ORT_THROW("Acl is not supported in this build\n"); #endif diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index a5015c18cee63..177647ab5be6b 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -254,7 +255,7 @@ TEST_P(ModelTest, Run) { #endif #ifdef USE_ACL else if (provider_name == "acl") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, 0)); + ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); } #endif #ifdef USE_ARMNN diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..6451f8ec6dce8 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -207,11 +208,11 @@ std::unique_ptr DefaultRknpuExecutionProvider() { #endif } -std::unique_ptr DefaultAclExecutionProvider(bool enable_arena) { +std::unique_ptr DefaultAclExecutionProvider(bool enable_fast_math) { #ifdef USE_ACL - return ACLProviderFactoryCreator::Create(enable_arena)->CreateProvider(); + return ACLProviderFactoryCreator::Create(enable_fast_math)->CreateProvider(); #else - ORT_UNUSED_PARAMETER(enable_arena); + ORT_UNUSED_PARAMETER(enable_fast_math); return nullptr; #endif } diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 606dfc068d399..b3a619022f79b 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/common/optional.h" @@ -53,7 +54,7 @@ std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); std::unique_ptr DefaultVSINPUExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); -std::unique_ptr DefaultAclExecutionProvider(bool enable_arena = true); +std::unique_ptr DefaultAclExecutionProvider(bool enable_fast_math = false); std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultRocmExecutionProvider(bool test_tunable_op = false); std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlprogram = false); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 902d15e8122b4..8535f1e8c85a0 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # Licensed under the MIT License. import argparse @@ -651,9 +652,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--enable_transformers_tool_test", action="store_true", help="Enable transformers tool test") parser.add_argument( "--use_acl", - nargs="?", - const="ACL_1905", - choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002", "ACL_2308"], + action="store_true", help="Build with ACL for ARM architectures.", ) parser.add_argument("--acl_home", help="Path to ACL home dir") @@ -1052,11 +1051,6 @@ def generate_build_tree( "-Donnxruntime_USE_TELEMETRY=" + ("ON" if args.use_telemetry else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), - "-Donnxruntime_USE_ACL_1902=" + ("ON" if args.use_acl == "ACL_1902" else "OFF"), - "-Donnxruntime_USE_ACL_1905=" + ("ON" if args.use_acl == "ACL_1905" else "OFF"), - "-Donnxruntime_USE_ACL_1908=" + ("ON" if args.use_acl == "ACL_1908" else "OFF"), - "-Donnxruntime_USE_ACL_2002=" + ("ON" if args.use_acl == "ACL_2002" else "OFF"), - "-Donnxruntime_USE_ACL_2308=" + ("ON" if args.use_acl == "ACL_2308" else "OFF"), "-Donnxruntime_USE_ARMNN=" + ("ON" if args.use_armnn else "OFF"), "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), From 59b7b6bb7cbb7bcc86dab590f1b4d5ed50d53dec Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 13 Sep 2024 16:52:49 +0000 Subject: [PATCH 18/26] Remove training from web ci pipeline (#22082) ### Description Remove training from web ci pipeline ### Motivation and Context --- .../templates/linux-wasm-ci.yml | 21 ------------------- .../azure-pipelines/templates/win-web-ci.yml | 6 +----- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index a56eb37faef84..2ab432e94fcbd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -31,10 +31,6 @@ parameters: type: boolean default: false -- name: BuildTraining - type: boolean - default: true - - name: WithCache type: boolean default: false @@ -116,19 +112,6 @@ jobs: DisplayName: 'Build and test (browser) (simd + threads)' WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildTraining, true) }}: - - template: build-linux-wasm-step.yml - parameters: - Today: $(Today) - ${{ if eq(parameters.BuildStaticLib, true)}}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} | static - ${{ else }}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} - CacheDir: $(ORT_CACHE_DIR)/wasm_training - Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_training --enable_training_apis --target onnxruntime_webassembly --skip_tests' - DisplayName: 'Build (training + simd + threads)' - WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildJsep, true) }}: - template: build-linux-wasm-step.yml parameters: @@ -150,10 +133,6 @@ jobs: cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.wasm $(Build.ArtifactStagingDirectory) cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.mjs $(Build.ArtifactStagingDirectory) fi - if [ -d $(Build.BinariesDirectory)/wasm_training ]; then - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.wasm $(Build.ArtifactStagingDirectory) - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.mjs $(Build.ArtifactStagingDirectory) - fi displayName: 'Create Artifacts' - ${{ if eq(parameters.SkipPublish, false) }}: - task: PublishPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index c1fde93d8e640..0e8a7eb94379b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -214,11 +214,7 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'E2E package consuming test' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) - - script: | - npm run test:training:e2e - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'E2E training package test' - condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) + - task: CopyFiles@2 inputs: sourceFolder: $(Build.SourcesDirectory)\js\common From 7e2c722459a7a7015a238379acc8705d9ce5b8dc Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Fri, 13 Sep 2024 13:21:11 -0700 Subject: [PATCH 19/26] Add Continuous Decoding support in GQA (#21523) ### Description This PR will add support for Continuous Decoding for batch_size = 1 input. From now on, GQA can take arbitrary length input using seqlens_k as total_sequence_length - 1 and the sequence length of qkv as new_sequence_length. **This change will not affect the default behavior of GQA** ### Motivation and Context Prior to this change it was impossible to support sequence_length > 1 inputs when past context was given. This use case is essential to making continuous decoding work, which is one of our current efforts in ORT-GenAI. --- docs/ContribOperators.md | 6 +- .../contrib_ops/cpu/bert/attention_common.h | 3 +- .../contrib_ops/cpu/bert/attention_helper.h | 11 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 177 +++++------ .../cpu/bert/group_query_attention.cc | 31 +- .../cpu/bert/group_query_attention_helper.h | 36 ++- .../cpu/sparse/sparse_attention_base.h | 4 +- .../bert/cutlass_fmha/fmha_launch_template.h | 1 - .../cuda/bert/group_query_attention.cc | 5 +- .../cuda/bert/group_query_attention_helper.h | 298 ------------------ .../cuda/bert/group_query_attention_impl.cu | 149 ++++++--- .../cuda/bert/group_query_attention_impl.h | 4 +- .../rocm/bert/group_query_attention.cu | 10 +- .../core/graph/contrib_ops/bert_defs.cc | 8 +- .../transformers/test_flash_attn_cuda.py | 171 +++++++++- .../test/python/transformers/test_gqa_cpu.py | 79 ++++- .../transformers/test_sparse_attention.py | 7 +- 17 files changed, 498 insertions(+), 502 deletions(-) delete mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index aadf4ebe2f488..09a7e47fc9913 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2521,6 +2521,8 @@ This version of the operator has been available since version 1 of the 'com.micr Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. + Supports continuous decoding for batch_size == 1 for CPU and CUDA. + #### Version @@ -2561,9 +2563,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_value (optional) : T
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
seqlens_k : M
-
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
+
1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).
total_sequence_length : M
-
Scalar tensor of total sequence length (past + new).
+
Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for checking inputs and determining prompt vs token generation case.
cos_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 45acb90ba68b0..e0fa581c8071d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -114,7 +114,8 @@ struct GroupQueryAttentionParameters { int local_window_size; bool kv_share_buffer; bool is_packed_qkv; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt; // indicates whether this is first decoding step bool do_rotary; bool rotary_interleaved; bool use_smooth_softmax; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index e6c948acb0d6c..4d435f71cc195 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -236,19 +236,16 @@ T* ConcatStateChunkGQA(const T* past, size_t past_buff_chunk_length, size_t past_chunk_length, size_t new_chunk_length, - bool is_prompt, bool past_present_share_buffer, std::ptrdiff_t i) { T* start = present + i * present_buff_chunk_length; T* p = start; - if (!is_prompt) { - if (!past_present_share_buffer) { - const T* src_past = past + i * past_buff_chunk_length; - memcpy(p, src_past, past_chunk_length * sizeof(T)); - } - p += past_chunk_length; + if (!past_present_share_buffer && past_chunk_length > 0) { + const T* src_past = past + i * past_buff_chunk_length; + memcpy(p, src_past, past_chunk_length * sizeof(T)); } + p += past_chunk_length; memcpy(p, chunk, new_chunk_length * sizeof(T)); return start; diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 2bf0aa0915c2d..bfec9aef56727 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -59,6 +59,7 @@ class GQAAttentionBase { GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int head_size = parameters.head_size; @@ -88,14 +89,14 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - tp); + is_prompt, tp); return Status::OK(); } @@ -105,35 +106,35 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int past_buffer_sequence_length, // sequence length of past state - int present_buffer_sequence_length, // sequence length of present state - int head_size, // head size of self-attention - const T* past_key, // past key only - T* present_key, // present key only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed - ThreadPool* tp) const { // thread pool - const bool is_prompt = sequence_length != 1; + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size of self-attention + const size_t sequence_length, // sequence length of self-attention (S) + const size_t past_buffer_sequence_length, // sequence length of past state + const size_t present_buffer_sequence_length, // sequence length of present state + const size_t head_size, // head size of self-attention + const T* past_key, // past key only + T* present_key, // present key only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt + ThreadPool* tp) const { // thread pool const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H - const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t q_input_chunk_length = sequence_length * head_size; // S x H + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } - const int loop_len = batch_size * num_heads_; + const size_t loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; @@ -156,12 +157,11 @@ class GQAAttentionBase { ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i) / num_heads_; - const int head_index = static_cast(i) % num_heads_; - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; T* output = attention_probs + output_offset; @@ -174,7 +174,7 @@ class GQAAttentionBase { } if (nullptr != present_key) { k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); } @@ -189,16 +189,17 @@ class GQAAttentionBase { } else { q = Q + q_input_chunk_length * i; } + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length, - nullptr); + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, + static_cast(present_buffer_sequence_length), nullptr); // compute Softmax T* output_softmax = output; - for (int seq = 0; seq < sequence_length; seq++) { - int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1; - if (local_window_size_ > 0 && seq_causal_length > local_window_size_ + 1) { - for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { + for (size_t seq = 0; seq < sequence_length; seq++) { + size_t seq_causal_length = past_seqlen + seq + 1; + if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { + for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } if (softcap_ > 0.f) { @@ -214,17 +215,17 @@ class GQAAttentionBase { } } else { if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax, seq_causal_length, softcap_); + ComputeAttentionSoftcapInplace(output_softmax, static_cast(seq_causal_length), softcap_); } if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeSmoothSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeAttentionSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } } // set causal [seq_causal_length, total_seqlen) to 0.f - for (int total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { + for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } @@ -235,34 +236,36 @@ class GQAAttentionBase { } template - void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT - const T* V, // V value with size BxN_kvxSxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size - int sequence_length, // sequence length - int past_buffer_sequence_length, // sequence length in past state - int present_buffer_sequence_length, // sequence length in past state - int head_size, // head size of Q, K, V - int hidden_size, // hidden size of Output - const T* past_value, // past value only - T* present_value, // present value only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxN_kvxSxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size + const size_t sequence_length, // sequence length + const size_t past_buffer_sequence_length, // sequence length in past state + const size_t present_buffer_sequence_length, // sequence length in past state + const size_t head_size, // head size of Q, K, V + const size_t hidden_size, // hidden size of Output + const T* past_value, // past value only + T* present_value, // present value only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt ThreadPool* tp) const { - const bool is_prompt = sequence_length != 1; const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const int kv_input_chunk_length = sequence_length * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } + const size_t loop_len = batch_size * num_heads_; + // The cost of Gemm TensorOpCost unit_cost; unit_cost.compute_cycles = @@ -282,37 +285,35 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor( - tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; - - const T* v; - if (packed_qkv) { - v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); - } else { - v = V + kv_input_chunk_length * (i / kv_num_heads_factor); - } - if (nullptr != present_value) { - v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, - i / kv_num_heads_factor); - } + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_value) { + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, + i / kv_num_heads_factor); + } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, - 1.f, /*alpha*/ - attention_probs + attention_probs_offset, present_buffer_sequence_length, v, - head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); - } - }); + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ + attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, static_cast(head_size), + 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + } + }); } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 87675255f5ba4..2a38e4a1ac636 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -45,7 +45,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* past_key = context->Input(3); const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); - const Tensor* total_seqlen = context->Input(6); + const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); @@ -61,7 +61,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { num_heads_, kv_num_heads_, seqlens_k, - total_seqlen, + total_seqlen_tensor, scale_, softcap_)); @@ -103,6 +103,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } if (do_rotary_) { + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; rotary_params.sequence_length = sequence_length; @@ -114,17 +115,29 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { rotary_params.seq_stride = head_size; rotary_params.head_stride = sequence_length * rotary_params.seq_stride; rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride; - rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0; + rotary_params.position_ids_format = !parameters.is_first_prompt ? 1 : 0; rotary_params.transposed = true; auto* tp = context->GetOperatorThreadPool(); - std::vector pos_ids(sequence_length == 1 ? batch_size : 1); - if (sequence_length == 1) { + // Generate position ids + const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length; + std::vector pos_ids(pos_ids_size); + if (parameters.is_first_prompt) { + pos_ids[0] = static_cast(0); + } else { + // Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { - pos_ids[b] = static_cast(seqlens_k->Data()[b]); + const int total_seqlen = seqlens_k->Data()[b] + 1; + const int past_seqlen = total_seqlen - sequence_length; + for (int s = 0; s < sequence_length; s++) { + if (past_seqlen + s < total_seqlen) { + pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; + } else { + pos_ids[b * sequence_length + s] = static_cast(1); + } + } } - } else { - pos_ids[0] = static_cast(0); } + // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; T* q_rotary; @@ -149,6 +162,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Q = RotaryQ; K = RotaryK; } + // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), q_rotary, rotary_interleaved_)); @@ -161,6 +175,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), k_rotary, rotary_interleaved_)); + // Pack V into rotary QKV buffer if (packed_qkv) { const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 3342052260ff9..0bdee151d2173 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -168,14 +168,13 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + const auto& seqlens_k_dim = seqlens_k->Shape().GetDims(); + if (seqlens_k_dim.size() != 1 && seqlens_k_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "seqlens_k must be shape (batch_size)."); } - // Set present sequence length and kv_share_buffer from input total_seqlen tensor + // Set present sequence length from input total_seqlen tensor if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "total_sequence_length tensor must be of one element."); @@ -195,11 +194,11 @@ Status CheckInputs(const Tensor* query, } if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); + "cos_cache dimension 0 shall not be less than total_sequence_length."); } if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); + "sin_cache dimension 0 shall not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -219,7 +218,26 @@ Status CheckInputs(const Tensor* query, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); } - bool is_prompt = sequence_length != 1; + bool is_subsequent_prompt = false; + if (sequence_length > 1 && sequence_length != total_sequence_length) { + if (batch_size != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "batch_size must be 1 when sequence_length > 1 and past context is given."); + } + is_subsequent_prompt = true; + } + + bool is_first_prompt; + if (is_subsequent_prompt) { + is_first_prompt = false; // irrelevant for interactive decoding + } else { + // If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt + is_first_prompt = (sequence_length == total_sequence_length); + if (!is_first_prompt && sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sequence_length shall be 1 when it is not prompt."); + } + } if (parameters != nullptr) { GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); @@ -227,6 +245,7 @@ Status CheckInputs(const Tensor* query, output_parameters->sequence_length = sequence_length; // sequence length of Q output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->total_sequence_length = total_sequence_length; // total sequence length output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; output_parameters->head_size = head_size; @@ -235,7 +254,8 @@ Status CheckInputs(const Tensor* query, output_parameters->rotary_dim = rotary_dim; output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; - output_parameters->is_prompt = is_prompt; + output_parameters->is_subsequent_prompt = is_subsequent_prompt; + output_parameters->is_first_prompt = is_first_prompt; output_parameters->scale = scale; output_parameters->softcap = softcap; output_parameters->qkv_format = qkv_format; diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index cf66bd8407126..37172074e5d86 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -184,7 +184,7 @@ class SparseAttentionBase { // Concatenate past_k + k -> present_k // TODO: avoid copying mutiple times for a group. k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); // Compute Q*K' + AttentionMask @@ -365,7 +365,7 @@ class SparseAttentionBase { // Concatenate past_v + v -> present_v v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a10d2548fa7b8..7f1c3786858c8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -42,7 +42,6 @@ struct RightPaddingBatchHook { auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; - // Advance to current batch - in case of different sequence lengths if (p.seqlen_k_ptr) { p.num_keys = p.seqlen_k_ptr[batch_id]; } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index d0ae812bb4fa2..6eff584cec5da 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -5,7 +5,7 @@ #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" -#include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" @@ -95,7 +95,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { kv_num_heads_, seqlens_k, total_seqlen, - is_past_bsnh_, scale_, softcap_, device_prop.maxThreadsPerBlock)); @@ -253,7 +252,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } if (seqlens_k_buffer != nullptr) { - data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + data.seqlens_k_buff = reinterpret_cast(seqlens_k_buffer.get()); } // Memory Efficient Buffers if (k_buffer != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h deleted file mode 100644 index e65827e4ccdd5..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/common.h" -#include "contrib_ops/cpu/bert/attention_common.h" - -namespace onnxruntime { -namespace contrib { -namespace group_query_attention_helper { - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap) { - // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // no packing for q/k/v: - // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) - // key (K) : (B, S, D_kv) or nullptr - // value (V) : (B, S, D_kv) or nullptr - AttentionQkvFormat qkv_format = Q_K_V_BSNH; - AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - const bool is_packed_qkv = key == nullptr; - const auto& query_dims = query->Shape().GetDims(); - - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", - query_dims.size()); - } - - int batch_size = static_cast(query_dims[0]); - int sequence_length = static_cast(query_dims[1]); - int q_hidden_size = static_cast(query_dims[2]); - int head_size = 0; - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - int kv_hidden_size = 0; - // Check key and value when not packed - if (!is_packed_qkv) { - head_size = static_cast(q_hidden_size) / num_heads; - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } else if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != key_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 1 (sequence length)"); - } - kv_hidden_size = static_cast(key_dims[2]); - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } else if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 1 (sequence length)"); - } else if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { - // Check packed qkv - head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - q_hidden_size = head_size * num_heads; - kv_hidden_size = head_size * kv_num_heads; - } - - // Check past-present KV - int32_t past_sequence_length = 0; - if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - // BNSH - if (!is_past_bsnh) { - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[2]); - // BSNH - } else { - if (past_key_dims[1] != past_value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[1]); - } - - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - } else if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be both present or both absent."); - } - - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "seqlens_k must be shape (batch_size)."); - } - - // Set present sequence length and kv_share_buffer from input total_seqlen tensor - if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "total_sequence_length tensor must be of one element."); - } - int total_sequence_length = *((*total_seqlen).template Data()); - int present_sequence_length = std::max(total_sequence_length, past_sequence_length); - - int rotary_dim = 0; - if (cos_cache != nullptr && sin_cache != nullptr) { - const auto& cos_dims = cos_cache->Shape().GetDims(); - const auto& sin_dims = sin_cache->Shape().GetDims(); - - if (head_size % 16 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size shall be a multiple of 16. Got head_size % 16 == ", - head_size % 16); - } - if (cos_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); - } - if (sin_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); - } - if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (cos_dims[1] != sin_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache dimension 1 must be the same."); - } - rotary_dim = static_cast(cos_dims[1] * 2); - } else if (cos_cache != nullptr || sin_cache != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); - } - - bool is_prompt = (sequence_length == total_sequence_length); - if (!is_prompt && sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sequence_length shall be 1 when it is not prompt."); - } - - if (parameters != nullptr) { - GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = batch_size; - output_parameters->sequence_length = sequence_length; // sequence length of Q - output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors - output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors - output_parameters->total_sequence_length = total_sequence_length; // total sequence length - output_parameters->hidden_size = q_hidden_size; - output_parameters->num_heads = num_heads; - output_parameters->head_size = head_size; - output_parameters->kv_hidden_size = kv_hidden_size; - output_parameters->kv_num_heads = kv_num_heads; - output_parameters->rotary_dim = rotary_dim; - output_parameters->is_packed_qkv = is_packed_qkv; - output_parameters->is_prompt = is_prompt; - output_parameters->scale = scale; - output_parameters->softcap = softcap; - output_parameters->qkv_format = qkv_format; - output_parameters->past_kv_format = past_kv_format; - } - - return Status::OK(); -} - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap, - int max_threads_per_block) { - if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); - } - - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale, softcap); -} - -} // namespace group_query_attention_helper -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index be94f26ec298f..8bf9848245ec7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -71,6 +71,8 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, + // const int* seqlens_q, const bool is_bsnh) { // refers to past; otherwise bnsh const int h = threadIdx.x; const int n = threadIdx.y; @@ -88,7 +90,9 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -96,7 +100,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { // Note: new KV always BSNH const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; @@ -116,6 +120,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, const bool is_bsnh) { int i = threadIdx.x + (blockDim.x * blockIdx.x); if (i < H * num_heads) { @@ -132,7 +137,9 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -140,7 +147,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; const int new_head_stride = H; @@ -160,13 +167,12 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter const int max_threads_per_block, const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); - + const int* seqlens_k = parameters.is_first_prompt ? nullptr : reinterpret_cast(data.seqlens_k); AttentionQkvFormat past_kv_format = parameters.past_kv_format; assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -180,6 +186,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, @@ -187,6 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { int steps = (H * kv_num_heads + 255) / 256; @@ -200,6 +208,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKVLarge<<>>(kv_sequence_length, past_sequence_length, @@ -209,6 +218,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } return CUDA_CALL(cudaGetLastError()); @@ -219,7 +229,7 @@ template __global__ void ConcatKVInPlace(const int max_seqlen, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { @@ -234,7 +244,7 @@ __global__ void ConcatKVInPlace(const int max_seqlen, const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -253,7 +263,7 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int kv_num_heads, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { // refers to kv buff; otherwise bnsh @@ -264,9 +274,10 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int s = blockIdx.y; const int b = blockIdx.z; const int new_seqlen = gridDim.y; + const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -286,15 +297,15 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const T* new_key, const T* new_value, T* present_key, T* present_value, - bool is_past_kv_bnsh_format, - bool is_new_kv_bnsh_format, + const bool is_past_kv_bnsh_format, + const bool is_new_kv_bnsh_format, cudaStream_t stream, const int max_threads_per_block) { static_assert(sizeof(T) == 2); @@ -307,14 +318,14 @@ Status LaunchConcatKVInPlace(int batch_size, ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -327,7 +338,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -336,7 +347,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -354,7 +365,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, cudaStream_t stream, const int max_threads_per_block) { const int max_sequence_length = parameters.seqlen_present_kv_cache; - const int* past_seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + const int* seqlens_k = (parameters.is_first_prompt && !parameters.is_subsequent_prompt) ? nullptr + : reinterpret_cast(data.seqlens_k); assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -364,8 +376,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, parameters.kv_num_heads, parameters.head_size, max_sequence_length, - past_seqlens_k, - nullptr, // total_seqlens_k is not available + seqlens_k, + nullptr, // total_seqlens_k would be wrong to use here parameters.sequence_length, reinterpret_cast(new_key), reinterpret_cast(new_value), @@ -495,23 +507,33 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; } -// Convert Past to Total sequence length tensor -Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, - int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int /*threads_per_block*/) { - if (parameters.is_prompt) { - return Status::OK(); - } - const int batch_size = parameters.batch_size; - const int add_seqlen = is_total ? parameters.sequence_length : 0; - +// Calculate total sequence length from seqlens_k +Status LaunchGetSeqlensTotal(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int batch_size, cudaStream_t stream, + const int /*threads_per_block*/) { const dim3 grid(1, 1, 1); // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads const dim3 block(batch_size, 1, 1); + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, 1); + return CUDA_CALL(cudaGetLastError()); +} - // TODO(aciddelgado): small version - PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); +// Currently, interactive decoding only works for batch_size 1 +__global__ void GetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + seqlens_k_buff[tid] = seqlens_k[tid] + 1 - sequence_length; + } +} +// Calculate past sequence length for each batch entry for flash attention kernel +Status LaunchGetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length, cudaStream_t stream, + const int max_threads_per_block) { + const int threads = std::min(batch_size, max_threads_per_block); + const int blocks = (threads / max_threads_per_block) + 1; + GetSeqlensInteractive<<>>(seqlens_k, seqlens_k_buff, batch_size, + sequence_length); return CUDA_CALL(cudaGetLastError()); } @@ -576,7 +598,22 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp return CUDA_CALL(cudaGetLastError()); } -// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsInteractive(const int32_t* seqlens_k, int64_t* position_ids, + const int seqlen, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + const int total_seqlen = seqlens_k[b] + 1; + const int past_seqlen = total_seqlen - seqlen; + if (past_seqlen + s < total_seqlen) { + position_ids[tid] = past_seqlen + s; + } else { + position_ids[tid] = 1; + } + } +} + __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -591,7 +628,6 @@ __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* positio } } -// Kernel to convert seqlens_k to position_ids __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { @@ -601,12 +637,15 @@ __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position // Convert seqlens_k to position_ids Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + int64_t* position_ids, cudaStream_t stream, + const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + SeqlensToPosIdsInteractive<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -650,7 +689,12 @@ Status FlashAttention( } void* seqlens_k = reinterpret_cast(data.seqlens_k); - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensInteractive(reinterpret_cast(data.seqlens_k), + reinterpret_cast(data.seqlens_k_buff), batch_size, + sequence_length, stream, max_threads_per_block)); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); + } else if (parameters.is_first_prompt) { // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value // user should use seqlens_k to index into output to get new tokens if (batch_size <= parameters.zeros_count) { @@ -659,10 +703,12 @@ Status FlashAttention( // Launch kernel to create larger seqlen tensor when batch_size > 256 constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); - seqlens_k = data.seqlens_k_total; + repeat_seqlen<<>>(data.seqlens_k_buff, 0, batch_size); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); } - } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + } + + if (!parameters.kv_share_buffer || parameters.is_first_prompt) { // copy past kv to present kv ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, true)); } @@ -682,7 +728,7 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); - // if (parameters.left_padding && parameters.is_prompt) { + // if (parameters.left_padding && parameters.is_first_prompt) { // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); // } @@ -766,15 +812,16 @@ Status EfficientAttention( key = reinterpret_cast(k_buffer); } - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt || !parameters.is_first_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensTotal(data.seqlens_k, data.seqlens_k_buff, batch_size, stream, 256)); + } else { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + repeat_seqlen<<>>(data.seqlens_k_buff, parameters.sequence_length, batch_size); - } else { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } + int* seqlens_k = data.seqlens_k_buff; if (parameters.kv_share_buffer) { // Share buffer case @@ -815,7 +862,7 @@ Status EfficientAttention( } DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + DUMP_TENSOR("seqlens_k", seqlens_k, batch_size, 1); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -823,14 +870,14 @@ Status EfficientAttention( p.batch_size = batch_size; p.num_heads = num_heads; p.sequence_length = sequence_length; - p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.kv_sequence_length = present_sequence_length; // maybe remove p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; p.causal = true; p.scale = scale; p.softcap = parameters.softcap; - p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqlen_k_ptr = seqlens_k; // Note: seqlens_k is total sequence length for efficient p.seqstart_q_ptr = nullptr; p.seqstart_k_ptr = nullptr; p.query = query; @@ -912,7 +959,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const half* new_key, @@ -928,7 +975,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const BFloat16* new_key, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index e8dc69188b95f..8593ecede2bab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -27,7 +27,7 @@ struct GroupQueryAttentionData { T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; - int* seqlens_k_total = nullptr; + int* seqlens_k_buff = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; T* unpacked_qkv_buffer = nullptr; @@ -61,7 +61,7 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, // max sequence length of present_key or present_value. - const int* past_seqlens_k, // it is not used when total_seqlens_k is available. + const int* seqlens_k, // it is not used when total_seqlens_k is available. const int* total_seqlens_k, // optional, nullptr means it is not available. int new_seq_len, const T* new_key, diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 7a16eb38181aa..e644b7e903138 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -5,7 +5,7 @@ #include "core/providers/rocm/rocm_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/rocm/bert/rotary_embedding_impl.h" #include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" @@ -115,7 +115,7 @@ Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -325,7 +325,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // build present kv cache auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { // copy prompt kv to present kv ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); @@ -383,7 +383,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { return ret; } - if (parameters.is_prompt && is_unidirectional_) { + if (parameters.is_first_prompt && is_unidirectional_) { return mask_info::decode("t", sequence_length, kv_sequence_length); } @@ -496,7 +496,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.head_size, parameters.head_size, // v head size GetCkFmhaDataTypeString(), - !parameters.is_prompt, // true, // is_group_mode + !parameters.is_first_prompt, // true, // is_group_mode true, // is_v_rowmajor ? dim is fastest : seq is fastest mask.type, bias_type, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 5185205f1dde1..c706c6fc5ff5f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1049,6 +1049,8 @@ Supports different number of heads for q and kv for CPU and CUDA. Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. +Supports continuous decoding for batch_size == 1 for CPU and CUDA. + )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1110,12 +1112,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(5, "seqlens_k", - // For prompt, the value is number of tokens (excluding padding) - 1. - "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", + "1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).", "M") .Input(6, "total_sequence_length", - "Scalar tensor of total sequence length (past + new).", + "Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for " + "checking inputs and determining prompt vs token generation case.", "M") .Input(7, "cos_cache", diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index c04929a3b603e..46ab905977f48 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -223,6 +223,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + interactive=False, softcap=0.0, use_smooth_softmax=False, ): @@ -1224,7 +1225,7 @@ def parity_check_gqa_prompt( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1422,7 +1423,7 @@ def parity_check_gqa_prompt_no_buff( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1597,7 +1598,7 @@ def parity_check_gqa_past( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1667,7 +1668,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1696,7 +1696,6 @@ def parity_check_gqa_past( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1730,6 +1729,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1783,15 +1784,14 @@ def parity_check_gqa_past( numpy.testing.assert_allclose( present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) def parity_check_gqa_past_no_buff( config, - causal=False, + causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1864,7 +1864,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1896,7 +1895,6 @@ def parity_check_gqa_past_no_buff( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1930,6 +1928,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1976,6 +1976,23 @@ def parity_check_gqa_past_no_buff( f" with {config}, causal={causal}, local={local}, past_format={past_format}," f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" ) + for b in range(config.batch_size): + numpy.testing.assert_allclose( + present_k[b, :, : (cache_seqlens + 1)[b]], + k_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + numpy.testing.assert_allclose( + present_v[b, :, : (cache_seqlens + 1)[b]], + v_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) @@ -2229,6 +2246,86 @@ def gqa_past_flash_attention_test_cases(): ) +def gqa_interactive_one_batch_flash_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + config, + local, + rotary, + rotary_interleaved, + packed, + ) + + +def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + ) + + class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): @@ -2350,6 +2447,60 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle use_smooth_softmax=True, ) + @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) + def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + if not has_flash_attention(): + return + print("------- FLASH ATTENTION (INTERACTIVE) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + + @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) + def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): + if not has_memory_efficient(): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") + + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index cc9d7ff51a5c6..dc21d4e4a5890 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -121,8 +121,12 @@ def rotate_tensor( else: x_rot = torch.cat((real, imag), dim=-1) else: - cos_x = cos[:, 0:seq_len, :, :] - sin_x = sin[:, 0:seq_len, :, :] + batch_size = x.shape[0] + cos_x = torch.zeros((batch_size, seq_len, 1, cos.shape[3]), device=x.device) + sin_x = torch.zeros((batch_size, seq_len, 1, sin.shape[3]), device=x.device) + for b in range(x.shape[0]): + cos_x[b] = cos[0, pos[b] : pos[b] + seq_len, :, :] + sin_x[b] = sin[0, pos[b] : pos[b] + seq_len, :, :] real = cos_x * x1 - sin_x * x2 imag = sin_x * x1 + cos_x * x2 if interleaved: @@ -716,7 +720,6 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - # TODO: do we need io binding for cpu input? io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -788,6 +791,7 @@ def gqa_past_func( softcap=0.0, use_smooth_softmax=False, ): + assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( config, past_kv_format, @@ -819,12 +823,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -867,12 +871,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -1518,7 +1522,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1576,6 +1579,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1739,7 +1744,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1800,6 +1804,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -2000,6 +2006,61 @@ def test_gqa_past(self): ) self.assertTrue(all_close) + def test_gqa_interactive_one_batch(self): + print("-------- TEST GQA INTERACTIVE ---------") + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 6a08d2101b100..5dbb9a277e45a 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -890,7 +890,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, ) yield config @@ -929,7 +929,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, # test smaller kv cache buffer. ) yield config @@ -940,7 +940,6 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot class TestSparseAttention(unittest.TestCase): - @unittest.skipUnless(has_cuda_support(), "cuda not available") def test_sparse_attention_cuda(self): major, minor = torch.cuda.get_device_capability() @@ -1056,7 +1055,7 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool): vert_stride=4, softmax_scale=None, do_rotary=do_rotary, - rotary_interleaved=(past_seq_len % 2 == 1), + rotary_interleaved=do_rotary and (past_seq_len % 2 == 1), device=device, is_packed_qkv=packed_qkv, max_rotary_sequence_length=None if past_seq_len >= 128 else 128, # test smaller rotary buffer. From a89bddd5c224c045510d09537a95d32602e021cc Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Fri, 13 Sep 2024 14:55:08 -0700 Subject: [PATCH 20/26] Matmul_nbits kernel for mlas sqnbits to support Fp16 inputs (#21807) --- cmake/onnxruntime_mlas.cmake | 4 +- docs/OperatorKernels.md | 2 +- .../cpu/quantization/matmul_nbits.cc | 246 +++++++++++++----- .../cpu/quantization/matmul_nbits_impl.cc | 11 +- onnxruntime/core/mlas/inc/mlas.h | 36 ++- onnxruntime/core/mlas/lib/cast.cpp | 42 ++- onnxruntime/core/mlas/lib/mlasi.h | 11 +- onnxruntime/core/mlas/lib/platform.cpp | 4 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 45 ++++ .../core/providers/cpu/tensor/cast_op.cc | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 54 +++- 11 files changed, 341 insertions(+), 116 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b612b3ead4658..e35c83ba45952 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -580,10 +580,10 @@ message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") else() message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d57394b3e7b97..121240e6e18f9 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -488,7 +488,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index bf43aca73ef3a..ccb779721d006 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -146,8 +146,15 @@ class MatMulNBits final : public OpKernel { bool all_constant_{false}; #endif // defined(ORT_NEURAL_SPEED) + + template + Status ComputeTyped(OpKernelContext* ctx) const; }; +bool IsATypeFloat16(const Tensor& tensor) { + return tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +} + Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { @@ -211,10 +218,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) ORT_UNUSED_PARAMETER(prepacked_weights); const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } if (input_idx == InputIndex::B) { - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { - return Status::OK(); - } packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); if (packed_b_size_ == 0) { return Status::OK(); @@ -226,8 +233,15 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } else if (compute_type == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + if (IsATypeFloat16(tensor)) { + auto sptr = tensor.Data(); + std::vector scales_v(static_cast(tensor.Shape().Size())); + MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), &scales_v[0], has_zp_input_, nullptr, nullptr); + } else { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + } is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); @@ -274,9 +288,20 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } Status MatMulNBits::Compute(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(InputIndex::A); + + if (IsATypeFloat16(*a)) { + return ComputeTyped(ctx); + } else { + return ComputeTyped(ctx); + } +} + +template +Status MatMulNBits::ComputeTyped(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); - const auto* a_data = a->Data(); + const auto* a_data = a->Data(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -289,7 +314,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { return Status::OK(); } - auto* y_data = y->MutableData(); + auto* y_data = y->MutableData(); const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); @@ -297,9 +322,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), - helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + // clang-format off + const bool has_single_b_matrix = std::all_of( + helper.RightOffsets().begin(), + helper.RightOffsets().end(), + [](size_t offset) { return offset == 0; }); + // clang-format on #if defined(ORT_NEURAL_SPEED) @@ -336,9 +364,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* bias = ctx->Input(InputIndex::bias); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); IAllocatorUniquePtr workspace{}; const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( @@ -349,26 +377,64 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; + if constexpr (std::is_same::value) { + InlinedVector data(batch_count); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); + + auto tmp_scales_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(scales->Shape().Size())); + MlasConvertHalfToFloatBuffer(scales_data, tmp_scales_data_ptr.get(), static_cast(scales->Shape().Size())); + + std::vector bias_data_v; + if (bias_data != nullptr) { + bias_data_v.resize((const unsigned int)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias_data, &bias_data_v[0], bias_data_v.size()); + } + std::vector C_v((const unsigned int)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type == CompInt8) { - data[i].QuantBDataWorkspace = packed_b_.get(); + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = tmp_scales_data_ptr.get(); + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data != nullptr ? &bias_data_v[0] : nullptr; + data[i].C = &C_v[0] + helper.OutputOffsets()[i]; + data[i].ldc = N; } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + MlasConvertFloatToHalfBuffer(&C_v[0], y_data, C_v.size()); + return Status::OK(); + } else { + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } #endif - data[i].PackedQuantBData = static_cast(packed_b_.get()); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].Bias = bias_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + return Status::OK(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); } } @@ -380,7 +446,17 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); + const float* scales_data_; + std::vector scales_data_v; + if constexpr (std::is_same::value) { + scales_data_v.resize((const unsigned int)scales->Shape().Size()); + MlasConvertHalfToFloatBuffer(scales_data, &scales_data_v[0], scales_data_v.size()); + scales_data_ = &scales_data_v[0]; + } else { + scales_data_ = scales_data; + } + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); @@ -391,12 +467,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -406,12 +482,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } else { ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! - if ((zero_points && zero_points->IsDataType())) { - DequantizeBlockwise( + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales - static_cast(zero_points_data), // quantization zero points + scales_data_, // quantization scales + static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -422,7 +498,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -436,40 +512,80 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); #endif + if constexpr (std::is_same::value) { + std::vector data(batch_count); - std::vector data(batch_count); - for (size_t i = 0; i < batch_count; i++) { - data[i].BIsPacked = false; - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = 1.f; - data[i].beta = 0.0f; - } + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); - // if there is a bias input, copy bias values into C and set beta to 1.0f - if (const Tensor* bias = ctx->Input(InputIndex::bias); - bias != nullptr) { - gsl::span bias_span = bias->DataAsSpan(); - for (size_t i = 0; i < batch_count; ++i) { - float* C_row = data[i].C; - const size_t ldc = data[i].ldc; - for (size_t m = 0; m < M; ++m) { - memcpy(C_row, bias_span.data(), bias_span.size_bytes()); - C_row += ldc; + auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = tmp_c_ptr.get() + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + auto tmp_bias_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias->Data(), tmp_bias_data_ptr.get(), static_cast(bias->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + std::copy(tmp_bias_data_ptr.get(), tmp_bias_data_ptr.get() + bias->Shape().Size(), C_row); + C_row += ldc; + } + data[i].beta = 1.0f; } + } - data[i].beta = 1.0f; + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, static_cast(y->Shape().Size())); + return Status::OK(); + } else { + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; } - } - MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), batch_count, thread_pool); + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + gsl::span bias_span = bias->DataAsSpan(); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + memcpy(C_row, bias_span.data(), bias_span.size_bytes()); + C_row += ldc; + } - return Status::OK(); + data[i].beta = 1.0f; + } + } + + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + + return Status::OK(); + } } ONNX_OPERATOR_KERNEL_EX( @@ -478,9 +594,9 @@ ONNX_OPERATOR_KERNEL_EX( 1, kCpuExecutionProvider, KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index b28f3758f89b5..6a19a741c3028 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -54,12 +54,12 @@ void Dequantize4BitsKernelReOrder( T scale = *(scale_data + n_idx * scales_shape_x + rid); float zp_f = 8; if (zero_points) { - if constexpr (std::is_same_v) { - zp_f = *(zero_points + n_idx * scales_shape_x + rid); - } else { + if constexpr (std::is_same_v) { uint8_t zp = 8; zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } else { + zp_f = *(zero_points + static_cast(n_idx) * static_cast(scales_shape_x) + static_cast(rid)); } } @@ -112,5 +112,10 @@ template void DequantizeBlockwise( const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 8b3156d77e57c..28ae64c4d5b3e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -20,6 +20,7 @@ Module Name: #include #include #include +#include // // Define the calling convention for Windows targets. @@ -1025,18 +1026,6 @@ MlasComputeTanh( size_t N ); -// -// Half-precision floating-point routines. -// - -void -MLASCALL -MlasConvertHalfToFloatBuffer( - const unsigned short* Source, - float* Destination, - size_t Count -); - // // Transpose routines. // @@ -1426,7 +1415,27 @@ using MLAS_FP16 = onnxruntime::MLFloat16; constexpr size_t FP16_SIZE = sizeof(uint16_t); -/** +// +// Half-precision floating-point routines. +// + +void +MLASCALL +MlasConvertHalfToFloatBuffer( + const MLAS_FP16* Source, + float* Destination, + size_t Count +); + +void +MLASCALL +MlasConvertFloatToHalfBuffer( +const float* Source, +MLAS_FP16* Destination, +size_t Count +); + + /** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL @@ -1787,6 +1796,7 @@ MlasTranspose( M, N); } + #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED /** * @brief Max Pooling for fp16 NHWC diff --git a/onnxruntime/core/mlas/lib/cast.cpp b/onnxruntime/core/mlas/lib/cast.cpp index 24af4064bbd9b..a6138e29bd796 100644 --- a/onnxruntime/core/mlas/lib/cast.cpp +++ b/onnxruntime/core/mlas/lib/cast.cpp @@ -23,37 +23,35 @@ union fp32_bits { void MLASCALL MlasConvertHalfToFloatBuffer( - const unsigned short* Source, + const MLAS_FP16* Source, float* Destination, size_t Count ) { - if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) { - // If there is no kernel use the reference implementation, adapted from mlas_float16.h. - constexpr fp32_bits magic = {113 << 23}; - constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + for (size_t i = 0; i < Count; ++i) { + Destination[i] = Source[i].ToFloat(); + } + } else { + // If the kernel is available, use it to perform the conversion. + GetMlasPlatform().CastF16ToF32Kernel(reinterpret_cast(Source), Destination, Count); + } +} +void +MLASCALL +MlasConvertFloatToHalfBuffer( + const float* Source, + MLAS_FP16* Destination, + size_t Count +) +{ + if (GetMlasPlatform().CastF32ToF16Kernel == nullptr) { for (size_t i = 0; i < Count; ++i) { - fp32_bits o; - o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits - uint32_t exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust - - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize - } - - o.u |= (Source[i] & 0x8000) << 16; // sign bit - Destination[i] = o.f; + Destination[i] = MLAS_FP16(Source[i]); } - } else { // If the kernel is available, use it to perform the conversion. - GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count); + GetMlasPlatform().CastF32ToF16Kernel(Source, reinterpret_cast(Destination), Count); } } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6f5db766b7def..8e8f46b8a102e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,13 +610,19 @@ void size_t N ); -typedef +typedef void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)( const unsigned short* Source, float* Destination, size_t Count ); +typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)( + const float* Source, + unsigned short* Destination, + size_t Count +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -880,6 +886,8 @@ extern "C" { #if defined(MLAS_TARGET_AMD64) MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse; MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx; + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx2; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; #endif } @@ -1165,6 +1173,7 @@ struct MLAS_PLATFORM { const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; + MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 4cd7faaa9e6ff..2b4d99800c546 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -245,6 +245,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; this->CastF16ToF32Kernel = nullptr; + this->CastF32ToF16Kernel = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) @@ -387,6 +388,9 @@ Return Value: this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + // // Check if the processor supports Hybrid core architecture. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 55d86bb9cc18e..baaa4ba1a3b1f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -29,6 +29,51 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +void +MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) +{ + size_t i = 0; + + // Process 16 elements at a time using AVX2 + for (; i + 15 < size; i += 16) { + // Load 16 FP16 values into an AVX2 register + __m256i fp16_values = _mm256_loadu_si256(reinterpret_cast(src_fp16 + i)); + + // Convert FP16 values to FP32 + __m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values)); + __m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1)); + + // Store the converted FP32 values into the output vector + _mm256_storeu_ps(dst_fp32 + i, fp32_values1); + _mm256_storeu_ps(dst_fp32 + i + 8, fp32_values2); + } + + // Process any remaining elements + const MLAS_FP16* fp16 = reinterpret_cast(src_fp16); + for (; i < size; ++i) { + dst_fp32[i] = fp16[i].ToFloat(); + } +} + +void +MlasCastF32ToF16KernelAvx2(const float* src_fp32, unsigned short* dst_fp16, size_t size) +{ + size_t i = 0; + + // Process 8 elements at a time using AVX2 + for (; i + 8 <= size; i += 8) { + __m256 fp32_chunk = _mm256_loadu_ps(&src_fp32[i]); + __m128i fp16_chunk = _mm256_cvtps_ph(fp32_chunk, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dst_fp16[i]), fp16_chunk); + } + + // Process any remaining elements + for (; i < size; ++i) { + MLAS_FP16 fp16(src_fp32[i]); + dst_fp16[i] = fp16.val; + } +} + MLAS_FORCEINLINE __m256 load_float_n_avx2(const float* data, int n) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index f2aaa75cadd8d..35f3b12aeba35 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -258,7 +258,7 @@ struct TensorCaster { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); - MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); + MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size); } }; diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 548f24e8ac69e..fa7c6bce7c23e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -262,8 +262,8 @@ void RunTest(const TestOptions& opts, } // namespace -TEST(MatMulNBits, Float32) { - // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); +template +void TestMatMulNBitsTyped() { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */ 1, 2, 32, 288}) { for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { @@ -276,30 +276,53 @@ TEST(MatMulNBits, Float32) { if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; + } else { + if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.01f; + } } { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { TestOptions opts = base_opts; opts.has_zero_point = true; - RunTest(opts); + RunTest(opts); } #if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) { TestOptions opts = base_opts; opts.has_g_idx = true; - RunTest(opts); + RunTest(opts); + } + + { + TestOptions opts = base_opts; + opts.has_g_idx = true; + opts.has_bias = true; + if constexpr (std::is_same::value) { + if (opts.accuracy_level == 0 || opts.accuracy_level == 1) { + // CI failure (not able to repro on either local machines): + // M:100, N:288, K:1234, block_size:16, accuracy_level:0, has_zero_point:0, zp_is_4bit:1, has_g_idx:1, has_bias:1 + // The difference between cur_expected[i] and cur_actual[i] is 1.0401010513305664e-05, which exceeds tolerance, + // tolerance evaluates to 1.006456386676291e-05. + opts.output_abs_error = 0.0001f; + } + } + // only enabled for CPU EP for now + std::vector> explicit_eps; + explicit_eps.emplace_back(DefaultCpuExecutionProvider()); + RunTest(opts, std::move(explicit_eps)); } { TestOptions opts = base_opts; opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); + RunTest(opts); } #endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) @@ -311,7 +334,7 @@ TEST(MatMulNBits, Float32) { std::vector> explicit_eps; explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - RunTest(opts, std::move(explicit_eps)); + RunTest(opts, std::move(explicit_eps)); } } } @@ -320,6 +343,21 @@ TEST(MatMulNBits, Float32) { } } +TEST(MatMulNBits, Float32) { + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); + TestMatMulNBitsTyped(); +} + +#ifdef MLAS_TARGET_AMD64_IX86 +#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +// Actual and expected difference is over 0.01 with DmlExecutionProvider. +// Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16) { + TestMatMulNBitsTyped(); +} +#endif +#endif + #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) namespace { @@ -367,7 +405,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura } } // namespace -TEST(MatMulNBits, Float16) { +TEST(MatMulNBits, Float16Cuda) { #if defined(USE_CUDA) || defined(USE_ROCM) auto has_gidx_options = {true, false}; #else From c63dd0234b4e0236b24fabdca005bbeb75ff4eb9 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 14 Sep 2024 12:36:20 +0800 Subject: [PATCH 21/26] [WebNN EP] Use opSupportLimits to dynamically check data type support (#22025) - Remove hard code data type checks and use WebNN's opSupportLimits instead - Add HasSupportedOutputsImpl for output data type validation - Get preferred layout info from opSupportLimits - Move Not op to logical_op_builder.cc because it should be there. This avoid the inconsistent input names in `unary_op_builder.cc`. --- .../core/providers/webnn/builders/helper.cc | 61 ++++++++++++++++--- .../core/providers/webnn/builders/helper.h | 43 +++++++++---- .../builders/impl/activation_op_builder.cc | 40 ------------ .../builders/impl/argmax_min_op_builder.cc | 27 -------- .../webnn/builders/impl/base_op_builder.cc | 52 ++++++++-------- .../webnn/builders/impl/base_op_builder.h | 9 ++- .../webnn/builders/impl/binary_op_builder.cc | 36 +++-------- .../webnn/builders/impl/cast_op_builder.cc | 32 +++++----- .../webnn/builders/impl/clip_op_builder.cc | 29 --------- .../webnn/builders/impl/concat_op_builder.cc | 28 +++++++++ .../webnn/builders/impl/conv_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gather_op_builder.cc | 26 +++----- .../webnn/builders/impl/gemm_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gru_op_builder.cc | 40 ++++-------- .../webnn/builders/impl/logical_op_builder.cc | 42 +++++++------ .../webnn/builders/impl/max_min_op_builder.cc | 29 ++++----- .../builders/impl/normalization_op_builder.cc | 35 ++++------- .../webnn/builders/impl/pad_op_builder.cc | 27 -------- .../builders/impl/reduction_op_builder.cc | 52 ---------------- .../webnn/builders/impl/resize_op_builder.cc | 26 -------- .../webnn/builders/impl/shape_op_builder.cc | 27 -------- .../webnn/builders/impl/slice_op_builder.cc | 26 -------- .../webnn/builders/impl/softmax_op_builder.cc | 26 -------- .../webnn/builders/impl/ternary_op_builder.cc | 23 ++----- .../builders/impl/transpose_op_builder.cc | 27 -------- .../webnn/builders/impl/unary_op_builder.cc | 43 ------------- .../providers/webnn/builders/model_builder.cc | 7 ++- .../providers/webnn/builders/model_builder.h | 5 +- .../providers/webnn/builders/op_builder.h | 3 +- .../webnn/builders/op_builder_factory.cc | 2 +- .../webnn/webnn_execution_provider.cc | 22 ++++--- .../webnn/webnn_execution_provider.h | 1 + 32 files changed, 281 insertions(+), 635 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index d3c1d06818db2..c4a633fcc92bb 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector& shape, const loggin return true; } -bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, - const WebnnDeviceType device_type, const logging::Logger& logger) { +bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { const auto& op_builders = GetOpBuilders(); if (Contains(op_builders, node.OpType())) { const auto* op_builder = op_builders.at(node.OpType()); - return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger); + return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger); } else { return false; } @@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -105,7 +106,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v // Firstly check if platform supports the WebNN op. if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; - supported = IsNodeSupported(*node, graph_viewer, device_type, logger); + supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); } LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() @@ -130,10 +131,54 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v return supported_node_groups; } -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types) { - return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != - supported_data_types.end(); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger) { + for (size_t i = 1; i < input_types.size(); i++) { + if (input_types[0] != input_types[i]) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same, but [" + << input_types[0] << "] does not match " + << input_types[i] << "]."; + return false; + } + } + return true; +} + +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) { + auto it = onnx_to_webnn_data_type_map.find(static_cast(onnx_data_type)); + if (it == onnx_to_webnn_data_type_map.end()) + return false; + + std::string webnn_data_type = it->second; + + // Check if WebNN supports the data type. + emscripten::val is_supported = webnn_supported_data_types.call("includes", + emscripten::val(webnn_data_type)); + return is_supported.as(); +} + +// Check if the input or output data type of ONNX node is supported by the WebNN operator. +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + std::string webnn_op_type; + if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) + return false; + + if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type + << "] " << onnx_input_output_name + << " type: [" << onnx_data_type + << "] is not supported for now"; + return false; + } + + return true; } bool GetBidirectionalBroadcastShape(std::vector& shape_a, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index b51092619db22..257fcff9ef50c 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger); static const InlinedHashMap op_map = { {"Abs", "abs"}, @@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn return true; } -static const std::unordered_set webnn_supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, +inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) { + auto it = op_map.find(op_type); + // Returns false if the op_type is not listed in the op_map. + if (it == op_map.end()) { + return false; + } + webnn_op_type = it->second; + return true; +} + +static const InlinedHashMap onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, }; -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger); +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types); +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 626aaf5c71b74..781ddcb896155 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi return true; } -bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - // WebNN relu op supports float32, float16, int32, int8 input data types. - if (op_type == "Relu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend does not support int32 data type for relu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { // Others only support float32 and float16. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 05f3a742a3775..d61ae1a1f6be7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index fa535889299ea..8da255a288f17 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { ORT_RETURN_IF_NOT( - IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger), - "Unsupported operator ", - node.OpType()); + IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), + model_builder.GetOpSupportLimits(), logger), + "Unsupported operator ", node.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; @@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& // Operator support related. bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, device_type, logger)) + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + if (!HasSupportedInputs(node, wnn_limits, logger)) + return false; + + if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; // We do not support external initializers for now. @@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { @@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d } } - // WebNN CPU backend (TFLite) will enable float16 input data type soon, - // temporarily fallback float16 input data type for WebNN CPU. - if (device_type == WebnnDeviceType::CPU) { - const auto& input = *node.InputDefs()[0]; - - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) - return false; - } - - return HasSupportedInputsImpl(node, device_type, logger); + return HasSupportedInputsImpl(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, - const WebnnDeviceType /* device_type */, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. const auto& input = *node.InputDefs()[0]; - + const auto& op_type = node.OpType(); int32_t input_type; if (!GetType(input, input_type, logger)) return false; - if (!IsSupportedDataType(input_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); +} + +bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + // We only check the type of output 0 by default, specific op builder can override this. + const auto& output = *node.OutputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) return false; - } - return true; + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 85e38b668cee4..584455f62cb4e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -28,7 +28,8 @@ class BaseOpBuilder : public IOpBuilder { // Operator support related. public: bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; protected: virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, @@ -36,8 +37,10 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const; // ONNX Runtime only *guarantees* support for models stamped // with opset version 7 or above for opset domain 'ai.onnx'. @@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 555de68cd60fe..af82a01b14de5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,7 +22,7 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -86,7 +86,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -97,36 +97,14 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevice !GetType(*input_defs[1], input1_type, logger)) return false; - std::unordered_set supported_data_types; - // WebNN prelu op only supports float32, float16, int32, int8 input data types. - if (op_type == "Prelu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend doesn't support int32 for prelu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { - supported_data_types = webnn_supported_data_types; - } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; - } - - return true; + std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a08e1681a8464..3c4fc822f3d01 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -80,26 +80,22 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input_type; -bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType device_type, - const logging::Logger& logger) const { - NodeAttrHelper helper(node); - // Check cast output type. - const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); - - // WebNN CPU backend doesn't support casting to uint64 data type. - if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) { - LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend."; + if (!GetType(*input_defs[0], input_type, logger)) return false; - } - if (!IsSupportedDataType(to_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << "."; + + if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger)) return false; - } - return true; + NodeAttrHelper helper(node); + // Check cast to type. + const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); + return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger); } void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index b5c3206072d50..374143c886849 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -25,8 +25,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,33 +92,6 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, }; } -bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index dedc76b80e978..48dd6f3beb020 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -19,6 +19,10 @@ class ConcatOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -52,6 +56,30 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + + if (!GetType(*input_defs[0], input0_type, logger)) + return false; + + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } + + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); +} + void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 76a8a178678df..35498c2e9b8b7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,7 +29,7 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -397,7 +397,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -415,35 +415,18 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Conv" || op_type == "ConvTranspose") { - // WebNN conv2d and convTranspose2d only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "ConvInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 23233539d34c7..ae9fe3e3f3bd1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,7 +22,7 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -69,29 +69,19 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t indices_type; + if (!GetType(input, input_type, logger) || + !GetType(indices, indices_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for gather. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index bd452b118fe3e..30e024792ed42 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,7 +25,7 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -215,7 +215,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -233,35 +233,18 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Gemm" || op_type == "MatMul") { - // WebNN gemm and matmul only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "MatMulInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 23cc7f1b11459..c92fe7366d494 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -208,37 +208,21 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp return false; } - std::unordered_set supported_data_types; - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend only support float32 input data type. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - }; - } else if (device_type == WebnnDeviceType::GPU) { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; + InlinedVector input_types = {input0_type, input1_type, input2_type}; + if (has_input3) { + input_types.push_back(input3_type); } - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input4) { + input_types.push_back(input4_type); } - - if (input0_type != input1_type || - input0_type != input2_type || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type) || - (has_input5 && input0_type != input5_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input5) { + input_types.push_back(input5_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 23f3a938fee5e..ea7f70b4598e6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,7 +21,7 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -29,9 +29,14 @@ class LogicalOpBuilder : public BaseOpBuilder { Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { + const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); - emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val input1 = emscripten::val::undefined(); + if (input_defs.size() > 1) { + input1 = model_builder.GetOperand(input_defs[1]->Name()); + } + emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); @@ -45,6 +50,8 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); + } else if (op_type == "Not") { + output = model_builder.GetBuilder().call("logicalNot", input0, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -61,7 +68,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2) { + if (input_defs.size() < 2 && op_type != "Not") { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " << input_defs.size(); return false; @@ -69,31 +76,27 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) - return false; - - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + if (!GetType(*input_defs[0], input0_type, logger)) return false; - } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + if (op_type != "Not") { + if (!GetType(*input_defs[1], input1_type, logger)) + return false; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + std::string onnx_input_name = op_type == "Not" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { @@ -107,6 +110,7 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& "GreaterOrEqual", "Less", "LessOrEqual", + "Not", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 5d88afda7b6a7..e111ca412c6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,7 +22,7 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -87,31 +87,28 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; - int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) + if (!GetType(*input_defs[0], input0_type, logger)) return false; - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; - } + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 4d068baf35e72..a3c6b8fdcea9b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,7 +25,7 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -182,7 +182,7 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -203,30 +203,21 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const Webn return false; } - // WebNN batchNormalization, instanceNormalization, layerNormalization - // only support float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + std::vector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index 071155a2fb372..d8373a45e4423 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -28,8 +28,6 @@ class PadOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -196,31 +194,6 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } // namespace webnn -bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for pad. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 3e6d4d9820e9a..93ad933d71c34 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -31,8 +31,6 @@ class ReductionOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -147,56 +145,6 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return true; } -bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "ReduceL1" || op_type == "ReduceProd" || - op_type == "ReduceSum" || op_type == "ReduceSumSquare") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, - }; - - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend doesn't support uint32 and uint64 for reduceL1, - // reduceProd, reduceSum and reduceSumSquare. - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" || - op_type == "ReduceLogSumExp" || op_type == "ReduceMean") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else { // ReduceMax and ReduceMin - supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 2218c858951d3..9dc79f4f52f46 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -35,8 +35,6 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Helper functions @@ -275,30 +273,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN resample2d op only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 0eb7dafdffe4d..6b56d2c740f40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -18,11 +18,6 @@ class ShapeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -69,28 +64,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; - if (!IsSupportedDataType(output_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Output type: [" << output_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index bef13841c646c..3f0d633ac888b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -29,8 +29,6 @@ class SliceOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -166,30 +164,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 input data type for slice. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 798cfabae65db..b1b737b114998 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -24,8 +24,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -63,30 +61,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN softmax only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 2ed8330bf25be..4b6cf312074ba 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,7 +18,7 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -46,7 +46,7 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -59,27 +59,14 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevic !GetType(*input_defs[2], input2_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 X, Y data type for where. - if (device_type == WebnnDeviceType::CPU && op_type == "Where") { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } // ONNX's condition data type is bool which is same as WebNN. // Only need to check X, Y data types. - if (!IsSupportedDataType(input1_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input1_type - << "] is not supported for now"; - return false; - } - - if (input1_type != input2_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input X, Y data types should be the same."; + std::array input_types{input1_type, input2_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 03c88ad9db88a..3a5e39f7f7a56 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -18,8 +18,6 @@ class TransposeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -50,31 +48,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for transpose. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 061404c8a9ce0..8e64e98445f03 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -18,8 +18,6 @@ class UnaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Add operator related. @@ -51,8 +49,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { output = model_builder.GetBuilder().call("neg", input, options); - } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { @@ -70,44 +66,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "Identity") { - supported_data_types = webnn_supported_data_types; - } else if (op_type == "Abs" || op_type == "Neg") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - } else if (op_type == "Not") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - }; - } else { // Others only support float32, float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; @@ -123,7 +81,6 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Identity", "Log", "Neg", - "Not", "Reciprocal", "Sin", "Sqrt", diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 44bec1fb6fd48..b58bf8233692e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -21,12 +21,13 @@ namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type) + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) { + wnn_device_type_(wnn_device_type), + wnn_limits_(wnn_limits) { // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() // is only allowed to be called once. wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); @@ -102,7 +103,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("dimensions", emscripten::val::array(dims)); auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, webnn_supported_data_types)) { + if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2d686070cdcc1..256337baeba7e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -23,7 +23,7 @@ class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type); + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -35,6 +35,8 @@ class ModelBuilder { const emscripten::val& GetBuilder() const { return wnn_builder_; } const emscripten::val& GetContext() const { return wnn_context_; } const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } + const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } + void AddOperand(const std::string& name, const emscripten::val& operand); const emscripten::val& GetZeroConstant(const std::string& data_type); // Use the buffers to persist WebNN allocated data like transposed weight. @@ -66,6 +68,7 @@ class ModelBuilder { emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; + emscripten::val wnn_limits_ = emscripten::val::undefined(); InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder.h b/onnxruntime/core/providers/webnn/builders/op_builder.h index 6ecc5d1068963..bb69a6a545597 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder.h @@ -29,7 +29,8 @@ class IOpBuilder { public: // Check if an operator is supported. virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const = 0; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const = 0; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 01761290f07e3..3dc1c7966ae41 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -25,7 +25,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Identity", op_registrations); CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); - CreateUnaryOpBuilder("Not", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); @@ -118,6 +117,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); CreateLogicalOpBuilder("LessOrEqual", op_registrations); + CreateLogicalOpBuilder("Not", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index b918daf838c99..b729623c5d3d8 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -21,10 +21,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { - preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { - preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { wnn_device_type_ = webnn::WebnnDeviceType::GPU; } else if (webnn_device_flags.compare("npu") == 0) { @@ -38,6 +36,17 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } + + // Retrieve the level of support for different WebNN operators. + // This varies across implementations and is obtained via the WebNN's opSupportLimits() function. + // https://www.w3.org/TR/webnn/#api-mlcontext-opsupportlimits + wnn_limits_ = wnn_context_.call("opSupportLimits"); + + if (wnn_limits_["preferredInputLayout"].as().compare("nhwc") == 0) { + preferred_layout_ = DataLayout::NHWC; + } else { + preferred_layout_ = DataLayout::NCHW; + } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -82,7 +91,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { @@ -213,7 +222,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); @@ -295,11 +304,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector Date: Sun, 15 Sep 2024 18:31:55 -0400 Subject: [PATCH 22/26] [java] Adding ability to load a model from a memory mapped byte buffer (#20062) ### Description Adds support for constructing an `OrtSession` from a `java.nio.ByteBuffer`. These buffers can be memory mapped from files which means there doesn't need to be copies of the model protobuf held in Java, reducing peak memory usage during session construction. ### Motivation and Context Reduces memory usage on model construction by not requiring as many copies on the Java side. Should help with #19599. --- .../java/ai/onnxruntime/OrtEnvironment.java | 49 ++++++++++++++++++- .../main/java/ai/onnxruntime/OrtSession.java | 35 +++++++++++++ .../main/native/ai_onnxruntime_OrtSession.c | 25 +++++++++- .../java/ai/onnxruntime/InferenceTest.java | 31 ++++++++++++ 4 files changed, 138 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 26137e88478b5..8382ef06e26e5 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -7,6 +7,7 @@ import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.EnumSet; import java.util.Objects; import java.util.logging.Logger; @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption return new OrtSession(this, modelPath, allocator, options); } + /** + * Create a session using the specified {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options) + throws OrtException { + return createSession(modelBuffer, defaultAllocator, options); + } + + /** + * Create a session using the default {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException { + return createSession(modelBuffer, new OrtSession.SessionOptions()); + } + + /** + * Create a session using the specified {@link SessionOptions} and model buffer. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param allocator The memory allocator to use. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + Objects.requireNonNull(modelBuffer, "model array must not be null"); + if (modelBuffer.remaining() == 0) { + throw new OrtException("Invalid model buffer, no elements remaining."); + } else if (!modelBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + return new OrtSession(this, modelBuffer, allocator, options); + } + /** * Create a session using the specified {@link SessionOptions}, model and the default memory * allocator. diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8fe73ff69e169..f87cbc76ef141 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -11,6 +11,7 @@ import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable { allocator); } + /** + * Creates a session reading the model from the supplied byte buffer. + * + *

Must be a direct byte buffer. + * + * @param env The environment. + * @param modelBuffer The model protobuf as a byte buffer. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred in native code. + */ + OrtSession( + OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + this( + createSession( + OnnxRuntime.ortApiHandle, + env.getNativeHandle(), + modelBuffer, + modelBuffer.position(), + modelBuffer.remaining(), + options.getNativeHandle()), + allocator); + } + /** * Private constructor to build the Java object wrapped around a native session. * @@ -514,6 +540,15 @@ private static native long createSession( private static native long createSession( long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; + private static native long createSession( + long apiHandle, + long envHandle, + ByteBuffer modelBuffer, + int bufferPos, + int bufferSize, + long optsHandle) + throws OrtException; + private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index f4d5ab080cd31..ee8cdee659296 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la return (jlong)session; } +/* + * Class: ai_onnxruntime_OrtSession + * Method: createSession + * Signature: (JJLjava/nio/ByteBuffer;IIJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*)envHandle; + OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle; + OrtSession* session = NULL; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + + // Create the session + checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session)); + + return (jlong)session; +} + /* * Class: ai_onnxruntime_OrtSession * Method: createSession diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 3340a2e5e9f3a..f76e1b3b20e19 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -20,10 +20,14 @@ import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; import java.io.File; import java.io.IOException; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.LongBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException { } } + @Test + public void createSessionFromByteBuffer() throws IOException, OrtException { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r"); + FileChannel channel = file.getChannel()) { + MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size()); + try (OrtSession.SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelBuffer, options)) { + assertNotNull(session); + assertEquals(1, session.getNumInputs()); // 1 input node + Map inputInfoList = session.getInputInfo(); + assertNotNull(inputInfoList); + assertEquals(1, inputInfoList.size()); + NodeInfo input = inputInfoList.get("data_0"); + assertEquals("data_0", input.getName()); // input node name + assertTrue(input.getInfo() instanceof TensorInfo); + TensorInfo inputInfo = (TensorInfo) input.getInfo(); + assertEquals(OnnxJavaType.FLOAT, inputInfo.type); + int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; + assertEquals(expectedInputDimensions.length, inputInfo.shape.length); + for (int i = 0; i < expectedInputDimensions.length; i++) { + assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); + } + } + } + } + @Test public void createSessionFromByteArray() throws IOException, OrtException { Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); From 6d7235ba5ab995e42a0e251874e65e9d7eaa2997 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 15 Sep 2024 21:55:38 -0400 Subject: [PATCH 23/26] [Java] Exposing SessionOptions.SetDeterministicCompute (#18998) ### Description Exposes `SetDeterministicCompute` in Java, added to the C API by #18944. ### Motivation and Context Parity between C and Java APIs. --- .../main/java/ai/onnxruntime/OrtSession.java | 17 +++++++++++++++++ .../ai_onnxruntime_OrtSession_SessionOptions.c | 13 +++++++++++++ .../test/java/ai/onnxruntime/InferenceTest.java | 1 + 3 files changed, 31 insertions(+) diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index f87cbc76ef141..6d146d5857d3c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -942,6 +942,20 @@ public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); } + /** + * Set whether to use deterministic compute. + * + *

Default is false. If set to true, this will enable deterministic compute for GPU kernels + * where possible. Note that this most likely will have a performance cost. + * + * @param value Should the compute be deterministic? + * @throws OrtException If there was an error in native code. + */ + public void setDeterministicCompute(boolean value) throws OrtException { + checkClosed(); + setDeterministicCompute(OnnxRuntime.ortApiHandle, nativeHandle, value); + } + /** * Disables the per session thread pools. Must be used in conjunction with an environment * containing global thread pools. @@ -1327,6 +1341,9 @@ private native void registerCustomOpsUsingFunction( private native void closeOptions(long apiHandle, long nativeHandle); + private native void setDeterministicCompute( + long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; + private native void addFreeDimensionOverrideByName( long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index ff9348c299e90..ff6b7fa703e6e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -259,6 +259,19 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSes checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel)); } +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setDeterministicCompute + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setDeterministicCompute + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean isDeterministic) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->SetDeterministicCompute(options, isDeterministic)); +} + /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: registerCustomOpLibrary diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f76e1b3b20e19..11141a3a65a3e 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1263,6 +1263,7 @@ public void testExtraSessionOptions() throws OrtException, IOException { options.setLoggerId("monkeys"); options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); options.setSessionLogVerbosityLevel(5); + options.setDeterministicCompute(true); Map configEntries = options.getConfigEntries(); assertTrue(configEntries.isEmpty()); options.addConfigEntry("key", "value"); From 1a1669fe817232e7d19c6459da0fc610e0c74b0a Mon Sep 17 00:00:00 2001 From: George Wu Date: Mon, 16 Sep 2024 09:12:13 -0700 Subject: [PATCH 24/26] use node name in transpose optimizer when adding nodes rather than optype (#22084) patch from @john-dance "The main change is simple: Use the original node name rather than the original node op_type when creating new nodes. Here are my comments on the change: ------ The onnx runtime uses the op_type as the basis for a new node name, so a node claimed by QNN EP might be named Conv_token_1 with no relation to the original /conv1/Conv. This patch: 1. Adds OpName as a virtual function in NodeRef and implements it in ApiNode. 2. AddNode now takes an op_name and op_type and passes them both to CreateNodeHelper. 3. CreateNodeHelper uses the op_name rather than the op_type in GenerateNodeName 4. Direct calls to AddNode are modified to either use the NodeRef if available, or just repeat the op_type if not available. The result is that the new nodes are named something like /conv1/Conv_token_1, allowing a straight forward mapping back to the original model node (if they exist in the original graph)." --- .../onnx_transpose_optimization.cc | 18 +++++++++--------- .../transpose_optimization/optimizer_api.h | 6 +++++- .../ort_optimizer_api_impl.cc | 17 +++++++++++------ .../internal_testing/internal_testing_tests.cc | 4 ++-- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index df81367c5bbee..5d689a9d933e8 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -78,7 +78,7 @@ static std::unique_ptr MakeNode1Attr(api::GraphRef& graph, std::st std::string_view input, std::string_view attr_name, const std::vector& attr_val) { std::vector inputs{input}; - std::unique_ptr node = graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + std::unique_ptr node = graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); node->SetAttributeInts(attr_name, attr_val); return node; } @@ -102,7 +102,7 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: std::vector inputs{input, axes_initializer}; - return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + return graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); } ///

@@ -136,7 +136,7 @@ static std::unique_ptr MakeQuantizeOp(api::GraphRef& graph, std::s std::optional block_size, std::optional output_dtype, std::optional saturate) { - std::unique_ptr node = graph.AddNode("QuantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("QuantizeLinear", "QuantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -170,7 +170,7 @@ static std::unique_ptr MakeDequantizeOp(api::GraphRef& graph, std: std::vector inputs, std::optional axis, std::optional block_size) { - std::unique_ptr node = graph.AddNode("DequantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("DequantizeLinear", "DequantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -1724,7 +1724,7 @@ static bool HandleShape(HandlerArgs& args) { // X -> Shape -> Y, Gather std::vector gather_inputs{"", perm_const}; - auto gather_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; gather.SetAttributeInt("axis", 0); @@ -1767,7 +1767,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con // inputs that would never be quantized. std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm); std::vector gather_inputs{input_name, gather_indices_const}; - auto gather_ptr = graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; std::string_view gather_output = gather.Outputs()[0]; graph.CopyValueInfo(input_name, gather_output); @@ -2215,7 +2215,7 @@ static bool HandleTile(HandlerArgs& args) { // Case 2: Repeats is computed. Insert Gather node. std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv); std::vector gather_inputs{repeats_inp, perm_inv_const}; - auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_node_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather_node = *gather_node_ptr; std::string_view gather_output = gather_node.Outputs()[0]; args.ctx.graph.CopyValueInfo(repeats_inp, gather_output); @@ -2265,7 +2265,7 @@ static void RemoveCancelingTransposeNodes(HandlerArgs& args) { // Worst-case scenario: Both parent output and 2nd transpose/reshape output cannot be removed (both graph outputs) // despite computing the same value. Use an Identity op instead. std::vector single_empty_input{""}; - auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1); + auto identity_ptr = args.ctx.graph.AddNode("Identity", "Identity", single_empty_input, /*num_outputs*/ 1); api::NodeRef& identity = *identity_ptr; args.ctx.graph.MoveOutput(args.node, 0, identity, 0); identity.SetInput(0, transpose_input); @@ -2297,7 +2297,7 @@ static bool HandleTransposeImpl(HandlerArgs& args, const std::vector& n // replace Reshape with Transpose to simplify the logic. // use the same input as the 1st Transpose, move the output from the Reshape to the new Transpose node, // and remove the Reshape node. - new_node = args.ctx.graph.AddNode("Transpose", {args.transpose.Inputs()[0]}, 1); + new_node = args.ctx.graph.AddNode("Transpose", "Transpose", {args.transpose.Inputs()[0]}, 1); args.ctx.graph.MoveOutput(args.node, 0, *new_node, 0); args.ctx.graph.RemoveNode(args.node); } else { diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 211734f4bacc8..7122aec45e61a 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -146,6 +146,9 @@ class ValueInfoRef { /// class NodeRef { public: + /// Node name + virtual std::string_view Name() const = 0; + /// Op computed by the node virtual std::string_view OpType() const = 0; @@ -361,6 +364,7 @@ class GraphRef { /// generated. Outputs of created node have unspecified shapes/dtypes. They will be populated afterwards using /// CopyValueInfo. /// + /// The new node's name /// The new node's op type /// Inputs for the node. "" for missing optional inputs. /// @@ -368,7 +372,7 @@ class GraphRef { /// /// The new node's domain. Empty string signifies default onnx domain. /// The new node - virtual std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + virtual std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain = /*kOnnxDomain*/ "") = 0; /// diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index 33408474f92a6..f87df746234fa 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -80,6 +80,10 @@ class ApiNode final : public api::NodeRef { return node_; } + std::string_view Name() const override { + return node_.Name(); + } + std::string_view OpType() const override { return node_.OpType(); } @@ -134,7 +138,7 @@ class ApiGraph final : public api::GraphRef { std::unique_ptr GetNodeProducingOutput(std::string_view name) const override; void TransposeInitializer(std::string_view name, const std::vector& perm) override; void ReshapeInitializer(std::string_view name, const std::vector& shape) override; - std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs = 1, std::string_view domain = "") override; std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, @@ -621,11 +625,12 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vectorSetShape(new_shape); } -static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type, +static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain, int since_version, std::string_view node_ep) { const std::string op_type_str(op_type); - std::string name = graph.GenerateNodeName(op_type_str); + const std::string op_name_str(op_name); + std::string name = graph.GenerateNodeName(op_name_str); std::vector input_args; std::vector output_args; @@ -731,11 +736,11 @@ static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view do return *since_version; } -std::unique_ptr ApiGraph::AddNode(std::string_view op_type, +std::unique_ptr ApiGraph::AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain) { int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap()); - Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs, + Node& node = CreateNodeHelper(graph_, name, op_type, inputs, num_outputs, domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : ""); return std::make_unique(node, graph_); @@ -744,7 +749,7 @@ std::unique_ptr ApiGraph::AddNode(std::string_view op_type, std::unique_ptr ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain, std::optional since_version) { const int new_node_since_version = since_version.has_value() ? *since_version : source_node.SinceVersion(); - Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(), + Node& node = CreateNodeHelper(graph_, source_node.Name(), op_type, source_node.Inputs(), source_node.Outputs().size(), domain, new_node_since_version, source_node.GetExecutionProviderType()); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 9f7be524daa34..67fb35d26e6dc 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -196,7 +196,7 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { // Error message should come from the Conv implementation with the statically registered kernel ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); } @@ -242,7 +242,7 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { std::vector fetches; ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); }; From e93f14e00d09b0c62ba0869bc87f14ee5f1cf4c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erick=20Mu=C3=B1oz?= Date: Mon, 16 Sep 2024 10:20:06 -0600 Subject: [PATCH 25/26] Check partial conversion on FP16 to FP32 AVX Cast kernel (#22091) ### Description Added checks to convert partial vectors in the early stages of the FP16 to FP32 cast using AVX NE CONVERT ISA. ### Motivation and Context Avoid storing data in sections outside of the output buffer, these checks are missing on the [original PR](https://github.com/microsoft/onnxruntime/pull/21183). This fix prevents memory corruption when the output buffer has a size [n*16 + 1, n*16 + 7] with 0< n --- onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm | 4 +++- onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm index c7f6342c527bf..800863c77a230 100644 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm @@ -54,7 +54,7 @@ HIGH_SELECTOR equ 00110001b LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT - test r8, r8 ; Check if we have any elements to convert + test r8, r8 ; Check if we have any elements to convert jz ExitRoutine cmp r8, 8 jb ConvertMaskedVectors @@ -80,6 +80,8 @@ Convert256Vectors: jz ExitRoutine ; If we are done, exit cmp r8, 16 ; If the vector is big enough, we go again jae Convert256Vectors + cmp r8, 8 ; Check if we have enough elements to convert + jb ConvertMaskedVectors diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S index 1a70061460e50..a4d730fa513ab 100644 --- a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S @@ -51,8 +51,6 @@ FUNCTION_ENTRY MlasCastF16ToF32KernelAvx test rdx, rdx // Check if we have any elements to convert jz ExitRoutine - -AVX_NE_CONVERT: cmp rdx, 8 jb ConvertMaskedVectors cmp rdx, 16 @@ -75,6 +73,8 @@ Convert256Vectors: jz ExitRoutine // If we are done, exit cmp rdx, 16 // If the vector is big enough, we go again jae Convert256Vectors + cmp rdx, 8 // Check if we have enough elements to convert + jb ConvertMaskedVectors From 291a5352b27ded5714e5748b381f2efb88f28fb9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:56:22 -0700 Subject: [PATCH 26/26] [js/web] remove training release (#22103) ### Description Remove training from onnxruntime-web Following up of #22082 --- js/web/lib/backend-wasm-inference.ts | 5 - js/web/lib/backend-wasm-training.ts | 29 - js/web/lib/backend-wasm.ts | 2 + js/web/lib/index.ts | 4 +- js/web/lib/wasm/session-handler-training.ts | 198 ------ js/web/lib/wasm/wasm-core-impl.ts | 9 +- js/web/lib/wasm/wasm-training-core-impl.ts | 631 ------------------ js/web/lib/wasm/wasm-types.ts | 76 +-- js/web/lib/wasm/wasm-utils-import.ts | 16 +- js/web/package.json | 7 - js/web/script/build.ts | 13 +- js/web/script/pull-prebuilt-wasm-artifacts.ts | 2 - js/web/test/training/e2e/browser-test-wasm.js | 21 - js/web/test/training/e2e/common.js | 248 ------- js/web/test/training/e2e/data/model.onnx | 16 - js/web/test/training/e2e/karma.conf.js | 54 -- js/web/test/training/e2e/package.json | 14 - js/web/test/training/e2e/run.js | 143 ---- .../test/training/e2e/simple-http-server.js | 67 -- js/web/types.d.ts | 4 - 20 files changed, 15 insertions(+), 1544 deletions(-) delete mode 100644 js/web/lib/backend-wasm-inference.ts delete mode 100644 js/web/lib/backend-wasm-training.ts delete mode 100644 js/web/lib/wasm/session-handler-training.ts delete mode 100644 js/web/lib/wasm/wasm-training-core-impl.ts delete mode 100644 js/web/test/training/e2e/browser-test-wasm.js delete mode 100644 js/web/test/training/e2e/common.js delete mode 100644 js/web/test/training/e2e/data/model.onnx delete mode 100644 js/web/test/training/e2e/karma.conf.js delete mode 100644 js/web/test/training/e2e/package.json delete mode 100644 js/web/test/training/e2e/run.js delete mode 100644 js/web/test/training/e2e/simple-http-server.js diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts deleted file mode 100644 index 7dfe7ee05a1d3..0000000000000 --- a/js/web/lib/backend-wasm-inference.ts +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts deleted file mode 100644 index 7332b3f97eba0..0000000000000 --- a/js/web/lib/backend-wasm-training.ts +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common'; - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training'; - -class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); - await handler.createTrainingSession( - checkpointStateUriOrBuffer, - trainModelUriOrBuffer, - evalModelUriOrBuffer, - optimizerModelUriOrBuffer, - options, - ); - return Promise.resolve(handler); - } -} - -export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 7bef538b26063..766937dc4c4cf 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } + +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 321394466b365..776c0d026bc97 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.DISABLE_TRAINING - ? require('./backend-wasm-inference').wasmBackend - : require('./backend-wasm-training').wasmBackend; + const wasmBackend = require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { registerBackend('webgpu', wasmBackend, 5); registerBackend('webnn', wasmBackend, 5); diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts deleted file mode 100644 index 8bbfb9cf06668..0000000000000 --- a/js/web/lib/wasm/session-handler-training.ts +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference'; -import { copyFromExternalBuffer } from './wasm-core-impl'; -import { - createCheckpointHandle, - createTrainingSessionHandle, - getContiguousParameters, - getModelInputOutputNames, - getParametersSize, - lazyResetGrad, - loadParametersBuffer, - releaseTrainingSessionAndCheckpoint, - runEvalStep, - runOptimizerStep, - runTrainStep, -} from './wasm-training-core-impl'; - -export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - private sessionId: number; - private checkpointId: number; - - inputNames: string[]; - outputNames: string[]; - - evalInputNames: string[] = []; - evalOutputNames: string[] = []; - - async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise { - let buffer: Uint8Array; - if (typeof uriOrBuffer === 'string') { - const response = await fetch(uriOrBuffer); - const arrayBuffer = await response.arrayBuffer(); - buffer = new Uint8Array(arrayBuffer); - } else { - buffer = uriOrBuffer; - } - return copyFromExternalBuffer(buffer); - } - - async createTrainingSession( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ) { - const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); - // 0 is supposed to be the nullptr - let evalModelData: SerializableInternalBuffer = [0, 0]; - let optimizerModelData: SerializableInternalBuffer = [0, 0]; - - if (evalModelUriOrBuffer !== '') { - evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); - } - if (optimizerModelUriOrBuffer !== '') { - optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); - } - - this.checkpointId = createCheckpointHandle(checkpointData); - this.sessionId = createTrainingSessionHandle( - this.checkpointId, - trainModelData, - evalModelData, - optimizerModelData, - options, - ); - [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); - if (evalModelUriOrBuffer !== '') { - [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); - } - } - - /** - * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the - * corresponding name as a number referring to the index in the list of names provided. - * - * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType - * @param names either inputNames or outputNames - * @returns a tuple of a list of values and a list of indices. - */ - convertMapIntoValuesArrayAndIndicesArray( - feeds: { [name: string]: T }, - names: string[], - mapFunc: (val: T, index: number) => U, - ): [T[], number[], U[]] { - const values: T[] = []; - const indices: number[] = []; - Object.entries(feeds).forEach((kvp) => { - const name = kvp[0]; - const tensor = kvp[1]; - const index = names.indexOf(name); - if (index === -1) { - throw new Error(`invalid input '${name}`); - } - values.push(tensor); - indices.push(index); - }); - - const uList = values.map(mapFunc); - return [values, indices, uList]; - } - - /** - * Helper method that converts the TensorMetadata that the wasm-core functions return to the - * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the - * corresponding result. - * - * @param results used to populate the resultMap if there is no value for that outputName already - * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results - * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. - * @returns a map of output names and OnnxValues. - */ - convertTensorMetadataToReturnType( - results: TensorMetadata[], - outputArray: Array, - outputIndices: number[], - ): SessionHandler.ReturnType { - const resultMap: SessionHandler.ReturnType = {}; - for (let i = 0; i < results.length; i++) { - resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); - } - return resultMap; - } - - async lazyResetGrad(): Promise { - await lazyResetGrad(this.sessionId); - } - - async runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.inputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.outputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async runOptimizerStep(options: InferenceSession.RunOptions): Promise { - await runOptimizerStep(this.sessionId, options); - } - - async runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.evalInputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.evalOutputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async getParametersSize(trainableOnly: boolean): Promise { - return getParametersSize(this.sessionId, trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { - await loadParametersBuffer(this.sessionId, array, trainableOnly); - } - async getContiguousParameters(trainableOnly: boolean): Promise { - const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); - return decodeTensorMetadata(tensorResult); - } - - async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); - } -} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6c4e28df62f23..ed001cfa90f59 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -41,8 +41,8 @@ import { loadFile } from './wasm-utils-load-file'; * Refer to web/lib/index.ts for the backend registration. * * 2. WebAssembly artifact initialization. - * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or - * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings: + * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is + * called). In this step, onnxruntime-web does the followings: * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled. * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated * JavaScript code to initialize the WebAssembly runtime. @@ -57,9 +57,8 @@ import { loadFile } from './wasm-utils-load-file'; * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step. * * 4. Session initialization. - * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3 - * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the - * followings: + * This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once), + * this step will be done for each session. In this step, onnxruntime-web does the followings: * If the parameter is a URL: * - download the model data from the URL. * - copy the model data to the WASM heap. (proxy: 'copy-from') diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts deleted file mode 100644 index 22cd6ec30732c..0000000000000 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, Tensor } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { setRunOptions } from './run-options'; -import { setSessionOptions } from './session-options'; -import { - dataLocationStringToEnum, - tensorDataTypeEnumToString, - tensorDataTypeStringToEnum, - tensorTypeToTypedArrayConstructor, -} from './wasm-common'; -import { prepareInputOutputTensor } from './wasm-core-impl'; -import { getInstance } from './wasm-factory'; -import { checkLastError } from './wasm-utils'; - -const NO_TRAIN_FUNCS_MSG = - "Built without training API's enabled. Use the onnxruntime-web/training import for training " + - 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + - 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; - -/** - * Runs the checkLastError function which will throw an error, if the provided error code matches the specified - * pattern for an error code. - * @param errCode number to evaluated for if it's an error - * @param message message to pass into checkLastError - * @param checkNeqZero when true, treats not equal to zero as an error. - * When false, treats equal to zero as an error. - */ -const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { - if (checkNeqZero && errCode !== 0) { - checkLastError(message); - } else if (!checkNeqZero && errCode === 0) { - checkLastError(message); - } -}; - -export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { - const wasm = getInstance(); - - const [checkpointDataOffset, checkpointDataLength] = checkpointData; - let checkpointHandle = 0; - - try { - if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); - return checkpointHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { - wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); - } - throw e; - } finally { - // free buffer from wasm heap - wasm._OrtFree(checkpointData[0]); - } -}; - -const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetModelInputOutputCount) { - const errorCode = wasm._OrtTrainingGetModelInputOutputCount( - trainingSessionId, - dataOffset, - dataOffset + 4, - isEvalModel, - ); - ifErrCodeCheckLastError(errorCode, "Can't get session input/output count."); - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -const getModelInputOutputNamesLoop = ( - trainingSessionId: number, - count: number, - isInput: boolean, - isEvalModel: boolean, -): string[] => { - const names = []; - const wasm = getInstance(); - - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetModelInputOutputName) { - const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - - names.push(wasm.UTF8ToString(name)); - wasm._free(name); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return names; -}; - -export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { - let inputNames: string[] = []; - let outputNames: string[] = []; - - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - - inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); - outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - - return [inputNames, outputNames]; -}; - -export const createTrainingSessionHandle = ( - checkpointHandle: number, - trainModelData: SerializableInternalBuffer, - evalModelData: SerializableInternalBuffer, - optimizerModelData: SerializableInternalBuffer, - options: InferenceSession.SessionOptions, -): number => { - const wasm = getInstance(); - - let trainingSessionHandle = 0; - let sessionOptionsHandle = 0; - let allocs: number[] = []; - - try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); - if (wasm._OrtTrainingCreateSession) { - trainingSessionHandle = wasm._OrtTrainingCreateSession( - sessionOptionsHandle, - checkpointHandle, - trainModelData[0], - trainModelData[1], - evalModelData[0], - evalModelData[1], - optimizerModelData[0], - optimizerModelData[1], - ); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - return trainingSessionHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { - wasm._OrtTrainingReleaseSession(trainingSessionHandle); - } - throw e; - } finally { - wasm._free(trainModelData[0]); - wasm._free(evalModelData[0]); - wasm._free(optimizerModelData[0]); - - if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); - } - allocs.forEach((alloc) => wasm._free(alloc)); - } -}; - -/** - * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the - * WASM tensors. - * - * @param trainingSessionId - * @param indices for each tensor, the index of the input or output name that the tensor corresponds with - * @param tensors list of TensorMetaData - * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting - * handles of the allocated tensors on the heap - * @param inputOutputAllocs modified in-place by this method - * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor - */ -const createAndAllocateTensors = ( - trainingSessionId: number, - indices: number[], - tensors: Array, - tensorHandles: number[], - inputOutputAllocs: number[], - indexAdd: number, -) => { - const count = indices.length; - - // creates the tensors - for (let i = 0; i < count; i++) { - prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); - } - - // moves to heap - const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; - for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; - } - - return valuesOffset; -}; - -/** - * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information - * associated with the tensor handle. - * - * @param outputValuesOffset - * @param outputCount - * @returns list of TensorMetadata retrieved from the output handles. - */ -const moveOutputToTensorMetadataArr = ( - outputValuesOffset: number, - outputCount: number, - outputTensorHandles: number[], - outputTensors: Array, -) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - if (tensor === outputTensorHandles[i]) { - // output tensor is pre-allocated. no need to copy data. - output.push(outputTensors[i]!); - continue; - } - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let type: Tensor.Type | undefined, - dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, - tensorDataOffset, - tensorDataOffset + 4, - tensorDataOffset + 8, - tensorDataOffset + 12, - ); - ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength), - ); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - wasm._OrtReleaseTensor(tensor); - } - } - - return output; -}; - -export const lazyResetGrad = async (trainingSessionId: number): Promise => { - const wasm = getInstance(); - - if (wasm._OrtTrainingLazyResetGrad) { - const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); - ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } -}; - -export const runTrainStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingRunTrainStep) { - const errorCode = wasm._OrtTrainingRunTrainStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runOptimizerStep = async ( - trainingSessionId: number, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - try { - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - if (wasm._OrtTrainingOptimizerStep) { - const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); - ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runEvalStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingEvalStep) { - const errorCode = wasm._OrtTrainingEvalStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - try { - const sizeOffset = wasm.stackAlloc(4); - if (wasm._OrtTrainingGetParametersSize) { - const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); - ifErrCodeCheckLastError(errorCode, "Can't get parameters size"); - - return wasm.HEAP32[sizeOffset / 4]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -export const getContiguousParameters = async ( - trainingSessionId: number, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - const parametersSize = getParametersSize(trainingSessionId, trainableOnly); - let tensor = 0; - - // allocates a buffer of the correct size on the WASM heap - const paramsByteLength = 4 * parametersSize; - const paramsOffset = wasm._malloc(paramsByteLength); - - // handles the dimensions-related createTensor parameters - const dims = [parametersSize]; - - const dimsOffset = wasm.stackAlloc(4); - const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; - - try { - // wraps allocated array in a tensor - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - paramsOffset, - paramsByteLength, - dimsOffset, - dims.length, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError( - tensor, - `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, - false, - ); - - if (wasm._OrtTrainingCopyParametersToBuffer) { - const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); - const data = new typedArrayConstructor(parametersSize); - const output: TensorMetadata[] = []; - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength), - ); - output.push([tensorTypeAsString, dims, data, locationAsString]); - if (output.length !== 1) { - throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of - one, got ${output.length}`); - } else { - return output[0]; - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm._free(paramsOffset); - wasm._free(dimsOffset); - wasm.stackRestore(stack); - } -}; - -export const loadParametersBuffer = async ( - trainingSessionId: number, - buffer: Uint8Array, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - // allocates & copies JavaScript buffer to WASM heap - const bufferByteLength = buffer.length; - const bufferCount = bufferByteLength / 4; - const bufferOffset = wasm._malloc(bufferByteLength); - wasm.HEAPU8.set(buffer, bufferOffset); - - // allocates and handles moving dimensions information to WASM memory - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; - const dimsLength = 1; - let tensor = 0; - - try { - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - bufferOffset, - bufferByteLength, - dimsOffset, - dimsLength, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); - - if (wasm._OrtTrainingCopyParametersFromBuffer) { - const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm.stackRestore(stack); - wasm._free(bufferOffset); - wasm._free(dimsOffset); - } -}; - -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { - const wasm = getInstance(); - - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } -}; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70b6cceab0eef..828cd3cfd94fa 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -213,84 +213,10 @@ export interface OrtInferenceAPIs { _OrtEndProfiling(sessionHandle: number): number; } -export interface OrtTrainingAPIs { - _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - - _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - - _OrtTrainingCreateSession( - sessionOptionsHandle: number, - checkpointHandle: number, - trainOffset: number, - trainLength: number, - evalOffset: number, - evalLength: number, - optimizerOffset: number, - optimizerLength: number, - ): number; - - _OrtTrainingLazyResetGrad(trainingHandle: number): number; - - _OrtTrainingRunTrainStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - - _OrtTrainingEvalStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - - _OrtTrainingCopyParametersToBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingCopyParametersFromBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingGetModelInputOutputCount( - trainingHandle: number, - inputCount: number, - outputCount: number, - isEvalModel: boolean, - ): number; - _OrtTrainingGetModelInputOutputName( - trainingHandle: number, - index: number, - isInput: boolean, - isEvalModel: boolean, - ): number; - - _OrtTrainingReleaseSession(trainingHandle: number): void; -} - /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule - extends EmscriptenModule, - OrtInferenceAPIs, - Partial, - Partial { +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index 008b9b41b1592..bd9e0ce083ef0 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -135,11 +135,9 @@ const embeddedWasmModule: EmscriptenModuleFactory | undefined = BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT ? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires require( - !BUILD_DEFS.DISABLE_TRAINING - ? '../../dist/ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' - : '../../dist/ort-wasm-simd-threaded.mjs', + !BUILD_DEFS.DISABLE_JSEP + ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', ).default : undefined; @@ -163,11 +161,9 @@ export const importWasmModule = async ( if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) { return [undefined, embeddedWasmModule!]; } else { - const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING - ? 'ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? 'ort-wasm-simd-threaded.jsep.mjs' - : 'ort-wasm-simd-threaded.mjs'; + const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP + ? 'ort-wasm-simd-threaded.jsep.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. diff --git a/js/web/package.json b/js/web/package.json index 94dd047915b05..d770499adada4 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -23,7 +23,6 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", - "test:training:e2e": "node ./test/training/e2e/run", "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", @@ -101,12 +100,6 @@ "import": "./dist/ort.webgpu.bundle.min.mjs", "require": "./dist/ort.webgpu.min.js", "types": "./types.d.ts" - }, - "./training": { - "node": null, - "import": "./dist/ort.training.wasm.min.mjs", - "require": "./dist/ort.training.wasm.min.js", - "types": "./types.d.ts" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6d1b3bdb65068..408f9e00a5cbd 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -56,7 +56,6 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_JSEP': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', - 'BUILD_DEFS.DISABLE_TRAINING': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'false', 'BUILD_DEFS.IS_ESM': 'false', @@ -253,7 +252,7 @@ async function buildBundle(options: esbuild.BuildOptions) { * * The distribution code is split into multiple files: * - [output-name][.min].[m]js - * - ort[-training]-wasm-simd-threaded[.jsep].mjs + * - ort-wasm-simd-threaded[.jsep].mjs */ async function buildOrt({ isProduction = false, @@ -630,16 +629,6 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, }); - // ort.training.wasm[.min].[m]js - await addAllWebBuildTasks({ - outputName: 'ort.training.wasm', - define: { - ...DEFAULT_DEFINE, - 'BUILD_DEFS.DISABLE_TRAINING': 'false', - 'BUILD_DEFS.DISABLE_JSEP': 'true', - 'BUILD_DEFS.DISABLE_WEBGL': 'true', - }, - }); } if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') { diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index b1b2fa26b2351..5b8b0d27c88db 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -149,11 +149,9 @@ downloadJson( void jszip.loadAsync(buffer).then((zip) => { extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName); }); }); }, diff --git a/js/web/test/training/e2e/browser-test-wasm.js b/js/web/test/training/e2e/browser-test-wasm.js deleted file mode 100644 index 05750ed149303..0000000000000 --- a/js/web/test/training/e2e/browser-test-wasm.js +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -describe('Browser E2E testing for training package', function () { - it('Check that training package encompasses inference', async function () { - ort.env.wasm.numThreads = 1; - await testInferenceFunction(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, all options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, minimum options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] }); - }); -}); diff --git a/js/web/test/training/e2e/common.js b/js/web/test/training/e2e/common.js deleted file mode 100644 index 0574ae85aabd1..0000000000000 --- a/js/web/test/training/e2e/common.js +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const DATA_FOLDER = 'data/'; -const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; -const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; -const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; -const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; - -const trainingSessionAllOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, - evalModel: TRAININGDATA_EVAL_MODEL, - optimizerModel: TRAININGDATA_OPTIMIZER_MODEL, -}; - -const trainingSessionMinOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, -}; - -// ASSERT METHODS - -function assert(cond) { - if (!cond) throw new Error(); -} - -function assertStrictEquals(actual, expected) { - if (actual !== expected) { - let strRep = actual; - if (typeof actual === 'object') { - strRep = JSON.stringify(actual); - } - throw new Error(`expected: ${expected}; got: ${strRep}`); - } -} - -function assertTwoListsUnequal(list1, list2) { - if (list1.length !== list2.length) { - return; - } - for (let i = 0; i < list1.length; i++) { - if (list1[i] !== list2[i]) { - return; - } - } - throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); -} - -// HELPER METHODS FOR TESTS - -function generateGaussianRandom(mean = 0, scale = 1) { - const u = 1 - Math.random(); - const v = Math.random(); - const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); - return z * scale + mean; -} - -function generateGaussianFloatArray(length) { - const array = new Float32Array(length); - - for (let i = 0; i < length; i++) { - array[i] = generateGaussianRandom(); - } - - return array; -} - -/** - * creates the TrainingSession and verifies that the input and output names of the training model loaded into the - * training session are correct. - * @param {} ort - * @param {*} createOptions - * @param {*} options - * @returns - */ -async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { - const trainingSession = await ort.TrainingSession.create(createOptions, options); - - assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); - assertStrictEquals(trainingSession.trainingInputNames.length, 2); - assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.trainingOutputNames.length, 1); - return trainingSession; -} - -/** - * verifies that the eval input and output names associated with the eval model loaded into the given training session - * are correct. - */ -function checkEvalModel(trainingSession) { - assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); - assertStrictEquals(trainingSession.evalInputNames.length, 2); - assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.evalOutputNames.length, 1); -} - -/** - * Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if - * accessed - * @param {} trainingSession - */ -function checkNoEvalModel(trainingSession) { - try { - assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } - try { - assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } -} - -/** - * runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length - * of 1 for the loss. - * @param {} trainingSession - * @param {*} feeds - * @returns - */ -var runTrainStepAndCheck = async function (trainingSession, feeds) { - const results = await trainingSession.runTrainStep(feeds); - assertStrictEquals(Object.keys(results).length, 1); - assertStrictEquals(results['onnx::loss::21273'].data.length, 1); - assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); - return results; -}; - -var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) { - // make a float32 array that is filled with the constant - const newParams = new Float32Array(paramsLength); - for (let i = 0; i < paramsLength; i++) { - newParams[i] = constant; - } - - const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); - - await trainingSession.loadParametersBuffer(newParamsUint8); - const paramsAfterLoad = await trainingSession.getContiguousParameters(); - - // check that the parameters have changed - assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); - assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); - - // check that the parameters have changed to what they should be - for (let i = 0; i < paramsLength; i++) { - // round to the same number of digits (4 decimal places) - assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); - } - - return paramsAfterLoad; -}; - -// TESTS - -var testInferenceFunction = async function (ort, options) { - const session = await ort.InferenceSession.create('data/model.onnx', options || {}); - - const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); - - const fetches = await session.run({ - a: new ort.Tensor('float32', dataA, [3, 4]), - b: new ort.Tensor('float32', dataB, [4, 3]), - }); - - const c = fetches.c; - - assert(c instanceof ort.Tensor); - assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); - assert(c.data[0] === 700); - assert(c.data[1] === 800); - assert(c.data[2] === 900); - assert(c.data[3] === 1580); - assert(c.data[4] === 1840); - assert(c.data[5] === 2100); - assert(c.data[6] === 2460); - assert(c.data[7] === 2880); - assert(c.data[8] === 3300); -}; - -var testTrainingFunctionMin = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); - checkNoEvalModel(trainingSession); - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - const feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - await runTrainStepAndCheck(trainingSession, feeds); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); -}; - -var testTrainingFunctionAll = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); - checkEvalModel(trainingSession); - - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - let feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - const results = await runTrainStepAndCheck(trainingSession, feeds); - - await trainingSession.runOptimizerStep(feeds); - feeds = { 'input-0': input0, labels: labels }; - // check getContiguousParameters after optimizerStep -- that the parameters have been updated - const optimizedParams = await trainingSession.getContiguousParameters(); - assertTwoListsUnequal(originalParams.data, optimizedParams.data); - - const results2 = await runTrainStepAndCheck(trainingSession, feeds); - - // check that loss decreased after optimizer step and training again - assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); -}; - -if (typeof module === 'object') { - module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; -} diff --git a/js/web/test/training/e2e/data/model.onnx b/js/web/test/training/e2e/data/model.onnx deleted file mode 100644 index 088124bd48624..0000000000000 --- a/js/web/test/training/e2e/data/model.onnx +++ /dev/null @@ -1,16 +0,0 @@ - backend-test:b - -a -bc"MatMultest_matmul_2dZ -a -  - -Z -b -  - -b -c -  - -B \ No newline at end of file diff --git a/js/web/test/training/e2e/karma.conf.js b/js/web/test/training/e2e/karma.conf.js deleted file mode 100644 index 74662b67676f7..0000000000000 --- a/js/web/test/training/e2e/karma.conf.js +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const args = require('minimist')(process.argv.slice(2)); -const SELF_HOST = !!args['self-host']; -const ORT_MAIN = args['ort-main']; -const TEST_MAIN = args['test-main']; -if (typeof TEST_MAIN !== 'string') { - throw new Error('flag --test-main= is required'); -} -const USER_DATA = args['user-data']; -if (typeof USER_DATA !== 'string') { - throw new Error('flag --user-data= is required'); -} - -module.exports = function (config) { - const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; - config.set({ - frameworks: ['mocha'], - files: [ - { pattern: distPrefix + ORT_MAIN }, - { pattern: './common.js' }, - { pattern: TEST_MAIN }, - { pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true }, - { pattern: './data/*', included: false }, - ], - plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], - proxies: { - '/model.onnx': '/base/model.onnx', - '/data/': '/base/data/', - }, - client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } }, - reporters: ['mocha'], - captureTimeout: 120000, - reportSlowerThan: 100, - browserDisconnectTimeout: 600000, - browserNoActivityTimeout: 300000, - browserDisconnectTolerance: 0, - browserSocketTimeout: 60000, - hostname: 'localhost', - browsers: [], - customLaunchers: { - Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA }, - Chrome_no_threads: { - base: 'ChromeHeadless', - chromeDataDir: USER_DATA, - // TODO: no-thread flags - }, - Edge_default: { base: 'Edge', edgeDataDir: USER_DATA }, - }, - }); -}; diff --git a/js/web/test/training/e2e/package.json b/js/web/test/training/e2e/package.json deleted file mode 100644 index 5f11a27de6dfc..0000000000000 --- a/js/web/test/training/e2e/package.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "devDependencies": { - "@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", - "fs-extra": "^11.1.0", - "globby": "^13.1.3", - "karma": "^6.4.1", - "karma-chrome-launcher": "^3.1.1", - "karma-mocha": "^2.0.1", - "karma-mocha-reporter": "^2.2.5", - "light-server": "^2.9.1", - "minimist": "^1.2.7", - "mocha": "^10.2.0" - } -} diff --git a/js/web/test/training/e2e/run.js b/js/web/test/training/e2e/run.js deleted file mode 100644 index d12bcc7aa66ed..0000000000000 --- a/js/web/test/training/e2e/run.js +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const path = require('path'); -const fs = require('fs-extra'); -const { spawn } = require('child_process'); -const startServer = require('./simple-http-server'); -const minimist = require('minimist'); - -// copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file -// exists in its parent folder. -// here we use /build/js/e2e-training/ for the test - -const TEST_E2E_SRC_FOLDER = __dirname; -const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..'); -const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training'); -const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache'); -const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data'); -fs.emptyDirSync(TEST_E2E_RUN_FOLDER); -fs.emptyDirSync(NPM_CACHE_FOLDER); -fs.emptyDirSync(CHROME_USER_DATA_FOLDER); -fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); - -// training data to copy -const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..'); -const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api'); -const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); - -// always use a new folder as user-data-dir -let nextUserDataDirId = 0; -function getNextUserDataDir() { - const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()); - nextUserDataDirId++; - fs.emptyDirSync(dir); - return dir; -} - -// commandline arguments -const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; - -async function main() { - // find packed package - const { globbySync } = await import('globby'); - - const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); - const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER }); - - const PACKAGES_TO_INSTALL = []; - - if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) { - PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0])); - } else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) { - throw new Error('multiple packages found for onnxruntime-common.'); - } - - const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); - const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER }); - if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { - throw new Error('cannot find exactly single package for onnxruntime-web.'); - } - PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0])); - - // we start here: - - // install dev dependencies - await runInShell(`npm install`); - - // npm install with "--cache" to install packed packages with an empty cache folder - await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`); - - // prepare training data - prepareTrainingDataByCopying(); - - console.log('==============================================================='); - console.log('Running self-hosted tests'); - console.log('==============================================================='); - // test cases with self-host (ort hosted in same origin) - await testAllBrowserCases({ hostInKarma: true }); - - console.log('==============================================================='); - console.log('Running not self-hosted tests'); - console.log('==============================================================='); - // test cases without self-host (ort hosted in cross origin) - const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081); - try { - await testAllBrowserCases({ hostInKarma: false }); - } finally { - // close the server after all tests - await server.close(); - } -} - -async function testAllBrowserCases({ hostInKarma }) { - await runKarma({ hostInKarma, main: './browser-test-wasm.js' }); -} - -async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) { - console.log('==============================================================='); - console.log(`Running karma with the following binary: ${ortMain}`); - console.log('==============================================================='); - const selfHostFlag = hostInKarma ? '--self-host' : ''; - await runInShell( - `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ - ortMain - } --test-main=${main} --user-data=${getNextUserDataDir()}`, - ); -} - -async function runInShell(cmd) { - console.log('==============================================================='); - console.log(' Running command in shell:'); - console.log(' > ' + cmd); - console.log('==============================================================='); - let complete = false; - const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER }); - childProcess.on('close', function (code) { - if (code !== 0) { - process.exit(code); - } else { - complete = true; - } - }); - while (!complete) { - await delay(100); - } -} - -async function delay(ms) { - return new Promise(function (resolve) { - setTimeout(function () { - resolve(); - }, ms); - }); -} - -function prepareTrainingDataByCopying() { - fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST); - console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`); -} - -main(); diff --git a/js/web/test/training/e2e/simple-http-server.js b/js/web/test/training/e2e/simple-http-server.js deleted file mode 100644 index ef9cced681cc8..0000000000000 --- a/js/web/test/training/e2e/simple-http-server.js +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -// this is a simple HTTP server that enables CORS. -// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework - -const http = require('http'); -const fs = require('fs'); -const path = require('path'); - -const getRequestData = (url, dir) => { - const pathname = new URL(url, 'http://localhost').pathname; - - let filepath; - let mimeType; - if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/')) { - filepath = path.resolve(dir, pathname.substring(1)); - } else { - return null; - } - - if (filepath.endsWith('.wasm')) { - mimeType = 'application/wasm'; - } else if (filepath.endsWith('.js') || filepath.endsWith('.mjs')) { - mimeType = 'text/javascript'; - } else { - return null; - } - - return [filepath, mimeType]; -}; - -module.exports = function (dir, port) { - const server = http - .createServer(function (request, response) { - const url = request.url.replace(/\n|\r/g, ''); - console.log(`request ${url}`); - - const requestData = getRequestData(url, dir); - if (!request || !requestData) { - response.writeHead(404); - response.end('404'); - } else { - const [filePath, contentType] = requestData; - fs.readFile(path.resolve(dir, filePath), function (error, content) { - if (error) { - if (error.code == 'ENOENT') { - response.writeHead(404); - response.end('404'); - } else { - response.writeHead(500); - response.end('500'); - } - } else { - response.setHeader('access-control-allow-origin', '*'); - response.writeHead(200, { 'Content-Type': contentType }); - response.end(content, 'utf-8'); - } - }); - } - }) - .listen(port); - console.log(`Server running at http://localhost:${port}/`); - return server; -}; diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 735b6a89a2a86..b82248c0c83b8 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -20,7 +20,3 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } - -declare module 'onnxruntime-web/training' { - export * from 'onnxruntime-web'; -}