Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,7 @@ void Fp8QuantBlockwiseInferMeta(const MetaTensor& X,
bool return_transpose_only,
bool using_e5m2,
bool using_pow2_scale,
bool using_ue8m0_scale,
MetaTensor* out,
MetaTensor* scale,
MetaTensor* out_transposed,
Expand Down Expand Up @@ -2222,6 +2223,69 @@ void Fp8QuantBlockwiseInferMeta(const MetaTensor& X,
scale_transposed_inner_dim = scale_outer_dim;
}

if (using_ue8m0_scale) {
// 1. Add using_ue8m0_scale param, support ue8m0 quantization scale.
// 2. After using using_ue8m0_scale, the method will convert the original
// output fp32 scale to e8m0, then store 4 groups as int32 variables,
// finally output int32 scale, so the output shape will become 1/4 of the
// original.
// 3. But for 128*128 quantization, there is a special point that the
// original output should be M/128, N/128. After using using_ue8m0_scale, it
// becomes M, N/128/4. Explicitly copy 128 rows and share a scale, so M is
// not divided by 128.
// Upgrade 128x128 to expanded shape
if (!using_1x128_vec_quant) {
if (output_scale_transpose) {
scale_inner_dim = rows;
scale_transposed_inner_dim = cols;
} else {
scale_outer_dim = rows;
scale_transposed_outer_dim = cols;
}
}

if (output_scale_transpose) {
PADDLE_ENFORCE_EQ(
scale_outer_dim % 4,
0,
common::errors::InvalidArgument(
"When use_ue8m0 is true, the outer dimension of scale "
"must be divisible by 4, but got %d",
scale_outer_dim));
scale_outer_dim /= 4;
if (input_transpose) {
PADDLE_ENFORCE_EQ(scale_transposed_outer_dim % 4,
0,
common::errors::InvalidArgument(
"When use_ue8m0 is true, the outer dimension of "
"transposed scale "
"must be divisible by 4, but got %d",
scale_transposed_outer_dim));
scale_transposed_outer_dim /= 4;
}

} else {
PADDLE_ENFORCE_EQ(
scale_inner_dim % 4,
0,
common::errors::InvalidArgument(
"When use_ue8m0 is true, the inner dimension of scale "
"must be divisible by 4, but got %d",
scale_inner_dim));
scale_inner_dim /= 4;
if (input_transpose) {
PADDLE_ENFORCE_EQ(scale_transposed_inner_dim % 4,
0,
common::errors::InvalidArgument(
"When use_ue8m0 is true, the inner dimension of "
"transposed scale "
"must be divisible by 4, but got %d",
scale_transposed_inner_dim));
scale_transposed_inner_dim /= 4;
}
}
}

PADDLE_ENFORCE_GT(output_outer_dim,
0,
common::errors::InvalidArgument(
Expand All @@ -2244,19 +2308,20 @@ void Fp8QuantBlockwiseInferMeta(const MetaTensor& X,
out->set_dims(common::make_ddim({output_outer_dim, output_inner_dim}));
out->set_dtype(DataType::FLOAT8_E4M3FN);
scale->set_dims(common::make_ddim({scale_outer_dim, scale_inner_dim}));
scale->set_dtype(DataType::FLOAT32);
scale->set_dtype(using_ue8m0_scale ? DataType::INT32 : DataType::FLOAT32);
} else {
out->set_dims(common::make_ddim({0}));
out->set_dtype(DataType::FLOAT8_E4M3FN);
scale->set_dims(common::make_ddim({0}));
scale->set_dtype(DataType::FLOAT32);
scale->set_dtype(using_ue8m0_scale ? DataType::INT32 : DataType::FLOAT32);
}
if (input_transpose) {
out_transposed->set_dims(make_ddim({output_inner_dim, output_outer_dim}));
out_transposed->set_dtype(DataType::FLOAT8_E4M3FN);
scale_transposed->set_dims(common::make_ddim(
{scale_transposed_outer_dim, scale_transposed_inner_dim}));
scale_transposed->set_dtype(DataType::FLOAT32);
scale_transposed->set_dtype(using_ue8m0_scale ? DataType::INT32
: DataType::FLOAT32);
}
} else {
PADDLE_THROW(
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ PADDLE_API void Fp8QuantBlockwiseInferMeta(const MetaTensor& X,
bool return_transpose_only,
bool using_e5m2,
bool using_pow2_scale,
bool using_ue8m0_scale,
MetaTensor* out,
MetaTensor* scale,
MetaTensor* out_transposed,
Expand Down
Loading
Loading