Skip to content

Commit

Permalink
[NPUW] Add a new i4 pattern (openvinotoolkit#26212)
Browse files Browse the repository at this point in the history
1) CWAI3 have been renamed to Reshape3
2) Added Reshape4 pattern
3) Supported 3-dim scale shapes in i4f16 unpack

Co-authored-by: Dmitry Matveev <[email protected]>
  • Loading branch information
smirnov-alexey and dmatveev committed Aug 26, 2024
1 parent 0ed0948 commit 6fe1a53
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1624,8 +1624,11 @@ void Partitioner::decompressionCutOff(const std::string& func_name) {
// LLaMaGPTQ
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape2>(dcoff_mode, dcoff_type, std::ref(params_to));

// Phi-3 4SymW16A/GPTQ
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassCWAI3>(dcoff_mode, dcoff_type, std::ref(params_to));
// Phi-3 4SymW16A
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape3>(dcoff_mode, dcoff_type, std::ref(params_to));

// Phi-3 i4 4SymW16A
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape4>(dcoff_mode, dcoff_type, std::ref(params_to));

// Asymmetric zeropoints
rewr.add_matcher<ov::npuw::patterns::AsymmZP::DCOFFPassReshape>(dcoff_mode, dcoff_type, std::ref(params_to));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco
register_matcher(std::make_shared<opp::Matcher>(reshpe, "TagDCOFFReshape2"), std::move(callback));
}

// Pattern: Phi-3 4SymW16A/GPTQ
// Pattern: Phi-3 4SymW16A
//
//
// "tensor" "scale" > "tensor"
Expand All @@ -555,7 +555,7 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco
// V >
// Convert

DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
DCOFFPassReshape3::DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
Expand Down Expand Up @@ -616,7 +616,91 @@ DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_typ
return false; // root node hasn't changed
};

register_matcher(std::make_shared<opp::Matcher>(cvt, "TagDCOFFPassCWAI3"), std::move(callback));
register_matcher(std::make_shared<opp::Matcher>(cvt, "TagDCOFFPassReshape3"), std::move(callback));
}

// Pattern: i4 Phi-3 4SymW16A
//
//
// "tensor" "scale" > "tensor"
// Param:A Param:C > Param:A
// i4 f16|f32 > f16
// : : > :
// V : > V
// Convert : > Convert
// f16|f32 : > f32
// : : >
// V V >
// Multiply >
// f16|f32 >
// : >
// : >
// Reshape >
// f16|f32 >

DCOFFPassReshape4::DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, paramC});
auto scalar = opp::wrap_type<ov::op::v0::Constant>();
auto reshape = opp::wrap_type<ov::op::v1::Reshape>({mulply, scalar});

auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
auto matched_nodeA = node_to_output.at(paramA).get_node_shared_ptr();
auto matched_nodeC = node_to_output.at(paramC).get_node_shared_ptr();

NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeA));
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeC));

auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);

if (ov::element::i4 == matched_paramA->get_element_type() &&
(ov::element::f16 == matched_paramC->get_element_type() ||
ov::element::f32 == matched_paramC->get_element_type())) {
LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << dcoff_type);
matched_paramA->set_element_type(dcoff_type);

if (dcoff_mode == DCOffMode::CAST_SCALE) {
NPUW_ASSERT(dcoff_type == ov::element::f16);

LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove...");
LOG_BLOCK();

// Extra transformation here:
// - remove Multiply + Intermediate Convert
// - mark paramC for removal.
// Convert will be reconnected to paramA directly.

// Record mapping from the Scale coeff parameter to the Real weight parameter
pref.get().scales[matched_paramC] = matched_paramA;

// Disconnect Multiply and Convert from their outputs
auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr();
auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr();
auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
for (auto&& node_outputs : node->outputs()) {
for (auto&& node_reader_port : node_outputs.get_target_inputs()) {
node_outputs.remove_target_input(node_reader_port);
}
}
};
LOG_DEBUG("Dropping the connections...");
drop_outputs(matched_mulply);
drop_outputs(matched_convrt);

LOG_DEBUG("Reconnecting the Root...");
auto matched_reshape = node_to_output.at(reshape).get_node_shared_ptr();
matched_reshape->input(0).replace_source_output(matched_paramA);
}
LOG_DEBUG("Done");
}
return false; // root node hasn't changed
};

register_matcher(std::make_shared<opp::Matcher>(reshape, "TagDCOFFPassReshape4"), std::move(callback));
}

//------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,14 @@ class DCOFFPassReshape2 : public ov::pass::MatcherPass {
DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
};

class DCOFFPassCWAI3 : public ov::pass::MatcherPass {
class DCOFFPassReshape3 : public ov::pass::MatcherPass {
public:
DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
};

class DCOFFPassReshape4 : public ov::pass::MatcherPass {
public:
DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
};

class CWAI1 : public ov::pass::MatcherPass {
Expand Down
20 changes: 16 additions & 4 deletions src/plugins/intel_npu/src/plugin/npuw/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,23 @@ void unpack_i4f16(const ov::SoPtr<ov::ITensor>& from,
NPUW_ASSERT(to->is_continuous());
NPUW_ASSERT(from->get_size() == to->get_size());

// TODO: force 2d shapes for now
NPUW_ASSERT(scale->get_shape().size() == 2);
const auto& from_shape = from->get_shape();
NPUW_ASSERT(from_shape.back() % 64 == 0);

NPUW_ASSERT(scale->get_shape()[0] == from->get_shape()[0]);
NPUW_ASSERT(scale->get_shape()[1] == 1);
// 2-channel (Symmetric) and 3-channel (group-wise)
// scale factors are supported. The scale/value loop
// iteration is based on stotal, so should work for
// both cases.
const auto& scale_shape = scale->get_shape();
NPUW_ASSERT(scale_shape.size() == 3 || scale_shape.size() == 2);
if (scale_shape.size() == 3) {
NPUW_ASSERT(scale_shape[0] == from_shape[0]);
NPUW_ASSERT(scale_shape[1] == from_shape[1]);
NPUW_ASSERT(scale_shape[2] == 1);
} else {
NPUW_ASSERT(scale_shape[0] == from_shape[0]);
NPUW_ASSERT(scale_shape[1] == 1);
}

const auto scale_elem_type = scale->get_element_type();
NPUW_ASSERT(scale_elem_type == ov::element::f32 || scale_elem_type == ov::element::f16);
Expand Down

0 comments on commit 6fe1a53

Please sign in to comment.