diff --git a/minference/modules/minference_forward.py b/minference/modules/minference_forward.py index c0c99eb..eb0d612 100644 --- a/minference/modules/minference_forward.py +++ b/minference/modules/minference_forward.py @@ -784,20 +784,15 @@ def forward( shape = [num_tokens, num_heads * head_size] """ self.best_pattern = {int(ii): jj for ii, jj in pattern_config[layer_idx].items()} - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - slen, num_key_value_heads, head_dim = hidden_states.shape + def repeat_kv(hidden_states, n_rep): + sqlen, num_head, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, None, :, :].expand(slen, n_rep, num_key_value_heads, head_dim) - return hidden_states.reshape(slen, num_key_value_heads * n_rep, head_dim) + hidden_states = hidden_states[:, :, None, :].expand(sqlen, num_head, n_rep, head_dim) + return hidden_states.reshape(sqlen, num_head * n_rep, head_dim) def minference_prefill_func( q, k, v, - ): # (seq_len, num_heads, head_size) if q.size(-2) != k.size(-2): diff --git a/minference/patch.py b/minference/patch.py index 5bb7fbe..211b166 100644 --- a/minference/patch.py +++ b/minference/patch.py @@ -1090,7 +1090,7 @@ def update_module(m): llm.llm_engine.model_executor.driver_worker.model_runner.model.apply(update_module) - print("Patched model for minference with VLLM..") + print("Patched model for minference with vLLM..") return llm