diff --git a/doc/graph/operations/RMSNorm.md b/doc/graph/operations/RMSNorm.md new file mode 100644 index 00000000000..df7871e6be4 --- /dev/null +++ b/doc/graph/operations/RMSNorm.md @@ -0,0 +1,56 @@ +RMSNorm {#dev_guide_op_rmsnorm} +=============================== + +## General + +RMSNorm (Root Mean Square Layer Normalization) operation performs normalization +on the input tensor using the root mean square statistic. + +The RMSNorm operation performs the following transformation of the input tensor: + +\f[ + y = \gamma \cdot \frac{x}{\sqrt{\text{RMS}(x) + \epsilon}}, +\f] + +where + +\f[ + \text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2} +\f] + +## Operation attributes + +| Attribute Name | Description | Value Type | Supported Values | Required or Optional | +|:---------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------|:----------------------------------------------|:---------------------| +| [epsilon](@ref dnnl::graph::op::attr::epsilon) | The constant to improve numerical stability. | f32 | Arbitrary positive f32 value, `1e-5`(default) | Optional | +| [begin_norm_axis](@ref dnnl::graph::op::attr::begin_norm_axis) | `begin_norm_axis` is used to indicate which axis to start RMS normalization. The normalization is from `begin_norm_axis` to last dimension. Negative values means indexing from right to left. This op normalizes over the last dimension by default, e.g. C in TNC for 3D and LDNC for 4D. | s64 | [-r,r-1],where r=rank(src). -1 is default | Optional | + +## Execution arguments + +The inputs and outputs must be provided according to below index order when +constructing an operation. + +### Inputs + +| Index | Argument Name | Required or Optional | +|:------|:--------------|:---------------------| +| 0 | `src` | Required | +| 1 | `gamma` | Optional | + +@note `gamma` is scaling for the normalized value. `gamma` shape should be broadcastable to the `src` shape. + +### Outputs + +| Index | Argument Name | Required or Optional | +|:------|:--------------|:---------------------| +| 0 | `dst` | Required | + +## Supported data types + +RMSNorm operation supports the following data type combinations. + +| Src / Dst | Gamma | +|:----------|:-------------| +| f32 | f32 | +| bf16 | f32, bf16 | +| f16 | f32, f16 | diff --git a/doc/graph/rst/graph_supported_operations.rst b/doc/graph/rst/graph_supported_operations.rst index fbd3146b113..1bc79d62b75 100644 --- a/doc/graph/rst/graph_supported_operations.rst +++ b/doc/graph/rst/graph_supported_operations.rst @@ -71,6 +71,7 @@ Supported Operations dev_guide_op_relu dev_guide_op_relubackward dev_guide_op_reorder + dev_guide_op_rmsnorm dev_guide_op_round dev_guide_op_select dev_guide_op_sigmoid diff --git a/include/oneapi/dnnl/dnnl_graph.hpp b/include/oneapi/dnnl/dnnl_graph.hpp index b9ad7832663..4ab3fcb48b2 100644 --- a/include/oneapi/dnnl/dnnl_graph.hpp +++ b/include/oneapi/dnnl/dnnl_graph.hpp @@ -842,6 +842,7 @@ class op : public op_handle { ReLUBackward = dnnl_graph_op_relu_backward, Reorder = dnnl_graph_op_reorder, Round = dnnl_graph_op_round, + RMSNorm = dnnl_graph_op_rms_norm, Select = dnnl_graph_op_select, Sigmoid = dnnl_graph_op_sigmoid, SigmoidBackward = dnnl_graph_op_sigmoid_backward, diff --git a/include/oneapi/dnnl/dnnl_graph_types.h b/include/oneapi/dnnl/dnnl_graph_types.h index 3368cf67288..fcf045d1907 100644 --- a/include/oneapi/dnnl/dnnl_graph_types.h +++ b/include/oneapi/dnnl/dnnl_graph_types.h @@ -261,6 +261,7 @@ typedef enum { dnnl_graph_op_group_norm, dnnl_graph_op_gen_index, dnnl_graph_op_greater_equal, + dnnl_graph_op_rms_norm, dnnl_graph_op_last_symbol, } dnnl_graph_op_kind_t; diff --git a/src/graph/backend/dnnl/dnnl_op_def.hpp b/src/graph/backend/dnnl/dnnl_op_def.hpp index c408219a6d0..bd4fe4b8b71 100644 --- a/src/graph/backend/dnnl/dnnl_op_def.hpp +++ b/src/graph/backend/dnnl/dnnl_op_def.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2025 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1054,12 +1054,13 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_layernorm, 1, int64_t(-1)) .set_attr(op_attr::use_affine, false, attribute_kind::b, true) .set_attr(op_attr::epsilon, false, attribute_kind::f, 1e-5f) + .set_attr(op_attr::is_rms, false, attribute_kind::b, false) .set_attr(op_attr::fusion_info, false, attribute_kind::fusion_info) // New added attributes .SET_ATTR_IS_CONSTANT // used for constant prop and cache // Analysis rules - .set_shape_inference_function(infer_norm_output_shape) + .set_shape_inference_function(infer_dnnl_layernorm_output_shape) .SET_LAYOUT_PROPAGATOR(layout_propagator_for_layernorm) .SET_EXECUTABLE_CREATOR( executable_creator) diff --git a/src/graph/backend/dnnl/dnnl_shape_infer.cpp b/src/graph/backend/dnnl/dnnl_shape_infer.cpp index dea3dd9aced..36ae5127dc9 100644 --- a/src/graph/backend/dnnl/dnnl_shape_infer.cpp +++ b/src/graph/backend/dnnl/dnnl_shape_infer.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2025 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -612,6 +612,21 @@ status_t infer_dnnl_host_scalar_output_shape(op_t *n, return status::success; } +status_t infer_dnnl_layernorm_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs) { + const auto is_rms = n->has_attr(op_attr::is_rms) + && n->get_attr(op_attr::is_rms); + if (is_rms) { + auto status = infer_identity_output_shape(n, inputs, outputs); + if (status != status::success) return status; + } else { + auto status = infer_norm_output_shape(n, inputs, outputs); + if (status != status::success) return status; + } + return status::success; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/dnnl_shape_infer.hpp b/src/graph/backend/dnnl/dnnl_shape_infer.hpp index 6dfe4be71e7..94bd8cfe8f7 100644 --- a/src/graph/backend/dnnl/dnnl_shape_infer.hpp +++ b/src/graph/backend/dnnl/dnnl_shape_infer.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2025 Intel Corporation +* Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -115,6 +115,9 @@ status_t infer_dnnl_host_scalar_output_shape(op_t *n, std::vector &inputs, std::vector &outputs); +status_t infer_dnnl_layernorm_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs); } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/internal_attrs.hpp b/src/graph/backend/dnnl/internal_attrs.hpp index c17266f3c85..2ba365c734b 100644 --- a/src/graph/backend/dnnl/internal_attrs.hpp +++ b/src/graph/backend/dnnl/internal_attrs.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2025 Intel Corporation +* Copyright 2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,7 @@ const op_attr_t keep_dst_layout = 0x1000f; const op_attr_t with_scale = 0x10010; const op_attr_t is_invert_scale = 0x10011; const op_attr_t mask_type = 0x10012; +const op_attr_t is_rms = 0x10013; // int64_t const op_attr_t alg_kind = 0x10100; @@ -108,6 +109,7 @@ static inline std::string internal_attr2str(op_attr_t attr) { CASE(fusion_info); CASE(qk_acc_mode); CASE(vs_acc_mode); + CASE(is_rms); default: return "undefined_attr"; } #undef CASE diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index 32c7a101789..0d370fc2a80 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2022-2025 Intel Corporation + * Copyright 2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -718,12 +718,17 @@ layernorm_executable_t::desc_t layernorm_executable_t::create_desc( bool use_affine = true; if (op->has_attr(op_attr::use_affine)) use_affine = op->get_attr(op_attr::use_affine); + bool is_rms = false; + if (op->has_attr(op_attr::is_rms)) + is_rms = op->get_attr(op_attr::is_rms); auto flags = dnnl::normalization_flags::none; - if (use_affine) - flags |= (dnnl::normalization_flags::use_scale - | dnnl::normalization_flags::use_shift); - + if (is_rms) flags |= dnnl::normalization_flags::rms_norm; + if (use_affine) { + flags |= dnnl::normalization_flags::use_scale; + // no shift for rms norm + if (!is_rms) flags |= dnnl::normalization_flags::use_shift; + } prop_kind pkind = keep_stats ? prop_kind::forward_training : prop_kind::forward_inference; @@ -735,8 +740,17 @@ layernorm_executable_t::desc_t layernorm_executable_t::create_desc( auto dst = make_dnnl_memory_desc( op->get_output_value(0)->get_logical_tensor()); dst = to_format_any(dst); - dnnl::layer_normalization_forward::primitive_desc pd( - p_engine, pkind, src, dst, epsilon, flags, prm_attr); + dnnl::layer_normalization_forward::primitive_desc pd; + if (use_affine) { + memory::data_type scale_shift_data_type + = static_cast( + op->get_input_value(1)->get_logical_tensor().data_type); + pd = dnnl::layer_normalization_forward::primitive_desc(p_engine, pkind, + src, dst, scale_shift_data_type, epsilon, flags, prm_attr); + } else { + pd = dnnl::layer_normalization_forward::primitive_desc( + p_engine, pkind, src, dst, epsilon, flags, prm_attr); + } pd_cache.insert({op.get(), pd}); return {pd, false}; @@ -2192,44 +2206,59 @@ arg_indices_t batchnorm_bwd_executable_t::get_arg_indices(const op_t *op) { return arg_indices; } -static arg_indices_t get_arg_indices_for_lnorm_and_gnorm(const op_t *op) { - arg_indices_t arg_indices; - - size_t in_index = 0; - arg_indices.insert({DNNL_ARG_SRC, indices_t {input, in_index++}}); +static arg_indices_t get_arg_indices_for_norm(const op_t *op) { + arg_indices_t args; + size_t in_idx = 0; + const bool is_rms = op->has_attr(op_attr::is_rms) + ? op->get_attr(op_attr::is_rms) + : false; + args.insert({DNNL_ARG_SRC, {indices_t::type_t::input, in_idx++}}); if (!op->has_attr(op_attr::use_affine) || op->get_attr(op_attr::use_affine)) { - arg_indices.insert({DNNL_ARG_SCALE, indices_t {input, in_index++}}); - arg_indices.insert({DNNL_ARG_SHIFT, indices_t {input, in_index++}}); + // rms doesn't support shift + if (!is_rms) { + args.insert({DNNL_ARG_SCALE, {indices_t::type_t::input, in_idx++}}); + args.insert({DNNL_ARG_SHIFT, {indices_t::type_t::input, in_idx++}}); + } else { + if (op->has_attr(op_attr::use_affine) + && op->get_attr(op_attr::use_affine)) { + args.insert( + {DNNL_ARG_SCALE, {indices_t::type_t::input, in_idx++}}); + } + } } const fusion_info_t &fusion_info = op->has_attr(op_attr::fusion_info) ? op->get_attr(op_attr::fusion_info) : fusion_info_t(); - get_arg_indices_for_post_ops(op, arg_indices, in_index); + get_arg_indices_for_post_ops(op, args, in_idx); if (fusion_info.with_runtime_scales(false, 0)) { - arg_indices.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, - indices_t {input, in_index++}}); + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, + indices_t {input, in_idx++}}); } size_t out_index = 0; - arg_indices.insert({DNNL_ARG_DST, indices_t {output, out_index++}}); + args.insert({DNNL_ARG_DST, indices_t {output, out_index++}}); if (!op->has_attr(op_attr::keep_stats) || op->get_attr(op_attr::keep_stats)) { - arg_indices.insert({DNNL_ARG_MEAN, indices_t {output, out_index++}}); - arg_indices.insert( - {DNNL_ARG_VARIANCE, indices_t {output, out_index++}}); + // RMSNorm OP in oneDNN Graph API only have 1 output + if (!is_rms) { + args.insert( + {DNNL_ARG_MEAN, {indices_t::type_t::output, out_index++}}); + args.insert({DNNL_ARG_VARIANCE, + {indices_t::type_t::output, out_index++}}); + } } - arg_indices.insert({DNNL_ARG_SCRATCHPAD, indices_t {output, out_index++}}); + args.insert({DNNL_ARG_SCRATCHPAD, indices_t {output, out_index++}}); - return arg_indices; + return args; } arg_indices_t layernorm_executable_t::get_arg_indices(const op_t *op) { - return get_arg_indices_for_lnorm_and_gnorm(op); + return get_arg_indices_for_norm(op); } arg_indices_t layernorm_bwd_executable_t::get_arg_indices(const op_t *op) { @@ -2357,7 +2386,7 @@ arg_indices_t eltwise_bwd_executable_t::get_arg_indices(const op_t *op) { } arg_indices_t groupnorm_executable_t::get_arg_indices(const op_t *op) { - return get_arg_indices_for_lnorm_and_gnorm(op); + return get_arg_indices_for_norm(op); } arg_indices_t genindex_executable_t::get_arg_indices(const op_t *op) { diff --git a/src/graph/backend/dnnl/passes/lower.cpp b/src/graph/backend/dnnl/passes/lower.cpp index 5db117aafbd..176e9187190 100644 --- a/src/graph/backend/dnnl/passes/lower.cpp +++ b/src/graph/backend/dnnl/passes/lower.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2022-2025 Intel Corporation + * Copyright 2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -917,6 +917,23 @@ static status_t gen_index_handler( return status::success; } +static status_t rmsnorm_handler( + const std::shared_ptr &op, subgraph_rewriter_t &rewriter) { + auto new_op = std::make_shared(op_kind::dnnl_layernorm); + new_op->set_attr(op_attr::is_rms, true); + if (op->get_input_values().size() == 2) { + new_op->set_attr(op_attr::use_affine, true); + } else { + new_op->set_attr(op_attr::use_affine, false); + } + // RMSNorm OP in oneDNN Graph API only have 1 output + new_op->set_attr(op_attr::keep_stats, false); + new_op->merge_attributes(op->get_attributes()); + rewriter.replace_op(op, new_op); + insert_empty_scratchpad(new_op); + return status::success; +} + #define ITEM(kind, func) \ { \ graph::op_kind::kind, handler_func { (func) } \ @@ -1000,6 +1017,7 @@ static const std::unordered_map handler_table { ITEM(ReduceMin, reduction_handler), ITEM(ReduceProd, reduction_handler), ITEM(ReduceSum, reduction_handler), + ITEM(RMSNorm, rmsnorm_handler), // softplus ITEM(SoftPlus, softplus_handler), ITEM(SoftPlusBackward, softplus_handler), @@ -1010,6 +1028,7 @@ static const std::unordered_map handler_table { // layernorm ITEM(LayerNorm, common_handler), ITEM(LayerNormBackward, common_handler), + ITEM(RMSNorm, common_handler), // groupnorm ITEM(GroupNorm, common_handler), // quantization diff --git a/src/graph/backend/dnnl/patterns/single_op_pattern.cpp b/src/graph/backend/dnnl/patterns/single_op_pattern.cpp index 2d898040c44..7329d295f20 100644 --- a/src/graph/backend/dnnl/patterns/single_op_pattern.cpp +++ b/src/graph/backend/dnnl/patterns/single_op_pattern.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2020-2025 Intel Corporation + * Copyright 2020 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -435,6 +435,23 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, greater_equal_pass) .set_attr("FCreateKernel", []() -> kernel_ptr { return std::make_shared(); }); +DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, rmsn_pass) + .set_priority(DEFAULT_P) + .set_kind(partition_kind_t::misc_post_ops) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + graph::utils::pm::pb_op_t *p_rmsnorm + = pgraph->append_op(graph::op_kind::RMSNorm); + p_rmsnorm->append_decision_function( + check_begin_norm_axis_attr); + // primitive only support 2-5D data tensor for rmsnorm + p_rmsnorm->append_decision_function( + check_input_ndim_from_offset<0, 2, 5>); + }) + .set_attr("FCreateKernel", []() -> kernel_ptr { + return std::make_shared(); + }); + #undef DNNL_BACKEND_SINGLE_OP_TRANSFORM #undef DEFAULT_P diff --git a/src/graph/interface/c_types_map.hpp b/src/graph/interface/c_types_map.hpp index 0bc6c244520..1cd79f7a711 100644 --- a/src/graph/interface/c_types_map.hpp +++ b/src/graph/interface/c_types_map.hpp @@ -193,6 +193,7 @@ const op_kind_t ReduceProd = dnnl_graph_op_reduce_prod; const op_kind_t ReduceSum = dnnl_graph_op_reduce_sum; const op_kind_t ReLU = dnnl_graph_op_relu; const op_kind_t ReLUBackward = dnnl_graph_op_relu_backward; +const op_kind_t RMSNorm = dnnl_graph_op_rms_norm; const op_kind_t Reorder = dnnl_graph_op_reorder; const op_kind_t Round = dnnl_graph_op_round; const op_kind_t Select = dnnl_graph_op_select; diff --git a/src/graph/interface/op.hpp b/src/graph/interface/op.hpp index cd2cfdfa0a8..c4817ff315b 100644 --- a/src/graph/interface/op.hpp +++ b/src/graph/interface/op.hpp @@ -434,6 +434,7 @@ struct dnnl_graph_op : public std::enable_shared_from_this { CASE(ReLUBackward); CASE(Reorder); CASE(Round); + CASE(RMSNorm); CASE(Select); CASE(Sigmoid); CASE(SigmoidBackward); diff --git a/src/graph/interface/op_def.hpp b/src/graph/interface/op_def.hpp index 19a5c7a30b3..8e08763b461 100644 --- a/src/graph/interface/op_def.hpp +++ b/src/graph/interface/op_def.hpp @@ -497,7 +497,7 @@ DNNL_GRAPH_OP_SCHEMA(GroupNorm, 1, "T1", {data_type::f32, data_type::bf16, data_type::f16}) .set_type_constraints("T2", {data_type::f32, data_type::bf16}) .set_shape_inference_function(infer_groupnorm_output_shape) - .set_op_def_constraint_function(check_ln_gn_data_type) + .set_op_def_constraint_function(check_norm_data_type) .set_op_def_constraint_function(check_ln_gn_fwd_outputs_num)) DNNL_GRAPH_OP_SCHEMA(HardSigmoid, 1, @@ -614,7 +614,7 @@ DNNL_GRAPH_OP_SCHEMA(LayerNorm, 1, "T1", {data_type::f32, data_type::bf16, data_type::f16}) .set_type_constraints("T2", {data_type::f32, data_type::bf16}) .set_shape_inference_function(infer_norm_output_shape) - .set_op_def_constraint_function(check_ln_gn_data_type) + .set_op_def_constraint_function(check_norm_data_type) .set_op_def_constraint_function(check_ln_gn_fwd_outputs_num)) DNNL_GRAPH_OP_SCHEMA(LayerNormBackward, 1, @@ -640,7 +640,7 @@ DNNL_GRAPH_OP_SCHEMA(LayerNormBackward, 1, "T1", {data_type::f32, data_type::bf16, data_type::f16}) .set_type_constraints("T2", {data_type::f32, data_type::bf16}) .set_shape_inference_function(infer_norm_bprop_output_shape) - .set_op_def_constraint_function(check_ln_gn_data_type) + .set_op_def_constraint_function(check_norm_data_type) .set_op_def_constraint_function(check_ln_bwd_use_affine)) DNNL_GRAPH_OP_SCHEMA(LeakyReLU, 1, @@ -1336,6 +1336,24 @@ DNNL_GRAPH_OP_SCHEMA(Reciprocal, 1, "T", {data_type::f32, data_type::bf16, data_type::f16}) .set_shape_inference_function(infer_identity_output_shape)) +DNNL_GRAPH_OP_SCHEMA(RMSNorm, 1, + op_schema_t() + .set_inputs_option(op_schema_t::param_num_option::optional) + .set_num_inputs(std::set({1, 2})) + .set_num_outputs(1) + .set_input(0, "src", "T1") + .set_input(1, "gamma", "T2") + .set_output(0, "dst", "T1") + .set_attr(op_attr::begin_norm_axis, false, attribute_kind::i, + int64_t(-1)) + .set_attr(op_attr::epsilon, false, attribute_kind::f, 1e-5f) + .set_type_constraints( + "T1", {data_type::f32, data_type::bf16, data_type::f16}) + .set_type_constraints( + "T2", {data_type::f32, data_type::bf16, data_type::f16}) + .set_shape_inference_function(infer_norm_output_shape) + .set_op_def_constraint_function(check_norm_data_type)) + } // namespace graph } // namespace impl } // namespace dnnl diff --git a/src/graph/interface/op_def_constraint.cpp b/src/graph/interface/op_def_constraint.cpp index 9c535c79428..4a5ca33d6d8 100644 --- a/src/graph/interface/op_def_constraint.cpp +++ b/src/graph/interface/op_def_constraint.cpp @@ -145,10 +145,10 @@ bool check_softmax_bwd_output_dtype(const op_t *n) { return true; } -// check function for data_type of LayerNorm and GroupNorm. +// check function for data_type of LayerNorm, GroupNorm and RMSNorm. // only when data is bf16, gamma/beta/mean/var can be bf16. // If data is bf16, gamma/beta/mean/var can be f32 or bf16. -bool check_ln_gn_data_type(const op_t *n) { +bool check_norm_data_type(const op_t *n) { const auto &input_values = n->get_input_values(); const auto &output_values = n->get_output_values(); @@ -158,8 +158,9 @@ bool check_ln_gn_data_type(const op_t *n) { if (input_values.size() == 1 && output_values.size() == 1) { return true; } else { - if (input_values.size() > 2) { - aux_lt = input_values[2]->get_logical_tensor(); + // RMSNorm uses only one aux tensor + if (input_values.size() > 1) { + aux_lt = input_values[1]->get_logical_tensor(); } else { aux_lt = output_values[1]->get_logical_tensor(); } @@ -433,7 +434,6 @@ bool check_dyn_quant_dequant_scales_zps(const op_t *n) { } return true; } - } // namespace graph } // namespace impl } // namespace dnnl diff --git a/src/graph/interface/op_def_constraint.hpp b/src/graph/interface/op_def_constraint.hpp index 973c59641f0..73fe34abbc2 100644 --- a/src/graph/interface/op_def_constraint.hpp +++ b/src/graph/interface/op_def_constraint.hpp @@ -34,7 +34,7 @@ bool check_softmax_dtype(const op_t *n); bool check_softmax_bwd_output_dtype(const op_t *n); -bool check_ln_gn_data_type(const op_t *n); +bool check_norm_data_type(const op_t *n); bool check_typecast_data_type(const op_t *n); diff --git a/src/graph/interface/opset.hpp b/src/graph/interface/opset.hpp index 99588c4c2ad..392519e28c5 100644 --- a/src/graph/interface/opset.hpp +++ b/src/graph/interface/opset.hpp @@ -108,6 +108,7 @@ class opset_v1_t { fn(get_op_schema()); fn(get_op_schema()); fn(get_op_schema()); + fn(get_op_schema()); fn(get_op_schema()); fn(get_op_schema()); fn(get_op_schema()); diff --git a/src/graph/interface/shape_infer.cpp b/src/graph/interface/shape_infer.cpp index 89201367869..10fbb6bb35e 100644 --- a/src/graph/interface/shape_infer.cpp +++ b/src/graph/interface/shape_infer.cpp @@ -1156,6 +1156,9 @@ status_t infer_norm_output_shape(op_t *n, auto status = infer_identity_output_shape(n, inputs, outputs); if (status != status::success) return status; + const auto is_rms = n->get_kind() == op_kind::RMSNorm; + if (is_rms) return status::success; + const bool keep_stats = n->has_attr(op_attr::keep_stats) ? n->get_attr(op_attr::keep_stats) // Keep default value as which in op_schema diff --git a/tests/benchdnn/graph/deserialize.cpp b/tests/benchdnn/graph/deserialize.cpp index 2adce5ff698..47e14e41531 100644 --- a/tests/benchdnn/graph/deserialize.cpp +++ b/tests/benchdnn/graph/deserialize.cpp @@ -291,6 +291,7 @@ dnnl_driver_t deserialized_op_t::opkind2driver() const { {dnnl::graph::op::kind::ReLUBackward, dnnl_driver_t::eltwise}, {dnnl::graph::op::kind::Reorder, dnnl_driver_t::reorder}, + {dnnl::graph::op::kind::RMSNorm, dnnl_driver_t::lnorm}, {dnnl::graph::op::kind::Round, dnnl_driver_t::eltwise}, {dnnl::graph::op::kind::Select, dnnl_driver_t::binary}, {dnnl::graph::op::kind::Sigmoid, dnnl_driver_t::eltwise}, diff --git a/tests/benchdnn/graph/flex_rewrite.cpp b/tests/benchdnn/graph/flex_rewrite.cpp index 2354c6c48c9..4c92f47ddd5 100644 --- a/tests/benchdnn/graph/flex_rewrite.cpp +++ b/tests/benchdnn/graph/flex_rewrite.cpp @@ -502,6 +502,7 @@ int flex_rewrite_t::infer_output_shape( case dnnl::graph::op::kind::ReLU: case dnnl::graph::op::kind::ReLUBackward: case dnnl::graph::op::kind::Reorder: + case dnnl::graph::op::kind::RMSNorm: case dnnl::graph::op::kind::Round: case dnnl::graph::op::kind::Sigmoid: case dnnl::graph::op::kind::SigmoidBackward: diff --git a/tests/benchdnn/graph/setting_handler.cpp b/tests/benchdnn/graph/setting_handler.cpp index 8f58e5885a1..ebece7559db 100644 --- a/tests/benchdnn/graph/setting_handler.cpp +++ b/tests/benchdnn/graph/setting_handler.cpp @@ -1210,6 +1210,14 @@ bool get_lnorm_dir(const deserialized_op_t &base_op_ref, dir_t &dir) { } else { return false; } + } else if (op_kind == "RMSNorm") { + // RMSNorm OP in oneDNN Graph API only have 1 output + const size_t out_size = base_op_ref.out_lts_.size(); + if (out_size == 1) { + dir = dir_t::FWD_I; + } else { + return false; + } } else if (op_kind == "LayerNormBackward") { dir = dir_t::BWD_DW; } else { @@ -1225,7 +1233,7 @@ bool get_lnorm_dt(const deserialized_op_t &base_op_ref, dnnl_data_type_t &dt) { } bool get_lnorm_flags( - const deserialized_op_t &base_op_ref, ::bnorm::flags_t &flags) { + const deserialized_op_t &base_op_ref, ::lnorm::flags_t &flags) { bool use_affine = false; base_op_ref.get_attr_bool(use_affine, "use_affine"); const auto &op_kind = base_op_ref.kind_; @@ -1245,6 +1253,19 @@ bool get_lnorm_flags( return false; } } + } else if (op_kind == "RMSNorm") { + flags = ::lnorm::USE_RMS_NORM; + // RMSNorm input: src, gamma(opt) + // no beta/shift parameter for RMSNorm + if (in_size == 2) { + // has gamma (scale only) + flags |= ::lnorm::USE_SCALE; + } else if (in_size == 1) { + // no gamma + flags |= ::lnorm::NONE; + } else { + return false; + } } else if (op_kind == "LayerNormBackward") { // input: src, diff_dst, mean, var, gamma(opt), beta(opt) if (use_affine) { @@ -1278,6 +1299,8 @@ ::lnorm::settings_t get_setting( lnorm::get_lnorm_dt(base_op_ref, op_setting.dt[0].front()), res); DNN_GRAPH_CHECK_SETTINGS( get_driver_tag(base_op_ref, op_setting.tag[0].front()), res); + DNN_GRAPH_CHECK_SETTINGS( + get_driver_tag(base_op_ref, op_setting.tag[0].back(), true), res); DNN_GRAPH_CHECK_SETTINGS( lnorm::get_lnorm_flags(base_op_ref, op_setting.flags.front()), res); DNN_GRAPH_CHECK_SETTINGS( diff --git a/tests/benchdnn/graph/utils.cpp b/tests/benchdnn/graph/utils.cpp index 61f72ee684c..daad0bd298e 100644 --- a/tests/benchdnn/graph/utils.cpp +++ b/tests/benchdnn/graph/utils.cpp @@ -332,6 +332,7 @@ dnnl::graph::op::kind opstr2kind(const std::string &kind) { {"ReLU", dnnl::graph::op::kind::ReLU}, {"ReLUBackward", dnnl::graph::op::kind::ReLUBackward}, {"Reorder", dnnl::graph::op::kind::Reorder}, + {"RMSNorm", dnnl::graph::op::kind::RMSNorm}, {"Round", dnnl::graph::op::kind::Round}, {"Select", dnnl::graph::op::kind::Select}, {"Sigmoid", dnnl::graph::op::kind::Sigmoid}, @@ -778,6 +779,17 @@ int get_prim_arg_name_from_graph_op_output_offset( } } break; + case dnnl::graph::op::kind::RMSNorm: { + // RMSNorm OP in oneDNN Graph API only have 1 output + if (output_offset == 0) + return DNNL_ARG_DST; + else { + BENCHDNN_PRINT(0, "Error: no matching ARG for offset %d", + static_cast(output_offset)); + assert(false); + return -1; + } + } break; case dnnl::graph::op::kind::SoftMax: { if (output_offset == 0) return DNNL_ARG_DST; @@ -1150,6 +1162,18 @@ int get_prim_arg_name_from_graph_op_input_offset( return -1; } } break; + case dnnl::graph::op::kind::RMSNorm: { + if (input_offset == 0) + return DNNL_ARG_SRC; + else if (input_offset == 1) + return DNNL_ARG_SCALE; // gamma + else { + BENCHDNN_PRINT(0, "Error: no matching ARG for offset %zu", + input_offset); + assert(false); + return -1; + } + } break; default: { return DNNL_ARG_SRC; } break; diff --git a/tests/benchdnn/inputs/graph/op/f32/rmsnorm.json b/tests/benchdnn/inputs/graph/op/f32/rmsnorm.json new file mode 100644 index 00000000000..3ddc1051ffc --- /dev/null +++ b/tests/benchdnn/inputs/graph/op/f32/rmsnorm.json @@ -0,0 +1,81 @@ +{ + "version": "3.11.0", + "engine_kind": "cpu", + "fpmath_mode": "strict", + "fpmath_mode_apply_to_int": "false", + "input_ports": [ + 0 + ], + "output_ports": [ + 1 + ], + "graph": [ + { + "id": 0, + "name": "RMSNORM_0", + "kind": "RMSNorm", + "attrs": { + "begin_norm_axis": { + "type": "s64", + "value": -1 + }, + "epsilon": { + "type": "f32", + "value": 0.0625 + } + }, + "inputs": [ + { + "id": 0, + "dtype": "f32", + "shape": [ + 64, + 128, + 28, + 28 + ], + "stride": [ + 100352, + 784, + 28, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 2, + "dtype": "f32", + "shape": [ + 28 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 1, + "dtype": "f32", + "shape": [ + 64, + 128, + 28, + 28 + ], + "stride": [ + 100352, + 784, + 28, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + } + ] +} diff --git a/tests/benchdnn/inputs/graph/op/harness_bf16_all b/tests/benchdnn/inputs/graph/op/harness_bf16_all index 59f05350b17..7fcc6399dd2 100644 --- a/tests/benchdnn/inputs/graph/op/harness_bf16_all +++ b/tests/benchdnn/inputs/graph/op/harness_bf16_all @@ -135,7 +135,9 @@ --reset --dt=bf16 --in-shapes=1:1x1x1x1 --case=op/f32/greaterequal.json --reset --dt=bf16 --in-shapes=1:1 --case=op/f32/greaterequal.json --reset --dt=bf16 --op-attrs=0:mode:gelu_erf,0:mode:gelu_tanh --case=op/f32/gelu.json - +--reset --dt=bf16 --case=op/f32/rmsnorm.json +# bf16 src/dst with f32 scale +--reset --dt=0:bf16+1:bf16 --case=op/f32/rmsnorm.json # select --reset --dt=bf16 --in-shapes=2:1x1x1x128 --case=op/f32/select.json # concat diff --git a/tests/benchdnn/inputs/graph/op/harness_f16_all b/tests/benchdnn/inputs/graph/op/harness_f16_all index 8cbdc32e912..230b214f79f 100644 --- a/tests/benchdnn/inputs/graph/op/harness_f16_all +++ b/tests/benchdnn/inputs/graph/op/harness_f16_all @@ -134,6 +134,9 @@ --reset --dt=f16 --in-shapes=1:1x1x1x1 --case=op/f32/greaterequal.json --reset --dt=f16 --in-shapes=1:1 --case=op/f32/greaterequal.json --reset --dt=f16 --op-attrs=0:mode:gelu_erf,0:mode:gelu_tanh --case=op/f32/gelu.json +--reset --dt=f16 --case=op/f32/rmsnorm.json +# f16 src/dst with f32 scale +--reset --dt=0:f16+1:f16 --case=op/f32/rmsnorm.json # select --reset --dt=bf16 --in-shapes=2:1x1x1x128 --case=op/f32/select.json diff --git a/tests/benchdnn/inputs/graph/op/harness_f32_all b/tests/benchdnn/inputs/graph/op/harness_f32_all index badece55537..93c29c1f609 100644 --- a/tests/benchdnn/inputs/graph/op/harness_f32_all +++ b/tests/benchdnn/inputs/graph/op/harness_f32_all @@ -854,3 +854,4 @@ --reset --in-shapes=1:1x1x1x1 --case=op/f32/greaterequal.json --reset --in-shapes=1:1 --case=op/f32/greaterequal.json --reset --op-attrs=0:mode:gelu_erf,0:mode:gelu_tanh --case=op/f32/gelu.json +--reset --case=op/f32/rmsnorm.json diff --git a/tests/gtests/graph/api/test_cpp_api_op.cpp b/tests/gtests/graph/api/test_cpp_api_op.cpp index a8b80217012..c9907a41601 100644 --- a/tests/gtests/graph/api/test_cpp_api_op.cpp +++ b/tests/gtests/graph/api/test_cpp_api_op.cpp @@ -116,6 +116,7 @@ TEST(APIOp, CreateAllOps) { op::kind::GroupNorm, op::kind::GenIndex, op::kind::GreaterEqual, + op::kind::RMSNorm, }; // clang-format on diff --git a/tests/gtests/graph/unit/interface/test_op_def_constraint_cpu.cpp b/tests/gtests/graph/unit/interface/test_op_def_constraint_cpu.cpp index d93aef06975..491daa88b77 100644 --- a/tests/gtests/graph/unit/interface/test_op_def_constraint_cpu.cpp +++ b/tests/gtests/graph/unit/interface/test_op_def_constraint_cpu.cpp @@ -354,13 +354,13 @@ INSTANTIATE_TEST_SUITE_P(test_interface_op_def_constraint, ::testing::Values( // test function of CheckLayerNormDataType dnnl_graph_ln_params_t {LayerNorm, f32, f32, true, false, true, - graph::check_ln_gn_data_type, true}, + graph::check_norm_data_type, true}, dnnl_graph_ln_params_t {LayerNorm, bf16, f32, true, false, true, - graph::check_ln_gn_data_type, true}, + graph::check_norm_data_type, true}, dnnl_graph_ln_params_t {LayerNorm, f32, bf16, true, false, true, - graph::check_ln_gn_data_type, false}, + graph::check_norm_data_type, false}, dnnl_graph_ln_params_t {LayerNorm, bf16, bf16, true, false, - true, graph::check_ln_gn_data_type, true}, + true, graph::check_norm_data_type, true}, // test function of CheckLayerNormFwdOutputsNum dnnl_graph_ln_params_t {LayerNorm, f32, f32, true, false, true, graph::check_ln_gn_fwd_outputs_num, true},