Skip to content

Commit

Permalink
Fixed Convolution fusion (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat authored Mar 17, 2021
1 parent 05422c8 commit 8190eee
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<pass::ConvertConv1D>();
manager.register_pass<pass::ConvertGroupConv1D>();
manager.register_pass<pass::ConvertGroupConvolution>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<pass::ConvBiasActivationFusion>();
manager.register_pass<pass::ConvertMatMulToFC>();
manager.register_pass<pass::ConvertEltwise>();
Expand Down
22 changes: 11 additions & 11 deletions modules/arm_plugin/src/transformations/conv_bias_activ_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ ngraph::matcher_pass_callback ArmPlugin::pass::ConvBiasFusionBase::fuse_conv_wit
}

if (!std::dynamic_pointer_cast<opset::Constant>(eltwise->input_value(1 - conv_idx).get_node_shared_ptr())) {
THROW_IE_EXCEPTION << "Unsupported Convolution with inconstant weights.";
return false; // Unsupported Convolution with inconstant bias
}

auto bias = eltwise->input_value(1 - conv_idx);
Expand All @@ -160,10 +160,6 @@ ngraph::matcher_pass_callback ArmPlugin::pass::ConvBiasFusionBase::fuse_conv_wit
opset::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {channel_dim}), true);
}

if (m_conv->output(0).get_target_inputs().size() != 1) {
return false;
}

if (m_conv->inputs().size() == 3) {
new_bias = std::make_shared<opset::Add>(new_bias, m_conv->input_value(Inputs::Bias));
}
Expand Down Expand Up @@ -194,6 +190,10 @@ ngraph::matcher_pass_callback ArmPlugin::pass::ConvertConvBase::convert_conv_to_
return false;
}

if (!std::dynamic_pointer_cast<opset::Constant>(m_conv->input_value(Inputs::Weights).get_node_shared_ptr())) {
THROW_IE_EXCEPTION << "Unsupported Convolution with inconstant weights.";
}

auto conv_arm = std::make_shared<ArmConv>(
m_conv->input_value(Inputs::Data),
m_conv->input_value(Inputs::Weights),
Expand Down Expand Up @@ -233,17 +233,17 @@ ArmPlugin::pass::ConvertGroupConvolutionToArm::ConvertGroupConvolutionToArm() {

ArmPlugin::pass::ConvBiasFusion::ConvBiasFusion() {
auto m = std::make_shared<ngraph::pattern::Matcher>(
ngraph::pattern::wrap_type<opset::ArmConvolution>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
ngraph::pattern::has_static_shape()), "ConvBiasFusion");
ngraph::pattern::wrap_type<opset::Add>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
ngraph::pattern::has_static_shape()), "ConvBiasFusion");
register_matcher(m, fuse_conv_with_bias<opset::ArmConvolution>());
}

ArmPlugin::pass::GroupConvBiasFusion::GroupConvBiasFusion() {
auto m = std::make_shared<ngraph::pattern::Matcher>(
ngraph::pattern::wrap_type<opset::ArmGroupConvolution>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
ngraph::pattern::has_static_shape()), "GroupConvBiasFusion");
ngraph::pattern::wrap_type<opset::Add>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()),
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
ngraph::pattern::has_static_shape()), "GroupConvBiasFusion");
register_matcher(m, fuse_conv_with_bias<opset::ArmGroupConvolution>());
}

Expand Down

0 comments on commit 8190eee

Please sign in to comment.