You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
-[Saving a quantized checkpoint](#saving-a-quantized-checkpoint)
49
-
-[Add the scales to `Linear` layers](#add-the-scales-to-linear-layers)
47
+
-[Add the scales to Linear layers](#add-the-scales-to-linear-layers)
50
48
-[Update model config](#update-model-config)
51
49
52
-
# Why?
50
+
##Why?
53
51
54
52
tl;dr:
55
53
@@ -67,15 +65,15 @@ Starting with NVIDIA H100 GPU, GPUs have *hardware support* for 8 bit floating p
67
65
3. Depending on the GPU, fp8 FLOPS are just higher than `bf16` FLOPS. E.g. See [H100 specifications](https://www.nvidia.com/en-us/data-center/h100/); bfloat16 has ~2k teraFLOPS and fp8 has ~4k teraFLOPS
68
66
69
67
70
-
# How?
68
+
##How?
71
69
72
-
## Note on executing fp8 models
70
+
###Note on executing fp8 models
73
71
74
72
When we talk about `fp8` models, we typically only are talking about the **weights being `fp8`**. The actual execution of the model is still done in `bf16`. So all the **intermediate tensors are still in `bf16`**, and it's the underlying CUDA kernels that are taking in `bf16` tensors and `fp8` weights.
75
73
76
74
**fp8 models still use `bf16` kv cache by default** (since the kv cache stores kv values, which are intermediate tensors).
77
75
78
-
## fp8 bit format
76
+
###fp8 bit format
79
77
80
78
There are a number of different `fp8` formats; the most common is `float8_e4m3fn`. Here are some facts about it:
81
79
@@ -108,7 +106,7 @@ So this leads us with two questions for quantization:
108
106
1.`bf16` can store values between `[-3.38953e+38, +3.38953e+38]`, how do we fit that into `fp8` range of `[-448, +448]`?
109
107
2. How do we take advantage of the distribution of values in `fp8`?
110
108
111
-
## Quantization - scaling to lower precision loss & handle large values
109
+
###Quantization - scaling to lower precision loss & handle large values
112
110
113
111
Since `bf16` and `fp8` have different ranges, we need to scale the values to fit into the `fp8` range. This scale is based
114
112
on the max value of the data at `bf16`, and is roughly computed like:
@@ -128,7 +126,7 @@ And to dequantize (which is essentially done on the fly at runtime inside the CU
128
126
x_dequantized = x.to(torch.bfloat16) * scale
129
127
```
130
128
131
-
## Finer grained scale - weight block size
129
+
###Finer grained scale - weight block size
132
130
133
131
Above I showed the scale being a single value, but you can also have it be a tensor. If you look at some popular open source `fp8` models they typically use this option.
assert scale.shape == torch.Size([N // n, K // k])
147
145
```
148
146
149
-
# Saving a quantized checkpoint
147
+
##Saving a quantized checkpoint
150
148
151
149
For compatibility with things like VLLM there's a couple things we need to do:
152
150
153
-
## Add the scales to `Linear` layers
151
+
###Add the scales to Linear layers
154
152
155
153
We need to add the previously computed `weight_scale` as a parameter to each of the `Linear` layers. This basically means just replace the `Linear` layer with this custom `PackedLinear` class, where `weight` is the `fp8` tensor, and `weight_scale` is the scale from previous sections.
156
154
@@ -162,7 +160,7 @@ class PackedLinear(torch.nn.Module):
This part is really easy, just add a `quantization_config` into the model's config. This will also appear in the `config.json` file in the huggingface repo of the model.
0 commit comments