From 337f67ae9a6b9fbbd08536992bd960cfe00d57d1 Mon Sep 17 00:00:00 2001 From: Erez Schwartz Date: Mon, 18 Mar 2024 12:48:08 +0000 Subject: [PATCH 001/110] Merged in jurassic-2.5 (pull request #1) BA-78554: Jurassic 2.5 * worked on jurasic2.5 configuration file, updated jurassic2_5 modeling file to support alternating experts/attn layers * finished working the forward pass of jurassic3.py * finished working the forward pass of jurassic3.py * finished working the forward pass of jurassic3.py * jurassic_3 modeling file works, uses dummy weights initialized by "dummy" flag. Tokenizer raises issues, for now copying the mixtral tokenizer * changed default tokenizer vocab values, loading of custom .pt weight files works. * removed notebook * merging master to jurassic-2.5 to reset head * Merge branch 'master' into jurassic-2.5 * align to master Approved-by: Tomer Asida Approved-by: Mor Zusman --- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/jurassic3.py | 478 +++++++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 3 + vllm/transformers_utils/configs/jurassic3.py | 131 +++++ 5 files changed, 614 insertions(+) create mode 100644 vllm/model_executor/models/jurassic3.py create mode 100644 vllm/transformers_utils/configs/jurassic3.py diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 17fc9705680..3b79ffa3b72 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -54,6 +54,7 @@ "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + "Jurassic3ForCausalLM": ("jurassic3", "Jurassic3ForCausalLM") } # Architecture -> type. diff --git a/vllm/model_executor/models/jurassic3.py b/vllm/model_executor/models/jurassic3.py new file mode 100644 index 00000000000..5c8920bd854 --- /dev/null +++ b/vllm/model_executor/models/jurassic3.py @@ -0,0 +1,478 @@ +# coding=utf-8 + +"""Inference-only Jurassic model.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn + +from vllm.transformers_utils.configs.jurassic3 import Jurassic3Config +from vllm.config import LoRAConfig +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Jurassic3MoE(nn.Module): + """A tensor-parallel MoE implementation for Jurassic3 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if self.num_total_experts > 1: + # init expert router iff this layer has multiple experts + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + linear_method=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + if self.num_total_experts > 1: + router_logits, _ = self.gate(hidden_states) + else: + router_logits = torch.ones([hidden_states.shape[0], 1], device=hidden_states.device, + dtype=hidden_states.dtype) + + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=False, # Mixtral normalize the expert probs to 1. We don't! + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(batch_size, sequence_length, + hidden_size) + + +class Jurassic3Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + use_positional_embeddings: bool = False, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + self.use_positional_embeddings = use_positional_embeddings + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + if self.use_positional_embeddings: + # define positional embeddings conditioned on flag + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # TODO - add embedding flag + if self.use_positional_embeddings: + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Jurassic3DecoderLayer(nn.Module): + + def __init__( + self, + config: Jurassic3Config, + is_attn_layer: bool, + is_expert_layer: bool, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + + self.is_attn_layer = is_attn_layer + self.is_expert_layer = is_expert_layer + + if self.is_attn_layer: + self.self_attn = Jurassic3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + linear_method=linear_method) + else: + # TODO - Mor - add mamba implementation here + raise NotImplementedError + + actual_num_experts = config.num_experts if self.is_expert_layer else 1 + actual_num_experts_per_tok = config.num_experts_per_tok if self.is_expert_layer else 1 + + self.block_sparse_moe = Jurassic3MoE( + num_experts=actual_num_experts, + top_k=actual_num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class Jurassic3Model(nn.Module): + + def __init__( + self, + config: Jurassic3Config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + # init each model layer, decide if it's mamba/attention and has experts and pass it down + + module_list = [] + for i in range(config.num_hidden_layers): + is_attn = True if (i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0 else False + is_expert = True if (i - self.config.expert_layer_offset) % self.config.expert_layer_period == 0 else False + + module_list.append( + Jurassic3DecoderLayer( + config, + is_attn_layer=is_attn, + is_expert_layer=is_expert, + linear_method=linear_method + ) + ) + + self.layers = nn.ModuleList(module_list) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], input_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Jurassic3ForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: Jurassic3Config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = Jurassic3Model(config, + linear_method, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=True): # erez - might need to change later to False + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1756c91a612..bd40ec9b906 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -12,6 +12,7 @@ "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "jais": JAISConfig, + "jurassic3": Jurassic3Config } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 0e486928824..045c24cbd01 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,10 +7,13 @@ from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.jurassic3 import Jurassic3Config + __all__ = [ "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig", + "Jurassic3Config", ] diff --git a/vllm/transformers_utils/configs/jurassic3.py b/vllm/transformers_utils/configs/jurassic3.py new file mode 100644 index 00000000000..87140e7d4ee --- /dev/null +++ b/vllm/transformers_utils/configs/jurassic3.py @@ -0,0 +1,131 @@ +""" Jurassic3 model configuration""" +from transformers.configuration_utils import PretrainedConfig +from transformers import AutoConfig + + +class Jurassic3Config(PretrainedConfig): + r""" + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the Jurassic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`JurassicModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + # max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + # The maximum sequence length that this model might ever be used with. Jurassic's sliding window attention + # allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + use_positional_embeddings (`bool`, *optional, default False) + flag indicating whether to use positional embeddings or not + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_experts (`int`, *optional*, defaults to 16): + Number of experts per Sparse MLP layer. + expert_layer_period (`int`, *optional*, defaults to 2) + Once in this many layers, we will have an expert layer + expert_layer_offset(`int`, *optional*, defaults to 1) + The first layer index that contains an expert mlp layer + attn_layer_period (`int`, *optional*, defaults to 8) + Once in this many layers, we will have a vanilla attention layer + attn_layer_offset(`int`, *optional*, defaults to 4) + The first layer index that contains a vanilla attention mlp layer + """ + + model_type = "jurassic3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + use_positional_embeddings=False, + rope_theta=1e6, + sliding_window=None, + num_experts_per_tok=2, + num_experts=16, + expert_layer_offset=1, + expert_layer_period=2, + attn_layer_period=8, + attn_layer_offset=4, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_positional_embeddings = use_positional_embeddings + self.rope_theta = rope_theta + + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.expert_layer_period = expert_layer_period + self.expert_layer_offset = expert_layer_offset + self.attn_layer_period = attn_layer_period + self.attn_layer_offset = attn_layer_offset + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +AutoConfig.register('jurassic3', Jurassic3Config) From 0330e14253a8286b76420b80ef44c321deb4f046 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 10:14:30 +0300 Subject: [PATCH 002/110] Merged in jamba-3 (pull request #4) BA-78760: Jamba * Add support for n concat and splitting * change naming * input_metadata is a dict list now in order to pass "n" * clean up code from unecessary changes and prints * Remove kv cache allocation in case of mamba layer * Add the considerations of mamba layer cache into the num of blocks calculation * Delete mamba cache after profile * Remove prints * Cleaning * - and not _ for requirements Approved-by: Tomer Asida --- pyproject.toml | 2 + requirements-common.txt | 8 +- vllm/engine/llm_engine.py | 12 +++ vllm/model_executor/input_metadata.py | 55 +++++++++++++ vllm/model_executor/models/jurassic3.py | 103 +++++++++++++++++------- vllm/worker/cache_engine.py | 4 + vllm/worker/model_runner.py | 31 ++++++- vllm/worker/worker.py | 7 ++ 8 files changed, 190 insertions(+), 32 deletions(-) create mode 100644 vllm/model_executor/input_metadata.py diff --git a/pyproject.toml b/pyproject.toml index b870a4b8589..342011d0d29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,8 @@ requires = [ "setuptools >= 49.4.0", "torch == 2.2.1", "wheel", + "mamba-ssm", + "causal-conv1d" ] build-backend = "setuptools.build_meta" diff --git a/requirements-common.txt b/requirements-common.txt index c1614d2537b..da08df3721e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -10,8 +10,12 @@ fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 -tiktoken == 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.9.3 -outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +pynvml == 11.5.0 +triton >= 2.1.0 +outlines == 0.0.34 +tiktoken == 0.6.0 # Required for DBRX tokenizer +mamba-ssm +causal-conv1d diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f06c1d18ace..17421c85ed7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -651,8 +651,20 @@ def _process_model_outputs( self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. + finished_seq_groups_req_ids = [ + seq_group.request_id + for seq_group in self.scheduler.running + if seq_group.is_finished() + ] + + if len(finished_seq_groups_req_ids) > 0: + self._run_workers( + "release_mamba_cache", + finished_seq_groups_req_ids= finished_seq_groups_req_ids, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) self.scheduler.free_finished_seq_groups() + # Create the outputs. request_outputs: List[RequestOutput] = [] for scheduled_seq_group in scheduled_seq_groups: diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py new file mode 100644 index 00000000000..3a7b7f88241 --- /dev/null +++ b/vllm/model_executor/input_metadata.py @@ -0,0 +1,55 @@ +from typing import Optional + +import torch + + +class InputMetadata: + """Metadata for input sequences. Used in PagedAttention. + + Args: + prompt_lens: Lengths of prompts. + slot_mapping: The address to write the new KV to of each token. + max_context_len: The maximum context length. + context_lens: the length of attention context for each sequence. + block_tables: The block tables. (Seq id -> list of physical block) + kv_cache_dtype: Data type to store kv cache. + """ + + def __init__( + self, + is_prompt: bool, + slot_mapping: torch.Tensor, + prompt_lens: Optional[torch.Tensor], + max_seq_len: Optional[int], + start_loc: Optional[torch.Tensor], + max_context_len: Optional[int], + context_lens: Optional[torch.Tensor], + block_tables: Optional[torch.Tensor], + use_cuda_graph: bool, + kv_cache_dtype: str, + ) -> None: + self.is_prompt = is_prompt + self.prompt_lens = prompt_lens + self.max_seq_len = max_seq_len + self.start_loc = start_loc + self.max_context_len = max_context_len + self.slot_mapping = slot_mapping + self.context_lens = context_lens + self.block_tables = block_tables + self.use_cuda_graph = use_cuda_graph + self.kv_cache_dtype = kv_cache_dtype + + # Set during the execution of the first attention op. + # FIXME(woosuk): This is a hack. + self.attn_bias = None + self.mamba_metadata = None + + def __repr__(self) -> str: + return ("InputMetadata(" + f"is_prompt={self.is_prompt}, " + f"max_context_len={self.max_context_len}, " + f"slot_mapping={self.slot_mapping}, " + f"context_lens={self.context_lens}, " + f"block_tables={self.block_tables}, " + f"use_cuda_graph={self.use_cuda_graph}, " + f"kv_cache_dtype={self.kv_cache_dtype})") diff --git a/vllm/model_executor/models/jurassic3.py b/vllm/model_executor/models/jurassic3.py index 5c8920bd854..6d2ec714c59 100644 --- a/vllm/model_executor/models/jurassic3.py +++ b/vllm/model_executor/models/jurassic3.py @@ -5,6 +5,7 @@ import torch from torch import nn +import os from vllm.transformers_utils.configs.jurassic3 import Jurassic3Config from vllm.config import LoRAConfig @@ -29,6 +30,8 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.utils.generation import InferenceParams KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -130,17 +133,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_size) -class Jurassic3Attention(nn.Module): +class Jurassic3Mamba(nn.Module): + def __init__(self, hidden_size: int, layer_idx: int) -> None: + super().__init__() + self.layer_idx = layer_idx + self.mamba = Mamba(d_model=hidden_size, layer_idx=layer_idx) + + def forward(self, hidden_states: torch.Tensor, cache = None): + max_seqlen = int(os.environ.get("MAMBA_MAX_SEQLEN", "2048")) + inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=hidden_states.shape[0]) + if cache is not None: + inference_params.key_value_memory_dict[self.layer_idx] = cache + res = self.mamba(hidden_states, inference_params=inference_params) + return res, inference_params.key_value_memory_dict - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - use_positional_embeddings: bool = False, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: +class Jurassic3Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + use_positional_embeddings: bool = False, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -217,18 +235,19 @@ def forward( class Jurassic3DecoderLayer(nn.Module): - def __init__( - self, - config: Jurassic3Config, - is_attn_layer: bool, - is_expert_layer: bool, - linear_method: Optional[LinearMethodBase] = None, + self, + config: Jurassic3Config, + is_attn_layer: bool, + is_expert_layer: bool, + layer_idx: int, + linear_method: Optional[LinearMethodBase] = None ) -> None: super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) + self.layer_idx = layer_idx self.is_attn_layer = is_attn_layer self.is_expert_layer = is_expert_layer @@ -241,10 +260,10 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, - linear_method=linear_method) + linear_method=linear_method, + ) else: - # TODO - Mor - add mamba implementation here - raise NotImplementedError + self.mamba = Jurassic3Mamba(hidden_size=self.hidden_size,layer_idx=layer_idx) actual_num_experts = config.num_experts if self.is_expert_layer else 1 actual_num_experts_per_tok = config.num_experts_per_tok if self.is_expert_layer else 1 @@ -272,14 +291,40 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + if self.is_attn_layer: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + else: + cache = None + if not input_metadata.is_prompt: + for mamba_metadata in input_metadata.mamba_metadata: + # check if batch size of cache fits "n" + if mamba_metadata["cache"][self.layer_idx][0].shape[0] < mamba_metadata["n"]: + k_cache = mamba_metadata["cache"][self.layer_idx][0].repeat_interleave(mamba_metadata["n"],dim=0) + v_cache = mamba_metadata["cache"][self.layer_idx][1].repeat_interleave(mamba_metadata["n"],dim=0) + mamba_metadata["cache"][self.layer_idx] = (k_cache,v_cache) + + # mamba requires concatenated cache + if len(input_metadata.mamba_metadata) > 1: + k_cache = torch.concat([req["cache"][self.layer_idx][0] for req in input_metadata.mamba_metadata],dim=0) + v_cache = torch.concat([req["cache"][self.layer_idx][1] for req in input_metadata.mamba_metadata],dim=0) + cache = (k_cache,v_cache) + + hidden_states ,cache = self.mamba(hidden_states, cache=cache) + + sample_id = 0 + # split cache back to individual requests + for req_mamba_metadata in input_metadata.mamba_metadata: + n = req_mamba_metadata["n"] if not input_metadata.is_prompt else 1 + req_mamba_metadata["cache"][self.layer_idx] = (cache[self.layer_idx][0][sample_id:sample_id+n] + ,cache[self.layer_idx][1][sample_id:sample_id+n]) + sample_id += n + # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -289,7 +334,6 @@ def forward( class Jurassic3Model(nn.Module): - def __init__( self, config: Jurassic3Config, @@ -322,7 +366,8 @@ def __init__( config, is_attn_layer=is_attn, is_expert_layer=is_expert, - linear_method=linear_method + layer_idx=i, + linear_method=linear_method, ) ) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c34ee064862..9d877a7c4fc 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -89,6 +89,10 @@ def get_cache_block_size( head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) + is_mamba = model_config.hf_config.model_type == "jurassic3" + if is_mamba: + attention_period = model_config.hf_config.attn_layer_period + num_layers = num_layers // attention_period key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7dbe14ead09..2e5413a9d64 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,6 +6,7 @@ import numpy as np import torch import torch.nn as nn +from collections import defaultdict from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) @@ -149,6 +150,7 @@ def __init__( self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config + self.mamba_cache = defaultdict(lambda: {}) self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) @@ -811,7 +813,7 @@ def prepare_input_tensors( def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[torch.Tensor], + kv_caches: List[torch.Tensor] ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input @@ -845,6 +847,21 @@ def execute_model( if not sampling_metadata.perform_sampling: return None + mamba_metadata = self._get_mamba_caches_by_seq_group(seq_group_metadata_list) + input_metadata.mamba_metadata = mamba_metadata # list of caches + + hidden_states = model_executable( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata + ) + + if self.is_driver_worker: + for idx, seq_group_metadata in enumerate(seq_group_metadata_list): + request_id = seq_group_metadata.request_id + self.mamba_cache[request_id] = input_metadata.mamba_metadata[idx]["cache"] + # Sample the next token. output = self.model.sample( logits=logits, @@ -852,6 +869,17 @@ def execute_model( ) return output + def _get_mamba_caches_by_seq_group( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ): + if seq_group_metadata_list is None: + return [] + return [{ + "cache":self.mamba_cache[seq.request_id], + "n":seq.sampling_params.n, + } for seq in seq_group_metadata_list] + @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -917,6 +945,7 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() + self.mamba_cache = defaultdict(lambda: {}) return def remove_all_loras(self) -> bool: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82491c6df66..afa57b244a6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -203,6 +203,13 @@ def cache_swap( if blocks_to_copy: self.cache_engine.copy(blocks_to_copy) + + def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.model_runner.mamba_cache: + del self.model_runner.mamba_cache[req_id] + + @torch.inference_mode() def execute_model( self, From 07cc899c7c0ff2292cf3ac17aff0d056ecd68f26 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 2 Apr 2024 15:47:21 +0300 Subject: [PATCH 003/110] Jamba mamba (#3) * Remove assertion * adapting jamba vllm to changes after hf release, working on weight loading in modeling file * splitting the JambaDecoderLayer to JambaMambaDecoderLayer and JambaAttentionDecoderLayer * weight loading from hf checkpoint supposedly works, might be a mixup in the MoE between the gated and non-gated weights * Add mamba from jamba modeling file * Remove slow forward * Modifications to mamba_mixer * Save changes, WIP * Fix cache placement * Debugging * Additions and logging * Jamba with mamba cache handling * Clean up * Another cleanup * Use vllm's RMSNorm instead of JambaRMSNorm, Thier implementation is with fused kernel * Clean up and orginization of the objects to handle the mamba cache * Shorten the code for kv cache mem * Move cache handling inside the Mixer * Add mamba to the wheel requirements * Add mamba to the requirements script * Add mamba_metadata * Add to __init__ __all__ * Revert 2 commits ad1a3db 'Add mamba to the requirements script' 75ed2c8 'Add mamba to the wheel requirements' * Clean up * Naming * Apply whitespace suggestions from code review * pass tie_word_embeddings to PretrainedConfig init * Replace repeat with expand as expand doesn't require more mem * Allocate really small cache if needed , don't use meta * Fix for expanded --------- Co-authored-by: Mor Zusman Co-authored-by: Erez Schwartz Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/__init__.py | 4 + vllm/model_executor/input_metadata.py | 8 +- vllm/model_executor/mamba_metadata.py | 30 + vllm/model_executor/models/__init__.py | 24 +- vllm/model_executor/models/jamba.py | 674 ++++++++++++++++++ vllm/model_executor/models/jurassic3.py | 523 -------------- vllm/transformers_utils/config.py | 2 +- vllm/transformers_utils/configs/__init__.py | 4 +- .../configs/{jurassic3.py => jamba.py} | 46 +- vllm/worker/cache_engine.py | 3 +- vllm/worker/model_runner.py | 71 +- vllm/worker/worker.py | 9 +- 12 files changed, 825 insertions(+), 573 deletions(-) create mode 100644 vllm/model_executor/mamba_metadata.py create mode 100644 vllm/model_executor/models/jamba.py delete mode 100644 vllm/model_executor/models/jurassic3.py rename vllm/transformers_utils/configs/{jurassic3.py => jamba.py} (81%) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index fb98f4a6b46..1eedd593570 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,7 +1,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo, MambaCache __all__ = [ "SamplingMetadata", "set_random_seed", + "MambaCacheParams", + "RequestInfo", + "MambaCache", ] diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 3a7b7f88241..c6621864e8a 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,7 +1,9 @@ -from typing import Optional +from typing import Dict, List, Optional import torch +from vllm.model_executor.mamba_metadata import MambaCache, RequestInfo + class InputMetadata: """Metadata for input sequences. Used in PagedAttention. @@ -27,6 +29,7 @@ def __init__( block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, + requests_info: Optional[List[RequestInfo]] = None ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -42,7 +45,8 @@ def __init__( # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None - self.mamba_metadata = None + self.mamba_cache_batch: List[MambaCache] = [] + self.requests_info = requests_info def __repr__(self) -> str: return ("InputMetadata(" diff --git a/vllm/model_executor/mamba_metadata.py b/vllm/model_executor/mamba_metadata.py new file mode 100644 index 00000000000..225f6016ece --- /dev/null +++ b/vllm/model_executor/mamba_metadata.py @@ -0,0 +1,30 @@ +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Optional, Tuple +import torch + +@dataclass +class MambaCacheParams: + seqlen_offset: int = 0 + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + +@dataclass +class RequestInfo: + request_id: str = '' + n: int = 1 + + +class MambaCache: + def __init__( + self, + request_info: RequestInfo, + layer_idx2mamba_cache: Optional[Dict[int, MambaCacheParams]] = None + ) -> None: + self.request_info = request_info + if layer_idx2mamba_cache is None: + self.layer_idx2mamba_cache = defaultdict(MambaCacheParams) + else: + self.layer_idx2mamba_cache = layer_idx2mamba_cache + diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 3b79ffa3b72..aa7dc4b775e 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -31,8 +31,7 @@ "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - "LlavaForConditionalGeneration": - ("llava", "LlavaForConditionalGeneration"), + "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), @@ -54,7 +53,7 @@ "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - "Jurassic3ForCausalLM": ("jurassic3", "Jurassic3ForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM") } # Architecture -> type. @@ -67,17 +66,13 @@ # Models partially supported by ROCm. # Architecture -> Reason. _ROCM_PARTIALLY_SUPPORTED_MODELS = { - "Qwen2ForCausalLM": - "Sliding window attention is not yet supported in ROCm's flash attention", - "MistralForCausalLM": - "Sliding window attention is not yet supported in ROCm's flash attention", - "MixtralForCausalLM": - "Sliding window attention is not yet supported in ROCm's flash attention", + "Qwen2ForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention", + "MistralForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention", + "MixtralForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention", } class ModelRegistry: - @staticmethod def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch in _OOT_MODELS: @@ -88,15 +83,16 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError( f"Model architecture {model_arch} is not supported by " - "ROCm for now.") + "ROCm for now." + ) if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: logger.warning( f"Model architecture {model_arch} is partially supported " - "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] + ) module_name, model_cls_name = _MODELS[model_arch] - module = importlib.import_module( - f"vllm.model_executor.models.{module_name}") + module = importlib.import_module(f"vllm.model_executor.models.{module_name}") return getattr(module, model_cls_name, None) @staticmethod diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py new file mode 100644 index 00000000000..17a09f18dd3 --- /dev/null +++ b/vllm/model_executor/models/jamba.py @@ -0,0 +1,674 @@ +# coding=utf-8 + +"""Inference-only Jurassic model.""" +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import torch +from torch import conv_transpose3d, nn +import os +from vllm.model_executor.mamba_metadata import MambaCacheParams + +from vllm.transformers_utils.configs.jamba import JambaConfig +from transformers.activations import ACT2FN +from vllm.config import LoRAConfig +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput +from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + +KVCache = Tuple[torch.Tensor, torch.Tensor] + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class JambaMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: JambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.time_step_rank = config.mamba_dt_rank + self.use_conv_bias = config.mamba_conv_bias + self.use_bias = config.mamba_proj_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=self.use_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.intermediate_size, + padding=self.conv_kernel_size - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.apply_inner_layernorms = config.mamba_inner_layernorms + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) + # selective projection used to make dt, B and C input dependant + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if self.apply_inner_layernorms: + self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.B_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.C_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + else: + self.dt_layernorm = None + self.B_layernorm = None + self.C_layernorm = None + + def _apply_layernorms(self, dt, B, C): + if self.dt_layernorm is not None: + dt = self.dt_layernorm.forward(dt.contiguous()) + if self.B_layernorm is not None: + B = self.B_layernorm.forward(B.contiguous()) + if self.C_layernorm is not None: + C = self.C_layernorm.forward(C.contiguous()) + return dt, B, C + + def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_params.seqlen_offset > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_state.copy_(conv_states) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + time_step, B, C = self._apply_layernorms(time_step, B, C) + + # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. + # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed + # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized + # linear layers, and requires to call the forward pass directly. + # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)``` + dt_proj_bias = self.dt_proj.bias + self.dt_proj.bias = None + discrete_time_step = self.dt_proj(time_step).transpose(1, 2) + self.dt_proj.bias = dt_proj_bias + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_params.seqlen_offset > 0: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata): + if input_metadata.is_prompt: + batch_size = hidden_states.shape[0] + conv_cache = torch.zeros( + batch_size, + self.config.mamba_expand * self.config.hidden_size, + self.config.mamba_d_conv, + device=hidden_states.device, + dtype=hidden_states.dtype + ) + ssm_cache = torch.zeros( + batch_size, + self.config.mamba_expand * self.config.hidden_size, + self.config.mamba_d_state, + device=hidden_states.device, + dtype=hidden_states.dtype + ) + cache = MambaCacheParams(0, conv_cache, ssm_cache) + else: + for mamba_cache_request in input_metadata.mamba_cache_batch: + # check if batch size of cache fits "n" + n = mamba_cache_request.request_info.n + if mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.shape[0] < n: + expanded_dims_conv = (n, *mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.shape[1:]) + conv_state = mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.expand(*expanded_dims_conv) + expanded_dims_ssm = (n, *mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state.shape[1:]) + ssm_state = mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state.expand(*expanded_dims_ssm) + mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state = conv_state + mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state = ssm_state + + # mamba requires concatenated cache + conv_state = torch.concat([req.layer_idx2mamba_cache[self.layer_idx].conv_state for req in input_metadata.mamba_cache_batch], dim=0) + ssm_state = torch.concat([req.layer_idx2mamba_cache[self.layer_idx].ssm_state for req in input_metadata.mamba_cache_batch], dim=0) + cache = MambaCacheParams(1, conv_state, ssm_state) + hidden_states = self.mamba_forward(hidden_states, cache_params=cache) + + # split cache back to individual requests + sample_id = 0 + for req_mamba_metadata in input_metadata.mamba_cache_batch: + n = 1 if input_metadata.is_prompt else req_mamba_metadata.request_info.n + req_mamba_metadata.layer_idx2mamba_cache[self.layer_idx].conv_state=cache.conv_state[sample_id:sample_id + n] + req_mamba_metadata.layer_idx2mamba_cache[self.layer_idx].ssm_state=cache.ssm_state[sample_id:sample_id + n] + sample_id += n + + return hidden_states + + + + +class JambaMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if self.num_total_experts > 1: + # init expert router iff this layer has multiple experts + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + linear_method=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("gate_proj.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("up_proj.weight"): + param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("down_proj.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + if self.num_total_experts > 1: + router_logits, _ = self.router(hidden_states) + else: + router_logits = torch.ones([hidden_states.shape[0], 1], device=hidden_states.device, + dtype=hidden_states.dtype) + + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=False, # Mixtral normalize the expert probs to 1. We don't! + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(batch_size, sequence_length, + hidden_size) + + +class JambaMambaDecoderLayer(nn.Module): + def __init__(self, config: JambaConfig, actual_num_experts: int, actual_num_experts_per_tok: int ,layer_idx: int) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mamba = JambaMambaMixer(config, layer_idx) + self.moe = JambaMoE( + num_experts=actual_num_experts, + top_k=actual_num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward(self, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + **kwargs): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.mamba(hidden_states, input_metadata) + # Fully Connected + hidden_states, residual = self.pre_moe_layernorm( + hidden_states, residual) + hidden_states = self.moe(hidden_states) + return hidden_states, residual + + +class JambaAttentionDecoderLayer(nn.Module): + def __init__( + self, config: JambaConfig, actual_num_experts: int, actual_num_experts_per_tok: int ,layer_idx: int, linear_method: Optional[LinearMethodBase] = None + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + self.use_positional_embeddings = False + self.sliding_window = config.sliding_window + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) + + + self.moe = JambaMoE( + num_experts=actual_num_experts, + top_k=actual_num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention(self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # TODO - add embedding flag + if self.use_positional_embeddings: + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor]): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_moe_layernorm( + hidden_states, residual) + hidden_states = self.moe(hidden_states) + return hidden_states, residual + + +class JambaModel(nn.Module): + def __init__( + self, + config: JambaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + # init each model layer, decide if it's mamba/attention and has experts and pass it down + + module_list = [] + for i in range(config.num_hidden_layers): + is_attn = True if (i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0 else False + is_expert = True if (i - self.config.expert_layer_offset) % self.config.expert_layer_period == 0 else False + + actual_num_experts = config.num_experts if is_expert else 1 + actual_num_experts_per_tok = config.num_experts_per_tok if is_expert else 1 + + if is_attn: + module_list.append(JambaAttentionDecoderLayer(config, + actual_num_experts=actual_num_experts, + actual_num_experts_per_tok=actual_num_experts_per_tok, + layer_idx=i, + linear_method=linear_method + )) + else: + module_list.append(JambaMambaDecoderLayer(config, + actual_num_experts=actual_num_experts, + actual_num_experts_per_tok=actual_num_experts_per_tok, + layer_idx=i + )) + + self.layers = nn.ModuleList(module_list) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + + hidden_states, residual = layer(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_caches[i], + input_metadata=input_metadata, + residual=residual) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class JambaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: JambaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = JambaModel(config, + linear_method, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["gate_proj", "up_proj"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_experts) + for weight_name in ["down_proj", "up_proj", "gate_proj"] + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=True): # erez - might need to change later to False + if "rotary_emb.inv_freq" in name: + continue + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/jurassic3.py b/vllm/model_executor/models/jurassic3.py deleted file mode 100644 index 6d2ec714c59..00000000000 --- a/vllm/model_executor/models/jurassic3.py +++ /dev/null @@ -1,523 +0,0 @@ -# coding=utf-8 - -"""Inference-only Jurassic model.""" -from typing import List, Optional, Tuple - -import torch -from torch import nn -import os - -from vllm.transformers_utils.configs.jurassic3 import Jurassic3Config -from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput -from mamba_ssm.modules.mamba_simple import Mamba -from mamba_ssm.utils.generation import InferenceParams - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class Jurassic3MoE(nn.Module): - """A tensor-parallel MoE implementation for Jurassic3 that shards each expert - across all ranks. - - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - ): - super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - if self.num_total_experts > 1: - # init expert router iff this layer has multiple experts - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - linear_method=None) - - self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype)) - self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) - - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_size = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) - # router_logits: (batch * sequence_length, n_experts) - if self.num_total_experts > 1: - router_logits, _ = self.gate(hidden_states) - else: - router_logits = torch.ones([hidden_states.shape[0], 1], device=hidden_states.device, - dtype=hidden_states.dtype) - - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=False, # Mixtral normalize the expert probs to 1. We don't! - inplace=True) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(batch_size, sequence_length, - hidden_size) - - -class Jurassic3Mamba(nn.Module): - def __init__(self, hidden_size: int, layer_idx: int) -> None: - super().__init__() - self.layer_idx = layer_idx - self.mamba = Mamba(d_model=hidden_size, layer_idx=layer_idx) - - def forward(self, hidden_states: torch.Tensor, cache = None): - max_seqlen = int(os.environ.get("MAMBA_MAX_SEQLEN", "2048")) - inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=hidden_states.shape[0]) - if cache is not None: - inference_params.key_value_memory_dict[self.layer_idx] = cache - res = self.mamba(hidden_states, inference_params=inference_params) - return res, inference_params.key_value_memory_dict - -class Jurassic3Attention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - use_positional_embeddings: bool = False, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim ** -0.5 - self.use_positional_embeddings = use_positional_embeddings - self.rope_theta = rope_theta - self.sliding_window = sliding_window - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - - if self.use_positional_embeddings: - # define positional embeddings conditioned on flag - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, - ) - - self.attn = PagedAttention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # TODO - add embedding flag - if self.use_positional_embeddings: - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class Jurassic3DecoderLayer(nn.Module): - def __init__( - self, - config: Jurassic3Config, - is_attn_layer: bool, - is_expert_layer: bool, - layer_idx: int, - linear_method: Optional[LinearMethodBase] = None - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.layer_idx = layer_idx - - self.is_attn_layer = is_attn_layer - self.is_expert_layer = is_expert_layer - - if self.is_attn_layer: - self.self_attn = Jurassic3Attention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - sliding_window=config.sliding_window, - linear_method=linear_method, - ) - else: - self.mamba = Jurassic3Mamba(hidden_size=self.hidden_size,layer_idx=layer_idx) - - actual_num_experts = config.num_experts if self.is_expert_layer else 1 - actual_num_experts_per_tok = config.num_experts_per_tok if self.is_expert_layer else 1 - - self.block_sparse_moe = Jurassic3MoE( - num_experts=actual_num_experts, - top_k=actual_num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - if self.is_attn_layer: - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - else: - cache = None - if not input_metadata.is_prompt: - for mamba_metadata in input_metadata.mamba_metadata: - # check if batch size of cache fits "n" - if mamba_metadata["cache"][self.layer_idx][0].shape[0] < mamba_metadata["n"]: - k_cache = mamba_metadata["cache"][self.layer_idx][0].repeat_interleave(mamba_metadata["n"],dim=0) - v_cache = mamba_metadata["cache"][self.layer_idx][1].repeat_interleave(mamba_metadata["n"],dim=0) - mamba_metadata["cache"][self.layer_idx] = (k_cache,v_cache) - - # mamba requires concatenated cache - if len(input_metadata.mamba_metadata) > 1: - k_cache = torch.concat([req["cache"][self.layer_idx][0] for req in input_metadata.mamba_metadata],dim=0) - v_cache = torch.concat([req["cache"][self.layer_idx][1] for req in input_metadata.mamba_metadata],dim=0) - cache = (k_cache,v_cache) - - hidden_states ,cache = self.mamba(hidden_states, cache=cache) - - sample_id = 0 - # split cache back to individual requests - for req_mamba_metadata in input_metadata.mamba_metadata: - n = req_mamba_metadata["n"] if not input_metadata.is_prompt else 1 - req_mamba_metadata["cache"][self.layer_idx] = (cache[self.layer_idx][0][sample_id:sample_id+n] - ,cache[self.layer_idx][1][sample_id:sample_id+n]) - sample_id += n - - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.block_sparse_moe(hidden_states) - return hidden_states, residual - - -class Jurassic3Model(nn.Module): - def __init__( - self, - config: Jurassic3Config, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - # init each model layer, decide if it's mamba/attention and has experts and pass it down - - module_list = [] - for i in range(config.num_hidden_layers): - is_attn = True if (i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0 else False - is_expert = True if (i - self.config.expert_layer_offset) % self.config.expert_layer_period == 0 else False - - module_list.append( - Jurassic3DecoderLayer( - config, - is_attn_layer=is_attn, - is_expert_layer=is_expert, - layer_idx=i, - linear_method=linear_method, - ) - ) - - self.layers = nn.ModuleList(module_list) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], input_metadata, - residual) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class Jurassic3ForCausalLM(nn.Module): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__( - self, - config: Jurassic3Config, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = Jurassic3Model(config, - linear_method, - lora_config=lora_config) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.config.num_experts) - for weight_name in ["w1", "w2", "w3"] - ] - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=True): # erez - might need to change later to False - if "rotary_emb.inv_freq" in name: - continue - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, expert_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index bd40ec9b906..36edbbfce30 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -12,7 +12,7 @@ "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "jais": JAISConfig, - "jurassic3": Jurassic3Config + "jamba": JambaConfig } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 045c24cbd01..42084ac0067 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,7 +7,7 @@ from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.mpt import MPTConfig -from vllm.transformers_utils.configs.jurassic3 import Jurassic3Config +from vllm.transformers_utils.configs.jamba import JambaConfig __all__ = [ "ChatGLMConfig", @@ -15,5 +15,5 @@ "MPTConfig", "RWConfig", "JAISConfig", - "Jurassic3Config", + "JambaConfig" ] diff --git a/vllm/transformers_utils/configs/jurassic3.py b/vllm/transformers_utils/configs/jamba.py similarity index 81% rename from vllm/transformers_utils/configs/jurassic3.py rename to vllm/transformers_utils/configs/jamba.py index 87140e7d4ee..7c58fe35a87 100644 --- a/vllm/transformers_utils/configs/jurassic3.py +++ b/vllm/transformers_utils/configs/jamba.py @@ -1,9 +1,10 @@ -""" Jurassic3 model configuration""" +""" Jamba model configuration""" +import math from transformers.configuration_utils import PretrainedConfig from transformers import AutoConfig -class Jurassic3Config(PretrainedConfig): +class JambaConfig(PretrainedConfig): r""" Args: vocab_size (`int`, *optional*, defaults to 65536): @@ -63,43 +64,53 @@ class Jurassic3Config(PretrainedConfig): The first layer index that contains a vanilla attention mlp layer """ - model_type = "jurassic3" + model_type = "jamba" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, - vocab_size=32000, + vocab_size=65536, + tie_word_embeddings=False, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act="silu", - max_position_embeddings=4096 * 32, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, - pad_token_id=None, + output_router_logits=False, + router_aux_loss_coef=0.001, + pad_token_id=0, bos_token_id=1, eos_token_id=2, - use_positional_embeddings=False, - rope_theta=1e6, sliding_window=None, + attention_dropout=0.0, num_experts_per_tok=2, num_experts=16, expert_layer_offset=1, expert_layer_period=2, attn_layer_period=8, attn_layer_offset=4, + use_mamba_kernels=True, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_dt_rank="auto", + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_inner_layernorms=True, **kwargs, ): self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings + self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window + self.attention_dropout = attention_dropout # for backward compatibility if num_key_value_heads is None: @@ -109,9 +120,10 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache - self.use_positional_embeddings = use_positional_embeddings - self.rope_theta = rope_theta + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts @@ -120,12 +132,22 @@ def __init__( self.attn_layer_period = attn_layer_period self.attn_layer_offset = attn_layer_offset + self.use_mamba_kernels = use_mamba_kernels + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.mamba_inner_layernorms = mamba_inner_layernorms + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, **kwargs, ) -AutoConfig.register('jurassic3', Jurassic3Config) +AutoConfig.register('jamba', JambaConfig) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 9d877a7c4fc..7fea1bdceb8 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -89,7 +89,8 @@ def get_cache_block_size( head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) - is_mamba = model_config.hf_config.model_type == "jurassic3" + is_mamba = model_config.hf_config.model_type == "jamba" + if is_mamba: attention_period = model_config.hf_config.attn_layer_period num_layers = num_layers // attention_period diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2e5413a9d64..badca69520f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,6 +21,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata +from vllm.model_executor.mamba_metadata import RequestInfo from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, @@ -151,6 +152,7 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config self.mamba_cache = defaultdict(lambda: {}) + self.request_id2mamba_cache: Dict[str, MambaCache] = {} self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) @@ -403,6 +405,14 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + kv_cache_dtype=self.kv_cache_dtype, + requests_info=[ + RequestInfo( + request_id=req.request_id, + n=req.sampling_params.n + ) + for req in seq_group_metadata_list + ] ) return PreparePromptMetadata( @@ -533,6 +543,14 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, + kv_cache_dtype=self.kv_cache_dtype, + requests_info=[ + RequestInfo( + request_id=req.request_id, + n=req.sampling_params.n + ) + for req in seq_group_metadata_list] + ) return PrepareDecodeMetadata( input_tokens=input_tokens, @@ -740,6 +758,13 @@ def prepare_input_tensors( "slot_mapping": slot_mapping, "num_prefills": num_prefills, "batch_type": batch_type, + "requests_info": [ + RequestInfo( + request_id=req.request_id, + n=req.sampling_params.n + ) + for req in seq_group_metadata_list + ] } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) @@ -778,6 +803,24 @@ def prepare_input_tensors( else: decode_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + input_tokens = metadata_dict["input_tokens"] + input_positions = metadata_dict["input_positions"] + lora_mapping = metadata_dict["lora_mapping"] + lora_requests = metadata_dict["lora_requests"] + input_metadata = InputMetadata( + is_prompt=metadata_dict["is_prompt"], + slot_mapping=metadata_dict["slot_mapping"], + prompt_lens=metadata_dict["prompt_lens"], + max_seq_len=metadata_dict["max_seq_len"], + start_loc=metadata_dict["start_loc"], + max_context_len=metadata_dict["max_context_len"], + context_lens=metadata_dict["context_lens"], + block_tables=metadata_dict["block_tables"], + use_cuda_graph=metadata_dict["use_cuda_graph"], + kv_cache_dtype=metadata_dict["kv_cache_dtype"], + requests_info=metadata_dict["requests_info"] + ) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -847,8 +890,8 @@ def execute_model( if not sampling_metadata.perform_sampling: return None - mamba_metadata = self._get_mamba_caches_by_seq_group(seq_group_metadata_list) - input_metadata.mamba_metadata = mamba_metadata # list of caches + batch_mamba_cache = self._prepare_mamba_cache(input_metadata) + input_metadata.mamba_cache_batch = batch_mamba_cache # list of caches hidden_states = model_executable( input_ids=input_tokens, @@ -857,10 +900,8 @@ def execute_model( input_metadata=input_metadata ) - if self.is_driver_worker: - for idx, seq_group_metadata in enumerate(seq_group_metadata_list): - request_id = seq_group_metadata.request_id - self.mamba_cache[request_id] = input_metadata.mamba_metadata[idx]["cache"] + for request_mamba_cache in input_metadata.mamba_cache_batch: + self.request_id2mamba_cache[request_mamba_cache.request_info.request_id] = request_mamba_cache # Sample the next token. output = self.model.sample( @@ -869,16 +910,14 @@ def execute_model( ) return output - def _get_mamba_caches_by_seq_group( + def _prepare_mamba_cache( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] - ): - if seq_group_metadata_list is None: - return [] - return [{ - "cache":self.mamba_cache[seq.request_id], - "n":seq.sampling_params.n, - } for seq in seq_group_metadata_list] + input_metadata: InputMetadata + ) -> List[MambaCache]: + return [self.request_id2mamba_cache.get( + request_info.request_id, + MambaCache(request_info) + ) for request_info in input_metadata.requests_info] @torch.inference_mode() def profile_run(self) -> None: @@ -945,7 +984,7 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() - self.mamba_cache = defaultdict(lambda: {}) + self.request_id2mamba_cache = {} return def remove_all_loras(self) -> bool: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index afa57b244a6..140c00ea130 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,7 +21,11 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase +from vllm.lora.request import LoRARequest +from vllm.utils import is_hip +from vllm.logger import init_logger +logger = init_logger(__name__) class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -206,8 +210,9 @@ def cache_swap( def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: - if req_id in self.model_runner.mamba_cache: - del self.model_runner.mamba_cache[req_id] + if req_id in self.model_runner.request_id2mamba_cache: + del self.model_runner.request_id2mamba_cache[req_id] + logger.info(f"deleted { req_id } from mamba_cache") @torch.inference_mode() From 6d336f6eb0eddadf4b33a9c63632eec90d69aa18 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 8 Apr 2024 13:16:23 +0300 Subject: [PATCH 004/110] Cuda graph (#5) * Drop indecies when finish * min 1 attention layer * CG is working on forward pass passing * Remove comments * cosmetics - rename indecies -> indices, organize some whitespaces * Add some TODOs * Adding mamba cache for cg * Remove useless vars from input_metadata * Remove unused import * Set the seqlen offset to boolean * Return only hidden state * Return only hidden states * Add padding to match forward pass bs * Is prompt instead of seqlen offset * Remove mamba cache class (not used) * Another remove * Remove * Use mamba4gc * Fix mamba forward, run update only on non prompt * Use 1 index after the maximal index * Remove import * Remove import * typo * typo * place holder * Padding and empty token takes it from the first empty place * reformat * Apply suggestions from code review Whitespaces --------- Co-authored-by: Mor Zusman Co-authored-by: Tomer Asida Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/__init__.py | 2 + vllm/model_executor/input_metadata.py | 3 +- vllm/model_executor/mamba_metadata.py | 14 +--- vllm/model_executor/models/jamba.py | 90 ++++++++++------------- vllm/worker/cache_engine.py | 2 +- vllm/worker/model_runner.py | 101 ++++++++++++++++++++++---- vllm/worker/worker.py | 7 +- 7 files changed, 132 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 1eedd593570..8fbbdf06526 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,6 +1,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo, MambaCache +from vllm.model_executor.utils import set_random_seed __all__ = [ "SamplingMetadata", @@ -8,4 +9,5 @@ "MambaCacheParams", "RequestInfo", "MambaCache", + "RequestInfo" ] diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index c6621864e8a..a63fa2ba212 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -2,7 +2,7 @@ import torch -from vllm.model_executor.mamba_metadata import MambaCache, RequestInfo +from vllm.model_executor.mamba_metadata import RequestInfo class InputMetadata: @@ -45,7 +45,6 @@ def __init__( # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None - self.mamba_cache_batch: List[MambaCache] = [] self.requests_info = requests_info def __repr__(self) -> str: diff --git a/vllm/model_executor/mamba_metadata.py b/vllm/model_executor/mamba_metadata.py index 225f6016ece..aa7346f33fb 100644 --- a/vllm/model_executor/mamba_metadata.py +++ b/vllm/model_executor/mamba_metadata.py @@ -5,7 +5,7 @@ @dataclass class MambaCacheParams: - seqlen_offset: int = 0 + is_prompt: bool = False conv_state: torch.Tensor = torch.Tensor() ssm_state: torch.Tensor = torch.Tensor() @@ -16,15 +16,3 @@ class RequestInfo: n: int = 1 -class MambaCache: - def __init__( - self, - request_info: RequestInfo, - layer_idx2mamba_cache: Optional[Dict[int, MambaCacheParams]] = None - ) -> None: - self.request_info = request_info - if layer_idx2mamba_cache is None: - self.layer_idx2mamba_cache = defaultdict(MambaCacheParams) - else: - self.layer_idx2mamba_cache = layer_idx2mamba_cache - diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 17a09f18dd3..6dbd515458c 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -33,7 +33,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput -from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -114,7 +114,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and not cache_params.is_prompt: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), cache_params.conv_state, @@ -154,7 +154,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar A = -torch.exp(self.A_log.float()) # 3.c perform the recurrence y ← SSM(A, B, C)(x) time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and not cache_params.is_prompt: scan_outputs = selective_state_update( cache_params.ssm_state, hidden_states[..., 0], @@ -187,50 +187,14 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata): - if input_metadata.is_prompt: - batch_size = hidden_states.shape[0] - conv_cache = torch.zeros( - batch_size, - self.config.mamba_expand * self.config.hidden_size, - self.config.mamba_d_conv, - device=hidden_states.device, - dtype=hidden_states.dtype - ) - ssm_cache = torch.zeros( - batch_size, - self.config.mamba_expand * self.config.hidden_size, - self.config.mamba_d_state, - device=hidden_states.device, - dtype=hidden_states.dtype - ) - cache = MambaCacheParams(0, conv_cache, ssm_cache) - else: - for mamba_cache_request in input_metadata.mamba_cache_batch: - # check if batch size of cache fits "n" - n = mamba_cache_request.request_info.n - if mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.shape[0] < n: - expanded_dims_conv = (n, *mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.shape[1:]) - conv_state = mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state.expand(*expanded_dims_conv) - expanded_dims_ssm = (n, *mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state.shape[1:]) - ssm_state = mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state.expand(*expanded_dims_ssm) - mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].conv_state = conv_state - mamba_cache_request.layer_idx2mamba_cache[self.layer_idx].ssm_state = ssm_state - - # mamba requires concatenated cache - conv_state = torch.concat([req.layer_idx2mamba_cache[self.layer_idx].conv_state for req in input_metadata.mamba_cache_batch], dim=0) - ssm_state = torch.concat([req.layer_idx2mamba_cache[self.layer_idx].ssm_state for req in input_metadata.mamba_cache_batch], dim=0) - cache = MambaCacheParams(1, conv_state, ssm_state) + def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): + cache = MambaCacheParams( + input_metadata.is_prompt, + conv_state=conv_state[self.layer_idx], + ssm_state=ssm_state[self.layer_idx] + ) hidden_states = self.mamba_forward(hidden_states, cache_params=cache) - # split cache back to individual requests - sample_id = 0 - for req_mamba_metadata in input_metadata.mamba_cache_batch: - n = 1 if input_metadata.is_prompt else req_mamba_metadata.request_info.n - req_mamba_metadata.layer_idx2mamba_cache[self.layer_idx].conv_state=cache.conv_state[sample_id:sample_id + n] - req_mamba_metadata.layer_idx2mamba_cache[self.layer_idx].ssm_state=cache.ssm_state[sample_id:sample_id + n] - sample_id += n - return hidden_states @@ -352,6 +316,8 @@ def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata, residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, **kwargs): if residual is None: @@ -360,7 +326,12 @@ def forward(self, else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.mamba(hidden_states, input_metadata) + hidden_states = self.mamba( + hidden_states, + input_metadata, + conv_state, + ssm_state + ) # Fully Connected hidden_states, residual = self.pre_moe_layernorm( hidden_states, residual) @@ -433,7 +404,8 @@ def self_attention(self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: KVCache, - input_metadata: InputMetadata) -> torch.Tensor: + input_metadata: InputMetadata, + **kwargs) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # TODO - add embedding flag @@ -450,7 +422,8 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - residual: Optional[torch.Tensor]): + residual: Optional[torch.Tensor], + **kwargs): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -524,6 +497,8 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -534,7 +509,10 @@ def forward( hidden_states=hidden_states, kv_cache=kv_caches[i], input_metadata=input_metadata, - residual=residual) + residual=residual, + conv_state=conv_state, + ssm_state=ssm_state + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -593,9 +571,17 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + conv_state: torch.Tensor, + ssm_state: torch.Tensor + ): + hidden_states = self.model( + input_ids, + positions, + kv_caches, + input_metadata, + conv_state, + ssm_state + ) return hidden_states def sample( diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 7fea1bdceb8..66c76757b89 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -93,7 +93,7 @@ def get_cache_block_size( if is_mamba: attention_period = model_config.hf_config.attn_layer_period - num_layers = num_layers // attention_period + num_layers = max(num_layers // attention_period, 1) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index badca69520f..9cd8411773d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -151,12 +151,44 @@ def __init__( self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config - self.mamba_cache = defaultdict(lambda: {}) - self.request_id2mamba_cache: Dict[str, MambaCache] = {} + # cache in_wsl result + self.mamba_cache = None + self.mamba_cache4gc = None + self.request_id2index = {} + self.in_wsl = in_wsl() + self.kv_cache_dtype = kv_cache_dtype self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) + @torch.inference_mode() + def prepare_contiguous_mamba_cache(self): + is_mamba = self.model_config.hf_config.model_type == "jamba" + if not is_mamba: + return + hf_config = self.model_config.hf_config + num_layers = hf_config.num_hidden_layers + max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + conv_state_shape = ( + num_layers, + max_batch_size, + hf_config.mamba_expand * hf_config.hidden_size, + hf_config.mamba_d_conv, + ) + ssm_state_shape = ( + num_layers, + max_batch_size, + hf_config.mamba_expand * hf_config.hidden_size, + hf_config.mamba_d_state, + ) + if self.mamba_cache is None: + self.mamba_cache = {} + self.mamba_cache = (torch.empty(size=conv_state_shape, dtype=torch.float16, device="cuda"), + torch.empty(size=ssm_state_shape, dtype=torch.float16, device="cuda")) + self.mamba_cache4gc = (torch.empty(size=conv_state_shape, dtype=torch.float16, device="cuda"), + torch.empty(size=ssm_state_shape, dtype=torch.float16, device="cuda")) + + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -890,18 +922,22 @@ def execute_model( if not sampling_metadata.perform_sampling: return None - batch_mamba_cache = self._prepare_mamba_cache(input_metadata) - input_metadata.mamba_cache_batch = batch_mamba_cache # list of caches + if self.mamba_cache is None: + self.prepare_contiguous_mamba_cache() + + conv_state, ssm_state, indecies = self._prepare_request_mamba_cache(input_metadata, input_tokens.shape[0]) hidden_states = model_executable( input_ids=input_tokens, positions=input_positions, kv_caches=kv_caches, - input_metadata=input_metadata + input_metadata=input_metadata, + conv_state=conv_state, + ssm_state=ssm_state ) - - for request_mamba_cache in input_metadata.mamba_cache_batch: - self.request_id2mamba_cache[request_mamba_cache.request_info.request_id] = request_mamba_cache + for i,offset in enumerate(indecies): + self.mamba_cache[0][:,offset] = conv_state[:,i] + self.mamba_cache[1][:,offset] = ssm_state[:,i] # Sample the next token. output = self.model.sample( @@ -910,14 +946,26 @@ def execute_model( ) return output - def _prepare_mamba_cache( + def _prepare_request_mamba_cache( self, - input_metadata: InputMetadata - ) -> List[MambaCache]: - return [self.request_id2mamba_cache.get( - request_info.request_id, - MambaCache(request_info) - ) for request_info in input_metadata.requests_info] + input_metadata: InputMetadata, + batch_size: int + ): + indices = [] + max_possible_bs = self.mamba_cache[0].shape[1] + for request_info in input_metadata.requests_info: + if request_info.request_id not in self.request_id2index: + first_free_index = [i not in self.request_id2index.values() for i in range(max_possible_bs)].index(True) + self.request_id2index[request_info.request_id] = first_free_index + indices.append(self.request_id2index[request_info.request_id]) + ## Pad the batch incase of running batch that was not captured via CG + padded_indices = indices + for _ in range(batch_size - len(indices)): + padded_indices += [[i not in set(self.request_id2index.values()).union(padded_indices) for i in range(max_possible_bs)].index(True)] + + conv_state = self.mamba_cache[0][:,padded_indices] + ssm_state = self.mamba_cache[1][:,padded_indices] + return conv_state,ssm_state,indices @torch.inference_mode() def profile_run(self) -> None: @@ -984,7 +1032,7 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() - self.request_id2mamba_cache = {} + self.request_id2index = {} return def remove_all_loras(self) -> bool: @@ -1106,6 +1154,8 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: kv_caches, attn_metadata, memory_pool=self.graph_memory_pool, + conv_state=self.mamba_cache4gc[0][:, :batch_size], + ssm_state=self.mamba_cache4gc[1][:, :batch_size] ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -1144,6 +1194,8 @@ def capture( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, memory_pool, **kwargs, ) -> None: @@ -1157,6 +1209,8 @@ def capture( positions, kv_caches, attn_metadata, + conv_state, + ssm_state **kwargs, ) torch.cuda.synchronize() @@ -1172,6 +1226,9 @@ def capture( positions, kv_caches, attn_metadata, + input_metadata, + conv_state, + ssm_state **kwargs, ) torch.cuda.synchronize() @@ -1184,6 +1241,8 @@ def capture( "slot_mapping": attn_metadata.slot_mapping, "context_lens": attn_metadata.decode_metadata.context_lens, "block_tables": attn_metadata.decode_metadata.block_tables, + "conv_state": conv_state, + "ssm_state": ssm_state } self.output_buffers = {"hidden_states": hidden_states} return @@ -1194,6 +1253,8 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + conv_state:torch.Tensor, + ssm_state:torch.Tensor **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. @@ -1208,9 +1269,17 @@ def forward( attn_metadata.decode_metadata.context_lens, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + self.input_buffers["conv_state"].copy_(conv_state, + non_blocking=True) + self.input_buffers["ssm_state"].copy_(ssm_state, + non_blocking=True) # Run the graph. self.graph.replay() + # in-place edit of the mamba cache states as in the KV cache + ssm_state.copy_(self.input_buffers["ssm_state"]) + conv_state.copy_(self.input_buffers["conv_state"]) + # Return the output tensor. return self.output_buffers["hidden_states"] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 140c00ea130..47e052773a2 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -184,6 +184,7 @@ def _init_cache_engine(self): self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) + self.model_runner.prepare_contiguous_mamba_cache() def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -210,9 +211,9 @@ def cache_swap( def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: - if req_id in self.model_runner.request_id2mamba_cache: - del self.model_runner.request_id2mamba_cache[req_id] - logger.info(f"deleted { req_id } from mamba_cache") + if req_id in self.model_runner.request_id2index: + index = self.model_runner.request_id2index.pop(req_id) + logger.info(f"deleted { req_id } from mamba_cache with index = {index}") @torch.inference_mode() From 00bce1f33358e8cba924acb46c484334c5cc1c42 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 8 Apr 2024 16:37:21 +0300 Subject: [PATCH 005/110] dtype (#6) Co-authored-by: Mor Zusman --- vllm/worker/model_runner.py | 14 +++++++------- vllm/worker/worker.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9cd8411773d..258a285e7ec 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -162,9 +162,9 @@ def __init__( self.model_config.dtype if model_config is not None else None) @torch.inference_mode() - def prepare_contiguous_mamba_cache(self): + def prepare_contiguous_mamba_cache(self, dtype): is_mamba = self.model_config.hf_config.model_type == "jamba" - if not is_mamba: + if not is_mamba or self.mamba_cache is not None: return hf_config = self.model_config.hf_config num_layers = hf_config.num_hidden_layers @@ -183,10 +183,10 @@ def prepare_contiguous_mamba_cache(self): ) if self.mamba_cache is None: self.mamba_cache = {} - self.mamba_cache = (torch.empty(size=conv_state_shape, dtype=torch.float16, device="cuda"), - torch.empty(size=ssm_state_shape, dtype=torch.float16, device="cuda")) - self.mamba_cache4gc = (torch.empty(size=conv_state_shape, dtype=torch.float16, device="cuda"), - torch.empty(size=ssm_state_shape, dtype=torch.float16, device="cuda")) + self.mamba_cache = (torch.empty(size=conv_state_shape, dtype=dtype, device="cuda"), + torch.empty(size=ssm_state_shape, dtype=dtype, device="cuda")) + self.mamba_cache4gc = (torch.empty(size=conv_state_shape, dtype=dtype, device="cuda"), + torch.empty(size=ssm_state_shape, dtype=dtype, device="cuda")) def load_model(self) -> None: @@ -923,7 +923,7 @@ def execute_model( return None if self.mamba_cache is None: - self.prepare_contiguous_mamba_cache() + self.prepare_contiguous_mamba_cache(self.model_config.dtype) conv_state, ssm_state, indecies = self._prepare_request_mamba_cache(input_metadata, input_tokens.shape[0]) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 47e052773a2..099846b823a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -184,7 +184,7 @@ def _init_cache_engine(self): self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) - self.model_runner.prepare_contiguous_mamba_cache() + self.model_runner.prepare_contiguous_mamba_cache(self.cache_engine.dtype) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: From 39c27b7ab060d542d0c750920d8fe6fa52e8e613 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Apr 2024 13:33:28 +0300 Subject: [PATCH 006/110] N support (#8) * Return support for other models apart from jamba * Support n>1 * A little cleanup * Rename * Apply whitespace suggestions from code review * Add max batch size to the main func * Fixed attention kv cache bug * log where requests id are deleted from the dict to debug mode * Fix typo * Align with v0.3.3 vllm code * Remove comments * Take out model config from CUDAGraph object * Fix * Fix typo * Make the kv cache selection cleaner * Another typo * Took the num layers calc outside * Remove the -1 * Set as num layer / period --------- Co-authored-by: Mor Zusman Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/mamba_metadata.py | 7 +- vllm/model_executor/models/jamba.py | 6 +- vllm/worker/cache_engine.py | 22 ++-- vllm/worker/model_runner.py | 174 +++++++++++++++----------- vllm/worker/worker.py | 8 +- 5 files changed, 130 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/mamba_metadata.py b/vllm/model_executor/mamba_metadata.py index aa7346f33fb..7e349b5e49b 100644 --- a/vllm/model_executor/mamba_metadata.py +++ b/vllm/model_executor/mamba_metadata.py @@ -1,8 +1,9 @@ -from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, Optional, Tuple +from typing import List + import torch + @dataclass class MambaCacheParams: is_prompt: bool = False @@ -13,6 +14,6 @@ class MambaCacheParams: @dataclass class RequestInfo: request_id: str = '' - n: int = 1 + seqs_id: List[int] = field(default_factory=list) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6dbd515458c..73902d3e257 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -504,10 +504,12 @@ def forward( residual = None for i in range(len(self.layers)): layer = self.layers[i] - + kv_cache = None + if isinstance(layer, JambaAttentionDecoderLayer): + kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] hidden_states, residual = layer(positions=positions, hidden_states=hidden_states, - kv_cache=kv_caches[i], + kv_cache=kv_cache, input_metadata=input_metadata, residual=residual, conv_state=conv_state, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 66c76757b89..5bbcaa140b8 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -30,7 +30,7 @@ def __init__( self.parallel_config = parallel_config self.head_size = model_config.get_head_size() - self.num_layers = model_config.get_num_layers(parallel_config) + self.num_layers = CacheEngine.get_num_attention_layers(model_config, parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size @@ -80,6 +80,18 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) + @staticmethod + def get_num_attention_layers( + model_config:ModelConfig, + parallel_config:ParallelConfig + ): + num_layers = model_config.get_num_layers(parallel_config) + is_mamba = model_config.hf_config.model_type == "jamba" + if is_mamba: + attention_period = model_config.hf_config.attn_layer_period + num_layers = num_layers // attention_period + return num_layers + @staticmethod def get_cache_block_size( cache_config: CacheConfig, @@ -88,13 +100,7 @@ def get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = model_config.get_num_layers(parallel_config) - is_mamba = model_config.hf_config.model_type == "jamba" - - if is_mamba: - attention_period = model_config.hf_config.attn_layer_period - num_layers = max(num_layers // attention_period, 1) - + num_layers = CacheEngine.get_num_attention_layers(model_config,parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 258a285e7ec..6c4cdf9a2e0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -154,7 +154,7 @@ def __init__( # cache in_wsl result self.mamba_cache = None self.mamba_cache4gc = None - self.request_id2index = {} + self.request_id2index: Dict[str, Dict[int, int]] = {} self.in_wsl = in_wsl() self.kv_cache_dtype = kv_cache_dtype @@ -441,7 +441,7 @@ def _prepare_prompt( requests_info=[ RequestInfo( request_id=req.request_id, - n=req.sampling_params.n + seqs_id=list(req.seq_data.keys()) ) for req in seq_group_metadata_list ] @@ -579,10 +579,9 @@ def _prepare_decode( requests_info=[ RequestInfo( request_id=req.request_id, - n=req.sampling_params.n + seqs_id=list(req.seq_data.keys()) ) for req in seq_group_metadata_list] - ) return PrepareDecodeMetadata( input_tokens=input_tokens, @@ -790,13 +789,7 @@ def prepare_input_tensors( "slot_mapping": slot_mapping, "num_prefills": num_prefills, "batch_type": batch_type, - "requests_info": [ - RequestInfo( - request_id=req.request_id, - n=req.sampling_params.n - ) - for req in seq_group_metadata_list - ] + "requests_info": input_metadata.requests_info } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) @@ -922,22 +915,29 @@ def execute_model( if not sampling_metadata.perform_sampling: return None - if self.mamba_cache is None: - self.prepare_contiguous_mamba_cache(self.model_config.dtype) - - conv_state, ssm_state, indecies = self._prepare_request_mamba_cache(input_metadata, input_tokens.shape[0]) - - hidden_states = model_executable( - input_ids=input_tokens, - positions=input_positions, - kv_caches=kv_caches, - input_metadata=input_metadata, - conv_state=conv_state, - ssm_state=ssm_state - ) - for i,offset in enumerate(indecies): - self.mamba_cache[0][:,offset] = conv_state[:,i] - self.mamba_cache[1][:,offset] = ssm_state[:,i] + is_mamba = self.model_config.hf_config.model_type == "jamba" + indices = [] + conv_state = None + model_inputs = { + "input_ids":input_tokens, + "positions":input_positions, + "kv_caches":kv_caches, + "input_metadata":input_metadata, + } + if is_mamba: + if self.mamba_cache is None: + self.prepare_contiguous_mamba_cache(self.model_config.dtype) + conv_state, ssm_state, indices = self._prepare_request_mamba_cache(input_metadata, input_tokens.shape[0]) + model_inputs = { + **model_inputs, + "conv_state":conv_state, + "ssm_state":ssm_state, + } + hidden_states = model_executable(**model_inputs) + if is_mamba: + for i, offset in enumerate(indices): + self.mamba_cache[0][:, offset] = conv_state[:, i] + self.mamba_cache[1][:, offset] = ssm_state[:, i] # Sample the next token. output = self.model.sample( @@ -946,6 +946,13 @@ def execute_model( ) return output + def _get_first_free_mamba_cache_index(self): + max_possible_bs = self.mamba_cache[0].shape[1] + occupied = [id for seq_ids in self.request_id2index.values() for id in seq_ids.values()] + first_free_index = [i not in occupied for i in range(max_possible_bs)].index(True) + return first_free_index + + def _prepare_request_mamba_cache( self, input_metadata: InputMetadata, @@ -955,13 +962,26 @@ def _prepare_request_mamba_cache( max_possible_bs = self.mamba_cache[0].shape[1] for request_info in input_metadata.requests_info: if request_info.request_id not in self.request_id2index: - first_free_index = [i not in self.request_id2index.values() for i in range(max_possible_bs)].index(True) - self.request_id2index[request_info.request_id] = first_free_index - indices.append(self.request_id2index[request_info.request_id]) + self.request_id2index[request_info.request_id] = {} + for seq_id in request_info.seqs_id: + first_free_index = self._get_first_free_mamba_cache_index() + self.request_id2index[request_info.request_id][seq_id] = first_free_index + indices.append(first_free_index) + else: + for seq_id in request_info.seqs_id: + if seq_id not in self.request_id2index[request_info.request_id]: + first_free_index = self._get_first_free_mamba_cache_index() + ## case of decoding n>1 + if len(self.request_id2index[request_info.request_id].keys()) > 0: + self.mamba_cache[0][:,first_free_index].copy_(self.mamba_cache[0][:,list(self.request_id2index[request_info.request_id].values())[0]]) + self.mamba_cache[1][:,first_free_index].copy_(self.mamba_cache[1][:,list(self.request_id2index[request_info.request_id].values())[0]]) + self.request_id2index[request_info.request_id][seq_id] = first_free_index + indices.append(self.request_id2index[request_info.request_id][seq_id]) ## Pad the batch incase of running batch that was not captured via CG padded_indices = indices for _ in range(batch_size - len(indices)): - padded_indices += [[i not in set(self.request_id2index.values()).union(padded_indices) for i in range(max_possible_bs)].index(True)] + occupied = [id for seq_ids in self.request_id2index.values() for id in seq_ids.values()] + padded_indices += [[i not in set(occupied).union(padded_indices) for i in range(max_possible_bs)].index(True)] conv_state = self.mamba_cache[0][:,padded_indices] ssm_state = self.mamba_cache[1][:,padded_indices] @@ -1140,6 +1160,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: kv_cache_dtype=self.kv_cache_dtype, ) + is_mamba = self.model_config.hf_config.model_type == "jamba" if self.lora_config: lora_mapping = LoRAMapping( [0] * batch_size, @@ -1147,16 +1168,18 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( - input_tokens[:batch_size], - input_positions[:batch_size], - kv_caches, - attn_metadata, - memory_pool=self.graph_memory_pool, - conv_state=self.mamba_cache4gc[0][:, :batch_size], - ssm_state=self.mamba_cache4gc[1][:, :batch_size] - ) + graph_runner = CUDAGraphRunner(self.model,is_mamba) + capture_inputs = { + "input_ids" : input_tokens[:batch_size], + "positions" :input_positions[:batch_size], + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + "memory_pool":self.graph_memory_pool, + } + if is_mamba: + capture_inputs["conv_state"]=self.mamba_cache4gc[0][:, :batch_size] + capture_inputs["ssm_state"]=self.mamba_cache4gc[1][:, :batch_size] + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -1182,11 +1205,12 @@ def vocab_size(self) -> int: class CUDAGraphRunner: - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, is_mamba: bool): self.model = model self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} + self.is_mamba = is_mamba def capture( self, @@ -1197,22 +1221,29 @@ def capture( conv_state: torch.Tensor, ssm_state: torch.Tensor, memory_pool, + conv_state: Optional[torch.Tensor] = None, + ssm_state: Optional[torch.Tensor] = None, **kwargs, ) -> None: assert self.graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_pynccl(): - self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - conv_state, - ssm_state - **kwargs, - ) + model_inputs = { + "input_ids":input_ids, + "positions":positions, + "kv_caches":kv_caches, + "attn_metadata":attn_metadata, + } + if self.is_mamba: + model_inputs = { + **model_inputs, + "conv_state":conv_state, + "ssm_state":ssm_state, + } + + with _maybe_cupy_nccl(): + self.model(**model_inputs) torch.cuda.synchronize() # Capture the graph. @@ -1220,17 +1251,8 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with _maybe_pynccl(): - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - input_metadata, - conv_state, - ssm_state - **kwargs, - ) + with _maybe_cupy_nccl(): + hidden_states = self.model(**model_inputs) torch.cuda.synchronize() # Save the input and output buffers. @@ -1244,6 +1266,13 @@ def capture( "conv_state": conv_state, "ssm_state": ssm_state } + if self.is_mamba: + self.input_buffers = { + **self.input_buffers, + "conv_state": conv_state, + "ssm_state": ssm_state, + } + self.output_buffers = {"hidden_states": hidden_states} return @@ -1253,8 +1282,8 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state:torch.Tensor, - ssm_state:torch.Tensor + conv_state:Optional[torch.Tensor] = None, + ssm_state:Optional[torch.Tensor] = None **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. @@ -1269,16 +1298,19 @@ def forward( attn_metadata.decode_metadata.context_lens, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - self.input_buffers["conv_state"].copy_(conv_state, - non_blocking=True) - self.input_buffers["ssm_state"].copy_(ssm_state, - non_blocking=True) + if self.is_mamba: + self.input_buffers["conv_state"].copy_(conv_state, + non_blocking=True) + self.input_buffers["ssm_state"].copy_(ssm_state, + non_blocking=True) + # Run the graph. self.graph.replay() # in-place edit of the mamba cache states as in the KV cache - ssm_state.copy_(self.input_buffers["ssm_state"]) - conv_state.copy_(self.input_buffers["conv_state"]) + if self.is_mamba: + ssm_state.copy_(self.input_buffers["ssm_state"]) + conv_state.copy_(self.input_buffers["conv_state"]) # Return the output tensor. return self.output_buffers["hidden_states"] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 099846b823a..1707edace56 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -184,7 +184,9 @@ def _init_cache_engine(self): self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) - self.model_runner.prepare_contiguous_mamba_cache(self.cache_engine.dtype) + is_mamba = self.model_config.hf_config.model_type == "jamba" + if is_mamba: + self.model_runner.prepare_contiguous_mamba_cache(self.cache_engine.dtype) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -212,8 +214,8 @@ def cache_swap( def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.model_runner.request_id2index: - index = self.model_runner.request_id2index.pop(req_id) - logger.info(f"deleted { req_id } from mamba_cache with index = {index}") + indices = self.model_runner.request_id2index.pop(req_id) + logger.debug(f"Deleted { req_id } from mamba_cache with indices = {indices}") @torch.inference_mode() From 7c7586808e616dd4aa85a2e81cf7aa345af0d6e7 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Apr 2024 13:46:40 +0300 Subject: [PATCH 007/110] Tensor parallelism (#7) * Return support for other models apart from jamba * Support n>1 * Revert 2 commits d054737 'Support n>1' b5167cc 'Return support for other models apart from jamba' * TP on input and output * Basic TP impl , working, correctness not working * TP is working * Roll back the verification that everything in the weights fits into the model * Cleanup * Use world size func * clean up * Import * Apply whitespace suggestions from code review * Organize imports * Add comment on the unsqueeze in conv1d * Organize and remove redundant code in forward pass * Remove print * Add comments Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> * White spaces * Set as A * better comment --------- Co-authored-by: Mor Zusman Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/models/jamba.py | 109 +++++++++++++++++----------- vllm/worker/model_runner.py | 11 ++- 2 files changed, 74 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 73902d3e257..b0de0a23636 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -5,22 +5,20 @@ from typing import Dict, List, Optional, Tuple import torch -from torch import conv_transpose3d, nn -import os +from torch import nn from vllm.model_executor.mamba_metadata import MambaCacheParams from vllm.transformers_utils.configs.jamba import JambaConfig -from transformers.activations import ACT2FN +from torch.nn.parameter import Parameter from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) @@ -59,35 +57,65 @@ def __init__(self, config: JambaConfig, layer_idx): self.time_step_rank = config.mamba_dt_rank self.use_conv_bias = config.mamba_conv_bias self.use_bias = config.mamba_proj_bias - self.conv1d = nn.Conv1d( - in_channels=self.intermediate_size, - out_channels=self.intermediate_size, + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, bias=self.use_conv_bias, - kernel_size=self.conv_kernel_size, - groups=self.intermediate_size, - padding=self.conv_kernel_size - 1, ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. Can't do this in `weight_loader` since it already exists in `ColumnParallelLinear` and `set_weight_attrs` doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.apply_inner_layernorms = config.mamba_inner_layernorms - - # projection of the input hidden states - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) + self.in_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias + ) # selective projection used to make dt, B and C input dependant - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) - # time step projection (discretization) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False + ) + # time step projection (discretization) - In the forward we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear( + self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True + ) - # S4D real initialization. These are not discretized! - # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] - A = A.expand(self.intermediate_size, -1).contiguous() + def weight_loader(param:Parameter, loaded_weight:torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_(loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[tp_rank]) - self.A_log = nn.Parameter(torch.log(A)) - self.D = nn.Parameter(torch.ones(self.intermediate_size)) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + def A_weight_loader(param:Parameter, loaded_weight:torch.Tensor): + weight_loader(param,-torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter(torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32 + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, { + "weight_loader": weight_loader + }) + set_weight_attrs(self.A, { + "weight_loader": A_weight_loader + }) + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True + ) + self.activation = config.hidden_act + self.apply_inner_layernorms = config.mamba_inner_layernorms + if self.apply_inner_layernorms: self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) self.B_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) @@ -108,8 +136,7 @@ def _apply_layernorms(self, dt, B, C): def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states).transpose(1, 2) - + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation @@ -135,23 +162,14 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + time_step, B, C = torch.split( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) time_step, B, C = self._apply_layernorms(time_step, B, C) - # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. - # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed - # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized - # linear layers, and requires to call the forward pass directly. - # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)``` - dt_proj_bias = self.dt_proj.bias - self.dt_proj.bias = None - discrete_time_step = self.dt_proj(time_step).transpose(1, 2) - self.dt_proj.bias = dt_proj_bias - - A = -torch.exp(self.A_log.float()) + discrete_time_step = self.dt_proj(time_step)[0].transpose(1,2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if cache_params is not None and not cache_params.is_prompt: @@ -159,7 +177,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar cache_params.ssm_state, hidden_states[..., 0], discrete_time_step[..., 0], - A, + self.A, B[:, 0], C[:, 0], self.D, @@ -171,7 +189,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar scan_outputs, ssm_state = selective_scan_fn( hidden_states, discrete_time_step, - A, + self.A, B.transpose(1, 2), C.transpose(1, 2), self.D.float(), @@ -184,7 +202,7 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar cache_params.ssm_state.copy_(ssm_state) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] return contextualized_states def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): @@ -625,6 +643,9 @@ def load_weights(self, if "rotary_emb.inv_freq" in name: continue + if "A_log" in name: + name = name.replace("A_log","A") + if ".self_attn." in name: name = name.replace(".self_attn", "") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6c4cdf9a2e0..d4b6ba6c9d0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,12 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.mamba_metadata import RequestInfo from vllm.model_executor.model_loader import get_model +from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + with_pynccl_for_all_reduce, + get_tensor_model_parallel_world_size) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) @@ -169,16 +175,17 @@ def prepare_contiguous_mamba_cache(self, dtype): hf_config = self.model_config.hf_config num_layers = hf_config.num_hidden_layers max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + world_size = get_tensor_model_parallel_world_size() conv_state_shape = ( num_layers, max_batch_size, - hf_config.mamba_expand * hf_config.hidden_size, + hf_config.mamba_expand * hf_config.hidden_size // world_size, hf_config.mamba_d_conv, ) ssm_state_shape = ( num_layers, max_batch_size, - hf_config.mamba_expand * hf_config.hidden_size, + hf_config.mamba_expand * hf_config.hidden_size // world_size, hf_config.mamba_d_state, ) if self.mamba_cache is None: From 30e6dcdb546c50f4107e6d72f1b6f831d1e72b6f Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Apr 2024 15:17:58 +0300 Subject: [PATCH 008/110] After merge fixes --- vllm/model_executor/__init__.py | 4 +--- vllm/worker/model_runner.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 8fbbdf06526..8be2a869f33 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,13 +1,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo, MambaCache +from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo from vllm.model_executor.utils import set_random_seed __all__ = [ "SamplingMetadata", "set_random_seed", "MambaCacheParams", - "RequestInfo", - "MambaCache", "RequestInfo" ] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d4b6ba6c9d0..b56edce04b4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1225,8 +1225,6 @@ def capture( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, memory_pool, conv_state: Optional[torch.Tensor] = None, ssm_state: Optional[torch.Tensor] = None, From 5c0efdc2adfe2f0939d984dee4c037151acc2445 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 14 Apr 2024 18:10:11 +0300 Subject: [PATCH 009/110] Clean up --- vllm/model_executor/models/jamba.py | 2 +- vllm/worker/model_runner.py | 151 +++++++++++++--------------- vllm/worker/worker.py | 8 -- 3 files changed, 73 insertions(+), 88 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b0de0a23636..bdb5da2d89b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -583,7 +583,7 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + self.sampler = Sampler() def forward( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b56edce04b4..02a84f741cd 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -158,11 +158,10 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config # cache in_wsl result + self.is_mamba = self.model_config.hf_config.model_type == "jamba" self.mamba_cache = None self.mamba_cache4gc = None - self.request_id2index: Dict[str, Dict[int, int]] = {} - self.in_wsl = in_wsl() - self.kv_cache_dtype = kv_cache_dtype + self.request2i: Dict[str, Dict[int, int]] = {} self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) @@ -445,13 +444,6 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, - requests_info=[ - RequestInfo( - request_id=req.request_id, - seqs_id=list(req.seq_data.keys()) - ) - for req in seq_group_metadata_list - ] ) return PreparePromptMetadata( @@ -583,12 +575,6 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, - requests_info=[ - RequestInfo( - request_id=req.request_id, - seqs_id=list(req.seq_data.keys()) - ) - for req in seq_group_metadata_list] ) return PrepareDecodeMetadata( input_tokens=input_tokens, @@ -782,6 +768,10 @@ def prepare_input_tensors( batch_type = BatchType.PREFILL else: batch_type = BatchType.DECODE + requests_info = [ RequestInfo( + request_id=req.request_id, + seqs_id=list(req.seq_data.keys()) + ) for req in seq_group_metadata_list] metadata_dict = { "input_tokens": input_tokens, @@ -796,7 +786,7 @@ def prepare_input_tensors( "slot_mapping": slot_mapping, "num_prefills": num_prefills, "batch_type": batch_type, - "requests_info": input_metadata.requests_info + "requests_info": requests_info } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) @@ -835,24 +825,8 @@ def prepare_input_tensors( else: decode_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) + requests_info = metadata_dict.pop("requests_info") attn_metadata = self.attn_backend.make_metadata(**metadata_dict) - input_tokens = metadata_dict["input_tokens"] - input_positions = metadata_dict["input_positions"] - lora_mapping = metadata_dict["lora_mapping"] - lora_requests = metadata_dict["lora_requests"] - input_metadata = InputMetadata( - is_prompt=metadata_dict["is_prompt"], - slot_mapping=metadata_dict["slot_mapping"], - prompt_lens=metadata_dict["prompt_lens"], - max_seq_len=metadata_dict["max_seq_len"], - start_loc=metadata_dict["start_loc"], - max_context_len=metadata_dict["max_context_len"], - context_lens=metadata_dict["context_lens"], - block_tables=metadata_dict["block_tables"], - use_cuda_graph=metadata_dict["use_cuda_graph"], - kv_cache_dtype=metadata_dict["kv_cache_dtype"], - requests_info=metadata_dict["requests_info"] - ) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -882,7 +856,14 @@ def prepare_input_tensors( return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, - multi_modal_input) + multi_modal_input, requests_info) + + def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.request2i: + indices = self.request2i.pop(req_id) + logger.debug(f"Deleted { req_id } from mamba_cache with indices = {indices}") + @torch.inference_mode() def execute_model( @@ -891,7 +872,7 @@ def execute_model( kv_caches: List[torch.Tensor] ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_input + lora_requests, lora_mapping, multi_modal_input, requests_info ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: @@ -905,6 +886,8 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model + + indices = [] execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, @@ -913,6 +896,20 @@ def execute_model( } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) + + if self.is_mamba: + if self.mamba_cache is None: + self.prepare_contiguous_mamba_cache(self.model_config.dtype) + conv_state, ssm_state, indices = self._prepare_request_mamba_cache( + requests_info, + input_tokens.shape[0] + ) + execute_model_kwargs = { + **execute_model_kwargs, + "conv_state":conv_state, + "ssm_state":ssm_state, + } + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. @@ -922,26 +919,7 @@ def execute_model( if not sampling_metadata.perform_sampling: return None - is_mamba = self.model_config.hf_config.model_type == "jamba" - indices = [] - conv_state = None - model_inputs = { - "input_ids":input_tokens, - "positions":input_positions, - "kv_caches":kv_caches, - "input_metadata":input_metadata, - } - if is_mamba: - if self.mamba_cache is None: - self.prepare_contiguous_mamba_cache(self.model_config.dtype) - conv_state, ssm_state, indices = self._prepare_request_mamba_cache(input_metadata, input_tokens.shape[0]) - model_inputs = { - **model_inputs, - "conv_state":conv_state, - "ssm_state":ssm_state, - } - hidden_states = model_executable(**model_inputs) - if is_mamba: + if self.is_mamba: for i, offset in enumerate(indices): self.mamba_cache[0][:, offset] = conv_state[:, i] self.mamba_cache[1][:, offset] = ssm_state[:, i] @@ -955,40 +933,55 @@ def execute_model( def _get_first_free_mamba_cache_index(self): max_possible_bs = self.mamba_cache[0].shape[1] - occupied = [id for seq_ids in self.request_id2index.values() for id in seq_ids.values()] - first_free_index = [i not in occupied for i in range(max_possible_bs)].index(True) + occupied = [ + id + for seq_ids in self.request2i.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied + for i in range(max_possible_bs) + ].index(True) return first_free_index def _prepare_request_mamba_cache( self, - input_metadata: InputMetadata, + requests_info: List[RequestInfo], batch_size: int ): indices = [] max_possible_bs = self.mamba_cache[0].shape[1] - for request_info in input_metadata.requests_info: - if request_info.request_id not in self.request_id2index: - self.request_id2index[request_info.request_id] = {} + for request_info in requests_info: + cur_rid = request_info.request_id + if cur_rid not in self.request2i: + self.request2i[cur_rid] = {} for seq_id in request_info.seqs_id: - first_free_index = self._get_first_free_mamba_cache_index() - self.request_id2index[request_info.request_id][seq_id] = first_free_index - indices.append(first_free_index) + f_free_index = self._get_first_free_mamba_cache_index() + self.request2i[cur_rid][seq_id] = f_free_index + indices.append(f_free_index) else: for seq_id in request_info.seqs_id: - if seq_id not in self.request_id2index[request_info.request_id]: - first_free_index = self._get_first_free_mamba_cache_index() + if seq_id not in self.request2i[cur_rid]: + f_free_index = self._get_first_free_mamba_cache_index() ## case of decoding n>1 - if len(self.request_id2index[request_info.request_id].keys()) > 0: - self.mamba_cache[0][:,first_free_index].copy_(self.mamba_cache[0][:,list(self.request_id2index[request_info.request_id].values())[0]]) - self.mamba_cache[1][:,first_free_index].copy_(self.mamba_cache[1][:,list(self.request_id2index[request_info.request_id].values())[0]]) - self.request_id2index[request_info.request_id][seq_id] = first_free_index - indices.append(self.request_id2index[request_info.request_id][seq_id]) + i_exist = list(self.request2i[cur_rid].values())[0] + self.mamba_cache[0][:,f_free_index].copy_( + self.mamba_cache[0][:,i_exist] + ) + self.mamba_cache[1][:,f_free_index].copy_( + self.mamba_cache[1][:,i_exist] + ) + self.request2i[cur_rid][seq_id] = f_free_index + indices.append(self.request2i[cur_rid][seq_id]) ## Pad the batch incase of running batch that was not captured via CG padded_indices = indices for _ in range(batch_size - len(indices)): - occupied = [id for seq_ids in self.request_id2index.values() for id in seq_ids.values()] - padded_indices += [[i not in set(occupied).union(padded_indices) for i in range(max_possible_bs)].index(True)] + occu = [i for s_ids in self.request2i.values() for i in s_ids.values()] + padded_indices += [[ + i not in set(occu).union(padded_indices) + for i in range(max_possible_bs) + ].index(True)] conv_state = self.mamba_cache[0][:,padded_indices] ssm_state = self.mamba_cache[1][:,padded_indices] @@ -1032,6 +1025,7 @@ def profile_run(self) -> None: # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. + breakpoint() if self.vision_language_config: max_num_seqs = min( max_num_seqs, @@ -1059,7 +1053,7 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() - self.request_id2index = {} + self.request2i = {} return def remove_all_loras(self) -> bool: @@ -1167,7 +1161,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: kv_cache_dtype=self.kv_cache_dtype, ) - is_mamba = self.model_config.hf_config.model_type == "jamba" if self.lora_config: lora_mapping = LoRAMapping( [0] * batch_size, @@ -1175,7 +1168,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model,is_mamba) + graph_runner = CUDAGraphRunner(self.model,self.is_mamba) capture_inputs = { "input_ids" : input_tokens[:batch_size], "positions" :input_positions[:batch_size], @@ -1183,7 +1176,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: "attn_metadata": attn_metadata, "memory_pool":self.graph_memory_pool, } - if is_mamba: + if self.is_mamba: capture_inputs["conv_state"]=self.mamba_cache4gc[0][:, :batch_size] capture_inputs["ssm_state"]=self.mamba_cache4gc[1][:, :batch_size] graph_runner.capture(**capture_inputs) @@ -1288,7 +1281,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, conv_state:Optional[torch.Tensor] = None, - ssm_state:Optional[torch.Tensor] = None + ssm_state:Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1707edace56..088f9a03916 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -25,7 +25,6 @@ from vllm.utils import is_hip from vllm.logger import init_logger -logger = init_logger(__name__) class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -211,13 +210,6 @@ def cache_swap( self.cache_engine.copy(blocks_to_copy) - def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.model_runner.request_id2index: - indices = self.model_runner.request_id2index.pop(req_id) - logger.debug(f"Deleted { req_id } from mamba_cache with indices = {indices}") - - @torch.inference_mode() def execute_model( self, From 19f11f3433581b56a44dc0042e8d9708a5605260 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Apr 2024 13:27:28 +0300 Subject: [PATCH 010/110] Add release mamba cache to executor_base --- vllm/engine/llm_engine.py | 5 +---- vllm/executor/executor_base.py | 3 +++ vllm/executor/gpu_executor.py | 4 ++++ vllm/executor/ray_gpu_executor.py | 11 +++++++++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 17421c85ed7..6555a089480 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -658,10 +658,7 @@ def _process_model_outputs( ] if len(finished_seq_groups_req_ids) > 0: - self._run_workers( - "release_mamba_cache", - finished_seq_groups_req_ids= finished_seq_groups_req_ids, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + self.model_executor.release_mamba_cache(finished_seq_groups_req_ids) self.scheduler.free_finished_seq_groups() diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbb6ec80f7b..3121a96f8cc 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -94,6 +94,9 @@ def check_health(self) -> None: exception.""" raise NotImplementedError + @abstractmethod + def release_mamba_cache(self,requests_id:List[str]) -> None: + raise NotImplementedError class ExecutorAsyncBase(ExecutorBase): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index bae509f4802..06d677b13c0 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -92,6 +92,10 @@ def check_health(self) -> None: # it's running. return + def release_mamba_cache(self, requests_id:List[str]) -> None: + self.driver_worker.release_mamba_cache(requests_id) + + class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 7aca5e36107..e8c15f1a531 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -364,6 +364,12 @@ def _check_if_any_actor_is_dead(self): raise RuntimeError("At least one Worker is dead. " f"Dead Workers: {dead_actors}. ") + def release_mamba_cache(self, requests_id:List[str]) -> None: + self._run_workers( + "release_mamba_cache", + requests_id= requests_id, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): @@ -413,3 +419,8 @@ async def execute_model_async( # Only the driver worker returns the sampling results. output = all_outputs[0] return output + + async def check_health_async(self) -> None: + """Raises an error if engine is unhealthy.""" + self._check_if_any_actor_is_dead() + From 1fb817aae4b0b0ba46e0ec04f59187409e1f8a73 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Apr 2024 13:27:42 +0300 Subject: [PATCH 011/110] Add jamba modifications --- vllm/model_executor/models/jamba.py | 134 ++++++++++++++++++---------- 1 file changed, 86 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index bdb5da2d89b..d33beac59a0 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -6,13 +6,14 @@ import torch from torch import nn +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention from vllm.model_executor.mamba_metadata import MambaCacheParams from vllm.transformers_utils.configs.jamba import JambaConfig from torch.nn.parameter import Parameter from vllm.config import LoRAConfig -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear, @@ -205,13 +206,45 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] return contextualized_states - def forward(self, hidden_states: torch.Tensor, input_metadata: InputMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor): - cache = MambaCacheParams( - input_metadata.is_prompt, - conv_state=conv_state[self.layer_idx], - ssm_state=ssm_state[self.layer_idx] - ) - hidden_states = self.mamba_forward(hidden_states, cache_params=cache) + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor + ): + if attn_metadata.is_prompt: + max_seq_len = max(attn_metadata.prompt_lens) + batch_size = len(attn_metadata.prompt_lens) + padded_hidden_states = torch.zeros(( + batch_size, + max_seq_len, + hidden_states.shape[-1], + ), dtype=hidden_states.dtype, device=hidden_states.device) + offset = 0 + for i,prompt_len in enumerate(attn_metadata.prompt_lens): + padded_hidden_states[i,-prompt_len:].copy_(hidden_states[offset:offset + prompt_len]) + offset += prompt_len + cache = MambaCacheParams( + True, + conv_state=conv_state[self.layer_idx], + ssm_state=ssm_state[self.layer_idx] + ) + padded_hidden_states = self.mamba_forward(padded_hidden_states, cache_params=cache) + offset = 0 + for i,prompt_len in enumerate(attn_metadata.prompt_lens): + hidden_states[offset:offset + prompt_len].copy_(padded_hidden_states[i,-prompt_len:]) + offset += prompt_len + else: + cache = MambaCacheParams( + False, + conv_state=conv_state[self.layer_idx], + ssm_state=ssm_state[self.layer_idx] + ) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1).contiguous(), cache_params=cache) + hidden_states = hidden_states.squeeze(1).contiguous() + + return hidden_states @@ -289,7 +322,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_data[expert_id, :, :] = loaded_weight[:, shard] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_size = hidden_states.shape + num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (batch * sequence_length, n_experts) if self.num_total_experts > 1: @@ -310,7 +343,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - return final_hidden_states.view(batch_size, sequence_length, + return final_hidden_states.view(num_tokens, hidden_size) @@ -332,7 +365,7 @@ def __init__(self, config: JambaConfig, actual_num_experts: int, actual_num_expe def forward(self, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], conv_state: torch.Tensor, ssm_state: torch.Tensor, @@ -346,7 +379,7 @@ def forward(self, hidden_states = self.mamba( hidden_states, - input_metadata, + attn_metadata, conv_state, ssm_state ) @@ -381,7 +414,6 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim ** -0.5 - self.use_positional_embeddings = False self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( @@ -399,7 +431,7 @@ def __init__( linear_method=linear_method, ) - self.attn = PagedAttention( + self.attn = Attention( self.num_heads, self.head_dim, self.scaling, @@ -407,7 +439,6 @@ def __init__( sliding_window=self.sliding_window, ) - self.moe = JambaMoE( num_experts=actual_num_experts, top_k=actual_num_experts_per_tok, @@ -421,16 +452,13 @@ def __init__( def self_attention(self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, **kwargs) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # TODO - add embedding flag - if self.use_positional_embeddings: - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -438,8 +466,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], **kwargs): if residual is None: @@ -452,7 +480,7 @@ def forward( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.pre_moe_layernorm( @@ -513,8 +541,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor ) -> torch.Tensor: @@ -528,7 +556,7 @@ def forward( hidden_states, residual = layer(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, + attn_metadata=attn_metadata, residual=residual, conv_state=conv_state, ssm_state=ssm_state @@ -583,6 +611,8 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward( @@ -590,7 +620,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], - input_metadata: InputMetadata, + attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor ): @@ -598,19 +628,24 @@ def forward( input_ids, positions, kv_caches, - input_metadata, + attn_metadata, conv_state, ssm_state ) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + def sample( - self, - hidden_states: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, @@ -656,28 +691,31 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + if name in params_dict: + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break else: for param_name, weight_name, expert_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + if name in params_dict: + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 30ae4a130be06c6d9123121cce45852a1ff09864 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Apr 2024 13:29:04 +0300 Subject: [PATCH 012/110] Add minimun 1 attention layer --- vllm/worker/cache_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 5bbcaa140b8..ee6eb58afc6 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -89,7 +89,7 @@ def get_num_attention_layers( is_mamba = model_config.hf_config.model_type == "jamba" if is_mamba: attention_period = model_config.hf_config.attn_layer_period - num_layers = num_layers // attention_period + num_layers = max(num_layers // attention_period, 1) return num_layers @staticmethod From 7bd9c0ae3ad6cd622a9f7ad2f91564d4b08fb266 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Apr 2024 13:30:00 +0300 Subject: [PATCH 013/110] More fixes --- vllm/worker/model_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 02a84f741cd..b5414ea2756 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -902,7 +902,8 @@ def execute_model( self.prepare_contiguous_mamba_cache(self.model_config.dtype) conv_state, ssm_state, indices = self._prepare_request_mamba_cache( requests_info, - input_tokens.shape[0] + input_tokens.shape[0] if + not attn_metadata.is_prompt else len(requests_info) ) execute_model_kwargs = { **execute_model_kwargs, @@ -1025,7 +1026,6 @@ def profile_run(self) -> None: # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. - breakpoint() if self.vision_language_config: max_num_seqs = min( max_num_seqs, @@ -1232,6 +1232,7 @@ def capture( "positions":positions, "kv_caches":kv_caches, "attn_metadata":attn_metadata, + **kwargs } if self.is_mamba: model_inputs = { @@ -1240,7 +1241,7 @@ def capture( "ssm_state":ssm_state, } - with _maybe_cupy_nccl(): + with _maybe_pynccl(): self.model(**model_inputs) torch.cuda.synchronize() @@ -1249,7 +1250,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with _maybe_cupy_nccl(): + with _maybe_pynccl(): hidden_states = self.model(**model_inputs) torch.cuda.synchronize() From d5ac8e84ab8b564f688c5c800df554bae33e2577 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 15 Apr 2024 13:30:19 +0300 Subject: [PATCH 014/110] Delete mamba cache --- vllm/worker/worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 088f9a03916..4bbfc09d244 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -272,6 +272,8 @@ def get_cache_block_size_bytes(self) -> int: self.model_config, self.parallel_config) + def release_mamba_cache(self,requests_id:List[str]): + self.model_runner.release_mamba_cache(requests_id) def init_worker_distributed_environment( parallel_config: ParallelConfig, From 60b49b50c36a0572ba5e8a4c9441ac8cf9617415 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 10:08:44 +0300 Subject: [PATCH 015/110] Jamba padding to the left --- vllm/model_executor/models/jamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index d33beac59a0..400066ce0a7 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -223,7 +223,7 @@ def forward( ), dtype=hidden_states.dtype, device=hidden_states.device) offset = 0 for i,prompt_len in enumerate(attn_metadata.prompt_lens): - padded_hidden_states[i,-prompt_len:].copy_(hidden_states[offset:offset + prompt_len]) + padded_hidden_states[i,:prompt_len].copy_(hidden_states[offset:offset + prompt_len]) offset += prompt_len cache = MambaCacheParams( True, @@ -233,7 +233,7 @@ def forward( padded_hidden_states = self.mamba_forward(padded_hidden_states, cache_params=cache) offset = 0 for i,prompt_len in enumerate(attn_metadata.prompt_lens): - hidden_states[offset:offset + prompt_len].copy_(padded_hidden_states[i,-prompt_len:]) + hidden_states[offset:offset + prompt_len].copy_(padded_hidden_states[i,:prompt_len]) offset += prompt_len else: cache = MambaCacheParams( From c583fe8a677065011c770ab784b31f6f14ead58a Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 10:37:54 +0300 Subject: [PATCH 016/110] Clean up --- requirements-common.txt | 6 +-- vllm/engine/llm_engine.py | 1 - vllm/model_executor/input_metadata.py | 57 --------------------------- vllm/worker/model_runner.py | 12 +----- vllm/worker/worker.py | 4 -- 5 files changed, 3 insertions(+), 77 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index da08df3721e..0c7f243583e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -10,12 +10,10 @@ fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 +tiktoken == 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.9.3 +outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 -pynvml == 11.5.0 -triton >= 2.1.0 -outlines == 0.0.34 -tiktoken == 0.6.0 # Required for DBRX tokenizer mamba-ssm causal-conv1d diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6555a089480..0c327ca24c1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -661,7 +661,6 @@ def _process_model_outputs( self.model_executor.release_mamba_cache(finished_seq_groups_req_ids) self.scheduler.free_finished_seq_groups() - # Create the outputs. request_outputs: List[RequestOutput] = [] for scheduled_seq_group in scheduled_seq_groups: diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index a63fa2ba212..8b137891791 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,58 +1 @@ -from typing import Dict, List, Optional -import torch - -from vllm.model_executor.mamba_metadata import RequestInfo - - -class InputMetadata: - """Metadata for input sequences. Used in PagedAttention. - - Args: - prompt_lens: Lengths of prompts. - slot_mapping: The address to write the new KV to of each token. - max_context_len: The maximum context length. - context_lens: the length of attention context for each sequence. - block_tables: The block tables. (Seq id -> list of physical block) - kv_cache_dtype: Data type to store kv cache. - """ - - def __init__( - self, - is_prompt: bool, - slot_mapping: torch.Tensor, - prompt_lens: Optional[torch.Tensor], - max_seq_len: Optional[int], - start_loc: Optional[torch.Tensor], - max_context_len: Optional[int], - context_lens: Optional[torch.Tensor], - block_tables: Optional[torch.Tensor], - use_cuda_graph: bool, - kv_cache_dtype: str, - requests_info: Optional[List[RequestInfo]] = None - ) -> None: - self.is_prompt = is_prompt - self.prompt_lens = prompt_lens - self.max_seq_len = max_seq_len - self.start_loc = start_loc - self.max_context_len = max_context_len - self.slot_mapping = slot_mapping - self.context_lens = context_lens - self.block_tables = block_tables - self.use_cuda_graph = use_cuda_graph - self.kv_cache_dtype = kv_cache_dtype - - # Set during the execution of the first attention op. - # FIXME(woosuk): This is a hack. - self.attn_bias = None - self.requests_info = requests_info - - def __repr__(self) -> str: - return ("InputMetadata(" - f"is_prompt={self.is_prompt}, " - f"max_context_len={self.max_context_len}, " - f"slot_mapping={self.slot_mapping}, " - f"context_lens={self.context_lens}, " - f"block_tables={self.block_tables}, " - f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5414ea2756..a69831d6c56 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,7 +6,6 @@ import numpy as np import torch import torch.nn as nn -from collections import defaultdict from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) @@ -23,12 +22,6 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.mamba_metadata import RequestInfo from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - with_pynccl_for_all_reduce, - get_tensor_model_parallel_world_size) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) @@ -443,7 +436,6 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, ) return PreparePromptMetadata( @@ -574,7 +566,6 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, ) return PrepareDecodeMetadata( input_tokens=input_tokens, @@ -826,7 +817,6 @@ def prepare_input_tensors( decode_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) requests_info = metadata_dict.pop("requests_info") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -869,7 +859,7 @@ def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[torch.Tensor] + kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input, requests_info diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4bbfc09d244..16663f18500 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,9 +21,6 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase -from vllm.lora.request import LoRARequest -from vllm.utils import is_hip -from vllm.logger import init_logger class Worker(WorkerBase): @@ -209,7 +206,6 @@ def cache_swap( if blocks_to_copy: self.cache_engine.copy(blocks_to_copy) - @torch.inference_mode() def execute_model( self, From c951b7d21fffa8339d57a96508f1c12301f39750 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 10:43:22 +0300 Subject: [PATCH 017/110] Add import --- vllm/transformers_utils/config.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 36edbbfce30..06dd8fcfb37 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,8 +2,14 @@ from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - JAISConfig, MPTConfig, RWConfig) +from vllm.transformers_utils.configs import ( + ChatGLMConfig, + DbrxConfig, + JAISConfig, + MPTConfig, + RWConfig, + JambaConfig, +) _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, From da6d0f2a5365772f64c489878004e5b3b88e299f Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 11:10:41 +0300 Subject: [PATCH 018/110] Another clean up --- vllm/model_executor/models/jamba.py | 23 ++++++++++++++--------- vllm/transformers_utils/config.py | 9 +-------- vllm/worker/model_runner.py | 8 ++++++-- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 400066ce0a7..11f388ae214 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -25,8 +25,6 @@ VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, @@ -35,6 +33,8 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -63,7 +63,10 @@ def __init__(self, config: JambaConfig, layer_idx): output_size=self.intermediate_size, bias=self.use_conv_bias, ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. Can't do this in `weight_loader` since it already exists in `ColumnParallelLinear` and `set_weight_attrs` doesn't allow to override it + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.in_proj = MergedColumnParallelLinear( @@ -77,7 +80,9 @@ def __init__(self, config: JambaConfig, layer_idx): self.time_step_rank + self.ssm_state_size * 2, bias=False ) - # time step projection (discretization) - In the forward we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel. + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. self.dt_proj = ColumnParallelLinear( self.time_step_rank, self.intermediate_size, @@ -213,16 +218,16 @@ def forward( conv_state: torch.Tensor, ssm_state: torch.Tensor ): - if attn_metadata.is_prompt: - max_seq_len = max(attn_metadata.prompt_lens) - batch_size = len(attn_metadata.prompt_lens) + if attn_metadata.prefill_metadata is not None: + max_seq_len = max(attn_metadata.prefill_metadata.prompt_lens) + batch_size = len(attn_metadata.prefill_metadata.prompt_lens) padded_hidden_states = torch.zeros(( batch_size, max_seq_len, hidden_states.shape[-1], ), dtype=hidden_states.dtype, device=hidden_states.device) offset = 0 - for i,prompt_len in enumerate(attn_metadata.prompt_lens): + for i,prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): padded_hidden_states[i,:prompt_len].copy_(hidden_states[offset:offset + prompt_len]) offset += prompt_len cache = MambaCacheParams( @@ -232,7 +237,7 @@ def forward( ) padded_hidden_states = self.mamba_forward(padded_hidden_states, cache_params=cache) offset = 0 - for i,prompt_len in enumerate(attn_metadata.prompt_lens): + for i,prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): hidden_states[offset:offset + prompt_len].copy_(padded_hidden_states[i,:prompt_len]) offset += prompt_len else: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 06dd8fcfb37..366d3bb8ff2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,14 +2,7 @@ from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import ( - ChatGLMConfig, - DbrxConfig, - JAISConfig, - MPTConfig, - RWConfig, - JambaConfig, -) +from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig, JAISConfig, MPTConfig, RWConfig, JambaConfig _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a69831d6c56..40d2808454b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,7 +12,11 @@ from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, TensorizerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce +from vllm.distributed import ( + broadcast_tensor_dict, + with_pynccl_for_all_reduce, + get_tensor_model_parallel_world_size, +) from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) from vllm.logger import init_logger @@ -893,7 +897,7 @@ def execute_model( conv_state, ssm_state, indices = self._prepare_request_mamba_cache( requests_info, input_tokens.shape[0] if - not attn_metadata.is_prompt else len(requests_info) + attn_metadata.prefill_metadata is None else len(requests_info) ) execute_model_kwargs = { **execute_model_kwargs, From eb799234f56b44c623844366e71285f460718496 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 12:01:20 +0300 Subject: [PATCH 019/110] Align to main --- vllm/model_executor/models/__init__.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index aa7dc4b775e..5ec5ef2ae82 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -31,7 +31,8 @@ "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), + "LlavaForConditionalGeneration": + ("llava", "LlavaForConditionalGeneration"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), @@ -66,13 +67,17 @@ # Models partially supported by ROCm. # Architecture -> Reason. _ROCM_PARTIALLY_SUPPORTED_MODELS = { - "Qwen2ForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention", - "MistralForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention", - "MixtralForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention", + "Qwen2ForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", + "MistralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", + "MixtralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", } class ModelRegistry: + @staticmethod def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch in _OOT_MODELS: @@ -83,16 +88,15 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError( f"Model architecture {model_arch} is not supported by " - "ROCm for now." - ) + "ROCm for now.") if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: logger.warning( f"Model architecture {model_arch} is partially supported " - "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] - ) + "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) module_name, model_cls_name = _MODELS[model_arch] - module = importlib.import_module(f"vllm.model_executor.models.{module_name}") + module = importlib.import_module( + f"vllm.model_executor.models.{module_name}") return getattr(module, model_cls_name, None) @staticmethod From 919edbaa381f2bc9646a6c3a924b36f434d832a1 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 12:01:32 +0300 Subject: [PATCH 020/110] Fix reduce --- vllm/model_executor/models/jamba.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 11f388ae214..dda848568a2 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -23,8 +23,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, @@ -34,7 +32,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -462,7 +460,6 @@ def self_attention(self, **kwargs) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # TODO - add embedding flag attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output From 46685660bb4faa347b2599a6a6c1b3657c5fbd32 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 12:01:43 +0300 Subject: [PATCH 021/110] Another fix --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 40d2808454b..d077aabd8bc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -810,6 +810,7 @@ def prepare_input_tensors( num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") num_decode_tokens = metadata_dict.pop("num_decode_tokens") batch_type = metadata_dict.pop("batch_type") + requests_info = metadata_dict.pop("requests_info") # Create an attention metadata. prefill_attn_metadata = None @@ -820,7 +821,6 @@ def prepare_input_tensors( else: decode_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) - requests_info = metadata_dict.pop("requests_info") sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, From 11a073782f86765a88798229bd2bb5a971e3cc3a Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 12:09:13 +0300 Subject: [PATCH 022/110] Black format for jamba --- vllm/model_executor/models/jamba.py | 562 ++++++++++++++++------------ 1 file changed, 321 insertions(+), 241 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index dda848568a2..4fe65231f8b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -16,26 +16,39 @@ from vllm.config import LoRAConfig from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) + VocabParallelEmbedding, + ParallelLMHead, + DEFAULT_VOCAB_PADDING_SIZE, +) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) KVCache = Tuple[torch.Tensor, torch.Tensor] + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ @@ -61,65 +74,62 @@ def __init__(self, config: JambaConfig, layer_idx): output_size=self.intermediate_size, bias=self.use_conv_bias, ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.in_proj = MergedColumnParallelLinear( - self.hidden_size, - [self.intermediate_size] * 2, - bias=self.use_bias + self.hidden_size, [self.intermediate_size] * 2, bias=self.use_bias ) # selective projection used to make dt, B and C input dependant self.x_proj = RowParallelLinear( self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, - bias=False + bias=False, ) - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. self.dt_proj = ColumnParallelLinear( - self.time_step_rank, - self.intermediate_size, - bias=True, - skip_bias_add=True + self.time_step_rank, self.intermediate_size, bias=True, skip_bias_add=True ) - def weight_loader(param:Parameter, loaded_weight:torch.Tensor): + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - param.data.copy_(loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[tp_rank]) + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[ + tp_rank + ] + ) - def A_weight_loader(param:Parameter, loaded_weight:torch.Tensor): - weight_loader(param,-torch.exp(loaded_weight.float())) + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) tp_size = get_tensor_model_parallel_world_size() - self.A = nn.Parameter(torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, - dtype=torch.float32 - )) + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + ) + ) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - set_weight_attrs(self.D, { - "weight_loader": weight_loader - }) - set_weight_attrs(self.A, { - "weight_loader": A_weight_loader - }) + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) self.out_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=self.use_bias, - input_is_parallel=True + input_is_parallel=True, ) self.activation = config.hidden_act self.apply_inner_layernorms = config.mamba_inner_layernorms - + if self.apply_inner_layernorms: self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) self.B_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) @@ -138,13 +148,17 @@ def _apply_layernorms(self, dt, B, C): C = self.C_layernorm.forward(C.contiguous()) return dt, B, C - def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): + def mamba_forward( + self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None + ): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) if cache_params is not None and not cache_params.is_prompt: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), @@ -161,7 +175,10 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar ) cache_params.conv_state.copy_(conv_states) hidden_states = causal_conv1d_fn( - hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, ) # 3. State Space Model sequence transformation @@ -169,13 +186,17 @@ def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCachePar ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] time_step, B, C = torch.split( - ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, ) time_step, B, C = self._apply_layernorms(time_step, B, C) - discrete_time_step = self.dt_proj(time_step)[0].transpose(1,2) + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + time_proj_bias = ( + self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + ) if cache_params is not None and not cache_params.is_prompt: scan_outputs = selective_state_update( cache_params.ssm_state, @@ -214,46 +235,54 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, conv_state: torch.Tensor, - ssm_state: torch.Tensor + ssm_state: torch.Tensor, ): if attn_metadata.prefill_metadata is not None: max_seq_len = max(attn_metadata.prefill_metadata.prompt_lens) batch_size = len(attn_metadata.prefill_metadata.prompt_lens) - padded_hidden_states = torch.zeros(( - batch_size, - max_seq_len, - hidden_states.shape[-1], - ), dtype=hidden_states.dtype, device=hidden_states.device) + padded_hidden_states = torch.zeros( + ( + batch_size, + max_seq_len, + hidden_states.shape[-1], + ), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) offset = 0 - for i,prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): - padded_hidden_states[i,:prompt_len].copy_(hidden_states[offset:offset + prompt_len]) + for i, prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): + padded_hidden_states[i, :prompt_len].copy_( + hidden_states[offset : offset + prompt_len] + ) offset += prompt_len cache = MambaCacheParams( True, conv_state=conv_state[self.layer_idx], - ssm_state=ssm_state[self.layer_idx] + ssm_state=ssm_state[self.layer_idx], + ) + padded_hidden_states = self.mamba_forward( + padded_hidden_states, cache_params=cache ) - padded_hidden_states = self.mamba_forward(padded_hidden_states, cache_params=cache) offset = 0 - for i,prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): - hidden_states[offset:offset + prompt_len].copy_(padded_hidden_states[i,:prompt_len]) + for i, prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): + hidden_states[offset : offset + prompt_len].copy_( + padded_hidden_states[i, :prompt_len] + ) offset += prompt_len else: cache = MambaCacheParams( False, conv_state=conv_state[self.layer_idx], - ssm_state=ssm_state[self.layer_idx] + ssm_state=ssm_state[self.layer_idx], + ) + hidden_states = self.mamba_forward( + hidden_states.unsqueeze(1).contiguous(), cache_params=cache ) - hidden_states = self.mamba_forward(hidden_states.unsqueeze(1).contiguous(), cache_params=cache) hidden_states = hidden_states.squeeze(1).contiguous() - - return hidden_states - - class JambaMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -264,13 +293,13 @@ class JambaMoE(nn.Module): """ def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -285,34 +314,53 @@ def __init__( if self.num_total_experts > 1: # init expert router iff this layer has multiple experts - self.router = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - linear_method=None) + self.router = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + linear_method=None, + ) self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype, + ) + ) self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) - - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype, + ) + ) + + set_weight_attrs( + self.ws, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2s, + { + "weight_loader": self.weight_loader, + }, + ) + + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size @@ -320,7 +368,9 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, if weight_name.endswith("gate_proj.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("up_proj.weight"): - param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] if weight_name.endswith("down_proj.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] @@ -331,27 +381,36 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.num_total_experts > 1: router_logits, _ = self.router(hidden_states) else: - router_logits = torch.ones([hidden_states.shape[0], 1], device=hidden_states.device, - dtype=hidden_states.dtype) + router_logits = torch.ones( + [hidden_states.shape[0], 1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=False, # Mixtral normalize the expert probs to 1. We don't! - inplace=True) + final_hidden_states = fused_moe( + hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=False, # Mixtral normalize the expert probs to 1. We don't! + inplace=True, + ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states.view(num_tokens, - hidden_size) + return final_hidden_states.view(num_tokens, hidden_size) class JambaMambaDecoderLayer(nn.Module): - def __init__(self, config: JambaConfig, actual_num_experts: int, actual_num_experts_per_tok: int ,layer_idx: int) -> None: + def __init__( + self, + config: JambaConfig, + actual_num_experts: int, + actual_num_experts_per_tok: int, + layer_idx: int, + ) -> None: super().__init__() self.layer_idx = layer_idx self.config = config @@ -360,42 +419,41 @@ def __init__(self, config: JambaConfig, actual_num_experts: int, actual_num_expe num_experts=actual_num_experts, top_k=actual_num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward(self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - **kwargs): + intermediate_size=config.intermediate_size, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_moe_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.mamba( - hidden_states, - attn_metadata, - conv_state, - ssm_state - ) + hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, ssm_state) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_moe_layernorm(hidden_states, residual) hidden_states = self.moe(hidden_states) return hidden_states, residual class JambaAttentionDecoderLayer(nn.Module): def __init__( - self, config: JambaConfig, actual_num_experts: int, actual_num_experts_per_tok: int ,layer_idx: int, linear_method: Optional[LinearMethodBase] = None + self, + config: JambaConfig, + actual_num_experts: int, + actual_num_experts_per_tok: int, + layer_idx: int, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -416,7 +474,7 @@ def __init__( self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( @@ -446,18 +504,19 @@ def __init__( num_experts=actual_num_experts, top_k=actual_num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def self_attention(self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - **kwargs) -> torch.Tensor: + intermediate_size=config.intermediate_size, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_moe_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) @@ -465,13 +524,14 @@ def self_attention(self, return output def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - **kwargs): + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -479,30 +539,32 @@ def forward( hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attention( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, ) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_moe_layernorm(hidden_states, residual) hidden_states = self.moe(hidden_states) return hidden_states, residual class JambaModel(nn.Module): def __init__( - self, - config: JambaConfig, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None, + self, + config: JambaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -516,37 +578,54 @@ def __init__( module_list = [] for i in range(config.num_hidden_layers): - is_attn = True if (i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0 else False - is_expert = True if (i - self.config.expert_layer_offset) % self.config.expert_layer_period == 0 else False + is_attn = ( + True + if (i - self.config.attn_layer_offset) % self.config.attn_layer_period + == 0 + else False + ) + is_expert = ( + True + if (i - self.config.expert_layer_offset) + % self.config.expert_layer_period + == 0 + else False + ) actual_num_experts = config.num_experts if is_expert else 1 actual_num_experts_per_tok = config.num_experts_per_tok if is_expert else 1 if is_attn: - module_list.append(JambaAttentionDecoderLayer(config, - actual_num_experts=actual_num_experts, - actual_num_experts_per_tok=actual_num_experts_per_tok, - layer_idx=i, - linear_method=linear_method - )) + module_list.append( + JambaAttentionDecoderLayer( + config, + actual_num_experts=actual_num_experts, + actual_num_experts_per_tok=actual_num_experts_per_tok, + layer_idx=i, + linear_method=linear_method, + ) + ) else: - module_list.append(JambaMambaDecoderLayer(config, - actual_num_experts=actual_num_experts, - actual_num_experts_per_tok=actual_num_experts_per_tok, - layer_idx=i - )) + module_list.append( + JambaMambaDecoderLayer( + config, + actual_num_experts=actual_num_experts, + actual_num_experts_per_tok=actual_num_experts_per_tok, + layer_idx=i, + ) + ) self.layers = nn.ModuleList(module_list) self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -554,15 +633,18 @@ def forward( layer = self.layers[i] kv_cache = None if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] - hidden_states, residual = layer(positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - residual=residual, - conv_state=conv_state, - ssm_state=ssm_state - ) + kv_cache = kv_caches[ + (i - self.config.attn_layer_offset) // self.config.attn_layer_period + ] + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + conv_state=conv_state, + ssm_state=ssm_state, + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -590,17 +672,15 @@ class JambaForCausalLM(nn.Module): embedding_padding_modules = ["lm_head"] def __init__( - self, - config: JambaConfig, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None, + self, + config: JambaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = JambaModel(config, - linear_method, - lora_config=lora_config) + self.model = JambaModel(config, linear_method, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -613,33 +693,31 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.sampler = Sampler() def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor - ): + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - conv_state, - ssm_state + input_ids, positions, kv_caches, attn_metadata, conv_state, ssm_state ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, - sampling_metadata) + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor( + self.lm_head.weight, hidden_states, sampling_metadata + ) return logits def sample( @@ -650,11 +728,13 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -664,29 +744,29 @@ def load_weights(self, expert_params_mapping = [ # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["gate_proj", "up_proj"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ( + "ws" if weight_name in ["gate_proj", "up_proj"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(self.config.num_experts) for weight_name in ["down_proj", "up_proj", "gate_proj"] ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=True): # erez - might need to change later to False + model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=True + ): # erez - might need to change later to False if "rotary_emb.inv_freq" in name: continue if "A_log" in name: - name = name.replace("A_log","A") + name = name.replace("A_log", "A") if ".self_attn." in name: name = name.replace(".self_attn", "") - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -706,10 +786,9 @@ def load_weights(self, if name in params_dict: param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=expert_id + ) break else: # Skip loading extra bias for GPTQ models. @@ -718,6 +797,7 @@ def load_weights(self, if name in params_dict: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) From 7e3415ec42c328247c06cdc61810c308e6b67c9f Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 14:48:51 +0300 Subject: [PATCH 023/110] Formatting --- vllm/model_executor/models/jamba.py | 17 +++++++---------- vllm/worker/model_runner.py | 4 ++-- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 4fe65231f8b..e65e797e6f1 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -241,13 +241,9 @@ def forward( max_seq_len = max(attn_metadata.prefill_metadata.prompt_lens) batch_size = len(attn_metadata.prefill_metadata.prompt_lens) padded_hidden_states = torch.zeros( - ( - batch_size, - max_seq_len, - hidden_states.shape[-1], - ), + (batch_size, max_seq_len, hidden_states.shape[-1]), dtype=hidden_states.dtype, - device=hidden_states.device, + device=hidden_states.device ) offset = 0 for i, prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): @@ -273,12 +269,12 @@ def forward( cache = MambaCacheParams( False, conv_state=conv_state[self.layer_idx], - ssm_state=ssm_state[self.layer_idx], + ssm_state=ssm_state[self.layer_idx] ) hidden_states = self.mamba_forward( - hidden_states.unsqueeze(1).contiguous(), cache_params=cache + hidden_states.unsqueeze(1), cache_params=cache ) - hidden_states = hidden_states.squeeze(1).contiguous() + hidden_states = hidden_states.squeeze(1) return hidden_states @@ -489,7 +485,7 @@ def __init__( self.total_num_heads * self.head_dim, config.hidden_size, bias=False, - linear_method=linear_method, + linear_method=linear_method ) self.attn = Attention( @@ -629,6 +625,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None + for i in range(len(self.layers)): layer = self.layers[i] kv_cache = None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d077aabd8bc..41fceb5fc3a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -916,8 +916,8 @@ def execute_model( if self.is_mamba: for i, offset in enumerate(indices): - self.mamba_cache[0][:, offset] = conv_state[:, i] - self.mamba_cache[1][:, offset] = ssm_state[:, i] + self.mamba_cache[0][:, offset].copy_(conv_state[:, i]) + self.mamba_cache[1][:, offset].copy_(ssm_state[:, i]) # Sample the next token. output = self.model.sample( From adbd2aeebac7e4c55d7a6c4648b9367c1af7440d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 15:01:37 +0300 Subject: [PATCH 024/110] Formatting with format.sh --- vllm/engine/llm_engine.py | 6 +- vllm/executor/executor_base.py | 3 +- vllm/executor/gpu_executor.py | 3 +- vllm/executor/ray_gpu_executor.py | 10 +- vllm/model_executor/__init__.py | 5 +- vllm/model_executor/mamba_metadata.py | 2 - vllm/model_executor/models/jamba.py | 223 +++++++++----------- vllm/transformers_utils/configs/__init__.py | 6 +- vllm/transformers_utils/configs/jamba.py | 72 ++++--- vllm/worker/cache_engine.py | 12 +- vllm/worker/model_runner.py | 117 +++++----- vllm/worker/worker.py | 6 +- 12 files changed, 223 insertions(+), 242 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0c327ca24c1..8a14c4e2c61 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -652,13 +652,13 @@ def _process_model_outputs( # Free the finished sequence groups. finished_seq_groups_req_ids = [ - seq_group.request_id - for seq_group in self.scheduler.running + seq_group.request_id for seq_group in self.scheduler.running if seq_group.is_finished() ] if len(finished_seq_groups_req_ids) > 0: - self.model_executor.release_mamba_cache(finished_seq_groups_req_ids) + self.model_executor.release_mamba_cache( + finished_seq_groups_req_ids) self.scheduler.free_finished_seq_groups() # Create the outputs. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 3121a96f8cc..79d8424f71a 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -95,9 +95,10 @@ def check_health(self) -> None: raise NotImplementedError @abstractmethod - def release_mamba_cache(self,requests_id:List[str]) -> None: + def release_mamba_cache(self, requests_id: List[str]) -> None: raise NotImplementedError + class ExecutorAsyncBase(ExecutorBase): @abstractmethod diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 06d677b13c0..faac55d7a41 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -92,11 +92,10 @@ def check_health(self) -> None: # it's running. return - def release_mamba_cache(self, requests_id:List[str]) -> None: + def release_mamba_cache(self, requests_id: List[str]) -> None: self.driver_worker.release_mamba_cache(requests_id) - class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index e8c15f1a531..d6080db74b1 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -364,11 +364,10 @@ def _check_if_any_actor_is_dead(self): raise RuntimeError("At least one Worker is dead. " f"Dead Workers: {dead_actors}. ") - def release_mamba_cache(self, requests_id:List[str]) -> None: - self._run_workers( - "release_mamba_cache", - requests_id= requests_id, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + def release_mamba_cache(self, requests_id: List[str]) -> None: + self._run_workers("release_mamba_cache", + requests_id=requests_id, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): @@ -423,4 +422,3 @@ async def execute_model_async( async def check_health_async(self) -> None: """Raises an error if engine is unhealthy.""" self._check_if_any_actor_is_dead() - diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 8be2a869f33..c1c231fcb4d 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -4,8 +4,5 @@ from vllm.model_executor.utils import set_random_seed __all__ = [ - "SamplingMetadata", - "set_random_seed", - "MambaCacheParams", - "RequestInfo" + "SamplingMetadata", "set_random_seed", "MambaCacheParams", "RequestInfo" ] diff --git a/vllm/model_executor/mamba_metadata.py b/vllm/model_executor/mamba_metadata.py index 7e349b5e49b..3ee6bdf14b8 100644 --- a/vllm/model_executor/mamba_metadata.py +++ b/vllm/model_executor/mamba_metadata.py @@ -15,5 +15,3 @@ class MambaCacheParams: class RequestInfo: request_id: str = '' seqs_id: List[int] = field(default_factory=list) - - diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index e65e797e6f1..18827846710 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,5 +1,4 @@ # coding=utf-8 - """Inference-only Jurassic model.""" from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple @@ -80,9 +79,9 @@ def __init__(self, config: JambaConfig, layer_idx): # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = MergedColumnParallelLinear( - self.hidden_size, [self.intermediate_size] * 2, bias=self.use_bias - ) + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) # selective projection used to make dt, B and C input dependant self.x_proj = RowParallelLinear( self.intermediate_size, @@ -92,18 +91,17 @@ def __init__(self, config: JambaConfig, layer_idx): # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear( - self.time_step_rank, self.intermediate_size, bias=True, skip_bias_add=True - ) + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) def weight_loader(param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[ - tp_rank - ] - ) + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): weight_loader(param, -torch.exp(loaded_weight.float())) @@ -114,8 +112,7 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.intermediate_size // tp_size, self.ssm_state_size, dtype=torch.float32, - ) - ) + )) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) set_weight_attrs(self.D, {"weight_loader": weight_loader}) @@ -131,9 +128,12 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.apply_inner_layernorms = config.mamba_inner_layernorms if self.apply_inner_layernorms: - self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) - self.B_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) - self.C_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.rms_norm_eps) + self.B_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + self.C_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) else: self.dt_layernorm = None self.B_layernorm = None @@ -148,17 +148,16 @@ def _apply_layernorms(self, dt, B, C): C = self.C_layernorm.forward(C.contiguous()) return dt, B, C - def mamba_forward( - self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None - ): + def mamba_forward(self, + hidden_states: torch.Tensor, + cache_params: MambaCacheParams = None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ) + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) if cache_params is not None and not cache_params.is_prompt: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), @@ -171,8 +170,8 @@ def mamba_forward( else: if cache_params is not None: conv_states = nn.functional.pad( - hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) - ) + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) hidden_states = causal_conv1d_fn( hidden_states, @@ -194,9 +193,8 @@ def mamba_forward( discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = ( - self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None - ) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) if cache_params is not None and not cache_params.is_prompt: scan_outputs = selective_state_update( cache_params.ssm_state, @@ -243,37 +241,32 @@ def forward( padded_hidden_states = torch.zeros( (batch_size, max_seq_len, hidden_states.shape[-1]), dtype=hidden_states.dtype, - device=hidden_states.device - ) + device=hidden_states.device) offset = 0 - for i, prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.prompt_lens): padded_hidden_states[i, :prompt_len].copy_( - hidden_states[offset : offset + prompt_len] - ) + hidden_states[offset:offset + prompt_len]) offset += prompt_len cache = MambaCacheParams( True, conv_state=conv_state[self.layer_idx], ssm_state=ssm_state[self.layer_idx], ) - padded_hidden_states = self.mamba_forward( - padded_hidden_states, cache_params=cache - ) + padded_hidden_states = self.mamba_forward(padded_hidden_states, + cache_params=cache) offset = 0 - for i, prompt_len in enumerate(attn_metadata.prefill_metadata.prompt_lens): - hidden_states[offset : offset + prompt_len].copy_( - padded_hidden_states[i, :prompt_len] - ) + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.prompt_lens): + hidden_states[offset:offset + prompt_len].copy_( + padded_hidden_states[i, :prompt_len]) offset += prompt_len else: - cache = MambaCacheParams( - False, - conv_state=conv_state[self.layer_idx], - ssm_state=ssm_state[self.layer_idx] - ) - hidden_states = self.mamba_forward( - hidden_states.unsqueeze(1), cache_params=cache - ) + cache = MambaCacheParams(False, + conv_state=conv_state[self.layer_idx], + ssm_state=ssm_state[self.layer_idx]) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) hidden_states = hidden_states.squeeze(1) return hidden_states @@ -325,8 +318,7 @@ def __init__( self.hidden_size, device="cuda", dtype=self.params_dtype, - ) - ) + )) self.w2s = nn.Parameter( torch.empty( self.num_total_experts, @@ -334,8 +326,7 @@ def __init__( self.intermediate_size, device="cuda", dtype=self.params_dtype, - ) - ) + )) set_weight_attrs( self.ws, @@ -364,9 +355,8 @@ def weight_loader( if weight_name.endswith("gate_proj.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("up_proj.weight"): - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("down_proj.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] @@ -389,17 +379,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.w2s, router_logits, self.top_k, - renormalize=False, # Mixtral normalize the expert probs to 1. We don't! + renormalize= + False, # Mixtral normalize the expert probs to 1. We don't! inplace=True, ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class JambaMambaDecoderLayer(nn.Module): + def __init__( self, config: JambaConfig, @@ -417,8 +410,10 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -433,16 +428,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, ssm_state) + hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, + ssm_state) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm(hidden_states, residual) + hidden_states, residual = self.pre_moe_layernorm( + hidden_states, residual) hidden_states = self.moe(hidden_states) return hidden_states, residual class JambaAttentionDecoderLayer(nn.Module): + def __init__( self, config: JambaConfig, @@ -481,12 +480,10 @@ def __init__( bias=False, linear_method=linear_method, ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - linear_method=linear_method - ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + linear_method=linear_method) self.attn = Attention( self.num_heads, @@ -502,8 +499,10 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def self_attention( self, @@ -532,7 +531,8 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attention( positions=positions, @@ -541,12 +541,14 @@ def forward( attn_metadata=attn_metadata, ) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm(hidden_states, residual) + hidden_states, residual = self.pre_moe_layernorm( + hidden_states, residual) hidden_states = self.moe(hidden_states) return hidden_states, residual class JambaModel(nn.Module): + def __init__( self, config: JambaConfig, @@ -556,11 +558,8 @@ def __init__( super().__init__() self.config = config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -574,19 +573,10 @@ def __init__( module_list = [] for i in range(config.num_hidden_layers): - is_attn = ( - True - if (i - self.config.attn_layer_offset) % self.config.attn_layer_period - == 0 - else False - ) - is_expert = ( - True - if (i - self.config.expert_layer_offset) - % self.config.expert_layer_period - == 0 - else False - ) + is_attn = (True if (i - self.config.attn_layer_offset) % + self.config.attn_layer_period == 0 else False) + is_expert = (True if (i - self.config.expert_layer_offset) % + self.config.expert_layer_period == 0 else False) actual_num_experts = config.num_experts if is_expert else 1 actual_num_experts_per_tok = config.num_experts_per_tok if is_expert else 1 @@ -599,8 +589,7 @@ def __init__( actual_num_experts_per_tok=actual_num_experts_per_tok, layer_idx=i, linear_method=linear_method, - ) - ) + )) else: module_list.append( JambaMambaDecoderLayer( @@ -608,11 +597,11 @@ def __init__( actual_num_experts=actual_num_experts, actual_num_experts_per_tok=actual_num_experts_per_tok, layer_idx=i, - ) - ) + )) self.layers = nn.ModuleList(module_list) - self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -630,9 +619,8 @@ def forward( layer = self.layers[i] kv_cache = None if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache = kv_caches[ - (i - self.config.attn_layer_offset) // self.config.attn_layer_period - ] + kv_cache = kv_caches[(i - self.config.attn_layer_offset) // + self.config.attn_layer_period] hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, @@ -690,9 +678,8 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward( @@ -704,17 +691,14 @@ def forward( conv_state: torch.Tensor, ssm_state: torch.Tensor, ): - hidden_states = self.model( - input_ids, positions, kv_caches, attn_metadata, conv_state, ssm_state - ) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, conv_state, ssm_state) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor( - self.lm_head.weight, hidden_states, sampling_metadata - ) + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) return logits def sample( @@ -745,14 +729,17 @@ def load_weights( "ws" if weight_name in ["gate_proj", "up_proj"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id, - ) - for expert_id in range(self.config.num_experts) + ) for expert_id in range(self.config.num_experts) for weight_name in ["down_proj", "up_proj", "gate_proj"] ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=True + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=True ): # erez - might need to change later to False if "rotary_emb.inv_freq" in name: continue @@ -783,9 +770,10 @@ def load_weights( if name in params_dict: param = params_dict[name] weight_loader = param.weight_loader - weight_loader( - param, loaded_weight, weight_name, expert_id=expert_id - ) + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. @@ -794,7 +782,6 @@ def load_weights( if name in params_dict: param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 42084ac0067..4e4cdcb6ee1 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,10 +10,6 @@ from vllm.transformers_utils.configs.jamba import JambaConfig __all__ = [ - "ChatGLMConfig", - "DbrxConfig", - "MPTConfig", - "RWConfig", - "JAISConfig", + "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig", "JambaConfig" ] diff --git a/vllm/transformers_utils/configs/jamba.py b/vllm/transformers_utils/configs/jamba.py index 7c58fe35a87..88d38e7d6af 100644 --- a/vllm/transformers_utils/configs/jamba.py +++ b/vllm/transformers_utils/configs/jamba.py @@ -68,40 +68,40 @@ class JambaConfig(PretrainedConfig): keys_to_ignore_at_inference = ["past_key_values"] def __init__( - self, - vocab_size=65536, - tie_word_embeddings=False, - hidden_size=4096, - intermediate_size=14336, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - output_router_logits=False, - router_aux_loss_coef=0.001, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - sliding_window=None, - attention_dropout=0.0, - num_experts_per_tok=2, - num_experts=16, - expert_layer_offset=1, - expert_layer_period=2, - attn_layer_period=8, - attn_layer_offset=4, - use_mamba_kernels=True, - mamba_d_state=16, - mamba_d_conv=4, - mamba_expand=2, - mamba_dt_rank="auto", - mamba_conv_bias=True, - mamba_proj_bias=False, - mamba_inner_layernorms=True, - **kwargs, + self, + vocab_size=65536, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_experts=16, + expert_layer_offset=1, + expert_layer_period=2, + attn_layer_period=8, + attn_layer_offset=4, + use_mamba_kernels=True, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_dt_rank="auto", + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_inner_layernorms=True, + **kwargs, ): self.vocab_size = vocab_size self.tie_word_embeddings = tie_word_embeddings @@ -136,7 +136,9 @@ def __init__( self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand - self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank + self.mamba_dt_rank = math.ceil( + self.hidden_size / + 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.mamba_conv_bias = mamba_conv_bias self.mamba_proj_bias = mamba_proj_bias self.mamba_inner_layernorms = mamba_inner_layernorms diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ee6eb58afc6..cf76e41013e 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -30,7 +30,8 @@ def __init__( self.parallel_config = parallel_config self.head_size = model_config.get_head_size() - self.num_layers = CacheEngine.get_num_attention_layers(model_config, parallel_config) + self.num_layers = CacheEngine.get_num_attention_layers( + model_config, parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size @@ -81,10 +82,8 @@ def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @staticmethod - def get_num_attention_layers( - model_config:ModelConfig, - parallel_config:ParallelConfig - ): + def get_num_attention_layers(model_config: ModelConfig, + parallel_config: ParallelConfig): num_layers = model_config.get_num_layers(parallel_config) is_mamba = model_config.hf_config.model_type == "jamba" if is_mamba: @@ -100,7 +99,8 @@ def get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = CacheEngine.get_num_attention_layers(model_config,parallel_config) + num_layers = CacheEngine.get_num_attention_layers( + model_config, parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 41fceb5fc3a..9841e7e010d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -186,11 +186,18 @@ def prepare_contiguous_mamba_cache(self, dtype): ) if self.mamba_cache is None: self.mamba_cache = {} - self.mamba_cache = (torch.empty(size=conv_state_shape, dtype=dtype, device="cuda"), - torch.empty(size=ssm_state_shape, dtype=dtype, device="cuda")) - self.mamba_cache4gc = (torch.empty(size=conv_state_shape, dtype=dtype, device="cuda"), - torch.empty(size=ssm_state_shape, dtype=dtype, device="cuda")) - + self.mamba_cache = (torch.empty(size=conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=ssm_state_shape, + dtype=dtype, + device="cuda")) + self.mamba_cache4gc = (torch.empty(size=conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=ssm_state_shape, + dtype=dtype, + device="cuda")) def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -763,10 +770,11 @@ def prepare_input_tensors( batch_type = BatchType.PREFILL else: batch_type = BatchType.DECODE - requests_info = [ RequestInfo( - request_id=req.request_id, - seqs_id=list(req.seq_data.keys()) - ) for req in seq_group_metadata_list] + requests_info = [ + RequestInfo(request_id=req.request_id, + seqs_id=list(req.seq_data.keys())) + for req in seq_group_metadata_list + ] metadata_dict = { "input_tokens": input_tokens, @@ -856,8 +864,9 @@ def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.request2i: indices = self.request2i.pop(req_id) - logger.debug(f"Deleted { req_id } from mamba_cache with indices = {indices}") - + logger.debug( + f"Deleted { req_id } from mamba_cache with indices = {indices}" + ) @torch.inference_mode() def execute_model( @@ -866,8 +875,8 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_input, requests_info - ) = self.prepare_input_tensors(seq_group_metadata_list) + lora_requests, lora_mapping, multi_modal_input, + requests_info) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -895,14 +904,12 @@ def execute_model( if self.mamba_cache is None: self.prepare_contiguous_mamba_cache(self.model_config.dtype) conv_state, ssm_state, indices = self._prepare_request_mamba_cache( - requests_info, - input_tokens.shape[0] if - attn_metadata.prefill_metadata is None else len(requests_info) - ) + requests_info, input_tokens.shape[0] if + attn_metadata.prefill_metadata is None else len(requests_info)) execute_model_kwargs = { **execute_model_kwargs, - "conv_state":conv_state, - "ssm_state":ssm_state, + "conv_state": conv_state, + "ssm_state": ssm_state, } hidden_states = model_executable(**execute_model_kwargs) @@ -929,22 +936,15 @@ def execute_model( def _get_first_free_mamba_cache_index(self): max_possible_bs = self.mamba_cache[0].shape[1] occupied = [ - id - for seq_ids in self.request2i.values() + id for seq_ids in self.request2i.values() for id in seq_ids.values() ] - first_free_index = [ - i not in occupied - for i in range(max_possible_bs) - ].index(True) + first_free_index = [i not in occupied + for i in range(max_possible_bs)].index(True) return first_free_index - - def _prepare_request_mamba_cache( - self, - requests_info: List[RequestInfo], - batch_size: int - ): + def _prepare_request_mamba_cache(self, requests_info: List[RequestInfo], + batch_size: int): indices = [] max_possible_bs = self.mamba_cache[0].shape[1] for request_info in requests_info: @@ -961,26 +961,26 @@ def _prepare_request_mamba_cache( f_free_index = self._get_first_free_mamba_cache_index() ## case of decoding n>1 i_exist = list(self.request2i[cur_rid].values())[0] - self.mamba_cache[0][:,f_free_index].copy_( - self.mamba_cache[0][:,i_exist] - ) - self.mamba_cache[1][:,f_free_index].copy_( - self.mamba_cache[1][:,i_exist] - ) + self.mamba_cache[0][:, f_free_index].copy_( + self.mamba_cache[0][:, i_exist]) + self.mamba_cache[1][:, f_free_index].copy_( + self.mamba_cache[1][:, i_exist]) self.request2i[cur_rid][seq_id] = f_free_index indices.append(self.request2i[cur_rid][seq_id]) ## Pad the batch incase of running batch that was not captured via CG padded_indices = indices for _ in range(batch_size - len(indices)): - occu = [i for s_ids in self.request2i.values() for i in s_ids.values()] + occu = [ + i for s_ids in self.request2i.values() for i in s_ids.values() + ] padded_indices += [[ i not in set(occu).union(padded_indices) for i in range(max_possible_bs) ].index(True)] - conv_state = self.mamba_cache[0][:,padded_indices] - ssm_state = self.mamba_cache[1][:,padded_indices] - return conv_state,ssm_state,indices + conv_state = self.mamba_cache[0][:, padded_indices] + ssm_state = self.mamba_cache[1][:, padded_indices] + return conv_state, ssm_state, indices @torch.inference_mode() def profile_run(self) -> None: @@ -1162,17 +1162,19 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model,self.is_mamba) + graph_runner = CUDAGraphRunner(self.model, self.is_mamba) capture_inputs = { - "input_ids" : input_tokens[:batch_size], - "positions" :input_positions[:batch_size], + "input_ids": input_tokens[:batch_size], + "positions": input_positions[:batch_size], "kv_caches": kv_caches, "attn_metadata": attn_metadata, - "memory_pool":self.graph_memory_pool, + "memory_pool": self.graph_memory_pool, } if self.is_mamba: - capture_inputs["conv_state"]=self.mamba_cache4gc[0][:, :batch_size] - capture_inputs["ssm_state"]=self.mamba_cache4gc[1][:, :batch_size] + capture_inputs["conv_state"] = self.mamba_cache4gc[ + 0][:, :batch_size] + capture_inputs["ssm_state"] = self.mamba_cache4gc[ + 1][:, :batch_size] graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -1222,17 +1224,17 @@ def capture( # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). model_inputs = { - "input_ids":input_ids, - "positions":positions, - "kv_caches":kv_caches, - "attn_metadata":attn_metadata, + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, **kwargs } if self.is_mamba: model_inputs = { **model_inputs, - "conv_state":conv_state, - "ssm_state":ssm_state, + "conv_state": conv_state, + "ssm_state": ssm_state, } with _maybe_pynccl(): @@ -1275,8 +1277,8 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state:Optional[torch.Tensor] = None, - ssm_state:Optional[torch.Tensor] = None, + conv_state: Optional[torch.Tensor] = None, + ssm_state: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. @@ -1293,9 +1295,8 @@ def forward( attn_metadata.decode_metadata.block_tables, non_blocking=True) if self.is_mamba: self.input_buffers["conv_state"].copy_(conv_state, - non_blocking=True) - self.input_buffers["ssm_state"].copy_(ssm_state, - non_blocking=True) + non_blocking=True) + self.input_buffers["ssm_state"].copy_(ssm_state, non_blocking=True) # Run the graph. self.graph.replay() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 16663f18500..990de79178a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -182,7 +182,8 @@ def _init_cache_engine(self): self.model_runner.set_block_size(self.cache_engine.block_size) is_mamba = self.model_config.hf_config.model_type == "jamba" if is_mamba: - self.model_runner.prepare_contiguous_mamba_cache(self.cache_engine.dtype) + self.model_runner.prepare_contiguous_mamba_cache( + self.cache_engine.dtype) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -268,9 +269,10 @@ def get_cache_block_size_bytes(self) -> int: self.model_config, self.parallel_config) - def release_mamba_cache(self,requests_id:List[str]): + def release_mamba_cache(self, requests_id: List[str]): self.model_runner.release_mamba_cache(requests_id) + def init_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, From 6daf2a2cbb9d30afaa6a3e462faffda1ab9eb02c Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 15:09:44 +0300 Subject: [PATCH 025/110] Adding to docs and more --- docs/source/models/supported_models.rst | 4 ++++ vllm/model_executor/models/jamba.py | 29 +++++++++++-------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 5e5ce871f61..9ca0303873e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -79,6 +79,10 @@ Alongside each architecture, we include some popular models that use it. - Jais - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - + * - :code:`JambaForCausalLM` + - Jamba + - :code:`ai21labs/Jamba-v0.1`, etc. + - ✅︎ * - :code:`LlamaForCausalLM` - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 18827846710..092632d6ce3 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -757,31 +757,28 @@ def load_weights( # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name in params_dict: - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break else: for param_name, weight_name, expert_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name in params_dict: - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 7ee927bd90d7c17cb74a8faf609cc8f632efb0d1 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 15:11:38 +0300 Subject: [PATCH 026/110] Add to readme --- README.md | 1 + vllm/model_executor/models/jamba.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8434c118833..ec0b5ed9839 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) +- Jamba (`ai21labs/Jamba-v0.1`, etc.) - Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 092632d6ce3..58f4595add6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -769,9 +769,9 @@ def load_weights( param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + loaded_weight, + weight_name, + expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. From 87fa2994241d5b753a032f10aa2a1dcc87594ed4 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 15:54:32 +0300 Subject: [PATCH 027/110] Adding comments for prefill mamba --- vllm/model_executor/models/jamba.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 58f4595add6..8548619fc55 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -236,6 +236,9 @@ def forward( ssm_state: torch.Tensor, ): if attn_metadata.prefill_metadata is not None: + # Mamba doesn't support chunked prefill, + # We pad the hidden_states before the forward pass and + # unpad it again afterwards. max_seq_len = max(attn_metadata.prefill_metadata.prompt_lens) batch_size = len(attn_metadata.prefill_metadata.prompt_lens) padded_hidden_states = torch.zeros( From 8bca3b65b49c6db4ed5997c6cd1071b804821709 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 16 Apr 2024 15:56:11 +0300 Subject: [PATCH 028/110] Formating --- vllm/model_executor/models/jamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 8548619fc55..634d339c0c9 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -236,8 +236,8 @@ def forward( ssm_state: torch.Tensor, ): if attn_metadata.prefill_metadata is not None: - # Mamba doesn't support chunked prefill, - # We pad the hidden_states before the forward pass and + # Mamba doesn't support chunked prefill, + # We pad the hidden_states before the forward pass and # unpad it again afterwards. max_seq_len = max(attn_metadata.prefill_metadata.prompt_lens) batch_size = len(attn_metadata.prefill_metadata.prompt_lens) From b421877bcad7574a3398e477cf546f580dcb75d7 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 17 Apr 2024 15:56:57 +0300 Subject: [PATCH 029/110] Remove mamba-ssm and conv1d from the build system requirements --- pyproject.toml | 2 -- requirements-common.txt | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 342011d0d29..b870a4b8589 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,6 @@ requires = [ "setuptools >= 49.4.0", "torch == 2.2.1", "wheel", - "mamba-ssm", - "causal-conv1d" ] build-backend = "setuptools.build_meta" diff --git a/requirements-common.txt b/requirements-common.txt index 0c7f243583e..768f94e1d63 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -16,4 +16,4 @@ outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 mamba-ssm -causal-conv1d +causal-conv1d >= 1.2.0 From d9c33199bdf4d4c4e452f0da4680bdd118a26b42 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 28 Apr 2024 11:54:54 +0300 Subject: [PATCH 030/110] Remove autoconfig for jamba --- vllm/transformers_utils/configs/jamba.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/transformers_utils/configs/jamba.py b/vllm/transformers_utils/configs/jamba.py index 88d38e7d6af..d2f04c86801 100644 --- a/vllm/transformers_utils/configs/jamba.py +++ b/vllm/transformers_utils/configs/jamba.py @@ -150,6 +150,3 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) - - -AutoConfig.register('jamba', JambaConfig) From 1c0fad80404e3dc23fbbe61042e82fd9acb62bcb Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 28 Apr 2024 12:00:51 +0300 Subject: [PATCH 031/110] Move get_attention_num_layers to model_config --- vllm/config.py | 9 +++++++++ vllm/worker/cache_engine.py | 16 ++-------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index bf31b03b7c6..355b1ea7e5c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -331,6 +331,15 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size + def get_num_attention_layers(self, + parallel_config: "ParallelConfig") -> int: + num_layers = self.get_num_layers(parallel_config) + is_mamba = self.hf_config.model_type in ["jamba"] + if is_mamba: + attention_period = model_config.hf_config.attn_layer_period + num_layers = max(num_layers // attention_period, 1) + return num_layers + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index cf76e41013e..6dae97d1aa5 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -30,8 +30,7 @@ def __init__( self.parallel_config = parallel_config self.head_size = model_config.get_head_size() - self.num_layers = CacheEngine.get_num_attention_layers( - model_config, parallel_config) + self.num_layers = model_config.get_num_attention_layers(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size @@ -81,16 +80,6 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) - @staticmethod - def get_num_attention_layers(model_config: ModelConfig, - parallel_config: ParallelConfig): - num_layers = model_config.get_num_layers(parallel_config) - is_mamba = model_config.hf_config.model_type == "jamba" - if is_mamba: - attention_period = model_config.hf_config.attn_layer_period - num_layers = max(num_layers // attention_period, 1) - return num_layers - @staticmethod def get_cache_block_size( cache_config: CacheConfig, @@ -99,8 +88,7 @@ def get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = CacheEngine.get_num_attention_layers( - model_config, parallel_config) + num_layers = model_config.get_num_attention_layers(parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) From 1033904bfa7c0a15da40c8b235a910c4f2ed84b5 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 28 Apr 2024 12:28:15 +0300 Subject: [PATCH 032/110] Fix in model config --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 6696df66236..8baa544b515 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -280,7 +280,7 @@ def get_num_attention_layers(self, num_layers = self.get_num_layers(parallel_config) is_mamba = self.hf_config.model_type in ["jamba"] if is_mamba: - attention_period = model_config.hf_config.attn_layer_period + attention_period = self.hf_config.attn_layer_period num_layers = max(num_layers // attention_period, 1) return num_layers From 2b93182ac5495ae0d5b24324ccd116c51024370d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 28 Apr 2024 13:00:17 +0300 Subject: [PATCH 033/110] Formatting --- vllm/model_executor/__init__.py | 3 +- vllm/model_executor/models/jamba.py | 79 ++++++++++----------- vllm/transformers_utils/config.py | 4 +- vllm/transformers_utils/configs/__init__.py | 3 +- vllm/transformers_utils/configs/jamba.py | 44 +++++++----- vllm/worker/cache_engine.py | 3 +- vllm/worker/model_runner.py | 57 +++++++-------- 7 files changed, 96 insertions(+), 97 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index c1c231fcb4d..7747a40de5a 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,6 +1,5 @@ -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 634d339c0c9..2782cdd63a3 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,49 +1,39 @@ # coding=utf-8 """Inference-only Jurassic model.""" -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from torch.nn.parameter import Parameter + from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.model_executor.mamba_metadata import MambaCacheParams - -from vllm.transformers_utils.configs.jamba import JambaConfig -from torch.nn.parameter import Parameter from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, - ParallelLMHead, - DEFAULT_VOCAB_PADDING_SIZE, -) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.mamba_metadata import MambaCacheParams from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) +from vllm.transformers_utils.configs.jamba import JambaConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -51,10 +41,13 @@ # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ - Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. - A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) - ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, - and is why Mamba is called **selective** state spaces) + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) """ def __init__(self, config: JambaConfig, layer_idx): @@ -82,7 +75,7 @@ def __init__(self, config: JambaConfig, layer_idx): self.in_proj = MergedColumnParallelLinear(self.hidden_size, [self.intermediate_size] * 2, bias=self.use_bias) - # selective projection used to make dt, B and C input dependant + # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, @@ -572,17 +565,19 @@ def __init__( org_num_embeddings=config.vocab_size, ) - # init each model layer, decide if it's mamba/attention and has experts and pass it down + # init each model layer, decide if it's mamba/attention and + # has experts and pass it down module_list = [] for i in range(config.num_hidden_layers): - is_attn = (True if (i - self.config.attn_layer_offset) % - self.config.attn_layer_period == 0 else False) - is_expert = (True if (i - self.config.expert_layer_offset) % - self.config.expert_layer_period == 0 else False) + is_attn = ((i - self.config.attn_layer_offset) % + self.config.attn_layer_period == 0) + is_expert = ((i - self.config.expert_layer_offset) % + self.config.expert_layer_period == 0) actual_num_experts = config.num_experts if is_expert else 1 - actual_num_experts_per_tok = config.num_experts_per_tok if is_expert else 1 + actual_num_experts_per_tok = config.num_experts_per_tok \ + if is_expert else 1 if is_attn: module_list.append( diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 366d3bb8ff2..87b27e08fe2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,7 +2,9 @@ from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig, JAISConfig, MPTConfig, RWConfig, JambaConfig +from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + JAISConfig, JambaConfig, + MPTConfig, RWConfig) _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4e4cdcb6ee1..4e8fdad727e 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,9 +5,8 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig -from vllm.transformers_utils.configs.mpt import MPTConfig - from vllm.transformers_utils.configs.jamba import JambaConfig +from vllm.transformers_utils.configs.mpt import MPTConfig __all__ = [ "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig", diff --git a/vllm/transformers_utils/configs/jamba.py b/vllm/transformers_utils/configs/jamba.py index d2f04c86801..440cfb2466d 100644 --- a/vllm/transformers_utils/configs/jamba.py +++ b/vllm/transformers_utils/configs/jamba.py @@ -1,14 +1,15 @@ """ Jamba model configuration""" import math + from transformers.configuration_utils import PretrainedConfig -from transformers import AutoConfig class JambaConfig(PretrainedConfig): r""" Args: vocab_size (`int`, *optional*, defaults to 65536): - Vocabulary size of the Jurassic model. Defines the number of different tokens that can be represented by the + Vocabulary size of the Jurassic model. Defines the + number of different tokens that can be represented by the `inputs_ids` passed when calling [`JurassicModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. @@ -17,26 +18,31 @@ class JambaConfig(PretrainedConfig): num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. + Number of attention heads for each attention layer + in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + This is the number of key_value heads that should + be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1 the model will use + Multi Query Attention (MQA) otherwise GQA is used. When converting + a multi-head checkpoint to a GQA checkpoint, each group key and + value head should be constructed by meanpooling all the original + heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). + If it is not specified, will default to `8`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - # max_position_embeddings (`int`, *optional*, defaults to `4096*32`): - # The maximum sequence length that this model might ever be used with. Jurassic's sliding window attention - # allows sequence of up to 4096*32 tokens. + The non-linear activation function (function or string) + in the decoder. initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. + Whether or not the model should return the last key/values + attentions (not used by all models). + Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*): The id of the padding token. bos_token_id (`int`, *optional*, defaults to 1): @@ -48,9 +54,11 @@ class JambaConfig(PretrainedConfig): rope_theta (`float`, *optional*, defaults to 1000000.0): The base period of the RoPE embeddings. sliding_window (`int`, *optional*): - Sliding window attention window size. If not specified, will default to `4096`. + Sliding window attention window size. If not specified, will + default to `4096`. num_experts_per_tok (`int`, *optional*, defaults to 2): - The number of experts to root per-token, can be also interpreted as the `top-p` routing + The number of experts to root per-token, can be also interpreted + as the `top-p` routing parameter num_experts (`int`, *optional*, defaults to 16): Number of experts per Sparse MLP layer. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 6dae97d1aa5..916d4fb01f9 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -30,7 +30,8 @@ def __init__( self.parallel_config = parallel_config self.head_size = model_config.get_head_size() - self.num_layers = model_config.get_num_attention_layers(parallel_config) + self.num_layers = model_config.get_num_attention_layers( + parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 387019d3cdf..0d9b01c5b07 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,11 +11,9 @@ get_attn_backend) from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import ( - broadcast_tensor_dict, - with_pynccl_for_all_reduce, - get_tensor_model_parallel_world_size, -) +from vllm.distributed import (broadcast_tensor_dict, + get_tensor_model_parallel_world_size, + with_pynccl_for_all_reduce) from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) from vllm.logger import init_logger @@ -147,19 +145,16 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config # cache in_wsl result + self.mamba_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] self.is_mamba = self.model_config.hf_config.model_type == "jamba" - self.mamba_cache = None - self.mamba_cache4gc = None self.request2i: Dict[str, Dict[int, int]] = {} self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) -<<<<<<< HEAD @torch.inference_mode() def prepare_contiguous_mamba_cache(self, dtype): - is_mamba = self.model_config.hf_config.model_type == "jamba" - if not is_mamba or self.mamba_cache is not None: + if not self.is_mamba or self.mamba_cache is not None: return hf_config = self.model_config.hf_config num_layers = hf_config.num_hidden_layers @@ -177,8 +172,6 @@ def prepare_contiguous_mamba_cache(self, dtype): hf_config.mamba_expand * hf_config.hidden_size // world_size, hf_config.mamba_d_state, ) - if self.mamba_cache is None: - self.mamba_cache = {} self.mamba_cache = (torch.empty(size=conv_state_shape, dtype=dtype, device="cuda"), @@ -192,7 +185,6 @@ def prepare_contiguous_mamba_cache(self, dtype): dtype=dtype, device="cuda")) - # Lazy initialization self.model: torch.nn.Module # Set after load_model self.block_size: int # Set after initial profiling. @@ -203,6 +195,7 @@ def prepare_contiguous_mamba_cache(self, dtype): # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). self.graph_block_tables: torch.Tensor # Set after initial profiling. + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -597,7 +590,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, torch.Tensor, List[RequestInfo]]: if self.is_driver_worker: prefill_reqs = [] decode_reqs = [] @@ -768,10 +761,7 @@ def prepare_input_tensors( def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.request2i: - indices = self.request2i.pop(req_id) - logger.debug( - f"Deleted { req_id } from mamba_cache with indices = {indices}" - ) + self.request2i.pop(req_id) @torch.inference_mode() def execute_model( @@ -826,7 +816,7 @@ def execute_model( if not self.is_driver_worker: return None - if self.is_mamba: + if self.is_mamba and self.mamba_cache is not None: for i, offset in enumerate(indices): self.mamba_cache[0][:, offset].copy_(conv_state[:, i]) self.mamba_cache[1][:, offset].copy_(ssm_state[:, i]) @@ -839,19 +829,24 @@ def execute_model( return output - def _get_first_free_mamba_cache_index(self): - max_possible_bs = self.mamba_cache[0].shape[1] - occupied = [ - id for seq_ids in self.request2i.values() - for id in seq_ids.values() - ] - first_free_index = [i not in occupied - for i in range(max_possible_bs)].index(True) - return first_free_index + def _get_first_free_mamba_cache_index(self) -> int: + if self.is_mamba and self.mamba_cache is not None: + max_possible_bs = self.mamba_cache[0].shape[1] + occupied = [ + id for seq_ids in self.request2i.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_bs) + ].index(True) + return first_free_index + return 0 def _prepare_request_mamba_cache(self, requests_info: List[RequestInfo], batch_size: int): indices = [] + if self.mamba_cache is None: + return max_possible_bs = self.mamba_cache[0].shape[1] for request_info in requests_info: cur_rid = request_info.request_id @@ -873,7 +868,7 @@ def _prepare_request_mamba_cache(self, requests_info: List[RequestInfo], self.mamba_cache[1][:, i_exist]) self.request2i[cur_rid][seq_id] = f_free_index indices.append(self.request2i[cur_rid][seq_id]) - ## Pad the batch incase of running batch that was not captured via CG + ## Pad the batch in case of running batch that was not captured via CG padded_indices = indices for _ in range(batch_size - len(indices)): occu = [ @@ -1205,7 +1200,7 @@ def forward( attn_metadata.decode_metadata.context_lens, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - if self.is_mamba: + if self.is_mamba and conv_state is not None and ssm_state is not None: self.input_buffers["conv_state"].copy_(conv_state, non_blocking=True) self.input_buffers["ssm_state"].copy_(ssm_state, non_blocking=True) @@ -1214,7 +1209,7 @@ def forward( self.graph.replay() # in-place edit of the mamba cache states as in the KV cache - if self.is_mamba: + if self.is_mamba and conv_state is not None and ssm_state is not None: ssm_state.copy_(self.input_buffers["ssm_state"]) conv_state.copy_(self.input_buffers["conv_state"]) From b2f86f8c5a26859570a7f251e98869fe2b1d14b4 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:21:28 +0300 Subject: [PATCH 034/110] Add layers_block_type_support to model config --- vllm/config.py | 54 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 8baa544b515..14120edc6de 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2,7 +2,7 @@ import json import os from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union, Tuple import torch from packaging.version import Version @@ -275,14 +275,52 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size - def get_num_attention_layers(self, - parallel_config: "ParallelConfig") -> int: + def contains_seqlen_agnostic_layers(self) -> int: + return self.hf_config.model_type in ["jamba"] + + def get_layers_block_type(self, + parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) - is_mamba = self.hf_config.model_type in ["jamba"] - if is_mamba: - attention_period = self.hf_config.attn_layer_period - num_layers = max(num_layers // attention_period, 1) - return num_layers + # Transformers supports layers_block_type @property + return getattr( + self.hf_config, + "layers_block_type", + ["attention"] * num_layers + ) + + def get_num_attention_layers(self, parallel_config: "ParallelConfig") -> int: + return len([t for t in self.get_layers_block_type( + parallel_config + ) if t == "attention"]) + + def get_num_seqlen_agnostic_layers( + self, + parallel_config: "ParallelConfig" + ) -> int: + return len([t for t in self.get_layers_block_type( + parallel_config + ) if t != "attention"]) + + def get_num_seqlen_agnostic_cache_shape( + self, + parallel_config + ) -> Tuple[Optional[Tuple[int,int]],Optional[Tuple[int,int]]]: + world_size = parallel_config.tensor_parallel_size + hidden_size = self.get_hidden_size() + conv_state_shape = None + temporal_state_shape = None + if self.hf_config.model_type in ["jamba"]: + conv_state_shape = ( + self.hf_config.mamba_expand * hidden_size // world_size, + self.hf_config.mamba_d_conv, + ) + temporal_state_shape = ( + self.hf_config.mamba_expand * self.hf_config.hidden_size // world_size, + self.hf_config.mamba_d_state, + ) + + return conv_state_shape, temporal_state_shape + class CacheConfig: From 7061df7c3510dd5a94d5f6eef3bdea7fccc5115d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:21:44 +0300 Subject: [PATCH 035/110] Update Jamba to support changes from main --- vllm/model_executor/models/jamba.py | 79 +++++++++++++++-------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2782cdd63a3..1a2c2d49dbf 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,8 +1,9 @@ # coding=utf-8 """Inference-only Jurassic model.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch +from transformers import JambaConfig from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -24,16 +25,15 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.mamba_metadata import MambaCacheParams from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import SamplerOutput -from vllm.transformers_utils.configs.jamba import JambaConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -166,6 +166,7 @@ def mamba_forward(self, hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) + hidden_states = causal_conv1d_fn( hidden_states, conv_weights, @@ -246,8 +247,8 @@ def forward( offset += prompt_len cache = MambaCacheParams( True, - conv_state=conv_state[self.layer_idx], - ssm_state=ssm_state[self.layer_idx], + conv_state=conv_state, + ssm_state=ssm_state, ) padded_hidden_states = self.mamba_forward(padded_hidden_states, cache_params=cache) @@ -259,8 +260,8 @@ def forward( offset += prompt_len else: cache = MambaCacheParams(False, - conv_state=conv_state[self.layer_idx], - ssm_state=ssm_state[self.layer_idx]) + conv_state=conv_state, + ssm_state=ssm_state) hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), cache_params=cache) hidden_states = hidden_states.squeeze(1) @@ -304,7 +305,6 @@ def __init__( self.num_total_experts, bias=False, params_dtype=self.params_dtype, - linear_method=None, ) self.ws = nn.Parameter( @@ -443,8 +443,7 @@ def __init__( config: JambaConfig, actual_num_experts: int, actual_num_experts_per_tok: int, - layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -474,12 +473,12 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, config.hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.attn = Attention( self.num_heads, @@ -548,7 +547,7 @@ class JambaModel(nn.Module): def __init__( self, config: JambaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -585,8 +584,7 @@ def __init__( config, actual_num_experts=actual_num_experts, actual_num_experts_per_tok=actual_num_experts_per_tok, - layer_idx=i, - linear_method=linear_method, + quant_config=quant_config )) else: module_list.append( @@ -616,17 +614,24 @@ def forward( for i in range(len(self.layers)): layer = self.layers[i] kv_cache = None + current_ssm_state = None + current_conv_state = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] + if isinstance(layer, JambaMambaDecoderLayer): + current_state_layer = i - (1 + (i - self.config.attn_layer_offset) // self.config.attn_layer_period) + current_ssm_state = ssm_state[current_state_layer] + current_conv_state = conv_state[current_state_layer] + hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, - conv_state=conv_state, - ssm_state=ssm_state, + conv_state=current_conv_state, + ssm_state=current_ssm_state, ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -657,13 +662,16 @@ class JambaForCausalLM(nn.Module): def __init__( self, config: JambaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = JambaModel(config, linear_method, lora_config=lora_config) + self.model = JambaModel( + config, + quant_config=quant_config, + lora_config=lora_config + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -686,11 +694,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + seqlen_agnostic_cache: Tuple[torch.Tensor,torch.Tensor], ): - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, conv_state, ssm_state) + hidden_states = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + seqlen_agnostic_cache[0], + seqlen_agnostic_cache[1] + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -707,13 +720,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -732,13 +739,7 @@ def load_weights( ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=True - ): # erez - might need to change later to False + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue From 054faf14b0af8a4fb6e94d4633f3386f36a773d1 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:21:58 +0300 Subject: [PATCH 036/110] Take Jamba config off since its now in transformers --- vllm/transformers_utils/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 87b27e08fe2..753c88f7e3f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -3,8 +3,7 @@ from transformers import AutoConfig, PretrainedConfig from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - JAISConfig, JambaConfig, - MPTConfig, RWConfig) + JAISConfig, MPTConfig, RWConfig) _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, @@ -12,8 +11,7 @@ "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "jais": JAISConfig, - "jamba": JambaConfig + "jais": JAISConfig } From fb3fc83ada7c430fef0758c347f3029da362c4c2 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:22:21 +0300 Subject: [PATCH 037/110] Take jamba config off --- vllm/transformers_utils/configs/__init__.py | 4 +- vllm/transformers_utils/configs/jamba.py | 160 -------------------- 2 files changed, 1 insertion(+), 163 deletions(-) delete mode 100644 vllm/transformers_utils/configs/jamba.py diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4e8fdad727e..78dc6207a03 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,10 +5,8 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig -from vllm.transformers_utils.configs.jamba import JambaConfig from vllm.transformers_utils.configs.mpt import MPTConfig __all__ = [ - "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig", - "JambaConfig" + "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig" ] diff --git a/vllm/transformers_utils/configs/jamba.py b/vllm/transformers_utils/configs/jamba.py deleted file mode 100644 index 440cfb2466d..00000000000 --- a/vllm/transformers_utils/configs/jamba.py +++ /dev/null @@ -1,160 +0,0 @@ -""" Jamba model configuration""" -import math - -from transformers.configuration_utils import PretrainedConfig - - -class JambaConfig(PretrainedConfig): - r""" - Args: - vocab_size (`int`, *optional*, defaults to 65536): - Vocabulary size of the Jurassic model. Defines the - number of different tokens that can be represented by the - `inputs_ids` passed when calling [`JurassicModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 14336): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer - in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should - be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi - Head Attention (MHA), if `num_key_value_heads=1 the model will use - Multi Query Attention (MQA) otherwise GQA is used. When converting - a multi-head checkpoint to a GQA checkpoint, each group key and - value head should be constructed by meanpooling all the original - heads within that group. For more details checkout - [this paper](https://arxiv.org/pdf/2305.13245.pdf). - If it is not specified, will default to `8`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) - in the decoder. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values - attentions (not used by all models). - Only relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - The id of the padding token. - bos_token_id (`int`, *optional*, defaults to 1): - The id of the "beginning-of-sequence" token. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the "end-of-sequence" token. - use_positional_embeddings (`bool`, *optional, default False) - flag indicating whether to use positional embeddings or not - rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings. - sliding_window (`int`, *optional*): - Sliding window attention window size. If not specified, will - default to `4096`. - num_experts_per_tok (`int`, *optional*, defaults to 2): - The number of experts to root per-token, can be also interpreted - as the `top-p` routing - parameter - num_experts (`int`, *optional*, defaults to 16): - Number of experts per Sparse MLP layer. - expert_layer_period (`int`, *optional*, defaults to 2) - Once in this many layers, we will have an expert layer - expert_layer_offset(`int`, *optional*, defaults to 1) - The first layer index that contains an expert mlp layer - attn_layer_period (`int`, *optional*, defaults to 8) - Once in this many layers, we will have a vanilla attention layer - attn_layer_offset(`int`, *optional*, defaults to 4) - The first layer index that contains a vanilla attention mlp layer - """ - - model_type = "jamba" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=65536, - tie_word_embeddings=False, - hidden_size=4096, - intermediate_size=14336, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - output_router_logits=False, - router_aux_loss_coef=0.001, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - sliding_window=None, - attention_dropout=0.0, - num_experts_per_tok=2, - num_experts=16, - expert_layer_offset=1, - expert_layer_period=2, - attn_layer_period=8, - attn_layer_offset=4, - use_mamba_kernels=True, - mamba_d_state=16, - mamba_d_conv=4, - mamba_expand=2, - mamba_dt_rank="auto", - mamba_conv_bias=True, - mamba_proj_bias=False, - mamba_inner_layernorms=True, - **kwargs, - ): - self.vocab_size = vocab_size - self.tie_word_embeddings = tie_word_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.sliding_window = sliding_window - self.attention_dropout = attention_dropout - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - - self.use_cache = use_cache - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - - self.num_experts_per_tok = num_experts_per_tok - self.num_experts = num_experts - self.expert_layer_period = expert_layer_period - self.expert_layer_offset = expert_layer_offset - self.attn_layer_period = attn_layer_period - self.attn_layer_offset = attn_layer_offset - - self.use_mamba_kernels = use_mamba_kernels - self.mamba_d_state = mamba_d_state - self.mamba_d_conv = mamba_d_conv - self.mamba_expand = mamba_expand - self.mamba_dt_rank = math.ceil( - self.hidden_size / - 16) if mamba_dt_rank == "auto" else mamba_dt_rank - self.mamba_conv_bias = mamba_conv_bias - self.mamba_proj_bias = mamba_proj_bias - self.mamba_inner_layernorms = mamba_inner_layernorms - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) From 6d8765d5043e0898748a2d1a66925b92c5ad283d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:22:32 +0300 Subject: [PATCH 038/110] Format --- vllm/worker/cache_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 916d4fb01f9..6dae97d1aa5 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -30,8 +30,7 @@ def __init__( self.parallel_config = parallel_config self.head_size = model_config.get_head_size() - self.num_layers = model_config.get_num_attention_layers( - parallel_config) + self.num_layers = model_config.get_num_attention_layers(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size From 10896ae6ff8e4e20ffa77a27e44fbd6eb0b79752 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:28:30 +0300 Subject: [PATCH 039/110] Refactor the model runner a little , make it more readable and chage terminology from mamba speicifc to seqlen agnostic cache --- vllm/worker/model_runner.py | 285 +++++++++++++++++++----------------- 1 file changed, 151 insertions(+), 134 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0d9b01c5b07..4ad4c38638b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -144,46 +144,37 @@ def __init__( self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config - # cache in_wsl result - self.mamba_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] - self.is_mamba = self.model_config.hf_config.model_type == "jamba" - self.request2i: Dict[str, Dict[int, int]] = {} + self.seqlen_agnostic_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] + self.seqlen_agnostic_gc_cache_buffer: Optional[Tuple[torch.Tensor, torch.Tensor]] + self.contains_seqlen_agnostic_layers = self.model_config.contains_seqlen_agnostic_layers() + self.seqlen_agnostic_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) - @torch.inference_mode() - def prepare_contiguous_mamba_cache(self, dtype): - if not self.is_mamba or self.mamba_cache is not None: + def prepare_seqlen_agnostic_cache(self, dtype): + if not self.contains_seqlen_agnostic_layers: return - hf_config = self.model_config.hf_config - num_layers = hf_config.num_hidden_layers + num_seqlen_agnostic_layers = self.model_config.get_num_seqlen_agnostic_layers(self.parallel_config) max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] - world_size = get_tensor_model_parallel_world_size() - conv_state_shape = ( - num_layers, - max_batch_size, - hf_config.mamba_expand * hf_config.hidden_size // world_size, - hf_config.mamba_d_conv, - ) - ssm_state_shape = ( - num_layers, - max_batch_size, - hf_config.mamba_expand * hf_config.hidden_size // world_size, - hf_config.mamba_d_state, - ) - self.mamba_cache = (torch.empty(size=conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=ssm_state_shape, - dtype=dtype, - device="cuda")) - self.mamba_cache4gc = (torch.empty(size=conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=ssm_state_shape, - dtype=dtype, - device="cuda")) + conv_state_shape, temporal_state_shape = self.model_config.get_num_seqlen_agnostic_cache_shape(self.parallel_config) + assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in [ + "seqlen_agnostic_cache", + "seqlen_agnostic_gc_cache_buffer", + ]: + buffer = ( + torch.empty( + size=(num_seqlen_agnostic_layers,max_batch_size) + + conv_state_shape, dtype=dtype, + device="cuda"), + torch.empty( + size=(num_seqlen_agnostic_layers,max_batch_size) + + temporal_state_shape, dtype=dtype, + device="cuda") + ) + setattr(self,buffername, buffer) + # Lazy initialization self.model: torch.nn.Module # Set after load_model @@ -675,7 +666,6 @@ def prepare_input_tensors( seqs_id=list(req.seq_data.keys())) for req in seq_group_metadata_list ] - metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -758,10 +748,10 @@ def prepare_input_tensors( sampling_metadata, lora_requests, lora_mapping, multi_modal_input, requests_info) - def release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + def release_seqlen_agnostic_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: - if req_id in self.request2i: - self.request2i.pop(req_id) + if req_id in self.seqlen_agnostic_cache_indices_mapping: + self.seqlen_agnostic_cache_indices_mapping.pop(req_id) @torch.inference_mode() def execute_model( @@ -795,16 +785,15 @@ def execute_model( if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) - if self.is_mamba: - if self.mamba_cache is None: - self.prepare_contiguous_mamba_cache(self.model_config.dtype) - conv_state, ssm_state, indices = self._prepare_request_mamba_cache( - requests_info, input_tokens.shape[0] if - attn_metadata.prefill_metadata is None else len(requests_info)) + current_seqlen_agnostic_cache = None + if self.contains_seqlen_agnostic_layers: + if getattr(self, "seqlen_agnostic_cache", None) is None: + self.prepare_seqlen_agnostic_cache(self.model_config.dtype) + batch_size = input_tokens.shape[0] if attn_metadata.prefill_metadata is None else len(requests_info) + current_seqlen_agnostic_cache, indices = self._prepare_current_run_seqlen_agnostic_cache(requests_info, batch_size) execute_model_kwargs = { **execute_model_kwargs, - "conv_state": conv_state, - "ssm_state": ssm_state, + "seqlen_agnostic_cache": current_seqlen_agnostic_cache, } hidden_states = model_executable(**execute_model_kwargs) @@ -816,10 +805,9 @@ def execute_model( if not self.is_driver_worker: return None - if self.is_mamba and self.mamba_cache is not None: + if self.contains_seqlen_agnostic_layers: for i, offset in enumerate(indices): - self.mamba_cache[0][:, offset].copy_(conv_state[:, i]) - self.mamba_cache[1][:, offset].copy_(ssm_state[:, i]) + self._copy_seqlen_agnostic_cache(offset, i, current_seqlen_agnostic_cache) # Sample the next token. output = self.model.sample( @@ -829,11 +817,11 @@ def execute_model( return output - def _get_first_free_mamba_cache_index(self) -> int: - if self.is_mamba and self.mamba_cache is not None: - max_possible_bs = self.mamba_cache[0].shape[1] + def _first_free_index_in_seqlen_agnostic_cache(self) -> int: + if self.contains_seqlen_agnostic_layers and self.seqlen_agnostic_cache is not None: + max_possible_bs = self.seqlen_agnostic_cache[0].shape[1] occupied = [ - id for seq_ids in self.request2i.values() + id for seq_ids in self.seqlen_agnostic_cache_indices_mapping.values() for id in seq_ids.values() ] first_free_index = [ @@ -842,46 +830,77 @@ def _get_first_free_mamba_cache_index(self) -> int: return first_free_index return 0 - def _prepare_request_mamba_cache(self, requests_info: List[RequestInfo], - batch_size: int): - indices = [] - if self.mamba_cache is None: - return - max_possible_bs = self.mamba_cache[0].shape[1] + + def _copy_seqlen_agnostic_cache(self, index_to, index_from, from_buffer): + assert self.seqlen_agnostic_cache is not None + self.seqlen_agnostic_cache[0][:,index_to].copy_(from_buffer[0][:,index_from]) + self.seqlen_agnostic_cache[1][:,index_to].copy_(from_buffer[1][:,index_from]) + + + def _assign_seq_id_to_seqlen_agnostic_cache( + self, + cur_rid: str, + seqs_id: List[int] + ) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.seqlen_agnostic_cache_indices_mapping: + self.seqlen_agnostic_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_seqlen_agnostic_cache() + self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := self.seqlen_agnostic_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_seqlen_agnostic_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_seqlen_agnostic_cache( + index_from=index_exist, + index_to=first_free_index, + from_buffer=self.seqlen_agnostic_cache + ) + self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + + # def _find_seq_len_agnostic_pad_index(self, indices_for_current_run, max_possible_bs) -> int: + # occupied_indices = set([i for s_ids in + # self.seqlen_agnostic_cache_indices_mapping.values() + # for i in s_ids.values()]).union(indices_for_current_run) + # pad_index = [ i not in occupied_indices for i in range( + # max_possible_bs + # ) ].index(True) + # return pad_index + # + # + # + def _prepare_current_run_seqlen_agnostic_cache( + self, + requests_info: List[RequestInfo], + batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor] ,List[int]]: + indices_for_current_run = [] for request_info in requests_info: cur_rid = request_info.request_id - if cur_rid not in self.request2i: - self.request2i[cur_rid] = {} - for seq_id in request_info.seqs_id: - f_free_index = self._get_first_free_mamba_cache_index() - self.request2i[cur_rid][seq_id] = f_free_index - indices.append(f_free_index) - else: - for seq_id in request_info.seqs_id: - if seq_id not in self.request2i[cur_rid]: - f_free_index = self._get_first_free_mamba_cache_index() - ## case of decoding n>1 - i_exist = list(self.request2i[cur_rid].values())[0] - self.mamba_cache[0][:, f_free_index].copy_( - self.mamba_cache[0][:, i_exist]) - self.mamba_cache[1][:, f_free_index].copy_( - self.mamba_cache[1][:, i_exist]) - self.request2i[cur_rid][seq_id] = f_free_index - indices.append(self.request2i[cur_rid][seq_id]) + indices_for_current_run += self._assign_seq_id_to_seqlen_agnostic_cache( + cur_rid, + request_info.seqs_id + ) ## Pad the batch in case of running batch that was not captured via CG - padded_indices = indices - for _ in range(batch_size - len(indices)): - occu = [ - i for s_ids in self.request2i.values() for i in s_ids.values() - ] - padded_indices += [[ - i not in set(occu).union(padded_indices) - for i in range(max_possible_bs) - ].index(True)] + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_seqlen_agnostic_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) - conv_state = self.mamba_cache[0][:, padded_indices] - ssm_state = self.mamba_cache[1][:, padded_indices] - return conv_state, ssm_state, indices + conv_state = self.seqlen_agnostic_cache[0][:,padded_indices] + temporal_state = self.seqlen_agnostic_cache[1][:,padded_indices] + + return (conv_state, temporal_state), indices_for_current_run @torch.inference_mode() def profile_run(self) -> None: @@ -948,7 +967,7 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() - self.request2i = {} + self.seqlen_agnostic_cache_indices_mapping = {} return def remove_all_loras(self): @@ -1063,7 +1082,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model, self.is_mamba) + graph_runner = CUDAGraphRunner(self.model) capture_inputs = { "input_ids": input_tokens[:batch_size], "positions": input_positions[:batch_size], @@ -1071,11 +1090,12 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: "attn_metadata": attn_metadata, "memory_pool": self.graph_memory_pool, } - if self.is_mamba: - capture_inputs["conv_state"] = self.mamba_cache4gc[ - 0][:, :batch_size] - capture_inputs["ssm_state"] = self.mamba_cache4gc[ - 1][:, :batch_size] + if self.contains_seqlen_agnostic_layers: + assert self.seqlen_agnostic_gc_cache_buffer is not None + capture_inputs["seqlen_agnostic_cache"] = ( + self.seqlen_agnostic_gc_cache_buffer[0][:, :batch_size], + self.seqlen_agnostic_gc_cache_buffer[1][:, :batch_size], + ) graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -1102,11 +1122,10 @@ def vocab_size(self) -> int: class CUDAGraphRunner: - def __init__(self, model: nn.Module, is_mamba: bool): + def __init__(self, model: nn.Module): self.model = model self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} - self.is_mamba = is_mamba self._graph: Optional[torch.cuda.CUDAGraph] = None @@ -1122,30 +1141,20 @@ def capture( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool, - conv_state: Optional[torch.Tensor] = None, - ssm_state: Optional[torch.Tensor] = None, **kwargs, ) -> None: assert self._graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - model_inputs = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, - **kwargs - } - if self.is_mamba: - model_inputs = { - **model_inputs, - "conv_state": conv_state, - "ssm_state": ssm_state, - } - with _maybe_pynccl(): - self.model(**model_inputs) + self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) torch.cuda.synchronize() # Capture the graph. @@ -1154,7 +1163,13 @@ def capture( self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 with _maybe_pynccl(): - hidden_states = self.model(**model_inputs) + hidden_states = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) torch.cuda.synchronize() # Save the input and output buffers. @@ -1165,16 +1180,8 @@ def capture( "slot_mapping": attn_metadata.slot_mapping, "context_lens": attn_metadata.decode_metadata.context_lens, "block_tables": attn_metadata.decode_metadata.block_tables, - "conv_state": conv_state, - "ssm_state": ssm_state + **kwargs, } - if self.is_mamba: - self.input_buffers = { - **self.input_buffers, - "conv_state": conv_state, - "ssm_state": ssm_state, - } - self.output_buffers = {"hidden_states": hidden_states} return @@ -1184,8 +1191,6 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state: Optional[torch.Tensor] = None, - ssm_state: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. @@ -1200,18 +1205,30 @@ def forward( attn_metadata.decode_metadata.context_lens, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - if self.is_mamba and conv_state is not None and ssm_state is not None: - self.input_buffers["conv_state"].copy_(conv_state, - non_blocking=True) - self.input_buffers["ssm_state"].copy_(ssm_state, non_blocking=True) + + if "seqlen_agnostic_cache" in kwargs: + self.input_buffers["seqlen_agnostic_cache"][0].copy_( + kwargs["seqlen_agnostic_cache"][0], + non_blocking=True + ) + self.input_buffers["seqlen_agnostic_cache"][1].copy_( + kwargs["seqlen_agnostic_cache"][1], + non_blocking=True + ) # Run the graph. self.graph.replay() # in-place edit of the mamba cache states as in the KV cache - if self.is_mamba and conv_state is not None and ssm_state is not None: - ssm_state.copy_(self.input_buffers["ssm_state"]) - conv_state.copy_(self.input_buffers["conv_state"]) + if "seqlen_agnostic_cache" in kwargs: + kwargs["seqlen_agnostic_cache"][0].copy_( + self.input_buffers["seqlen_agnostic_cache"][0], + non_blocking=True + ) + kwargs["seqlen_agnostic_cache"][1].copy_( + self.input_buffers["seqlen_agnostic_cache"][1], + non_blocking=True + ) # Return the output tensor. return self.output_buffers["hidden_states"] From d1dc26f411482cac497d259e2fc27f403655d82d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:35:44 +0300 Subject: [PATCH 040/110] rename release mamba to release seqlen agnostic --- vllm/engine/llm_engine.py | 2 +- vllm/executor/executor_base.py | 2 +- vllm/executor/gpu_executor.py | 4 ++-- vllm/executor/ray_gpu_executor.py | 4 ++-- vllm/worker/worker.py | 10 ++++------ 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8004c18fab7..e0000df42df 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -518,7 +518,7 @@ def _process_model_outputs( ] if len(finished_seq_groups_req_ids) > 0: - self.model_executor.release_mamba_cache( + self.model_executor.release_seqlen_agnostic_cache( finished_seq_groups_req_ids) self.scheduler.free_finished_seq_groups() diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 7b26b125102..f939c9973de 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -96,7 +96,7 @@ def check_health(self) -> None: raise NotImplementedError @abstractmethod - def release_mamba_cache(self, requests_id: List[str]) -> None: + def release_seqlen_agnostic_cache(self, requests_id: List[str]) -> None: raise NotImplementedError def shutdown(self) -> None: diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 26241cb8118..8e847946df0 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -154,8 +154,8 @@ def check_health(self) -> None: # it's running. return - def release_mamba_cache(self, requests_id: List[str]) -> None: - self.driver_worker.release_mamba_cache(requests_id) + def release_seqlen_agnostic_cache(self, requests_id: List[str]) -> None: + self.driver_worker.release_seqlen_agnostic_cache(requests_id) class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index a9bdae2cda9..d3a6997d617 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -319,8 +319,8 @@ def _check_if_any_actor_is_dead(self): raise RuntimeError("At least one Worker is dead. " f"Dead Workers: {dead_actors}. ") - def release_mamba_cache(self, requests_id: List[str]) -> None: - self._run_workers("release_mamba_cache", + def release_seqlen_agnostic_cache(self, requests_id: List[str]) -> None: + self._run_workers("release_seqlen_agnostic_cache", requests_id=requests_id, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0d4c6627e83..83fd5b35a5a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -184,10 +184,8 @@ def _init_cache_engine(self): self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) - is_mamba = self.model_config.hf_config.model_type == "jamba" - if is_mamba: - self.model_runner.prepare_contiguous_mamba_cache( - self.cache_engine.dtype) + if self.model_config.contains_seqlen_agnostic_layers(): + self.model_runner.prepare_seqlen_agnostic_cache(self.cache_engine.dtype) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -281,8 +279,8 @@ def get_cache_block_size_bytes(self) -> int: self.model_config, self.parallel_config) - def release_mamba_cache(self, requests_id: List[str]): - self.model_runner.release_mamba_cache(requests_id) + def release_seqlen_agnostic_cache(self, requests_id: List[str]): + self.model_runner.release_seqlen_agnostic_cache(requests_id) def init_worker_distributed_environment( From 07c8cd26ba43e29e6e057bfd159ce3904c434f32 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:36:09 +0300 Subject: [PATCH 041/110] Move requirements of mamba to its own requirements --- requirements-common.txt | 2 -- requirements-mamba.txt | 6 ++++++ 2 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 requirements-mamba.txt diff --git a/requirements-common.txt b/requirements-common.txt index 09480d38dd3..e9db261c6ae 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -17,5 +17,3 @@ lm-format-enforcer == 0.9.8 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 -mamba-ssm -causal-conv1d >= 1.2.0 diff --git a/requirements-mamba.txt b/requirements-mamba.txt new file mode 100644 index 00000000000..a34b0e321b9 --- /dev/null +++ b/requirements-mamba.txt @@ -0,0 +1,6 @@ +# Common dependencies +-r requirements-common.txt +-r requirements-cuda.txt + +mamba-ssm +causal-conv1d >= 1.2.0 From 5c11285fa72d594ff2d59b9d3b8925dd42fd884b Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:40:00 +0300 Subject: [PATCH 042/110] Remove mamba metadata since its mamba specific --- vllm/model_executor/mamba_metadata.py | 17 ----------------- vllm/model_executor/models/jamba.py | 10 ++++++++-- vllm/worker/model_runner.py | 7 ++++++- 3 files changed, 14 insertions(+), 20 deletions(-) delete mode 100644 vllm/model_executor/mamba_metadata.py diff --git a/vllm/model_executor/mamba_metadata.py b/vllm/model_executor/mamba_metadata.py deleted file mode 100644 index 3ee6bdf14b8..00000000000 --- a/vllm/model_executor/mamba_metadata.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass, field -from typing import List - -import torch - - -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - - -@dataclass -class RequestInfo: - request_id: str = '' - seqs_id: List[int] = field(default_factory=list) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 1a2c2d49dbf..18b5a209ca0 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,5 +1,6 @@ # coding=utf-8 """Inference-only Jurassic model.""" +from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, @@ -29,7 +29,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.mamba_metadata import MambaCacheParams from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -37,6 +36,13 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4ad4c38638b..bf64200fe6e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,5 @@ import contextlib +from dataclasses import dataclass import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -21,7 +22,6 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata -from vllm.model_executor.mamba_metadata import RequestInfo from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, @@ -40,6 +40,11 @@ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] +@dataclass +class RequestInfo: + request_id: str + seqs_id: List[int] + class PreparePromptMetadata(NamedTuple): input_tokens: List[int] From 2bb33609d804055fe295da7fc3a146940e4f0901 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:42:42 +0300 Subject: [PATCH 043/110] Align with master --- vllm/model_executor/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 7747a40de5a..fb98f4a6b46 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,7 +1,7 @@ -from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ - "SamplingMetadata", "set_random_seed", "MambaCacheParams", "RequestInfo" + "SamplingMetadata", + "set_random_seed", ] From a235c446fa0b02e1e7b5bc1321e733fbedf613ea Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 5 May 2024 18:46:17 +0300 Subject: [PATCH 044/110] Change comment --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bf64200fe6e..5fbe74a8306 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1224,7 +1224,7 @@ def forward( # Run the graph. self.graph.replay() - # in-place edit of the mamba cache states as in the KV cache + # in-place edit of the seqlen agnostic cache states as in the KV cache if "seqlen_agnostic_cache" in kwargs: kwargs["seqlen_agnostic_cache"][0].copy_( self.input_buffers["seqlen_agnostic_cache"][0], From af7a4ac2d94a15b43c3fb372c158b707aec22068 Mon Sep 17 00:00:00 2001 From: Tomer Asida Date: Mon, 6 May 2024 16:54:22 +0300 Subject: [PATCH 045/110] (1) implement contains_seqlen_agnostic_layers with use of self.get_num_seqlen_agnostic_layers (2) renaming + remove commented code --- vllm/config.py | 6 +++--- vllm/worker/model_runner.py | 16 ++-------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 14120edc6de..9c1ac0ff4fb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -275,8 +275,8 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size - def contains_seqlen_agnostic_layers(self) -> int: - return self.hf_config.model_type in ["jamba"] + def contains_seqlen_agnostic_layers(self, parallel_config: "ParallelConfig") -> bool: + return self.get_num_seqlen_agnostic_layers(parallel_config) > 0 def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: @@ -301,7 +301,7 @@ def get_num_seqlen_agnostic_layers( parallel_config ) if t != "attention"]) - def get_num_seqlen_agnostic_cache_shape( + def get_seqlen_agnostic_cache_shape( self, parallel_config ) -> Tuple[Optional[Tuple[int,int]],Optional[Tuple[int,int]]]: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5fbe74a8306..1debe002068 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -151,7 +151,7 @@ def __init__( self.vision_language_config = vision_language_config self.seqlen_agnostic_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] self.seqlen_agnostic_gc_cache_buffer: Optional[Tuple[torch.Tensor, torch.Tensor]] - self.contains_seqlen_agnostic_layers = self.model_config.contains_seqlen_agnostic_layers() + self.contains_seqlen_agnostic_layers = self.model_config.contains_seqlen_agnostic_layers(parallel_config) self.seqlen_agnostic_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.attn_backend = get_attn_backend( @@ -162,7 +162,7 @@ def prepare_seqlen_agnostic_cache(self, dtype): return num_seqlen_agnostic_layers = self.model_config.get_num_seqlen_agnostic_layers(self.parallel_config) max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] - conv_state_shape, temporal_state_shape = self.model_config.get_num_seqlen_agnostic_cache_shape(self.parallel_config) + conv_state_shape, temporal_state_shape = self.model_config.get_seqlen_agnostic_cache_shape(self.parallel_config) assert conv_state_shape is not None and temporal_state_shape is not None for buffername in [ "seqlen_agnostic_cache", @@ -871,18 +871,6 @@ def _assign_seq_id_to_seqlen_agnostic_cache( indices_for_current_run.append(index_for_current_run) return indices_for_current_run - - # def _find_seq_len_agnostic_pad_index(self, indices_for_current_run, max_possible_bs) -> int: - # occupied_indices = set([i for s_ids in - # self.seqlen_agnostic_cache_indices_mapping.values() - # for i in s_ids.values()]).union(indices_for_current_run) - # pad_index = [ i not in occupied_indices for i in range( - # max_possible_bs - # ) ].index(True) - # return pad_index - # - # - # def _prepare_current_run_seqlen_agnostic_cache( self, requests_info: List[RequestInfo], From 988718e99987bb95958cb0a79646800aab5c3d7c Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Mon, 6 May 2024 17:01:43 +0300 Subject: [PATCH 046/110] Jamba official hf (#14) * remove JambaConfig and use official one from transformers * changes in Jamba modeling file to align with official HF format --- vllm/model_executor/models/jamba.py | 213 ++++++++------------ vllm/transformers_utils/configs/__init__.py | 2 + 2 files changed, 89 insertions(+), 126 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 18b5a209ca0..6133b00ac68 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -4,15 +4,14 @@ from typing import Iterable, List, Optional, Tuple import torch -from transformers import JambaConfig -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn -from torch.nn.parameter import Parameter +from vllm.model_executor.layers.activation import SiluAndMul from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention + +from transformers import JambaConfig +from torch.nn.parameter import Parameter from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -33,6 +32,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import SamplerOutput +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -43,7 +45,6 @@ class MambaCacheParams: ssm_state: torch.Tensor = torch.Tensor() - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ @@ -124,28 +125,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): input_is_parallel=True, ) self.activation = config.hidden_act - self.apply_inner_layernorms = config.mamba_inner_layernorms - - if self.apply_inner_layernorms: - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=config.rms_norm_eps) - self.B_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) - self.C_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) - else: - self.dt_layernorm = None - self.B_layernorm = None - self.C_layernorm = None - - def _apply_layernorms(self, dt, B, C): - if self.dt_layernorm is not None: - dt = self.dt_layernorm.forward(dt.contiguous()) - if self.B_layernorm is not None: - B = self.B_layernorm.forward(B.contiguous()) - if self.C_layernorm is not None: - C = self.C_layernorm.forward(C.contiguous()) - return dt, B, C + + self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) def mamba_forward(self, hidden_states: torch.Tensor, @@ -189,7 +172,9 @@ def mamba_forward(self, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1, ) - time_step, B, C = self._apply_layernorms(time_step, B, C) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -275,6 +260,36 @@ def forward( return hidden_states +class JambaMLP(nn.Module): + def __init__( + self, + config: JambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + class JambaMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -285,33 +300,27 @@ class JambaMoE(nn.Module): """ def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, + self, + config: JambaConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size + self.num_total_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // self.tp_size if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - if self.num_total_experts > 1: - # init expert router iff this layer has multiple experts - self.router = ReplicatedLinear( - self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - ) + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype) self.ws = nn.Parameter( torch.empty( @@ -366,14 +375,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (batch * sequence_length, n_experts) - if self.num_total_experts > 1: - router_logits, _ = self.router(hidden_states) - else: - router_logits = torch.ones( - [hidden_states.shape[0], 1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) + router_logits, _ = self.router(hidden_states) final_hidden_states = fused_moe( hidden_states, @@ -394,28 +396,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class JambaMambaDecoderLayer(nn.Module): - def __init__( - self, - config: JambaConfig, - actual_num_experts: int, - actual_num_experts_per_tok: int, - layer_idx: int, + self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None ) -> None: super().__init__() self.layer_idx = layer_idx self.config = config self.mamba = JambaMambaMixer(config, layer_idx) - self.moe = JambaMoE( - num_experts=actual_num_experts, - top_k=actual_num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -436,20 +429,15 @@ def forward( hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, ssm_state) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm( - hidden_states, residual) - hidden_states = self.moe(hidden_states) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class JambaAttentionDecoderLayer(nn.Module): def __init__( - self, - config: JambaConfig, - actual_num_experts: int, - actual_num_experts_per_tok: int, - quant_config: Optional[QuantizationConfig] = None, + self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -494,16 +482,11 @@ def __init__( sliding_window=self.sliding_window, ) - self.moe = JambaMoE( - num_experts=actual_num_experts, - top_k=actual_num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def self_attention( self, @@ -542,12 +525,14 @@ def forward( attn_metadata=attn_metadata, ) # Fully Connected - hidden_states, residual = self.pre_moe_layernorm( - hidden_states, residual) - hidden_states = self.moe(hidden_states) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) return hidden_states, residual +ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer} + + class JambaModel(nn.Module): def __init__( @@ -570,40 +555,12 @@ def __init__( org_num_embeddings=config.vocab_size, ) - # init each model layer, decide if it's mamba/attention and - # has experts and pass it down - - module_list = [] + decoder_layers = [] for i in range(config.num_hidden_layers): - is_attn = ((i - self.config.attn_layer_offset) % - self.config.attn_layer_period == 0) - is_expert = ((i - self.config.expert_layer_offset) % - self.config.expert_layer_period == 0) - - actual_num_experts = config.num_experts if is_expert else 1 - actual_num_experts_per_tok = config.num_experts_per_tok \ - if is_expert else 1 - - if is_attn: - module_list.append( - JambaAttentionDecoderLayer( - config, - actual_num_experts=actual_num_experts, - actual_num_experts_per_tok=actual_num_experts_per_tok, - quant_config=quant_config - )) - else: - module_list.append( - JambaMambaDecoderLayer( - config, - actual_num_experts=actual_num_experts, - actual_num_experts_per_tok=actual_num_experts_per_tok, - layer_idx=i, - )) - - self.layers = nn.ModuleList(module_list) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] + decoder_layers.append(layer_class(config, layer_idx=i, quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -732,6 +689,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] expert_params_mapping = [ @@ -758,6 +717,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if 'experts' in name: + continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 78dc6207a03..3bccc425cc8 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,8 @@ from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.jamba import JambaConfig + __all__ = [ "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig" ] From 49ce3dfb41b981df98674f2512143cd5c7f7cff9 Mon Sep 17 00:00:00 2001 From: Tomer Asida Date: Tue, 7 May 2024 11:39:36 +0300 Subject: [PATCH 047/110] fixes missed in merge with official Jamba HF format --- vllm/transformers_utils/configs/__init__.py | 2 -- vllm/worker/worker.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 3bccc425cc8..78dc6207a03 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,8 +7,6 @@ from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.mpt import MPTConfig -from vllm.transformers_utils.configs.jamba import JambaConfig - __all__ = [ "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig" ] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 83fd5b35a5a..457dca1b8a4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -184,7 +184,7 @@ def _init_cache_engine(self): self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) - if self.model_config.contains_seqlen_agnostic_layers(): + if self.model_config.contains_seqlen_agnostic_layers(self.parallel_config): self.model_runner.prepare_seqlen_agnostic_cache(self.cache_engine.dtype) def _warm_up_model(self) -> None: From 7add09ac9dbb6e700a917377ff80874df6d902ef Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 17 May 2024 01:00:58 +0000 Subject: [PATCH 048/110] fix merge error --- vllm/model_executor/models/jamba.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 7a640b240f0..61913d0756b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -224,18 +224,18 @@ def forward( # Mamba doesn't support chunked prefill, # We pad the hidden_states before the forward pass and # unpad it again afterwards. - max_seq_len = max(attn_metadata.prefill_metadata.prompt_lens) - batch_size = len(attn_metadata.prefill_metadata.prompt_lens) + max_seq_len = max(attn_metadata.prefill_metadata.seq_lens) + batch_size = len(attn_metadata.prefill_metadata.seq_lens) padded_hidden_states = torch.zeros( (batch_size, max_seq_len, hidden_states.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device) offset = 0 - for i, prompt_len in enumerate( - attn_metadata.prefill_metadata.prompt_lens): - padded_hidden_states[i, :prompt_len].copy_( - hidden_states[offset:offset + prompt_len]) - offset += prompt_len + for i, seq_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + padded_hidden_states[i, :seq_len].copy_( + hidden_states[offset:offset + seq_len]) + offset += seq_len cache = MambaCacheParams( True, conv_state=conv_state, @@ -244,11 +244,11 @@ def forward( padded_hidden_states = self.mamba_forward(padded_hidden_states, cache_params=cache) offset = 0 - for i, prompt_len in enumerate( - attn_metadata.prefill_metadata.prompt_lens): - hidden_states[offset:offset + prompt_len].copy_( - padded_hidden_states[i, :prompt_len]) - offset += prompt_len + for i, seq_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + hidden_states[offset:offset + seq_len].copy_( + padded_hidden_states[i, :seq_len]) + offset += seq_len else: cache = MambaCacheParams(False, conv_state=conv_state, From 14fbab553a1db659311b2887c8ef07ad586a5950 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 20 May 2024 16:03:49 +0300 Subject: [PATCH 049/110] Fix bug where seqlen agnostic cache wasn't copied for non driver workers --- vllm/worker/model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ef825e56d47..c03080e72d5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -869,14 +869,14 @@ def execute_model( # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return None - if self.contains_seqlen_agnostic_layers: for i, offset in enumerate(indices): self._copy_seqlen_agnostic_cache(offset, i, current_seqlen_agnostic_cache) + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + # Sample the next token. output = self.model.sample( logits=logits, From e3dec15f5a72dab40d1ffb1781a21ed0084b34b9 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 25 Jun 2024 17:01:43 +0300 Subject: [PATCH 050/110] WIP - encapsulate Jamba cache managemnt inside the modeling file --- vllm/config.py | 20 +-- vllm/model_executor/models/jamba.py | 185 +++++++++++++++++++++++++++- vllm/worker/model_runner.py | 144 ++-------------------- vllm/worker/worker.py | 2 - 4 files changed, 194 insertions(+), 157 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3581da17013..9f1221c9c5f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -348,25 +348,7 @@ def get_num_seqlen_agnostic_layers( parallel_config ) if t != "attention"]) - def get_seqlen_agnostic_cache_shape( - self, - parallel_config - ) -> Tuple[Optional[Tuple[int,int]],Optional[Tuple[int,int]]]: - world_size = parallel_config.tensor_parallel_size - hidden_size = self.get_hidden_size() - conv_state_shape = None - temporal_state_shape = None - if self.hf_config.model_type in ["jamba"]: - conv_state_shape = ( - self.hf_config.mamba_expand * hidden_size // world_size, - self.hf_config.mamba_d_conv, - ) - temporal_state_shape = ( - self.hf_config.mamba_expand * self.hf_config.hidden_size // world_size, - self.hf_config.mamba_d_state, - ) - - return conv_state_shape, temporal_state_shape + diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 61913d0756b..7bf24c0fcc2 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -35,6 +35,7 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from vllm.worker.model_runner import RequestInfo KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -657,6 +658,9 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) + self._capture = False + self.current_indices = [] + self.seqlen_agnostic_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -667,18 +671,193 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, - seqlen_agnostic_cache: Tuple[torch.Tensor,torch.Tensor], + **kwargs ): + self.prepare_seqlen_agnostic_cache() + + if "seqlen_agnostic_capture_inputs" not in kwargs: + requests_info = kwargs["requests_info"] + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(requests_info) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_seqlen_agnostic_cache( + requests_info, + batch_size + ) + else: + current_seqlen_agnostic_cache, indices = kwargs["seqlen_agnostic_capture_inputs"],[] + self.current_indices = indices + hidden_states = self.model( input_ids, positions, kv_caches, attn_metadata, - seqlen_agnostic_cache[0], - seqlen_agnostic_cache[1] + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1] ) + if "seqlen_agnostic_capture_inputs" not in kwargs: + self._copy_seqlen_agnostic_cache_by_indices(self.current_indices, current_seqlen_agnostic_cache) + return hidden_states + + def _copy_seqlen_agnostic_cache_by_indices(self, indices, current_seqlen_agnostic_cache): + for i, offset in enumerate(indices): + self._copy_seqlen_agnostic_cache(offset, i, current_seqlen_agnostic_cache) + + + def _copy_seqlen_agnostic_cache(self, index_to, index_from, from_buffer): + assert self.seqlen_agnostic_cache is not None + for i in [0,1]: + self.seqlen_agnostic_cache[i][:,index_to].copy_(from_buffer[i][:,index_from],non_blocking=True) + + + def _assign_seq_id_to_seqlen_agnostic_cache( + self, + cur_rid: str, + seqs_id: List[int] + ) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.seqlen_agnostic_cache_indices_mapping: + self.seqlen_agnostic_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_seqlen_agnostic_cache() + self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := self.seqlen_agnostic_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_seqlen_agnostic_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_seqlen_agnostic_cache( + index_from=index_exist, + index_to=first_free_index, + from_buffer=self.seqlen_agnostic_cache + ) + self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + + def _prepare_current_run_seqlen_agnostic_cache( + self, + requests_info: List[RequestInfo], + batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor] ,List[int]]: + indices_for_current_run = [] + for request_info in requests_info: + cur_rid = request_info.request_id + indices_for_current_run += self._assign_seq_id_to_seqlen_agnostic_cache( + cur_rid, + request_info.seqs_id + ) + ## Pad the batch in case of running batch that was not captured via CG + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_seqlen_agnostic_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) + + conv_state = self.seqlen_agnostic_cache[0][:,padded_indices] + temporal_state = self.seqlen_agnostic_cache[1][:,padded_indices] + + return (conv_state, temporal_state), indices_for_current_run + + + def copy_inputs_before_cuda_grpahs(self, input_buffers, **kwargs): + requests_info = kwargs["requests_info"] + batch_size = len(requests_info) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_seqlen_agnostic_cache(requests_info, batch_size) + self.current_indices = indices + + for i in [0,1]: + input_buffers["seqlen_agnostic_capture_inputs"][i].copy_( + current_seqlen_agnostic_cache[i], + non_blocking=True + ) + + + def copy_outputs_after_cuda_grpahs(self, input_buffers, **kwargs): + self._copy_seqlen_agnostic_cache_by_indices( + self.current_indices, + input_buffers["seqlen_agnostic_capture_inputs"] + ) + + def capture_inputs(self,batch_size): + return ( + self.seqlen_agnostic_gc_cache_buffer[0][:, :batch_size], + self.seqlen_agnostic_gc_cache_buffer[1][:, :batch_size], + ) + + + def _first_free_index_in_seqlen_agnostic_cache(self) -> int: + if self.seqlen_agnostic_cache is not None: + max_possible_bs = self.seqlen_agnostic_cache[0].shape[1] + occupied = [ + id for seq_ids in self.seqlen_agnostic_cache_indices_mapping.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_bs) + ].index(True) + return first_free_index + return 0 + + def get_seqlen_agnostic_cache_shape( + self, + ) -> Tuple[Optional[Tuple[int,int]],Optional[Tuple[int,int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_conv, + ) + temporal_state_shape = ( + self.config.mamba_expand * self.config.hidden_size // world_size, + self.config.mamba_d_state, + ) + + return conv_state_shape, temporal_state_shape + + + def prepare_seqlen_agnostic_cache(self): + if getattr(self, "seqlen_agnostic_cache", None) is not None: + return + # dtype = torch.get_default_dtype() + dtype = self.lm_head.weight.dtype + layers_type = self.config.layers_block_type + mamba_layers = sum([layer_type == "mamba" for layer_type in layers_type]) + num_seqlen_agnostic_layers = mamba_layers + max_batch_size = 256 + conv_state_shape, temporal_state_shape = self.get_seqlen_agnostic_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in [ + "seqlen_agnostic_cache", + "seqlen_agnostic_gc_cache_buffer", + ]: + buffer = ( + torch.empty( + size=(num_seqlen_agnostic_layers,max_batch_size) + + conv_state_shape, dtype=dtype, + device="cuda"), + torch.empty( + size=(num_seqlen_agnostic_layers,max_batch_size) + + temporal_state_shape, dtype=dtype, + device="cuda") + ) + setattr(self,buffername, buffer) + + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head.weight, hidden_states, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c03080e72d5..60e9edf575e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -142,7 +142,6 @@ def __init__( self.seqlen_agnostic_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] self.seqlen_agnostic_gc_cache_buffer: Optional[Tuple[torch.Tensor, torch.Tensor]] self.contains_seqlen_agnostic_layers = self.model_config.contains_seqlen_agnostic_layers(parallel_config) - self.seqlen_agnostic_cache_indices_mapping: Dict[str, Dict[int, int]] = {} # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in @@ -170,28 +169,7 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - def prepare_seqlen_agnostic_cache(self, dtype): - if not self.contains_seqlen_agnostic_layers: - return - num_seqlen_agnostic_layers = self.model_config.get_num_seqlen_agnostic_layers(self.parallel_config) - max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] - conv_state_shape, temporal_state_shape = self.model_config.get_seqlen_agnostic_cache_shape(self.parallel_config) - assert conv_state_shape is not None and temporal_state_shape is not None - for buffername in [ - "seqlen_agnostic_cache", - "seqlen_agnostic_gc_cache_buffer", - ]: - buffer = ( - torch.empty( - size=(num_seqlen_agnostic_layers,max_batch_size) - + conv_state_shape, dtype=dtype, - device="cuda"), - torch.empty( - size=(num_seqlen_agnostic_layers,max_batch_size) + - temporal_state_shape, dtype=dtype, - device="cuda") - ) - setattr(self,buffername, buffer) + def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -849,30 +827,16 @@ def execute_model( "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": attn_metadata, + "requests_info": requests_info } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) - current_seqlen_agnostic_cache = None - if self.contains_seqlen_agnostic_layers: - if getattr(self, "seqlen_agnostic_cache", None) is None: - self.prepare_seqlen_agnostic_cache(self.model_config.dtype) - batch_size = input_tokens.shape[0] if attn_metadata.prefill_metadata is None else len(requests_info) - current_seqlen_agnostic_cache, indices = self._prepare_current_run_seqlen_agnostic_cache(requests_info, batch_size) - execute_model_kwargs = { - **execute_model_kwargs, - "seqlen_agnostic_cache": current_seqlen_agnostic_cache, - } - hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) - if self.contains_seqlen_agnostic_layers: - for i, offset in enumerate(indices): - self._copy_seqlen_agnostic_cache(offset, i, current_seqlen_agnostic_cache) - # Only perform sampling in the driver worker. if not self.is_driver_worker: return None @@ -885,79 +849,6 @@ def execute_model( return output - def _first_free_index_in_seqlen_agnostic_cache(self) -> int: - if self.contains_seqlen_agnostic_layers and self.seqlen_agnostic_cache is not None: - max_possible_bs = self.seqlen_agnostic_cache[0].shape[1] - occupied = [ - id for seq_ids in self.seqlen_agnostic_cache_indices_mapping.values() - for id in seq_ids.values() - ] - first_free_index = [ - i not in occupied for i in range(max_possible_bs) - ].index(True) - return first_free_index - return 0 - - - def _copy_seqlen_agnostic_cache(self, index_to, index_from, from_buffer): - assert self.seqlen_agnostic_cache is not None - self.seqlen_agnostic_cache[0][:,index_to].copy_(from_buffer[0][:,index_from]) - self.seqlen_agnostic_cache[1][:,index_to].copy_(from_buffer[1][:,index_from]) - - - def _assign_seq_id_to_seqlen_agnostic_cache( - self, - cur_rid: str, - seqs_id: List[int] - ) -> List[int]: - indices_for_current_run = [] - for seq_id in seqs_id: - if cur_rid not in self.seqlen_agnostic_cache_indices_mapping: - self.seqlen_agnostic_cache_indices_mapping[cur_rid] = {} - first_free_index = self._first_free_index_in_seqlen_agnostic_cache() - self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index - index_for_current_run = first_free_index - ## case of decoding n>1, copy prefill cache to decoding indices - elif seq_id not in (seq_ids2indices := self.seqlen_agnostic_cache_indices_mapping[cur_rid]): - first_free_index = self._first_free_index_in_seqlen_agnostic_cache() - index_exist = list(seq_ids2indices.values())[0] - self._copy_seqlen_agnostic_cache( - index_from=index_exist, - index_to=first_free_index, - from_buffer=self.seqlen_agnostic_cache - ) - self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index - index_for_current_run = first_free_index - else: - index_for_current_run = self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] - - indices_for_current_run.append(index_for_current_run) - return indices_for_current_run - - def _prepare_current_run_seqlen_agnostic_cache( - self, - requests_info: List[RequestInfo], - batch_size: int - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor] ,List[int]]: - indices_for_current_run = [] - for request_info in requests_info: - cur_rid = request_info.request_id - indices_for_current_run += self._assign_seq_id_to_seqlen_agnostic_cache( - cur_rid, - request_info.seqs_id - ) - ## Pad the batch in case of running batch that was not captured via CG - padded_indices = indices_for_current_run.copy() - pad_index = self._first_free_index_in_seqlen_agnostic_cache() - - for _ in range(batch_size - len(indices_for_current_run)): - padded_indices.append(pad_index) - - conv_state = self.seqlen_agnostic_cache[0][:,padded_indices] - temporal_state = self.seqlen_agnostic_cache[1][:,padded_indices] - - return (conv_state, temporal_state), indices_for_current_run - @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -1136,11 +1027,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: "memory_pool": self.graph_memory_pool, } if self.contains_seqlen_agnostic_layers: - assert self.seqlen_agnostic_gc_cache_buffer is not None - capture_inputs["seqlen_agnostic_cache"] = ( - self.seqlen_agnostic_gc_cache_buffer[0][:, :batch_size], - self.seqlen_agnostic_gc_cache_buffer[1][:, :batch_size], - ) + capture_inputs["seqlen_agnostic_capture_inputs"] = self.model.capture_inputs(batch_size) graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -1251,28 +1138,19 @@ def forward( self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - if "seqlen_agnostic_cache" in kwargs: - self.input_buffers["seqlen_agnostic_cache"][0].copy_( - kwargs["seqlen_agnostic_cache"][0], - non_blocking=True - ) - self.input_buffers["seqlen_agnostic_cache"][1].copy_( - kwargs["seqlen_agnostic_cache"][1], - non_blocking=True + if "seqlen_agnostic_capture_inputs" in self.input_buffers: + self.model.copy_inputs_before_cuda_grpahs( + self.input_buffers, + **kwargs ) # Run the graph. self.graph.replay() - # in-place edit of the seqlen agnostic cache states as in the KV cache - if "seqlen_agnostic_cache" in kwargs: - kwargs["seqlen_agnostic_cache"][0].copy_( - self.input_buffers["seqlen_agnostic_cache"][0], - non_blocking=True - ) - kwargs["seqlen_agnostic_cache"][1].copy_( - self.input_buffers["seqlen_agnostic_cache"][1], - non_blocking=True + if "seqlen_agnostic_capture_inputs" in self.input_buffers: + self.model.copy_outputs_after_cuda_grpahs( + self.input_buffers, + **kwargs ) # Return the output tensor. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7e20d65538b..7bf0194e7b5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -186,8 +186,6 @@ def _init_cache_engine(self): self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache - if self.model_config.contains_seqlen_agnostic_layers(self.parallel_config): - self.model_runner.prepare_seqlen_agnostic_cache(self.cache_engine.dtype) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: From 92778c42920a1227636a8378f9d309e1fde9512c Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 25 Jun 2024 18:04:26 +0300 Subject: [PATCH 051/110] Cleanup --- vllm/config.py | 1 - vllm/transformers_utils/config.py | 2 +- vllm/transformers_utils/configs/__init__.py | 6 +++++- vllm/worker/model_runner.py | 12 ++---------- 4 files changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9f1221c9c5f..fc862a11431 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -348,7 +348,6 @@ def get_num_seqlen_agnostic_layers( parallel_config ) if t != "attention"]) - diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 753c88f7e3f..1756c91a612 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -11,7 +11,7 @@ "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "jais": JAISConfig + "jais": JAISConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 78dc6207a03..0e486928824 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -8,5 +8,9 @@ from vllm.transformers_utils.configs.mpt import MPTConfig __all__ = [ - "ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig" + "ChatGLMConfig", + "DbrxConfig", + "MPTConfig", + "RWConfig", + "JAISConfig", ] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 60e9edf575e..b1a48ca1125 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -139,8 +139,6 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.seqlen_agnostic_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] - self.seqlen_agnostic_gc_cache_buffer: Optional[Tuple[torch.Tensor, torch.Tensor]] self.contains_seqlen_agnostic_layers = self.model_config.contains_seqlen_agnostic_layers(parallel_config) # When using CUDA graph, the input block tables must be padded to @@ -820,18 +818,16 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - - indices = [] execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": attn_metadata, - "requests_info": requests_info } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) - + if self.contains_seqlen_agnostic_layers: + execute_model_kwargs.update({"requests_info": requests_info}) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. @@ -915,7 +911,6 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() - self.seqlen_agnostic_cache_indices_mapping = {} return def remove_all_loras(self): @@ -1137,7 +1132,6 @@ def forward( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_grpahs( self.input_buffers, @@ -1146,13 +1140,11 @@ def forward( # Run the graph. self.graph.replay() - if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_outputs_after_cuda_grpahs( self.input_buffers, **kwargs ) - # Return the output tensor. return self.output_buffers["hidden_states"] From db364277735f3d627d9982c8e3411c3f0e38f080 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 10:08:16 +0300 Subject: [PATCH 052/110] Typos and cleanup --- vllm/model_executor/input_metadata.py | 1 - vllm/model_executor/models/jamba.py | 2 +- vllm/worker/model_runner.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) delete mode 100644 vllm/model_executor/input_metadata.py diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py deleted file mode 100644 index 8b137891791..00000000000 --- a/vllm/model_executor/input_metadata.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 7bf24c0fcc2..193ea5a5463 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -787,7 +787,7 @@ def copy_inputs_before_cuda_grpahs(self, input_buffers, **kwargs): ) - def copy_outputs_after_cuda_grpahs(self, input_buffers, **kwargs): + def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): self._copy_seqlen_agnostic_cache_by_indices( self.current_indices, input_buffers["seqlen_agnostic_capture_inputs"] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b1a48ca1125..5d63686f92c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1141,7 +1141,7 @@ def forward( # Run the graph. self.graph.replay() if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_outputs_after_cuda_grpahs( + self.model.copy_outputs_after_cuda_graphs( self.input_buffers, **kwargs ) From 7f6edfc9c82eba55c3cd7d4da84a2667f380f3b5 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 10:13:01 +0300 Subject: [PATCH 053/110] Another typo --- vllm/model_executor/models/jamba.py | 2 +- vllm/worker/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 193ea5a5463..955c8be8d95 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -771,7 +771,7 @@ def _prepare_current_run_seqlen_agnostic_cache( return (conv_state, temporal_state), indices_for_current_run - def copy_inputs_before_cuda_grpahs(self, input_buffers, **kwargs): + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): requests_info = kwargs["requests_info"] batch_size = len(requests_info) ( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5d63686f92c..35e73e7db07 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1133,7 +1133,7 @@ def forward( self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_inputs_before_cuda_grpahs( + self.model.copy_inputs_before_cuda_graphs( self.input_buffers, **kwargs ) From ee5f0582fe075fcd1e6f01bf6f0087a84eea8455 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 11:28:12 +0300 Subject: [PATCH 054/110] Keep the finished requests ids after in the scheduler after each step and send them to the workers on every step --- vllm/core/scheduler.py | 13 ++++++++++++- vllm/engine/async_llm_engine.py | 1 + vllm/engine/llm_engine.py | 9 +-------- vllm/executor/executor_base.py | 4 ---- vllm/executor/gpu_executor.py | 4 ---- vllm/model_executor/models/jamba.py | 28 ++++++++++++++++++++------- vllm/sequence.py | 3 +++ vllm/worker/embedding_model_runner.py | 1 + vllm/worker/model_runner.py | 14 +++++++------- vllm/worker/worker.py | 13 ++++++------- 10 files changed, 52 insertions(+), 38 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fbde27f9982..ec0bef49381 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -298,6 +298,9 @@ def __init__( # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() + # Sequence groups finished in after the last step iter. + self.finished: List[SequenceGroup] = list() + # Time at previous scheduling step self.prev_time = 0.0 # Did we schedule a prompt at previous step? @@ -369,6 +372,12 @@ def has_unfinished_seqs(self) -> bool: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) + def get_last_step_finished_seq_groups(self) -> List[str]: + """Returns list of the finished request ids.""" + finisehd_req_ids = [seq_group.request_id for seq_group in self.finished] + self.finished = [] + return finisehd_req_ids + def _schedule_running( self, running_queue: deque, @@ -1012,8 +1021,10 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: + self.finished += [seq_group for seq_group in self.running + if seq_group.is_finished()] self.running = deque(seq_group for seq_group in self.running - if not seq_group.is_finished()) + if seq_group not in self.finished) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a31f10b7748..4ec83fd9384 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -221,6 +221,7 @@ async def step_async( blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, + finished_request_ids=self.scheduler.get_last_step_finished_seq_groups() ) output = await self.model_executor.execute_model_async( execute_model_req) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 16ce850e4e4..b176394ec95 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -597,14 +597,6 @@ def _process_model_outputs( self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. - finished_seq_groups_req_ids = [ - seq_group.request_id for seq_group in self.scheduler.running - if seq_group.is_finished() - ] - - if len(finished_seq_groups_req_ids) > 0: - self.model_executor.release_seqlen_agnostic_cache( - finished_seq_groups_req_ids) self.scheduler.free_finished_seq_groups() # Create the outputs. @@ -681,6 +673,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, + finished_request_ids=self.scheduler.get_last_step_finished_seq_groups() ) output = self.model_executor.execute_model( execute_model_req=execute_model_req) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index a9acbf7eed0..08aa58999b1 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -92,10 +92,6 @@ def check_health(self) -> None: exception.""" raise NotImplementedError - @abstractmethod - def release_seqlen_agnostic_cache(self, requests_id: List[str]) -> None: - raise NotImplementedError - def shutdown(self) -> None: """Shutdown the executor.""" return diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index f52fe4c3b86..359ce2fc0e7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -144,10 +144,6 @@ def check_health(self) -> None: # it's running. return - def release_seqlen_agnostic_cache(self, requests_id: List[str]) -> None: - self.driver_worker.release_seqlen_agnostic_cache(requests_id) - - class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 955c8be8d95..30b22c08041 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -673,10 +673,13 @@ def forward( attn_metadata: AttentionMetadata, **kwargs ): - self.prepare_seqlen_agnostic_cache() + if getattr(self, "seqlen_agnostic_cache", None) is None: + self._prepare_seqlen_agnostic_cache() if "seqlen_agnostic_capture_inputs" not in kwargs: requests_info = kwargs["requests_info"] + finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] + self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(requests_info) @@ -688,7 +691,11 @@ def forward( batch_size ) else: - current_seqlen_agnostic_cache, indices = kwargs["seqlen_agnostic_capture_inputs"],[] + ## CG capturing runs + current_seqlen_agnostic_cache, indices = ( + kwargs["seqlen_agnostic_capture_inputs"], + [], + ) self.current_indices = indices hidden_states = self.model( @@ -773,6 +780,8 @@ def _prepare_current_run_seqlen_agnostic_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): requests_info = kwargs["requests_info"] + finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] + self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) batch_size = len(requests_info) ( current_seqlen_agnostic_cache, @@ -800,6 +809,12 @@ def capture_inputs(self,batch_size): ) + def _release_seqlen_agnostic_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.seqlen_agnostic_cache_indices_mapping: + self.seqlen_agnostic_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_seqlen_agnostic_cache(self) -> int: if self.seqlen_agnostic_cache is not None: max_possible_bs = self.seqlen_agnostic_cache[0].shape[1] @@ -813,7 +828,7 @@ def _first_free_index_in_seqlen_agnostic_cache(self) -> int: return first_free_index return 0 - def get_seqlen_agnostic_cache_shape( + def _get_seqlen_agnostic_cache_shape( self, ) -> Tuple[Optional[Tuple[int,int]],Optional[Tuple[int,int]]]: world_size = get_tensor_model_parallel_world_size() @@ -830,16 +845,15 @@ def get_seqlen_agnostic_cache_shape( return conv_state_shape, temporal_state_shape - def prepare_seqlen_agnostic_cache(self): - if getattr(self, "seqlen_agnostic_cache", None) is not None: - return + def _prepare_seqlen_agnostic_cache(self): # dtype = torch.get_default_dtype() dtype = self.lm_head.weight.dtype layers_type = self.config.layers_block_type mamba_layers = sum([layer_type == "mamba" for layer_type in layers_type]) num_seqlen_agnostic_layers = mamba_layers + # TODO: get from config max_batch_size = 256 - conv_state_shape, temporal_state_shape = self.get_seqlen_agnostic_cache_shape() + conv_state_shape, temporal_state_shape = self._get_seqlen_agnostic_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None for buffername in [ "seqlen_agnostic_cache", diff --git a/vllm/sequence.py b/vllm/sequence.py index 46ac33b7eca..d451a084dcb 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -822,6 +822,8 @@ class ExecuteModelRequest: num_lookahead_slots: int = 0 # The number of requests in the running queue. running_queue_size: int = 0 + # The number of requests in the running queue. + finished_request_ids: List[str] = field(default_factory=list) def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -834,4 +836,5 @@ def clone( blocks_to_copy=self.blocks_to_copy.copy(), num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, + finished_request_ids=self.finished_request_ids ) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d04bebbdc31..7eccdfc5cdc 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -49,6 +49,7 @@ def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], + finished_seq_groups_req_ids: Optional[List[str]] = None ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 35e73e7db07..b48d13aec1c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -792,16 +792,12 @@ def prepare_input_tensors( sampling_metadata, lora_requests, lora_mapping, multi_modal_input, requests_info) - def release_seqlen_agnostic_cache(self, finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.seqlen_agnostic_cache_indices_mapping: - self.seqlen_agnostic_cache_indices_mapping.pop(req_id) - @torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], + finished_seq_groups_req_ids: Optional[List[str]] = None ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input, @@ -827,7 +823,10 @@ def execute_model( if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) if self.contains_seqlen_agnostic_layers: - execute_model_kwargs.update({"requests_info": requests_info}) + execute_model_kwargs.update({ + "requests_info": requests_info, + "finished_seq_groups_req_ids": finished_seq_groups_req_ids, + }) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. @@ -909,7 +908,8 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - self.execute_model(seqs, kv_caches) + finished_seq_groups_req_ids = [] + self.execute_model(seqs, kv_caches,finished_seq_groups_req_ids) torch.cuda.synchronize() return diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7bf0194e7b5..f12408ff940 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -218,7 +218,7 @@ def execute_model( seq_group_metadata_list = None else: seq_group_metadata_list = execute_model_req.seq_group_metadata_list - + finished_seq_group_req_ids: List[str] blocks_to_swap_in: torch.Tensor blocks_to_swap_out: torch.Tensor blocks_to_copy: torch.Tensor @@ -242,11 +242,13 @@ def execute_model( blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) + finished_seq_group_req_ids = execute_model_req.finished_request_ids data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, + "finished_seq_group_req_ids": finished_seq_group_req_ids } broadcast_tensor_dict(data, src=0) else: @@ -255,7 +257,7 @@ def execute_model( blocks_to_swap_in = data["blocks_to_swap_in"] blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] - + finished_seq_group_req_ids = data["finished_seq_group_req_ids"] self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. @@ -263,7 +265,8 @@ def execute_model( return [] output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) + self.gpu_cache, + finished_seq_group_req_ids) # Worker only supports single-step execution. Wrap the output in a list # to conform to interface. @@ -293,10 +296,6 @@ def get_cache_block_size_bytes(self) -> int: self.model_config, self.parallel_config) - def release_seqlen_agnostic_cache(self, requests_id: List[str]): - self.model_runner.release_seqlen_agnostic_cache(requests_id) - - def init_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, From 2d423674bd934a8707c31eef3c4f01f558fa7024 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 11:37:28 +0300 Subject: [PATCH 055/110] Cleanup --- vllm/executor/gpu_executor.py | 1 + vllm/executor/ray_gpu_executor.py | 5 ----- vllm/worker/worker.py | 3 +++ 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 359ce2fc0e7..2b72b31b5f0 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -144,6 +144,7 @@ def check_health(self) -> None: # it's running. return + class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 7467f3403e2..afc1c886722 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -293,11 +293,6 @@ def _check_if_any_actor_is_dead(self): raise RuntimeError("At least one Worker is dead. " f"Dead Workers: {dead_actors}. ") - def release_seqlen_agnostic_cache(self, requests_id: List[str]) -> None: - self._run_workers("release_seqlen_agnostic_cache", - requests_id=requests_id, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) - class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f12408ff940..d49d2c245cd 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -218,6 +218,7 @@ def execute_model( seq_group_metadata_list = None else: seq_group_metadata_list = execute_model_req.seq_group_metadata_list + finished_seq_group_req_ids: List[str] blocks_to_swap_in: torch.Tensor blocks_to_swap_out: torch.Tensor @@ -258,6 +259,7 @@ def execute_model( blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] finished_seq_group_req_ids = data["finished_seq_group_req_ids"] + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. @@ -296,6 +298,7 @@ def get_cache_block_size_bytes(self) -> int: self.model_config, self.parallel_config) + def init_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, From 6a6378cfea08464ed59772438f8a0efef9271aa4 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 11:47:33 +0300 Subject: [PATCH 056/110] clean up requests after profile --- vllm/model_executor/models/jamba.py | 9 +++++---- vllm/worker/model_runner.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 30b22c08041..a21264031bf 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -678,8 +678,6 @@ def forward( if "seqlen_agnostic_capture_inputs" not in kwargs: requests_info = kwargs["requests_info"] - finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] - self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(requests_info) @@ -690,6 +688,8 @@ def forward( requests_info, batch_size ) + finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] + self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) else: ## CG capturing runs current_seqlen_agnostic_cache, indices = ( @@ -780,8 +780,6 @@ def _prepare_current_run_seqlen_agnostic_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): requests_info = kwargs["requests_info"] - finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] - self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) batch_size = len(requests_info) ( current_seqlen_agnostic_cache, @@ -789,6 +787,9 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): ) = self._prepare_current_run_seqlen_agnostic_cache(requests_info, batch_size) self.current_indices = indices + finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] + self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) + for i in [0,1]: input_buffers["seqlen_agnostic_capture_inputs"][i].copy_( current_seqlen_agnostic_cache[i], diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b48d13aec1c..a529895fc5e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -908,7 +908,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - finished_seq_groups_req_ids = [] + finished_seq_groups_req_ids = [seq.request_id for seq in seqs] self.execute_model(seqs, kv_caches,finished_seq_groups_req_ids) torch.cuda.synchronize() return From 1a8e2f94225ef1f561b5a20a40c8c8f8640ef66c Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 12:43:01 +0300 Subject: [PATCH 057/110] Update mamba requirements --- requirements-mamba.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-mamba.txt b/requirements-mamba.txt index a34b0e321b9..30a5b0948f0 100644 --- a/requirements-mamba.txt +++ b/requirements-mamba.txt @@ -2,5 +2,5 @@ -r requirements-common.txt -r requirements-cuda.txt -mamba-ssm -causal-conv1d >= 1.2.0 +mamba-ssm>=1.2.2 +causal-conv1d>=1.2.0 From eb89987067e2bdb2d4054327435c30f607d255a1 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 12:43:31 +0300 Subject: [PATCH 058/110] Renaming --- vllm/core/scheduler.py | 16 ++++++++-------- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ec0bef49381..aa9cc1146a3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -299,7 +299,7 @@ def __init__( self.swapped: Deque[SequenceGroup] = deque() # Sequence groups finished in after the last step iter. - self.finished: List[SequenceGroup] = list() + self.previously_finished_request_id: List[str] = list() # Time at previous scheduling step self.prev_time = 0.0 @@ -372,11 +372,11 @@ def has_unfinished_seqs(self) -> bool: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def get_last_step_finished_seq_groups(self) -> List[str]: - """Returns list of the finished request ids.""" - finisehd_req_ids = [seq_group.request_id for seq_group in self.finished] - self.finished = [] - return finisehd_req_ids + def flush_last_step_finished_req_ids(self) -> List[str]: + """Flushs the list of request ids of previously finished seq_groups.""" + finished_request_ids = self.previously_finished_request_id + self.previously_finished_request_id = [] + return finished_request_ids def _schedule_running( self, @@ -1021,10 +1021,10 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: - self.finished += [seq_group for seq_group in self.running + self.previously_finished_request_id += [seq_group.request_id for seq_group in self.running if seq_group.is_finished()] self.running = deque(seq_group for seq_group in self.running - if seq_group not in self.finished) + if not seq_group.is_finished()) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4ec83fd9384..1a80228b798 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -221,7 +221,7 @@ async def step_async( blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler.get_last_step_finished_seq_groups() + finished_request_ids=self.scheduler.flush_last_step_finished_req_ids() ) output = await self.model_executor.execute_model_async( execute_model_req) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b176394ec95..5cda953b79c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -673,7 +673,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler.get_last_step_finished_seq_groups() + finished_request_ids=self.scheduler.flush_last_step_finished_req_ids() ) output = self.model_executor.execute_model( execute_model_req=execute_model_req) From 1cb8c1cada65d59288623c0dd6c03f47070e3e96 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 12:46:25 +0300 Subject: [PATCH 059/110] Renaming --- vllm/model_executor/models/jamba.py | 8 ++++---- vllm/worker/embedding_model_runner.py | 2 +- vllm/worker/model_runner.py | 8 ++++---- vllm/worker/worker.py | 10 +++++----- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index a21264031bf..2f85e147be5 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -688,8 +688,8 @@ def forward( requests_info, batch_size ) - finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] - self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) + finished_request_ids = kwargs["finished_request_ids"] + self._release_seqlen_agnostic_cache(finished_request_ids) else: ## CG capturing runs current_seqlen_agnostic_cache, indices = ( @@ -787,8 +787,8 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): ) = self._prepare_current_run_seqlen_agnostic_cache(requests_info, batch_size) self.current_indices = indices - finished_seq_groups_req_ids = kwargs["finished_seq_groups_req_ids"] - self._release_seqlen_agnostic_cache(finished_seq_groups_req_ids) + finished_request_ids = kwargs["finished_request_ids"] + self._release_seqlen_agnostic_cache(finished_request_ids) for i in [0,1]: input_buffers["seqlen_agnostic_capture_inputs"][i].copy_( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 7eccdfc5cdc..bb8f5080246 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -49,7 +49,7 @@ def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], - finished_seq_groups_req_ids: Optional[List[str]] = None + finished_request_ids: Optional[List[str]] = None ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a529895fc5e..f63b5b6999c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -797,7 +797,7 @@ def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], - finished_seq_groups_req_ids: Optional[List[str]] = None + finished_request_ids: Optional[List[str]] = None ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input, @@ -825,7 +825,7 @@ def execute_model( if self.contains_seqlen_agnostic_layers: execute_model_kwargs.update({ "requests_info": requests_info, - "finished_seq_groups_req_ids": finished_seq_groups_req_ids, + "finished_request_ids": finished_request_ids, }) hidden_states = model_executable(**execute_model_kwargs) @@ -908,8 +908,8 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - finished_seq_groups_req_ids = [seq.request_id for seq in seqs] - self.execute_model(seqs, kv_caches,finished_seq_groups_req_ids) + finished_request_ids = [seq.request_id for seq in seqs] + self.execute_model(seqs, kv_caches,finished_request_ids) torch.cuda.synchronize() return diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d49d2c245cd..7493144e2c5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -219,7 +219,7 @@ def execute_model( else: seq_group_metadata_list = execute_model_req.seq_group_metadata_list - finished_seq_group_req_ids: List[str] + finished_request_ids: List[str] blocks_to_swap_in: torch.Tensor blocks_to_swap_out: torch.Tensor blocks_to_copy: torch.Tensor @@ -243,13 +243,13 @@ def execute_model( blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) - finished_seq_group_req_ids = execute_model_req.finished_request_ids + finished_request_ids = execute_model_req.finished_request_ids data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, - "finished_seq_group_req_ids": finished_seq_group_req_ids + "finished_request_ids": finished_request_ids } broadcast_tensor_dict(data, src=0) else: @@ -258,7 +258,7 @@ def execute_model( blocks_to_swap_in = data["blocks_to_swap_in"] blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] - finished_seq_group_req_ids = data["finished_seq_group_req_ids"] + finished_request_ids = data["finished_request_ids"] self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -268,7 +268,7 @@ def execute_model( output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache, - finished_seq_group_req_ids) + finished_request_ids) # Worker only supports single-step execution. Wrap the output in a list # to conform to interface. From feca5d537714e4398228757f14c1fe0f347cdb52 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 12:49:56 +0300 Subject: [PATCH 060/110] Rename and docs --- vllm/model_executor/models/jamba.py | 2 +- vllm/sequence.py | 2 +- vllm/worker/model_runner.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2f85e147be5..2b4414b3c76 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -803,7 +803,7 @@ def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): input_buffers["seqlen_agnostic_capture_inputs"] ) - def capture_inputs(self,batch_size): + def get_seqlen_agnostic_capture_inputs(self,batch_size): return ( self.seqlen_agnostic_gc_cache_buffer[0][:, :batch_size], self.seqlen_agnostic_gc_cache_buffer[1][:, :batch_size], diff --git a/vllm/sequence.py b/vllm/sequence.py index d451a084dcb..037bbf8b83b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -822,7 +822,7 @@ class ExecuteModelRequest: num_lookahead_slots: int = 0 # The number of requests in the running queue. running_queue_size: int = 0 - # The number of requests in the running queue. + # Finished request ids since last step. finished_request_ids: List[str] = field(default_factory=list) def clone( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f63b5b6999c..da4171eb657 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1022,7 +1022,8 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: "memory_pool": self.graph_memory_pool, } if self.contains_seqlen_agnostic_layers: - capture_inputs["seqlen_agnostic_capture_inputs"] = self.model.capture_inputs(batch_size) + capture_inputs.update({"seqlen_agnostic_capture_inputs": self.model.get_seqlen_agnostic_capture_inputs(batch_size)} + ) graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner From 72c31cc6314cd02414aae401fa747c1130b5cfc5 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 12:51:20 +0300 Subject: [PATCH 061/110] Format --- vllm/worker/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index da4171eb657..61496db60bd 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -167,7 +167,6 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - def load_model(self) -> None: with CudaMemoryProfiler() as m: From ddeb6892a0e9137aaa03b3af0dc427fa257cabe5 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 15:01:05 +0300 Subject: [PATCH 062/110] Add mamba to Dockerfile --- Dockerfile | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/Dockerfile b/Dockerfile index ddca95c0e87..c868fc5049e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,6 +26,10 @@ COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ pip install -r requirements-cuda.txt +COPY requirements-mamba.txt requirements-mamba.txt +RUN pip install packaging +RUN pip install -r requirements-mamba.txt + # install development dependencies COPY requirements-dev.txt requirements-dev.txt RUN --mount=type=cache,target=/root/.cache/pip \ @@ -87,6 +91,22 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### +#################### MAMBA Build IMAGE #################### +FROM dev as mamba-builder +# max jobs used for build +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} + +WORKDIR /usr/src/mamba + +COPY requirements-mamba.txt requirements-mamba.txt + +# Download the wheel or build it if a pre-compiled release doesn't exist +RUN pip --verbose wheel -r requirements-mamba.txt \ + --no-build-isolation --no-deps --no-cache-dir + +#################### MAMBA Build IMAGE #################### + #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base @@ -105,6 +125,10 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose + +RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ + --mount=type=cache,target=/root/.cache/pip \ + pip install /usr/src/mamba/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### From 5d5a3bef35ddd0cf6fb5b00bef05358aa28d519a Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 15:09:14 +0300 Subject: [PATCH 063/110] Mamba disable prompt batching --- vllm/model_executor/models/jamba.py | 53 +++++++++++------------------ 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2b4414b3c76..6ddde7a7a38 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -222,40 +222,27 @@ def forward( ssm_state: torch.Tensor, ): if attn_metadata.prefill_metadata is not None: - # Mamba doesn't support chunked prefill, - # We pad the hidden_states before the forward pass and - # unpad it again afterwards. - max_seq_len = max(attn_metadata.prefill_metadata.seq_lens) - batch_size = len(attn_metadata.prefill_metadata.seq_lens) - padded_hidden_states = torch.zeros( - (batch_size, max_seq_len, hidden_states.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device) offset = 0 - for i, seq_len in enumerate( - attn_metadata.prefill_metadata.seq_lens): - padded_hidden_states[i, :seq_len].copy_( - hidden_states[offset:offset + seq_len]) - offset += seq_len + for i,prompt_len in enumerate(attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams( + True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0) + ) + hidden_states[offset:offset + prompt_len].copy_( + self.mamba_forward( + hidden_states[offset:offset + prompt_len].unsqueeze(0), + cache_params=cache + )[0] + ) + offset += prompt_len + else: cache = MambaCacheParams( - True, + False, conv_state=conv_state, - ssm_state=ssm_state, + ssm_state=ssm_state ) - padded_hidden_states = self.mamba_forward(padded_hidden_states, - cache_params=cache) - offset = 0 - for i, seq_len in enumerate( - attn_metadata.prefill_metadata.seq_lens): - hidden_states[offset:offset + seq_len].copy_( - padded_hidden_states[i, :seq_len]) - offset += seq_len - else: - cache = MambaCacheParams(False, - conv_state=conv_state, - ssm_state=ssm_state) - hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), - cache_params=cache) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), cache_params=cache) hidden_states = hidden_states.squeeze(1) return hidden_states @@ -409,7 +396,7 @@ def __init__( num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP - self.feed_forward = ffn_layer_class(config, quant_config) + self.feed_forward = ffn_layer_class(config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -463,7 +450,7 @@ def __init__( self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( @@ -490,7 +477,7 @@ def __init__( num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP - self.feed_forward = ffn_layer_class(config, quant_config) + self.feed_forward = ffn_layer_class(config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 85715fe20698f6e8a3e510a019976bc0bc1635e6 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 15:35:20 +0300 Subject: [PATCH 064/110] Format --- vllm/config.py | 37 ++- vllm/core/scheduler.py | 10 +- vllm/engine/async_llm_engine.py | 4 +- vllm/engine/llm_engine.py | 4 +- vllm/model_executor/models/jamba.py | 338 +++++++++++++------------- vllm/sequence.py | 3 +- vllm/worker/cache_engine.py | 3 +- vllm/worker/embedding_model_runner.py | 11 +- vllm/worker/model_runner.py | 41 ++-- 9 files changed, 221 insertions(+), 230 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fc862a11431..bb983c1ff86 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,7 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union, Tuple +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch from transformers import PretrainedConfig @@ -322,33 +322,30 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size - def contains_seqlen_agnostic_layers(self, parallel_config: "ParallelConfig") -> bool: + def contains_seqlen_agnostic_layers( + self, parallel_config: "ParallelConfig") -> bool: return self.get_num_seqlen_agnostic_layers(parallel_config) > 0 def get_layers_block_type(self, - parallel_config: "ParallelConfig") -> List[str]: + parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) # Transformers supports layers_block_type @property - return getattr( - self.hf_config, - "layers_block_type", - ["attention"] * num_layers - ) + return getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) - def get_num_attention_layers(self, parallel_config: "ParallelConfig") -> int: - return len([t for t in self.get_layers_block_type( - parallel_config - ) if t == "attention"]) + def get_num_attention_layers(self, + parallel_config: "ParallelConfig") -> int: + return len([ + t for t in self.get_layers_block_type(parallel_config) + if t == "attention" + ]) def get_num_seqlen_agnostic_layers( - self, - parallel_config: "ParallelConfig" - ) -> int: - return len([t for t in self.get_layers_block_type( - parallel_config - ) if t != "attention"]) - - + self, parallel_config: "ParallelConfig") -> int: + return len([ + t for t in self.get_layers_block_type(parallel_config) + if t != "attention" + ]) class CacheConfig: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index aa9cc1146a3..949e2df7d2e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -373,7 +373,7 @@ def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) def flush_last_step_finished_req_ids(self) -> List[str]: - """Flushs the list of request ids of previously finished seq_groups.""" + """Flushes the list of request ids of previously finished seq_groups.""" finished_request_ids = self.previously_finished_request_id self.previously_finished_request_id = [] return finished_request_ids @@ -1021,10 +1021,12 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: - self.previously_finished_request_id += [seq_group.request_id for seq_group in self.running - if seq_group.is_finished()] + self.previously_finished_request_id += [ + seq_group.request_id for seq_group in self.running + if seq_group.is_finished() + ] self.running = deque(seq_group for seq_group in self.running - if not seq_group.is_finished()) + if not seq_group.is_finished()) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1a80228b798..ce5d9e6ba99 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -221,8 +221,8 @@ async def step_async( blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler.flush_last_step_finished_req_ids() - ) + finished_request_ids=self.scheduler. + flush_last_step_finished_req_ids()) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5cda953b79c..23d15212091 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -673,8 +673,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler.flush_last_step_finished_req_ids() - ) + finished_request_ids=self.scheduler. + flush_last_step_finished_req_ids()) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6ddde7a7a38..31b60046516 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,21 +1,23 @@ # coding=utf-8 """Inference-only Jurassic model.""" from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn +from torch.nn.parameter import Parameter +from transformers import JambaConfig -from vllm.model_executor.layers.activation import SiluAndMul from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig -from transformers import JambaConfig -from torch.nn.parameter import Parameter -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -24,21 +26,20 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import SamplerOutput -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from vllm.worker.model_runner import RequestInfo KVCache = Tuple[torch.Tensor, torch.Tensor] + @dataclass class MambaCacheParams: is_prompt: bool = False @@ -127,9 +128,12 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): ) self.activation = config.hidden_act - self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) - self.b_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) - self.c_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.rms_norm_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) def mamba_forward(self, hidden_states: torch.Tensor, @@ -223,36 +227,33 @@ def forward( ): if attn_metadata.prefill_metadata is not None: offset = 0 - for i,prompt_len in enumerate(attn_metadata.prefill_metadata.seq_lens): - cache = MambaCacheParams( - True, - conv_state=conv_state[i].unsqueeze(0), - ssm_state=ssm_state[i].unsqueeze(0) - ) + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams(True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0)) hidden_states[offset:offset + prompt_len].copy_( - self.mamba_forward( - hidden_states[offset:offset + prompt_len].unsqueeze(0), - cache_params=cache - )[0] - ) + self.mamba_forward(hidden_states[offset:offset + + prompt_len].unsqueeze(0), + cache_params=cache)[0]) offset += prompt_len else: - cache = MambaCacheParams( - False, - conv_state=conv_state, - ssm_state=ssm_state - ) - hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), cache_params=cache) + cache = MambaCacheParams(False, + conv_state=conv_state, + ssm_state=ssm_state) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) hidden_states = hidden_states.squeeze(1) return hidden_states class JambaMLP(nn.Module): + def __init__( - self, - config: JambaConfig, - quant_config: Optional[QuantizationConfig] = None, + self, + config: JambaConfig, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() hidden_size = config.hidden_size @@ -288,11 +289,11 @@ class JambaMoE(nn.Module): """ def __init__( - self, - config: JambaConfig, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, + self, + config: JambaConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -384,11 +385,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class JambaMambaDecoderLayer(nn.Module): - def __init__( - self, config: JambaConfig, layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None - ) -> None: + + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.layer_idx = layer_idx self.config = config @@ -397,8 +399,10 @@ def __init__( num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP self.feed_forward = ffn_layer_class(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -419,7 +423,8 @@ def forward( hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, ssm_state) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -427,9 +432,11 @@ def forward( class JambaAttentionDecoderLayer(nn.Module): def __init__( - self, config: JambaConfig, layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -450,7 +457,7 @@ def __init__( self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( @@ -478,8 +485,10 @@ def __init__( num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP self.feed_forward = ffn_layer_class(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def self_attention( self, @@ -518,12 +527,16 @@ def forward( attn_metadata=attn_metadata, ) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual -ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer} +ALL_DECODER_LAYER_TYPES = { + "attention": JambaAttentionDecoderLayer, + "mamba": JambaMambaDecoderLayer +} class JambaModel(nn.Module): @@ -552,11 +565,14 @@ def __init__( decoder_layers = [] for i in range(config.num_hidden_layers): layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] - decoder_layers.append(layer_class(config, layer_idx=i, - cache_config=cache_config, - quant_config=quant_config)) + decoder_layers.append( + layer_class(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) self.layers = nn.ModuleList(decoder_layers) - self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -579,7 +595,9 @@ def forward( kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] if isinstance(layer, JambaMambaDecoderLayer): - current_state_layer = i - (1 + (i - self.config.attn_layer_offset) // self.config.attn_layer_period) + current_state_layer = i - (1 + + (i - self.config.attn_layer_offset) + // self.config.attn_layer_period) current_ssm_state = ssm_state[current_state_layer] current_conv_state = conv_state[current_state_layer] @@ -627,12 +645,10 @@ def __init__( ) -> None: super().__init__() self.config = config - self.model = JambaModel( - config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config - ) + self.model = JambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -647,34 +663,27 @@ def __init__( ) self._capture = False self.current_indices = [] - self.seqlen_agnostic_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - **kwargs - ): - if getattr(self, "seqlen_agnostic_cache", None) is None: + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[KVCache], attn_metadata: AttentionMetadata, + **kwargs): + if getattr(self, "mamba_cache", None) is None: self._prepare_seqlen_agnostic_cache() if "seqlen_agnostic_capture_inputs" not in kwargs: requests_info = kwargs["requests_info"] batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: - batch_size = len(requests_info) + batch_size = len(requests_info) ( current_seqlen_agnostic_cache, indices, - ) = self._prepare_current_run_seqlen_agnostic_cache( - requests_info, - batch_size - ) + ) = self._prepare_current_run_mamba_cache(requests_info, + batch_size) finished_request_ids = kwargs["finished_request_ids"] self._release_seqlen_agnostic_cache(finished_request_ids) else: @@ -685,129 +694,114 @@ def forward( ) self.current_indices = indices - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1] - ) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) if "seqlen_agnostic_capture_inputs" not in kwargs: - self._copy_seqlen_agnostic_cache_by_indices(self.current_indices, current_seqlen_agnostic_cache) + self._copy_mamba_cache_by_indices(self.current_indices, + current_seqlen_agnostic_cache) return hidden_states - - def _copy_seqlen_agnostic_cache_by_indices(self, indices, current_seqlen_agnostic_cache): + def _copy_mamba_cache_by_indices(self, indices, + current_seqlen_agnostic_cache): for i, offset in enumerate(indices): - self._copy_seqlen_agnostic_cache(offset, i, current_seqlen_agnostic_cache) + self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + def _copy_mamba_cache(self, index_to, index_from, from_buffer): + assert self.mamba_cache is not None + for i in [0, 1]: + self.mamba_cache[i][:, index_to].copy_(from_buffer[i][:, + index_from], + non_blocking=True) - def _copy_seqlen_agnostic_cache(self, index_to, index_from, from_buffer): - assert self.seqlen_agnostic_cache is not None - for i in [0,1]: - self.seqlen_agnostic_cache[i][:,index_to].copy_(from_buffer[i][:,index_from],non_blocking=True) - - - def _assign_seq_id_to_seqlen_agnostic_cache( - self, - cur_rid: str, - seqs_id: List[int] - ) -> List[int]: + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, + seqs_id: List[int]) -> List[int]: indices_for_current_run = [] for seq_id in seqs_id: - if cur_rid not in self.seqlen_agnostic_cache_indices_mapping: - self.seqlen_agnostic_cache_indices_mapping[cur_rid] = {} - first_free_index = self._first_free_index_in_seqlen_agnostic_cache() - self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index + if cur_rid not in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index index_for_current_run = first_free_index ## case of decoding n>1, copy prefill cache to decoding indices - elif seq_id not in (seq_ids2indices := self.seqlen_agnostic_cache_indices_mapping[cur_rid]): - first_free_index = self._first_free_index_in_seqlen_agnostic_cache() + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_mamba_cache() index_exist = list(seq_ids2indices.values())[0] - self._copy_seqlen_agnostic_cache( - index_from=index_exist, - index_to=first_free_index, - from_buffer=self.seqlen_agnostic_cache - ) - self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] = first_free_index + self._copy_mamba_cache(index_from=index_exist, + index_to=first_free_index, + from_buffer=self.mamba_cache) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index index_for_current_run = first_free_index else: - index_for_current_run = self.seqlen_agnostic_cache_indices_mapping[cur_rid][seq_id] + index_for_current_run = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] indices_for_current_run.append(index_for_current_run) return indices_for_current_run - - def _prepare_current_run_seqlen_agnostic_cache( - self, - requests_info: List[RequestInfo], - batch_size: int - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor] ,List[int]]: + def _prepare_current_run_mamba_cache( + self, requests_info: List[RequestInfo], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: indices_for_current_run = [] for request_info in requests_info: cur_rid = request_info.request_id - indices_for_current_run += self._assign_seq_id_to_seqlen_agnostic_cache( - cur_rid, - request_info.seqs_id - ) + indices_for_current_run += self._assign_seq_id_to_mamba_cache( + cur_rid, request_info.seqs_id) ## Pad the batch in case of running batch that was not captured via CG padded_indices = indices_for_current_run.copy() - pad_index = self._first_free_index_in_seqlen_agnostic_cache() + pad_index = self._first_free_index_in_mamba_cache() for _ in range(batch_size - len(indices_for_current_run)): padded_indices.append(pad_index) - conv_state = self.seqlen_agnostic_cache[0][:,padded_indices] - temporal_state = self.seqlen_agnostic_cache[1][:,padded_indices] + conv_state = self.mamba_cache[0][:, padded_indices] + temporal_state = self.mamba_cache[1][:, padded_indices] return (conv_state, temporal_state), indices_for_current_run - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): requests_info = kwargs["requests_info"] batch_size = len(requests_info) ( current_seqlen_agnostic_cache, indices, - ) = self._prepare_current_run_seqlen_agnostic_cache(requests_info, batch_size) + ) = self._prepare_current_run_mamba_cache(requests_info, batch_size) self.current_indices = indices finished_request_ids = kwargs["finished_request_ids"] self._release_seqlen_agnostic_cache(finished_request_ids) - for i in [0,1]: + for i in [0, 1]: input_buffers["seqlen_agnostic_capture_inputs"][i].copy_( - current_seqlen_agnostic_cache[i], - non_blocking=True - ) - + current_seqlen_agnostic_cache[i], non_blocking=True) def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): - self._copy_seqlen_agnostic_cache_by_indices( + self._copy_mamba_cache_by_indices( self.current_indices, - input_buffers["seqlen_agnostic_capture_inputs"] - ) + input_buffers["seqlen_agnostic_capture_inputs"]) - def get_seqlen_agnostic_capture_inputs(self,batch_size): + def get_seqlen_agnostic_capture_inputs(self, batch_size): return ( - self.seqlen_agnostic_gc_cache_buffer[0][:, :batch_size], - self.seqlen_agnostic_gc_cache_buffer[1][:, :batch_size], + self.mamba_gc_cache_buffer[0][:, :batch_size], + self.mamba_gc_cache_buffer[1][:, :batch_size], ) - - def _release_seqlen_agnostic_cache(self, finished_seq_groups_req_ids: List[str]): + def _release_seqlen_agnostic_cache(self, + finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: - if req_id in self.seqlen_agnostic_cache_indices_mapping: - self.seqlen_agnostic_cache_indices_mapping.pop(req_id) - + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) - def _first_free_index_in_seqlen_agnostic_cache(self) -> int: - if self.seqlen_agnostic_cache is not None: - max_possible_bs = self.seqlen_agnostic_cache[0].shape[1] + def _first_free_index_in_mamba_cache(self) -> int: + if self.mamba_cache is not None: + max_possible_bs = self.mamba_cache[0].shape[1] occupied = [ - id for seq_ids in self.seqlen_agnostic_cache_indices_mapping.values() + id for seq_ids in self.mamba_cache_indices_mapping.values() for id in seq_ids.values() ] first_free_index = [ @@ -816,9 +810,8 @@ def _first_free_index_in_seqlen_agnostic_cache(self) -> int: return first_free_index return 0 - def _get_seqlen_agnostic_cache_shape( - self, - ) -> Tuple[Optional[Tuple[int,int]],Optional[Tuple[int,int]]]: + def _get_mamba_cache_shape( + self, ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size conv_state_shape = ( @@ -829,36 +822,31 @@ def _get_seqlen_agnostic_cache_shape( self.config.mamba_expand * self.config.hidden_size // world_size, self.config.mamba_d_state, ) - return conv_state_shape, temporal_state_shape - def _prepare_seqlen_agnostic_cache(self): # dtype = torch.get_default_dtype() dtype = self.lm_head.weight.dtype layers_type = self.config.layers_block_type - mamba_layers = sum([layer_type == "mamba" for layer_type in layers_type]) + mamba_layers = sum( + [layer_type == "mamba" for layer_type in layers_type]) num_seqlen_agnostic_layers = mamba_layers # TODO: get from config max_batch_size = 256 - conv_state_shape, temporal_state_shape = self._get_seqlen_agnostic_cache_shape() + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None - for buffername in [ - "seqlen_agnostic_cache", - "seqlen_agnostic_gc_cache_buffer", - ]: - buffer = ( - torch.empty( - size=(num_seqlen_agnostic_layers,max_batch_size) - + conv_state_shape, dtype=dtype, + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: + buffer = (torch.empty( + size=(num_seqlen_agnostic_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, device="cuda"), - torch.empty( - size=(num_seqlen_agnostic_layers,max_batch_size) + - temporal_state_shape, dtype=dtype, - device="cuda") - ) - setattr(self,buffername, buffer) - + torch.empty( + size=(num_seqlen_agnostic_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + setattr(self, buffername, buffer) def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -874,7 +862,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), diff --git a/vllm/sequence.py b/vllm/sequence.py index 037bbf8b83b..932690166c0 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -836,5 +836,4 @@ def clone( blocks_to_copy=self.blocks_to_copy.copy(), num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, - finished_request_ids=self.finished_request_ids - ) + finished_request_ids=self.finished_request_ids) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 9718dc21e4b..592eab82c78 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -30,7 +30,8 @@ def __init__( self.parallel_config = parallel_config self.head_size = model_config.get_head_size() - self.num_layers = model_config.get_num_attention_layers(parallel_config) + self.num_layers = model_config.get_num_attention_layers( + parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index bb8f5080246..381a5c5b7be 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import BatchType, ModelRunner +from vllm.worker.model_runner import BatchType, ModelRunner, RequestInfo logger = init_logger(__name__) @@ -52,8 +52,8 @@ def execute_model( finished_request_ids: Optional[List[str]] = None ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) + lora_requests, lora_mapping, multi_modal_input, + _) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -87,7 +87,8 @@ def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, torch.Tensor, + Optional[List[RequestInfo]]]: if self.is_driver_worker: prefill_reqs = [] decode_reqs = [] @@ -239,7 +240,7 @@ def prepare_input_tensors( ) return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input) + lora_requests, lora_mapping, multi_modal_input, None) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 61496db60bd..ae11c4e8acc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass import time +from dataclasses import dataclass from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union @@ -37,6 +37,7 @@ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] + @dataclass class RequestInfo: request_id: str @@ -139,7 +140,8 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.contains_seqlen_agnostic_layers = self.model_config.contains_seqlen_agnostic_layers(parallel_config) + self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers( + parallel_config) # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in @@ -167,7 +169,6 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -625,7 +626,8 @@ def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor, List[RequestInfo]]: + Set[LoRARequest], LoRAMapping, torch.Tensor, + Optional[List[RequestInfo]]]: if self.is_driver_worker: prefill_reqs = [] decode_reqs = [] @@ -821,10 +823,12 @@ def execute_model( } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) - if self.contains_seqlen_agnostic_layers: + if self.has_seqlen_agnostic: execute_model_kwargs.update({ - "requests_info": requests_info, - "finished_request_ids": finished_request_ids, + "requests_info": + requests_info, + "finished_request_ids": + finished_request_ids, }) hidden_states = model_executable(**execute_model_kwargs) @@ -908,7 +912,7 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers finished_request_ids = [seq.request_id for seq in seqs] - self.execute_model(seqs, kv_caches,finished_request_ids) + self.execute_model(seqs, kv_caches, finished_request_ids) torch.cuda.synchronize() return @@ -1020,9 +1024,12 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: "attn_metadata": attn_metadata, "memory_pool": self.graph_memory_pool, } - if self.contains_seqlen_agnostic_layers: - capture_inputs.update({"seqlen_agnostic_capture_inputs": self.model.get_seqlen_agnostic_capture_inputs(batch_size)} - ) + if self.has_seqlen_agnostic: + capture_inputs.update({ + "seqlen_agnostic_capture_inputs": + self.model.get_seqlen_agnostic_capture_inputs( + batch_size) + }) graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -1133,18 +1140,14 @@ def forward( self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_inputs_before_cuda_graphs( - self.input_buffers, - **kwargs - ) + self.model.copy_inputs_before_cuda_graphs(self.input_buffers, + **kwargs) # Run the graph. self.graph.replay() if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_outputs_after_cuda_graphs( - self.input_buffers, - **kwargs - ) + self.model.copy_outputs_after_cuda_graphs(self.input_buffers, + **kwargs) # Return the output tensor. return self.output_buffers["hidden_states"] From 8c6d82df7ff0fd35cb0c5668a162dfcf1e2690cc Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 17:19:57 +0300 Subject: [PATCH 065/110] Fix jamba bug --- vllm/model_executor/models/jamba.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 31b60046516..e91fe540703 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -35,7 +35,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput -from vllm.worker.model_runner import RequestInfo +from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -661,7 +661,6 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self._capture = False self.current_indices = [] self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -672,7 +671,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, **kwargs): if getattr(self, "mamba_cache", None) is None: - self._prepare_seqlen_agnostic_cache() + self._prepare_mamba_cache() if "seqlen_agnostic_capture_inputs" not in kwargs: requests_info = kwargs["requests_info"] @@ -824,15 +823,13 @@ def _get_mamba_cache_shape( ) return conv_state_shape, temporal_state_shape - def _prepare_seqlen_agnostic_cache(self): - # dtype = torch.get_default_dtype() + def _prepare_mamba_cache(self): dtype = self.lm_head.weight.dtype layers_type = self.config.layers_block_type mamba_layers = sum( [layer_type == "mamba" for layer_type in layers_type]) num_seqlen_agnostic_layers = mamba_layers - # TODO: get from config - max_batch_size = 256 + max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: From 628eec7f9d1c33639c0d52cbc7d18477c2070797 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 18:28:31 +0300 Subject: [PATCH 066/110] Renaming --- vllm/model_executor/models/jamba.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index e91fe540703..60579f25094 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -674,14 +674,14 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, self._prepare_mamba_cache() if "seqlen_agnostic_capture_inputs" not in kwargs: - requests_info = kwargs["requests_info"] + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: - batch_size = len(requests_info) + batch_size = len(request_ids_to_seq_ids) ( current_seqlen_agnostic_cache, indices, - ) = self._prepare_current_run_mamba_cache(requests_info, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, batch_size) finished_request_ids = kwargs["finished_request_ids"] self._release_seqlen_agnostic_cache(finished_request_ids) @@ -744,13 +744,12 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str, return indices_for_current_run def _prepare_current_run_mamba_cache( - self, requests_info: List[RequestInfo], batch_size: int + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: indices_for_current_run = [] - for request_info in requests_info: - cur_rid = request_info.request_id + for request_id,seqs_id in request_ids_to_seq_ids.items(): indices_for_current_run += self._assign_seq_id_to_mamba_cache( - cur_rid, request_info.seqs_id) + request_id, seqs_id) ## Pad the batch in case of running batch that was not captured via CG padded_indices = indices_for_current_run.copy() pad_index = self._first_free_index_in_mamba_cache() From 30030ce3002e7e74024a57390e53a514952d7cd4 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 18:28:45 +0300 Subject: [PATCH 067/110] WIP - Merge with main adaptations --- vllm/worker/model_runner.py | 62 ++++++++++++++++++++----------------- vllm/worker/worker.py | 1 + vllm/worker/worker_base.py | 3 ++ 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9361ab3e115..fe2515ddc2d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -52,10 +52,6 @@ TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") -@dataclass -class RequestInfo: - request_id: str - seqs_id: List[int] @dataclasses.dataclass(frozen=True) class ModelInputForGPU(ModelRunnerInputBase): @@ -73,6 +69,7 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + request_ids_to_seq_ids : Optional[Dict[str, List[int]]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -81,6 +78,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -114,6 +112,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -343,6 +342,7 @@ def _prepare_model_input_tensors( block_tables: List[List[int]] = [] multi_modal_kwargs_list: Dict[str, List[torch.Tensor]] = defaultdict(list) + request_ids_to_seq_ids: Dict[str,List[int]] = defaultdict(list) decode_only = True num_prefills = 0 num_prefill_tokens = 0 @@ -716,17 +716,11 @@ def _prepare_model_input_tensors( k: torch.cat(v, dim=0).to(self.device) for k, v in multi_modal_kwargs_list.items() } - # if self.vision_language_config: - # execute_model_kwargs.update({"image_input": multi_modal_input}) - # if self.has_seqlen_agnostic: - # execute_model_kwargs.update({ - # "requests_info": - # requests_info, - # "finished_request_ids": - # finished_request_ids, - # }) - # hidden_states = model_executable(**execute_model_kwargs) - + request_ids_to_seq_ids = { + seq_group_metadata.request_id: + list(seq_group_metadata.seq_data.keys()) + for seq_group_metadata in seq_group_metadata_list + } return self._model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, @@ -736,6 +730,7 @@ def _prepare_model_input_tensors( lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=request_ids_to_seq_ids ) @torch.inference_mode() @@ -923,20 +918,23 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: self.set_active_loras(set(), lora_mapping) graph_runner = CUDAGraphRunner(self.model) - # capture_inputs = { - # "input_ids": input_tokens[:batch_size], - # "positions": input_positions[:batch_size], - # "kv_caches": kv_caches, - # "attn_metadata": attn_metadata, - # "memory_pool": self.graph_memory_pool, - # } - # if self.has_seqlen_agnostic: - # capture_inputs.update({ - # "seqlen_agnostic_capture_inputs": - # self.model.get_seqlen_agnostic_capture_inputs( - # batch_size) - # }) - # graph_runner.capture(**capture_inputs) + capture_inputs = { + "input_ids": input_tokens[:batch_size], + "positions": input_positions[:batch_size], + "hidden_states": hidden_states[:batch_size] + if hidden_states is not None else None, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + "memory_pool": self.graph_memory_pool, + "stream":graph_capture_context.stream + } + if self.has_seqlen_agnostic: + capture_inputs.update({ + "seqlen_agnostic_capture_inputs": + self.model.get_seqlen_agnostic_capture_inputs( + batch_size) + }) + hidden_states = graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -1002,6 +1000,7 @@ def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + finished_request_ids: List[str] ) -> SamplerOutput: if self.lora_config: assert model_input.lora_requests is not None @@ -1021,12 +1020,17 @@ def execute_model( model_executable = self.model multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_request_ids":finished_request_ids, + "request_ids_to_seq_ids":model_input.request_ids_to_seq_ids, + } hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, **multi_modal_kwargs, + **seqlen_agnostic_kwargs ) # Compute the logits. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e1944a4f1d6..ce9bee1eda2 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -257,6 +257,7 @@ def prepare_worker_input( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + finished_request_ids=execute_model_req.finished_request_ids ) @torch.inference_mode() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1df60eb1f38..095ccd05ba0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -124,6 +124,7 @@ class WorkerInput: blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None + finished_request_ids :Optional[List[str]] = None @classmethod def from_broadcasted_tensor_dict( @@ -139,6 +140,7 @@ def from_broadcasted_tensor_dict( blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + finished_request_ids=tensor_dict.pop("finished_request_ids"), ) def as_broadcastable_tensor_dict( @@ -151,6 +153,7 @@ def as_broadcastable_tensor_dict( "blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_copy": self.blocks_to_copy, + "finished_request_ids": self.finished_request_ids, } return tensor_dict From 3fba9bc8d481a68bef9e0f876bd818e7ce19de86 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 19:18:58 +0300 Subject: [PATCH 068/110] Fix --- vllm/model_executor/models/jamba.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 60579f25094..b661579acbc 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -458,7 +458,6 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( config.hidden_size, @@ -479,7 +478,6 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - sliding_window=self.sliding_window, ) num_experts = config.layers_num_experts[layer_idx] @@ -747,7 +745,7 @@ def _prepare_current_run_mamba_cache( self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: indices_for_current_run = [] - for request_id,seqs_id in request_ids_to_seq_ids.items(): + for request_id, seqs_id in request_ids_to_seq_ids.items(): indices_for_current_run += self._assign_seq_id_to_mamba_cache( request_id, seqs_id) ## Pad the batch in case of running batch that was not captured via CG @@ -763,12 +761,13 @@ def _prepare_current_run_mamba_cache( return (conv_state, temporal_state), indices_for_current_run def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - requests_info = kwargs["requests_info"] - batch_size = len(requests_info) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = len(request_ids_to_seq_ids) ( current_seqlen_agnostic_cache, indices, - ) = self._prepare_current_run_mamba_cache(requests_info, batch_size) + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) self.current_indices = indices finished_request_ids = kwargs["finished_request_ids"] From 45f3d96617cc2821a30842ef172ea8d408cecf04 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 26 Jun 2024 19:19:20 +0300 Subject: [PATCH 069/110] Formating --- vllm/sequence.py | 3 +- vllm/worker/embedding_model_runner.py | 1 - vllm/worker/model_runner.py | 49 ++++++++++++++------------- vllm/worker/model_runner_base.py | 5 ++- vllm/worker/worker.py | 3 +- vllm/worker/worker_base.py | 5 +-- 6 files changed, 32 insertions(+), 34 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 0f6c0803b50..1fb94222f49 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -918,5 +918,4 @@ def clone( num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, - finished_request_ids=self.finished_request_ids - ) + finished_request_ids=self.finished_request_ids) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 96a6cc1161e..9b1d83b9914 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -97,7 +97,6 @@ def execute_model( return self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) - def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fe2515ddc2d..db1042dfc96 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,7 +1,6 @@ import dataclasses import gc import time -from dataclasses import dataclass import warnings from collections import defaultdict from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, @@ -69,7 +68,7 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - request_ids_to_seq_ids : Optional[Dict[str, List[int]]] = None + request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -342,7 +341,7 @@ def _prepare_model_input_tensors( block_tables: List[List[int]] = [] multi_modal_kwargs_list: Dict[str, List[torch.Tensor]] = defaultdict(list) - request_ids_to_seq_ids: Dict[str,List[int]] = defaultdict(list) + request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) decode_only = True num_prefills = 0 num_prefill_tokens = 0 @@ -730,8 +729,7 @@ def _prepare_model_input_tensors( lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids - ) + request_ids_to_seq_ids=request_ids_to_seq_ids) @torch.inference_mode() def profile_run(self) -> None: @@ -805,10 +803,9 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - # finished_request_ids = [seq.request_id for seq in seqs] - # self.execute_model(seqs, kv_caches, finished_request_ids) model_input = self.prepare_model_input(seqs) - self.execute_model(model_input, kv_caches) + finished_request_ids = [seq.request_id for seq in seqs] + self.execute_model(model_input, kv_caches, finished_request_ids) torch.cuda.synchronize() return @@ -919,14 +916,21 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: graph_runner = CUDAGraphRunner(self.model) capture_inputs = { - "input_ids": input_tokens[:batch_size], - "positions": input_positions[:batch_size], - "hidden_states": hidden_states[:batch_size] + "input_ids": + input_tokens[:batch_size], + "positions": + input_positions[:batch_size], + "hidden_states": + hidden_states[:batch_size] if hidden_states is not None else None, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, - "memory_pool": self.graph_memory_pool, - "stream":graph_capture_context.stream + "kv_caches": + kv_caches, + "attn_metadata": + attn_metadata, + "memory_pool": + self.graph_memory_pool, + "stream": + graph_capture_context.stream } if self.has_seqlen_agnostic: capture_inputs.update({ @@ -997,11 +1001,9 @@ def prepare_model_input( @torch.inference_mode() def execute_model( - self, - model_input: ModelInputForGPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - finished_request_ids: List[str] - ) -> SamplerOutput: + self, model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + finished_request_ids: Optional[List[str]]) -> SamplerOutput: if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None @@ -1021,8 +1023,8 @@ def execute_model( multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { - "finished_request_ids":finished_request_ids, - "request_ids_to_seq_ids":model_input.request_ids_to_seq_ids, + "finished_request_ids": finished_request_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } hidden_states = model_executable( input_ids=model_input.input_tokens, @@ -1030,8 +1032,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, **multi_modal_kwargs, - **seqlen_agnostic_kwargs - ) + **seqlen_agnostic_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 9b1706035a3..57be35f80d6 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -147,9 +147,8 @@ def prepare_model_input( @torch.inference_mode() def execute_model( - self, - model_input: T, - kv_caches: Optional[List[torch.Tensor]], + self, model_input: T, kv_caches: Optional[List[torch.Tensor]], + finished_request_ids: Optional[List[str]] ) -> Optional[SamplerOutput]: """ Execute the model on the given input. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ce9bee1eda2..202233eaf69 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -257,8 +257,7 @@ def prepare_worker_input( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - finished_request_ids=execute_model_req.finished_request_ids - ) + finished_request_ids=execute_model_req.finished_request_ids) @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 095ccd05ba0..3a7202b7b2b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -124,7 +124,7 @@ class WorkerInput: blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None - finished_request_ids :Optional[List[str]] = None + finished_request_ids: Optional[List[str]] = None @classmethod def from_broadcasted_tensor_dict( @@ -255,7 +255,8 @@ def execute_model( if worker_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model(model_input, self.kv_cache) + output = self.model_runner.execute_model( + model_input, self.kv_cache, worker_input.finished_request_ids) # Worker only supports single-step execution. Wrap the output in a # list to conform to interface. return [output] From 794f1c307ac420e0397924d1e792381a89d1f93e Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 10:21:01 +0300 Subject: [PATCH 070/110] deploy the finihsed request ids inside the modelinputs instead of worker inputs --- vllm/worker/model_runner.py | 17 ++++++++++++----- vllm/worker/model_runner_base.py | 4 ++-- vllm/worker/worker.py | 3 +-- vllm/worker/worker_base.py | 11 +++++------ 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index db1042dfc96..bb10574cc6b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -69,6 +69,7 @@ class ModelInputForGPU(ModelRunnerInputBase): attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None + finished_request_ids: Optional[List[str]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -78,6 +79,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_request_ids": self.finished_request_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -112,6 +114,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_request_ids": self.finished_request_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -311,6 +314,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] ) -> TModelInputForGPU: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not @@ -729,7 +733,8 @@ def _prepare_model_input_tensors( lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids) + request_ids_to_seq_ids=request_ids_to_seq_ids, + finished_request_ids=finished_request_ids) @torch.inference_mode() def profile_run(self) -> None: @@ -803,9 +808,9 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs) finished_request_ids = [seq.request_id for seq in seqs] - self.execute_model(model_input, kv_caches, finished_request_ids) + model_input = self.prepare_model_input(seqs,finished_request_ids) + self.execute_model(model_input, kv_caches) torch.cuda.synchronize() return @@ -972,6 +977,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -987,7 +993,8 @@ def prepare_model_input( If cuda graph is required, this API automatically pads inputs. """ model_input = self._prepare_model_input_tensors( - seq_group_metadata_list) + seq_group_metadata_list, + finished_request_ids) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, @@ -1023,7 +1030,7 @@ def execute_model( multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { - "finished_request_ids": finished_request_ids, + "finished_request_ids": model_input.finished_request_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } hidden_states = model_executable( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 57be35f80d6..52c902caae8 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -137,6 +137,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] ) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution @@ -147,8 +148,7 @@ def prepare_model_input( @torch.inference_mode() def execute_model( - self, model_input: T, kv_caches: Optional[List[torch.Tensor]], - finished_request_ids: Optional[List[str]] + self, model_input: T, kv_caches: Optional[List[torch.Tensor]] ) -> Optional[SamplerOutput]: """ Execute the model on the given input. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 202233eaf69..df121754350 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -256,8 +256,7 @@ def prepare_worker_input( num_seq_groups=num_seq_groups, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - finished_request_ids=execute_model_req.finished_request_ids) + blocks_to_copy=blocks_to_copy) @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 3a7202b7b2b..b9b7d9bba8d 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -124,7 +124,6 @@ class WorkerInput: blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None - finished_request_ids: Optional[List[str]] = None @classmethod def from_broadcasted_tensor_dict( @@ -140,7 +139,6 @@ def from_broadcasted_tensor_dict( blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - finished_request_ids=tensor_dict.pop("finished_request_ids"), ) def as_broadcastable_tensor_dict( @@ -152,8 +150,7 @@ def as_broadcastable_tensor_dict( "num_seq_groups": self.num_seq_groups, "blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_out": self.blocks_to_swap_out, - "blocks_to_copy": self.blocks_to_copy, - "finished_request_ids": self.finished_request_ids, + "blocks_to_copy": self.blocks_to_copy } return tensor_dict @@ -230,7 +227,9 @@ def execute_model( execute_model_req=execute_model_req) model_input: ModelRunnerInputBase = ( self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) + execute_model_req.seq_group_metadata_list, + execute_model_req.finished_request_ids + )) if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() @@ -256,7 +255,7 @@ def execute_model( return [] output = self.model_runner.execute_model( - model_input, self.kv_cache, worker_input.finished_request_ids) + model_input, self.kv_cache) # Worker only supports single-step execution. Wrap the output in a # list to conform to interface. return [output] From 33eb4053e3ee92ede1f227cdc055f2d8102b4b05 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 10:35:22 +0300 Subject: [PATCH 071/110] fix --- vllm/worker/model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bb10574cc6b..5b5504b8711 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1009,8 +1009,7 @@ def prepare_model_input( @torch.inference_mode() def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - finished_request_ids: Optional[List[str]]) -> SamplerOutput: + kv_caches: List[torch.Tensor]) -> SamplerOutput: if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None From 25c03e73ac1e36df1cf4b85b19d5de225fb252f2 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 10:36:44 +0300 Subject: [PATCH 072/110] Renaming --- vllm/core/scheduler.py | 10 +++++----- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d61efda3190..75b5286e578 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -292,7 +292,7 @@ def __init__( self.swapped: Deque[SequenceGroup] = deque() # Sequence groups finished in after the last step iter. - self.previously_finished_request_id: List[str] = list() + self.finished_request_id: List[str] = list() # Time at previous scheduling step self.prev_time = 0.0 @@ -367,10 +367,10 @@ def has_unfinished_seqs(self) -> bool: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def flush_last_step_finished_req_ids(self) -> List[str]: + def flush_finished_request_ids(self) -> List[str]: """Flushes the list of request ids of previously finished seq_groups.""" - finished_request_ids = self.previously_finished_request_id - self.previously_finished_request_id = [] + finished_request_ids = self.finished_request_id + self.finished_request_id = [] return finished_request_ids def _schedule_running( @@ -1036,7 +1036,7 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: - self.previously_finished_request_id += [ + self.finished_request_id += [ seq_group.request_id for seq_group in self.running if seq_group.is_finished() ] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 80a28c704e3..0743f17d640 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -233,7 +233,7 @@ async def step_async( num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, finished_request_ids=self.scheduler. - flush_last_step_finished_req_ids()) + flush_finished_request_ids()) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e32d045f7f1..e7eca8fbcad 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -803,7 +803,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, finished_request_ids=self.scheduler. - flush_last_step_finished_req_ids()) + flush_finished_request_ids()) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: From 94d40a87c27e1aac981c114fe0e5c350ea872f86 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 10:47:01 +0300 Subject: [PATCH 073/110] Format --- vllm/engine/async_llm_engine.py | 4 ++-- vllm/engine/llm_engine.py | 4 ++-- vllm/worker/cpu_model_runner.py | 5 ++--- vllm/worker/embedding_model_runner.py | 6 +++--- vllm/worker/model_runner.py | 19 +++++++------------ vllm/worker/model_runner_base.py | 6 ++---- vllm/worker/neuron_model_runner.py | 5 ++--- vllm/worker/worker.py | 9 ++++----- vllm/worker/worker_base.py | 6 ++---- vllm/worker/xpu_model_runner.py | 7 +++---- 10 files changed, 29 insertions(+), 42 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0743f17d640..e33e6add050 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -232,8 +232,8 @@ async def step_async( blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler. - flush_finished_request_ids()) + finished_request_ids=self.scheduler.flush_finished_request_ids( + )) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e7eca8fbcad..f39bff7d9ec 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -802,8 +802,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler. - flush_finished_request_ids()) + finished_request_ids=self.scheduler.flush_finished_request_ids( + )) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index e3464c0d390..cb546080a55 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -325,9 +325,8 @@ def make_model_input_from_broadcasted_tensor_dict( ) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> CPUModelInput: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]]) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or # all decodes. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 9b1d83b9914..d5b23cd05c9 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -107,12 +107,12 @@ def make_model_input_from_broadcasted_tensor_dict( ) def prepare_model_input( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( - seq_group_metadata_list) + seq_group_metadata_list, finished_request_ids) # Prepare PoolingMetadata. assert model_input.seq_lens is not None pooling_metadata = self._prepare_pooling(seq_group_metadata_list, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5b5504b8711..94197dc8e93 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -312,10 +312,8 @@ def get_max_block_per_batch(self) -> int: return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] - ) -> TModelInputForGPU: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]]) -> TModelInputForGPU: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. @@ -809,7 +807,7 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers finished_request_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input(seqs,finished_request_ids) + model_input = self.prepare_model_input(seqs, finished_request_ids) self.execute_model(model_input, kv_caches) torch.cuda.synchronize() return @@ -975,8 +973,7 @@ def make_model_input_from_broadcasted_tensor_dict( )) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], + self, seq_group_metadata_list: List[SequenceGroupMetadata], finished_request_ids: Optional[List[str]] ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including @@ -993,8 +990,7 @@ def prepare_model_input( If cuda graph is required, this API automatically pads inputs. """ model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, - finished_request_ids) + seq_group_metadata_list, finished_request_ids) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, @@ -1007,9 +1003,8 @@ def prepare_model_input( is_prompt=is_prompt) @torch.inference_mode() - def execute_model( - self, model_input: ModelInputForGPUWithSamplingMetadata, - kv_caches: List[torch.Tensor]) -> SamplerOutput: + def execute_model(self, model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor]) -> SamplerOutput: if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 52c902caae8..dbd0ddf8c96 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -135,10 +135,8 @@ def make_model_input_from_broadcasted_tensor_dict( @abstractmethod def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] - ) -> T: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]]) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution request. This method may move data to the worker's local device. It is diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index fec2c97e738..0c17288c94e 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -173,9 +173,8 @@ def make_model_input_from_broadcasted_tensor_dict( return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInputForNeuron: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]]) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index df121754350..ba1964120e7 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -252,11 +252,10 @@ def prepare_worker_input( device=self.device, dtype=torch.int64).view(-1, 2) - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + return WorkerInput(num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b9b7d9bba8d..e4fa4557b66 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -228,8 +228,7 @@ def execute_model( model_input: ModelRunnerInputBase = ( self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list, - execute_model_req.finished_request_ids - )) + execute_model_req.finished_request_ids)) if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() @@ -254,8 +253,7 @@ def execute_model( if worker_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model( - model_input, self.kv_cache) + output = self.model_runner.execute_model(model_input, self.kv_cache) # Worker only supports single-step execution. Wrap the output in a # list to conform to interface. return [output] diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index d9124a788a6..067644ea5c5 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -175,7 +175,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs) + model_input = self.prepare_model_input(seqs, None) self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return @@ -188,9 +188,8 @@ def make_model_input_from_broadcasted_tensor_dict( )) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInputForXPU: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]]) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or From 976166f7ce9ba90c7f685cc5535c5329ee35cbdb Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 10:54:23 +0300 Subject: [PATCH 074/110] Typing and format --- vllm/model_executor/models/jamba.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b661579acbc..2fe22bdf39a 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -701,12 +701,14 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, return hidden_states - def _copy_mamba_cache_by_indices(self, indices, - current_seqlen_agnostic_cache): + def _copy_mamba_cache_by_indices( + self, indices: List[int], + current_seqlen_agnostic_cache: Tuple[torch.Tensor]): for i, offset in enumerate(indices): self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) - def _copy_mamba_cache(self, index_to, index_from, from_buffer): + def _copy_mamba_cache(self, index_to: int, index_from: int, + from_buffer: Tuple[torch.Tensor]): assert self.mamba_cache is not None for i in [0, 1]: self.mamba_cache[i][:, index_to].copy_(from_buffer[i][:, @@ -782,7 +784,7 @@ def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): self.current_indices, input_buffers["seqlen_agnostic_capture_inputs"]) - def get_seqlen_agnostic_capture_inputs(self, batch_size): + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return ( self.mamba_gc_cache_buffer[0][:, :batch_size], self.mamba_gc_cache_buffer[1][:, :batch_size], @@ -808,7 +810,8 @@ def _first_free_index_in_mamba_cache(self) -> int: return 0 def _get_mamba_cache_shape( - self, ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size conv_state_shape = ( From 8181821122af2bcbd116d908cb75e39e8658d341 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 11:01:19 +0300 Subject: [PATCH 075/110] Cleanup --- Dockerfile | 5 ----- vllm/core/scheduler.py | 4 +--- vllm/worker/cache_engine.py | 1 + vllm/worker/model_runner.py | 7 +++++-- vllm/worker/model_runner_base.py | 4 +++- vllm/worker/worker.py | 10 ++++++---- vllm/worker/worker_base.py | 2 +- 7 files changed, 17 insertions(+), 16 deletions(-) diff --git a/Dockerfile b/Dockerfile index 90314d44a0c..f571e8be421 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,11 +47,6 @@ COPY requirements-mamba.txt requirements-mamba.txt RUN python3 -m pip install packaging RUN python3 -m pip install -r requirements-mamba.txt -# install development dependencies -COPY requirements-dev.txt requirements-dev.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-dev.txt - # cuda arch list used by torch # can be useful for both `dev` and `test` # explicitly set the list to avoid issues with torch 2.2 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 75b5286e578..1c72b534d19 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -290,10 +290,8 @@ def __init__( # Sequence groups in the SWAPPED state. # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() - - # Sequence groups finished in after the last step iter. + # Sequence groups finished since last step iter. self.finished_request_id: List[str] = list() - # Time at previous scheduling step self.prev_time = 0.0 # Did we schedule a prompt at previous step? diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c2ade312001..68136baa9b3 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -105,6 +105,7 @@ def get_cache_block_size( head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_attention_layers(parallel_config) + key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 94197dc8e93..eae8c54f94b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1003,8 +1003,11 @@ def prepare_model_input( is_prompt=is_prompt) @torch.inference_mode() - def execute_model(self, model_input: ModelInputForGPUWithSamplingMetadata, - kv_caches: List[torch.Tensor]) -> SamplerOutput: + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + ) -> SamplerOutput: if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index dbd0ddf8c96..ea571de7bf4 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -146,7 +146,9 @@ def prepare_model_input( @torch.inference_mode() def execute_model( - self, model_input: T, kv_caches: Optional[List[torch.Tensor]] + self, + model_input: T, + kv_caches: Optional[List[torch.Tensor]], ) -> Optional[SamplerOutput]: """ Execute the model on the given input. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ba1964120e7..e1944a4f1d6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -252,10 +252,12 @@ def prepare_worker_input( device=self.device, dtype=torch.int64).view(-1, 2) - return WorkerInput(num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e4fa4557b66..6032d0f4218 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -150,7 +150,7 @@ def as_broadcastable_tensor_dict( "num_seq_groups": self.num_seq_groups, "blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_out": self.blocks_to_swap_out, - "blocks_to_copy": self.blocks_to_copy + "blocks_to_copy": self.blocks_to_copy, } return tensor_dict From 4fdc35be22ea99f3b731b66518570b7d2ca4c0c9 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 11:25:43 +0300 Subject: [PATCH 076/110] Remove requirements-common and cuda from requirements-mamba --- requirements-mamba.txt | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/requirements-mamba.txt b/requirements-mamba.txt index 30a5b0948f0..1838e87d063 100644 --- a/requirements-mamba.txt +++ b/requirements-mamba.txt @@ -1,6 +1,3 @@ -# Common dependencies --r requirements-common.txt --r requirements-cuda.txt - +# Mamba dependencies mamba-ssm>=1.2.2 causal-conv1d>=1.2.0 From aadeca28f29f773845783c785cca2719ef566ebb Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 12:57:39 +0300 Subject: [PATCH 077/110] Fix --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index eae8c54f94b..d7979d79a99 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1029,7 +1029,7 @@ def execute_model( seqlen_agnostic_kwargs = { "finished_request_ids": model_input.finished_request_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } + } if self.has_seqlen_agnostic else {} hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, From fee775edbbd071025d6a38d074cba70f81ed963f Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 14:20:57 +0300 Subject: [PATCH 078/110] set finished requests ids as none on default --- vllm/worker/cpu_model_runner.py | 5 +++-- vllm/worker/embedding_model_runner.py | 5 +++-- vllm/worker/model_runner.py | 11 +++++++---- vllm/worker/model_runner_base.py | 5 +++-- vllm/worker/neuron_model_runner.py | 6 ++++-- vllm/worker/xpu_model_runner.py | 6 ++++-- 6 files changed, 24 insertions(+), 14 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index cb546080a55..da267582af0 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -325,8 +325,9 @@ def make_model_input_from_broadcasted_tensor_dict( ) def prepare_model_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]]) -> CPUModelInput: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] = None) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or # all decodes. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d5b23cd05c9..a62dec113ed 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -107,8 +107,9 @@ def make_model_input_from_broadcasted_tensor_dict( ) def prepare_model_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d7979d79a99..b6e1d337a79 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -312,8 +312,10 @@ def get_max_block_per_batch(self) -> int: return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_model_input_tensors( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]]) -> TModelInputForGPU: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] = None + ) -> TModelInputForGPU: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. @@ -973,8 +975,9 @@ def make_model_input_from_broadcasted_tensor_dict( )) def prepare_model_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index ea571de7bf4..f42be58bb25 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -135,8 +135,9 @@ def make_model_input_from_broadcasted_tensor_dict( @abstractmethod def prepare_model_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]]) -> T: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] = None) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution request. This method may move data to the worker's local device. It is diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 0c17288c94e..7c5e38e66f9 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -173,8 +173,10 @@ def make_model_input_from_broadcasted_tensor_dict( return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) def prepare_model_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]]) -> ModelInputForNeuron: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] = None + ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 067644ea5c5..12a70e230af 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -188,8 +188,10 @@ def make_model_input_from_broadcasted_tensor_dict( )) def prepare_model_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]]) -> ModelInputForXPU: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids: Optional[List[str]] = None + ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or From 668f3d906dac737aaa4aab6915a20ff4091daa59 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 14:26:54 +0300 Subject: [PATCH 079/110] get attr to get num hidden layers --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 1e9ffc99d23..8e9ff522306 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -349,7 +349,8 @@ def get_num_attention_heads(self, return num_heads // parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: - total_num_hidden_layers = self.hf_text_config.num_hidden_layers + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) return total_num_hidden_layers // parallel_config.pipeline_parallel_size def contains_seqlen_agnostic_layers( From 10a44dc921c57b7afeb06048131f573302181ce6 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 14:29:53 +0300 Subject: [PATCH 080/110] Add jamba test --- tests/models/test_jamba.py | 48 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/models/test_jamba.py diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py new file mode 100644 index 00000000000..6865467df2d --- /dev/null +++ b/tests/models/test_jamba.py @@ -0,0 +1,48 @@ +import pytest + +MODELS = [ + "ai21labs/Jamba-tiny-random" +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + assert dtype == "float" + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) From cd9ba35121abca25bf8a3f9c7cb2bb091746416e Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 15:18:14 +0300 Subject: [PATCH 081/110] Ignore jamba test in cpu --- .buildkite/run-cpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f4fa24be1f2..9d4b2bb1cd5 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" docker exec cpu-test bash -c "cd tests; pip install pytest Pillow protobuf cd ../ - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported From 6df4f6965695d6701ce44756743370ab6e48b0c5 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 15:48:09 +0300 Subject: [PATCH 082/110] Cleanup --- vllm/worker/xpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 12a70e230af..6c349f92446 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -175,7 +175,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs, None) + model_input = self.prepare_model_input(seqs) self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return From 75dd84e5b0ea2970a742802eab003bc99904b779 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 16:50:50 +0300 Subject: [PATCH 083/110] Format and rename --- tests/models/test_jamba.py | 4 +--- vllm/core/scheduler.py | 12 ++++++------ vllm/engine/async_llm_engine.py | 4 ++-- vllm/engine/llm_engine.py | 4 ++-- vllm/model_executor/models/jamba.py | 11 +++++------ vllm/sequence.py | 4 ++-- vllm/worker/cpu_model_runner.py | 3 ++- vllm/worker/embedding_model_runner.py | 12 ++++-------- vllm/worker/model_runner.py | 24 ++++++++++++------------ vllm/worker/model_runner_base.py | 7 ++++--- vllm/worker/neuron_model_runner.py | 2 +- vllm/worker/worker_base.py | 2 +- vllm/worker/xpu_model_runner.py | 2 +- 13 files changed, 43 insertions(+), 48 deletions(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 6865467df2d..6e95f944ec9 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -1,8 +1,6 @@ import pytest -MODELS = [ - "ai21labs/Jamba-tiny-random" -] +MODELS = ["ai21labs/Jamba-tiny-random"] @pytest.mark.parametrize("model", MODELS) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1c72b534d19..d580b6014aa 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -291,7 +291,7 @@ def __init__( # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() # Sequence groups finished since last step iter. - self.finished_request_id: List[str] = list() + self.finished_requests_ids: List[str] = list() # Time at previous scheduling step self.prev_time = 0.0 # Did we schedule a prompt at previous step? @@ -365,11 +365,11 @@ def has_unfinished_seqs(self) -> bool: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def flush_finished_request_ids(self) -> List[str]: + def flush_finished_requests_ids(self) -> List[str]: """Flushes the list of request ids of previously finished seq_groups.""" - finished_request_ids = self.finished_request_id - self.finished_request_id = [] - return finished_request_ids + finished_requests_ids = self.finished_requests_ids + self.finished_requests_ids = [] + return finished_requests_ids def _schedule_running( self, @@ -1034,7 +1034,7 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: - self.finished_request_id += [ + self.finished_requests_ids += [ seq_group.request_id for seq_group in self.running if seq_group.is_finished() ] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e33e6add050..e18ee54cfd1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -232,8 +232,8 @@ async def step_async( blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler.flush_finished_request_ids( - )) + finished_requests_ids=self.scheduler. + flush_finished_requests_ids()) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f39bff7d9ec..5f6eeb328e2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -802,8 +802,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_request_ids=self.scheduler.flush_finished_request_ids( - )) + finished_requests_ids=self.scheduler. + flush_finished_requests_ids()) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2fe22bdf39a..7ab3cc3e341 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -681,8 +681,8 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, batch_size) - finished_request_ids = kwargs["finished_request_ids"] - self._release_seqlen_agnostic_cache(finished_request_ids) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) else: ## CG capturing runs current_seqlen_agnostic_cache, indices = ( @@ -772,8 +772,8 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): batch_size) self.current_indices = indices - finished_request_ids = kwargs["finished_request_ids"] - self._release_seqlen_agnostic_cache(finished_request_ids) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) for i in [0, 1]: input_buffers["seqlen_agnostic_capture_inputs"][i].copy_( @@ -790,8 +790,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): self.mamba_gc_cache_buffer[1][:, :batch_size], ) - def _release_seqlen_agnostic_cache(self, - finished_seq_groups_req_ids: List[str]): + def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.mamba_cache_indices_mapping: self.mamba_cache_indices_mapping.pop(req_id) diff --git a/vllm/sequence.py b/vllm/sequence.py index 1fb94222f49..658bbc3d8f6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -904,7 +904,7 @@ class ExecuteModelRequest: # Optional hidden states from prior step. previous_hidden_states: Optional[HiddenStates] = None # Finished request ids since last step. - finished_request_ids: List[str] = field(default_factory=list) + finished_requests_ids: List[str] = field(default_factory=list) def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -918,4 +918,4 @@ def clone( num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, - finished_request_ids=self.finished_request_ids) + finished_requests_ids=self.finished_requests_ids) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index da267582af0..e5b3c3d48d5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -327,7 +327,8 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] = None) -> CPUModelInput: + finished_requests_ids: Optional[List[str]] = None + ) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or # all decodes. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a62dec113ed..9b93149eec5 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -53,12 +53,8 @@ def __init__( vision_language_config=vision_language_config) @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForGPUWithPoolingMetadata, - kv_caches: List[torch.Tensor], - finished_request_ids: Optional[List[str]] = None - ) -> Optional[PoolerOutput]: + def execute_model(self, model_input: ModelInputForGPUWithPoolingMetadata, + kv_caches: List[torch.Tensor]) -> Optional[PoolerOutput]: if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None @@ -109,11 +105,11 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_request_ids) + seq_group_metadata_list, finished_requests_ids) # Prepare PoolingMetadata. assert model_input.seq_lens is not None pooling_metadata = self._prepare_pooling(seq_group_metadata_list, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b6e1d337a79..faf0042ca15 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -69,7 +69,7 @@ class ModelInputForGPU(ModelRunnerInputBase): attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None - finished_request_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -79,7 +79,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_request_ids": self.finished_request_ids, + "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -114,7 +114,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_request_ids": self.finished_request_ids, + "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -312,9 +312,9 @@ def get_max_block_per_batch(self) -> int: return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] = None + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None ) -> TModelInputForGPU: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not @@ -734,7 +734,7 @@ def _prepare_model_input_tensors( lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_request_ids=finished_request_ids) + finished_requests_ids=finished_requests_ids) @torch.inference_mode() def profile_run(self) -> None: @@ -808,8 +808,8 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - finished_request_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input(seqs, finished_request_ids) + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input(seqs, finished_requests_ids) self.execute_model(model_input, kv_caches) torch.cuda.synchronize() return @@ -977,7 +977,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -993,7 +993,7 @@ def prepare_model_input( If cuda graph is required, this API automatically pads inputs. """ model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_request_ids) + seq_group_metadata_list, finished_requests_ids) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, @@ -1030,7 +1030,7 @@ def execute_model( multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { - "finished_request_ids": model_input.finished_request_ids, + "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} hidden_states = model_executable( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index f42be58bb25..93de54848c5 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -135,9 +135,10 @@ def make_model_input_from_broadcasted_tensor_dict( @abstractmethod def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] = None) -> T: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None, + ) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution request. This method may move data to the worker's local device. It is diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 7c5e38e66f9..2844c306a12 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -175,7 +175,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6032d0f4218..6cbd6ce4349 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -228,7 +228,7 @@ def execute_model( model_input: ModelRunnerInputBase = ( self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list, - execute_model_req.finished_request_ids)) + execute_model_req.finished_requests_ids)) if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 6c349f92446..cdcd73db975 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -190,7 +190,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - finished_request_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: From 577f678d2e6ab705c080af1fd0e80969527c1db6 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 27 Jun 2024 17:03:35 +0300 Subject: [PATCH 084/110] Format --- vllm/worker/embedding_model_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 9b93149eec5..a2b26a584f5 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -53,8 +53,11 @@ def __init__( vision_language_config=vision_language_config) @torch.inference_mode() - def execute_model(self, model_input: ModelInputForGPUWithPoolingMetadata, - kv_caches: List[torch.Tensor]) -> Optional[PoolerOutput]: + def execute_model( + self, + model_input: ModelInputForGPUWithPoolingMetadata, + kv_caches: List[torch.Tensor], + ) -> Optional[PoolerOutput]: if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None From 7bb332e34161e2e117620d96ca37e7911c935a16 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 00:54:36 +0300 Subject: [PATCH 085/110] change num_layers to num_attention_layers and add comment --- vllm/worker/cache_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 68136baa9b3..902e47de126 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -33,7 +33,8 @@ def __init__( self.device_config = device_config self.head_size = model_config.get_head_size() - self.num_layers = model_config.get_num_attention_layers( + # Models like Jamba, have mixed typed layers, E.g Mamba + self.num_attention_layers = model_config.get_num_attention_layers( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) @@ -72,7 +73,7 @@ def _allocate_kv_cache( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] - for _ in range(self.num_layers): + for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. @@ -84,12 +85,12 @@ def _allocate_kv_cache( return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_layers): + for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) def swap_out(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_layers): + for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) From c05175877162426675fd215a2d196f68fd6926ae Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 02:31:06 +0300 Subject: [PATCH 086/110] Extended the finished reqeusts ids comment --- vllm/core/scheduler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d580b6014aa..206a1099439 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -290,7 +290,9 @@ def __init__( # Sequence groups in the SWAPPED state. # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() - # Sequence groups finished since last step iter. + # Sequence groups finished requests ids since last step iteration. + # It lets the model know that any state associated with these requests + # can and must be released after the current step. self.finished_requests_ids: List[str] = list() # Time at previous scheduling step self.prev_time = 0.0 @@ -368,7 +370,7 @@ def get_num_unfinished_seq_groups(self) -> int: def flush_finished_requests_ids(self) -> List[str]: """Flushes the list of request ids of previously finished seq_groups.""" finished_requests_ids = self.finished_requests_ids - self.finished_requests_ids = [] + self.finished_requests_ids = list() return finished_requests_ids def _schedule_running( From b6dc237415284a28a79ff6b76150ab2f5ca57323 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 02:34:15 +0300 Subject: [PATCH 087/110] Format and make the jamba code more readable, adding comments and explicitly declare vars --- vllm/core/scheduler.py | 2 +- vllm/model_executor/models/jamba.py | 88 ++++++++++++++++++----------- 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 206a1099439..f84f8d4391a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -291,7 +291,7 @@ def __init__( # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() # Sequence groups finished requests ids since last step iteration. - # It lets the model know that any state associated with these requests + # It lets the model know that any state associated with these requests # can and must be released after the current step. self.finished_requests_ids: List[str] = list() # Time at previous scheduling step diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 7ab3cc3e341..b67d6a4d275 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -659,7 +659,14 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - self.current_indices = [] + # Current step used indices + self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Used as an input_buffer for the CUDA graph runs. + self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -668,10 +675,13 @@ def __init__( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, **kwargs): - if getattr(self, "mamba_cache", None) is None: + if not self.mamba_cache: self._prepare_mamba_cache() if "seqlen_agnostic_capture_inputs" not in kwargs: + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: @@ -684,7 +694,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) else: - ## CG capturing runs + ## CUDA graph capturing runs current_seqlen_agnostic_cache, indices = ( kwargs["seqlen_agnostic_capture_inputs"], [], @@ -703,17 +713,16 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def _copy_mamba_cache_by_indices( self, indices: List[int], - current_seqlen_agnostic_cache: Tuple[torch.Tensor]): + current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): for i, offset in enumerate(indices): self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) def _copy_mamba_cache(self, index_to: int, index_from: int, - from_buffer: Tuple[torch.Tensor]): - assert self.mamba_cache is not None - for i in [0, 1]: - self.mamba_cache[i][:, index_to].copy_(from_buffer[i][:, - index_from], - non_blocking=True) + from_buffer: Tuple[torch.Tensor, torch.Tensor]): + assert len(self.mamba_cache) > 0 + for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): + cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + non_blocking=True) def _assign_seq_id_to_mamba_cache(self, cur_rid: str, seqs_id: List[int]) -> List[int]: @@ -763,32 +772,48 @@ def _prepare_current_run_mamba_cache( return (conv_state, temporal_state), indices_for_current_run def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (JambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] batch_size = len(request_ids_to_seq_ids) ( - current_seqlen_agnostic_cache, + current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, batch_size) self.current_indices = indices - finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) - for i in [0, 1]: - input_buffers["seqlen_agnostic_capture_inputs"][i].copy_( - current_seqlen_agnostic_cache[i], non_blocking=True) + for input_buffer, current_cache_buffer in zip( + input_buffers["seqlen_agnostic_capture_inputs"], + current_mamba_cache): + input_buffer.copy_(current_cache_buffer, non_blocking=True) def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache from the CUDA graph input_buffers + back to the JambaForCausalLM.mamba_cache after CUDA + graph replay run is done. + """ self._copy_mamba_cache_by_indices( self.current_indices, input_buffers["seqlen_agnostic_capture_inputs"]) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return ( - self.mamba_gc_cache_buffer[0][:, :batch_size], - self.mamba_gc_cache_buffer[1][:, :batch_size], - ) + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] + for buffer in self.mamba_gc_cache_buffer) def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: @@ -796,14 +821,14 @@ def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): self.mamba_cache_indices_mapping.pop(req_id) def _first_free_index_in_mamba_cache(self) -> int: - if self.mamba_cache is not None: - max_possible_bs = self.mamba_cache[0].shape[1] + if self.mamba_cache: + max_possible_batch_size = self.mamba_cache[0].shape[1] occupied = [ id for seq_ids in self.mamba_cache_indices_mapping.values() for id in seq_ids.values() ] first_free_index = [ - i not in occupied for i in range(max_possible_bs) + i not in occupied for i in range(max_possible_batch_size) ].index(True) return first_free_index return 0 @@ -828,21 +853,18 @@ def _prepare_mamba_cache(self): layers_type = self.config.layers_block_type mamba_layers = sum( [layer_type == "mamba" for layer_type in layers_type]) - num_seqlen_agnostic_layers = mamba_layers max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: - buffer = (torch.empty( - size=(num_seqlen_agnostic_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty( - size=(num_seqlen_agnostic_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) + buffer = (torch.empty(size=(mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) setattr(self, buffername, buffer) def compute_logits(self, hidden_states: torch.Tensor, From b0b08368a87fa64727bea1d174b1dd3123260e5c Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 11:35:36 +0300 Subject: [PATCH 088/110] Format --- vllm/sequence.py | 5 ++--- vllm/worker/model_runner.py | 14 ++------------ 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index b1ee3bae6fa..3d134d86013 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -877,10 +877,10 @@ class ExecuteModelRequest: running_queue_size: int = 0 # Optional hidden states from prior step. previous_hidden_states: Optional[HiddenStates] = None - # Finished request ids since last step. - finished_requests_ids: List[str] = field(default_factory=list) # The number of forward steps to run. num_steps: int = 1 + # Finished request ids since last step. + finished_requests_ids: List[str] = field(default_factory=list) def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -896,4 +896,3 @@ def clone( previous_hidden_states=self.previous_hidden_states, num_steps=self.num_steps, finished_requests_ids=self.finished_requests_ids) - ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 07a7b4adb6c..38121e40262 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1275,23 +1275,13 @@ def capture( torch.cuda.synchronize() # Save the input and output buffers. -<<<<<<< HEAD - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - **kwargs, - } -======= if self.backend_name == "flashinfer": self.input_buffers = { "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, + **kwargs, } else: self.input_buffers = { @@ -1302,8 +1292,8 @@ def capture( "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, + **kwargs, } ->>>>>>> gh-main self.output_buffers = {"hidden_states": hidden_states} return hidden_states From e52e4d76b571f6a1d9201cb788dfe8a23517f93b Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 11:37:50 +0300 Subject: [PATCH 089/110] Resolve conflicts and format --- vllm/worker/model_runner.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 38121e40262..cff1c7b5524 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1313,23 +1313,17 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) -<<<<<<< HEAD - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_inputs_before_cuda_graphs(self.input_buffers, - **kwargs) - -======= if self.backend_name != "flashinfer": self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) ->>>>>>> gh-main + + if "seqlen_agnostic_capture_inputs" in self.input_buffers: + self.model.copy_inputs_before_cuda_graphs(self.input_buffers, + **kwargs) + # Run the graph. self.graph.replay() if "seqlen_agnostic_capture_inputs" in self.input_buffers: From b4d49e04085376bc6bb7244d5c75224f094f2160 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 13:09:42 +0300 Subject: [PATCH 090/110] Add finished requests ids to the prepare model spec decoding --- vllm/spec_decode/draft_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index f30d2937612..fb954648f03 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -76,13 +76,17 @@ def __init__( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithSamplingMetadata: """A temporary solution that caches the seq_group_metadata_list for multi-step execution. TODO: In-place update model_input and remove this function. """ self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input(seq_group_metadata_list) + return super().prepare_model_input( + seq_group_metadata_list, + finished_requests_ids + ) def update_model_input( self, model_input: ModelInputForGPUWithSamplingMetadata, From 68e27de192b95cd949bbc020633d643de709455c Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 13:11:42 +0300 Subject: [PATCH 091/110] Format --- vllm/spec_decode/draft_model_runner.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index fb954648f03..beab00f8e2c 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -83,10 +83,8 @@ def prepare_model_input( TODO: In-place update model_input and remove this function. """ self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input( - seq_group_metadata_list, - finished_requests_ids - ) + return super().prepare_model_input(seq_group_metadata_list, + finished_requests_ids) def update_model_input( self, model_input: ModelInputForGPUWithSamplingMetadata, From 670ff3a102bebd183aaa94906b201002bf8b653a Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 14:24:21 +0300 Subject: [PATCH 092/110] Test cleanup --- tests/models/test_jamba.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index 6e95f944ec9..fb84a5fd2fa 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -32,6 +32,26 @@ def test_models( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_state_cleanup( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Jamba state is cleaned up between + # steps, If its not cleaned, an error would be expected. + with vllm_runner(model, dtype=dtype) as vllm_model: + for _ in range(10): + vllm_outputs = vllm_model.generate_greedy( + [example_prompts[0]] * 100, + 1 + ) + + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_model_print( From b7e31e3ba5253d16bb4a03bd121e64845b896a38 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Sun, 30 Jun 2024 14:49:24 +0300 Subject: [PATCH 093/110] Add message to test --- tests/models/test_jamba.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py index fb84a5fd2fa..d7e3a2fc4a7 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/test_jamba.py @@ -32,7 +32,6 @@ def test_models( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_state_cleanup( @@ -43,13 +42,13 @@ def test_state_cleanup( ) -> None: # This test is for verifying that the Jamba state is cleaned up between # steps, If its not cleaned, an error would be expected. - with vllm_runner(model, dtype=dtype) as vllm_model: - for _ in range(10): - vllm_outputs = vllm_model.generate_greedy( - [example_prompts[0]] * 100, - 1 - ) - + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + for _ in range(10): + vllm_model.generate_greedy([example_prompts[0]] * 100, 1) + except ValueError: + pytest.fail("Jamba inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids") @pytest.mark.parametrize("model", MODELS) From 571f63d337ee300ac69e58dc9615d7c9a2c2dd19 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 1 Jul 2024 12:11:59 +0300 Subject: [PATCH 094/110] Add docstring in vllm/config.py --- vllm/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c0eecbda6a8..090610b1990 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -374,7 +374,8 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def contains_seqlen_agnostic_layers( self, parallel_config: "ParallelConfig") -> bool: - return self.get_num_seqlen_agnostic_layers(parallel_config) > 0 + """True for Mamba/SSM models (Jamba)""" + return self._get_num_seqlen_agnostic_layers(parallel_config) > 0 def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: @@ -390,7 +391,7 @@ def get_num_attention_layers(self, if t == "attention" ]) - def get_num_seqlen_agnostic_layers( + def _get_num_seqlen_agnostic_layers( self, parallel_config: "ParallelConfig") -> int: return len([ t for t in self.get_layers_block_type(parallel_config) From 49da326cb7d43274314f51530fd825e6e8a3a6b7 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 1 Jul 2024 12:12:49 +0300 Subject: [PATCH 095/110] rename flush to get_and_reset --- vllm/core/scheduler.py | 2 +- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f84f8d4391a..7f7f81a05c9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -367,7 +367,7 @@ def has_unfinished_seqs(self) -> bool: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def flush_finished_requests_ids(self) -> List[str]: + def get_and_reset_finished_requests_ids(self) -> List[str]: """Flushes the list of request ids of previously finished seq_groups.""" finished_requests_ids = self.finished_requests_ids self.finished_requests_ids = list() diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5d6c7f051bd..b2a99d9c08f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -233,7 +233,7 @@ async def step_async( num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, finished_requests_ids=self.scheduler. - flush_finished_requests_ids()) + get_and_reset_finished_requests_ids()) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 07907684e0e..21908398159 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -823,7 +823,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, finished_requests_ids=self.scheduler. - flush_finished_requests_ids()) + get_and_reset_finished_requests_ids()) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: From 688732efc1e7296d7ec8baf06db0e4328381fa90 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 1 Jul 2024 12:32:36 +0300 Subject: [PATCH 096/110] Add comments --- vllm/model_executor/models/jamba.py | 5 ++++- vllm/worker/model_runner.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b67d6a4d275..526801095e4 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -679,9 +679,11 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, self._prepare_mamba_cache() if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: @@ -694,7 +696,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) else: - ## CUDA graph capturing runs + # CUDA graph capturing runs current_seqlen_agnostic_cache, indices = ( kwargs["seqlen_agnostic_capture_inputs"], [], @@ -705,6 +707,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, attn_metadata, current_seqlen_agnostic_cache[0], current_seqlen_agnostic_cache[1]) + if "seqlen_agnostic_capture_inputs" not in kwargs: self._copy_mamba_cache_by_indices(self.current_indices, current_seqlen_agnostic_cache) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cff1c7b5524..ad9f1ab6b94 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1039,6 +1039,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: graph_capture_context.stream } if self.has_seqlen_agnostic: + # Only used by Mamba-based models CUDA graph atm (Jamba). capture_inputs.update({ "seqlen_agnostic_capture_inputs": self.model.get_seqlen_agnostic_capture_inputs( From 4a6b170debcc0fade63db27595fbb99cca2fc608 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 1 Jul 2024 12:45:14 +0300 Subject: [PATCH 097/110] Change to private and check finished through all of the queue --- vllm/core/scheduler.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7f7f81a05c9..ccb75cf964c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -293,7 +293,7 @@ def __init__( # Sequence groups finished requests ids since last step iteration. # It lets the model know that any state associated with these requests # can and must be released after the current step. - self.finished_requests_ids: List[str] = list() + self._finished_requests_ids: List[str] = list() # Time at previous scheduling step self.prev_time = 0.0 # Did we schedule a prompt at previous step? @@ -369,8 +369,8 @@ def get_num_unfinished_seq_groups(self) -> int: def get_and_reset_finished_requests_ids(self) -> List[str]: """Flushes the list of request ids of previously finished seq_groups.""" - finished_requests_ids = self.finished_requests_ids - self.finished_requests_ids = list() + finished_requests_ids = self._finished_requests_ids + self._finished_requests_ids = list() return finished_requests_ids def _schedule_running( @@ -1036,10 +1036,11 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: - self.finished_requests_ids += [ - seq_group.request_id for seq_group in self.running - if seq_group.is_finished() - ] + for queue in [self.running, self.swapped, self.waiting]: + self._finished_requests_ids += [ + seq_group.request_id for seq_group in queue + if seq_group.is_finished() + ] self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) From 2047a91dfb088170954fcadadeb5dafc321d6419 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 1 Jul 2024 18:07:46 +0300 Subject: [PATCH 098/110] CI From 5d932a4fd77f37c1c5bd59733be02311858ff5da Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Mon, 1 Jul 2024 18:17:41 +0000 Subject: [PATCH 099/110] Pipeline Parallelism Signed-off-by: Muralidhar Andoorveedu --- .buildkite/test-pipeline.yaml | 10 + tests/async_engine/test_async_llm_engine.py | 14 +- tests/async_engine/test_openapi_server_ray.py | 4 +- tests/basic_correctness/test_preemption.py | 24 +- tests/core/test_chunked_prefill_scheduler.py | 42 ++- tests/core/test_scheduler.py | 35 +- tests/distributed/test_pipeline_parallel.py | 149 ++++++++ tests/entrypoints/openai/test_chat.py | 4 +- tests/entrypoints/openai/test_completion.py | 4 +- tests/entrypoints/openai/test_embedding.py | 4 +- tests/entrypoints/openai/test_models.py | 4 +- tests/entrypoints/openai/test_vision.py | 4 +- tests/spec_decode/utils.py | 6 +- tests/tensorizer_loader/test_tensorizer.py | 4 +- tests/utils.py | 16 +- tests/worker/test_swap.py | 4 +- vllm/config.py | 25 +- vllm/core/block_manager_v1.py | 3 + vllm/core/block_manager_v2.py | 3 + vllm/core/scheduler.py | 17 +- vllm/distributed/parallel_state.py | 50 +-- vllm/distributed/utils.py | 11 +- vllm/engine/async_llm_engine.py | 79 ++++- vllm/engine/llm_engine.py | 65 +++- vllm/engine/output_processor/multi_step.py | 3 +- vllm/engine/output_processor/single_step.py | 18 +- vllm/executor/distributed_gpu_executor.py | 12 +- vllm/executor/executor_base.py | 25 ++ vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 12 +- vllm/executor/ray_gpu_executor.py | 71 +++- vllm/model_executor/models/arctic.py | 3 +- vllm/model_executor/models/baichuan.py | 3 +- vllm/model_executor/models/bloom.py | 3 +- vllm/model_executor/models/chatglm.py | 3 +- vllm/model_executor/models/commandr.py | 3 +- vllm/model_executor/models/dbrx.py | 3 +- vllm/model_executor/models/deepseek.py | 3 +- vllm/model_executor/models/deepseek_v2.py | 3 +- vllm/model_executor/models/falcon.py | 3 +- vllm/model_executor/models/gemma.py | 3 +- vllm/model_executor/models/gemma2.py | 3 +- vllm/model_executor/models/gpt2.py | 92 +++-- vllm/model_executor/models/gpt_bigcode.py | 3 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 3 +- vllm/model_executor/models/internlm2.py | 3 +- vllm/model_executor/models/jais.py | 3 +- vllm/model_executor/models/llama.py | 105 ++++-- vllm/model_executor/models/llava.py | 4 +- vllm/model_executor/models/llava_next.py | 4 +- vllm/model_executor/models/minicpm.py | 3 +- vllm/model_executor/models/mixtral.py | 3 +- vllm/model_executor/models/mixtral_quant.py | 3 +- vllm/model_executor/models/mpt.py | 3 +- vllm/model_executor/models/olmo.py | 3 +- vllm/model_executor/models/opt.py | 3 +- vllm/model_executor/models/orion.py | 3 +- vllm/model_executor/models/phi.py | 3 +- vllm/model_executor/models/phi3_small.py | 3 +- vllm/model_executor/models/qwen.py | 3 +- vllm/model_executor/models/qwen2.py | 3 +- vllm/model_executor/models/qwen2_moe.py | 3 +- vllm/model_executor/models/stablelm.py | 3 +- vllm/model_executor/models/starcoder2.py | 3 +- vllm/model_executor/models/xverse.py | 3 +- vllm/sequence.py | 31 ++ vllm/spec_decode/draft_model_runner.py | 16 +- vllm/worker/cache_engine.py | 4 + vllm/worker/cpu_model_runner.py | 5 +- vllm/worker/cpu_worker.py | 38 +- vllm/worker/embedding_model_runner.py | 9 +- vllm/worker/model_runner.py | 328 +++++++++++------- vllm/worker/model_runner_base.py | 5 +- vllm/worker/neuron_model_runner.py | 5 +- vllm/worker/neuron_worker.py | 2 +- vllm/worker/worker.py | 36 +- vllm/worker/worker_base.py | 40 ++- vllm/worker/xpu_model_runner.py | 7 +- vllm/worker/xpu_worker.py | 4 +- 80 files changed, 1134 insertions(+), 419 deletions(-) create mode 100644 tests/distributed/test_pipeline_parallel.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d96e3c6d192..d127278aaae 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -74,6 +74,16 @@ steps: - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py +- label: Pipeline Parallelism Test + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + commands: + - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py + - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py + - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py + - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py + + - label: Engine Test mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 52d3394a96a..aa2b6e22208 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -5,6 +5,7 @@ import torch from vllm import SamplingParams +from vllm.config import ParallelConfig from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from ..utils import wait_for_gpu_memory_to_clear @@ -23,8 +24,11 @@ def __init__(self): self.add_request_calls = 0 self.abort_request_calls = 0 self.request_id = None + # Ugly, remove dependency when possible + self.parallel_config = ParallelConfig(1, 1, False) - async def step_async(self): + async def step_async(self, virtual_engine): + # PP size is 1, ignore virtual engine self.step_calls += 1 return [RequestOutput( request_id=self.request_id)] if self.request_id else [] @@ -32,6 +36,9 @@ async def step_async(self): async def process_model_inputs_async(self, *args, **kwargs): pass + async def stop_remote_worker_execution_loop_async(self): + pass + def generate(self, request_id): self.request_id = request_id @@ -41,6 +48,7 @@ def stop_generating(self): def add_request(self, **kwargs): del kwargs # Unused self.add_request_calls += 1 + print(f'Request calls: {self.add_request_calls}') async def add_request_async(self, **kwargs): self.add_request_calls += 1 @@ -53,6 +61,9 @@ def abort_request(self, request_id): def has_unfinished_requests(self): return self.request_id is not None + def has_unfinished_requests_for_virtual_engine(self, virtual_engine): + return self.request_id is not None + class MockAsyncLLMEngine(AsyncLLMEngine): @@ -76,6 +87,7 @@ async def test_new_requests_event(): engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) + await asyncio.sleep(0) assert engine.engine.add_request_calls == 2 assert engine.engine.step_calls >= 2 await asyncio.sleep(0.001) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 332937b874e..cc05d79e568 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -4,7 +4,7 @@ # and debugging. import ray -from ..utils import RemoteOpenAIServer +from ..utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index d60cc95d754..7aed0d5e1fa 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -56,8 +56,8 @@ def test_chunked_prefill_recompute( max_num_seqs=max_num_seqs, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -91,10 +91,10 @@ def test_preemption( disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) total_preemption = ( - vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) + vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -147,10 +147,10 @@ def test_swap( ) as vllm_model: vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) total_preemption = ( - vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) + vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) for i in range(len(example_prompts)): hf_output_ids, _ = hf_outputs[i] @@ -214,8 +214,8 @@ def test_swap_infeasible( example_prompts, sampling_params=sampling_params, ) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) # Verify the request is ignored and not hang. assert req_outputs[0].outputs[0].finish_reason == "length" @@ -252,8 +252,8 @@ def test_preemption_infeasible( sampling_params=sampling_params, ) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) + assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + < ARTIFICIAL_PREEMPTION_MAX_CNT) # Verify the request is ignored and not hang. for req_output in req_outputs: diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index a3b76327e0a..7a5477175fa 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -3,7 +3,7 @@ import pytest # noqa -from vllm.config import CacheConfig, SchedulerConfig +from vllm.config import CacheConfig, ParallelConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup @@ -40,7 +40,9 @@ def test_simple(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -82,7 +84,9 @@ def test_chunk(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -125,7 +129,9 @@ def test_complex(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -189,7 +195,9 @@ def test_maximal_decoding(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -276,7 +284,9 @@ def test_prompt_limit(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] _, seq_group = create_dummy_prompt("1", prompt_length=48) @@ -305,7 +315,9 @@ def test_prompt_limit_exceed(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] _, seq_group = create_dummy_prompt("2", prompt_length=48) @@ -330,7 +342,9 @@ def test_swap(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler.add_seq_group(seq_group) @@ -381,7 +395,9 @@ def test_running_prefill_prioritized_over_swap(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler.add_seq_group(seq_group) @@ -468,7 +484,9 @@ def test_chunked_prefill_preempt(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) _, seq_group = create_dummy_prompt("1", prompt_length=60) scheduler.add_seq_group(seq_group) @@ -529,7 +547,9 @@ def test_chunked_prefill_max_seqs(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] _, seq_group = create_dummy_prompt("1", prompt_length=65) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index bae958211cb..d3d98524152 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -5,7 +5,8 @@ import pytest # noqa -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import (CacheConfig, LoRAConfig, ParallelConfig, + SchedulerConfig) from vllm.core.interfaces import AllocStatus from vllm.core.policy import PolicyFactory from vllm.core.scheduler import Scheduler, SchedulingBudget @@ -45,7 +46,9 @@ def test_scheduler_add_seq_group(): cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) # Add seq group to scheduler. num_seq_group = 4 @@ -61,7 +64,9 @@ def test_scheduler_abort_seq_group(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) # Add multiple seq groups to scheduler. num_seq_group = 4 @@ -85,7 +90,9 @@ def test_scheduler_schedule_simple(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -124,7 +131,9 @@ def test_scheduler_prefill_prioritized(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) # Add seq groups to scheduler. _, seq_group_a = create_dummy_prompt("1", 1) @@ -151,7 +160,9 @@ def test_scheduler_schedule_preempt_abort(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) # Add seq groups to scheduler. seq_a, seq_group_a = create_dummy_prompt("1", block_size) @@ -202,7 +213,9 @@ def test_scheduler_max_seqs(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) all_seq_groups: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -240,7 +253,9 @@ def test_scheduler_delay_factor(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + None) # schedule first prompt seq_group_meta, seq_group = create_dummy_prompt("0", @@ -321,7 +336,9 @@ def initialize_scheduler(*, cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, lora_config) + parallel_config = ParallelConfig(1, 1, False) + scheduler = Scheduler(scheduler_config, cache_config, parallel_config, + lora_config) return scheduler diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py new file mode 100644 index 00000000000..6072a2dd718 --- /dev/null +++ b/tests/distributed/test_pipeline_parallel.py @@ -0,0 +1,149 @@ +import os + +import openai # use the official client for correctness check +import pytest +# using Ray for overall ease of process management, parallel requests, +# and debugging. +import ray + +from ..utils import VLLM_PATH, RemoteOpenAIServer + +# downloading lora to test lora requests + +# any model with a chat template should work here +MODEL_NAME = "meta-llama/Meta-Llama-3-8B" +EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0))) +CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0))) +TP_SIZE = int(os.getenv("TP_SIZE", 1)) +PP_SIZE = int(os.getenv("PP_SIZE", 1)) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(scope="module") +def ray_ctx(): + ray.init(runtime_env={"working_dir": VLLM_PATH}) + yield + ray.shutdown() + + +@pytest.fixture(scope="module") +def server(ray_ctx): + args = [ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--pipeline-parallel-size", + str(PP_SIZE), + "--tensor-parallel-size", + str(TP_SIZE), + "--distributed-executor-backend", + "ray", + ] + if CHUNKED_PREFILL: + args += [ + "--enable-chunked-prefill", + ] + if EAGER_MODE: + args += [ + "--enforce-eager", + ] + return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE) + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + + +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_single_completion(server, client: openai.AsyncOpenAI, + model_name: str): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + + +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME], +) +async def test_batch_completions(server, client: openai.AsyncOpenAI, + model_name: str): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=["Hello, my name is", "Hello, my name is"], + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary + # for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index f4c0af1adfd..3e80214f24d 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download from openai import BadRequestError -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -77,7 +77,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index b05035713d7..4fe925495ee 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -16,7 +16,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -79,7 +79,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 7c7232dbcca..f8aa1c9143a 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -5,14 +5,14 @@ import pytest import ray -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index fddfd755048..914ef6e19e1 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -6,7 +6,7 @@ # downloading lora to test lora requests from huggingface_hub import snapshot_download -from ...utils import RemoteOpenAIServer +from ...utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -22,7 +22,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index dbaaa349ad3..31131a51a17 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -24,13 +24,13 @@ @pytest.fixture(scope="module") def ray_ctx(): - ray.init() + ray.init(runtime_env={"working_dir": VLLM_PATH}) yield ray.shutdown() @pytest.fixture(scope="module") -def server(): +def server(ray_ctx): return RemoteOpenAIServer([ "--model", MODEL_NAME, diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 68802f0b846..86148291ae6 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -54,9 +54,9 @@ def new_execute_model(*args, **kwargs): return new_execute_model -def zero_kv_cache(cache_engine: CacheEngine): - assert cache_engine.gpu_cache - for key_blocks, value_blocks in cache_engine.gpu_cache: +def zero_kv_cache(cache_engine: List[CacheEngine]): + assert cache_engine[0].gpu_cache + for key_blocks, value_blocks in cache_engine[0].gpu_cache: key_blocks.zero_() value_blocks.zero_() diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index c8f86133f41..b2ebcc15cd0 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -22,7 +22,7 @@ tensorize_vllm_model) from ..conftest import VllmRunner, cleanup -from ..utils import RemoteOpenAIServer +from ..utils import VLLM_PATH, RemoteOpenAIServer # yapf conflicts with isort for this docstring @@ -220,6 +220,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): json.dumps(model_loader_extra_config), ] + ray.init(runtime_env={"working_dir": VLLM_PATH}) + server = RemoteOpenAIServer(openai_args) print("Server ready.") diff --git a/tests/utils.py b/tests/utils.py index 09107b5e7e2..ad4d097b0e8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,7 +49,6 @@ class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds - @ray.remote(num_gpus=1) class _RemoteRunner: def __init__(self, cli_args: List[str], *, wait_url: str, @@ -92,7 +91,11 @@ def __del__(self): if hasattr(self, "proc"): self.proc.terminate() - def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: + def __init__(self, + cli_args: List[str], + *, + auto_port: bool = True, + num_gpus: int = 1) -> None: if auto_port: if "-p" in cli_args or "--port" in cli_args: raise ValueError("You have manually specified the port" @@ -105,10 +108,11 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: self.host = str(args.host or 'localhost') self.port = int(args.port) - self._runner = self._RemoteRunner.remote( # type: ignore - cli_args, - wait_url=self.url_for("health"), - wait_timeout=self.MAX_SERVER_START_WAIT_S) + self._runner = ray.remote(num_gpus=num_gpus)( + self._RemoteRunner).remote( + cli_args, + wait_url=self.url_for("health"), + wait_timeout=self.MAX_SERVER_START_WAIT_S) self._wait_until_ready() diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index d941ffdb558..7aa439ba0a1 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -39,8 +39,8 @@ def test_swap() -> None: num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) # Randomly initialize the cache. - gpu_cache = worker.cache_engine.gpu_cache - cpu_cache = worker.cache_engine.cpu_cache + gpu_cache = worker.cache_engine[0].gpu_cache + cpu_cache = worker.cache_engine[0].cpu_cache num_layers = len(gpu_cache) for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] diff --git a/vllm/config.py b/vllm/config.py index 9854f175065..6d36d7ec29a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,6 +27,17 @@ _GB = 1 << 30 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_PP_SUPPORTED_MODELS = [ + "AquilaModel", + "AquilaForCausalLM", + "InternLMForCausalLM", + "LlamaForCausalLM", + "LLaMAForCausalLM", + "MistralForCausalLM", + "Phi3ForCausalLM", + "GPT2LMHeadModel", +] + class ModelConfig: """Configuration for the model. @@ -258,6 +269,13 @@ def verify_with_parallel_config( total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) pipeline_parallel_size = parallel_config.pipeline_parallel_size + architectures = getattr(self.hf_config, "architectures", []) + if not all(arch in _PP_SUPPORTED_MODELS + for arch in architectures) and pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported for the following " + f" architectures: {_PP_SUPPORTED_MODELS}.") + if total_num_hidden_layers % pipeline_parallel_size != 0: raise ValueError( f"Total number of hidden layers ({total_num_hidden_layers}) " @@ -665,9 +683,10 @@ def __init__( self._verify_args() def _verify_args(self) -> None: - if self.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is not supported yet.") + if (self.pipeline_parallel_size > 1 + and self.distributed_executor_backend == "mp"): + raise NotImplementedError("Pipeline parallelism is not supported " + "yet with multiprocessing.") if self.distributed_executor_backend not in ("ray", "mp", None): raise ValueError( "Unrecognized distributed executor backend. Supported values " diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 995ea04a5b3..e29eba375f4 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -471,6 +471,9 @@ def append_slots( def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. # Thus, it is always safe from OOM. + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.copy() # When using a sliding window, blocks will be eventually reused. diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 309775237a7..c2653cc9eda 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -293,6 +293,9 @@ def get_common_computed_block_ids( seq_block_ids) # type: ignore def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.fork() diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 48c34625c08..c59e3aa6953 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -6,7 +6,8 @@ from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import (CacheConfig, LoRAConfig, ParallelConfig, + SchedulerConfig) from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger @@ -255,10 +256,12 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, lora_config: Optional[LoRAConfig], ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + self.parallel_config = parallel_config # Note for LoRA scheduling: the current policy is extremely # simple and NOT fair. It can lead to starvation of some # LoRAs. This should be improved in the future. @@ -273,11 +276,19 @@ def __init__( BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) + num_gpu_blocks = cache_config.num_gpu_blocks + if num_gpu_blocks: + num_gpu_blocks //= parallel_config.pipeline_parallel_size + + num_cpu_blocks = cache_config.num_cpu_blocks + if num_cpu_blocks: + num_cpu_blocks //= parallel_config.pipeline_parallel_size + # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( block_size=self.cache_config.block_size, - num_gpu_blocks=self.cache_config.num_gpu_blocks, - num_cpu_blocks=self.cache_config.num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4ebb8703e0f..faf9177adc8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -416,7 +416,7 @@ def send_object(self, obj: Any, dst: int) -> None: assert dst < self.world_size, f"Invalid dst rank ({dst})" - assert dst != self.rank, ( + assert dst != self.rank_in_group, ( "Invalid destination rank. Destination rank is the same " "as the current rank.") @@ -446,7 +446,7 @@ def recv_object(self, src: int) -> Any: assert src < self.world_size, f"Invalid src rank ({src})" - assert src != self.rank, ( + assert src != self.rank_in_group, ( "Invalid source rank. Source rank is the same as the current rank." ) @@ -454,7 +454,7 @@ def recv_object(self, src: int) -> Any: # Receive object size rank_size = torch.distributed.recv(size_tensor, - src=src, + src=self.ranks[src], group=self.cpu_group) # Tensor to receive serialized objects into. @@ -464,7 +464,7 @@ def recv_object(self, src: int) -> Any: device="cpu") rank_object = torch.distributed.recv(object_tensor, - src=src, + src=self.ranks[src], group=self.cpu_group) assert rank_object == rank_size, ( @@ -491,10 +491,9 @@ def broadcast_tensor_dict( group = self.device_group metadata_group = self.cpu_group assert src < self.world_size, f"Invalid src rank ({src})" - src = self.ranks[src] - rank = self.rank - if rank == src: + rank_in_group = self.rank_in_group + if rank_in_group == src: metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, @@ -512,13 +511,13 @@ def broadcast_tensor_dict( if tensor.is_cpu: # use metadata_group for CPU tensors handle = torch.distributed.broadcast(tensor, - src=src, + src=self.ranks[src], group=metadata_group, async_op=True) else: # use group for GPU tensors handle = torch.distributed.broadcast(tensor, - src=src, + src=self.ranks[src], group=group, async_op=True) async_handles.append(handle) @@ -542,15 +541,16 @@ def broadcast_tensor_dict( # use metadata_group for CPU tensors handle = torch.distributed.broadcast( tensor, - src=src, + src=self.ranks[src], group=metadata_group, async_op=True) else: # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True) + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) async_handles.append(handle) _update_nested_dict(tensor_dict, key, tensor) else: @@ -575,7 +575,7 @@ def send_tensor_dict( metadata_group = self.cpu_group if dst is None: - dst = self.next_rank + dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" metadata_list: List[Tuple[Any, Any]] = [] @@ -593,10 +593,14 @@ def send_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.send(tensor, dst=dst, group=metadata_group) + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) else: # use group for GPU tensors - torch.distributed.send(tensor, dst=dst, group=group) + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) return None def recv_tensor_dict( @@ -614,7 +618,7 @@ def recv_tensor_dict( metadata_group = self.cpu_group if src is None: - src = self.prev_rank + src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) @@ -631,11 +635,13 @@ def recv_tensor_dict( if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.recv(tensor, - src=src, + src=self.ranks[src], group=metadata_group) else: # use group for GPU tensors - torch.distributed.recv(tensor, src=src, group=group) + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) _update_nested_dict(tensor_dict, key, tensor) else: _update_nested_dict(tensor_dict, key, value) @@ -654,7 +660,7 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" if dst is None: - dst = self.next_rank + dst = (self.rank_in_group + 1) % self.world_size pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: @@ -669,7 +675,7 @@ def recv(self, """Receives a tensor from the src rank.""" """NOTE: `src` is the local rank of the destination rank.""" if src is None: - src = self.prev_rank + src = (self.rank_in_group - 1) % self.world_size tensor = torch.empty(size, dtype=dtype, device=self.device) pynccl_comm = self.pynccl_comm diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 0cd420c8e11..4e4206e5893 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import Sequence +from typing import Sequence, Tuple import torch @@ -46,3 +46,12 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list + + +def get_pp_indices(num_hidden_layers: int, pp_rank: int, + pp_size: int) -> Tuple[int, int]: + layers_per_partition = divide(num_hidden_layers, pp_size) + start_layer = pp_rank * layers_per_partition + end_layer = start_layer + layers_per_partition + + return (start_layer, end_layer) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7db3bb28c6e..0ce511ce424 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" async def step_async( - self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + self, virtual_engine: int + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -221,7 +222,8 @@ async def step_async( and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + virtual_engine].schedule() if not scheduler_outputs.is_empty(): # Execute the model. @@ -230,6 +232,7 @@ async def step_async( blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, + virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, ) @@ -248,16 +251,12 @@ async def step_async( # Tracing self.do_tracing(scheduler_outputs) - if not request_outputs: - # Stop the execute model loop in parallel workers until there are - # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks - # the RPC thread in the workers so that they can process any other - # queued control plane messages, such as add/remove lora adapters. - await self.model_executor.stop_remote_worker_execution_loop_async() - return request_outputs + async def stop_remote_worker_execution_loop_async(self) -> None: + """Stop the remote worker execution loop.""" + await self.model_executor.stop_remote_worker_execution_loop_async() + async def process_model_inputs_async( self, request_id: str, @@ -491,7 +490,8 @@ def _init_engine(self, *args, # order of the arguments. cache_config = kwargs["cache_config"] parallel_config = kwargs["parallel_config"] - if parallel_config.tensor_parallel_size == 1: + if (parallel_config.tensor_parallel_size == 1 + and parallel_config.pipeline_parallel_size == 1): num_gpus = cache_config.gpu_memory_utilization else: num_gpus = 1 @@ -499,7 +499,7 @@ def _init_engine(self, *args, self._engine_class).remote return engine_class(*args, **kwargs) - async def engine_step(self) -> bool: + async def engine_step(self, virtual_engine: int) -> bool: """Kick the engine to process the waiting requests. Returns True if there are in-progress requests.""" @@ -530,7 +530,7 @@ async def engine_step(self) -> bool: if self.engine_use_ray: request_outputs = await self.engine.step.remote() # type: ignore else: - request_outputs = await self.engine.step_async() + request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. for request_output in request_outputs: @@ -546,18 +546,65 @@ async def _engine_abort(self, request_ids: Iterable[str]): self.engine.abort_request(request_ids) async def run_engine_loop(self): - has_requests_in_progress = False + if self.engine_use_ray: + pipeline_parallel_size = 1 # type: ignore + else: + pipeline_parallel_size = \ + self.engine.parallel_config.pipeline_parallel_size + has_requests_in_progress = [False] * pipeline_parallel_size while True: - if not has_requests_in_progress: + if not any(has_requests_in_progress): logger.debug("Waiting for new requests...") + # Stop the execute model loop in parallel workers until there + # are more requests to process. This avoids waiting + # indefinitely in torch.distributed ops which may otherwise + # timeout, and unblocks the RPC thread in the workers so that + # they can process any other queued control plane messages, + # such as add/remove lora adapters. + if self.engine_use_ray: + await (self.engine.stop_remote_worker_execution_loop. + remote() # type: ignore + ) + else: + await self.engine.stop_remote_worker_execution_loop_async() await self._request_tracker.wait_for_new_requests() logger.debug("Got new requests!") + requests_in_progress = [ + asyncio.create_task(self.engine_step(ve)) + for ve in range(pipeline_parallel_size) + ] + has_requests_in_progress = [True] * pipeline_parallel_size # Abort if iteration takes too long due to unrecoverable errors # (eg. NCCL timeouts). try: async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): - has_requests_in_progress = await self.engine_step() + done, _ = await asyncio.wait( + requests_in_progress, + return_when=asyncio.FIRST_COMPLETED) + for _ in range(pipeline_parallel_size): + await asyncio.sleep(0) + for task in done: + result = task.result() + virtual_engine = requests_in_progress.index(task) + if self.engine_use_ray: + has_unfinished_requests = ( + await (self.engine. + has_unfinished_requests_for_virtual_engine. + remote( # type: ignore + virtual_engine))) + else: + has_unfinished_requests = ( + self.engine. + has_unfinished_requests_for_virtual_engine( + virtual_engine)) + if result or has_unfinished_requests: + requests_in_progress[virtual_engine] = ( + asyncio.create_task( + self.engine_step(virtual_engine))) + has_requests_in_progress[virtual_engine] = True + else: + has_requests_in_progress[virtual_engine] = False except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7e38c0e6b9..afbae5b9513 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -173,6 +173,7 @@ def __init__( "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " @@ -194,6 +195,7 @@ def __init__( load_config.download_dir, load_config.load_format, parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, model_config.quantization, model_config.enforce_eager, @@ -293,7 +295,11 @@ def __init__( # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + self.scheduler = [ + Scheduler(scheduler_config, cache_config, parallel_config, + lora_config) + for _ in range(parallel_config.pipeline_parallel_size) + ] # Metric Logging. if self.log_stats: @@ -510,8 +516,16 @@ def _add_processed_request( raise ValueError( "Either SamplingParams or PoolingParams must be provided.") - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() def process_model_inputs( self, @@ -681,7 +695,8 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: >>> # abort the request >>> engine.abort_request(request_id) """ - self.scheduler.abort_seq_group(request_id) + for scheduler in self.scheduler: + scheduler.abort_seq_group(request_id) def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" @@ -693,11 +708,20 @@ def get_decoding_config(self) -> DecodingConfig: def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" - return self.scheduler.get_num_unfinished_seq_groups() + return sum(scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler) def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" - return self.scheduler.has_unfinished_seqs() + return any(scheduler.has_unfinished_seqs() + for scheduler in self.scheduler) + + def has_unfinished_requests_for_virtual_engine( + self, virtual_engine: int) -> bool: + """ + Returns True if there are unfinished requests for the virtual engine. + """ + return self.scheduler[virtual_engine].has_unfinished_seqs() def _process_sequence_group_outputs( self, @@ -746,7 +770,8 @@ def _process_model_outputs( self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. - self.scheduler.free_finished_seq_groups() + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() # Create the outputs. request_outputs: List[Union[RequestOutput, @@ -812,7 +837,12 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: >>> if not (engine.has_unfinished_requests() or example_inputs): >>> break """ - seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported through AsyncLLMEngine " + "as performance will be severely degraded otherwise.") + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + 0].schedule() if not scheduler_outputs.is_empty(): execute_model_req = ExecuteModelRequest( @@ -883,23 +913,28 @@ def _get_stats( # System State # Scheduler State - num_running_sys = len(self.scheduler.running) - num_swapped_sys = len(self.scheduler.swapped) - num_waiting_sys = len(self.scheduler.waiting) + num_running_sys = sum( + len(scheduler.running) for scheduler in self.scheduler) + num_swapped_sys = sum( + len(scheduler.swapped) for scheduler in self.scheduler) + num_waiting_sys = sum( + len(scheduler.waiting) for scheduler in self.scheduler) # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks gpu_cache_usage_sys = 0. if num_total_gpu is not None: - num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( - ) + num_free_gpu = sum( + scheduler.block_manager.get_num_free_gpu_blocks() + for scheduler in self.scheduler) gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks cpu_cache_usage_sys = 0. if num_total_cpu is not None and num_total_cpu > 0: - num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( - ) + num_free_cpu = sum( + scheduler.block_manager.get_num_free_cpu_blocks() + for scheduler in self.scheduler) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) # Iteration stats diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 8512ff83e41..7fd1faada62 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -141,4 +141,5 @@ def _process_seq_outputs(self, seq: Sequence, break if seq.is_finished(): - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 07a68c65a6d..555652249e1 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -95,7 +95,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # not be used in the future iterations. parent.status = SequenceStatus.FINISHED_ABORTED seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) + for scheduler in self.scheduler: + scheduler.free_seq(parent) continue # Fork the parent sequence if there are multiple child samples. for child_sample in child_samples[:-1]: @@ -133,7 +134,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + for scheduler in self.scheduler: + scheduler.fork_seq(parent, seq) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. @@ -141,7 +143,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # old sequences. for seq, parent in child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) return # Beam search case @@ -226,13 +229,15 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + for scheduler in self.scheduler: + scheduler.fork_seq(parent, seq) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. for seq, parent in selected_child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) # Remove the unselected parent sequences from the sequence group and # free their memory in block manager. @@ -241,7 +246,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Remove the parent sequence if it is not selected for next # iteration seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) + for scheduler in self.scheduler: + scheduler.free_seq(seq) def _check_beam_search_early_stopping( self, diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index d8693e636ac..3db82eb1fe7 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -69,7 +69,7 @@ def execute_model( if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", - async_run_remote_workers_only=True, + async_run_tensor_parallel_workers_only=True, **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. @@ -138,17 +138,17 @@ def _run_workers( self, method: str, *args, - async_run_remote_workers_only: bool = False, + async_run_tensor_parallel_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: """Runs the given method on all workers. Args: - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than - blocking on the results. + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. """ raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index d7c19622e27..9018c329510 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple @@ -110,6 +111,30 @@ def __del__(self): class ExecutorAsyncBase(ExecutorBase): + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + ) -> None: + # This locks each pipeline parallel stage so multiple virtual engines + # can't execute on the same stage at the same time + self.pp_locks = [ + asyncio.Lock() + for _ in range(parallel_config.pipeline_parallel_size) + ] + + super().__init__(model_config, cache_config, parallel_config, + scheduler_config, device_config, load_config, + lora_config, vision_language_config, + speculative_config) + @abstractmethod async def execute_model_async( self, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 5522b5322e6..c2910ccdcdb 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -45,7 +45,8 @@ def _get_worker_kwargs( lora_config=self.lora_config, vision_language_config=self.vision_language_config, speculative_config=self.speculative_config, - is_driver_worker=rank == 0, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), ) def _create_worker(self, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 6aebb470288..5bfeac0cf02 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -91,17 +91,17 @@ def _run_workers( self, method: str, *args, - async_run_remote_workers_only: bool = False, + async_run_tensor_parallel_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: """Runs the given method on all workers. Args: - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than - blocking on the results. + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. """ if max_concurrent_workers: @@ -114,7 +114,7 @@ def _run_workers( for worker in self.workers ] - if async_run_remote_workers_only: + if async_run_tensor_parallel_workers_only: # Just return futures return worker_outputs diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index faa500c2d79..e742d11bb3e 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -62,7 +62,8 @@ def _configure_ray_workers_use_nsight(self, def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - if self.parallel_config.tensor_parallel_size == 1: + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): # For single GPU case, we use a ray worker with constrained memory. num_gpus = self.cache_config.gpu_memory_utilization else: @@ -189,6 +190,26 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + for tp_rank in range(self.parallel_config.tensor_parallel_size): + rank = (pp_rank * + self.parallel_config.tensor_parallel_size) + tp_rank + if rank == 0: + pass + elif rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(self.workers[rank - 1]) + else: + self.non_driver_workers.append(self.workers[rank - 1]) + def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] ) -> Optional[List[SamplerOutput]]: @@ -204,7 +225,7 @@ def _run_workers( self, method: str, *args, - async_run_remote_workers_only: bool = False, + async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, @@ -215,10 +236,11 @@ def _run_workers( """Runs the given method on all workers. Can be used in the following ways: - - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than blocking - on the results. + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs - all_args/all_kwargs: args/kwargs for each worker are specified individually @@ -228,7 +250,9 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - count = len(self.workers) + count = len(self.workers) if not \ + async_run_tensor_parallel_workers_only \ + else len(self.non_driver_workers) all_worker_args = repeat(args, count) if all_args is None \ else islice(all_args, 1, None) all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ @@ -242,14 +266,17 @@ def _run_workers( ray_worker_outputs = [] else: # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers ray_worker_outputs = [ worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) ] - if async_run_remote_workers_only: + if async_run_tensor_parallel_workers_only: # Just return futures return ray_worker_outputs @@ -319,12 +346,32 @@ async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - return await self.driver_exec_method("execute_model", - execute_model_req) + + async def _run_task_with_lock(task, lock, *args, **kwargs): + async with lock: + return await task(*args, **kwargs) + + tasks = [] + tasks.append( + asyncio.create_task( + _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], + "execute_model", execute_model_req))) + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] async def _start_worker_execution_loop(self): coros = [ worker.execute_method.remote("start_worker_execution_loop") - for worker in self.workers + for worker in self.non_driver_workers ] return await asyncio.gather(*coros) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 5777611079c..fec52e01688 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -29,7 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.arctic import ArcticConfig logger = init_logger(__name__) @@ -426,6 +426,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5cf5a199b76..ddc4e908451 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -43,7 +43,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -338,6 +338,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index a29aee4cffb..8387c8e37bd 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -286,6 +286,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 5b5a69447e0..e6012a6d4e7 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -25,7 +25,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA @@ -365,6 +365,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 600c2990b36..2961f421eb6 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput @torch.compile @@ -353,6 +353,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 59af42445f3..210cf616526 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -381,6 +381,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 8fbda2638aa..e9ceca9b18c 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -48,7 +48,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class DeepseekMLP(nn.Module): @@ -387,6 +387,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 3d4f78c6647..3cf62afd9b4 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -48,7 +48,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class DeepseekV2MLP(nn.Module): @@ -475,6 +475,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 9618652f70d..89b0bbf014d 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -44,7 +44,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import RWConfig FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -410,6 +410,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index efefb34814c..0a5a7ed3d04 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -339,6 +339,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 4e35a9ec340..1f921c8bd09 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -37,7 +37,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -338,6 +338,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index cc83f6eb6d9..81f709e49ea 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -25,7 +25,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_world_size) +from vllm.distributed.utils import get_pp_indices from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -38,7 +40,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class GPT2Attention(nn.Module): @@ -181,10 +183,18 @@ def __init__( self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList([ - GPT2Block(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer = get_pp_indices( + config.num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + self.h = nn.ModuleList( + [nn.Identity() for _ in range(self.start_layer)] + [ + GPT2Block(config, cache_config, quant_config) + for _ in range(self.start_layer, self.end_layer) + ] + [ + nn.Identity() + for _ in range(self.end_layer, config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -193,17 +203,27 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] - for i in range(len(self.h)): + for i in range(self.start_layer, self.end_layer): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) - hidden_states = self.ln_f(hidden_states) - return hidden_states + if get_pp_group().is_last_rank: + hidden_states = self.ln_f(hidden_states) + return hidden_states + else: + return IntermediateTensors({"hidden_states": hidden_states}) class GPT2LMHeadModel(nn.Module): @@ -228,9 +248,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -247,6 +268,16 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: @@ -260,16 +291,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if not name.startswith("transformer."): name = "transformer." + name - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + try: + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + except KeyError: + continue diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 17bbe4e312f..7d0bf39c58f 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -273,6 +273,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 47fd5788a4c..de7f86af709 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -38,7 +38,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class GPTJAttention(nn.Module): @@ -239,6 +239,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index eb0fcc8f26a..3658b8fbf05 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -38,7 +38,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class GPTNeoXAttention(nn.Module): @@ -251,6 +251,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index e75c567f589..283bc064b59 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -22,7 +22,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class InternLM2MLP(nn.Module): @@ -263,6 +263,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: IntermediateTensors, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 869b8fc91fd..2758e2d0b59 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -40,7 +40,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import JAISConfig @@ -289,6 +289,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 54d01701f04..15e09c06759 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,7 +29,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_pp_indices, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -46,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip, print_warning_once from .interfaces import SupportsLoRA @@ -261,12 +262,20 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config) - for idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer = get_pp_indices( + config.num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + self.layers = nn.ModuleList( + [nn.Identity() for _ in range(self.start_layer)] + [ + LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for _ in range(self.start_layer, self.end_layer) + ] + [ + nn.Identity() + for _ in range(self.end_layer, config.num_hidden_layers) + ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -278,24 +287,38 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + + if get_pp_group().is_last_rank: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + else: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) class LlamaForCausalLM(nn.Module, SupportsLoRA): @@ -372,10 +395,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) - return hidden_states + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -391,6 +415,20 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -416,9 +454,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + try: + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + except KeyError: + pass break else: # Skip loading extra bias for GPTQ models. @@ -437,10 +478,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + try: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + except KeyError: + pass # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should @@ -452,7 +496,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: quantization_param_path, tp_rank, tp_size, self.config.num_hidden_layers, self.config.__class__.model_type): - layer_self_attn = self.model.layers[layer_idx].self_attn + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn if is_hip(): # The scaling factor convention we are assuming is diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index ba4496f9cfa..4914d033d9a 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -18,7 +18,7 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip, dummy_seq_data_for_clip) @@ -251,6 +251,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> SamplerOutput: """Run forward pass for LLaVA-1.5. @@ -311,6 +312,7 @@ def forward( positions, kv_caches, attn_metadata, + None, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 28143107467..545be8d5752 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -23,7 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData from vllm.multimodal.image import ImagePixelData -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip, dummy_seq_data_for_clip, get_clip_patch_grid_length) @@ -422,6 +422,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> SamplerOutput: """Run forward pass for LlaVA-NeXT. @@ -476,6 +477,7 @@ def forward( positions, kv_caches, attn_metadata, + None, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index a76ed049828..33020432713 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -462,6 +462,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a662db6d28d..05c36b9c037 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -536,6 +536,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 1894c05e167..dde2da20b3b 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -47,7 +47,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class MixtralMLP(nn.Module): @@ -354,6 +354,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 5f9e4d86f3c..28dc5922cfe 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -22,7 +22,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig @@ -273,6 +273,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 39270f71ec4..53215f32b92 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -43,7 +43,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class OlmoAttention(nn.Module): @@ -301,6 +301,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 4bf59105dba..d12a51af5a7 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -39,7 +39,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -304,6 +304,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 133a10e6bb3..a298f0307f3 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -26,7 +26,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class OrionMLP(nn.Module): @@ -269,6 +269,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 008fceb624f..cc8e31fe1ad 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -57,7 +57,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -278,6 +278,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 0c5298eb6f1..706ae65201d 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -21,7 +21,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput def load_column_parallel_weight(param: torch.nn.Parameter, @@ -412,6 +412,7 @@ def forward( positions: Optional[torch.LongTensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: output_hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b6ea6ab3966..408c206c5e1 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -27,7 +27,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once @@ -245,6 +245,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e2d725af635..3691a3d2e36 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -45,7 +45,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -331,6 +331,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 564536f2dd2..b3e7dfef93e 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -50,7 +50,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class Qwen2MoeMLP(nn.Module): @@ -397,6 +397,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index a6ed3800bed..1098b3031b1 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -41,7 +41,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class StablelmMLP(nn.Module): @@ -250,6 +250,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 4324bf50d4a..6f3d5d51d03 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -40,7 +40,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput class Starcoder2Attention(nn.Module): @@ -262,6 +262,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index b61721999ca..08d3efd3312 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -43,7 +43,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -320,6 +320,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) diff --git a/vllm/sequence.py b/vllm/sequence.py index 22cb26dc08e..92a6df96219 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -745,6 +745,34 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings +@dataclass +class IntermediateTensors: + """For all pipeline stages except the last, we need to return the hidden + states and residuals to be sent to the next stage. This data structure + contains the hidden states and residuals for a request. + """ + + tensors: Dict[str, torch.Tensor] + + def __getitem__(self, key: Union[str, slice]): + if isinstance(key, str): + return self.tensors[key] + elif isinstance(key, slice): + return self.__class__({k: v[key] for k, v in self.tensors.items()}) + + def __setitem__(self, key: str, value): + self.tensors[key] = value + + def __len__(self): + return len(self.tensors) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self + + def __repr__(self) -> str: + return f"IntermediateTensors(tensors={self.tensors})" + + @dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, @@ -871,6 +899,8 @@ class ExecuteModelRequest: blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) + # Virtual engine ID for pipeline parallel. + virtual_engine: int = 0 # The number of slots for lookahead decoding. num_lookahead_slots: int = 0 # The number of requests in the running queue. @@ -889,6 +919,7 @@ def clone( blocks_to_swap_in=self.blocks_to_swap_in.copy(), blocks_to_swap_out=self.blocks_to_swap_out.copy(), blocks_to_copy=self.blocks_to_copy.copy(), + virtual_engine=self.virtual_engine, num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index f30d2937612..d91d57fc0da 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -6,7 +6,8 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -74,9 +75,8 @@ def __init__( List[SequenceGroupMetadata]] = None def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInputForGPUWithSamplingMetadata: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int) -> ModelInputForGPUWithSamplingMetadata: """A temporary solution that caches the seq_group_metadata_list for multi-step execution. TODO: In-place update model_input and remove this function. @@ -108,13 +108,14 @@ def update_model_input( seq.append_token_id(token_id, token_logprob.logprob) seq.update_num_computed_tokens(1) - return self.prepare_model_input(self.cached_seq_group_metadata_list) + return self.prepare_model_input(self.cached_seq_group_metadata_list, 0) @torch.inference_mode() def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: # Since we do not broadcast data inside execute_model anymore, @@ -130,6 +131,7 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): # Currently cuda graph is only supported by the decode phase. @@ -139,7 +141,8 @@ def execute_model( if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] + model_executable = ( + self.graph_runners[virtual_engine][graph_batch_size]) else: model_executable = self.model @@ -149,6 +152,7 @@ def execute_model( positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, **multi_modal_kwargs, ) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index fbd1343fea1..891e74f8ab9 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -38,7 +38,11 @@ def __init__( self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks + if self.num_gpu_blocks: + self.num_gpu_blocks //= parallel_config.pipeline_parallel_size self.num_cpu_blocks = cache_config.num_cpu_blocks + if self.num_cpu_blocks: + self.num_cpu_blocks //= parallel_config.pipeline_parallel_size if cache_config.cache_dtype == "auto": self.dtype = model_config.dtype diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b83cc6f095b..26042a7dd88 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -13,7 +13,8 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, @@ -315,6 +316,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int, ) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or @@ -351,6 +353,7 @@ def execute_model( self, model_input: CPUModelInput, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 30ee262c7a8..8089abd6906 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -167,8 +167,8 @@ def __init__( is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: CPUCacheEngine - self.cpu_cache: List[torch.Tensor] + self.cache_engine: List[CPUCacheEngine] + self.cpu_cache: List[List[torch.Tensor]] def init_device(self) -> None: self.init_distributed_environment() @@ -242,25 +242,32 @@ def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: "initializing the engine.") def _init_cache_engine(self) -> None: - self.cache_engine = CPUCacheEngine(self.cache_config, - self.model_config, - self.parallel_config, - self.device_config) - self.cpu_cache = self.cache_engine.cpu_cache - self.model_runner.block_size = self.cache_engine.block_size - - assert self.cpu_cache is not None + self.cache_engine = [ + CPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.cpu_cache = [ + self.cache_engine[ve].cpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + self.model_runner.block_size = self.cache_engine[0].block_size + + assert all( + self.cpu_cache[ve] is not None + for ve in range(self.parallel_config.pipeline_parallel_size)) # Populate the cache to warmup the memory - for layer_cache in self.cpu_cache: - layer_cache.fill_(0) + for ve in range(self.parallel_config.pipeline_parallel_size): + for layer_cache in self.cpu_cache[ve]: + layer_cache.fill_(0) @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @property - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return self.cpu_cache def execute_worker( @@ -269,12 +276,14 @@ def execute_worker( ) -> None: if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine.copy(worker_input.blocks_to_copy) + self.cache_engine[worker_input.virtual_engine].copy( + worker_input.blocks_to_copy) @torch.inference_mode() def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: assert execute_model_req is not None + virtual_engine = execute_model_req.virtual_engine num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, @@ -285,6 +294,7 @@ def prepare_worker_input( return WorkerInput( num_seq_groups=num_seq_groups, blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, ) def init_distributed_environment(self) -> None: diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 272917c7272..964f6d08bf0 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams -from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU logger = init_logger(__name__) @@ -57,6 +58,7 @@ def execute_model( self, model_input: ModelInputForGPUWithPoolingMetadata, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[PoolerOutput]]: if num_steps > 1: @@ -73,10 +75,12 @@ def execute_model( assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata + virtual_engine = model_input.virtual_engine if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] else: model_executable = self.model @@ -115,6 +119,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int, ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 942063677a4..594f7fd4907 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -8,6 +8,7 @@ import numpy as np import torch +import torch.distributed import torch.nn as nn try: @@ -25,6 +26,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -37,7 +39,8 @@ from vllm.model_executor.models.interfaces import supports_lora from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( @@ -81,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + virtual_engine: int = 0 def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -89,6 +93,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -122,6 +127,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -179,7 +185,10 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.graph_runners: Dict[int, CUDAGraphRunner] = {} + + self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + {} for _ in range(self.parallel_config.pipeline_parallel_size) + ] self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. # When using CUDA graph, the input block tables must be padded to @@ -787,9 +796,11 @@ def profile_run(self) -> None: max_num_seqs = min( max_num_seqs, int(max_num_batched_tokens / vlm_config.image_feature_size)) + batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ .dummy_data_for_profiling(model_config, seq_len) @@ -810,8 +821,14 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs) - self.execute_model(model_input, kv_caches) + model_input = self.prepare_model_input(seqs, 0) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -847,7 +864,7 @@ def list_loras(self) -> Set[int]: return self.lora_manager.list_loras() @torch.inference_mode() - def capture_model(self, kv_caches: List[torch.Tensor]) -> None: + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """Cuda graph capture a model. Note that CUDA graph's performance gain is negligible if number @@ -880,10 +897,18 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + intermediate_inputs = None + if not get_pp_group().is_first_rank: + intermediate_inputs = self.model.make_empty_intermediate_tensors( + batch_size=max_batch_size, + dtype=self.model_config.dtype, + device=self.device) # Prepare buffer for outputs. These will be reused for all batch sizes. # It will be filled after the first graph capture. - hidden_states: Optional[torch.Tensor] = None + hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [ + None + ] * self.parallel_config.pipeline_parallel_size graph_batch_size = _get_graph_batch_size( self.scheduler_config.max_num_seqs) @@ -912,109 +937,120 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - for batch_size in reversed(batch_size_capture_list): - if self.attn_backend.get_name() == "flashinfer": - indptr_buffer = indptr_buffer[:batch_size + 1] - last_page_len_buffer = last_page_len_buffer[:batch_size] - - num_qo_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - num_kv_heads = self.model_config.get_num_kv_heads( - self.parallel_config) - if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = True + for virtual_engine in range( + self.parallel_config.pipeline_parallel_size): + for batch_size in reversed(batch_size_capture_list): + if self.attn_backend.get_name() == "flashinfer": + indptr_buffer = indptr_buffer[:batch_size + 1] + last_page_len_buffer = last_page_len_buffer[: + batch_size] + + num_qo_heads = ( + self.model_config.get_num_attention_heads( + self.parallel_config)) + num_kv_heads = self.model_config.get_num_kv_heads( + self.parallel_config) + if num_qo_heads // num_kv_heads >= 4: + use_tensor_cores = True + else: + use_tensor_cores = False + decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + decode_workspace_buffer, indptr_buffer, + indices_buffer, last_page_len_buffer, "NHD", + use_tensor_cores) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.kv_cache_dtype, self.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange( + 0, batch_size + 1, dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange( + 0, batch_size, dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full( + (batch_size, ), self.block_size, dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len= + paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.device, + data_type=kv_cache_dtype, + use_cuda_graph=True, + decode_wrapper=decode_wrapper, + prefill_wrapper=None) + attn_metadata.begin_forward() else: - use_tensor_cores = False - decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, indptr_buffer, indices_buffer, - last_page_len_buffer, "NHD", use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.kv_cache_dtype, self.model_config.dtype) - - paged_kv_indptr_tensor_host = torch.arange( - 0, batch_size + 1, dtype=torch.int32) - paged_kv_indices_tensor_host = torch.arange( - 0, batch_size, dtype=torch.int32) - paged_kv_last_page_len_tensor_host = torch.full( - (batch_size, ), self.block_size, dtype=torch.int32) - query_start_loc_host = torch.arange(0, - batch_size + 1, - dtype=torch.int32) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - slot_mapping=slot_mapping[:batch_size], - num_prefill_tokens=0, - num_decode_tokens=batch_size, - max_prefill_seq_len=0, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor_host, - paged_kv_indices=paged_kv_indices_tensor_host, - paged_kv_last_page_len= - paged_kv_last_page_len_tensor_host, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=None, - query_start_loc=query_start_loc_host, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=True, - decode_wrapper=decode_wrapper, - prefill_wrapper=None) - attn_metadata.begin_forward() - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + + graph_runner = CUDAGraphRunner( + self.model, self.attn_backend.get_name()) + + if self.attn_backend.get_name() == "flashinfer": + graph_runner.flashinfer_indptr_buffer = indptr_buffer + graph_runner.flashinfer_indices_buffer = indices_buffer + graph_runner.flashinfer_last_page_len_buffer = \ + last_page_len_buffer + graph_runner.flashinfer_decode_workspace_buffer = \ + decode_workspace_buffer + graph_runner.flashinfer_decode_wrapper = \ + decode_wrapper + + graph_runner.capture( + input_tokens[:batch_size], + input_positions[:batch_size], + hidden_or_intermediate_states[ + virtual_engine] # type: ignore + [:batch_size] + if hidden_or_intermediate_states[virtual_engine] + is not None else None, + intermediate_inputs[:batch_size] + if intermediate_inputs is not None else None, + kv_caches[virtual_engine], + attn_metadata, + memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, ) - - if self.lora_config: - lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, - ) - self.set_active_loras(set(), lora_mapping) - - graph_runner = CUDAGraphRunner(self.model, - self.attn_backend.get_name()) - - if self.attn_backend.get_name() == "flashinfer": - graph_runner.flashinfer_indptr_buffer = indptr_buffer - graph_runner.flashinfer_indices_buffer = indices_buffer - graph_runner.flashinfer_last_page_len_buffer = \ - last_page_len_buffer - graph_runner.flashinfer_decode_workspace_buffer = \ - decode_workspace_buffer - graph_runner.flashinfer_decode_wrapper = \ - decode_wrapper - - graph_runner.capture( - input_tokens[:batch_size], - input_positions[:batch_size], - hidden_states[:batch_size] - if hidden_states is not None else None, - kv_caches, - attn_metadata, - memory_pool=self.graph_memory_pool, - stream=graph_capture_context.stream, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[virtual_engine][batch_size] = ( + graph_runner) end_time = time.perf_counter() elapsed_time = end_time - start_time @@ -1047,6 +1083,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1072,15 +1109,17 @@ def prepare_model_input( if seq_group_metadata_list else None) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, - is_prompt=is_prompt) + is_prompt=is_prompt, + virtual_engine=virtual_engine) @torch.inference_mode() def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1124,27 +1163,34 @@ def execute_model( assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] else: model_executable = self.model multi_modal_kwargs = model_input.multi_modal_kwargs or {} - hidden_states = model_executable( + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, **multi_modal_kwargs, ) - # Compute the logits. - logits = self.model.compute_logits(hidden_states, + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) - # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] @@ -1159,9 +1205,12 @@ def execute_model( assert model_input.sampling_metadata is not None indices = model_input.sampling_metadata.selected_token_indices if model_input.is_prompt: - hidden_states = hidden_states.index_select(0, indices) + hidden_states = hidden_or_intermediate_states.index_select( + 0, indices) elif decode_meta.use_cuda_graph: - hidden_states = hidden_states[:len(indices)] + hidden_states = hidden_or_intermediate_states[:len(indices)] + else: + hidden_states = hidden_or_intermediate_states output.hidden_states = hidden_states @@ -1195,13 +1244,15 @@ def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, - hidden_states: Optional[torch.Tensor], + hidden_or_intermediate_states: Optional[Union[IntermediateTensors, + torch.Tensor]], + intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool: Optional[Tuple[int, int]], stream: torch.cuda.Stream, **kwargs, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: assert self._graph is None # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the @@ -1213,6 +1264,7 @@ def capture( positions, kv_caches, attn_metadata, + intermediate_inputs, **kwargs, ) torch.cuda.synchronize() @@ -1220,18 +1272,27 @@ def capture( # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): - output_hidden_states = self.model( + output_hidden_or_intermediate_states = self.model( input_ids, positions, kv_caches, attn_metadata, + intermediate_inputs, **kwargs, ) - if hidden_states is not None: - hidden_states.copy_(output_hidden_states) + if hidden_or_intermediate_states is not None: + if get_pp_group().is_last_rank: + hidden_or_intermediate_states.copy_( + output_hidden_or_intermediate_states) + else: + for key in hidden_or_intermediate_states.tensors: + hidden_or_intermediate_states[key].copy_( + output_hidden_or_intermediate_states[key]) else: - hidden_states = output_hidden_states - del output_hidden_states + hidden_or_intermediate_states = ( + output_hidden_or_intermediate_states) + + del output_hidden_or_intermediate_states # make sure `output_hidden_states` is deleted # in the graph's memory pool gc.collect() @@ -1255,8 +1316,15 @@ def capture( attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } - self.output_buffers = {"hidden_states": hidden_states} - return hidden_states + if intermediate_inputs is not None: + self.input_buffers.update(intermediate_inputs.tensors) + if get_pp_group().is_last_rank: + self.output_buffers = { + "hidden_states": hidden_or_intermediate_states + } + else: + self.output_buffers = hidden_or_intermediate_states + return hidden_or_intermediate_states def forward( self, @@ -1264,6 +1332,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. @@ -1280,11 +1349,18 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + if intermediate_tensors is not None: + for key in intermediate_tensors.tensors: + self.input_buffers[key].copy_(intermediate_tensors[key], + non_blocking=True) # Run the graph. self.graph.replay() # Return the output tensor. - return self.output_buffers["hidden_states"] + if get_pp_group().is_last_rank: + return self.output_buffers["hidden_states"] + else: + return self.output_buffers def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 959cfc0b9ca..9c150a6b5ef 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -5,7 +5,8 @@ import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -137,6 +138,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int, ) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution @@ -150,6 +152,7 @@ def execute_model( self, model_input: T, kv_caches: Optional[List[torch.Tensor]], + intermediate_tensors: Optional[IntermediateTensors], num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: """ diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 2ccf4a50a87..9be80bf7b4a 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -175,6 +176,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int, ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -207,6 +209,7 @@ def execute_model( self, model_input: ModelInputForNeuron, kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 307c107ddef..f7525e049ee 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -80,7 +80,7 @@ def do_metadata_broadcast(self) -> bool: return False @property - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return None @torch.inference_mode() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index cc27d06b511..5b572829099 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -59,9 +59,9 @@ def __init__( self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker - if self.is_driver_worker: - assert self.rank == 0, "The driver worker must have rank 0." - + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -99,9 +99,9 @@ def __init__( ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: CacheEngine + self.cache_engine: List[CacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches - self.gpu_cache: Optional[List[torch.tensor]] = None + self.gpu_cache: Optional[List[List[torch.tensor]]] = None def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -217,10 +217,15 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = CacheEngine(self.cache_config, self.model_config, - self.parallel_config, - self.device_config) - self.gpu_cache = self.cache_engine.gpu_cache + self.cache_engine = [ + CacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.gpu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: @@ -234,12 +239,13 @@ def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @property - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return self.gpu_cache @torch.inference_mode() def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. @@ -261,20 +267,24 @@ def prepare_worker_input( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, ) @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine # Issue cache operations. if (worker_input.blocks_to_swap_in is not None and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine.swap_in(worker_input.blocks_to_swap_in) + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) if (worker_input.blocks_to_swap_out is not None and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine.swap_out(worker_input.blocks_to_swap_out) + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine.copy(worker_input.blocks_to_copy) + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d867e15bdf8..d4d28507693 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,10 +6,11 @@ import torch -from vllm.distributed import broadcast_tensor_dict +from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, is_hip, update_environment_variables) from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -124,6 +125,7 @@ class WorkerInput: blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None + virtual_engine: int = 0 @classmethod def from_broadcasted_tensor_dict( @@ -139,6 +141,7 @@ def from_broadcasted_tensor_dict( blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + virtual_engine=tensor_dict.pop("virtual_engine"), ) def as_broadcastable_tensor_dict( @@ -151,6 +154,7 @@ def as_broadcastable_tensor_dict( "blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_copy": self.blocks_to_copy, + "virtual_engine": self.virtual_engine, } return tensor_dict @@ -181,11 +185,13 @@ def do_metadata_broadcast(self) -> bool: @property @abstractmethod - def kv_cache(self) -> Optional[List[torch.Tensor]]: + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: """ - Get the kv cache to pass to the worker's model runner. Used by the - default `execute_model`. If the worker's model runner does not follow - the ModelRunnerBase interface, then inherit from WorkerBase instead. + Gets the list of kv caches to pass to the worker's model runner. Each + element in the list is a kv cache corresponding to a particular virtual + engine (PP stream). Used by the default `execute_model`. If the worker's + model runner does not follow the ModelRunnerBase interface, then inherit + from WorkerBase instead. """ raise NotImplementedError @@ -227,7 +233,8 @@ def execute_model( execute_model_req=execute_model_req) model_input: ModelRunnerInputBase = ( self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine)) num_steps = execute_model_req.num_steps if self.do_metadata_broadcast: @@ -255,8 +262,23 @@ def execute_model( if worker_input.num_seq_groups == 0: return [] - return self.model_runner.execute_model(model_input, self.kv_cache, - num_steps) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict()) + + output = self.model_runner.execute_model( + model_input, self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, intermediate_tensors, + num_steps) + + if not get_pp_group().is_last_rank: + get_pp_group().send_tensor_dict(output.tensors) + return [None] + + # Worker only supports single-step execution. Wrap the output in a + # list to conform to interface. + return output class WorkerWrapperBase: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 99fd7da5edd..5c02b924a69 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,7 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( @@ -175,7 +176,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs) + model_input = self.prepare_model_input(seqs, 0) self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return @@ -190,6 +191,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: @@ -334,6 +336,7 @@ def execute_model( self, model_input: ModelInputForXPU, kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 773ee9f8159..7a51f2b2c72 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -85,8 +85,8 @@ def __init__( ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: CacheEngine - self.gpu_cache: List[torch.Tensor] + self.cache_engine: List[CacheEngine] + self.gpu_cache: Optional[List[List[torch.Tensor]]] def init_device(self) -> None: if self.device_config.device.type == "xpu" and is_xpu(): From 3c15001910d9bb68729415b2c89b75b2feacf9d5 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Mon, 1 Jul 2024 20:06:33 +0000 Subject: [PATCH 100/110] Make scheduler use pipeline parallel size directly Signed-off-by: Muralidhar Andoorveedu --- tests/core/test_chunked_prefill_scheduler.py | 42 +++++-------------- tests/core/test_scheduler.py | 35 ++++------------ .../output_processor/test_multi_step.py | 8 ++-- vllm/core/scheduler.py | 7 ++-- vllm/engine/llm_engine.py | 3 +- vllm/engine/output_processor/interfaces.py | 2 +- vllm/engine/output_processor/multi_step.py | 2 +- vllm/engine/output_processor/single_step.py | 2 +- 8 files changed, 31 insertions(+), 70 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 7a5477175fa..a3b76327e0a 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -3,7 +3,7 @@ import pytest # noqa -from vllm.config import CacheConfig, ParallelConfig, SchedulerConfig +from vllm.config import CacheConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup @@ -40,9 +40,7 @@ def test_simple(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -84,9 +82,7 @@ def test_chunk(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -129,9 +125,7 @@ def test_complex(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -195,9 +189,7 @@ def test_maximal_decoding(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -284,9 +276,7 @@ def test_prompt_limit(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] _, seq_group = create_dummy_prompt("1", prompt_length=48) @@ -315,9 +305,7 @@ def test_prompt_limit_exceed(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] _, seq_group = create_dummy_prompt("2", prompt_length=48) @@ -342,9 +330,7 @@ def test_swap(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler.add_seq_group(seq_group) @@ -395,9 +381,7 @@ def test_running_prefill_prioritized_over_swap(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler.add_seq_group(seq_group) @@ -484,9 +468,7 @@ def test_chunked_prefill_preempt(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) _, seq_group = create_dummy_prompt("1", prompt_length=60) scheduler.add_seq_group(seq_group) @@ -547,9 +529,7 @@ def test_chunked_prefill_max_seqs(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] _, seq_group = create_dummy_prompt("1", prompt_length=65) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index d3d98524152..b9750c67bc5 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -5,8 +5,7 @@ import pytest # noqa -from vllm.config import (CacheConfig, LoRAConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, LoRAConfig, SchedulerConfig) from vllm.core.interfaces import AllocStatus from vllm.core.policy import PolicyFactory from vllm.core.scheduler import Scheduler, SchedulingBudget @@ -46,9 +45,7 @@ def test_scheduler_add_seq_group(): cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq group to scheduler. num_seq_group = 4 @@ -64,9 +61,7 @@ def test_scheduler_abort_seq_group(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) # Add multiple seq groups to scheduler. num_seq_group = 4 @@ -90,9 +85,7 @@ def test_scheduler_schedule_simple(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -131,9 +124,7 @@ def test_scheduler_prefill_prioritized(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. _, seq_group_a = create_dummy_prompt("1", 1) @@ -160,9 +151,7 @@ def test_scheduler_schedule_preempt_abort(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. seq_a, seq_group_a = create_dummy_prompt("1", block_size) @@ -213,9 +202,7 @@ def test_scheduler_max_seqs(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) all_seq_groups: List[SequenceGroup] = [] # Add seq groups to scheduler. @@ -253,9 +240,7 @@ def test_scheduler_delay_factor(): cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - None) + scheduler = Scheduler(scheduler_config, cache_config, None) # schedule first prompt seq_group_meta, seq_group = create_dummy_prompt("0", @@ -336,9 +321,7 @@ def initialize_scheduler(*, cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - parallel_config = ParallelConfig(1, 1, False) - scheduler = Scheduler(scheduler_config, cache_config, parallel_config, - lora_config) + scheduler = Scheduler(scheduler_config, cache_config, lora_config) return scheduler diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 4f32a622546..88f3fad4c79 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -32,7 +32,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(), stop_checker=stop_checker, @@ -86,7 +86,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(), stop_checker=stop_checker, @@ -148,7 +148,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), stop_checker=stop_checker, @@ -215,7 +215,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, - scheduler=scheduler, + scheduler=[scheduler], seq_counter=seq_counter, get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), stop_checker=stop_checker, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c59e3aa6953..f126476a99d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -256,12 +256,11 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - self.parallel_config = parallel_config # Note for LoRA scheduling: the current policy is extremely # simple and NOT fair. It can lead to starvation of some # LoRAs. This should be improved in the future. @@ -278,11 +277,11 @@ def __init__( num_gpu_blocks = cache_config.num_gpu_blocks if num_gpu_blocks: - num_gpu_blocks //= parallel_config.pipeline_parallel_size + num_gpu_blocks //= pipeline_parallel_size num_cpu_blocks = cache_config.num_cpu_blocks if num_cpu_blocks: - num_cpu_blocks //= parallel_config.pipeline_parallel_size + num_cpu_blocks //= pipeline_parallel_size # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index afbae5b9513..311912a2c4c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -296,8 +296,7 @@ def __init__( # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ - Scheduler(scheduler_config, cache_config, parallel_config, - lora_config) + Scheduler(scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size) for _ in range(parallel_config.pipeline_parallel_size) ] diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 9ddb6a3648b..92aecebe6ec 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC): def create_output_processor( scheduler_config: SchedulerConfig, detokenizer: Detokenizer, - scheduler: Scheduler, + scheduler: List[Scheduler], seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: "StopChecker", diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 7fd1faada62..25d15df9f91 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): def __init__( self, detokenizer: Detokenizer, - scheduler: Scheduler, + scheduler: List[Scheduler], seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: StopChecker, diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 555652249e1..fa672e1feda 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -33,7 +33,7 @@ def __init__( self, scheduler_config: SchedulerConfig, detokenizer: Detokenizer, - scheduler: Scheduler, + scheduler: List[Scheduler], seq_counter: Counter, stop_checker: StopChecker, ): From 1ff2cdb93082a196b3d828f0e57678392b53c3ea Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Mon, 1 Jul 2024 20:15:07 +0000 Subject: [PATCH 101/110] Change ABC defaults for prepare_model_input Signed-off-by: Muralidhar Andoorveedu --- tests/core/test_scheduler.py | 4 ++-- vllm/core/scheduler.py | 3 +-- vllm/engine/llm_engine.py | 3 ++- vllm/spec_decode/draft_model_runner.py | 7 ++++--- vllm/worker/cpu_model_runner.py | 2 +- vllm/worker/embedding_model_runner.py | 2 +- vllm/worker/model_runner.py | 2 +- vllm/worker/model_runner_base.py | 2 +- vllm/worker/neuron_model_runner.py | 2 +- vllm/worker/xpu_model_runner.py | 2 +- 10 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index b9750c67bc5..bae958211cb 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -5,7 +5,7 @@ import pytest # noqa -from vllm.config import (CacheConfig, LoRAConfig, SchedulerConfig) +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.policy import PolicyFactory from vllm.core.scheduler import Scheduler, SchedulingBudget @@ -321,7 +321,7 @@ def initialize_scheduler(*, cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, lora_config) + scheduler = Scheduler(scheduler_config, cache_config, lora_config) return scheduler diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f126476a99d..5fb3b78141b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -6,8 +6,7 @@ from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union -from vllm.config import (CacheConfig, LoRAConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 311912a2c4c..5a84e285b59 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -296,7 +296,8 @@ def __init__( # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ - Scheduler(scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size) + Scheduler(scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size) for _ in range(parallel_config.pipeline_parallel_size) ] diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index d91d57fc0da..b4c953162e2 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -75,8 +75,9 @@ def __init__( List[SequenceGroupMetadata]] = None def prepare_model_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int) -> ModelInputForGPUWithSamplingMetadata: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata: """A temporary solution that caches the seq_group_metadata_list for multi-step execution. TODO: In-place update model_input and remove this function. @@ -108,7 +109,7 @@ def update_model_input( seq.append_token_id(token_id, token_logprob.logprob) seq.update_num_computed_tokens(1) - return self.prepare_model_input(self.cached_seq_group_metadata_list, 0) + return self.prepare_model_input(self.cached_seq_group_metadata_list) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 26042a7dd88..f46e9e8aba9 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -316,7 +316,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int, + virtual_engine: int = 0, ) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 964f6d08bf0..faf6e99ab64 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -119,7 +119,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - virtual_engine: int, + virtual_engine: int = 0, ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 594f7fd4907..6891a6c63c9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -821,7 +821,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs, 0) + model_input = self.prepare_model_input(seqs) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 9c150a6b5ef..f66bb466228 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -138,7 +138,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int, + virtual_engine: int = 0, ) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 9be80bf7b4a..ab8e4852812 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -176,7 +176,7 @@ def make_model_input_from_broadcasted_tensor_dict( def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int, + virtual_engine: int = 0, ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 5c02b924a69..73b771c4395 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -176,7 +176,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs, 0) + model_input = self.prepare_model_input(seqs) self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return From 548f4e85f9aa1c29920b8e2d92dddf08be2a4b30 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Mon, 1 Jul 2024 23:14:36 +0000 Subject: [PATCH 102/110] Add basic comm ops tests with TP and PP. Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index bf0f31df02f..7302d484954 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -32,7 +32,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, (r + 1) for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - t = all_tensors[rank] + t = all_tensors[rank % tp_size] t = tensor_model_parallel_all_reduce(t) assert torch.allclose(t, expected) @@ -60,7 +60,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) - t = all_tensors[rank] + t = all_tensors[rank % tp_size] t = tensor_model_parallel_all_gather(t, all_gather_dimension) assert torch.allclose(t, expected) @@ -91,7 +91,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, "f": torch.tensor([], dtype=torch.float32, device="cuda"), } - if rank == 0: + if (rank % tp_size) == 0: broadcast_tensor_dict(test_dict, src=0) else: recv_dict = broadcast_tensor_dict(src=0) @@ -184,3 +184,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target): "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) def test_multi_process_pipeline_parallel(pp_size, test_target): multi_process_parallel(1, pp_size, test_target) + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pp_size", [2]) +@pytest.mark.parametrize("test_target", [ + send_recv_test_worker, send_recv_tensor_dict_test_worker, + all_reduce_test_worker, all_gather_test_worker, + broadcast_tensor_dict_test_worker +]) +def test_multi_process_tensor_parallel_pipeline_parallel( + tp_size, pp_size, test_target): + multi_process_parallel(tp_size, pp_size, test_target) From 5a4b323c5372772889fa7918d0539e79bfe8166e Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Mon, 1 Jul 2024 23:24:26 +0000 Subject: [PATCH 103/110] Fix phi3v for test Signed-off-by: Muralidhar Andoorveedu --- vllm/model_executor/models/phi3v.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bc3d3f0fbf1..02fb42f6306 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -36,7 +36,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import ImagePixelData -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsVision @@ -390,9 +390,13 @@ def _parse_and_validate_image_input( return None - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, **kwargs: object): + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object): image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: @@ -407,6 +411,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states From c92257c63384b8f36c4d28d8641f0cb9eba5efce Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Tue, 2 Jul 2024 04:52:05 +0000 Subject: [PATCH 104/110] Address Nick nits and fix CUDAGraph correctness Signed-off-by: Muralidhar Andoorveedu --- vllm/model_executor/models/gpt2.py | 8 ++++---- vllm/model_executor/models/llama.py | 8 ++++---- vllm/worker/model_runner.py | 4 ++-- vllm/worker/worker_base.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 81f709e49ea..55f2e27410d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -219,12 +219,12 @@ def forward( kv_caches[i - self.start_layer], attn_metadata) - if get_pp_group().is_last_rank: - hidden_states = self.ln_f(hidden_states) - return hidden_states - else: + if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + class GPT2LMHeadModel(nn.Module): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 15e09c06759..af75b6bee10 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -311,15 +311,15 @@ def forward( residual, ) - if get_pp_group().is_last_rank: - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - else: + if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + class LlamaForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6891a6c63c9..69ba6c62071 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1359,8 +1359,8 @@ def forward( # Return the output tensor. if get_pp_group().is_last_rank: return self.output_buffers["hidden_states"] - else: - return self.output_buffers + + return self.output_buffers def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d4d28507693..118173a4ca9 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -141,7 +141,7 @@ def from_broadcasted_tensor_dict( blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict.pop("virtual_engine"), + virtual_engine=tensor_dict["virtual_engine"], ) def as_broadcastable_tensor_dict( From 10d8f3caa4591b5cb6597512d64970af4c54dfc1 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 2 Jul 2024 21:46:06 +0300 Subject: [PATCH 105/110] Formating and fixing llm engine --- vllm/engine/async_llm_engine.py | 5 +++-- vllm/engine/llm_engine.py | 5 +++-- vllm/spec_decode/draft_model_runner.py | 2 +- vllm/worker/model_runner.py | 7 ++----- vllm/worker/xpu_model_runner.py | 8 ++++---- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a159e7e0213..13b4635cb88 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -224,6 +224,8 @@ async def step_async( """ seq_group_metadata_list, scheduler_outputs = self.scheduler[ virtual_engine].schedule() + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() if not scheduler_outputs.is_empty(): # Execute the model. @@ -235,8 +237,7 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=self.scheduler. - get_and_reset_finished_requests_ids()) + finished_requests_ids=finished_requests_ids) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5a0d872db67..a7428d01010 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -846,6 +846,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "as performance will be severely degraded otherwise.") seq_group_metadata_list, scheduler_outputs = self.scheduler[ 0].schedule() + finished_requests_ids = self.scheduler[ + 0].get_and_reset_finished_requests_ids() if not scheduler_outputs.is_empty(): execute_model_req = ExecuteModelRequest( @@ -855,8 +857,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=self.scheduler. - get_and_reset_finished_requests_ids()) + finished_requests_ids=finished_requests_ids) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index d8efef20440..5eb84361543 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -79,7 +79,7 @@ def prepare_model_input( seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: + ) -> ModelInputForGPUWithSamplingMetadata: """A temporary solution that caches the seq_group_metadata_list for multi-step execution. TODO: In-place update model_input and remove this function. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 917782b6f8d..e663d4596e7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -840,9 +840,7 @@ def profile_run(self) -> None: kv_caches = [None] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( - seqs, - finished_requests_ids= finished_requests_ids - ) + seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( @@ -1086,8 +1084,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: }) graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][batch_size] = ( - graph_runner) + self.graph_runners[virtual_engine][batch_size] = (graph_runner) end_time = time.perf_counter() elapsed_time = end_time - start_time diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 3ee090493e0..e652f1b1042 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -189,10 +189,10 @@ def make_model_input_from_broadcasted_tensor_dict( )) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: From 1331a8fa9f4042b79cd913a9b8d764aeebef929d Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 2 Jul 2024 21:56:08 +0300 Subject: [PATCH 106/110] Align with main and format --- vllm/worker/cache_engine.py | 5 +- vllm/worker/model_runner.py | 107 ++++++++++++++++++------------------ 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index f78948142ca..252440c7b7e 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -109,11 +109,12 @@ def get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = model_config.get_num_attention_layers(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block - total = num_layers * (key_cache_block + value_cache_block) + total = num_attention_layers * (key_cache_block + value_cache_block) if cache_config.cache_dtype == "auto": dtype = model_config.dtype else: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e663d4596e7..bd30281471d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -95,9 +95,9 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, - "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -131,9 +131,9 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, - "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -1032,59 +1032,60 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: use_cuda_graph=True, ) - if self.lora_config: - lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, - ) - self.set_active_loras(set(), lora_mapping) + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model, - self.attn_backend.get_name()) + graph_runner = CUDAGraphRunner( + self.model, self.attn_backend.get_name()) - if self.attn_backend.get_name() == "flashinfer": - graph_runner.flashinfer_indptr_buffer = indptr_buffer - graph_runner.flashinfer_indices_buffer = indices_buffer - graph_runner.flashinfer_last_page_len_buffer = \ - last_page_len_buffer - graph_runner.flashinfer_decode_workspace_buffer = \ - decode_workspace_buffer - graph_runner.flashinfer_decode_wrapper = \ - decode_wrapper - - capture_inputs = { - "input_ids": - input_tokens[:batch_size], - "positions": - input_positions[:batch_size], - "hidden_or_intermediate_states": - hidden_or_intermediate_states[ - virtual_engine] # type: ignore - [:batch_size] - if hidden_or_intermediate_states[virtual_engine] - is not None else None, - "intermediate_inputs": - intermediate_inputs[:batch_size] - if intermediate_inputs is not None else None, - "kv_caches": - kv_caches[virtual_engine], - "attn_metadata": - attn_metadata, - "memory_pool": - self.graph_memory_pool, - "stream": - graph_capture_context.stream - } - if self.has_seqlen_agnostic: - # Only used by Mamba-based models CUDA graph atm (Jamba). - capture_inputs.update({ - "seqlen_agnostic_capture_inputs": - self.model.get_seqlen_agnostic_capture_inputs( - batch_size) - }) - graph_runner.capture(**capture_inputs) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][batch_size] = (graph_runner) + if self.attn_backend.get_name() == "flashinfer": + graph_runner.flashinfer_indptr_buffer = indptr_buffer + graph_runner.flashinfer_indices_buffer = indices_buffer + graph_runner.flashinfer_last_page_len_buffer = \ + last_page_len_buffer + graph_runner.flashinfer_decode_workspace_buffer = \ + decode_workspace_buffer + graph_runner.flashinfer_decode_wrapper = \ + decode_wrapper + + capture_inputs = { + "input_ids": + input_tokens[:batch_size], + "positions": + input_positions[:batch_size], + "hidden_or_intermediate_states": + hidden_or_intermediate_states[ + virtual_engine] # type: ignore + [:batch_size] + if hidden_or_intermediate_states[virtual_engine] + is not None else None, + "intermediate_inputs": + intermediate_inputs[:batch_size] + if intermediate_inputs is not None else None, + "kv_caches": + kv_caches[virtual_engine], + "attn_metadata": + attn_metadata, + "memory_pool": + self.graph_memory_pool, + "stream": + graph_capture_context.stream + } + if self.has_seqlen_agnostic: + # Only used by Mamba-based models CUDA graph atm (Jamba) + capture_inputs.update({ + "seqlen_agnostic_capture_inputs": + self.model.get_seqlen_agnostic_capture_inputs( + batch_size) + }) + graph_runner.capture(**capture_inputs) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[virtual_engine][batch_size] = ( + graph_runner) end_time = time.perf_counter() elapsed_time = end_time - start_time From 21c92b4a6310c3382144964bb41ddc08dff4dfa4 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 2 Jul 2024 22:50:45 +0300 Subject: [PATCH 107/110] Fix bug --- vllm/spec_decode/draft_model_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 5eb84361543..26391d8e4b9 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -85,8 +85,10 @@ def prepare_model_input( TODO: In-place update model_input and remove this function. """ self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input(seq_group_metadata_list, - finished_requests_ids) + return super().prepare_model_input( + seq_group_metadata_list, + finished_requests_ids=finished_requests_ids + ) def update_model_input( self, model_input: ModelInputForGPUWithSamplingMetadata, From 726ccad83b842688ec35d1cbe12b6c382b41de44 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 2 Jul 2024 22:52:17 +0300 Subject: [PATCH 108/110] Format --- vllm/spec_decode/draft_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 26391d8e4b9..1c7b8c07e89 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -87,8 +87,7 @@ def prepare_model_input( self.cached_seq_group_metadata_list = seq_group_metadata_list return super().prepare_model_input( seq_group_metadata_list, - finished_requests_ids=finished_requests_ids - ) + finished_requests_ids=finished_requests_ids) def update_model_input( self, model_input: ModelInputForGPUWithSamplingMetadata, From 4b6a4917e551d56567faa64a2a69e41867fb3019 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 2 Jul 2024 23:52:21 +0300 Subject: [PATCH 109/110] Add intermediate tensors --- vllm/model_executor/models/jamba.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 526801095e4..063db0183fa 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -36,6 +36,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE +from vllm.sequence import IntermediateTensors, SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -674,6 +675,7 @@ def __init__( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): if not self.mamba_cache: self._prepare_mamba_cache() From da5d94a099e1b7bf203504d4f13c8a99d9d563f0 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Tue, 2 Jul 2024 23:53:18 +0300 Subject: [PATCH 110/110] Format --- vllm/model_executor/models/jamba.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 063db0183fa..c485d3779d9 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -34,9 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput -from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -673,8 +672,11 @@ def __init__( config.vocab_size) self.sampler = Sampler() - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], attn_metadata: AttentionMetadata, + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): if not self.mamba_cache: