Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions doc/graph/operations/RMSNorm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
RMSNorm {#dev_guide_op_rmsnorm}
===============================

## General

RMSNorm (Root Mean Square Layer Normalization) operation performs normalization
on the input tensor using the root mean square statistic.

The RMSNorm operation performs the following transformation of the input tensor:

\f[
y = \gamma \cdot \frac{x}{\sqrt{\text{RMS}(x) + \epsilon}},
\f]

where

\f[
\text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2}
\f]

## Operation attributes

| Attribute Name | Description | Value Type | Supported Values | Required or Optional |
|:---------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------|:----------------------------------------------|:---------------------|
| [epsilon](@ref dnnl::graph::op::attr::epsilon) | The constant to improve numerical stability. | f32 | Arbitrary positive f32 value, `1e-5`(default) | Optional |
| [begin_norm_axis](@ref dnnl::graph::op::attr::begin_norm_axis) | `begin_norm_axis` is used to indicate which axis to start RMS normalization. The normalization is from `begin_norm_axis` to last dimension. Negative values means indexing from right to left. This op normalizes over the last dimension by default, e.g. C in TNC for 3D and LDNC for 4D. | s64 | [-r,r-1],where r=rank(src). -1 is default | Optional |

## Execution arguments

The inputs and outputs must be provided according to below index order when
constructing an operation.

### Inputs

| Index | Argument Name | Required or Optional |
|:------|:--------------|:---------------------|
| 0 | `src` | Required |
| 1 | `gamma` | Optional |

@note `gamma` is scaling for the normalized value. `gamma` shape should be broadcastable to the `src` shape.

### Outputs

| Index | Argument Name | Required or Optional |
|:------|:--------------|:---------------------|
| 0 | `dst` | Required |

## Supported data types

RMSNorm operation supports the following data type combinations.

| Src / Dst | Gamma |
|:----------|:-------------|
| f32 | f32 |
| bf16 | f32, bf16 |
| f16 | f32, f16 |
1 change: 1 addition & 0 deletions doc/graph/rst/graph_supported_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Supported Operations
dev_guide_op_relu
dev_guide_op_relubackward
dev_guide_op_reorder
dev_guide_op_rmsnorm
dev_guide_op_round
dev_guide_op_select
dev_guide_op_sigmoid
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/dnnl/dnnl_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ class op : public op_handle {
ReLUBackward = dnnl_graph_op_relu_backward,
Reorder = dnnl_graph_op_reorder,
Round = dnnl_graph_op_round,
RMSNorm = dnnl_graph_op_rms_norm,
Select = dnnl_graph_op_select,
Sigmoid = dnnl_graph_op_sigmoid,
SigmoidBackward = dnnl_graph_op_sigmoid_backward,
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/dnnl/dnnl_graph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ typedef enum {
dnnl_graph_op_group_norm,
dnnl_graph_op_gen_index,
dnnl_graph_op_greater_equal,
dnnl_graph_op_rms_norm,
dnnl_graph_op_last_symbol,
} dnnl_graph_op_kind_t;

Expand Down
5 changes: 3 additions & 2 deletions src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2025 Intel Corporation
* Copyright 2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1054,12 +1054,13 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_layernorm, 1,
int64_t(-1))
.set_attr(op_attr::use_affine, false, attribute_kind::b, true)
.set_attr(op_attr::epsilon, false, attribute_kind::f, 1e-5f)
.set_attr(op_attr::is_rms, false, attribute_kind::b, false)
.set_attr(op_attr::fusion_info, false,
attribute_kind::fusion_info)
// New added attributes
.SET_ATTR_IS_CONSTANT // used for constant prop and cache
// Analysis rules
.set_shape_inference_function(infer_norm_output_shape)
.set_shape_inference_function(infer_dnnl_layernorm_output_shape)
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_layernorm)
.SET_EXECUTABLE_CREATOR(
executable_creator<layernorm_executable_t>)
Expand Down
17 changes: 16 additions & 1 deletion src/graph/backend/dnnl/dnnl_shape_infer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2025 Intel Corporation
* Copyright 2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -612,6 +612,21 @@ status_t infer_dnnl_host_scalar_output_shape(op_t *n,
return status::success;
}

status_t infer_dnnl_layernorm_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
const auto is_rms = n->has_attr(op_attr::is_rms)
&& n->get_attr<bool>(op_attr::is_rms);
if (is_rms) {
auto status = infer_identity_output_shape(n, inputs, outputs);
if (status != status::success) return status;
} else {
auto status = infer_norm_output_shape(n, inputs, outputs);
if (status != status::success) return status;
}
return status::success;
}

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
5 changes: 4 additions & 1 deletion src/graph/backend/dnnl/dnnl_shape_infer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2025 Intel Corporation
* Copyright 2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -115,6 +115,9 @@ status_t infer_dnnl_host_scalar_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

status_t infer_dnnl_layernorm_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);
} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
4 changes: 3 additions & 1 deletion src/graph/backend/dnnl/internal_attrs.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2025 Intel Corporation
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,6 +48,7 @@ const op_attr_t keep_dst_layout = 0x1000f;
const op_attr_t with_scale = 0x10010;
const op_attr_t is_invert_scale = 0x10011;
const op_attr_t mask_type = 0x10012;
const op_attr_t is_rms = 0x10013;

// int64_t
const op_attr_t alg_kind = 0x10100;
Expand Down Expand Up @@ -108,6 +109,7 @@ static inline std::string internal_attr2str(op_attr_t attr) {
CASE(fusion_info);
CASE(qk_acc_mode);
CASE(vs_acc_mode);
CASE(is_rms);
default: return "undefined_attr";
}
#undef CASE
Expand Down
79 changes: 54 additions & 25 deletions src/graph/backend/dnnl/op_executable.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2025 Intel Corporation
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -718,12 +718,17 @@ layernorm_executable_t::desc_t layernorm_executable_t::create_desc(
bool use_affine = true;
if (op->has_attr(op_attr::use_affine))
use_affine = op->get_attr<bool>(op_attr::use_affine);
bool is_rms = false;
if (op->has_attr(op_attr::is_rms))
is_rms = op->get_attr<bool>(op_attr::is_rms);

auto flags = dnnl::normalization_flags::none;
if (use_affine)
flags |= (dnnl::normalization_flags::use_scale
| dnnl::normalization_flags::use_shift);

if (is_rms) flags |= dnnl::normalization_flags::rms_norm;
if (use_affine) {
flags |= dnnl::normalization_flags::use_scale;
// no shift for rms norm
if (!is_rms) flags |= dnnl::normalization_flags::use_shift;
}
prop_kind pkind = keep_stats ? prop_kind::forward_training
: prop_kind::forward_inference;

Expand All @@ -735,8 +740,17 @@ layernorm_executable_t::desc_t layernorm_executable_t::create_desc(
auto dst = make_dnnl_memory_desc(
op->get_output_value(0)->get_logical_tensor());
dst = to_format_any(dst);
dnnl::layer_normalization_forward::primitive_desc pd(
p_engine, pkind, src, dst, epsilon, flags, prm_attr);
dnnl::layer_normalization_forward::primitive_desc pd;
if (use_affine) {
memory::data_type scale_shift_data_type
= static_cast<memory::data_type>(
op->get_input_value(1)->get_logical_tensor().data_type);
pd = dnnl::layer_normalization_forward::primitive_desc(p_engine, pkind,
src, dst, scale_shift_data_type, epsilon, flags, prm_attr);
} else {
pd = dnnl::layer_normalization_forward::primitive_desc(
p_engine, pkind, src, dst, epsilon, flags, prm_attr);
}

pd_cache.insert({op.get(), pd});
return {pd, false};
Expand Down Expand Up @@ -2192,44 +2206,59 @@ arg_indices_t batchnorm_bwd_executable_t::get_arg_indices(const op_t *op) {
return arg_indices;
}

static arg_indices_t get_arg_indices_for_lnorm_and_gnorm(const op_t *op) {
arg_indices_t arg_indices;

size_t in_index = 0;
arg_indices.insert({DNNL_ARG_SRC, indices_t {input, in_index++}});
static arg_indices_t get_arg_indices_for_norm(const op_t *op) {
arg_indices_t args;
size_t in_idx = 0;
const bool is_rms = op->has_attr(op_attr::is_rms)
? op->get_attr<bool>(op_attr::is_rms)
: false;
args.insert({DNNL_ARG_SRC, {indices_t::type_t::input, in_idx++}});
if (!op->has_attr(op_attr::use_affine)
|| op->get_attr<bool>(op_attr::use_affine)) {
arg_indices.insert({DNNL_ARG_SCALE, indices_t {input, in_index++}});
arg_indices.insert({DNNL_ARG_SHIFT, indices_t {input, in_index++}});
// rms doesn't support shift
if (!is_rms) {
args.insert({DNNL_ARG_SCALE, {indices_t::type_t::input, in_idx++}});
args.insert({DNNL_ARG_SHIFT, {indices_t::type_t::input, in_idx++}});
} else {
if (op->has_attr(op_attr::use_affine)
&& op->get_attr<bool>(op_attr::use_affine)) {
args.insert(
{DNNL_ARG_SCALE, {indices_t::type_t::input, in_idx++}});
}
}
}

const fusion_info_t &fusion_info = op->has_attr(op_attr::fusion_info)
? op->get_attr<fusion_info_t>(op_attr::fusion_info)
: fusion_info_t();

get_arg_indices_for_post_ops(op, arg_indices, in_index);
get_arg_indices_for_post_ops(op, args, in_idx);

if (fusion_info.with_runtime_scales(false, 0)) {
arg_indices.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST,
indices_t {input, in_index++}});
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST,
indices_t {input, in_idx++}});
}

