diff --git a/README.md b/README.md index fce3de3b70430..e59a1c60cc369 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) +- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f4dd5d52ad873..ceb658bbd5c66 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -101,7 +101,7 @@ Alongside each architecture, we include some popular models that use it. - * - :code:`OLMoForCausalLM` - OLMo - - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. + - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. - * - :code:`OPTForCausalLM` - OPT, OPT-IML diff --git a/requirements-dev.txt b/requirements-dev.txt index 1317e51b2dd11..d9816828d007d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -26,7 +26,6 @@ requests ray peft awscli -ai2-olmo # required for OLMo # Benchmarking aiohttp diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c0aeab5dd3032..6afb2f31c1334 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -42,7 +42,7 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), - "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index b92003bc0e067..15527569b9e20 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -1,53 +1,36 @@ # coding=utf-8 # Adapted from -# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and -# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py -# Copyright 2023 The vLLM team. -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. # -# BSD 3-Clause License +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: +# http://www.apache.org/licenses/LICENSE-2.0 # -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" from typing import Iterable, List, Optional, Tuple import torch -# this model must need this dependency -from hf_olmo import OLMoConfig from torch import nn +from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -55,7 +38,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -70,55 +53,52 @@ class OlmoAttention(nn.Module): def __init__( self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.hidden_size = config.d_model - assert config.d_model % config.n_heads == 0 + self.hidden_size = config.hidden_size tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) - self.total_num_heads = self.config.n_heads + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv - # Layer norms. - self.attn_norm = nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False) # Attention input projection. Projects x -> (q, k, v) - self.att_proj = QKVParallelLinear( - config.d_model, + self.qkv_proj = QKVParallelLinear( + self.hidden_size, self.head_dim, self.total_num_heads, - bias=config.include_bias, + bias=config.attention_bias, linear_method=linear_method, ) # Rotary embeddings. - if self.config.rope: - rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling) # Attention output projection. - self.attn_out = RowParallelLinear( - config.d_model, - config.d_model, - bias=config.include_bias, + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, linear_method=linear_method, ) @@ -129,13 +109,13 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - hidden_states = self.attn_norm(hidden_states) - qkv, _ = self.att_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.chunk(chunks=3, dim=-1) - if self.config.rope: - q, k = self.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.attn_out(attn_output) + output, _ = self.o_proj(attn_output) return output @@ -148,37 +128,30 @@ class OlmoMLP(nn.Module): def __init__( self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size - is not None else config.mlp_ratio * config.d_model) - - # Layer norms. - self.ff_norm = nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False) + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size # Feed-forward input projection. - self.ff_proj = MergedColumnParallelLinear( - config.d_model, - [self.hidden_size // 2] * 2, - bias=config.include_bias, + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, linear_method=linear_method, ) # Activation function. - self.act = SiluAndMul() - self.act.output_multiplier = 0.5 - assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + self.act_fn = SiluAndMul() # Feed-forward output projection. - self.ff_out = RowParallelLinear( - int(self.act.output_multiplier * self.hidden_size), - config.d_model, - bias=config.include_bias, + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, linear_method=linear_method, ) @@ -186,19 +159,13 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: - # Add feed-forward projection. - # shape: (batch_size, seq_len, d_model) - og_x = x - x = self.ff_norm(x) - x, _ = self.ff_proj(x) - x = self.act(x) - x, _ = self.ff_out(x) - x = og_x + x - + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) return x -class OlmoBlock(nn.Module): +class OlmoDecoderLayer(nn.Module): """ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` @@ -206,15 +173,23 @@ class OlmoBlock(nn.Module): """ def __init__(self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() # Attention block. - self.attn = OlmoAttention(config, linear_method) + self.self_attn = OlmoAttention(config, linear_method) # MLP block. self.mlp = OlmoMLP(config, linear_method) + # LayerNorm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + def forward( self, positions: torch.Tensor, @@ -223,52 +198,37 @@ def forward( attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Attention block. - og_x = hidden_states - x = self.attn(positions, hidden_states, kv_cache, attn_metadata) - x = x + og_x + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual # MLP block. - hidden_states = self.mlp(x) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states return hidden_states class OlmoModel(nn.Module): def __init__(self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config - self.transformer = nn.ModuleDict( - dict( - wte=VocabParallelEmbedding( - config.embedding_size or config.vocab_size, - config.d_model, - ), - ln_f=nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False), - )) - - blocks = [ - OlmoBlock(config, linear_method) for i in range(config.n_layers) - ] - if self.config.block_group_size > 1: - raise NotImplementedError("Block group size > 1 not supported yet") - else: - self.transformer.update({"blocks": nn.ModuleList(blocks)}) - - if not config.weight_tying: - self.transformer.update({ - "ff_out": - ColumnParallelLinear( - config.d_model, - config.embedding_size or config.vocab_size, - bias=config.include_bias, - linear_method=linear_method, - ) - }) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + OlmoDecoderLayer(config, linear_method) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) def forward( self, @@ -282,39 +242,49 @@ def forward( """ # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - x = self.transformer.wte(input_ids) # type: ignore + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds # Apply blocks one-by-one. - for block_idx, block in enumerate(self.transformer.blocks): + for layer_idx, decoder_layer in enumerate(self.layers): # shape: (batch_size, seq_len, d_model) - x = block( + hidden_states = decoder_layer( positions, - x, - kv_caches[block_idx], + hidden_states, + kv_caches[layer_idx], attn_metadata, ) # Apply final layer norm. # shape: (batch_size, seq_len or 1, d_model) - x = self.transformer.ln_f(x) # type: ignore - return x + hidden_states = self.norm(hidden_states) + return hidden_states -class OLMoForCausalLM(nn.Module): +class OlmoForCausalLM(nn.Module): """ Extremely barebones HF model wrapper. """ def __init__(self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config self.linear_method = linear_method self.model = OlmoModel(config, linear_method) - self.lm_head_weight = (self.model.transformer.wte.weight - if config.weight_tying else - self.model.transformer.ff_out.weight) + if config.tie_word_embeddings: + self.lm_head_weight = self.model.embed_tokens.weight + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -348,20 +318,39 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - # attention - if ".att" in name: - name = name.replace(".att", ".attn.att") - # mlp - if ".ff_proj" in name: - name = name.replace(".ff_proj", ".mlp.ff_proj") - # Reverse the weight for the MergeColumnParallelLinear - loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1]) - if ".ff_out" in name and "transformer.ff_out" not in name: - name = name.replace(".ff_out", ".mlp.ff_out") - # there is no bias in olmo - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)