Skip to content

Commit

Permalink
fix bugs in vllm patch
Browse files Browse the repository at this point in the history
  • Loading branch information
liyucheng09 committed Jul 2, 2024
1 parent 2c48613 commit ceba21b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
13 changes: 4 additions & 9 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,20 +784,15 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
self.best_pattern = {int(ii): jj for ii, jj in pattern_config[layer_idx].items()}
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
slen, num_key_value_heads, head_dim = hidden_states.shape
def repeat_kv(hidden_states, n_rep):
sqlen, num_head, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, None, :, :].expand(slen, n_rep, num_key_value_heads, head_dim)
return hidden_states.reshape(slen, num_key_value_heads * n_rep, head_dim)
hidden_states = hidden_states[:, :, None, :].expand(sqlen, num_head, n_rep, head_dim)
return hidden_states.reshape(sqlen, num_head * n_rep, head_dim)

def minference_prefill_func(
q, k, v,

):
# (seq_len, num_heads, head_size)
if q.size(-2) != k.size(-2):
Expand Down
2 changes: 1 addition & 1 deletion minference/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def update_module(m):

llm.llm_engine.model_executor.driver_worker.model_runner.model.apply(update_module)

print("Patched model for minference with VLLM..")
print("Patched model for minference with vLLM..")
return llm


Expand Down

0 comments on commit ceba21b

Please sign in to comment.