From 6fe1a53c9b76e0f177eb7068f08bd00f78542c98 Mon Sep 17 00:00:00 2001 From: Alexey Smirnov Date: Mon, 26 Aug 2024 12:16:04 +0100 Subject: [PATCH] [NPUW] Add a new i4 pattern (#26212) 1) CWAI3 have been renamed to Reshape3 2) Added Reshape4 pattern 3) Supported 3-dim scale shapes in i4f16 unpack Co-authored-by: Dmitry Matveev --- .../plugin/npuw/partitioning/partitioning.cpp | 7 +- .../npuw/partitioning/patterns/dcoff.cpp | 90 ++++++++++++++++++- .../npuw/partitioning/patterns/dcoff.hpp | 9 +- .../intel_npu/src/plugin/npuw/util.cpp | 20 ++++- 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp index 1325fc1f1d1dd0..6fef5d8b6fdf94 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp @@ -1624,8 +1624,11 @@ void Partitioner::decompressionCutOff(const std::string& func_name) { // LLaMaGPTQ rewr.add_matcher(dcoff_mode, dcoff_type, std::ref(params_to)); - // Phi-3 4SymW16A/GPTQ - rewr.add_matcher(dcoff_mode, dcoff_type, std::ref(params_to)); + // Phi-3 4SymW16A + rewr.add_matcher(dcoff_mode, dcoff_type, std::ref(params_to)); + + // Phi-3 i4 4SymW16A + rewr.add_matcher(dcoff_mode, dcoff_type, std::ref(params_to)); // Asymmetric zeropoints rewr.add_matcher(dcoff_mode, dcoff_type, std::ref(params_to)); diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp index 857bcd9c93ba56..ffbece94b04176 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp @@ -536,7 +536,7 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco register_matcher(std::make_shared(reshpe, "TagDCOFFReshape2"), std::move(callback)); } -// Pattern: Phi-3 4SymW16A/GPTQ +// Pattern: Phi-3 4SymW16A // // // "tensor" "scale" > "tensor" @@ -555,7 +555,7 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco // V > // Convert -DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) { +DCOFFPassReshape3::DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) { auto paramA = opp::wrap_type(); auto paramC = opp::wrap_type(); auto cvtA = opp::wrap_type({paramA}); @@ -616,7 +616,91 @@ DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_typ return false; // root node hasn't changed }; - register_matcher(std::make_shared(cvt, "TagDCOFFPassCWAI3"), std::move(callback)); + register_matcher(std::make_shared(cvt, "TagDCOFFPassReshape3"), std::move(callback)); +} + +// Pattern: i4 Phi-3 4SymW16A +// +// +// "tensor" "scale" > "tensor" +// Param:A Param:C > Param:A +// i4 f16|f32 > f16 +// : : > : +// V : > V +// Convert : > Convert +// f16|f32 : > f32 +// : : > +// V V > +// Multiply > +// f16|f32 > +// : > +// : > +// Reshape > +// f16|f32 > + +DCOFFPassReshape4::DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) { + auto paramA = opp::wrap_type(); + auto paramC = opp::wrap_type(); + auto cvtA = opp::wrap_type({paramA}); + auto mulply = opp::wrap_type({cvtA, paramC}); + auto scalar = opp::wrap_type(); + auto reshape = opp::wrap_type({mulply, scalar}); + + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + auto matched_nodeA = node_to_output.at(paramA).get_node_shared_ptr(); + auto matched_nodeC = node_to_output.at(paramC).get_node_shared_ptr(); + + NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeA)); + NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeC)); + + auto matched_paramA = std::static_pointer_cast(matched_nodeA); + auto matched_paramC = std::static_pointer_cast(matched_nodeC); + + if (ov::element::i4 == matched_paramA->get_element_type() && + (ov::element::f16 == matched_paramC->get_element_type() || + ov::element::f32 == matched_paramC->get_element_type())) { + LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << dcoff_type); + matched_paramA->set_element_type(dcoff_type); + + if (dcoff_mode == DCOffMode::CAST_SCALE) { + NPUW_ASSERT(dcoff_type == ov::element::f16); + + LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove..."); + LOG_BLOCK(); + + // Extra transformation here: + // - remove Multiply + Intermediate Convert + // - mark paramC for removal. + // Convert will be reconnected to paramA directly. + + // Record mapping from the Scale coeff parameter to the Real weight parameter + pref.get().scales[matched_paramC] = matched_paramA; + + // Disconnect Multiply and Convert from their outputs + auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr(); + auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr(); + auto drop_outputs = [](std::shared_ptr node) { + for (auto&& node_outputs : node->outputs()) { + for (auto&& node_reader_port : node_outputs.get_target_inputs()) { + node_outputs.remove_target_input(node_reader_port); + } + } + }; + LOG_DEBUG("Dropping the connections..."); + drop_outputs(matched_mulply); + drop_outputs(matched_convrt); + + LOG_DEBUG("Reconnecting the Root..."); + auto matched_reshape = node_to_output.at(reshape).get_node_shared_ptr(); + matched_reshape->input(0).replace_source_output(matched_paramA); + } + LOG_DEBUG("Done"); + } + return false; // root node hasn't changed + }; + + register_matcher(std::make_shared(reshape, "TagDCOFFPassReshape4"), std::move(callback)); } //------------------------------------------------------------------------------ diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp index c0b394616c6ed5..9bb3c132fa9c5d 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp @@ -131,9 +131,14 @@ class DCOFFPassReshape2 : public ov::pass::MatcherPass { DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref); }; -class DCOFFPassCWAI3 : public ov::pass::MatcherPass { +class DCOFFPassReshape3 : public ov::pass::MatcherPass { public: - DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref); + DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref); +}; + +class DCOFFPassReshape4 : public ov::pass::MatcherPass { +public: + DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref); }; class CWAI1 : public ov::pass::MatcherPass { diff --git a/src/plugins/intel_npu/src/plugin/npuw/util.cpp b/src/plugins/intel_npu/src/plugin/npuw/util.cpp index fbfbcc5d35eb19..a29c0ab454357a 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/util.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/util.cpp @@ -491,11 +491,23 @@ void unpack_i4f16(const ov::SoPtr& from, NPUW_ASSERT(to->is_continuous()); NPUW_ASSERT(from->get_size() == to->get_size()); - // TODO: force 2d shapes for now - NPUW_ASSERT(scale->get_shape().size() == 2); + const auto& from_shape = from->get_shape(); + NPUW_ASSERT(from_shape.back() % 64 == 0); - NPUW_ASSERT(scale->get_shape()[0] == from->get_shape()[0]); - NPUW_ASSERT(scale->get_shape()[1] == 1); + // 2-channel (Symmetric) and 3-channel (group-wise) + // scale factors are supported. The scale/value loop + // iteration is based on stotal, so should work for + // both cases. + const auto& scale_shape = scale->get_shape(); + NPUW_ASSERT(scale_shape.size() == 3 || scale_shape.size() == 2); + if (scale_shape.size() == 3) { + NPUW_ASSERT(scale_shape[0] == from_shape[0]); + NPUW_ASSERT(scale_shape[1] == from_shape[1]); + NPUW_ASSERT(scale_shape[2] == 1); + } else { + NPUW_ASSERT(scale_shape[0] == from_shape[0]); + NPUW_ASSERT(scale_shape[1] == 1); + } const auto scale_elem_type = scale->get_element_type(); NPUW_ASSERT(scale_elem_type == ov::element::f32 || scale_elem_type == ov::element::f16);