Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Degradation when Using MInference with Qwen2-7B-Instruct Model #71

Open
yumingfan-0219 opened this issue Aug 26, 2024 · 1 comment
Assignees
Labels
question Further information is requested

Comments

@yumingfan-0219
Copy link

Describe the issue

Hello,

I am encountering an unexpected performance issue while using the MInference library with the Qwen/Qwen2-7B-Instruct model. I have followed the example provided in run_hf.py and made minimal changes to adapt it to the Qwen model. Here is the modified code snippet:

import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from minference import MInference
prompt = "Hello, my name is"
model_name = "Qwen/Qwen2-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="cuda",
)
minference_patch = MInference("minference", model_name)
model = minference_patch(model)
start_time = time.time()
batch_inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # Corrected 'cu da' to 'cuda'
outputs = model.generate(**batch_inputs, max_length=10)
end_time = time.time()
elapsed_time = end_time - start_time
tokens_per_second = len(outputs[0]) / elapsed_time
print(f"Tokens per second with MInference: {tokens_per_second}")

when the MInference patch line is commented out, the tokens per second increase to approximately 26 tokens/s.
When I measure the tokens per second with MInference enabled, the rate is approximately 10 tokens/s. However, when I disable the MInference patch (commenting out the line model = minference_patch(model)), the rate increases to about 26 tokens/s. I am seeking clarification on why there is such a significant performance drop when using MInference, and whether there might be a mistake in my usage or a potential bug in the library.

I would appreciate any guidance or insights you can provide to help me resolve this issue.

Thank you for your time and assistance.

@yumingfan-0219 yumingfan-0219 added the question Further information is requested label Aug 26, 2024
@iofu728 iofu728 self-assigned this Aug 27, 2024
@iofu728
Copy link
Contributor

iofu728 commented Aug 27, 2024

Hi @yumingfan-0219, thanks for your interest in MInference.

The main reason why MInference performs less efficiently than full attention in short contexts is due to the following two factors:

  1. The cost of dynamically building a sparse index is relatively high compared to full attention in short contexts.
  2. MInference computes the attention heads sequentially, and using Triton in the kernel can reduce throughput and increase latency.

You can find more latency benchmark results at https://github.com/microsoft/MInference/tree/main/experiments#minference-benchmark-experiments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants
@iofu728 @yumingfan-0219 and others