diff --git a/paddle/cinn/hlir/pe/transform.cc b/paddle/cinn/hlir/pe/transform.cc index aa486d9940c62b..1acad47ed53045 100644 --- a/paddle/cinn/hlir/pe/transform.cc +++ b/paddle/cinn/hlir/pe/transform.cc @@ -69,8 +69,8 @@ std::vector> GetMatmulNewShapes( auto& new_y_shape = new_shape[1]; auto& out_shape = new_shape[2]; - int x_dim = x_shape.size(), y_dim = y_shape.size(); - int max_dim = std::max(x_shape.size(), y_shape.size()); + size_t x_dim = x_shape.size(), y_dim = y_shape.size(); + size_t max_dim = std::max(x_shape.size(), y_shape.size()); int out_dim = max_dim >= 3 ? 3 : (max_dim <= 2 ? 2 : max_dim); auto get_input_shape = [out_dim](const std::vector& old_shape) { @@ -309,8 +309,9 @@ std::vector Matmul(const Tensor& A, const std::string& name) { std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int a_dim = static_cast(shape_A.size()); + int b_dim = static_cast(shape_B.size()); PADDLE_ENFORCE_EQ( a_dim == 3U || a_dim == 2U, true, @@ -347,7 +348,8 @@ std::vector Matmul(const Tensor& A, auto temp = Compute( output_shape, [=](const std::vector& indice) { - int out_dim = indice.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int out_dim = static_cast(indice.size()); std::vector A_indice; std::vector B_indice; PADDLE_ENFORCE_EQ( @@ -458,13 +460,15 @@ ir::Tensor Concat(const std::vector& input_tensors, int axis, const std::string& name) { // input size 1 is valid for Concat - int input_size = input_tensors.size(); + // NOTE(large-tensor): input size is a small integer + int input_size = static_cast(input_tensors.size()); PADDLE_ENFORCE_GE(input_size, 1U, ::common::errors::InvalidArgument( "Concat should have at least 1 input tensors")); std::vector output_shape = input_tensors[0]->shape; - int input_dim = output_shape.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int input_dim = static_cast(output_shape.size()); PADDLE_ENFORCE_EQ( axis >= -input_dim && axis < input_dim, true, @@ -518,8 +522,9 @@ std::vector MatmulV2(const Tensor& A, const cinn::common::Target& target) { std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int a_dim = static_cast(shape_A.size()); + int b_dim = static_cast(shape_B.size()); PADDLE_ENFORCE_EQ( a_dim == 3U || a_dim == 2U, true, @@ -566,7 +571,8 @@ std::vector MatmulV2(const Tensor& A, packedB_shape, [=](const std::vector& indice) { std::vector indice_b; - int indice_dim = indice.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int indice_dim = static_cast(indice.size()); PADDLE_ENFORCE_GE(indice_dim, 3, ::common::errors::InvalidArgument( @@ -590,7 +596,8 @@ std::vector MatmulV2(const Tensor& A, [=](const std::vector& indice) { std::vector indice_a; std::vector indice_b; - int out_dim = indice.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int out_dim = static_cast(indice.size()); PADDLE_ENFORCE_EQ( out_dim == 3U || out_dim == 2U, true, @@ -635,8 +642,9 @@ std::vector MatmulMKL(const Tensor& A, "Mkl should be used in the cpu environment.")); std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int a_dim = static_cast(shape_A.size()); + int b_dim = static_cast(shape_B.size()); PADDLE_ENFORCE_EQ( a_dim == 3U || a_dim == 2U, true, @@ -930,8 +938,9 @@ std::vector MulMKL(const Tensor& A, "Mkl should be used in the cpu environment.")); std::vector shape_A = A->shape; std::vector shape_B = B->shape; - int a_dim = shape_A.size(); - int b_dim = shape_B.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int a_dim = static_cast(shape_A.size()); + int b_dim = static_cast(shape_B.size()); PADDLE_ENFORCE_EQ( a_dim, 2U, diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 1ae0f1b148d1bd..22c4584c295090 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -362,7 +362,7 @@ inline void TensorToVector(const phi::DenseTensor& src, memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size, nullptr); } #endif - for (unsigned int i = 0; i < src.numel(); i++) { + for (int64_t i = 0; i < src.numel(); i++) { (*dst)[i] = static_cast(array[i]); } delete[] array; @@ -408,7 +408,7 @@ inline void TensorToVector(const phi::DenseTensor& src, memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); - for (unsigned int i = 0; i < src.numel(); i++) { + for (int64_t i = 0; i < src.numel(); i++) { (*dst)[i] = static_cast(array[i]); } delete[] array; diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 5c3c22b2ab141b..a2b8b15a717813 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -96,7 +96,7 @@ 'LegacyInterpolateInferMeta', 'NceInferMeta', 'PyramidHashInferMeta', - 'RmsNormInferMeta', + 'FusedRmsNormQuantInferMeta', 'SigmoidCrossEntropyWithLogitsInferMeta', 'StackInferMeta', 'WeightOnlyLinearInferMeta', diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 3f12cda02446fd..feda0af0cd910e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -973,6 +973,7 @@ bool GatherOpInferSymbolicShape(pir::Operation *op, "3 when the axis is not set.")); const auto &axis_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(2)); + // NOTE(large-tensor): axis is a small integer axis = static_cast(axis_shape_or_data.data().value()[0].Get()); } @@ -983,7 +984,10 @@ bool GatherOpInferSymbolicShape(pir::Operation *op, const std::vector &index_sym_shape = index_shape_or_data.shape(); - if (axis < 0) axis += input_sym_shape.size(); + if (axis < 0) { + // NOTE(large-tensor): tensor rank is a small integer + axis += static_cast(input_sym_shape.size()); + } const auto &out_sym_shape = [&] { std::vector out_sym_shape; diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc index 9a12a83ddff9cd..4f21b70561bef8 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc @@ -83,7 +83,10 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op, infer_context->GetShapeOrDataForValue(input_values[0]).shape(); size_t rank = out_dims.size(); - axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank)); + // NOTE(large-tensor): axis is a small integer. + axis = axis >= 0 + ? axis + : static_cast(std::max(int64_t(0), int64_t(axis + rank))); for (size_t i = 1; i < input_size; ++i) { const auto &operand_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.cc index f8fc73584dac8c..322fd1332cb204 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.cc @@ -22,7 +22,7 @@ void SliceDfsImpl(const ExprVec &datas, int64_t start, int64_t end, int64_t cur_visit_axis, - int offset, + int64_t offset, ExprVec *result) { int64_t begin = 0; int64_t stop = shape.at(cur_visit_axis); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 3369a797f3543f..0be6c22de66d30 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -740,7 +740,8 @@ bool BroadcastTensorsOpInferSymbolicShape( // 1. Find Output rank = max(Inputs rank) int target_rank = 0; for (const auto &input_shape_or_data : input_shape_or_data_list) { - int tmp_rank = input_shape_or_data.shape().size(); + // NOTE(large-tensor): tensor rank is a small integer + int tmp_rank = static_cast(input_shape_or_data.shape().size()); target_rank = std::max(target_rank, tmp_rank); } // 2. Output dim(axis=x) = max(Inputs dim(axis=x)) @@ -749,7 +750,8 @@ bool BroadcastTensorsOpInferSymbolicShape( for (int i = 0; i < target_rank; i++) { auto tmp_dim = symbol::DimExpr{1}; for (const auto &input_shape_or_data : input_shape_or_data_list) { - int axis = input_shape_or_data.shape().size(); + // NOTE(large-tensor): tensor rank is a small integer + int axis = static_cast(input_shape_or_data.shape().size()); axis = i - target_rank + axis; if (axis >= 0) { infer_context->AddBroadcastableCstr(input_shape_or_data.shape()[axis], @@ -1143,7 +1145,10 @@ bool CrossEntropyWithSoftmaxOpInferSymbolicShape( const auto &index_dim = index_shape.shape(); const auto &attributes = op->attributes(); int axis = attributes.at("axis").dyn_cast().data(); - if (axis < 0) axis += input_shape.shape().size(); + if (axis < 0) { + // NOTE(large-tensor): tensor rank is a small integer + axis += static_cast(input_shape.shape().size()); + } bool soft_label = attributes.at("soft_label").dyn_cast().data(); PADDLE_ENFORCE(!soft_label || input_dim.size() == index_dim.size(), @@ -1197,7 +1202,8 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op, const auto &shape_data_list = x_shape.dyn_cast(); - size_t rank = shape_data_list.at(0).shape().size(); + // NOTE(large-tensor): tensor rank is a small integer + int rank = static_cast(shape_data_list.at(0).shape().size()); const int64_t axis = [&] { int64_t axis = axis_expr.data()->at(0).dyn_cast(); return axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank)); @@ -1216,8 +1222,8 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op, const std::vector &out_dims = [&] { std::vector out_dims = shape_data_list.at(0).shape(); - for (size_t i = 0; i < rank; ++i) { - if (i != static_cast(axis)) { + for (int i = 0; i < rank; ++i) { + if (i != axis) { details::BuildCstrEqForTensorListAlongAxis( infer_context, shape_data_list, i); continue; @@ -2751,7 +2757,7 @@ bool GroupNormOpInferSymbolicShape( infer_context->SetShapeOrDataForValue(op->result(0), x_shape); - int64_t channel_idx; + size_t channel_idx; std::string data_format = op->attribute("data_format").AsString(); if (data_format == "NHWC") { @@ -2867,9 +2873,10 @@ bool LerpOpInferSymbolicShape(pir::Operation *op, std::vector x_shape = x_shape_or_data.shape(); std::vector y_shape = y_shape_or_data.shape(); std::vector w_shape = w_shape_or_data.shape(); - int x_ndims = x_shape.size(); - int y_ndims = y_shape.size(); - int w_ndims = w_shape.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int x_ndims = static_cast(x_shape.size()); + int y_ndims = static_cast(y_shape.size()); + int w_ndims = static_cast(w_shape.size()); std::vector out1_shape; std::vector out2_shape; int diffxy = x_ndims - y_ndims; @@ -3414,6 +3421,59 @@ bool RoiAlignOpInferSymbolicShape( return true; } +bool RmsNormOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + // Get the shapes of input tensors + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &scale_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + std::vector normalized_shape = + paddle::dialect::details::GetVectorAttr(op, "normalized_shape"); + + std::vector x_dims = x_shape_or_data.shape(); + int x_dims_size = x_dims.size(); + int normalized_shape_size = normalized_shape.size(); + int begin_norm_axis = x_dims_size - normalized_shape_size; + + // Flatten x_dims to 2D and get dim[1] + PADDLE_ENFORCE_LT(normalized_shape_size, + x_dims_size, + "normalized_shape must be less than x_dims"); + for (int i = 0; i < normalized_shape_size; i++) { + infer_context->AddEqualCstr( + x_dims[x_dims_size - i - 1], + symbol::DimExpr(normalized_shape[normalized_shape_size - i - 1])); + } + + if (!scale_shape_or_data.isa()) { + std::vector scale_dims = scale_shape_or_data.shape(); + PADDLE_ENFORCE_EQ( + scale_dims.size(), + normalized_shape_size, + "scale_dims.size() must be equal to normalized_shape_size"); + for (int i = 0; i < normalized_shape_size; i++) { + infer_context->AddEqualCstr(scale_dims[i], + symbol::DimExpr(normalized_shape[i])); + } + } + + // Set output shapes + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)}); + + // Set invvar + std::vector before_norm_dims( + x_dims.begin(), x_dims.begin() + begin_norm_axis); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(before_norm_dims)}); + + return true; +} + bool SpectralNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &weight_shape = @@ -3688,7 +3748,7 @@ bool NceOpInferSymbolicShape(pir::Operation *op, infer_context->AddEqualCstr(weight_shape[0], bias_shape[0]); } - int num_total_classes = + int64_t num_total_classes = op->attribute("num_total_classes").data(); infer_context->AddEqualCstr(symbol::DimExpr(num_total_classes), weight_shape[0]); @@ -3700,7 +3760,7 @@ bool NceOpInferSymbolicShape(pir::Operation *op, symbol::TensorShapeOrDataDimExprs(out_shape)}); bool is_test = op->attribute("is_test").data(); - int num_neg_samples = + int64_t num_neg_samples = op->attribute("num_neg_samples").data(); if (!is_test) { std::vector sample_out_shape = {x_shape[0]}; @@ -4072,7 +4132,7 @@ bool RnnOpInferSymbolicShape(pir::Operation *op, "The rank of PreState in RNN must be 3. But " "the received rank is %d.", pre_state_shape_or_data_list[0].shape().size())); - for (size_t i = 0; i < 3; ++i) { + for (int i = 0; i < 3; ++i) { details::BuildCstrEqForTensorListAlongAxis( infer_context, pre_state_shape_or_data_list, i); } @@ -4344,7 +4404,8 @@ bool StackOpInferSymbolicShape(pir::Operation *op, infer_context->GetShapeOrDataForValue(operand_source) .dyn_cast(); - size_t rank = shape_data_list.at(0).shape().size(); + // NOTE(large-tensor): tensor rank is a small integer + int rank = static_cast(shape_data_list.at(0).shape().size()); if (axis < 0) axis += rank + 1; const symbol::ShapeOrDataDimExprs shape_data = [&] { std::vector result_shape = {}; @@ -4378,7 +4439,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op, } else { // case 2: data is empty, eg: shape_data_list = // [[shape:{5,6,7},data:{}],...] - for (size_t i = 0; i < rank; ++i) { + for (int i = 0; i < rank; ++i) { details::BuildCstrEqForTensorListAlongAxis( infer_context, shape_data_list, i); } @@ -4759,9 +4820,10 @@ bool WhereOpInferSymbolicShape(pir::Operation *op, const std::vector &operands = { op->operand_source(0), op->operand_source(1), op->operand_source(2)}; - size_t rank = x_shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + int rank = static_cast(x_shape.size()); - for (size_t i = 0; i < rank; ++i) { + for (int i = 0; i < rank; ++i) { paddle::dialect::details::BuildCstrEqForTensorListAlongAxis( infer_context, operands, i); } @@ -4849,7 +4911,8 @@ bool YoloLossOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(op->operand_source(2)).shape(); const std::vector &anchors_mask = paddle::dialect::details::GetVectorAttr(op, "anchor_mask"); - int mask_num = anchors_mask.size(); + // NOTE(large-tensor): mask number is a small integer + int mask_num = static_cast(anchors_mask.size()); int class_num = op->attribute("class_num").data(); PADDLE_ENFORCE_EQ(x_shape.size(), diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index b1d193e44163ea..a9385e149f8aba 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -115,6 +115,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(RandomRouting) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RandomRouting_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedRmsNormQuant) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiPool) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(RmsNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rnn) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rnn_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiAlign) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc index a05d5e3a0ea316..304264553ea0fa 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -337,7 +337,11 @@ bool EyeOpInferSymbolicShape(pir::Operation *op, symbol::DimExpr num_columns_dim; symbol::DimExpr num_rows_dim; if (op->HasAttribute("num_rows")) { - int num_rows_int = op->attribute("num_rows").data(); + int64_t num_rows_int64 = + op->attribute("num_rows").data(); + // TODO(large-tensor): num_rows may exceed INT_MAX + PADDLE_ENFORCE_LE_INT_MAX(num_rows_int64, "num_rows"); + int num_rows_int = static_cast(num_rows_int64); num_rows_dim = symbol::DimExpr(num_rows_int); } else if (op->operand_source(0)) { const auto &num_rows_shape_or_data = @@ -351,8 +355,11 @@ bool EyeOpInferSymbolicShape(pir::Operation *op, } if (op->HasAttribute("num_columns")) { - int num_columns_int = + int64_t num_columns_int64 = op->attribute("num_columns").data(); + // TODO(large-tensor): num_columns may exceed INT_MAX + PADDLE_ENFORCE_LE_INT_MAX(num_columns_int64, "num_columns"); + int num_columns_int = static_cast(num_columns_int64); if (num_columns_int == -1) { num_columns_dim = num_rows_dim; } else { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 3b24ec9e458e40..1badfa6e8a6d2e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -235,7 +235,8 @@ bool AffineGridOpInferSymbolicShape( op->attribute("output_shape") .data() .GetData(); - output_shape_size = output_shape.size(); + // NOTE(large-tensor): output shape size is a small integer + output_shape_size = static_cast(output_shape.size()); for (const auto &i : output_shape) { output_shape_data.push_back(symbol::DimExpr{i}); } @@ -245,7 +246,8 @@ bool AffineGridOpInferSymbolicShape( output_shape_data = details::GetOrCreateExprVecFromData( output_shape_or_data, infer_context); - output_shape_size = output_shape_data.size(); + // NOTE(large-tensor): output shape size is a small integer + output_shape_size = static_cast(output_shape_data.size()); } else { PADDLE_THROW(common::errors::InvalidArgument( "The input arguments must have the shape of output, please check!")); @@ -339,13 +341,15 @@ bool MinMaxOpInferSymbolicShape(pir::Operation *op, keepdims = GetBoolAttr(op, "keepdims"); const auto &axis_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(1)); + // NOTE(large-tensor): axis is a small integer. axis = static_cast( axis_shape_or_data.data().value().at(0).Get()); } const auto &input_sym_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); - int rank = input_sym_shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + int rank = static_cast(input_sym_shape.size()); if (axis < 0) axis += rank; const auto &out_sym_shape = [&] { @@ -475,7 +479,8 @@ bool AsStridedOpInferSymbolicShape( const std::vector &shape = paddle::dialect::details::GetVectorAttr(op, "dims"); - int rank = shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + int rank = static_cast(shape.size()); std::vector out_shape; for (int i = 0; i < rank; ++i) { symbol::DimExpr out_unknown = infer_context->GetNextSymName(); @@ -1012,8 +1017,9 @@ bool DiagEmbedOpInferSymbolicShape( int offset = attributes.at("offset").dyn_cast().data(); const auto &x_dims = operand_shape_or_data.shape(); - int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; - int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; + // NOTE(large-tensor): tensor dimensions are small integers + int dim1_ = dim1 < 0 ? static_cast(x_dims.size()) + dim1 + 1 : dim1; + int dim2_ = dim2 < 0 ? static_cast(x_dims.size()) + dim2 + 1 : dim2; int64_t offset_ = static_cast(std::abs(offset)); symbol::DimExpr new_dim_len = symbol::DimExpr(offset_) + x_dims.at(x_dims.size() - 1); @@ -1041,8 +1047,9 @@ bool DiagonalOpInferSymbolicShape( int offset = attributes.at("offset").dyn_cast().data(); const auto &x_dims = operand_shape_or_data.shape(); - int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; - int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; + // NOTE(large-tensor): tensor dimensions are small integers + int axis1_ = axis1 < 0 ? static_cast(x_dims.size()) + axis1 : axis1; + int axis2_ = axis2 < 0 ? static_cast(x_dims.size()) + axis2 : axis2; auto out_dims = x_dims; auto axis1_size = out_dims.at(axis1_); @@ -1210,7 +1217,8 @@ bool FrameOpInferSymbolicShape(pir::Operation *op, if (axis == 0) { seq_length = x_shape[0]; start_axis = 1; - end_axis = x_rank - 1; + // NOTE(large-tensor): tensor rank is a small integer + end_axis = static_cast(x_rank - 1); } else { seq_length = x_shape[x_rank - 1]; start_axis = 0; @@ -1514,7 +1522,8 @@ bool FlattenOpInferSymbolicShape( const auto &x_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); - int in_dims_size = x_shape.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int in_dims_size = static_cast(x_shape.size()); if (in_dims_size == 0) { PADDLE_ENFORCE_EQ( @@ -1776,7 +1785,8 @@ bool KthvalueOpInferSymbolicShape( bool keepdim = GetBoolAttr(op, "keepdim"); const auto &input_dims = operand_shape_or_data.shape(); - const int &dim_size = input_dims.size(); + // NOTE(large-tensor): tensor dimensions are small integers + const int dim_size = static_cast(input_dims.size()); if (axis < 0) axis += dim_size; std::vector out_dims; for (int i = 0; i < axis; i++) { @@ -1816,7 +1826,8 @@ bool InverseOpInferSymbolicShape( const auto &input_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)); std::vector input_dims = input_shape.shape(); - int input_rank = input_dims.size(); + // NOTE(large-tensor): tensor rank is a small integer + int input_rank = static_cast(input_dims.size()); infer_context->AddEqualCstr(input_dims[input_rank - 2], input_dims[input_rank - 1]); @@ -1869,7 +1880,8 @@ bool LrnOpInferSymbolicShape(pir::Operation *op, const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_shape = x_shape_or_data.shape(); - int x_size = x_shape.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int x_size = static_cast(x_shape.size()); PADDLE_ENFORCE_EQ( x_size, 4, @@ -1901,7 +1913,8 @@ bool LuOpInferSymbolicShape(pir::Operation *op, const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_shape = x_shape_or_data.shape(); - int x_rank = x_shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + int x_rank = static_cast(x_shape.size()); PADDLE_ENFORCE_GE( x_rank, @@ -1989,7 +2002,8 @@ bool ModeOpInferSymbolicShape(pir::Operation *op, int axis = op->attribute("axis").data(); bool keepdim = op->attribute("keepdim").data(); - int dim_size = x_shape.size(); + // NOTE(large-tensor): tensor dimensions are small integers + int dim_size = static_cast(x_shape.size()); if (axis < 0) { axis += dim_size; @@ -2027,7 +2041,8 @@ bool MaxoutOpInferSymbolicShape(pir::Operation *op, int axis = op->attribute("axis").data(); if (axis < 0) { - axis += x_shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + axis += static_cast(x_shape.size()); } std::vector output_shape = x_shape; @@ -2243,7 +2258,8 @@ bool MatrixPowerOpInferSymbolicShape( const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_shape = x_shape_or_data.shape(); - const int n_dim = x_shape.size(); + // NOTE(large-tensor): tensor dimensions are small integers + const int n_dim = static_cast(x_shape.size()); PADDLE_ENFORCE_GE(n_dim, 2, @@ -2448,7 +2464,10 @@ bool NormOpInferSymbolicShape(pir::Operation *op, bool is_test = op->attribute("is_test").data(); if (!is_test) { - if (axis < 0) axis += x_shape.size(); + if (axis < 0) { + // NOTE(large-tensor): tensor rank is a small integer + axis += static_cast(x_shape.size()); + } auto norm_shape = x_shape; norm_shape[axis] = symbol::DimExpr(1); @@ -2466,7 +2485,8 @@ bool NonzeroOpInferSymbolicShape( const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_shape = x_shape_or_data.shape(); - int rank = x_shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + int rank = static_cast(x_shape.size()); PADDLE_ENFORCE_GE( rank, @@ -2553,7 +2573,8 @@ bool OverlapAddOpInferSymbolicShape( const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_dims = x_shape_or_data.shape(); - const int x_rank = x_dims.size(); + // NOTE(large-tensor): tensor rank is a small integer + const int x_rank = static_cast(x_dims.size()); int hop_length = op->attribute("hop_length").data(); int axis = op->attribute("axis").data(); @@ -2740,7 +2761,8 @@ bool PNormOpInferSymbolicShape(pir::Operation *op, const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_shape = x_shape_or_data.shape(); - int x_rank = x_shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + int x_rank = static_cast(x_shape.size()); int axis = op->attribute("axis").data(); bool keepdim = op->attribute("keepdim").data(); @@ -2775,7 +2797,7 @@ bool PNormOpInferSymbolicShape(pir::Operation *op, } } else { if (keepdim) { - for (int i = 0; i < x_rank; ++i) { + for (unsigned int i = 0; i < x_rank; ++i) { if (i == axis) { out_shape.emplace_back(symbol::DimExpr(1)); } else { @@ -2783,7 +2805,7 @@ bool PNormOpInferSymbolicShape(pir::Operation *op, } } } else { - for (int i = 0; i < x_rank; ++i) { + for (unsigned int i = 0; i < x_rank; ++i) { if (i != axis) { out_shape.emplace_back(x_shape[i]); } @@ -2995,7 +3017,8 @@ bool QrOpInferSymbolicShape(pir::Operation *op, const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_shape = x_shape_or_data.shape(); - int x_rank = x_shape.size(); + // NOTE(large-tensor): tensor rank is a small integer + unsigned int x_rank = static_cast(x_shape.size()); PADDLE_ENFORCE_GE( x_rank, @@ -3064,7 +3087,8 @@ bool RepeatInterleaveOpInferSymbolicShape( // what should I do if axis is null int axis = attributes.at("axis").dyn_cast().data(); - int x_rank = operand_shape_or_data.shape().size(); + // NOTE(large-tensor): tensor rank is a small integer + int x_rank = static_cast(operand_shape_or_data.shape().size()); if (axis < 0) axis += x_rank; const auto &out_sym_shape = [&] { @@ -3574,7 +3598,8 @@ bool SplitWithNumOpInferSymbolicShape( const auto &x_s_or_d = infer_context->GetShapeOrDataForValue(op->operand_source(0)); - int rank = x_s_or_d.shape().size(); + // NOTE(large-tensor): tensor rank is a small integer + int rank = static_cast(x_s_or_d.shape().size()); const auto &out_s_d = [&](int64_t split_axis, int64_t res_num) { symbol::DimExpr input_axis_dim = x_s_or_d.shape().at(split_axis); @@ -4319,7 +4344,8 @@ bool UniqueOpInferSymbolicShape(pir::Operation *op, const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_dims_sym = x_shape_or_data.shape(); - const size_t rank = x_dims_sym.size(); + // NOTE(large-tensor): tensor rank is a small integer + const int rank = static_cast(x_dims_sym.size()); const std::vector axes = paddle::dialect::details::GetVectorAttr(op, "axis"); @@ -4387,7 +4413,8 @@ bool UniqueConsecutiveOpInferSymbolicShape( const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_dims_sym = x_shape_or_data.shape(); - const size_t rank = x_dims_sym.size(); + // NOTE(large-tensor): tensor rank is a small integer + const int rank = static_cast(x_dims_sym.size()); const std::vector axes = paddle::dialect::details::GetVectorAttr(op, "axis"); const bool return_inverse = GetBoolAttr(op, "return_inverse"); diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index af1c7ba8b92157..f679e46add0b14 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -142,7 +142,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& dev_ctx, } // Number of threads per block shall be larger than 64. threads = std::max(64, threads); - int blocks = DivUp(DivUp(numel, vec_size), threads); + int64_t blocks = DivUp(DivUp(numel, vec_size), threads); int limit_blocks = dev_ctx.GetCUDAMaxGridDimSize()[0]; if (blocks > limit_blocks) { blocks = limit_blocks; diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index e9553832aa0794..3456866bbfe4c2 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -1998,6 +1998,14 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::BFLOAT16})}, + {"rms_norm", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, + {"rms_norm_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, #ifdef PADDLE_WITH_XPU_FFT {"conj", XPUKernelSet({phi::DataType::FLOAT32, diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index 90f193e5b077b7..ec52db685ce63a 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -805,7 +805,7 @@ void TensorToVector(const phi::DenseTensor& src, memory_utils::Copy(dst_place, dst_ptr, src.place(), src_ptr, size, nullptr); } #endif - for (unsigned int i = 0; i < src.numel(); i++) { + for (int64_t i = 0; i < src.numel(); i++) { (*dst)[i] = static_cast(array[i]); } delete[] array; @@ -886,7 +886,7 @@ void TensorToVector(const phi::DenseTensor& src, std::vector* dst) { memory_utils::Copy(dst_place, dst_ptr, src.place(), src_ptr, size); - for (unsigned int i = 0; i < src.numel(); i++) { + for (int64_t i = 0; i < src.numel(); i++) { (*dst)[i] = static_cast(array[i]); } delete[] array; diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 94750577a5debc..62838603f8db4e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1671,6 +1671,23 @@ void FusedRmsNormQuantGradInferMeta(const MetaTensor& x, } } +PADDLE_API void RMSNormGradInferMeta( + const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& y_grad, + const std::vector& normalized_shape, + double epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad) { + if (x_grad && x) { + x_grad->share_meta(x); + } + if (scale_grad && scale) { + scale_grad->share_meta(scale); + } +} + void RnnGradInferMeta(const MetaTensor& x, const std::vector& pre_state, const std::vector& weight_list, @@ -2337,6 +2354,7 @@ PADDLE_API void FastRMSNormGradInfermeta(const MetaTensor& x, scale_grad->share_meta(scale); } } + void IndexElementwiseGetGradInferMeta( const MetaTensor& x, const std::vector& index, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index d92cfb5ca9592d..7e51094f498b4f 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -629,6 +629,16 @@ PADDLE_API void FusedRmsNormQuantGradInferMeta(const MetaTensor& x, MetaTensor* norm_weight_grad, MetaTensor* norm_bias_grad); +PADDLE_API void RMSNormGradInferMeta( + const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& invvar, + const MetaTensor& y_grad, + const std::vector& normalized_shape, + double epsilon, + MetaTensor* x_grad, + MetaTensor* scale_grad); + PADDLE_API void RnnGradInferMeta( const MetaTensor& x, const std::vector& pre_state, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index fbbc32a06c6503..940d62e672d330 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3887,6 +3887,90 @@ void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RmsNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const std::vector& normalized_shape, + double epsilon, + MetaTensor* y, + MetaTensor* invvar) { + auto x_dim = x.dims(); + // std::vector normalized_shape_data = normalized_shape.GetData(); + int normalized_shape_size = normalized_shape.size(); + int x_dims_size = x_dim.size(); + int begin_norm_axis = x_dims_size - normalized_shape_size; + + PADDLE_ENFORCE_GT(begin_norm_axis, + 0, + common::errors::InvalidArgument( + "'begin_norm_axis' in Op(LayerNorm) should be " + "greater than zero. But received [%d].", + begin_norm_axis)); + + PADDLE_ENFORCE_LT( + begin_norm_axis, + x_dims_size, + common::errors::InvalidArgument( + "'begin_norm_axis' must be less than the dimensions of X," + "But received 'begin_norm_axis' is [%d]," + "received the dimensions of X is [%d].", + begin_norm_axis, + x_dims_size)); + + for (int i = 0; i < normalized_shape_size; i++) { + PADDLE_ENFORCE_EQ(x_dim[x_dims_size - i - 1], + normalized_shape[normalized_shape_size - i - 1], + common::errors::InvalidArgument( + "The %d-th dimension of X is not equal to the %d-th " + "dimension of NormalizedShape.", + x_dims_size - i - 1, + normalized_shape_size - i - 1)); + } + + if (scale) { + auto scale_dim = scale.dims(); + PADDLE_ENFORCE_EQ(scale_dim.size(), + normalized_shape_size, + common::errors::InvalidArgument( + "The dimensions of Input(Scale) must be equal to the " + "dimensions of NormalizedShape. " + "But received: the dimensions of Input(Scale) is " + "[%d], the dimensions of NormalizedShape is [%d].", + scale_dim.size(), + normalized_shape_size)); + for (int i = 0; i < normalized_shape_size; i++) { + PADDLE_ENFORCE_EQ(scale_dim[i], + normalized_shape[i], + common::errors::InvalidArgument( + "The %d-th dimension of Input(Scale) is not equal " + "to the %d-th dimension of NormalizedShape.", + i, + i)); + } + } + + auto matrix_dim = common::flatten_to_2d(x_dim, begin_norm_axis); + auto before_norm_dims = slice_ddim(x_dim, 0, begin_norm_axis); + int64_t right = matrix_dim[1]; + + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + common::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + + DataType x_dtype = x.dtype(); + y->set_dims(x_dim); + y->set_dtype(x_dtype); + + DataType param_type = + (x_dtype == DataType::BFLOAT16 || x_dtype == DataType::FLOAT16) + ? DataType::FLOAT32 + : x_dtype; + invvar->set_dims({before_norm_dims}); + invvar->set_dtype(param_type); +} + void RowConvInferMeta(const MetaTensor& x, const MetaTensor& filter, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 35eec1f8c7cd20..f1e9464094243a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -773,6 +773,13 @@ PADDLE_API void ReduceAsInferMeta(const MetaTensor& x, const MetaTensor& target, MetaTensor* out); +PADDLE_API void RmsNormInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const std::vector& normalized_shape, + double epsilon, + MetaTensor* y, + MetaTensor* invvar); + PADDLE_API void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out); diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index aa4c45fecf968e..cea06837c9a1fb 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1010,9 +1010,16 @@ void FusedAttentionInferMeta(const MetaTensor& x, "and must satisfy the limitations: " "(num_head * dim_head == dim_embed)")); } - num_heads = y_dim[1]; - dim_head = y_dim[2]; - hidden_size = y_dim[3]; + // TODO(large-tensor): num_heads, dim_head, hidden_size may exceed INT_MAX + int64_t num_heads_int64 = y_dim[1]; + int64_t dim_head_int64 = y_dim[2]; + int64_t hidden_size_int64 = y_dim[3]; + PADDLE_ENFORCE_LE_INT_MAX(num_heads_int64, "num_heads"); + PADDLE_ENFORCE_LE_INT_MAX(dim_head_int64, "dim_head"); + PADDLE_ENFORCE_LE_INT_MAX(hidden_size_int64, "hidden_size"); + num_heads = static_cast(num_heads_int64); + dim_head = static_cast(dim_head_int64); + hidden_size = static_cast(hidden_size_int64); } PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index ac171094731d2f..b071589a5b4da8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -351,13 +351,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x, vec = {}; } } else { - for (int64_t i = 0; i < int_axis; i++) - vec.emplace_back(x_dims[static_cast(i)]); + for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]); if (keepdims) { vec.emplace_back(static_cast(1)); } - for (int64_t i = int_axis + 1; i < x_rank; i++) - vec.emplace_back(x_dims[static_cast(i)]); + for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); } out->set_dims(common::make_ddim(vec)); @@ -834,7 +832,7 @@ void CSplitInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { void DecodeJpegInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* out) { - std::vector out_dims; + std::vector out_dims; if (mode == "unchanged") { out_dims = {-1, -1, -1}; @@ -1117,7 +1115,7 @@ void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { x_dims[rank - 2], x_dims[rank - 1])); - std::vector batch_dims_vec{}; + std::vector batch_dims_vec{}; for (int i = 0; i < rank - 1; ++i) { batch_dims_vec.emplace_back(x_dims[i]); } @@ -1835,12 +1833,14 @@ void FoldInferMeta(const MetaTensor& x, "It is expected dilations_size equals to 2, but got size %d", dilations.size())); - int output_height = output_sizes[0]; - int output_width = output_sizes[1]; - int kernel_height = kernel_sizes[0]; - int kernel_width = kernel_sizes[1]; - int dilation_height = dilations[0]; - int dilation_width = dilations[1]; + // NOTE(large-tensor): output_sizes, kernel_sizes, dilations are small + // integers + int64_t output_height = output_sizes[0]; + int64_t output_width = output_sizes[1]; + int64_t kernel_height = kernel_sizes[0]; + int64_t kernel_width = kernel_sizes[1]; + int64_t dilation_height = dilations[0]; + int64_t dilation_width = dilations[1]; int64_t stride_height = strides[0]; int64_t stride_width = strides[1]; @@ -1879,13 +1879,13 @@ void FoldInferMeta(const MetaTensor& x, 0, common::errors::InvalidArgument( "The `output_height` should be greater than zero, " - "but received output_height: %d .", + "but received output_height: %ld .", output_height)); PADDLE_ENFORCE_GT(output_width, 0, common::errors::InvalidArgument( "The `output_width` should be greater than zero, " - "but received output_width: %d .", + "but received output_width: %ld .", output_width)); // check dilations PADDLE_ENFORCE_GT( @@ -1905,30 +1905,31 @@ void FoldInferMeta(const MetaTensor& x, dilations[0], dilations[1])); - std::vector out_dims; + std::vector out_dims; // batch_size out_dims.push_back(in_dims[0]); // NOLINT // output_plane int64_t output_channels = in_dims[1] / (kernel_width * kernel_height); out_dims.push_back(output_channels); - int blocks_height = (output_sizes[0] + 2 * paddings[0] - - (dilations[0] * (kernel_sizes[0] - 1) + 1)) / - strides[0] + - 1; - int blocks_width = (output_sizes[1] + 2 * paddings[1] - - (dilations[1] * (kernel_sizes[1] - 1) + 1)) / - strides[1] + - 1; + int64_t blocks_height = + (output_height + 2 * static_cast(paddings[0]) - + (dilation_height * (kernel_height - 1) + 1)) / + stride_height + + 1; + int64_t blocks_width = (output_width + 2 * static_cast(paddings[1]) - + (dilation_width * (kernel_width - 1) + 1)) / + stride_width + + 1; // check output height and width PADDLE_ENFORCE_GT( blocks_height, 0, common::errors::InvalidArgument( - "The sliding blocks calculated from input spatial size (%d, %d), " + "The sliding blocks calculated from input spatial size (%ld, %ld), " "kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), " - "is (%d, %d), which should be a positive integer.", + "is (%ld, %ld), which should be a positive integer.", in_dims[2], in_dims[3], kernel_sizes[0], @@ -1944,9 +1945,9 @@ void FoldInferMeta(const MetaTensor& x, blocks_width, 0, common::errors::InvalidArgument( - "The sliding blocks calculated from input spatial size (%d, %d), " + "The sliding blocks calculated from input spatial size (%ld, %ld), " "kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), " - "is (%d, %d), which should be a positive integer.", + "is (%ld, %ld), which should be a positive integer.", in_dims[2], in_dims[3], kernel_sizes[0], @@ -1962,10 +1963,10 @@ void FoldInferMeta(const MetaTensor& x, blocks_height * blocks_width, in_dims[2], common::errors::InvalidArgument( - "Given input output_size (%d, %d), " + "Given input output_size (%ld, %ld), " "kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), " "which should be expected size of input's dimension " - "2 to match the calculated number of %d * %d = %d, but got %d", + "2 to match the calculated number of %ld * %ld = %ld, but got %ld", output_height, output_width, kernel_sizes[0], @@ -2683,8 +2684,8 @@ void LUInferMeta(const MetaTensor& x, if (x_rank == 2) { infos->set_dims(common::make_ddim({})); } else { - auto Infos_dim = - std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 2); + std::vector Infos_dim( + dims_vec.begin(), dims_vec.begin() + static_cast(x_rank - 2)); infos->set_dims(common::make_ddim(Infos_dim)); } infos->set_dtype(DataType::INT32); @@ -2692,8 +2693,8 @@ void LUInferMeta(const MetaTensor& x, PADDLE_ENFORCE_NOT_NULL(pivots, common::errors::InvalidArgument( "Output(Pivots) should not be nullptr.")); - auto Pivots_dim = - std::vector(dims_vec.begin(), dims_vec.begin() + x_rank - 1); + std::vector Pivots_dim( + dims_vec.begin(), dims_vec.begin() + static_cast(x_rank - 1)); Pivots_dim[x_rank - 2] = min_mn; pivots->set_dims(common::make_ddim(Pivots_dim)); pivots->set_dtype(DataType::INT32); @@ -3070,13 +3071,11 @@ void MinMaxWithIndexInferMeta(const MetaTensor& x, vec = {}; } } else { - for (int64_t i = 0; i < int_axis; i++) - vec.emplace_back(x_dims[static_cast(i)]); + for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]); if (keepdims) { vec.emplace_back(static_cast(1)); } - for (int64_t i = int_axis + 1; i < x_rank; i++) - vec.emplace_back(x_dims[static_cast(i)]); + for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); } val_out->set_dims(common::make_ddim(vec)); @@ -3967,7 +3966,8 @@ void QrInferMeta(const MetaTensor& x, MetaTensor* q, MetaTensor* r) { auto x_dims = x.dims(); - int x_rank = x_dims.size(); + // NOTE(large-tensor): tensor rank is a small integer + int x_rank = static_cast(x_dims.size()); PADDLE_ENFORCE_GE( x_dims.size(), 2, @@ -5383,6 +5383,8 @@ void PartialConcatInferMeta(const std::vector& xs, MetaConfig config) { int64_t batch_size = -1; int64_t input_len = -1; + // TODO(large-tensor): change start_index to int64_t + int64_t start_index_int64 = start_index; auto inputs_num = xs.size(); PADDLE_ENFORCE_GT(inputs_num, @@ -5427,12 +5429,12 @@ void PartialConcatInferMeta(const std::vector& xs, start_index)); if (start_index < 0) { - start_index += input_len; + start_index_int64 += input_len; } if (length > 0) { PADDLE_ENFORCE_GE(input_len, - start_index + length, + start_index_int64 + length, common::errors::OutOfRange( "start_index + length is larger than input length")); } @@ -5440,7 +5442,7 @@ void PartialConcatInferMeta(const std::vector& xs, std::vector out_dims(2); out_dims[0] = batch_size; // colnum = input_num * length - out_dims[1] = (length < 0) ? input_len - start_index : length; + out_dims[1] = (length < 0) ? input_len - start_index_int64 : length; out_dims[1] *= inputs_num; DDim out_dim = common::make_ddim(out_dims); out->set_dims(out_dim); @@ -5448,7 +5450,7 @@ void PartialConcatInferMeta(const std::vector& xs, } void SvdvalsInferMeta(const MetaTensor& x, MetaTensor* s) { - auto SDDim = [](const DDim& x_dim, int k) { + auto SDDim = [](const DDim& x_dim, int64_t k) { auto x_vec = common::vectorize(x_dim); x_vec.erase(x_vec.end() - 2, x_vec.end()); x_vec.push_back(k); @@ -5477,21 +5479,21 @@ void SvdInferMeta(const MetaTensor& x, MetaTensor* u, MetaTensor* s, MetaTensor* vh) { - auto UDDim = [](const DDim& x_dim, int k) { + auto UDDim = [](const DDim& x_dim, int64_t k) { // get x_dim and return the ddim of U auto x_vec = common::vectorize(x_dim); x_vec[x_vec.size() - 1] = k; return common::make_ddim(x_vec); }; - auto VHDDim = [](const DDim& x_dim, int k) { + auto VHDDim = [](const DDim& x_dim, int64_t k) { // get x_dim and return the ddim of U auto x_vec = common::vectorize(x_dim); x_vec[x_vec.size() - 2] = k; return common::make_ddim(x_vec); }; - auto SDDim = [](const DDim& x_dim, int k) { + auto SDDim = [](const DDim& x_dim, int64_t k) { // get x_dim and return the ddim of U auto x_vec = common::vectorize(x_dim); x_vec[x_vec.size() - 2] = k; @@ -5500,7 +5502,8 @@ void SvdInferMeta(const MetaTensor& x, }; auto in_dims = x.dims(); - int x_rank = in_dims.size(); + // NOTE(large-tensor): tensor rank is a small integer + int x_rank = static_cast(in_dims.size()); PADDLE_ENFORCE_GE( in_dims.size(), 2, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 09dbd12f534f7f..27871579f9db57 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -125,6 +125,12 @@ if(NOT WITH_MAGMA) list(REMOVE_ITEM kernel_gpu "gpu/eig_kernel.cu" "gpu/eig_grad_kernel.cu") endif() +if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 11.8)) + OR APPLE + OR WITH_ROCM) + list(REMOVE_ITEM kernel_gpu "gpu/rms_norm_cuda_kernel.cu") +endif() + if(WITH_CUTLASS) add_custom_target( gemm_epilogue_compile_script ALL diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index af13745d27eda8..6ea956f22d39a9 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -138,13 +138,14 @@ HOSTDEVICE inline T AdaptEndIndex(T ph, T input_size, T output_size) { /* used for fractional pool to calculate start and end index of each divided * grid */ +template HOSTDEVICE inline float FractionalRationalU( - float u, float alpha, int input, int output, int pool_size = 0) { + float u, float alpha, T input, T output, T pool_size = 0) { if (pool_size > 0) { return u; } - int base = input / output; + T base = input / output; float u_max1 = static_cast(base + 2) / alpha - 1; float u_max2 = static_cast(input + 1 - base) / alpha - @@ -154,24 +155,26 @@ HOSTDEVICE inline float FractionalRationalU( return u * max_u; } -HOSTDEVICE inline int FractionalStartIndex(int idx, - float alpha, - float u, - int pool_size = 0) { +template +HOSTDEVICE inline T FractionalStartIndex(T idx, + float alpha, + float u, + T pool_size = 0) { // paper use ceil instead: static_cast(ceil(alpha * (idx + u) - 1)); - return static_cast((idx + u) * alpha) - static_cast(u * alpha); + return static_cast((idx + u) * alpha) - static_cast(u * alpha); } -HOSTDEVICE inline int FractionalEndIndex(int idx, - float alpha, - float u, - int pool_size = 0) { +template +HOSTDEVICE inline T FractionalEndIndex(T idx, + float alpha, + float u, + T pool_size = 0) { if (pool_size > 0) { - return static_cast((idx + u) * alpha) - static_cast(u * alpha) + + return static_cast((idx + u) * alpha) - static_cast(u * alpha) + pool_size; } // paper use ceil instead: static_cast(ceil(alpha * (idx + 1 + u) - 1)); - return static_cast((idx + 1 + u) * alpha) - static_cast(u * alpha); + return static_cast((idx + 1 + u) * alpha) - static_cast(u * alpha); } /* diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index 46b2e009f68c5b..bd0e64e046b780 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -14,8 +14,10 @@ #include "paddle/phi/kernels/lerp_grad_kernel.h" +#include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/common/amp_type_traits.h" @@ -28,6 +30,8 @@ #include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" +COMMON_DECLARE_bool(use_accuracy_compatible_kernel); + namespace phi { template @@ -56,20 +60,43 @@ __global__ void LerpGradKernelImpl(const T* weight, } template -__global__ void LerpGradScalarKernelImpl(const T* weight, +__global__ void LerpGradKernelCompatibleImpl(const T* weight, + const T* dout, + T* dx, + T* dy, + const int64_t out_size, + const int64_t x_size, + const int64_t y_size) { + CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) { + T weight_value = weight[idx]; + T remaining_weight_value = static_cast(1) - weight[idx]; + if (dx) { + if (idx < x_size) { + dx[idx] = remaining_weight_value * dout[idx]; + } + } + if (dy) { + if (idx < y_size) { + dy[idx] = weight_value * dout[idx]; + } + } + } +} + +template +__global__ void LerpGradScalarKernelImpl(const WeightT* weight, const T* dout, T* dx, T* dy, const int64_t out_size, const int64_t x_size, const int64_t y_size) { - using MPType = typename phi::dtype::MPTypeTrait::Type; - MPType weight_scalar = static_cast(weight[0]); + double weight_scalar = static_cast(weight[0]); CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) { - MPType temp_dx = weight_scalar * static_cast(dout[idx]); + double temp_dx = weight_scalar * static_cast(dout[idx]); if (dx) { if (idx < x_size) { - dx[idx] = static_cast(static_cast(dout[idx]) - temp_dx); + dx[idx] = static_cast(static_cast(dout[idx]) - temp_dx); } } if (dy) { @@ -80,6 +107,31 @@ __global__ void LerpGradScalarKernelImpl(const T* weight, } } +template +__global__ void LerpGradScalarKernelCompatibleImpl(const WeightT* weight, + const T* dout, + T* dx, + T* dy, + const int64_t out_size, + const int64_t x_size, + const int64_t y_size) { + T weight_scalar = static_cast(weight[0]); + T remaining_weight_scalar = + static_cast(1 - static_cast(weight[0])); + CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) { + if (dx) { + if (idx < x_size) { + dx[idx] = remaining_weight_scalar * dout[idx]; + } + } + if (dy) { + if (idx < y_size) { + dy[idx] = weight_scalar * dout[idx]; + } + } + } +} + bool XYNeedReduce(const DenseTensor& x, const DenseTensor& y, const DenseTensor& out) { @@ -123,23 +175,66 @@ void SwitchKernel(const Context& dev_ctx, T* y_grad_data) { if (weight.numel() == 1) { // condition when weight is a scalar - const T* weight_data = weight.data(); const T* out_grad_data = out_grad.data(); const int64_t out_size = out_grad.numel(); const int64_t weight_size = weight.numel(); auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_size); - LerpGradScalarKernelImpl<<>>(weight_data, - out_grad_data, - x_grad_data, - y_grad_data, - out_size, - x_grad_size, - y_grad_size); + + if (weight.dtype() == DataType::FLOAT64) { + const double* weight_data = weight.data(); + if (FLAGS_use_accuracy_compatible_kernel) { + LerpGradScalarKernelCompatibleImpl + <<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } else { + LerpGradScalarKernelImpl<<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } + } else { + const T* weight_data = weight.data(); + if (FLAGS_use_accuracy_compatible_kernel) { + LerpGradScalarKernelCompatibleImpl + <<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } else { + LerpGradScalarKernelImpl<<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } + } } else { // broadcast weight with out_grad's dimensions const std::vector in_tensors = {&weight, &out_grad}; @@ -155,16 +250,29 @@ void SwitchKernel(const Context& dev_ctx, const int64_t weight_size = weight.numel(); auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_size); - LerpGradKernelImpl<<>>(weight_data, - out_grad_data, - x_grad_data, - y_grad_data, - out_size, - x_grad_size, - y_grad_size); + if (FLAGS_use_accuracy_compatible_kernel) { + LerpGradKernelCompatibleImpl<<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } else { + LerpGradKernelImpl<<>>(weight_data, + out_grad_data, + x_grad_data, + y_grad_data, + out_size, + x_grad_size, + y_grad_size); + } } } diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index 59d6e0e2834bc0..831ab972506163 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" @@ -33,28 +34,29 @@ struct LerpElementWiseDirectCUDAFunctor { } }; -template +template struct LerpScalarDirectCUDAFunctor { - const T *weight_; + const WeightT* weight_; - HOSTDEVICE inline LerpScalarDirectCUDAFunctor(const T *weight) + HOSTDEVICE inline LerpScalarDirectCUDAFunctor(const WeightT* weight) : weight_(weight) {} HOSTDEVICE inline T operator()(const T x, const T y) const { + T weight_scalar = static_cast(weight_[0]); if (abs(static_cast(weight_[0])) < 0.5f) { - return x + weight_[0] * (y - x); + return x + weight_scalar * (y - x); } else { - return y - (y - x) * (static_cast(1) - weight_[0]); + return y - (y - x) * (static_cast(1) - weight_scalar); } } }; template -void LerpKernel(const Context &dev_ctx, - const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &weight, - DenseTensor *out) { +void LerpKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + DenseTensor* out) { if (out && out->numel() == 0) { dev_ctx.template Alloc(out); return; @@ -70,16 +72,22 @@ void LerpKernel(const Context &dev_ctx, rank)); dev_ctx.template Alloc(out); - std::vector outputs = {out}; + std::vector outputs = {out}; - std::vector inputs; + std::vector inputs; if (weight.numel() == 1) { - const T *weight_ptr = weight.data(); inputs.reserve(2); inputs.emplace_back(&x); inputs.emplace_back(&y); - auto functor = LerpScalarDirectCUDAFunctor(weight_ptr); - funcs::BroadcastKernel(dev_ctx, inputs, &outputs, functor); + if (weight.dtype() == DataType::FLOAT64) { + const double* weight_ptr = weight.data(); + auto functor = LerpScalarDirectCUDAFunctor(weight_ptr); + funcs::BroadcastKernel(dev_ctx, inputs, &outputs, functor); + } else { + const T* weight_ptr = weight.data(); + auto functor = LerpScalarDirectCUDAFunctor(weight_ptr); + funcs::BroadcastKernel(dev_ctx, inputs, &outputs, functor); + } } else { inputs.reserve(3); auto functor = LerpElementWiseDirectCUDAFunctor(); diff --git a/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu new file mode 100644 index 00000000000000..82bfa0985a7e20 --- /dev/null +++ b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.cu @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(rms_norm, + GPU, + ALL_LAYOUT, + phi::RMSNormFwdKernel, + float, + double, + phi::float16, + phi::bfloat16) {} + +PD_REGISTER_KERNEL(rms_norm_grad, + GPU, + ALL_LAYOUT, + phi::RMSNormBwdKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h new file mode 100644 index 00000000000000..1583825446beb4 --- /dev/null +++ b/paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h @@ -0,0 +1,1111 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/common/ddim.h" +#include "paddle/common/flags.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +COMMON_DECLARE_bool(use_accuracy_compatible_kernel); + +namespace phi { + +// ----------------------------------------------------------------------- +// Constants +// ----------------------------------------------------------------------- + +static constexpr int kCUDANumThreads = 256; +static constexpr int kCUDABlockReduceNumThreads = 512; +static constexpr int kWarpSize = 32; + +// ----------------------------------------------------------------------- +// Helper Functions & Structs +// ----------------------------------------------------------------------- + +template +__device__ __forceinline__ T Rsqrt_(T x); + +template <> +__device__ __forceinline__ float Rsqrt_(float x) { + return rsqrtf(x); +} + +template <> +__device__ __forceinline__ double Rsqrt_(double x) { + return rsqrt(x); +} + +template +struct alignas(sizeof(T) * kVecSize) aligned_vector { + T val[kVecSize]; +}; + +template +struct SimplePair { + T1 first; + T2 second; + + __host__ __device__ SimplePair() {} + __host__ __device__ SimplePair(T1 f, T2 s) : first(f), second(s) {} +}; + +template +bool can_vectorize(const T* ptr, int alignment) { + uint64_t addr = reinterpret_cast(ptr); + return addr % alignment == 0; +} + +// ----------------------------------------------------------------------- +// Welford Algorithms +// ----------------------------------------------------------------------- +template +struct WelfordData { + scalar_t mean; + scalar_t m2; + index_t n; + scalar_t nf; + + __host__ __device__ WelfordData() : mean(0), m2(0), n(0), nf(0) {} + + __host__ __device__ + WelfordData(scalar_t mean, scalar_t m2, index_t n, scalar_t nf) + : mean(mean), m2(m2), n(n), nf(nf) {} +}; + +// ----------------------------------------------------------------------- +// Warp & Block Reductions +// ----------------------------------------------------------------------- + +template +__device__ __forceinline__ T WARP_SHFL_DOWN_(T value, + int delta, + int width = kWarpSize, + unsigned int mask = 0xffffffff) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_(T value, + int srcLane, + int width = kWarpSize, + unsigned int mask = 0xffffffff) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_sync(mask, value, srcLane, width); +#else + return __shfl(value, srcLane, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_XOR_(T value, + int laneMask, + int width = kWarpSize, + unsigned int mask = 0xffffffff) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +__device__ T BlockReduceSum(T val, T* shared) { + int lane = threadIdx.x % kWarpSize; + int wid = threadIdx.x / kWarpSize; + + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN_(val, offset); + } + + if (lane == 0) { + shared[wid] = val; + } + __syncthreads(); + + // Assuming blockDim.x <= 1024, max 32 warps + val = (threadIdx.x < blockDim.x / kWarpSize) ? shared[lane] : T(0); + + if (wid == 0) { + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN_(val, offset); + } + } + return val; +} + +template +struct WelfordOps { + acc_scalar_t correction; + bool take_sqrt; + + public: + using acc_t = WelfordData; + inline __device__ acc_t reduce(acc_t acc, + scalar_t data, + index_t /*idx*/) const { + index_t new_n = acc.n + 1; + acc_scalar_t new_nf = static_cast(new_n); + acc_scalar_t delta = data - acc.mean; + acc_scalar_t new_mean = acc.mean + delta / new_nf; + acc_scalar_t new_delta = data - new_mean; + return { + new_mean, + acc.m2 + delta * new_delta, + new_n, + new_nf, + }; + } + inline __device__ acc_t combine(acc_t a, acc_t b) const { + if (a.nf == 0) { + return b; + } + if (b.nf == 0) { + return a; + } + acc_scalar_t delta = b.mean - a.mean; + acc_scalar_t new_count = a.nf + b.nf; + acc_scalar_t nb_over_n = b.nf / new_count; + return {a.mean + delta * nb_over_n, + a.m2 + b.m2 + delta * delta * a.nf * nb_over_n, + -1, + new_count}; + } + inline __device__ res_t project(acc_t acc) const { + const scalar_t mean = static_cast(acc.mean); + const acc_scalar_t divisor = acc.nf > correction ? acc.nf - correction : 0; + const acc_scalar_t var = acc.m2 / divisor; + res_t results(take_sqrt ? std::sqrt(var) : var, mean); + return results; + } + + static __device__ acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const { + return {WARP_SHFL_DOWN_(acc.mean, offset), + WARP_SHFL_DOWN_(acc.m2, offset), + WARP_SHFL_DOWN_(acc.n, offset), + WARP_SHFL_DOWN_(acc.nf, offset)}; + } +#endif + __host__ __device__ WelfordOps(acc_scalar_t correction, bool take_sqrt) + : correction(correction), take_sqrt(take_sqrt) {} +}; + +// ----------------------------------------------------------------------- +// Forward Kernels +// ----------------------------------------------------------------------- + +// Non-vectorized RowwiseMoments for RMSNorm +template +__global__ void RowwiseMomentsCUDAKernel(int64_t N, + T_ACC eps, + const T* X, + T_ACC* rstd) { + using WelfordType = WelfordData; + using WelfordOp = WelfordOps>; + + const int64_t i = blockIdx.x; + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; + WelfordType val(0, 0, 0, 0); + + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + val = welford_op.reduce(val, static_cast(X[index]), index); + } + + // Block Reduce + // 1. Warp Reduce + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + WelfordType wdB = welford_op.warp_shfl_down(val, offset); + val = welford_op.combine(val, wdB); + } + + // 2. Block Reduce (via shared memory) + __shared__ + typename std::aligned_storage::type val_shared[32]; + WelfordType* val_shared_ptr = reinterpret_cast(val_shared); + + int lane = threadIdx.x % kWarpSize; + int wid = threadIdx.x / kWarpSize; + + __syncthreads(); + if (lane == 0) { + val_shared_ptr[wid] = val; + } + __syncthreads(); + + val = (threadIdx.x < blockDim.x / kWarpSize) ? val_shared_ptr[lane] + : WelfordType(0, 0, 0, 0); + + // Final Warp Reduce for the first warp + if (wid == 0) { + for (int offset = kWarpSize >> 1; offset > 0; offset >>= 1) { + WelfordType wdB = welford_op.warp_shfl_down(val, offset); + val = welford_op.combine(val, wdB); + } + } + + if (threadIdx.x == 0) { + T_ACC m1; // mean + T_ACC m2; // var + SimplePair res = welford_op.project(val); + m2 = res.first; + m1 = res.second; + rstd[i] = Rsqrt_(m2 + m1 * m1 + eps); + } +} + +// Non-vectorized Forward for RMSNorm +template +__global__ void RMSNormForwardCUDAKernel( + int64_t N, const T* X, const T_ACC* rstd, const T* scale, T* Y) { + const int64_t i = blockIdx.x; + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + const T_ACC scale_v = + scale == nullptr ? T_ACC(1) : static_cast(scale[j]); + Y[index] = static_cast((static_cast(X[index])) * + static_cast(rstd[i]) * scale_v); + } +} + +// Vectorized Helper +template +__device__ T_ACC compute_stats(const T* __restrict__ X, + const int N, + T_ACC* buf) { + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(X); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const int n_vec_to_read = N / kVecSize; + T_ACC sigma2 = 0; + + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; +#pragma unroll + for (int ii = 0; ii < kVecSize; ii++) { + T_ACC val = static_cast(data.val[ii]); + sigma2 += val * val; + } + } + + // Intra-warp reduction + for (int offset = (kWarpSize >> 1); offset > 0; offset >>= 1) { + sigma2 += WARP_SHFL_DOWN_(sigma2, offset); + } + + // Inter-warp reductions + if (blockDim.y > 1) { + T_ACC* meansigmabuf = buf; + // Use simpler layout: just sigma2 + for (int offset = blockDim.y >> 1; offset > 0; offset >>= 1) { + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + meansigmabuf[wrt_y] = sigma2; + } + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y < offset) { + sigma2 += meansigmabuf[threadIdx.y]; + } + __syncthreads(); + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + meansigmabuf[0] = sigma2 / static_cast(N); + } + __syncthreads(); + return meansigmabuf[0]; + + } else { + return WARP_SHFL_(sigma2, 0) / static_cast(N); + } +} + +template +__global__ void vectorized_rms_norm_kernel(const int N, + T_ACC eps, + const T* __restrict__ X, + const T* scale, + T_ACC* rstd, + T* Y) { + extern __shared__ char s_data_raw[]; + T_ACC* s_data = reinterpret_cast(s_data_raw); + + auto i1 = blockIdx.x; + const T* block_row = X + i1 * N; + + // Compute stats + T_ACC sigma2 = compute_stats(block_row, N, s_data); + + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(block_row); + const vec_t* scale_vec = + (scale != nullptr) ? reinterpret_cast(scale) : nullptr; + vec_t* Y_vec = reinterpret_cast(Y + i1 * N); + + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const int n_vec_to_read = N / kVecSize; + + T_ACC rstd_val = Rsqrt_(sigma2 + eps); + + if (scale_vec != nullptr) { + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; +#pragma unroll + for (int ii = 0; ii < kVecSize; ii++) { + out.val[ii] = + static_cast(static_cast(scale_vec[i].val[ii]) * + (rstd_val * static_cast(data.val[ii]))); + } + Y_vec[i] = out; + } + } else { + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; +#pragma unroll + for (int ii = 0; ii < kVecSize; ii++) { + out.val[ii] = + static_cast(rstd_val * static_cast(data.val[ii])); + } + Y_vec[i] = out; + } + } + + if (thrx == 0) { + rstd[i1] = rstd_val; + } +} + +template +void launch_vectorized_rms_norm_kernel_driver(int N, + int64_t M, + T_ACC eps, + const T* X_data, + const T* scale_data, + T* Y_data, + T_ACC* rstd_data, + cudaStream_t stream) { + const int num_threads = 128; + const dim3 threads(kWarpSize, num_threads / kWarpSize, 1); + dim3 blocks(M); + + // Shared memory for reduction: need size proportional to threads.y and T_ACC + int nshared = threads.y > 1 ? threads.y * 3 / 2 * sizeof(T_ACC) : 0; + + vectorized_rms_norm_kernel + <<>>( + N, eps, X_data, scale_data, rstd_data, Y_data); +} + +// ----------------------------------------------------------------------- +// Backward Kernels +// ----------------------------------------------------------------------- + +template +__device__ __inline__ void compute_gI(const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + const T* __restrict__ scale, + T* dX, + const int N, + T_ACC* buf) { + const auto i1 = blockIdx.x; + const T_ACC rstd_val = rstd[i1]; + T_ACC stats_x2{0}; + constexpr int unroll = 4; + auto l = unroll * threadIdx.x; + const T* X_i = X + i1 * N; + const T* dY_i = dY + i1 * N; + T* dX_i = dX + i1 * N; + + for (; l + unroll - 1 < N; l += blockDim.x * unroll) { +#pragma unroll + for (int k = 0; k < unroll; k++) { + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l + k]) : T_ACC(1); + const auto c_h = static_cast(X_i[l + k]); + const auto c_loss = static_cast(dY_i[l + k]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; + } + } + for (; l < N; l++) { + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + const auto c_h = static_cast(X_i[l]); + const auto c_loss = static_cast(dY_i[l]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; + } + + stats_x2 = BlockReduceSum(stats_x2, buf); + + if (threadIdx.x == 0) { + buf[0] = stats_x2; + } + __syncthreads(); + stats_x2 = buf[0]; + + T_ACC fH = N; + T_ACC term1 = (T_ACC(1) / fH) * rstd_val; + + for (int l = threadIdx.x; l < N; l += blockDim.x) { + const auto x = static_cast(X_i[l]); + const auto dy = static_cast(dY_i[l]); + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + + T_ACC f_grad_input = fH * scale_val * dy; + f_grad_input -= (x)*rstd_val * stats_x2; + f_grad_input *= term1; + dX_i[l] = static_cast(f_grad_input); + } +} + +template +__global__ void rms_norm_grad_input_kernel(const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + const T* __restrict__ scale, + T* dX, + const int N) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* buf = reinterpret_cast(&s_data1); + compute_gI(dY, X, rstd, scale, dX, N, buf); +} + +template +__global__ void rms_norm_grad_input_kernel_vectorized( + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + const T* __restrict__ scale, + T* dX, + const int N) { + alignas(sizeof(double)) extern __shared__ char shared_data[]; + T_ACC* reduce_buf = reinterpret_cast(&shared_data); + + const auto bIdx = blockIdx.x; + const T_ACC rstd_val = rstd[bIdx]; + const T* X_i = X + bIdx * N; + const T* dY_i = dY + bIdx * N; + T* dX_i = dX + bIdx * N; + + using vec_t = aligned_vector; + const vec_t* const X_i_vec_ptr = reinterpret_cast(X_i); + const vec_t* const dY_i_vec_ptr = reinterpret_cast(dY_i); + const vec_t* const scale_vec_ptr = + (scale != nullptr) ? reinterpret_cast(scale) : nullptr; + vec_t* const dX_i_vec = reinterpret_cast(dX_i); + + vec_t X_i_vec_reg, dY_i_vec_reg, scale_vec_reg, dX_i_vec_reg; + for (int k = 0; k < kVecSize; ++k) { + scale_vec_reg.val[k] = T(1); + } + + T_ACC stats_x2{0}; + unsigned int l = threadIdx.x * kVecSize; + for (; l + kVecSize - 1 < N; l += blockDim.x * kVecSize) { + unsigned int vec_idx = l / kVecSize; + if (scale != nullptr) { + scale_vec_reg = scale_vec_ptr[vec_idx]; + } + + X_i_vec_reg = X_i_vec_ptr[vec_idx]; + dY_i_vec_reg = dY_i_vec_ptr[vec_idx]; + + for (int k = 0; k < kVecSize; ++k) { + const auto scale_val = static_cast(scale_vec_reg.val[k]); + const auto c_h = static_cast(X_i_vec_reg.val[k]); + const auto c_loss = static_cast(dY_i_vec_reg.val[k]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; + } + } + + // Tail Loop + for (; l < N; l++) { + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + const auto c_h = static_cast(X_i[l]); + const auto c_loss = static_cast(dY_i[l]); + stats_x2 += c_loss * scale_val * (c_h)*rstd_val; + } + + stats_x2 = BlockReduceSum(stats_x2, reduce_buf); + if (threadIdx.x == 0) { + reduce_buf[0] = stats_x2; + } + __syncthreads(); + stats_x2 = reduce_buf[0]; + + T_ACC fH = N; + T_ACC term1 = (T_ACC(1) / fH) * rstd_val; + + l = threadIdx.x * kVecSize; + for (; l + kVecSize - 1 < N; l += blockDim.x * kVecSize) { + unsigned int vec_idx = l / kVecSize; + if (scale != nullptr) { + scale_vec_reg = scale_vec_ptr[vec_idx]; + } + + X_i_vec_reg = X_i_vec_ptr[vec_idx]; + dY_i_vec_reg = dY_i_vec_ptr[vec_idx]; + + for (int k = 0; k < kVecSize; ++k) { + const auto scale_val = static_cast(scale_vec_reg.val[k]); + const auto x = static_cast(X_i_vec_reg.val[k]); + const auto dy = static_cast(dY_i_vec_reg.val[k]); + + T_ACC f_grad_input = fH * scale_val * dy; + f_grad_input -= (x)*rstd_val * stats_x2; + f_grad_input *= term1; + dX_i_vec_reg.val[k] = static_cast(f_grad_input); + } + + dX_i_vec[vec_idx] = dX_i_vec_reg; + } + + // Tail Loop + for (; l < N; l += blockDim.x) { + const auto x = static_cast(X_i[l]); + const auto dy = static_cast(dY_i[l]); + const auto scale_val = + (scale != nullptr) ? static_cast(scale[l]) : T_ACC(1); + + T_ACC f_grad_input = fH * scale_val * dy; + f_grad_input -= (x)*rstd_val * stats_x2; + f_grad_input *= term1; + dX_i[l] = static_cast(f_grad_input); + } +} + +template +__device__ __forceinline__ void blockReduceScaleBackwardHelper( + int64_t M_start, + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + T* __restrict__ dscale, + T_ACC* dscale_sum) { + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + int64_t thread_x = static_cast(blockIdx.x) * block_dim_x + + static_cast(threadIdx.x); + + int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); + int64_t mean_index = + M_start + static_cast(threadIdx.y) * rows_per_thread_y; + T_ACC warp_rstd = 0; + if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { + warp_rstd = rstd[mean_index + lane_id]; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + __syncwarp(); +#endif + + T_ACC dY_regs[rows_per_thread_y] = {0}; + T_ACC X_regs[rows_per_thread_y] = {0}; +#pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + int64_t current_y = + M_start + static_cast(threadIdx.y) * rows_per_thread_y + i; + bool active = true; + if (check_x && thread_x >= N) { + active = false; + } + if (check_y && current_y >= M) { + active = false; + } + if (active) { + dY_regs[i] = static_cast(dY[current_y * N + thread_x]); + X_regs[i] = static_cast(X[current_y * N + thread_x]); + } + } + +#pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + T_ACC rstd_reg = WARP_SHFL_(warp_rstd, i, kWarpSize); + *dscale_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; + } +} + +template +__device__ __forceinline__ void blockReduceScaleBackwardWithChecks( + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + T* __restrict__ dscale, + T_ACC* dscale_sum) { + for (int64_t M_start = static_cast(blockIdx.y) * rows_per_block_y; + M_start < M; + M_start += rows_per_block_y * gridDim.y) { + int64_t M_end = M_start + rows_per_block_y - 1; + if (!check_y || M_end < M) { + blockReduceScaleBackwardHelper( + M_start, M, N, dY, X, rstd, dscale, dscale_sum); + } else { + blockReduceScaleBackwardHelper( + M_start, M, N, dY, X, rstd, dscale, dscale_sum); + } + } +} + +template +__global__ void ScaleBackwardCUDAKernelTemplate(int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ rstd, + T* __restrict__ dscale) { + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + static_assert(rows_per_thread_y <= kWarpSize); + + T_ACC dscale_sum = 0; + + // Template : Boundary check of x and y + if (aligned_grid) { + blockReduceScaleBackwardWithChecks( + M, N, dY, X, rstd, dscale, &dscale_sum); + } else { + if (static_cast(blockIdx.x) * block_dim_x + block_dim_x - 1 < N) { + blockReduceScaleBackwardWithChecks( + M, N, dY, X, rstd, dscale, &dscale_sum); + } else { + blockReduceScaleBackwardWithChecks( + M, N, dY, X, rstd, dscale, &dscale_sum); + } + } + + int64_t thread_x = + (static_cast(blockIdx.x)) * block_dim_x + threadIdx.x; + + if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) { + if (aligned_grid || thread_x < N) { + int64_t thread_y = + (static_cast(blockIdx.y)) * blockDim.y + threadIdx.y; + if (dscale) { + dscale[thread_y * N + thread_x] = static_cast(dscale_sum); + } + } + } else { + // Full reduction using shared memory + static_assert(rows_per_thread_y <= kWarpSize); + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dscale; + int padded_bx = (block_dim_x + 1); + s_dscale = s_data_typed; + s_dscale[threadIdx.y * padded_bx + threadIdx.x] = dscale_sum; + __syncthreads(); + + static_assert(block_dim_x * block_dim_y % kWarpSize == 0); + constexpr int warps_available_to_reduce = + block_dim_x * block_dim_y / kWarpSize; + int thread_id = threadIdx.y * block_dim_x + threadIdx.x; + int warp_id = thread_id / kWarpSize; + int lane_id = thread_id & (kWarpSize - 1); +#pragma unroll + for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) { + T_ACC reg_dscale; + if (lane_id < block_dim_y) { + reg_dscale = s_dscale[lane_id * padded_bx + i]; + } +#pragma unroll + for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { + reg_dscale += WARP_SHFL_XOR_(reg_dscale, delta, kWarpSize); + } + + int64_t out_index = static_cast(blockIdx.x) * block_dim_x + i; + if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { + if (dscale) { + dscale[out_index] = static_cast(reg_dscale); + } + } + } + } +} + +template +void ConfigureAndLaunchScaleBackwardKernel(const T* dY_data, + const T* X_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + T* dscale_data, + cudaStream_t cuda_stream) { + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + blocks.y = 1; + size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; + + if (blocks.y == 1 && threads.y == 1) { + if (aligned_grid) { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } else { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } + } else { + if (aligned_grid) { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } else { + ScaleBackwardCUDAKernelTemplate + <<>>( + M, N, dY_data, X_data, rstd_data, dscale_data); + } + } +} + +// ----------------------------------------------------------------------- +// Host API Implementations +// ----------------------------------------------------------------------- + +template +void RMSNormFwdKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale_opt, + const std::vector& normalized_shape, + double epsilon, + DenseTensor* y, + DenseTensor* invvar) { + using T_ACC = typename phi::dtype::MPTypeTrait::Type; + + int begin_norm_axis = x.dims().size() - normalized_shape.size(); + + auto matrix_dim = common::flatten_to_2d(x.dims(), begin_norm_axis); + int64_t rows = matrix_dim[0]; + int64_t cols = matrix_dim[1]; + + auto* scale_ptr = scale_opt.get_ptr(); + const DenseTensor& scale = *scale_ptr; + + auto* x_data = x.data(); + auto* scale_data = scale_ptr ? scale.data() : nullptr; + auto* y_data = dev_ctx.template Alloc(y); + auto* rstd_data = dev_ctx.template Alloc(invvar); + + auto stream = dev_ctx.stream(); + + // When using a vectorization size of 8 in fp16 and bf16, there may be + // misalignment of accuracy and torch alignment. + if (!FLAGS_use_accuracy_compatible_kernel && rows <= 1024 && + (cols / rows >= 32)) { + constexpr int num_vec_elems2 = 8; + constexpr int alignment2 = num_vec_elems2 * sizeof(T); + bool can_vec_X2 = can_vectorize(x_data, alignment2); + bool can_vec_Y2 = can_vectorize(y_data, alignment2); + bool can_vec_scale2 = can_vectorize(scale_data, alignment2); + bool is_supported_type2 = (std::is_same::value || + std::is_same::value); + if (is_supported_type2 && + cols <= + static_cast(1ULL << std::numeric_limits::digits) && + cols % num_vec_elems2 == 0 && can_vec_X2 && can_vec_Y2 && + can_vec_scale2) { + launch_vectorized_rms_norm_kernel_driver( + cols, + rows, + static_cast(epsilon), + x_data, + scale_data, + y_data, + rstd_data, + stream); + return; + } + } + + // Check vectorization conditions + constexpr int num_vec_elems = 4; + constexpr int alignment = num_vec_elems * sizeof(T); + bool can_vec_X = can_vectorize(x_data, alignment); + bool can_vec_Y = can_vectorize(y_data, alignment); + bool can_vec_scale = can_vectorize(scale_data, alignment); + bool is_supported_type = (std::is_same::value || + std::is_same::value || + std::is_same::value); + + if (is_supported_type && + cols <= + static_cast(1ULL << std::numeric_limits::digits) && + cols % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_scale) { + launch_vectorized_rms_norm_kernel_driver( + cols, + rows, + static_cast(epsilon), + x_data, + scale_data, + y_data, + rstd_data, + stream); + + } else { + RowwiseMomentsCUDAKernel + <<>>( + cols, static_cast(epsilon), x_data, rstd_data); + + RMSNormForwardCUDAKernel<<>>( + cols, x_data, rstd_data, scale_data, y_data); + } +} + +template +void RMSNormBwdKernel(const Context& dev_ctx, + const DenseTensor& X, + const paddle::optional& scale_opt, + const DenseTensor& invvar, + const DenseTensor& dY, + const std::vector& normalized_shape, + double epsilon, + DenseTensor* dX, + DenseTensor* dscale) { + using T_ACC = typename phi::dtype::MPTypeTrait::Type; + + int begin_norm_axis = X.dims().size() - normalized_shape.size(); + + // X, dY: [Batch, ..., Feature] -> flatten to [M, N] + // scale, dscale: [Feature] -> [N] + // invvar: [Batch, ...] -> [M] + + auto matrix_dim = common::flatten_to_2d(X.dims(), begin_norm_axis); + int64_t M = matrix_dim[0]; + int64_t N = matrix_dim[1]; + + auto* scale_ptr = scale_opt.get_ptr(); + const DenseTensor& scale = *scale_ptr; + + auto* dY_data = dY.data(); + auto* X_data = X.data(); + auto* scale_data = scale_ptr ? scale.data() : nullptr; + auto* invvar_data = invvar.data(); + + auto* dX_data = dX ? dev_ctx.template Alloc(dX) : nullptr; + auto* dscale_data = dscale ? dev_ctx.template Alloc(dscale) : nullptr; + + auto stream = dev_ctx.stream(); + + // 1. Compute dX + if (dX_data) { + static constexpr int kVecSize = 4; + bool bVectorSizeMultiple = (N % kVecSize == 0); + const unsigned int alignment = sizeof(T) * kVecSize; + bool bAlignedBuffers = can_vectorize(dY_data, alignment) && + can_vectorize(X_data, alignment) && + can_vectorize(scale_data, alignment) && + can_vectorize(dX_data, alignment); + bool is_supported_type = (std::is_same::value || + std::is_same::value || + std::is_same::value); + + const unsigned int alignment2 = sizeof(T) * 8; + bool bAlignedBuffers2 = can_vectorize(dY_data, alignment2) && + can_vectorize(X_data, alignment2) && + can_vectorize(scale_data, alignment2) && + can_vectorize(dX_data, alignment2); + bool is_supported_type2 = (std::is_same::value || + std::is_same::value); + + dim3 blocks(M); + constexpr int num_threads = 128; + constexpr int nshared = (num_threads / kWarpSize) * sizeof(T_ACC); + + // When using a vectorization size of 8 in fp16 and bf16, there may be + // misalignment of accuracy and torch alignment. + if (!FLAGS_use_accuracy_compatible_kernel && is_supported_type2 && + bAlignedBuffers2 && (N % 8 == 0 && M <= 1024 && (N / M >= 32))) { + rms_norm_grad_input_kernel_vectorized + <<>>( + dY_data, X_data, invvar_data, scale_data, dX_data, N); + } else if (is_supported_type && bAlignedBuffers && bVectorSizeMultiple) { + rms_norm_grad_input_kernel_vectorized + <<>>( + dY_data, X_data, invvar_data, scale_data, dX_data, N); + } else { + rms_norm_grad_input_kernel + <<>>( + dY_data, X_data, invvar_data, scale_data, dX_data, N); + } + } + + // 2. Compute dscale + if (dscale_data) { + constexpr int block_dim_x = 32; + const int sm_count = dev_ctx.GetSMCount(); + if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { + // When M>>N and N is very small. We can parallelize and accelerate + // computation by starting multiple blocks on the M-dimension (y). + constexpr int block_dim_y = 1; + constexpr int rows_per_block_y = 32; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y; + constexpr int max_grid_size = 64 * 1024 / 2; + blocks.y = std::min(max_grid_size / blocks.x, blocks.y); + + DenseTensor dscale_blocks; + dscale_blocks.Resize({static_cast(blocks.y * threads.y), N}); + T* dscale_blocks_ptr = dev_ctx.template Alloc(&dscale_blocks); + + if (aligned_grid) { + ScaleBackwardCUDAKernelTemplate<<>>( + M, N, dY_data, X_data, invvar_data, dscale_blocks_ptr); + } else { + ScaleBackwardCUDAKernelTemplate<<>>( + M, N, dY_data, X_data, invvar_data, dscale_blocks_ptr); + } + + // Sum reduction along blocks.y dimension to get final dscale + phi::SumKernel( + dev_ctx, dscale_blocks, {0}, dscale->dtype(), false, dscale); + + } else { + if (M < 64) { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } else if (M < 128) { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } else if (M < 256) { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } else { + ConfigureAndLaunchScaleBackwardKernel( + dY_data, X_data, invvar_data, M, N, dscale_data, stream); + } + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc b/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc index 0b4738509037f7..beb2ad5e294c23 100644 --- a/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc @@ -44,7 +44,7 @@ void StridedSliceRawGradStridedKernel(const Context& dev_ctx, PD_VISIT_ALL_TYPES(x_grad->dtype(), "StridedSliceRawGradStridedKernel", ([&] { phi::StridedTensorFill(*x_grad, 0, x_grad); })); - if (x_grad->numel() == 0) return; + if (out_grad.numel() == 0) return; DenseTensor tmp; tmp.set_layout(out_grad.layout()); tmp.set_lod(out_grad.lod()); diff --git a/paddle/phi/kernels/xpu/rms_norm_xpu_kernel.cc b/paddle/phi/kernels/xpu/rms_norm_xpu_kernel.cc new file mode 100644 index 00000000000000..cef2c6d207f711 --- /dev/null +++ b/paddle/phi/kernels/xpu/rms_norm_xpu_kernel.cc @@ -0,0 +1,253 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "paddle/common/exception.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +static void GetRowsCols(const std::vector &shape, + int64_t *p_rows, + int64_t *p_cols) { + int64_t rows = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + int64_t cols = shape[shape.size() - 1]; + *p_rows = rows; + *p_cols = cols; +} + +template +void RMSNormFwdKernel(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &scale_opt, + const std::vector &normalized_shape, + double epsilon, + DenseTensor *y, + DenseTensor *invvar) { + int begin_norm_axis = x.dims().size() - normalized_shape.size(); + PADDLE_ENFORCE_EQ( + begin_norm_axis, + x.dims().size() - 1, + common::errors::InvalidArgument( + "XPU RMSNorm only supports begin_norm_axis=%d, but got %d", + x.dims().size() - 1, + begin_norm_axis)); + + auto *scale_ptr = scale_opt.get_ptr(); + if (scale_ptr == nullptr) { + PADDLE_THROW(common::errors::InvalidArgument( + "Scale must be provided for RMSNorm backward")); + } + const DenseTensor &scale = *scale_ptr; + + int64_t rows, cols; + GetRowsCols(common::vectorize(x.dims()), &rows, &cols); + + if (scale.dtype() == phi::DataType::BFLOAT16) { + dev_ctx.template Alloc(y); + } else if (scale.dtype() == phi::DataType::FLOAT16) { + dev_ctx.template Alloc(y); + } else if (scale.dtype() == phi::DataType::FLOAT32) { + dev_ctx.template Alloc(y); + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "The dtype of scale must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]", + scale.dtype())); + } + invvar->Resize({rows}); + dev_ctx.template Alloc(invvar); + + /* + refer to: + - + https://github.com/NVIDIA/apex/blob/bfb500c8/csrc/layer_norm_cuda_kernel.cu#L1018 + - + https://github.com/PaddlePaddle/PaddleNLP/blob/5b9e0b33/ops/csrc/fused_ln/layer_norm_cuda.h#L1087 + + Supported Type combinations: + + input compute scale output + ======================================= + fp32 fp32 fp32 fp32 + fp16 fp32 fp16 fp16 + bf16 fp32 bf16 bf16 + + Not supported yet: + + input compute scale output + ======================================= + fp32 fp32 fp16 fp16 + fp32 fp32 bf16 bf16 + + Remarks: + Output type = Scale type + Compute always in FP32 + */ + +#define DISPATCH_FWD_CASE(scalar_t_out) \ + using XPUType = typename XPUTypeTrait::Type; \ + auto ret = xpu::rms_layer_norm( \ + dev_ctx.x_context(), \ + reinterpret_cast(x.data()), \ + reinterpret_cast(y->data()), \ + rows, \ + cols, \ + epsilon, \ + reinterpret_cast(scale.data()), \ + /*bias=*/nullptr, \ + invvar->data(), \ + /*is_rstd=*/true); \ + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rms_layer_norm"); + // scale.dtype() same as y->dtype() + if (x.dtype() == phi::DataType::FLOAT32 && + scale.dtype() == phi::DataType::FLOAT32) { + DISPATCH_FWD_CASE(float); + } else if (x.dtype() == phi::DataType::FLOAT16 && + scale.dtype() == phi::DataType::FLOAT16) { + DISPATCH_FWD_CASE(phi::float16); + } else if (x.dtype() == phi::DataType::BFLOAT16 && + scale.dtype() == phi::DataType::BFLOAT16) { + DISPATCH_FWD_CASE(phi::bfloat16); + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "Unsupported dtype combination: x [%s], scale [%s]. " + "Expected both to be float32, float16, or bfloat16.", + phi::DataTypeToString(x.dtype()), + phi::DataTypeToString(scale.dtype()))); + } +#undef DISPATCH_FWD_CASE +} + +template +void RMSNormBwdKernel(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &scale_opt, + const DenseTensor &invvar, + const DenseTensor &y_grad, + const std::vector &normalized_shape, + double epsilon, + DenseTensor *x_grad, + DenseTensor *scale_grad) { + int begin_norm_axis = x.dims().size() - normalized_shape.size(); + PADDLE_ENFORCE_EQ( + begin_norm_axis, + x.dims().size() - 1, + common::errors::InvalidArgument( + "XPU RMSNorm only supports begin_norm_axis=%d, but got %d", + x.dims().size() - 1, + begin_norm_axis)); + + auto *scale_ptr = scale_opt.get_ptr(); + if (scale_ptr == nullptr) { + PADDLE_THROW(common::errors::InvalidArgument( + "Scale must be provided for RMSNorm backward")); + } + const DenseTensor &scale = *scale_ptr; + + int64_t rows, cols; + GetRowsCols(common::vectorize(x.dims()), &rows, &cols); + dev_ctx.template Alloc(x_grad); + DenseTensor actual_scale_grad; + if (scale_grad) { + if (scale.dtype() == phi::DataType::BFLOAT16) { + dev_ctx.template Alloc(scale_grad); + } else if (scale.dtype() == phi::DataType::FLOAT16) { + dev_ctx.template Alloc(scale_grad); + } else if (scale.dtype() == phi::DataType::FLOAT32) { + dev_ctx.template Alloc(scale_grad); + } else { + PADDLE_THROW( + common::errors::InvalidArgument("The dtype of scale must be FLOAT32, " + "FLOAT16 or BFLOAT16, but got [%s]", + scale.dtype())); + } + actual_scale_grad = *scale_grad; + } else { + // lora specific, scale_grad is nullptr + if (scale.dtype() == phi::DataType::BFLOAT16) { + actual_scale_grad = + phi::EmptyLike(dev_ctx, scale); + } else if (scale.dtype() == phi::DataType::FLOAT16) { + actual_scale_grad = phi::EmptyLike(dev_ctx, scale); + } else if (scale.dtype() == phi::DataType::FLOAT32) { + actual_scale_grad = phi::EmptyLike(dev_ctx, scale); + } else { + PADDLE_THROW( + common::errors::InvalidArgument("The dtype of scale must be FLOAT32, " + "FLOAT16 or BFLOAT16, but got [%s]", + scale.dtype())); + } + } + +#define DISPATCH_BWD_CASE(scalar_t_out) \ + using XPUType = typename XPUTypeTrait::Type; \ + auto ret = xpu::rms_layer_norm_grad( \ + dev_ctx.x_context(), \ + reinterpret_cast(x.data()), \ + reinterpret_cast(y_grad.data()), \ + reinterpret_cast(x_grad->data()), \ + rows, \ + cols, \ + epsilon, \ + reinterpret_cast(scale.data()), \ + invvar.data(), \ + reinterpret_cast(actual_scale_grad.data()), \ + /*bias=*/nullptr, \ + /*is_rstd=*/true); \ + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rms_layer_norm_grad"); + // scale.dtype() same as y->dtype() + if (x.dtype() == phi::DataType::FLOAT32 && + scale.dtype() == phi::DataType::FLOAT32) { + DISPATCH_BWD_CASE(float); + } else if (x.dtype() == phi::DataType::FLOAT16 && + scale.dtype() == phi::DataType::FLOAT16) { + DISPATCH_BWD_CASE(phi::float16); + } else if (x.dtype() == phi::DataType::BFLOAT16 && + scale.dtype() == phi::DataType::BFLOAT16) { + DISPATCH_BWD_CASE(phi::bfloat16); + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "Unsupported dtype combination: x [%s], scale [%s]. " + "Expected both to be float32, float16, or bfloat16.", + phi::DataTypeToString(x.dtype()), + phi::DataTypeToString(scale.dtype()))); + } +#undef DISPATCH_BWD_CASE +} + +} // namespace phi + +PD_REGISTER_KERNEL(rms_norm, + XPU, + ALL_LAYOUT, + phi::RMSNormFwdKernel, + float, + phi::float16, + phi::bfloat16) {} + +PD_REGISTER_KERNEL(rms_norm_grad, + XPU, + ALL_LAYOUT, + phi::RMSNormBwdKernel, + float, + phi::float16, + phi::bfloat16) {} diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 9a7bd96b42b953..e4e4c16dba4b71 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -4238,6 +4238,17 @@ func: pyramid_hash_grad data_type: w +- backward_op: rms_norm_grad + forward: rms_norm (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon = 1e-5) -> Tensor(y), Tensor(invvar) + args: (Tensor x, Tensor scale, Tensor invvar, Tensor y_grad, int64_t[] normalized_shape={}, double epsilon = 1e-5) + output: Tensor(x_grad), Tensor(scale_grad) + infer_meta: + func: RMSNormGradInferMeta + kernel: + func: rms_norm_grad + data_type: x + optional : scale + - backward_op: shuffle_batch_grad forward: shuffle_batch (Tensor x, Tensor seed, int startup_seed=0) -> Tensor(out), Tensor(shuffle_idx), Tensor(seed_out) args: (Tensor shuffle_idx, Tensor out_grad,int startup_seed=0) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 899a86a22761a0..e52d746608e90a 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3162,6 +3162,7 @@ func : LerpInferMeta kernel : func : lerp + data_type : x inplace : (x -> out) backward : lerp_grad interfaces : paddle::dialect::InferSymbolicShapeInterface @@ -6106,3 +6107,15 @@ data_type: numbers interfaces : paddle::dialect::InferSymbolicShapeInterface traits : paddle::dialect::ForwardOnlyTrait + +- op: rms_norm + args: (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon= 1e-5) + output: Tensor(y), Tensor(invvar) + infer_meta: + func: RmsNormInferMeta + kernel: + func: rms_norm + data_type: x + optional : scale + backward: rms_norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index f8f231306f3f7a..4e0f88c25707b2 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -463,66 +463,31 @@ def layer_norm( def rms_norm( input: Tensor, - normalized_shape: int | Sequence[int], + normalized_shape: Sequence[int], weight: Tensor | None = None, eps: float = 1e-5, name: str | None = None, ) -> tuple[Tensor, Tensor]: """ Applies Layer Normalization over the last dimension of the input tensor using CUDA implementation. + Args: input (Tensor): Input tensor of shape [rows, cols] or higher dimensions (flattened to 2D). - normalized_shape(int|list|tuple): Input shape from an expected input of + normalized_shape(list|tuple): Input shape from an expected input of size :math:`[*, normalized_shape[0], normalized_shape[1], ..., normalized_shape[-1]]`. If it is a single integer, this module will normalize over the last dimension which is expected to be of that specific size. weight(Tensor, optional): The weight tensor of rms_norm. Default: None. eps(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. name (str, optional): Name of the operator. + Returns: out (Tensor): Normalized tensor of same shape as input. invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. """ - input_shape = list(input.shape) - input_ndim = len(input_shape) - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = [normalized_shape] - elif isinstance(normalized_shape, tuple): - normalized_shape = list(normalized_shape) - elif not isinstance(normalized_shape, list): - raise ValueError( - "`normalized_shape` should be int, list of ints or tuple of ints." - ) - - normalized_ndim = len(normalized_shape) - begin_norm_axis = input_ndim - normalized_ndim - if input_ndim < normalized_ndim or ( - not paddle.utils.is_same_shape( - input_shape[begin_norm_axis:], normalized_shape - ) - ): - str_normalized_shape = str(normalized_shape) - raise ValueError( - 'Given normalized_shape is ' - + str_normalized_shape - + ', expected input with shape [*, ' - + str_normalized_shape[1:] - + ', but got input shape ' - + str(input_shape) - ) - - if normalized_ndim != 1: - raise ValueError( - 'Given len(normalized_shape) is ' - + normalized_ndim - + ', expected len(normalized_shape) is 1.' - ) - - if weight is None: - raise ValueError("weight must not be None.") if in_dynamic_or_pir_mode(): - return _C_ops.fused_rms_norm_ext(input, weight, eps) + return _C_ops.rms_norm(input, weight, normalized_shape, eps) helper = LayerHelper('rms_norm', **locals()) from paddle.base.data_feeder import convert_dtype @@ -537,7 +502,7 @@ def rms_norm( type='rms_norm', inputs=inputs, outputs={'out': out, 'invvar': invvar}, - attrs={'eps': eps}, + attrs={"normalized_shape": normalized_shape, "eps": eps}, ) return out, invvar diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index a010918816c92d..be2db198bed323 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5169,7 +5169,10 @@ def lerp( """ if isinstance(weight, float): - weight = paddle.full(shape=[], fill_value=weight, dtype=x.dtype) + if x.is_cuda and in_dynamic_mode(): + weight = paddle.full(shape=[], fill_value=weight, dtype="float64") + else: + weight = paddle.full(shape=[], fill_value=weight, dtype=x.dtype) if in_dynamic_or_pir_mode(): return _C_ops.lerp(x, y, weight) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 7ab3c92dd2fcda..f32e0b2d6f115a 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -106,7 +106,7 @@ if(NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) list(REMOVE_ITEM TEST_OPS test_fused_transpose_spilt_quant_op) list(REMOVE_ITEM TEST_OPS test_fused_transpose_wlch_split_quant_op) - list(REMOVE_ITEM TEST_OPS test_rms_norm_op) + list(REMOVE_ITEM TEST_OPS test_fused_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op) list(REMOVE_ITEM TEST_OPS test_fused_swiglu_weighted_bwd_op) list(REMOVE_ITEM TEST_OPS test_fused_comm_buffer) @@ -202,7 +202,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_masked_multihead_attention_op) - list(REMOVE_ITEM TEST_OPS test_rms_norm_op) + list(REMOVE_ITEM TEST_OPS test_fused_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op) list(REMOVE_ITEM TEST_OPS test_incubate_fast_ln) list(REMOVE_ITEM TEST_OPS test_fused_weighted_swiglu_act_quant_op) @@ -575,7 +575,14 @@ if(NOT WITH_GPU OR ((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 11.8)) )# Restrict the use of older versions of CUB list(REMOVE_ITEM TEST_OPS test_incubate_fused_rmsnorm_ext) - list(REMOVE_ITEM TEST_OPS test_rms_norm) +endif() + +if(NOT WITH_GPU + OR APPLE + OR WITH_ROCM + OR ((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 11.8)) +)# Restrict the use of older versions of CUB + list(REMOVE_ITEM TEST_OPS test_rms_norm_op) endif() if(NOT WITH_GPU diff --git a/test/legacy_test/test_fused_rms_norm_op.py b/test/legacy_test/test_fused_rms_norm_op.py new file mode 100644 index 00000000000000..ac94cce01f6ac6 --- /dev/null +++ b/test/legacy_test/test_fused_rms_norm_op.py @@ -0,0 +1,1010 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +from op_test import get_device, get_device_place, is_custom_device + +import paddle +from paddle import base +from paddle.base import core + + +def quant_helper( + x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound +): + quant_value = quant_max_bound * quant_scale * x + if quant_round_type == 0: + quant_value = paddle.to_tensor(np.rint(quant_value.numpy())) + else: + quant_value = paddle.round(quant_value) + return paddle.cast( + paddle.clip(quant_value, quant_min_bound, quant_max_bound), + paddle.int8, + ) + + +def naive_residual_bias_add(x, residual, bias): + return x + residual + bias + + +def naive_rms_norm(x, gamma, beta=None, epsilon=1e-5): + variance = x.pow(2).mean(-1, keepdim=True) + out = paddle.rsqrt(variance + epsilon) * x + out = out * gamma + if beta is not None: + out = out + beta + return out + + +def fused_rms_norm(x, gamma, beta=None, epsilon=1e-5, begin_norm_axis=1): + out = paddle.incubate.nn.functional.fused_rms_norm( + x, gamma, beta, epsilon, begin_norm_axis=begin_norm_axis + ) + return out[0] + + +def naive_rms_norm_int8( + x, + gamma, + beta, + epsilon, + in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, +): + out = naive_rms_norm(x, gamma, beta, epsilon) + out = quant_helper( + out, in_scale, quant_round_type, quant_max_bound, quant_min_bound + ) + return out + + +def naive_residual_biasadd_rms_norm(x, residual, bias, gamma, beta, epsilon): + x = x + residual + bias + variance = x.pow(2).mean(-1, keepdim=True) + out = paddle.rsqrt(variance + epsilon) * x + out = out * gamma + beta + return out + + +def naive_residual_biasadd_rms_norm_int8( + x, + residual, + bias, + gamma, + beta, + epsilon, + in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, +): + out = naive_residual_biasadd_rms_norm( + x, residual, bias, gamma, beta, epsilon + ) + out = quant_helper( + out, in_scale, quant_round_type, quant_max_bound, quant_min_bound + ) + return out + + +@unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm(), + "core is not compiled with CUDA or ROCM", +) +class TestRMSNormOp(unittest.TestCase): + def setUp(self): + np.random.seed(20) + batch = 32 + cols = 256 + self.x_np = np.random.random([batch, cols]) + self.residual_np = np.random.random([batch, cols]) + self.bias_np = np.random.random([cols]) + + self.norm_weight_np = np.random.random([cols]) + self.norm_bias_np = np.random.random([cols]) + self.epsilon = 1e-6 + self.quant_scale = 0.15 + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + + def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + + paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( + x, gamma, beta, self.epsilon, begin_norm_axis=1 + )[0] + paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) + paddle.enable_static() + return paddle_rmsnorm_out, paddle_naive_rmsnorm_out + + def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + + paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( + x, + gamma, + beta, + self.epsilon, + begin_norm_axis=1, + quant_scale=self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + )[0] + + paddle_naive_rmsnorm_out = naive_rms_norm_int8( + x, + gamma, + beta, + self.epsilon, + self.quant_scale, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + paddle.enable_static() + return paddle_rmsnorm_out, paddle_naive_rmsnorm_out + + def check_residual_bias_rmsnorm( + self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype + ): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + residual = paddle.to_tensor(residual_np.astype(dtype)) + bias = paddle.to_tensor(bias_np.astype(dtype)) + + paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( + x, + gamma, + beta, + self.epsilon, + begin_norm_axis=1, + bias=bias, + residual=residual, + )[0] + + paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( + x, residual, bias, gamma, beta, self.epsilon + ) + paddle.enable_static() + return paddle_rmsnorm_out, paddle_naive_rmsnorm_out + + def check_residual_bias_rmsnorm_int8( + self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype + ): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + residual = paddle.to_tensor(residual_np.astype(dtype)) + bias = paddle.to_tensor(bias_np.astype(dtype)) + + paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( + x, + gamma, + beta, + self.epsilon, + begin_norm_axis=1, + bias=bias, + residual=residual, + quant_scale=self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + )[0] + + paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm_int8( + x, + residual, + bias, + gamma, + beta, + self.epsilon, + self.quant_scale, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + paddle.enable_static() + return paddle_rmsnorm_out, paddle_naive_rmsnorm_out + + def test_rmsnorm_fp16(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( + self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' + ) + + np.testing.assert_allclose( + paddle_rmsnorm.numpy(), + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_rmsnorm_int8(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8( + self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' + ) + np.testing.assert_allclose( + paddle_rmsnorm.numpy(), + paddle_naive_rmsnorm.numpy(), + rtol=2, + atol=2, + ) + + def test_residual_bias_add_rmsnorm_fp16(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm( + self.x_np, + self.norm_weight_np, + self.norm_bias_np, + self.residual_np, + self.bias_np, + 'float16', + ) + + np.testing.assert_allclose( + paddle_rmsnorm.numpy(), + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_residual_bias_add_rmsnorm_int8(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + ( + paddle_rmsnorm, + paddle_naive_rmsnorm, + ) = self.check_residual_bias_rmsnorm_int8( + self.x_np, + self.norm_weight_np, + self.norm_bias_np, + self.residual_np, + self.bias_np, + 'float16', + ) + + np.testing.assert_allclose( + paddle_rmsnorm.numpy(), + paddle_naive_rmsnorm.numpy(), + rtol=2, + atol=2, + ) + + def test_rms_norm_backward(self): + def get_paddle_tensor(shape, dtype, bound=0.5): + tmp = paddle.uniform(shape, dtype=dtype, min=-bound, max=bound) + tmp.stop_gradient = False + return tmp + + def get_forward_backward(func, seed, dtype): + paddle.disable_static() + paddle.seed(seed) + x = get_paddle_tensor([2, 256], dtype) + scale = get_paddle_tensor([256], dtype) + out_g = paddle.randn([2, 256], dtype) + out = func(x, scale) + paddle.autograd.backward([out], [out_g], True) + return out, (x.grad, scale.grad) + + dtypes = [paddle.float32] + if paddle.amp.is_bfloat16_supported(get_device()): + dtypes.append(paddle.bfloat16) + if paddle.amp.is_float16_supported(get_device()): + dtypes.append(paddle.float16) + for dtype in dtypes: + raw_out, raw_grads = get_forward_backward( + naive_rms_norm, seed=2024, dtype=dtype + ) + fused_out, fused_grads = get_forward_backward( + fused_rms_norm, seed=2024, dtype=dtype + ) + # forward rtol + rtol = 1e-5 if dtype == paddle.float32 else 1e-2 + np.testing.assert_allclose( + raw_out.astype(paddle.float32).numpy(), + fused_out.astype(paddle.float32).numpy(), + rtol=rtol, + ) + # backward rtol, only check float32 grad + rtol = 1e-3 + if dtype == paddle.float32: + raw_x_grad, raw_scale_grad = raw_grads + fused_x_grad, fused_scale_grad = fused_grads + np.testing.assert_allclose( + raw_x_grad.astype(paddle.float32).numpy(), + fused_x_grad.astype(paddle.float32).numpy(), + rtol=rtol, + ) + np.testing.assert_allclose( + raw_scale_grad.astype(paddle.float32).numpy(), + fused_scale_grad.astype(paddle.float32).numpy(), + rtol=rtol, + ) + + +@unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm(), + "core is not compiled with CUDA or ROCM", +) +class TestRMSNormStaticOp(unittest.TestCase): + def setUp(self): + np.random.seed(20) + self.batch = 32 + self.cols = 256 + self.x_np = np.random.random([self.batch, 256]) + self.norm_weight_np = np.random.random([256]) + self.norm_bias_np = np.random.random([256]) + self.residual_np = np.random.random([self.batch, 256]) + self.bias_np = np.random.random([256]) + self.epsilon = 1e-6 + self.quant_scale = 0.15 + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.place = get_device_place() + + def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + + paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + x_static = paddle.static.data( + name="x_static", shape=[self.batch, self.cols], dtype=dtype + ) + gamma_static = paddle.static.data( + name="gamma_static", shape=[self.cols], dtype=dtype + ) + beta_static = paddle.static.data( + name="beta_static", shape=[self.cols], dtype=dtype + ) + outs = paddle.incubate.nn.functional.fused_rms_norm( + x_static, + gamma_static, + beta_static, + self.epsilon, + begin_norm_axis=1, + )[0] + exe = base.Executor(self.place) + out_s = exe.run( + feed={ + "x_static": x_np.astype(dtype), + "gamma_static": gamma_np.astype(dtype), + "beta_static": beta_np.astype(dtype), + }, + fetch_list=[outs], + ) + return out_s[0], paddle_naive_rmsnorm_out + + def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + + paddle_naive_rmsnorm_out = naive_rms_norm_int8( + x, + gamma, + beta, + self.epsilon, + self.quant_scale, + self.quant_round_type, + self.quant_max_bound, + self.quant_min_bound, + ) + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + x_static = paddle.static.data( + name="x_static", shape=[self.batch, self.cols], dtype=dtype + ) + gamma_static = paddle.static.data( + name="gamma_static", shape=[self.cols], dtype=dtype + ) + beta_static = paddle.static.data( + name="beta_static", shape=[self.cols], dtype=dtype + ) + outs = paddle.incubate.nn.functional.fused_rms_norm( + x_static, + gamma_static, + beta_static, + self.epsilon, + begin_norm_axis=1, + quant_scale=self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + )[0] + exe = base.Executor(self.place) + out_s = exe.run( + feed={ + "x_static": x_np.astype(dtype), + "gamma_static": gamma_np.astype(dtype), + "beta_static": beta_np.astype(dtype), + }, + fetch_list=[outs], + ) + return out_s[0], paddle_naive_rmsnorm_out + + def check_residual_bias_rmsnorm( + self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype + ): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + residual = paddle.to_tensor(residual_np.astype(dtype)) + bias = paddle.to_tensor(bias_np.astype(dtype)) + + paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( + x, residual, bias, gamma, beta, self.epsilon + ) + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + x_static = paddle.static.data( + name="x_static", shape=[self.batch, self.cols], dtype=dtype + ) + residual_static = paddle.static.data( + name="residual_static", + shape=[self.batch, self.cols], + dtype=dtype, + ) + bias_static = paddle.static.data( + name="bias_static", shape=[self.cols], dtype=dtype + ) + gamma_static = paddle.static.data( + name="gamma_static", shape=[self.cols], dtype=dtype + ) + beta_static = paddle.static.data( + name="beta_static", shape=[self.cols], dtype=dtype + ) + outs = paddle.incubate.nn.functional.fused_rms_norm( + x_static, + gamma_static, + beta_static, + self.epsilon, + begin_norm_axis=1, + bias=bias_static, + residual=residual_static, + )[0] + + exe = base.Executor(self.place) + out_s = exe.run( + feed={ + "x_static": x_np.astype(dtype), + "gamma_static": gamma_np.astype(dtype), + "beta_static": beta_np.astype(dtype), + "residual_static": residual_np.astype(dtype), + "bias_static": bias_np.astype(dtype), + }, + fetch_list=[outs], + ) + return out_s[0], paddle_naive_rmsnorm_out + + def test_rmsnorm_fp16(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( + self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' + ) + + np.testing.assert_allclose( + paddle_rmsnorm, + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_residual_bias_add_rmsnorm_fp16(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm( + self.x_np, + self.norm_weight_np, + self.norm_bias_np, + self.residual_np, + self.bias_np, + 'float16', + ) + + np.testing.assert_allclose( + paddle_rmsnorm, + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_rmsnorm_int8(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8( + self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' + ) + np.testing.assert_allclose( + paddle_rmsnorm, + paddle_naive_rmsnorm.numpy(), + rtol=2, + atol=2, + ) + + +@unittest.skipIf( + not core.supports_avx512f() or not core.is_compiled_with_avx(), + "machine is not support AVX or is not compiled with AVX", +) +class TestRMSNormOpCPU(unittest.TestCase): + def setUp(self): + import os + + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + np.random.seed(20) + batch = 32 + cols = 256 + self.x_np = np.random.random([batch, cols]) + self.residual_np = np.random.random([batch, cols]) + self.bias_np = np.random.random([cols]) + + self.norm_weight_np = np.random.random([cols]) + self.norm_bias_np = np.random.random([cols]) + self.epsilon = 1e-6 + + def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + + paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( + x, gamma, beta, self.epsilon, begin_norm_axis=1 + )[0] + paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) + paddle.enable_static() + return paddle_rmsnorm_out, paddle_naive_rmsnorm_out + + def check_residual_bias_rmsnorm( + self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype + ): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + residual = paddle.to_tensor(residual_np.astype(dtype)) + bias = paddle.to_tensor(bias_np.astype(dtype)) + + paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( + x, + gamma, + beta, + self.epsilon, + begin_norm_axis=1, + bias=bias, + residual=residual, + ) + + paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( + x, residual, bias, gamma, beta, self.epsilon + ) + + paddle_naive_residual_out = naive_residual_bias_add(x, residual, bias) + paddle.enable_static() + return ( + paddle_rmsnorm_out, + paddle_naive_rmsnorm_out, + paddle_naive_residual_out, + ) + + def test_rmsnorm(self): + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( + self.x_np, self.norm_weight_np, self.norm_bias_np, 'float32' + ) + np.testing.assert_allclose( + paddle_rmsnorm.numpy(), + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_residual_bias_add_rmsnorm(self): + ( + paddle_rmsnorm, + paddle_naive_rmsnorm, + paddle_naive_residual_out, + ) = self.check_residual_bias_rmsnorm( + self.x_np, + self.norm_weight_np, + self.norm_bias_np, + self.residual_np, + self.bias_np, + 'float32', + ) + + np.testing.assert_allclose( + paddle_rmsnorm[0].numpy(), + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + np.testing.assert_allclose( + paddle_rmsnorm[1].numpy(), + paddle_naive_residual_out.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +@unittest.skipIf( + not core.supports_avx512f() or not core.is_compiled_with_avx(), + "machine is not support AVX or is not compiled with AVX", +) +class TestRMSNormStaticOpCPU(unittest.TestCase): + def setUp(self): + import os + + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + np.random.seed(20) + self.batch = 32 + self.cols = 256 + self.x_np = np.random.random([self.batch, 256]) + self.norm_weight_np = np.random.random([256]) + self.norm_bias_np = np.random.random([256]) + self.residual_np = np.random.random([self.batch, 256]) + self.bias_np = np.random.random([256]) + self.epsilon = 1e-6 + self.place = paddle.CPUPlace() + + def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + + paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + x_static = paddle.static.data( + name="x_static", shape=[self.batch, self.cols], dtype=dtype + ) + gamma_static = paddle.static.data( + name="gamma_static", shape=[self.cols], dtype=dtype + ) + beta_static = paddle.static.data( + name="beta_static", shape=[self.cols], dtype=dtype + ) + outs = paddle.incubate.nn.functional.fused_rms_norm( + x_static, + gamma_static, + beta_static, + self.epsilon, + begin_norm_axis=1, + )[0] + exe = base.Executor(self.place) + out_s = exe.run( + feed={ + "x_static": x_np.astype(dtype), + "gamma_static": gamma_np.astype(dtype), + "beta_static": beta_np.astype(dtype), + }, + fetch_list=[outs], + ) + return out_s[0], paddle_naive_rmsnorm_out + + def check_residual_bias_rmsnorm( + self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype + ): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + residual = paddle.to_tensor(residual_np.astype(dtype)) + bias = paddle.to_tensor(bias_np.astype(dtype)) + + paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( + x, residual, bias, gamma, beta, self.epsilon + ) + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + x_static = paddle.static.data( + name="x_static", shape=[self.batch, self.cols], dtype=dtype + ) + residual_static = paddle.static.data( + name="residual_static", + shape=[self.batch, self.cols], + dtype=dtype, + ) + bias_static = paddle.static.data( + name="bias_static", shape=[self.cols], dtype=dtype + ) + gamma_static = paddle.static.data( + name="gamma_static", shape=[self.cols], dtype=dtype + ) + beta_static = paddle.static.data( + name="beta_static", shape=[self.cols], dtype=dtype + ) + outs = paddle.incubate.nn.functional.fused_rms_norm( + x_static, + gamma_static, + beta_static, + self.epsilon, + begin_norm_axis=1, + bias=bias_static, + residual=residual_static, + )[0] + + exe = base.Executor(self.place) + out_s = exe.run( + feed={ + "x_static": x_np.astype(dtype), + "gamma_static": gamma_np.astype(dtype), + "beta_static": beta_np.astype(dtype), + "residual_static": residual_np.astype(dtype), + "bias_static": bias_np.astype(dtype), + }, + fetch_list=[outs], + ) + return out_s[0], paddle_naive_rmsnorm_out + + def test_rmsnorm(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( + self.x_np, self.norm_weight_np, self.norm_bias_np, 'float32' + ) + + np.testing.assert_allclose( + paddle_rmsnorm, + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_residual_bias_add_rmsnorm(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm( + self.x_np, + self.norm_weight_np, + self.norm_bias_np, + self.residual_np, + self.bias_np, + 'float32', + ) + + np.testing.assert_allclose( + paddle_rmsnorm, + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +class TestRMSNormAxisEquivalence(unittest.TestCase): + def setUp(self): + np.random.seed(123) + paddle.seed(123) + + # x [batch_size, seq_len, hidden_size] + self.batch_size = 1 + self.seq_len = 8 + self.hidden_size = 64 + + self.x_np = np.random.random( + [self.batch_size, self.seq_len, self.hidden_size] + ).astype('float32') + self.weight_np = np.random.random([self.hidden_size]).astype('float32') + self.bias_np = np.random.random([self.hidden_size]).astype('float32') + self.epsilon = 1e-6 + + def test_positive_negative_axis_equivalence(self): + paddle.disable_static() + + x = paddle.to_tensor(self.x_np) + weight = paddle.to_tensor(self.weight_np) + bias = paddle.to_tensor(self.bias_np) + + # positive + out_positive = paddle.incubate.nn.functional.fused_rms_norm( + x, weight, bias, self.epsilon, begin_norm_axis=2 + )[0] + + # negative + out_negative = paddle.incubate.nn.functional.fused_rms_norm( + x, weight, bias, self.epsilon, begin_norm_axis=-1 + )[0] + + # test + np.testing.assert_allclose( + out_positive.numpy(), + out_negative.numpy(), + rtol=1e-5, + atol=1e-5, + ) + + def test_out_of_range_axis(self): + paddle.disable_static() + + x = paddle.to_tensor(self.x_np) + weight = paddle.to_tensor(self.weight_np) + bias = paddle.to_tensor(self.bias_np) + + with self.assertRaises(ValueError): + paddle.incubate.nn.functional.fused_rms_norm( + x, weight, bias, self.epsilon, begin_norm_axis=3 + ) + + with self.assertRaises(ValueError): + paddle.incubate.nn.functional.fused_rms_norm( + x, weight, bias, self.epsilon, begin_norm_axis=-4 + ) + + +@unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm(), + "core is not compiled with CUDA or ROCM", +) +class TestRMSNormOp_ZeroSize(unittest.TestCase): + def setUp(self): + np.random.seed(20) + # 0-size + batch = 0 + cols = 256 + self.x_np = np.random.random([batch, cols]) + self.residual_np = np.random.random([batch, cols]) + self.bias_np = np.random.random([cols]) + + self.norm_weight_np = np.random.random([cols]) + self.norm_bias_np = np.random.random([cols]) + self.epsilon = 1e-6 + self.quant_scale = 0.15 + self.quant_round_type = 1 + self.quant_max_bound = 127 + self.quant_min_bound = -127 + + def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + gamma = paddle.to_tensor(gamma_np.astype(dtype)) + beta = paddle.to_tensor(beta_np.astype(dtype)) + + paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( + x, gamma, beta, self.epsilon, begin_norm_axis=1 + )[0] + paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) + paddle.enable_static() + return paddle_rmsnorm_out, paddle_naive_rmsnorm_out + + def test_rmsnorm_fp16(self): + if ( + not (paddle.is_compiled_with_cuda() or is_custom_device()) + and not paddle.is_compiled_with_rocm() + ): + return + paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( + self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' + ) + + np.testing.assert_allclose( + paddle_rmsnorm.numpy(), + paddle_naive_rmsnorm.numpy(), + rtol=1e-3, + atol=1e-3, + ) + + def test_rms_norm_backward(self): + def get_paddle_tensor(shape, dtype, bound=0.5): + tmp = paddle.uniform(shape, dtype=dtype, min=-bound, max=bound) + tmp.stop_gradient = False + return tmp + + def get_forward_backward(func, seed, dtype): + paddle.disable_static() + paddle.seed(seed) + # 0-size + x = get_paddle_tensor([0, 256], dtype) + scale = get_paddle_tensor([256], dtype) + out_g = paddle.randn([0, 256], dtype) + out = func(x, scale) + paddle.autograd.backward([out], [out_g], True) + return out, (x.grad, scale.grad) + + dtypes = [paddle.float32] + if paddle.amp.is_float16_supported(get_device()): + dtypes.append(paddle.float16) + for dtype in dtypes: + raw_out, raw_grads = get_forward_backward( + naive_rms_norm, seed=2024, dtype=dtype + ) + fused_out, fused_grads = get_forward_backward( + fused_rms_norm, seed=2024, dtype=dtype + ) + # forward rtol + rtol = 1e-5 if dtype == paddle.float32 else 1e-2 + np.testing.assert_allclose( + raw_out.astype(paddle.float32).numpy(), + fused_out.astype(paddle.float32).numpy(), + rtol=rtol, + ) + # backward rtol, only check float32 grad + rtol = 1e-3 + if dtype == paddle.float32: + raw_x_grad, raw_scale_grad = raw_grads + fused_x_grad, fused_scale_grad = fused_grads + np.testing.assert_allclose( + raw_x_grad.astype(paddle.float32).numpy(), + fused_x_grad.astype(paddle.float32).numpy(), + rtol=rtol, + ) + np.testing.assert_allclose( + raw_scale_grad.astype(paddle.float32).numpy(), + fused_scale_grad.astype(paddle.float32).numpy(), + rtol=rtol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_rms_norm.py b/test/legacy_test/test_rms_norm.py deleted file mode 100644 index 81b979bd768fcc..00000000000000 --- a/test/legacy_test/test_rms_norm.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np - -import paddle -from paddle.nn.functional import rms_norm - - -class TestRMSNorm(unittest.TestCase): - def setUp(self): - paddle.seed(2023) - np.random.seed(2023) - - def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5): - variance = paddle.mean(paddle.square(x), axis=-1, keepdim=True) - rms = paddle.sqrt(variance + epsilon) - y = x / rms - y = y * scale.reshape([1, -1]) - if bias is not None: - y = y + bias.reshape([1, -1]) - - return y, paddle.flatten(1.0 / rms) - - def test_2d_input(self): - rows, cols = 32, 64 - x = paddle.randn([rows, cols]) - scale = paddle.randn([cols]) - - y_fused, invvar_fused = rms_norm(x, (cols,), scale) - - y_ref, invvar_ref = self.rms_norm_reference(x, scale) - - np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose( - invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 - ) - - def test_3d_input(self): - batch, rows, cols = 16, 32, 64 - x = paddle.randn([batch, rows, cols]) - scale = paddle.randn([cols]) - - y_fused, invvar_fused = rms_norm(x, (cols,), scale) - - y_ref, invvar_ref = self.rms_norm_reference(x, scale) - - np.testing.assert_allclose( - y_fused.astype("float32"), - y_ref.astype("float32"), - rtol=1e-5, - atol=1e-5, - ) - np.testing.assert_allclose( - invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 - ) - - def test_without_bias(self): - rows, cols = 32, 64 - x = paddle.randn([rows, cols]) - scale = paddle.randn([cols]) - - y_fused, invvar_fused = rms_norm(x, (cols,), scale) - - y_ref, invvar_ref = self.rms_norm_reference(x, scale) - - np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose( - invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 - ) - - def test_3d_backward(self): - batch, rows, cols = 8, 16, 32 - x = paddle.randn([batch, rows, cols], dtype='float32') - x.stop_gradient = False - scale = paddle.randn([cols], dtype='float32') - scale.stop_gradient = False - - y_fused, invvar = rms_norm(x, (cols,), scale) - - loss = paddle.mean(y_fused) - loss.backward() - - x_grad_fused = x.grad.clone() - scale_grad_fused = scale.grad.clone() - - x.clear_gradient() - scale.clear_gradient() - - y_ref, invvar_ref = self.rms_norm_reference(x, scale) - loss_ref = paddle.mean(y_ref) - loss_ref.backward() - - x_grad_ref = x.grad - scale_grad_ref = scale.grad - - np.testing.assert_allclose( - x_grad_fused, x_grad_ref, rtol=1e-4, atol=1e-4 - ) - np.testing.assert_allclose( - scale_grad_fused, scale_grad_ref, rtol=1e-4, atol=1e-4 - ) - - def test_backward(self): - rows, cols = 16, 32 - test_type = ['bfloat16', 'float32'] - for x_type in test_type: - for scale_type in test_type: - x = paddle.randn([rows, cols], dtype=x_type) - x.stop_gradient = False - scale = paddle.randn([cols], dtype=scale_type) - scale.stop_gradient = False - - y_fused, invvar = rms_norm(x, (cols,), scale) - - loss = paddle.mean(y_fused) - loss.backward() - - x_grad_fused = x.grad.clone() - scale_grad_fused = scale.grad.clone() - - x.clear_gradient() - scale.clear_gradient() - - y_ref, invvar_ref = self.rms_norm_reference(x, scale) - loss_ref = paddle.mean(y_ref) - loss_ref.backward() - - x_grad_ref = x.grad - scale_grad_ref = scale.grad - - np.testing.assert_allclose( - x_grad_fused.astype("float32"), - x_grad_ref.astype("float32"), - rtol=1e-4, - atol=1e-4, - ) - np.testing.assert_allclose( - scale_grad_fused.astype("float32"), - scale_grad_ref.astype("float32"), - rtol=1e-4, - atol=1e-4, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/legacy_test/test_rms_norm_op.py b/test/legacy_test/test_rms_norm_op.py index ac94cce01f6ac6..43c5bbad049632 100644 --- a/test/legacy_test/test_rms_norm_op.py +++ b/test/legacy_test/test_rms_norm_op.py @@ -11,1000 +11,168 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import unittest +from functools import reduce +from operator import mul import numpy as np -from op_test import get_device, get_device_place, is_custom_device +from op_test import OpTest import paddle -from paddle import base -from paddle.base import core - - -def quant_helper( - x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound -): - quant_value = quant_max_bound * quant_scale * x - if quant_round_type == 0: - quant_value = paddle.to_tensor(np.rint(quant_value.numpy())) - else: - quant_value = paddle.round(quant_value) - return paddle.cast( - paddle.clip(quant_value, quant_min_bound, quant_max_bound), - paddle.int8, - ) - - -def naive_residual_bias_add(x, residual, bias): - return x + residual + bias - - -def naive_rms_norm(x, gamma, beta=None, epsilon=1e-5): - variance = x.pow(2).mean(-1, keepdim=True) - out = paddle.rsqrt(variance + epsilon) * x - out = out * gamma - if beta is not None: - out = out + beta - return out - - -def fused_rms_norm(x, gamma, beta=None, epsilon=1e-5, begin_norm_axis=1): - out = paddle.incubate.nn.functional.fused_rms_norm( - x, gamma, beta, epsilon, begin_norm_axis=begin_norm_axis - ) - return out[0] - - -def naive_rms_norm_int8( - x, - gamma, - beta, - epsilon, - in_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, -): - out = naive_rms_norm(x, gamma, beta, epsilon) - out = quant_helper( - out, in_scale, quant_round_type, quant_max_bound, quant_min_bound - ) - return out - - -def naive_residual_biasadd_rms_norm(x, residual, bias, gamma, beta, epsilon): - x = x + residual + bias - variance = x.pow(2).mean(-1, keepdim=True) - out = paddle.rsqrt(variance + epsilon) * x - out = out * gamma + beta - return out - - -def naive_residual_biasadd_rms_norm_int8( - x, - residual, - bias, - gamma, - beta, - epsilon, - in_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, -): - out = naive_residual_biasadd_rms_norm( - x, residual, bias, gamma, beta, epsilon - ) - out = quant_helper( - out, in_scale, quant_round_type, quant_max_bound, quant_min_bound - ) - return out - - -@unittest.skipIf( - not (core.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm(), - "core is not compiled with CUDA or ROCM", -) -class TestRMSNormOp(unittest.TestCase): +from paddle.nn.functional import rms_norm + + +def rms_norm_reference(x, scale, bias=None, epsilon=1e-5): + x_shape = x.shape + begin_norm_axis = len(x.shape) - 1 + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) + x.shape = [N, D] + + variance = np.mean(np.square(x), axis=-1) + rms = np.sqrt(variance + epsilon) + y = x / rms.reshape([N, 1]) + y = y * scale.reshape([1, -1]) + if bias is not None: + y = y + bias.reshape([1, -1]) + + return y, 1.0 / rms + + +class TestRMSNormOp(OpTest): def setUp(self): - np.random.seed(20) - batch = 32 - cols = 256 - self.x_np = np.random.random([batch, cols]) - self.residual_np = np.random.random([batch, cols]) - self.bias_np = np.random.random([cols]) - - self.norm_weight_np = np.random.random([cols]) - self.norm_bias_np = np.random.random([cols]) - self.epsilon = 1e-6 - self.quant_scale = 0.15 - self.quant_round_type = 1 - self.quant_max_bound = 127 - self.quant_min_bound = -127 - - def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - - paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( - x, gamma, beta, self.epsilon, begin_norm_axis=1 - )[0] - paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) - paddle.enable_static() - return paddle_rmsnorm_out, paddle_naive_rmsnorm_out - - def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - - paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( - x, - gamma, - beta, - self.epsilon, - begin_norm_axis=1, - quant_scale=self.quant_scale, - quant_round_type=self.quant_round_type, - quant_max_bound=self.quant_max_bound, - quant_min_bound=self.quant_min_bound, - )[0] - - paddle_naive_rmsnorm_out = naive_rms_norm_int8( - x, - gamma, - beta, - self.epsilon, - self.quant_scale, - self.quant_round_type, - self.quant_max_bound, - self.quant_min_bound, - ) - paddle.enable_static() - return paddle_rmsnorm_out, paddle_naive_rmsnorm_out - - def check_residual_bias_rmsnorm( - self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype - ): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - residual = paddle.to_tensor(residual_np.astype(dtype)) - bias = paddle.to_tensor(bias_np.astype(dtype)) - - paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( - x, - gamma, - beta, - self.epsilon, - begin_norm_axis=1, - bias=bias, - residual=residual, - )[0] - - paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( - x, residual, bias, gamma, beta, self.epsilon - ) - paddle.enable_static() - return paddle_rmsnorm_out, paddle_naive_rmsnorm_out - - def check_residual_bias_rmsnorm_int8( - self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype - ): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - residual = paddle.to_tensor(residual_np.astype(dtype)) - bias = paddle.to_tensor(bias_np.astype(dtype)) - - paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( - x, - gamma, - beta, - self.epsilon, - begin_norm_axis=1, - bias=bias, - residual=residual, - quant_scale=self.quant_scale, - quant_round_type=self.quant_round_type, - quant_max_bound=self.quant_max_bound, - quant_min_bound=self.quant_min_bound, - )[0] - - paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm_int8( - x, - residual, - bias, - gamma, - beta, - self.epsilon, - self.quant_scale, - self.quant_round_type, - self.quant_max_bound, - self.quant_min_bound, - ) - paddle.enable_static() - return paddle_rmsnorm_out, paddle_naive_rmsnorm_out - - def test_rmsnorm_fp16(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( - self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' - ) + self.op_type = "rms_norm" + self.init_dtype() + self.init_config() - np.testing.assert_allclose( - paddle_rmsnorm.numpy(), - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, - ) + np.random.seed(2023) + x = np.random.randn(*self.x_shape).astype(self.dtype) + scale = np.random.randn(self.x_shape[-1]).astype(self.dtype) + normalized_shape = [self.x_shape[-1]] - def test_rmsnorm_int8(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8( - self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' - ) - np.testing.assert_allclose( - paddle_rmsnorm.numpy(), - paddle_naive_rmsnorm.numpy(), - rtol=2, - atol=2, - ) + self.inputs = {'x': x, 'scale': scale} + self.attrs = { + 'normalized_shape': normalized_shape, + 'epsilon': self.epsilon, + } + y_ref, invvar_ref = rms_norm_reference(x, scale, epsilon=self.epsilon) + self.outputs = {'y': y_ref, 'invvar': invvar_ref} - def test_residual_bias_add_rmsnorm_fp16(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm( - self.x_np, - self.norm_weight_np, - self.norm_bias_np, - self.residual_np, - self.bias_np, - 'float16', - ) + def rms_norm_wrapper(x, scale): + return rms_norm(x, scale.shape, scale, eps=self.epsilon) - np.testing.assert_allclose( - paddle_rmsnorm.numpy(), - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, - ) + self.python_api = rms_norm_wrapper - def test_residual_bias_add_rmsnorm_int8(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - ( - paddle_rmsnorm, - paddle_naive_rmsnorm, - ) = self.check_residual_bias_rmsnorm_int8( - self.x_np, - self.norm_weight_np, - self.norm_bias_np, - self.residual_np, - self.bias_np, - 'float16', - ) + def init_dtype(self): + self.dtype = np.float32 - np.testing.assert_allclose( - paddle_rmsnorm.numpy(), - paddle_naive_rmsnorm.numpy(), - rtol=2, - atol=2, - ) + def init_config(self): + self.epsilon = 1e-5 + self.x_shape = (32, 64) - def test_rms_norm_backward(self): - def get_paddle_tensor(shape, dtype, bound=0.5): - tmp = paddle.uniform(shape, dtype=dtype, min=-bound, max=bound) - tmp.stop_gradient = False - return tmp - - def get_forward_backward(func, seed, dtype): - paddle.disable_static() - paddle.seed(seed) - x = get_paddle_tensor([2, 256], dtype) - scale = get_paddle_tensor([256], dtype) - out_g = paddle.randn([2, 256], dtype) - out = func(x, scale) - paddle.autograd.backward([out], [out_g], True) - return out, (x.grad, scale.grad) - - dtypes = [paddle.float32] - if paddle.amp.is_bfloat16_supported(get_device()): - dtypes.append(paddle.bfloat16) - if paddle.amp.is_float16_supported(get_device()): - dtypes.append(paddle.float16) - for dtype in dtypes: - raw_out, raw_grads = get_forward_backward( - naive_rms_norm, seed=2024, dtype=dtype - ) - fused_out, fused_grads = get_forward_backward( - fused_rms_norm, seed=2024, dtype=dtype - ) - # forward rtol - rtol = 1e-5 if dtype == paddle.float32 else 1e-2 - np.testing.assert_allclose( - raw_out.astype(paddle.float32).numpy(), - fused_out.astype(paddle.float32).numpy(), - rtol=rtol, - ) - # backward rtol, only check float32 grad - rtol = 1e-3 - if dtype == paddle.float32: - raw_x_grad, raw_scale_grad = raw_grads - fused_x_grad, fused_scale_grad = fused_grads - np.testing.assert_allclose( - raw_x_grad.astype(paddle.float32).numpy(), - fused_x_grad.astype(paddle.float32).numpy(), - rtol=rtol, - ) - np.testing.assert_allclose( - raw_scale_grad.astype(paddle.float32).numpy(), - fused_scale_grad.astype(paddle.float32).numpy(), - rtol=rtol, - ) - - -@unittest.skipIf( - not (core.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm(), - "core is not compiled with CUDA or ROCM", -) -class TestRMSNormStaticOp(unittest.TestCase): - def setUp(self): - np.random.seed(20) - self.batch = 32 - self.cols = 256 - self.x_np = np.random.random([self.batch, 256]) - self.norm_weight_np = np.random.random([256]) - self.norm_bias_np = np.random.random([256]) - self.residual_np = np.random.random([self.batch, 256]) - self.bias_np = np.random.random([256]) - self.epsilon = 1e-6 - self.quant_scale = 0.15 - self.quant_round_type = 1 - self.quant_max_bound = 127 - self.quant_min_bound = -127 - self.place = get_device_place() - - def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - - paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) - paddle.enable_static() - - with paddle.static.program_guard(paddle.static.Program()): - x_static = paddle.static.data( - name="x_static", shape=[self.batch, self.cols], dtype=dtype - ) - gamma_static = paddle.static.data( - name="gamma_static", shape=[self.cols], dtype=dtype - ) - beta_static = paddle.static.data( - name="beta_static", shape=[self.cols], dtype=dtype - ) - outs = paddle.incubate.nn.functional.fused_rms_norm( - x_static, - gamma_static, - beta_static, - self.epsilon, - begin_norm_axis=1, - )[0] - exe = base.Executor(self.place) - out_s = exe.run( - feed={ - "x_static": x_np.astype(dtype), - "gamma_static": gamma_np.astype(dtype), - "beta_static": beta_np.astype(dtype), - }, - fetch_list=[outs], - ) - return out_s[0], paddle_naive_rmsnorm_out - - def check_rmsnorm_int8(self, x_np, gamma_np, beta_np, dtype): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - - paddle_naive_rmsnorm_out = naive_rms_norm_int8( - x, - gamma, - beta, - self.epsilon, - self.quant_scale, - self.quant_round_type, - self.quant_max_bound, - self.quant_min_bound, - ) - paddle.enable_static() - - with paddle.static.program_guard(paddle.static.Program()): - x_static = paddle.static.data( - name="x_static", shape=[self.batch, self.cols], dtype=dtype - ) - gamma_static = paddle.static.data( - name="gamma_static", shape=[self.cols], dtype=dtype - ) - beta_static = paddle.static.data( - name="beta_static", shape=[self.cols], dtype=dtype - ) - outs = paddle.incubate.nn.functional.fused_rms_norm( - x_static, - gamma_static, - beta_static, - self.epsilon, - begin_norm_axis=1, - quant_scale=self.quant_scale, - quant_round_type=self.quant_round_type, - quant_max_bound=self.quant_max_bound, - quant_min_bound=self.quant_min_bound, - )[0] - exe = base.Executor(self.place) - out_s = exe.run( - feed={ - "x_static": x_np.astype(dtype), - "gamma_static": gamma_np.astype(dtype), - "beta_static": beta_np.astype(dtype), - }, - fetch_list=[outs], - ) - return out_s[0], paddle_naive_rmsnorm_out - - def check_residual_bias_rmsnorm( - self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype - ): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - residual = paddle.to_tensor(residual_np.astype(dtype)) - bias = paddle.to_tensor(bias_np.astype(dtype)) - - paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( - x, residual, bias, gamma, beta, self.epsilon - ) - paddle.enable_static() - - with paddle.static.program_guard(paddle.static.Program()): - x_static = paddle.static.data( - name="x_static", shape=[self.batch, self.cols], dtype=dtype - ) - residual_static = paddle.static.data( - name="residual_static", - shape=[self.batch, self.cols], - dtype=dtype, - ) - bias_static = paddle.static.data( - name="bias_static", shape=[self.cols], dtype=dtype - ) - gamma_static = paddle.static.data( - name="gamma_static", shape=[self.cols], dtype=dtype - ) - beta_static = paddle.static.data( - name="beta_static", shape=[self.cols], dtype=dtype - ) - outs = paddle.incubate.nn.functional.fused_rms_norm( - x_static, - gamma_static, - beta_static, - self.epsilon, - begin_norm_axis=1, - bias=bias_static, - residual=residual_static, - )[0] - - exe = base.Executor(self.place) - out_s = exe.run( - feed={ - "x_static": x_np.astype(dtype), - "gamma_static": gamma_np.astype(dtype), - "beta_static": beta_np.astype(dtype), - "residual_static": residual_np.astype(dtype), - "bias_static": bias_np.astype(dtype), - }, - fetch_list=[outs], - ) - return out_s[0], paddle_naive_rmsnorm_out - - def test_rmsnorm_fp16(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( - self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' - ) + def test_check_output(self): + self.check_output(check_pir=True) - np.testing.assert_allclose( - paddle_rmsnorm, - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, - ) + def test_check_grad(self): + self.check_grad(['x', 'scale'], ['y'], check_pir=True) - def test_residual_bias_add_rmsnorm_fp16(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm( - self.x_np, - self.norm_weight_np, - self.norm_bias_np, - self.residual_np, - self.bias_np, - 'float16', - ) + @classmethod + def tearDownClass(cls): + # Avoid AssertionError: This test of rms_norm op needs check_grad with fp64 precision. + pass - np.testing.assert_allclose( - paddle_rmsnorm, - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, - ) - def test_rmsnorm_int8(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm_int8( - self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' - ) - np.testing.assert_allclose( - paddle_rmsnorm, - paddle_naive_rmsnorm.numpy(), - rtol=2, - atol=2, - ) +class TestRMSNormOp3D(TestRMSNormOp): + def init_config(self): + self.epsilon = 1e-5 + self.x_shape = (16, 32, 64) + def test_check_output(self): + self.check_output(check_pir=True) -@unittest.skipIf( - not core.supports_avx512f() or not core.is_compiled_with_avx(), - "machine is not support AVX or is not compiled with AVX", -) -class TestRMSNormOpCPU(unittest.TestCase): + +class TestRMSNormOpEpsilon(TestRMSNormOp): + def init_config(self): + self.epsilon = 1e-4 + self.x_shape = (32, 64) + + +class TestRMSNormAPI(unittest.TestCase): def setUp(self): - import os - - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - np.random.seed(20) - batch = 32 - cols = 256 - self.x_np = np.random.random([batch, cols]) - self.residual_np = np.random.random([batch, cols]) - self.bias_np = np.random.random([cols]) - - self.norm_weight_np = np.random.random([cols]) - self.norm_bias_np = np.random.random([cols]) - self.epsilon = 1e-6 - - def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - - paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( - x, gamma, beta, self.epsilon, begin_norm_axis=1 - )[0] - paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) - paddle.enable_static() - return paddle_rmsnorm_out, paddle_naive_rmsnorm_out - - def check_residual_bias_rmsnorm( - self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype - ): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - residual = paddle.to_tensor(residual_np.astype(dtype)) - bias = paddle.to_tensor(bias_np.astype(dtype)) - - paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( - x, - gamma, - beta, - self.epsilon, - begin_norm_axis=1, - bias=bias, - residual=residual, - ) + paddle.seed(2023) + np.random.seed(2023) - paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( - x, residual, bias, gamma, beta, self.epsilon - ) + def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5): + variance = paddle.mean(paddle.square(x), axis=-1, keepdim=True) + rms = paddle.sqrt(variance + epsilon) + y = x / rms + y = y * scale.reshape([1, -1]) + if bias is not None: + y = y + bias.reshape([1, -1]) - paddle_naive_residual_out = naive_residual_bias_add(x, residual, bias) - paddle.enable_static() - return ( - paddle_rmsnorm_out, - paddle_naive_rmsnorm_out, - paddle_naive_residual_out, - ) + return y, paddle.flatten(1.0 / rms) - def test_rmsnorm(self): - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( - self.x_np, self.norm_weight_np, self.norm_bias_np, 'float32' - ) - np.testing.assert_allclose( - paddle_rmsnorm.numpy(), - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, - ) + def test_api_dygraph(self): + rows, cols = 32, 64 + x_np = np.random.randn(rows, cols).astype("float32") + scale_np = np.random.randn(cols).astype("float32") - def test_residual_bias_add_rmsnorm(self): - ( - paddle_rmsnorm, - paddle_naive_rmsnorm, - paddle_naive_residual_out, - ) = self.check_residual_bias_rmsnorm( - self.x_np, - self.norm_weight_np, - self.norm_bias_np, - self.residual_np, - self.bias_np, - 'float32', - ) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + scale = paddle.to_tensor(scale_np) + scale.stop_gradient = False + + # Test forward + y_fused, invvar_fused = rms_norm(x, (cols,), scale) + y_ref, invvar_ref = self.rms_norm_reference(x, scale) np.testing.assert_allclose( - paddle_rmsnorm[0].numpy(), - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, + y_fused.numpy(), y_ref.numpy(), rtol=1e-5, atol=1e-5 ) np.testing.assert_allclose( - paddle_rmsnorm[1].numpy(), - paddle_naive_residual_out.numpy(), - rtol=1e-3, - atol=1e-3, + invvar_fused.numpy(), invvar_ref.numpy(), rtol=1e-5, atol=1e-5 ) + # Test backward + loss = paddle.mean(y_fused) + loss.backward() -@unittest.skipIf( - not core.supports_avx512f() or not core.is_compiled_with_avx(), - "machine is not support AVX or is not compiled with AVX", -) -class TestRMSNormStaticOpCPU(unittest.TestCase): - def setUp(self): - import os - - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - np.random.seed(20) - self.batch = 32 - self.cols = 256 - self.x_np = np.random.random([self.batch, 256]) - self.norm_weight_np = np.random.random([256]) - self.norm_bias_np = np.random.random([256]) - self.residual_np = np.random.random([self.batch, 256]) - self.bias_np = np.random.random([256]) - self.epsilon = 1e-6 - self.place = paddle.CPUPlace() - - def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - - paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) - paddle.enable_static() - - with paddle.static.program_guard(paddle.static.Program()): - x_static = paddle.static.data( - name="x_static", shape=[self.batch, self.cols], dtype=dtype - ) - gamma_static = paddle.static.data( - name="gamma_static", shape=[self.cols], dtype=dtype - ) - beta_static = paddle.static.data( - name="beta_static", shape=[self.cols], dtype=dtype - ) - outs = paddle.incubate.nn.functional.fused_rms_norm( - x_static, - gamma_static, - beta_static, - self.epsilon, - begin_norm_axis=1, - )[0] - exe = base.Executor(self.place) - out_s = exe.run( - feed={ - "x_static": x_np.astype(dtype), - "gamma_static": gamma_np.astype(dtype), - "beta_static": beta_np.astype(dtype), - }, - fetch_list=[outs], - ) - return out_s[0], paddle_naive_rmsnorm_out - - def check_residual_bias_rmsnorm( - self, x_np, gamma_np, beta_np, residual_np, bias_np, dtype - ): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - residual = paddle.to_tensor(residual_np.astype(dtype)) - bias = paddle.to_tensor(bias_np.astype(dtype)) - - paddle_naive_rmsnorm_out = naive_residual_biasadd_rms_norm( - x, residual, bias, gamma, beta, self.epsilon - ) - paddle.enable_static() - - with paddle.static.program_guard(paddle.static.Program()): - x_static = paddle.static.data( - name="x_static", shape=[self.batch, self.cols], dtype=dtype - ) - residual_static = paddle.static.data( - name="residual_static", - shape=[self.batch, self.cols], - dtype=dtype, - ) - bias_static = paddle.static.data( - name="bias_static", shape=[self.cols], dtype=dtype - ) - gamma_static = paddle.static.data( - name="gamma_static", shape=[self.cols], dtype=dtype - ) - beta_static = paddle.static.data( - name="beta_static", shape=[self.cols], dtype=dtype - ) - outs = paddle.incubate.nn.functional.fused_rms_norm( - x_static, - gamma_static, - beta_static, - self.epsilon, - begin_norm_axis=1, - bias=bias_static, - residual=residual_static, - )[0] - - exe = base.Executor(self.place) - out_s = exe.run( - feed={ - "x_static": x_np.astype(dtype), - "gamma_static": gamma_np.astype(dtype), - "beta_static": beta_np.astype(dtype), - "residual_static": residual_np.astype(dtype), - "bias_static": bias_np.astype(dtype), - }, - fetch_list=[outs], - ) - return out_s[0], paddle_naive_rmsnorm_out - - def test_rmsnorm(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( - self.x_np, self.norm_weight_np, self.norm_bias_np, 'float32' - ) + x_grad_fused = x.grad.numpy() + scale_grad_fused = scale.grad.numpy() - np.testing.assert_allclose( - paddle_rmsnorm, - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, - ) + x.clear_gradient() + scale.clear_gradient() - def test_residual_bias_add_rmsnorm(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_residual_bias_rmsnorm( - self.x_np, - self.norm_weight_np, - self.norm_bias_np, - self.residual_np, - self.bias_np, - 'float32', - ) + y_ref, invvar_ref = self.rms_norm_reference(x, scale) + loss_ref = paddle.mean(y_ref) + loss_ref.backward() np.testing.assert_allclose( - paddle_rmsnorm, - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, + x_grad_fused, x.grad.numpy(), rtol=1e-5, atol=1e-5 ) - - -class TestRMSNormAxisEquivalence(unittest.TestCase): - def setUp(self): - np.random.seed(123) - paddle.seed(123) - - # x [batch_size, seq_len, hidden_size] - self.batch_size = 1 - self.seq_len = 8 - self.hidden_size = 64 - - self.x_np = np.random.random( - [self.batch_size, self.seq_len, self.hidden_size] - ).astype('float32') - self.weight_np = np.random.random([self.hidden_size]).astype('float32') - self.bias_np = np.random.random([self.hidden_size]).astype('float32') - self.epsilon = 1e-6 - - def test_positive_negative_axis_equivalence(self): - paddle.disable_static() - - x = paddle.to_tensor(self.x_np) - weight = paddle.to_tensor(self.weight_np) - bias = paddle.to_tensor(self.bias_np) - - # positive - out_positive = paddle.incubate.nn.functional.fused_rms_norm( - x, weight, bias, self.epsilon, begin_norm_axis=2 - )[0] - - # negative - out_negative = paddle.incubate.nn.functional.fused_rms_norm( - x, weight, bias, self.epsilon, begin_norm_axis=-1 - )[0] - - # test np.testing.assert_allclose( - out_positive.numpy(), - out_negative.numpy(), - rtol=1e-5, - atol=1e-5, + scale_grad_fused, scale.grad.numpy(), rtol=1e-5, atol=1e-5 ) - def test_out_of_range_axis(self): - paddle.disable_static() - x = paddle.to_tensor(self.x_np) - weight = paddle.to_tensor(self.weight_np) - bias = paddle.to_tensor(self.bias_np) +class TestRMSNormValueError(unittest.TestCase): + def test_normalized_shape_type_error(self): + x = paddle.randn([2, 3]) + with self.assertRaises(TypeError): + rms_norm(x, "invalid_shape") + def test_input_shape_mismatch(self): + x = paddle.randn([2, 3]) with self.assertRaises(ValueError): - paddle.incubate.nn.functional.fused_rms_norm( - x, weight, bias, self.epsilon, begin_norm_axis=3 - ) + rms_norm(x, [4]) + def test_weight_shape_mismatch(self): + x = paddle.randn([2, 3]) + weight = paddle.randn([4]) with self.assertRaises(ValueError): - paddle.incubate.nn.functional.fused_rms_norm( - x, weight, bias, self.epsilon, begin_norm_axis=-4 - ) - + rms_norm(x, [3], weight=weight) -@unittest.skipIf( - not (core.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm(), - "core is not compiled with CUDA or ROCM", -) -class TestRMSNormOp_ZeroSize(unittest.TestCase): - def setUp(self): - np.random.seed(20) - # 0-size - batch = 0 - cols = 256 - self.x_np = np.random.random([batch, cols]) - self.residual_np = np.random.random([batch, cols]) - self.bias_np = np.random.random([cols]) - - self.norm_weight_np = np.random.random([cols]) - self.norm_bias_np = np.random.random([cols]) - self.epsilon = 1e-6 - self.quant_scale = 0.15 - self.quant_round_type = 1 - self.quant_max_bound = 127 - self.quant_min_bound = -127 - - def check_rmsnorm(self, x_np, gamma_np, beta_np, dtype): - paddle.disable_static() - x = paddle.to_tensor(x_np.astype(dtype)) - gamma = paddle.to_tensor(gamma_np.astype(dtype)) - beta = paddle.to_tensor(beta_np.astype(dtype)) - - paddle_rmsnorm_out = paddle.incubate.nn.functional.fused_rms_norm( - x, gamma, beta, self.epsilon, begin_norm_axis=1 - )[0] - paddle_naive_rmsnorm_out = naive_rms_norm(x, gamma, beta, self.epsilon) - paddle.enable_static() - return paddle_rmsnorm_out, paddle_naive_rmsnorm_out - - def test_rmsnorm_fp16(self): - if ( - not (paddle.is_compiled_with_cuda() or is_custom_device()) - and not paddle.is_compiled_with_rocm() - ): - return - paddle_rmsnorm, paddle_naive_rmsnorm = self.check_rmsnorm( - self.x_np, self.norm_weight_np, self.norm_bias_np, 'float16' - ) - - np.testing.assert_allclose( - paddle_rmsnorm.numpy(), - paddle_naive_rmsnorm.numpy(), - rtol=1e-3, - atol=1e-3, - ) - def test_rms_norm_backward(self): - def get_paddle_tensor(shape, dtype, bound=0.5): - tmp = paddle.uniform(shape, dtype=dtype, min=-bound, max=bound) - tmp.stop_gradient = False - return tmp - - def get_forward_backward(func, seed, dtype): - paddle.disable_static() - paddle.seed(seed) - # 0-size - x = get_paddle_tensor([0, 256], dtype) - scale = get_paddle_tensor([256], dtype) - out_g = paddle.randn([0, 256], dtype) - out = func(x, scale) - paddle.autograd.backward([out], [out_g], True) - return out, (x.grad, scale.grad) - - dtypes = [paddle.float32] - if paddle.amp.is_float16_supported(get_device()): - dtypes.append(paddle.float16) - for dtype in dtypes: - raw_out, raw_grads = get_forward_backward( - naive_rms_norm, seed=2024, dtype=dtype - ) - fused_out, fused_grads = get_forward_backward( - fused_rms_norm, seed=2024, dtype=dtype - ) - # forward rtol - rtol = 1e-5 if dtype == paddle.float32 else 1e-2 - np.testing.assert_allclose( - raw_out.astype(paddle.float32).numpy(), - fused_out.astype(paddle.float32).numpy(), - rtol=rtol, - ) - # backward rtol, only check float32 grad - rtol = 1e-3 - if dtype == paddle.float32: - raw_x_grad, raw_scale_grad = raw_grads - fused_x_grad, fused_scale_grad = fused_grads - np.testing.assert_allclose( - raw_x_grad.astype(paddle.float32).numpy(), - fused_x_grad.astype(paddle.float32).numpy(), - rtol=rtol, - ) - np.testing.assert_allclose( - raw_scale_grad.astype(paddle.float32).numpy(), - fused_scale_grad.astype(paddle.float32).numpy(), - rtol=rtol, - ) - - -if __name__ == "__main__": +if __name__ == '__main__': unittest.main()