Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to cast 16/32-bit to FP8? #965

Open
mxjmtxrm opened this issue Jun 25, 2024 · 3 comments
Open

How to cast 16/32-bit to FP8? #965

mxjmtxrm opened this issue Jun 25, 2024 · 3 comments
Labels
question Further information is requested

Comments

@mxjmtxrm
Copy link

mxjmtxrm commented Jun 25, 2024

Hi, how to cast a float/bfloat16 tensor to fp8? I want to conduct W8A8 (fp8) quantization. But I didn't find an example of quantizing act to FP8 format.

@timmoon10 timmoon10 added the question Further information is requested label Jun 25, 2024
@timmoon10
Copy link
Collaborator

timmoon10 commented Jun 25, 2024

The easiest approach is to use native PyTorch FP8 dtypes:

x = torch.randn(128, device="cuda", dtype=torch.float32)
y = x.to(dtype=torch.float8_e4m3fn)  # or torch.float8_e5m2

You could also use transformer_engine.pytorch.Float8Tensor / float8_experimental.Float8Tensor:

scale = torch.ones(1, device="cuda", dtype=torch.float32)
y1 = te.Float8Tensor.to_float8(x)
y2 = float8_experimental.Float8Tensor.to_float8(x, scale, torch.float8_e4m3fn)

These classes are based on each other and they have some nice convenience features (support for scaling factors, casting to higher precision for ops that don't support FP8, float8_experimental has torch.compile support).

Finally, you could directly use the FP8 kernels from Transformer Engine:

y = te.cpp_extensions.cast_to_fp8(
    x,
    fp8_meta,
    0,
    transformer_engine_torch.DType.kFloat8E4M3,
)

I strongly advise against using these internal functions though. Their APIs are unstable, messy, and tightly integrated with TE's logic for computing FP8 scaling factors.

@mxjmtxrm
Copy link
Author

Thanks @timmoon10.
How to do mixed-precision calculations? matrix multiplication of FP8 and FP16 tensors to get FP16 output.

@timmoon10
Copy link
Collaborator

timmoon10 commented Jun 26, 2024

If you just want the performance benefit of FP8 matmuls, I recommend using Transformer Engine modules (like te.Linear) in your model (see this FP8 tutorial). They will internally handle the FP8 casts and FP8 scaling factors.

If you want more control, you'll have to get a bit into the weeds. I'm not sure if native PyTorch FP8 tensors support matmuls (even if they did, there would be numerical issues without FP8 scaling factors), but I see that float8_experimental.Float8Tensor does support matmuls with scaling factors (see addmm_float8_unwrapped). As far as I can tell, this just ends up calling cuBLAS (see scaled_gemm). Be advised that cuBLAS only supports FP8 inputs (see the FP8 support matrix for cublasLtMatmul). Implementing a custom matmul kernel with support for mixed FP8 and FP16 inputs may be possible using CUTLASS, but would get quite involved (and probably still be slower than TE for end-to-end training).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants