Skip to content

Commit

Permalink
[PT FE]: update layer norm for support non-static normalized shape (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#25365)

### Details:
- *aten::layer_norm support extended for support normalized shapes
represented as non constant and multiple dims*
 tested with stable-audio model

### Tickets:
 - TBD
  • Loading branch information
eaidova authored Jul 4, 2024
1 parent ea7921c commit 977b426
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/frontends/pytorch/src/op/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/mvn.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"

Expand All @@ -20,12 +23,14 @@ using namespace ov::op;
OutputVector translate_layer_norm(const NodeContext& context) {
num_inputs_check(context, 5, 6);
auto eps = context.const_input<float>(4);
auto normalized_shape = context.const_input<Shape>(1);
PYTORCH_OP_CONVERSION_CHECK(normalized_shape.size() == 1,
"Translation for aten::layer_norm supports only single normalized_shape value, "
"which means normalizing over the last dimension.");
// TODO: support any dimension
auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto normalized_shape = context.get_input(1);
auto num_axes = context.mark_node(std::make_shared<v3::ShapeOf>(normalized_shape, element::i32));
num_axes = context.mark_node(std::make_shared<v0::Squeeze>(num_axes));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto axes_range = context.mark_node(std::make_shared<v4::Range>(num_axes, zero, minus_one, element::i32));

auto axes = context.mark_node(std::make_shared<v1::Multiply>(axes_range, minus_one));
auto out_node =
context.mark_node(std::make_shared<v6::MVN>(context.get_input(0), axes, true, eps, MVNEpsMode::INSIDE_SQRT));
if (!context.input_is_none(2)) {
Expand Down
53 changes: 53 additions & 0 deletions tests/layer_tests/pytorch_tests/test_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest

from pytorch_layer_test_class import PytorchLayerTest
import numpy as np


class TestLayerNorm(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(20, 5, 10, 10).astype(np.float32),)

def create_model(self, normalized_shape, weight, bias, eps):
import torch
import torch.nn.functional as F

if weight == "ones":
weight = torch.ones(normalized_shape)

if weight == "random":
weight = torch.randn(normalized_shape)

if bias == "zeros":
bias = torch.zeros(normalized_shape)

if bias == "random":
bias = torch.randn(normalized_shape)

class aten_layer_norm(torch.nn.Module):
def __init__(self, normalized_shape, weight, bias, eps):
super(aten_layer_norm, self).__init__()
self.normalized_shape = normalized_shape
self.weight = weight
self.bias = bias
self.eps = eps

def forward(self, x):
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)

ref_net = None

return aten_layer_norm(normalized_shape, weight, bias, eps), ref_net, "aten::layer_norm"

@pytest.mark.parametrize("normalized_shape", [[10,], [10, 10], [5, 10, 10]])
@pytest.mark.parametrize("weight", [None, "ones", "random"])
@pytest.mark.parametrize("bias", [None, "zeros", "random"])
@pytest.mark.parametrize("eps", [1e-5, 0.005])
@pytest.mark.nightly
@pytest.mark.precommit
def test_layer_norm(self, normalized_shape, weight, bias, eps, ie_device, precision, ir_version):
self._test(*self.create_model(normalized_shape, weight, bias, eps), ie_device, precision, ir_version)

0 comments on commit 977b426

Please sign in to comment.