-
Notifications
You must be signed in to change notification settings - Fork 502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support fp8 e5m2 kv cache with flashinfer #1204
Conversation
Nice work! I'll review it asap. May we also support FP8 E4M3? |
FP8 E4M3 needs scale factor and calibration. We may add it in the future. |
if self.server_args.kv_cache_dtype == "auto": | ||
self.kv_cache_dtype = self.dtype | ||
elif self.server_args.kv_cache_dtype == "fp8_e5m2": | ||
if self.server_args.disable_flashinfer or self.server_args.enable_mla: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, only FlashInfer is supported and not Triton, due to the issue of insufficient smem. This needs to be fixed in another PR.
if cache_v.dtype != self.dtype: | ||
cache_v = cache_v.to(self.dtype) | ||
if self.store_dtype != self.dtype: | ||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
workaround for float8_e5m2
Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
Co-authored-by: Lianmin Zheng <[email protected]>
Sorry to dig this up but - are we suggesting that fp8 kv cache increased accuracy in both mmlu and gsm8k? Are we sure we don't have those values in the table reversed? |
@qeternity In the previous evaluation, I tested gsm8k with 200 questions (default setting in the benchmark script), so the result may not be reliable enough. I just test all the datasets and update the result in the table. |
Motivation
Support fp8 e5m2 kv cache with flashinfer.
Usage
Add
--kv-cache-dtype fp8_e5m2
to enable this feature. Currently it only works when flashinfer is not disabled.Performance & Accuracy
Tested with llama2-13b-chat on A100, the throughput increased by 17.8% without accuracy degradation.
Reproduce
The performance boost is model dependent. llama3-8b was also tested, but the performance was not improved.