From a82f5fe319b32fc87f6facc31600e0fe7ccf1a61 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Wed, 31 Dec 2025 13:13:27 +0800 Subject: [PATCH 1/2] add stack_quant e8m0 --- paddle/phi/infermeta/fusion.cc | 40 ++- paddle/phi/infermeta/fusion.h | 6 + .../gpu/fused_stack_transpose_quant_kernel.cu | 253 ++++++++++--- .../gpu/fused_stack_transpose_quant_kernel.h | 6 + paddle/phi/ops/yaml/fused_ops.yaml | 4 +- python/paddle/incubate/nn/functional/fp8.py | 20 +- test_fused_stack_transpose_quant.py | 338 ++++++++++++++++++ 7 files changed, 618 insertions(+), 49 deletions(-) create mode 100644 test_fused_stack_transpose_quant.py diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 0fed71afe8e5c3..606d424551aca4 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -4287,17 +4287,33 @@ std::tuple FusedStackQuantCommonCheck( } void FusedStackTransposeQuantInferMeta(const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, MetaTensor* out, MetaTensor* scale) { int64_t N, M, K; std::tie(N, M, K) = FusedStackQuantCommonCheck(x); std::vector out_shape = {N * K, M}; - std::vector scale_shape = {N * K / 128, M / 128}; + std::vector scale_shape; + if (using_ue8m0_scale) { + if (output_scale_transpose) { + scale_shape = {M / 128 / 4, N * K}; + } else { + scale_shape = {N * K, M / 128 / 4}; + } + } else { + if (output_scale_transpose) { + scale_shape = {M / 128, N * K / 128}; + } else { + scale_shape = {N * K / 128, M / 128}; + } + } out->set_dims(common::make_ddim(out_shape)); scale->set_dims(common::make_ddim(scale_shape)); out->set_dtype(DataType::FLOAT8_E4M3FN); - scale->set_dtype(DataType::FLOAT32); + scale->set_dtype(using_ue8m0_scale ? DataType::INT32 : DataType::FLOAT32); out->share_lod(*x.at(0)); scale->share_lod(*x.at(0)); out->set_layout(x.at(0)->layout()); @@ -4305,17 +4321,33 @@ void FusedStackTransposeQuantInferMeta(const std::vector& x, } void FusedStackQuantInferMeta(const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, MetaTensor* out, MetaTensor* scale) { int64_t N, M, K; std::tie(N, M, K) = FusedStackQuantCommonCheck(x); std::vector out_shape = {N * M, K}; - std::vector scale_shape = {N * M / 128, K / 128}; + std::vector scale_shape; + if (using_ue8m0_scale) { + if (output_scale_transpose) { + scale_shape = {K / 128 / 4, N * M}; + } else { + scale_shape = {N * M, K / 128 / 4}; + } + } else { + if (output_scale_transpose) { + scale_shape = {K / 128, N * M / 128}; + } else { + scale_shape = {N * M / 128, K / 128}; + } + } out->set_dims(common::make_ddim(out_shape)); scale->set_dims(common::make_ddim(scale_shape)); out->set_dtype(DataType::FLOAT8_E4M3FN); - scale->set_dtype(DataType::FLOAT32); + scale->set_dtype(using_ue8m0_scale ? DataType::INT32 : DataType::FLOAT32); out->share_lod(*x.at(0)); scale->share_lod(*x.at(0)); out->set_layout(x.at(0)->layout()); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 8b954c89433aab..adbaf5cb27d665 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -943,11 +943,17 @@ PADDLE_API void FusionSeqExpandConcatFCInferMeta( PADDLE_API void FusedStackTransposeQuantInferMeta( const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, MetaTensor* out, MetaTensor* scale); PADDLE_API void FusedStackQuantInferMeta( const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, MetaTensor* out, MetaTensor* scale); diff --git a/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.cu index 75f0ab872c99e9..62b46bb71a48e9 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/fast_divmod.h" @@ -26,6 +27,16 @@ namespace fusion { using FastDivMod = funcs::FastDivMod; +template +__device__ __forceinline__ void StoreScale(ScaleT* ptr, size_t idx, float val) { + if constexpr (using_ue8m0_scale) { + int exp = (__float_as_int(val) >> 23) & 0xFF; + reinterpret_cast(ptr)[idx] = static_cast(exp); + } else { + ptr[idx] = val; + } +} + template __device__ void BlockLoad(ArrayT input_array, __nv_bfloat16 x[8][4], @@ -58,8 +69,8 @@ __device__ __nv_bfloat16 WarpReduceMax(__nv_bfloat16 x) { return x; } -template -__device__ float BlockReduceScale(__nv_bfloat16 x[8][4]) { +template +__device__ float BlockReduceScale(__nv_bfloat16 x[8][4], float eps = 1e-10f) { // [(8), 16, 32, (4)] => [16, 32] __nv_bfloat16 local_max; for (uint32_t i = 0; i < 8; i++) { @@ -82,7 +93,8 @@ __device__ float BlockReduceScale(__nv_bfloat16 x[8][4]) { if (threadIdx.y == 0 && threadIdx.x < 16) { warp_max = WarpReduceMax<16>(block_max[threadIdx.x]); if (threadIdx.x == 0) { - block_scale = ComputeScale<__nv_bfloat16, OutT>(warp_max, 0.0f); + block_scale = + ComputeScale<__nv_bfloat16, OutT, Power2Scaling>(warp_max, eps); } } __syncthreads(); @@ -90,11 +102,16 @@ __device__ float BlockReduceScale(__nv_bfloat16 x[8][4]) { return block_scale; } -template +template __global__ void __launch_bounds__(512) FusedStackQuantGPUKernel(ArrayT input_array, OutT* __restrict__ out, - float* __restrict__ scale, + ScaleT* __restrict__ scale, size_t M, size_t K, FastDivMod K_div_128) { @@ -106,15 +123,44 @@ __global__ void __launch_bounds__(512) BlockLoad(input_array, x, K, block_y, block_x); // Find the scale of all elements - float block_scale = BlockReduceScale(x); + float block_scale = BlockReduceScale < OutT, + using_pow2_scaling || using_ue8m0_scale > (x); // Compute scale and store back - if (threadIdx.x == 0 && threadIdx.y == 0) { - size_t idx_n = blockIdx.z; - size_t idx_m = block_y; - size_t idx_k = block_x; - size_t idx = (idx_n * (M / 128) + idx_m) * (K / 128) + idx_k; - scale[idx] = __frcp_rn(block_scale); + // For FusedStackQuant, logical layout: Rows=N*M, Cols=K + // block_y -> idx_m (row block in M/128), block_x -> idx_k (col block in + // K/128) idx_n -> blockIdx.z + int tid = threadIdx.y * 32 + threadIdx.x; + if constexpr (using_ue8m0_scale) { + if (tid < 128) { + size_t r = tid; + size_t global_row = + (static_cast(blockIdx.z) * (M / 128) + block_y) * 128 + r; + size_t idx; + if constexpr (output_scale_transpose) { + // [K/128, N*M] + // idx = block_x * (static_cast(gridDim.z) * M) + global_row; + size_t total_cols = static_cast(gridDim.z) * M; + idx = (block_x / 4) * (total_cols * 4) + global_row * 4 + (block_x % 4); + } else { + // [N*M, K/128] + idx = global_row * (K / 128) + block_x; + } + StoreScale(scale, idx, __frcp_rn(block_scale)); + } + } else { + if (tid == 0) { + size_t idx; + if constexpr (output_scale_transpose) { + // [K/128, N*M/128] + idx = block_x * (static_cast(gridDim.z) * (M / 128)) + + (blockIdx.z * (M / 128) + block_y); + } else { + // [N*M/128, K/128] + idx = (blockIdx.z * (M / 128) + block_y) * (K / 128) + block_x; + } + StoreScale(scale, idx, __frcp_rn(block_scale)); + } } // Scale X and store to out @@ -135,11 +181,16 @@ __global__ void __launch_bounds__(512) } } -template +template __global__ void __launch_bounds__(512) FusedStackTransposeQuantGPUKernel(ArrayT input_array, OutT* __restrict__ out, - float* __restrict__ scale, + ScaleT* __restrict__ scale, size_t M, size_t K, FastDivMod K_div_128) { @@ -151,15 +202,44 @@ __global__ void __launch_bounds__(512) BlockLoad(input_array, x, K, block_y, block_x); // Find the scale of all elements - float block_scale = BlockReduceScale(x); + float block_scale = BlockReduceScale < OutT, + using_pow2_scaling || using_ue8m0_scale > (x); // Compute scale and store back - if (threadIdx.x == 0 && threadIdx.y == 0) { - size_t idx_n = blockIdx.z; - size_t idx_k = block_x; - size_t idx_m = block_y; - size_t idx = (idx_n * (K / 128) + idx_k) * (M / 128) + idx_m; - scale[idx] = __frcp_rn(block_scale); + // For FusedStackTransposeQuant, logical layout: Rows=N*K, Cols=M + // block_y -> idx_m (col block in M/128), block_x -> idx_k (row block in + // K/128) idx_n -> blockIdx.z + int tid = threadIdx.y * 32 + threadIdx.x; + if constexpr (using_ue8m0_scale) { + if (tid < 128) { + size_t r = tid; + size_t global_row = + (static_cast(blockIdx.z) * (K / 128) + block_x) * 128 + r; + size_t idx; + if constexpr (output_scale_transpose) { + // [M/128, N*K] + // idx = block_y * (static_cast(gridDim.z) * K) + global_row; + size_t total_rows = static_cast(gridDim.z) * K; + idx = (block_y / 4) * (total_rows * 4) + global_row * 4 + (block_y % 4); + } else { + // [N*K, M/128] + idx = global_row * (M / 128) + block_y; + } + StoreScale(scale, idx, __frcp_rn(block_scale)); + } + } else { + if (tid == 0) { + size_t idx; + if constexpr (output_scale_transpose) { + // [M/128, N*K/128] + idx = block_y * (static_cast(gridDim.z) * (K / 128)) + + (blockIdx.z * (K / 128) + block_x); + } else { + // [N*K/128, M/128] + idx = (blockIdx.z * (K / 128) + block_x) * (M / 128) + block_y; + } + StoreScale(scale, idx, __frcp_rn(block_scale)); + } } // Scale X and transpose in shared memory @@ -194,6 +274,9 @@ template void FusedStackTransposeQuantImpl(const Context& dev_ctx, const std::vector& x, bool transpose, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, DenseTensor* out, DenseTensor* scale) { int N = static_cast(x.size()); @@ -211,39 +294,129 @@ void FusedStackTransposeQuantImpl(const Context& dev_ctx, dim3 grid((M / 128) * (K / 128), 1, N); dim3 block(32, 16); auto* out_data = dev_ctx.template Alloc(out); - auto* scale_data = dev_ctx.template Alloc(scale); - FastDivMod K_div_128(K / 128); - switch (funcs::CalcArraySize(N)) { - SEGMENTED_ARRAY_KERNEL_HELPER({ - funcs::ConstPointerArraySetter setter(dev_ctx, x); - if (transpose) { - FusedStackTransposeQuantGPUKernel - <<>>( - setter.array, out_data, scale_data, M, K, K_div_128); - } else { - FusedStackQuantGPUKernel - <<>>( - setter.array, out_data, scale_data, M, K, K_div_128); - } - }); + if (using_ue8m0_scale) { + dev_ctx.template Alloc(scale); + } else { + dev_ctx.template Alloc(scale); } + + FastDivMod K_div_128(K / 128); + + DISPATCH_BOOL( + using_pow2_scaling, + k_using_pow2_scaling, + DISPATCH_BOOL( + using_ue8m0_scale, + k_using_ue8m0_scale, + DISPATCH_BOOL( + output_scale_transpose, + k_output_scale_transpose, + switch (funcs::CalcArraySize(N)) { + SEGMENTED_ARRAY_KERNEL_HELPER({ + funcs::ConstPointerArraySetter setter( + dev_ctx, x); + if (transpose) { + if (k_using_ue8m0_scale) { + FusedStackTransposeQuantGPUKernel< + phi::float8_e4m3fn, + decltype(setter.array), + int, + k_using_pow2_scaling, + k_using_ue8m0_scale, + k_output_scale_transpose> + <<>>( + setter.array, + out_data, + reinterpret_cast(scale->data()), + M, + K, + K_div_128); + } else { + FusedStackTransposeQuantGPUKernel< + phi::float8_e4m3fn, + decltype(setter.array), + float, + k_using_pow2_scaling, + k_using_ue8m0_scale, + k_output_scale_transpose> + <<>>( + setter.array, + out_data, + reinterpret_cast(scale->data()), + M, + K, + K_div_128); + } + } else { + if (k_using_ue8m0_scale) { + FusedStackQuantGPUKernel + <<>>( + setter.array, + out_data, + reinterpret_cast(scale->data()), + M, + K, + K_div_128); + } else { + FusedStackQuantGPUKernel + <<>>( + setter.array, + out_data, + reinterpret_cast(scale->data()), + M, + K, + K_div_128); + } + } + }); + }))); } template void FusedStackQuantKernel(const Context& dev_ctx, const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, DenseTensor* out, DenseTensor* scale) { - FusedStackTransposeQuantImpl(dev_ctx, x, false, out, scale); + FusedStackTransposeQuantImpl(dev_ctx, + x, + false, + using_pow2_scaling, + using_ue8m0_scale, + output_scale_transpose, + out, + scale); } template void FusedStackTransposeQuantKernel(const Context& dev_ctx, const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, DenseTensor* out, DenseTensor* scale) { - FusedStackTransposeQuantImpl(dev_ctx, x, true, out, scale); + FusedStackTransposeQuantImpl(dev_ctx, + x, + true, + using_pow2_scaling, + using_ue8m0_scale, + output_scale_transpose, + out, + scale); } } // namespace fusion @@ -255,7 +428,7 @@ PD_REGISTER_KERNEL(fused_stack_quant, phi::fusion::FusedStackQuantKernel, phi::bfloat16) { kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT8_E4M3FN); - kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + // kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); } PD_REGISTER_KERNEL(fused_stack_transpose_quant, @@ -264,5 +437,5 @@ PD_REGISTER_KERNEL(fused_stack_transpose_quant, phi::fusion::FusedStackTransposeQuantKernel, phi::bfloat16) { kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT8_E4M3FN); - kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + // kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); } diff --git a/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.h b/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.h index 0dd685305c74a7..292382502624eb 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.h +++ b/paddle/phi/kernels/fusion/gpu/fused_stack_transpose_quant_kernel.h @@ -23,12 +23,18 @@ namespace fusion { template void FusedStackQuantKernel(const Context& dev_ctx, const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, DenseTensor* out, DenseTensor* scale); template void FusedStackTransposeQuantKernel(const Context& dev_ctx, const std::vector& x, + bool using_pow2_scaling, + bool using_ue8m0_scale, + bool output_scale_transpose, DenseTensor* out, DenseTensor* scale); diff --git a/paddle/phi/ops/yaml/fused_ops.yaml b/paddle/phi/ops/yaml/fused_ops.yaml index 0b22345aa1733a..e00051ff7dacfa 100644 --- a/paddle/phi/ops/yaml/fused_ops.yaml +++ b/paddle/phi/ops/yaml/fused_ops.yaml @@ -485,7 +485,7 @@ backward: fused_seqpool_cvm_grad - op : fused_stack_quant - args : (Tensor[] x) + args : (Tensor[] x, bool using_pow2_scaling=false, bool using_ue8m0_scale=false, bool output_scale_transpose=true) output : Tensor(out), Tensor(scale) infer_meta : func : FusedStackQuantInferMeta @@ -495,7 +495,7 @@ support_dygraph_mode : true - op : fused_stack_transpose_quant - args : (Tensor[] x) + args : (Tensor[] x, bool using_pow2_scaling=false, bool using_ue8m0_scale=false, bool output_scale_transpose=true) output : Tensor(out), Tensor(scale) infer_meta : func : FusedStackTransposeQuantInferMeta diff --git a/python/paddle/incubate/nn/functional/fp8.py b/python/paddle/incubate/nn/functional/fp8.py index 918ed80ce1aeb3..c8c6aed6ec950c 100644 --- a/python/paddle/incubate/nn/functional/fp8.py +++ b/python/paddle/incubate/nn/functional/fp8.py @@ -33,7 +33,11 @@ def _empty_tensor() -> Tensor: def fused_stack_transpose_quant( - x: Sequence[Tensor], transpose: bool = True + x: Sequence[Tensor], + transpose: bool = True, + using_pow2_scaling: bool = False, + using_ue8m0_scale: bool = False, + output_scale_transpose: bool = True, ) -> tuple[Tensor, Tensor]: """ Fused operation that performs stacking, optional transposition, and quantization @@ -44,6 +48,12 @@ def fused_stack_transpose_quant( has shape `[M, K]`. All tensors should have the same shape and dtype. transpose (bool, optional): If True, applies a transpose before quantization. Default is True. + using_pow2_scaling (bool, optional): Whether to use power-of-2 quantization + scaling for hardware efficiency. Default: False. + using_ue8m0_scale (bool, optional): Whether to use ue8m0 quantization scale. + Default: False. + output_scale_transpose (bool, optional): Whether to transpose the output scale. + Default: True. Returns: tuple: @@ -75,9 +85,13 @@ def fused_stack_transpose_quant( """ if in_dynamic_or_pir_mode(): if transpose: - return _C_ops.fused_stack_transpose_quant(x) + return _C_ops.fused_stack_transpose_quant( + x, using_pow2_scaling, using_ue8m0_scale, output_scale_transpose + ) else: - return _C_ops.fused_stack_quant(x) + return _C_ops.fused_stack_quant( + x, using_pow2_scaling, using_ue8m0_scale, output_scale_transpose + ) def fused_act_dequant( diff --git a/test_fused_stack_transpose_quant.py b/test_fused_stack_transpose_quant.py new file mode 100644 index 00000000000000..d4a02f771f1ec7 --- /dev/null +++ b/test_fused_stack_transpose_quant.py @@ -0,0 +1,338 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.base import core + +M, K, N = 4096, 7168, 4096 +DTYPE_PD = paddle.bfloat16 + + +import paddle + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Align x to TMA-required size. + Args: + x: size in elements + element_size: size of each element in bytes + Returns: + Aligned size in elements + """ + kNumTMAAlignmentBytes = 16 + assert kNumTMAAlignmentBytes % element_size == 0 + return align(x, kNumTMAAlignmentBytes // element_size) + + +def ceil_to_ue8m0_paddle(x: paddle.Tensor): + """ + x > 0 + return 2 ^ ceil(log2(x)) + """ + # log2(x) + log2_x = paddle.log(x) / paddle.log(paddle.to_tensor(2.0, dtype=x.dtype)) + # ceil + ceil_log2_x = paddle.ceil(log2_x) + # 2^k + return paddle.pow(paddle.to_tensor(2.0, dtype=x.dtype), ceil_log2_x) + + +def _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl( + x: paddle.Tensor, +): + assert x.dtype == paddle.float and x.dim() in (2, 3) + + ue8m0_tensor = (x.view(paddle.int) >> 23).to(paddle.uint8) + + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + + padded = paddle.zeros( + (b, aligned_mn, aligned_k), device=x.device, dtype=paddle.uint8 + ) + padded[:, :mn, :k] = ue8m0_tensor + + padded = ( + padded.view(-1) + .view(dtype=paddle.int) + .view(b, aligned_mn, aligned_k // 4) + ) + + transposed = paddle.zeros( + (b, aligned_k // 4, aligned_mn), device=x.device, dtype=paddle.int + ).mT + transposed[:, :, :] = padded + + aligned_x = transposed[:, :mn, :] + + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def transform_scale_ue8m0(sf, mn, weight_block_size=None): + get_mn_major_tma_aligned_packed_ue8m0_tensor = ( + _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl + ) + if weight_block_size: + assert weight_block_size == [128, 128] + sf = sf.index_select(-2, paddle.arange(mn, device=sf.device) // 128) + sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + return sf + + +def quant_ref(x_scale_fp32, mn, weight_block_size=None): + # x_scale_fp32_ = ceil_to_ue8m0_paddle(x_scale_fp32) + ref_e8m0_scale = transform_scale_ue8m0( + x_scale_fp32, mn=mn, weight_block_size=weight_block_size + ) + return ref_e8m0_scale + + +class TestFusedStackTransposeQuant(unittest.TestCase): + def run_op( + self, + x_list, + transpose, + using_pow2_scaling, + use_ue8m0_scale, + output_scale_transpose, + ): + inputs = x_list + + out, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant( + inputs, + transpose, + using_pow2_scaling, + use_ue8m0_scale, + output_scale_transpose, + ) + return out, scale + + def test_transpose_input_output_consistency(self): + if not core.is_compiled_with_cuda(): + return + + np.random.seed(0) + w_paddle_list = [] + + for _ in range(3): + w = paddle.randn([N, K], dtype=DTYPE_PD) + # y = paddle.zeros([M, N], dtype=DTYPE_PD) + w_paddle_list.append(w) + + # Case 1: output_scale_transpose = False, use_ue8m0_scale = True + out_false, scale_false = self.run_op( + w_paddle_list, + transpose=True, + using_pow2_scaling=False, + use_ue8m0_scale=True, + output_scale_transpose=False, + ) + + # Case 2: output_scale_transpose = True, use_ue8m0_scale = True + out_true, scale_true = self.run_op( + w_paddle_list, + transpose=True, + using_pow2_scaling=False, + use_ue8m0_scale=True, + output_scale_transpose=True, + ) + + # Case 3: output_scale_transpose = True, use_ue8m0_scale = False + out_32_false, scale_32_false = self.run_op( + w_paddle_list, + transpose=True, + using_pow2_scaling=True, + use_ue8m0_scale=False, + output_scale_transpose=False, + ) + + np.testing.assert_allclose( + out_false.numpy(), out_true.numpy(), atol=0, rtol=0 + ) + np.testing.assert_allclose( + out_false.numpy(), out_32_false.numpy(), atol=0, rtol=0 + ) + + scale_false_np = scale_false.numpy() + scale_true_np = scale_true.numpy() + scale_32_false_np = scale_32_false.numpy() + + print(f"Scale False shape: {scale_false_np.shape}") + print(f"Scale True shape: {scale_true_np.shape}") + print(f"Scale 32 True shape: {scale_32_false_np.shape}") + + scale_false_T = scale_false_np.T + + scale_32_ref = quant_ref( + scale_32_false, out_32_false.shape[-2], [128, 128] + ) + + np.testing.assert_allclose( + scale_32_ref.numpy(), scale_true_np.T, atol=0, rtol=0 + ) + np.testing.assert_allclose(scale_false_T, scale_true_np, atol=0, rtol=0) + + def test_output_consistency(self): + if not core.is_compiled_with_cuda(): + return + + np.random.seed(0) + w_paddle_list = [] + + for _ in range(3): + w = paddle.randn([N, K], dtype=DTYPE_PD) + # y = paddle.zeros([M, N], dtype=DTYPE_PD) + w_paddle_list.append(w) + + # Case 1: output_scale_transpose = False, use_ue8m0_scale = True + out_false, scale_false = self.run_op( + w_paddle_list, + transpose=False, + using_pow2_scaling=False, + use_ue8m0_scale=True, + output_scale_transpose=False, + ) + + # Case 2: output_scale_transpose = True, use_ue8m0_scale = True + out_true, scale_true = self.run_op( + w_paddle_list, + transpose=False, + using_pow2_scaling=False, + use_ue8m0_scale=True, + output_scale_transpose=True, + ) + + # Case 3: output_scale_transpose = True, use_ue8m0_scale = False + out_32_false, scale_32_false = self.run_op( + w_paddle_list, + transpose=False, + using_pow2_scaling=True, + use_ue8m0_scale=False, + output_scale_transpose=False, + ) + + np.testing.assert_allclose( + out_false.numpy(), out_true.numpy(), atol=0, rtol=0 + ) + np.testing.assert_allclose( + out_false.numpy(), out_32_false.numpy(), atol=0, rtol=0 + ) + + scale_false_np = scale_false.numpy() + scale_true_np = scale_true.numpy() + scale_32_false_np = scale_32_false.numpy() + + print(f"Scale False shape: {scale_false_np.shape}") + print(f"Scale True shape: {scale_true_np.shape}") + print(f"Scale 32 True shape: {scale_32_false_np.shape}") + + scale_false_T = scale_false_np.T + + scale_32_ref = quant_ref( + scale_32_false, out_32_false.shape[-2], [128, 128] + ) + + np.testing.assert_allclose( + scale_32_ref.numpy(), scale_true_np.T, atol=0, rtol=0 + ) + np.testing.assert_allclose(scale_false_T, scale_true_np, atol=0, rtol=0) + + def test_gemm_out(self): + if not core.is_compiled_with_cuda(): + return + + np.random.seed(0) + w_paddle_list = [] + + for _ in range(3): + w = paddle.randn([N, K], dtype=DTYPE_PD) + # y = paddle.zeros([M, N], dtype=DTYPE_PD) + w_paddle_list.append(w) + + # Case 1: output_scale_transpose = False, use_ue8m0_scale = True + out_false, scale_false = self.run_op( + w_paddle_list, + transpose=False, + using_pow2_scaling=False, + use_ue8m0_scale=True, + output_scale_transpose=False, + ) + + # Case 2: output_scale_transpose = True, use_ue8m0_scale = True + out_true, scale_true = self.run_op( + w_paddle_list, + transpose=False, + using_pow2_scaling=False, + use_ue8m0_scale=True, + output_scale_transpose=True, + ) + + # Case 3: output_scale_transpose = True, use_ue8m0_scale = False + out_32_false, scale_32_false = self.run_op( + w_paddle_list, + transpose=False, + using_pow2_scaling=True, + use_ue8m0_scale=False, + output_scale_transpose=False, + ) + + np.testing.assert_allclose( + out_false.numpy(), out_true.numpy(), atol=0, rtol=0 + ) + np.testing.assert_allclose( + out_false.numpy(), out_32_false.numpy(), atol=0, rtol=0 + ) + + scale_false_np = scale_false.numpy() + scale_true_np = scale_true.numpy() + scale_32_false_np = scale_32_false.numpy() + + print(f"Scale False shape: {scale_false_np.shape}") + print(f"Scale True shape: {scale_true_np.shape}") + print(f"Scale 32 True shape: {scale_32_false_np.shape}") + + scale_false_T = scale_false_np.T + + scale_32_ref = quant_ref( + scale_32_false, out_32_false.shape[-2], [128, 128] + ) + + np.testing.assert_allclose( + scale_32_ref.numpy(), scale_true_np.T, atol=0, rtol=0 + ) + np.testing.assert_allclose(scale_false_T, scale_true_np, atol=0, rtol=0) + + +if __name__ == '__main__': + unittest.main() From e7a42a74623969b6aa135c1ed0a0989885bebd15 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Sat, 3 Jan 2026 11:30:29 +0800 Subject: [PATCH 2/2] fix ci --- python/paddle/incubate/nn/functional/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/incubate/nn/functional/fp8.py b/python/paddle/incubate/nn/functional/fp8.py index c8c6aed6ec950c..1d0c03963b2b01 100644 --- a/python/paddle/incubate/nn/functional/fp8.py +++ b/python/paddle/incubate/nn/functional/fp8.py @@ -37,7 +37,7 @@ def fused_stack_transpose_quant( transpose: bool = True, using_pow2_scaling: bool = False, using_ue8m0_scale: bool = False, - output_scale_transpose: bool = True, + output_scale_transpose: bool = False, ) -> tuple[Tensor, Tensor]: """ Fused operation that performs stacking, optional transposition, and quantization