size_t out_index = 0;
arg_indices.insert({DNNL_ARG_DST, indices_t {output, out_index++}});
args.insert({DNNL_ARG_DST, indices_t {output, out_index++}});
if (!op->has_attr(op_attr::keep_stats)
|| op->get_attr<bool>(op_attr::keep_stats)) {
arg_indices.insert({DNNL_ARG_MEAN, indices_t {output, out_index++}});
arg_indices.insert(
{DNNL_ARG_VARIANCE, indices_t {output, out_index++}});
// RMSNorm OP in oneDNN Graph API only have 1 output
if (!is_rms) {
args.insert(
{DNNL_ARG_MEAN, {indices_t::type_t::output, out_index++}});
args.insert({DNNL_ARG_VARIANCE,
{indices_t::type_t::output, out_index++}});
}
}

arg_indices.insert({DNNL_ARG_SCRATCHPAD, indices_t {output, out_index++}});
args.insert({DNNL_ARG_SCRATCHPAD, indices_t {output, out_index++}});

return arg_indices;
return args;
}

arg_indices_t layernorm_executable_t::get_arg_indices(const op_t *op) {
return get_arg_indices_for_lnorm_and_gnorm(op);
return get_arg_indices_for_norm(op);
}

arg_indices_t layernorm_bwd_executable_t::get_arg_indices(const op_t *op) {
Expand Down Expand Up @@ -2357,7 +2386,7 @@ arg_indices_t eltwise_bwd_executable_t::get_arg_indices(const op_t *op) {
}

arg_indices_t groupnorm_executable_t::get_arg_indices(const op_t *op) {
return get_arg_indices_for_lnorm_and_gnorm(op);
return get_arg_indices_for_norm(op);
}

arg_indices_t genindex_executable_t::get_arg_indices(const op_t *op) {
Expand Down
21 changes: 20 additions & 1 deletion src/graph/backend/dnnl/passes/lower.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2025 Intel Corporation
* Copyright 2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -917,6 +917,23 @@ static status_t gen_index_handler(
return status::success;
}

static status_t rmsnorm_handler(
const std::shared_ptr<op_t> &op, subgraph_rewriter_t &rewriter) {
auto new_op = std::make_shared<op_t>(op_kind::dnnl_layernorm);
new_op->set_attr<bool>(op_attr::is_rms, true);
if (op->get_input_values().size() == 2) {
new_op->set_attr<bool>(op_attr::use_affine, true);
} else {
new_op->set_attr<bool>(op_attr::use_affine, false);
}
// RMSNorm OP in oneDNN Graph API only have 1 output
new_op->set_attr<bool>(op_attr::keep_stats, false);
new_op->merge_attributes(op->get_attributes());
rewriter.replace_op(op, new_op);
insert_empty_scratchpad(new_op);
return status::success;
}

#define ITEM(kind, func) \
{ \
graph::op_kind::kind, handler_func { (func) } \
Expand Down Expand Up @@ -1000,6 +1017,7 @@ static const std::unordered_map<graph::op_kind_t, handler_func> handler_table {
ITEM(ReduceMin, reduction_handler),
ITEM(ReduceProd, reduction_handler),
ITEM(ReduceSum, reduction_handler),
ITEM(RMSNorm, rmsnorm_handler),
// softplus
ITEM(SoftPlus, softplus_handler),
ITEM(SoftPlusBackward, softplus_handler),
Expand All @@ -1010,6 +1028,7 @@ static const std::unordered_map<graph::op_kind_t, handler_func> handler_table {
// layernorm
ITEM(LayerNorm, common_handler<op_kind::kDnnl_layernorm>),
ITEM(LayerNormBackward, common_handler<op_kind::kDnnl_layernorm_bwd>),
ITEM(RMSNorm, common_handler<op_kind::kDnnl_layernorm>),
// groupnorm
ITEM(GroupNorm, common_handler<op_kind::kDnnl_groupnorm>),
// quantization
Expand Down
19 changes: 18 additions & 1 deletion src/graph/backend/dnnl/patterns/single_op_pattern.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2025 Intel Corporation
* Copyright 2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -435,6 +435,23 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, greater_equal_pass)
.set_attr<FCreateKernel>("FCreateKernel",
[]() -> kernel_ptr { return std::make_shared<binary_t>(); });

DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, rmsn_pass)
.set_priority(DEFAULT_P)
.set_kind(partition_kind_t::misc_post_ops)
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
graph::utils::pm::pb_op_t *p_rmsnorm
= pgraph->append_op(graph::op_kind::RMSNorm);
p_rmsnorm->append_decision_function(
check_begin_norm_axis_attr);
// primitive only support 2-5D data tensor for rmsnorm
p_rmsnorm->append_decision_function(
check_input_ndim_from_offset<0, 2, 5>);
})
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
return std::make_shared<layer_norm_fwd_t>();
});

#undef DNNL_BACKEND_SINGLE_OP_TRANSFORM
#undef DEFAULT_P

Expand Down
1 change: 1 addition & 0 deletions src/graph/interface/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ const op_kind_t ReduceProd = dnnl_graph_op_reduce_prod;
const op_kind_t ReduceSum = dnnl_graph_op_reduce_sum;
const op_kind_t ReLU = dnnl_graph_op_relu;
const op_kind_t ReLUBackward = dnnl_graph_op_relu_backward;
const op_kind_t RMSNorm = dnnl_graph_op_rms_norm;
const op_kind_t Reorder = dnnl_graph_op_reorder;
const op_kind_t Round = dnnl_graph_op_round;
const op_kind_t Select = dnnl_graph_op_select;
Expand Down
Loading