Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surpport kv cache int8/int4 for triton backend #1644

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

yuguo-Jack
Copy link

Motivation

surpport kv cache int8/int4 for triton backend

Modifications

use c8 cmd:
--kv-cache-dtype int8
use c4 cmd:
--kv-cache-dtype int4 --kvint4-groupsize 32

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! They look good and i have a few comments.

python/sglang/srt/mem_cache/memory_pool.py Show resolved Hide resolved
python/sglang/srt/mem_cache/memory_pool.py Outdated Show resolved Hide resolved
@merrymercy
Copy link
Contributor

Hi @yuguo-Jack Can you add an end-to-end accuracy test, similar to this one?

def test_mmlu(self):

self.k_buffer = [
torch.empty(
(size + 1, head_num, head_dim // 2), dtype=torch.int8, device="cuda"
)
Copy link
Contributor

@liangan1 liangan1 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except for the 'cuda', 'xpu' is also supported in the main branch. Change to use device=device as the original code? Also apply to other codes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

python/sglang/srt/mem_cache/memory_pool.py Outdated Show resolved Hide resolved
@merrymercy
Copy link
Contributor

@yuguo-Jack Can you follow up on this? This is a high priority item and we would like to merge this as soon as possible once an accuracy unit test is added

@@ -273,6 +281,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"gptq_marlin",
"awq_marlin",
"bitsandbytes",
"compressed-tensors",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this feature only comes with compressed-tensors?
Can we decouple a bit, and add torchao's INT4/INT8 support too?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can reuse a8w8 linear in vllm

@merrymercy
Copy link
Contributor

@yuguo-Jack Can you resolve the conflicts and add some correctness tests?

  1. Kernel-level unit tests. Make sure to compare it against a reference implementation, which can be a pytorch implementation or a triton implementation in fp16.
  2. End-to-end unit tests. Test the MMLU score.

@merrymercy
Copy link
Contributor

Let me know when the tests are added. I will review it again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants