Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support fp8 e5m2 kv cache with flashinfer #1204

Merged
merged 7 commits into from
Aug 26, 2024

Conversation

ispobock
Copy link
Collaborator

@ispobock ispobock commented Aug 25, 2024

Motivation

Support fp8 e5m2 kv cache with flashinfer.

Usage

Add --kv-cache-dtype fp8_e5m2 to enable this feature. Currently it only works when flashinfer is not disabled.

Performance & Accuracy

Tested with llama2-13b-chat on A100, the throughput increased by 17.8% without accuracy degradation.

Enable fp8_e5m2 kv cache Throughput MMLU (nsub=10) Avg Accuracy gsm8k Accuracy
N 7.04 req/s 0.487 0.340
Y 8.30 req/s 0.488 0.334
Reproduce
# w/o fp8_e5m2 kv cache
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-13b-chat-hf --port 30000 --trust-remote-code --disable-radix-cache --tp=1
# w/ fp8_e5m2 kv cache
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-13b-chat-hf --port 30000 --trust-remote-code --disable-radix-cache --tp=1 --kv-cache-dtype fp8_e5m2

# benchmark
python3 -m sglang.bench_serving --backend sglang --tokenizer meta-llama/Llama-2-13b-chat-hf  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000 --request-rate 128

# evaluation
python3 benchmark/mmlu/bench_sglang.py --nsub 10
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319

The performance boost is model dependent. llama3-8b was also tested, but the performance was not improved.

@zhyncs zhyncs self-assigned this Aug 25, 2024
@zhyncs
Copy link
Member

zhyncs commented Aug 25, 2024

Nice work! I'll review it asap. May we also support FP8 E4M3?

@zhyncs zhyncs requested a review from yzh119 August 25, 2024 08:25
@zhyncs zhyncs added the feature label Aug 25, 2024
@zhyncs zhyncs mentioned this pull request Aug 25, 2024
29 tasks
@ispobock
Copy link
Collaborator Author

May we also support FP8 E4M3?

FP8 E4M3 needs scale factor and calibration. We may add it in the future.

python/sglang/srt/mem_cache/memory_pool.py Outdated Show resolved Hide resolved
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, only FlashInfer is supported and not Triton, due to the issue of insufficient smem. This needs to be fixed in another PR.

if cache_v.dtype != self.dtype:
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workaround for float8_e5m2

Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2

Co-authored-by: Lianmin Zheng <[email protected]>
@zhyncs zhyncs enabled auto-merge (squash) August 25, 2024 17:37
@merrymercy merrymercy changed the title Support fp8 e5m2 kv cache with flashinfer [Feature] Support fp8 e5m2 kv cache with flashinfer Aug 25, 2024
@merrymercy merrymercy merged commit 2c615d1 into sgl-project:main Aug 26, 2024
4 of 5 checks passed
@ispobock ispobock mentioned this pull request Sep 1, 2024
3 tasks
@qeternity
Copy link
Contributor

Sorry to dig this up but - are we suggesting that fp8 kv cache increased accuracy in both mmlu and gsm8k? Are we sure we don't have those values in the table reversed?

@ispobock
Copy link
Collaborator Author

ispobock commented Oct 12, 2024

@qeternity In the previous evaluation, I tested gsm8k with 200 questions (default setting in the benchmark script), so the result may not be reliable enough. I just test all the datasets and update the result in the table.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants