-
Notifications
You must be signed in to change notification settings - Fork 496
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
357 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,352 @@ | ||
""" | ||
Copyright 2023-2024 SGLang Team | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
# Adapted from | ||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1 | ||
"""Inference-only OLMo model compatible with HuggingFace weights.""" | ||
from typing import Iterable, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from transformers import OlmoConfig | ||
from vllm.distributed import get_tensor_model_parallel_world_size | ||
from vllm.model_executor.layers.linear import ( | ||
MergedColumnParallelLinear, | ||
QKVParallelLinear, | ||
RowParallelLinear, | ||
) | ||
from vllm.model_executor.layers.rotary_embedding import get_rope | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
ParallelLMHead, | ||
VocabParallelEmbedding, | ||
) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
|
||
from sglang.srt.layers.activation import SiluAndMul | ||
from sglang.srt.layers.logits_processor import LogitsProcessor | ||
from sglang.srt.layers.quantization.base_config import QuantizationConfig | ||
from sglang.srt.layers.radix_attention import RadixAttention | ||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
|
||
|
||
class OlmoAttention(nn.Module): | ||
""" | ||
This is the attention block where the output is computed as | ||
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` | ||
(plus another skip connection). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: OlmoConfig, | ||
layer_id: int = 0, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
): | ||
super().__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() | ||
self.total_num_heads = config.num_attention_heads | ||
|
||
assert self.hidden_size % self.total_num_heads == 0 | ||
assert self.total_num_heads % tensor_model_parallel_world_size == 0 | ||
|
||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size | ||
self.head_dim = self.hidden_size // self.total_num_heads | ||
self.max_position_embeddings = config.max_position_embeddings | ||
self.rope_theta = config.rope_theta | ||
self.clip_qkv = config.clip_qkv | ||
|
||
# Attention input projection. Projects x -> (q, k, v) | ||
self.qkv_proj = QKVParallelLinear( | ||
self.hidden_size, | ||
self.head_dim, | ||
self.total_num_heads, | ||
bias=config.attention_bias, | ||
) | ||
|
||
# Rotary embeddings. | ||
self.rotary_emb = get_rope( | ||
self.head_dim, | ||
rotary_dim=self.head_dim, | ||
max_position=self.max_position_embeddings, | ||
base=self.rope_theta, | ||
) | ||
self.scaling = self.head_dim**-0.5 | ||
self.attn = RadixAttention( | ||
self.num_heads, | ||
self.head_dim, | ||
self.scaling, | ||
num_kv_heads=self.num_heads, | ||
layer_id=layer_id, | ||
) | ||
|
||
# Attention output projection. | ||
self.o_proj = RowParallelLinear( | ||
self.hidden_size, | ||
self.hidden_size, | ||
bias=config.attention_bias, | ||
) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
forward_batch: ForwardBatch, | ||
) -> torch.Tensor: | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
if self.clip_qkv is not None: | ||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) | ||
q, k, v = qkv.chunk(chunks=3, dim=-1) | ||
q, k = self.rotary_emb(positions, q, k) | ||
attn_output = self.attn(q, k, v, forward_batch) | ||
output, _ = self.o_proj(attn_output) | ||
return output | ||
|
||
|
||
class OlmoMLP(nn.Module): | ||
""" | ||
This is the MLP block where the output is computed as | ||
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` | ||
(plus another skip connection). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: OlmoConfig, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
): | ||
super().__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
self.intermediate_size = config.intermediate_size | ||
|
||
# Feed-forward input projection. | ||
self.gate_up_proj = MergedColumnParallelLinear( | ||
self.hidden_size, | ||
[self.intermediate_size] * 2, | ||
bias=False, | ||
quant_config=quant_config, | ||
) | ||
|
||
# Activation function. | ||
self.act_fn = SiluAndMul() | ||
|
||
# Feed-forward output projection. | ||
self.down_proj = RowParallelLinear( | ||
self.intermediate_size, | ||
self.hidden_size, | ||
bias=False, | ||
quant_config=quant_config, | ||
) | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
) -> torch.Tensor: | ||
gate_up, _ = self.gate_up_proj(x) | ||
x = self.act_fn(gate_up) | ||
x, _ = self.down_proj(x) | ||
return x | ||
|
||
|
||
class OlmoDecoderLayer(nn.Module): | ||
""" | ||
This is a typical transformer block where the output is | ||
computed as ``MLP(LN(x + Attention(LN(x))))`` | ||
(plus another skip connection). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: OlmoConfig, | ||
layer_id: int = 0, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
): | ||
super().__init__() | ||
# Attention block. | ||
self.self_attn = OlmoAttention(config, layer_id, quant_config) | ||
|
||
# MLP block. | ||
self.mlp = OlmoMLP(config, quant_config) | ||
|
||
# LayerNorm | ||
self.input_layernorm = nn.LayerNorm( | ||
config.hidden_size, elementwise_affine=False, bias=False | ||
) | ||
self.post_attention_layernorm = nn.LayerNorm( | ||
config.hidden_size, elementwise_affine=False, bias=False | ||
) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
forward_batch: ForwardBatch, | ||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | ||
# Attention block. | ||
residual = hidden_states | ||
hidden_states = self.input_layernorm(hidden_states) | ||
hidden_states = self.self_attn(positions, hidden_states, forward_batch) | ||
hidden_states = hidden_states + residual | ||
|
||
# MLP block. | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
hidden_states = self.mlp(hidden_states) | ||
hidden_states = residual + hidden_states | ||
return hidden_states | ||
|
||
|
||
class OlmoModel(nn.Module): | ||
|
||
def __init__( | ||
self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None | ||
): | ||
super().__init__() | ||
self.config = config | ||
|
||
self.embed_tokens = VocabParallelEmbedding( | ||
config.vocab_size, config.hidden_size | ||
) | ||
self.layers = nn.ModuleList( | ||
[ | ||
OlmoDecoderLayer(config, layer_idx, quant_config) | ||
for layer_idx in range(config.num_hidden_layers) | ||
] | ||
) | ||
self.norm = nn.LayerNorm( | ||
config.hidden_size, elementwise_affine=False, bias=False | ||
) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
forward_batch: ForwardBatch, | ||
input_embeds: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
""" | ||
:param input_ids: A tensor of shape `(batch_size, seq_len)`. | ||
""" | ||
# Get embeddings of input. | ||
# shape: (batch_size, seq_len, d_model) | ||
|
||
if input_embeds is None: | ||
hidden_states = self.embed_tokens(input_ids) | ||
else: | ||
hidden_states = input_embeds | ||
|
||
# Apply blocks one-by-one. | ||
for layer_idx, decoder_layer in enumerate(self.layers): | ||
# shape: (batch_size, seq_len, d_model) | ||
hidden_states = decoder_layer( | ||
positions, | ||
hidden_states, | ||
forward_batch, | ||
) | ||
|
||
# Apply final layer norm. | ||
# shape: (batch_size, seq_len or 1, d_model) | ||
hidden_states = self.norm(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class OlmoForCausalLM(nn.Module): | ||
""" | ||
Extremely barebones HF model wrapper. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: OlmoConfig, | ||
cache_config=None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
): | ||
super().__init__() | ||
self.config = config | ||
self.model = OlmoModel(config, quant_config) | ||
if config.tie_word_embeddings: | ||
self.lm_head = self.model.embed_tokens | ||
else: | ||
self.unpadded_vocab_size = config.vocab_size | ||
self.lm_head = ParallelLMHead( | ||
self.unpadded_vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=config.vocab_size, | ||
quant_config=quant_config, | ||
) | ||
self.logits_processor = LogitsProcessor(config) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
forward_batch: ForwardBatch, | ||
input_embeds: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
hidden_states = self.model( | ||
input_ids=input_ids, | ||
positions=positions, | ||
forward_batch=forward_batch, | ||
input_embeds=input_embeds, | ||
) | ||
return self.logits_processor( | ||
input_ids, hidden_states, self.lm_head.weight, forward_batch | ||
) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("qkv_proj", "q_proj", "q"), | ||
("qkv_proj", "k_proj", "k"), | ||
("qkv_proj", "v_proj", "v"), | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
] | ||
params_dict = dict(self.named_parameters(remove_duplicate=False)) | ||
for name, loaded_weight in weights: | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: | ||
# Models trained using ColossalAI may include these tensors in | ||
# the checkpoint. Skip them. | ||
continue | ||
# With tie_word_embeddings, we can skip lm_head.weight | ||
# The weight might appear unnecessarily in the files if the model is | ||
# processed with quantization, LoRA, fine-tuning, etc. | ||
if self.config.tie_word_embeddings and "lm_head.weight" in name: | ||
continue | ||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
|
||
|
||
EntryClass = OlmoForCausalLM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters