From 977b42612c5128640e97c950726b7eaa74b1ae9e Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 4 Jul 2024 11:52:01 +0400 Subject: [PATCH] [PT FE]: update layer norm for support non-static normalized shape (#25365) ### Details: - *aten::layer_norm support extended for support normalized shapes represented as non constant and multiple dims* tested with stable-audio model ### Tickets: - TBD --- src/frontends/pytorch/src/op/layer_norm.cpp | 17 +++--- .../pytorch_tests/test_layer_norm.py | 53 +++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 tests/layer_tests/pytorch_tests/test_layer_norm.py diff --git a/src/frontends/pytorch/src/op/layer_norm.cpp b/src/frontends/pytorch/src/op/layer_norm.cpp index 8775d30440ca2f..79464fa2d6d609 100644 --- a/src/frontends/pytorch/src/op/layer_norm.cpp +++ b/src/frontends/pytorch/src/op/layer_norm.cpp @@ -7,6 +7,9 @@ #include "openvino/op/constant.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/mvn.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" #include "openvino/op/util/framework_node.hpp" #include "utils.hpp" @@ -20,12 +23,14 @@ using namespace ov::op; OutputVector translate_layer_norm(const NodeContext& context) { num_inputs_check(context, 5, 6); auto eps = context.const_input(4); - auto normalized_shape = context.const_input(1); - PYTORCH_OP_CONVERSION_CHECK(normalized_shape.size() == 1, - "Translation for aten::layer_norm supports only single normalized_shape value, " - "which means normalizing over the last dimension."); - // TODO: support any dimension - auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + auto normalized_shape = context.get_input(1); + auto num_axes = context.mark_node(std::make_shared(normalized_shape, element::i32)); + num_axes = context.mark_node(std::make_shared(num_axes)); + auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + auto axes_range = context.mark_node(std::make_shared(num_axes, zero, minus_one, element::i32)); + + auto axes = context.mark_node(std::make_shared(axes_range, minus_one)); auto out_node = context.mark_node(std::make_shared(context.get_input(0), axes, true, eps, MVNEpsMode::INSIDE_SQRT)); if (!context.input_is_none(2)) { diff --git a/tests/layer_tests/pytorch_tests/test_layer_norm.py b/tests/layer_tests/pytorch_tests/test_layer_norm.py new file mode 100644 index 00000000000000..3bba4a31dab0a4 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_layer_norm.py @@ -0,0 +1,53 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest +import numpy as np + + +class TestLayerNorm(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(20, 5, 10, 10).astype(np.float32),) + + def create_model(self, normalized_shape, weight, bias, eps): + import torch + import torch.nn.functional as F + + if weight == "ones": + weight = torch.ones(normalized_shape) + + if weight == "random": + weight = torch.randn(normalized_shape) + + if bias == "zeros": + bias = torch.zeros(normalized_shape) + + if bias == "random": + bias = torch.randn(normalized_shape) + + class aten_layer_norm(torch.nn.Module): + def __init__(self, normalized_shape, weight, bias, eps): + super(aten_layer_norm, self).__init__() + self.normalized_shape = normalized_shape + self.weight = weight + self.bias = bias + self.eps = eps + + def forward(self, x): + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + + ref_net = None + + return aten_layer_norm(normalized_shape, weight, bias, eps), ref_net, "aten::layer_norm" + + @pytest.mark.parametrize("normalized_shape", [[10,], [10, 10], [5, 10, 10]]) + @pytest.mark.parametrize("weight", [None, "ones", "random"]) + @pytest.mark.parametrize("bias", [None, "zeros", "random"]) + @pytest.mark.parametrize("eps", [1e-5, 0.005]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_layer_norm(self, normalized_shape, weight, bias, eps, ie_device, precision, ir_version): + self._test(*self.create_model(normalized_shape, weight, bias, eps), ie_device, precision, ir_version) \ No newline at end of file