Replies: 16 comments 51 replies
-
https://github.com/IntelLabs/FP8-Emulation-Toolkit Look a good starting point I find some day ago... I'll have a look. |
Beta Was this translation helpful? Give feedback.
-
https://arxiv.org/pdf/2209.05433v2 I'll have to read it more carefully, but look like we need to apply a scale factor to. Possibly is common for all weight or per weight. |
Beta Was this translation helpful? Give feedback.
-
static float from_fp8(uint8_t fp8) {
union {
float f;
uint32_t i;
} u;
const uint32_t t = fp8;
if ((fp8 & 0x7F) == 0x7F) return (fp8==0xFF)?(-NAN):(+NAN);
int exp = ((fp8 >> 3) & 0x0F) - 7;
if (fp8 & 0x78) {
u.i = (t & 7) << 20; // mantissa: bit 2-0 -> 22-20
} else if (fp8 & 0x04) { // denormalise:
u.i = (t & 3) << 21;
exp = -7;
} else if (fp8 & 0x02) {
u.i = (t & 1) << 22;
exp = -8;
} else if (fp8 & 0x01) {
u.i = 0;
exp = -9;
} else {
u.i = 0;
exp = -127;
}
u.i |= (exp + 127) << 23; // exponent
u.i |= (t & 128) << 24; // sign
return u.f;
} with more arithmetic expression, not sure it is best. At least the calculation is good (or at least the same as the current one). |
Beta Was this translation helpful? Give feedback.
-
I do a few statistical calculations on weights (mean, standard deviation, min max)... |
Beta Was this translation helpful? Give feedback.
-
I wrote an avx512 vectorized fp8 loader. On my znver4 Threadripper I managed to get prompt processing tokens per second up from 10 tok/sec to 56 tok/sec on my workstation for Mistral Nemo. See the commit to my fp8 branch here: c10a65c On the other hand, using BF16 gives me 385.92 tok/sec for prompt processing. The K quants probably go just as fast, and token generation goes even faster. So unless someone can come up with a significantly better faster way of dequantizing FP8 then I really don't think it has anything to offer us until native CPU hardware support becomes available. Except it is worth supporting, because NVIDIA has hardware support for it. However I don't own an NVIDIA graphics card that has FP8 yet. If someone is willing to devote the eng resources to implementing CUDA support, then I will merge that and ensure it at least works on CPU too. Although realistically, anyone who chooses FP8 will only be interested in doing it on GPU. There's also a question of what you'd do for AMD GPUs. As I mentioned, as long as it can be made to work I'm happy. Even if it's only primarily good for NVIDIA GPU owners. That's my judgement for now. If anyone wants to volunteer, then send me pull requests to my FP8 branch and we'll keep working on it. |
Beta Was this translation helpful? Give feedback.
-
OK some bench with this fp8 branch:
Now using this forme of computing fp8 > fp16: static float from_fp8_8(uint8_t fp8) {
union {
float f;
uint32_t i;
} u;
const uint32_t t = fp8;
// if ((fp8 & 0x7F) == 0x7F) return (fp8==0xFF)?(-NAN):(+NAN); // not needed here there is no NAN after quantisation.
auto exp_8 = t & 0x78;
auto exp_32 = (exp_8+(120<<3))<<20;
u.i = (t & 7) << 20; // mantissa: bit 2-0 -> 22-20
u.i |= exp_8 ? exp_32 : (-6 + 127) << 23; // exponent
if (!exp_8) { u.f -= 1.0/64; } // 2⁻⁶
u.i |= (t & 0x80) << 24; // sign
return u.f;
} I get this vectorised implemetation: #if defined(__AVX512F__) && defined(__AVX512VL__) && defined(__AVX512BW__)
#include <immintrin.h>
static __m512 llamafile_from_fp8_e4m3_avx512(__m128i fp8_vec) {
// extract componants:
__m128i expo_8 = _mm_and_si128(fp8_vec, _mm_set1_epi8(0x78));
__m128i mant_8 = _mm_and_si128(fp8_vec, _mm_set1_epi8(0x07));
__m128i sign_8 = _mm_and_si128(fp8_vec, _mm_set1_epi8(0x80));
// denorm mask
__mmask16 is_denorm = _mm_cmpeq_epi8_mask(expo_8, _mm_set1_epi8(0));
// convert to 32 bits
__m512i expo_32 = _mm512_cvtepu8_epi32(expo_8);
__m512i mant_32 = _mm512_cvtepu8_epi32(mant_8);
__m512i sign_32 = _mm512_cvtepu8_epi32(sign_8);
// shift
expo_32 = _mm512_slli_epi32(_mm512_add_epi32(expo_32,_mm512_set1_epi32(120<<3)), 20);
mant_32 = _mm512_slli_epi32(mant_32, 20);
sign_32 = _mm512_slli_epi32(sign_32, 24);
// correction denorm expo:
expo_32 = _mm512_mask_blend_epi32(is_denorm, expo_32, _mm512_set1_epi32((-6 + 127) << 23));
// merge mantissa+exponent
__m512 result = _mm512_castsi512_ps(_mm512_or_si512(expo_32,mant_32));
// correction denorm mantissa:
result = _mm512_mask_add_ps(result, is_denorm, result, _mm512_set1_ps(-1.0/64));
// add sign
return _mm512_castsi512_ps(_mm512_or_si512(sign_32,_mm512_castps_si512(result)));
}
#endif now some bench:
Now Token generation is good. and prompt processing is not that bad if we think that we do not use bf16. pretty sure we can get a x2 if we convert fp8 to bf16 and use it for compute as is done with bf16 quant. I may study closely this branch and see if I can add fp8/bf16 in tinyblas. It can be good exercise. Or update my test on blas_bf16 (https://github.com/Djip007/llama.cpp/tree/poc/bf16) but may take some time with all re-factoring that have be done on llama.cpp 😉. may be faster to rewrite it on top of fp8 branch... |
Beta Was this translation helpful? Give feedback.
-
we may take some time to read that: That how FP8 quantisation are performe on vllm (I think...) Look they have per tensor scale factor.
|
Beta Was this translation helpful? Give feedback.
-
19/10/2024 obsolete!!! it can be "much" better! Ok some more work. I create my own backend for more control on compute... (more later o, that)
Note: the 3 V2 FP8 decoding have the same perplexity so we can go for the faster (V2.2 ...) |
Beta Was this translation helpful? Give feedback.
-
Some benchmark.
The llamafile colonne is the current 0.8.13 release with BF16/Q8/Q6/Q5 compute for reference
|
Beta Was this translation helpful? Give feedback.
-
@Djip007 Am I understanding correctly that you got FP8 E4M3 to:
If so, that's outstanding; I am interested in including your work in llamafile. What especially interests me is that you figured out how to make it work with flush to zero. Were those perplexity scores measured while flushing subnormals to zero? What about the benchmarks? For your kernel shape, what you want to do is pick whatever shape uses the most vector registers without spilling to the stack. For BF16 on AVX512 and NEON which have 32 registers that shape was either 5x5 or 8x3. Also can you share your code? |
Beta Was this translation helpful? Give feedback.
-
OK I publish it, it is not finish and more like a POC..
The load for fp8 can be found here: load_fp8
But it is really experimental, I want to do much more work... |
Beta Was this translation helpful? Give feedback.
-
Some more news, I have added more FP8 format on my backend. FP8: (19/10/2024: obsolete !! )
for reference:
G => scale global A[k,m] => scale (float) |
Beta Was this translation helpful? Give feedback.
-
I was so focused on the speed that I forgot to look at the quality of the quantization. With a single scale for a bloc of 256 weight: FP8_E3M4_K2 : PPL = 6.346961 ± 0.039010 In this case I even have add "output.weight" fp8 quantisation. Now I have to re-compute all the perplexity 😎 Q8_0 : PPL = 6.3445 +/- 0.03898 |
Beta Was this translation helpful? Give feedback.
-
OK this is better with "correct" rounding: #> zen4 / BF16 / wiki.test.raw / rounding / Mistral-Nemo-Instruct.
Now most are from Q6_K and Q8_0. (some more result in progress) But as you see I have best result with E3M4 format... that may not be a good point for AMD/Nvidia that have implement E4M3 in hardware. (but to be fare only the weight are quantize, the compute is done after convert on BF16.) liste of quantized tensor for the FP8 are : static constexpr std::list<std::string> LIST_WEIGHT_CONVERT() { return {
"ffn_down.weight",
"ffn_gate.weight",
"ffn_up.weight",
"ttn_k.weight",
"ttn_q.weight",
"ttn_v.weight",
"ttn_output.weight",
"output.weight",
};} (before I do not quantize the "output.weight") |
Beta Was this translation helpful? Give feedback.
-
some more... Meta-Llama-3.1-8B-Instruct(zen4 / BF16 / wiki.test.raw / rounding / output-weight)
when we see what we get when model are not train with FP8 quantize award ? Mistral-7B-Instruct-v0.3( zen4 / BF16 / wiki.test.raw / rounding / output-weight )
this one is better! |
Beta Was this translation helpful? Give feedback.
-
OK for now I create a branch with FP8 base support. It is not for speed, I only add base part (quantize / convert / dot) Djip007@cbd3abd I add 4 FP8 format:
Now I need to make it fast..
|
Beta Was this translation helpful? Give feedback.
-
As discution start here: #543 (reply in thread):
There is 2 (or 3 E3M4 is give in some paper...)
I really like to help. Have a look what we can have.
I think E5M2 format can be convert to FP16 the same way we do with BF16 <-> FP32. I'll have a look and get back 😎
Beta Was this translation helpful? Give feedback.
All reactions