Skip to content

Commit

Permalink
support old version internlm2
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Jun 19, 2024
1 parent 5f2bca4 commit 7a665ed
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 8 deletions.
2 changes: 2 additions & 0 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@
)

ROTE_DISPATCH_MAPPING = dict(
InternLM2RotaryEmbedding=LazyObject(
'xtuner.model.modules.dispatch.internlm2', 'InternLM2RotaryEmbedding'),
InternLMRotaryEmbedding=LazyObject(
'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'),
MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
Expand Down
278 changes: 270 additions & 8 deletions xtuner/model/modules/dispatch/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from inspect import signature
from typing import Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from mmengine import MessageHub
from transformers.cache_utils import Cache, StaticCache

from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
post_process_for_sequence_parallel_attn,
pre_process_for_sequence_parallel_attn)
from .attention import SUPPORT_FLASH2, flash_attn_wo_mask, varlen_flash_attn
from .attention import (SUPPORT_FLASH2, flash_attn_w_mask, flash_attn_wo_mask,
varlen_flash_attn)
from .triton_kernels import apply_rotary_emb


# Copied from https://huggingface.co/internlm/internlm2-20b/blob/fa45716009471c75016da0ba85308cff1afd030a/modeling_internlm2.py#L97 # noqa: E501
class InternLM2RotaryEmbedding(nn.Module):
"""Rotary Position Embedding for the InternLM2 model.
Credits to the Reddit user /u/lucidrains.
"""

def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
base**(torch.arange(0, dim, 2,
dtype=torch.int64).float().to(device) / dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings

@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(
device_type, str) and device_type != 'mps' else 'cpu'
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float()
@ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
Expand Down Expand Up @@ -60,6 +111,117 @@ def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
head_dim)


def _is_legacy(rote):
params = signature(rote.forward).parameters
return 'seq_len' in params


def _internlm2_attn_forward_legacy(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37'
'Please make sure use `attention_mask` instead.`')

# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop('padding_mask')

output_attentions = False

bsz, q_len, _ = hidden_states.size()

qkv_states = self.wqkv(hidden_states)

qkv_states = rearrange(
qkv_states,
'b q (h gs d) -> b q h gs d',
gs=2 + self.num_key_value_groups,
d=self.head_dim,
)

query_states = qkv_states[..., :self.num_key_value_groups, :]
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat kv for sequence parallel
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if SUPPORT_FLASH2:
# the shape of attention_mask used by flash_attn and
# F.scaled_dot_product_attention are different
assert attention_mask is None or attention_mask.ndim == 2, \
('When using flash_attn, attention_mask.ndim should equal to 2.'
f'But got attention_mask.shape = {attention_mask.shape}.'
'We can pass the `attn_implementation="flash_attention_2"` flag '
'to `.from_pretrained` method when instantiating a Internlm2 '
'model.')
# flash attn 2 need (bs, seq_len, nhead, h_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

causal = self.is_causal and q_len != 1

if attention_mask is not None:
attn_output = flash_attn_w_mask(
query_states,
key_states,
value_states,
attention_mask,
causal=causal,
training=self.training)
else:
attn_output = flash_attn_wo_mask(
query_states,
key_states,
value_states,
causal=causal,
training=self.training)
else:
# use flash attention implemented by pytorch
# do not support sequence parallel
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


def internlm2_attn_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -70,6 +232,12 @@ def internlm2_attn_forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
if _is_legacy(self.rotary_emb):
return _internlm2_attn_forward_legacy(self, hidden_states,
attention_mask, position_ids,
past_key_value,
output_attentions, use_cache)

if isinstance(past_key_value, StaticCache):
raise ValueError(
'`static` cache implementation is not compatible with '
Expand Down Expand Up @@ -171,6 +339,101 @@ def internlm2_attn_forward(
return attn_output, attn_weights, past_key_value


def _internlm2_varlen_attn_forward_legacy(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:

message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
use_varlen_atten = (cumulative_len is not None)

bsz, q_len, _ = hidden_states.size()

assert bsz == 1, (f'If utilizing local attention, the batch size should be'
f' set to 1, but got {bsz}')

qkv_states = self.wqkv(hidden_states)
qkv_states = rearrange(
qkv_states,
'b q (h gs d) -> b q h gs d',
gs=2 + self.num_key_value_groups,
d=self.head_dim,
)

query_states = qkv_states[..., :self.num_key_value_groups, :]
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]

kv_seq_len = key_states.shape[-3]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

if use_varlen_atten:
# Adapt to the new version of rote
cos, sin = self.rotary_emb(value_states, position_ids)
query_states = apply_rotary_emb(query_states, cos.squeeze(0),
sin.squeeze(0))
key_states = apply_rotary_emb(key_states, cos.squeeze(0),
sin.squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# repeat kv for sequence parallel
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)

assert SUPPORT_FLASH2
if use_varlen_atten:
attn_output = varlen_flash_attn(
query_states,
key_states,
value_states,
cumulative_len,
max_seqlen,
training=self.training)
else:
attn_output = flash_attn_wo_mask(
query_states,
key_states,
value_states,
causal=True,
training=False)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.wo(attn_output)

# Due to the implementation of the PyTorch version of flash attention,
# even when the output_attentions flag is set to True, it is not possible
# to return the attn_weights.
return attn_output, None, past_key_value


def internlm2_varlen_attn_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -183,6 +446,11 @@ def internlm2_varlen_attn_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:

if _is_legacy(self.rotary_emb):
return _internlm2_varlen_attn_forward_legacy(
self, hidden_states, attention_mask, position_ids, past_key_value,
output_attentions, use_cache)

if isinstance(past_key_value, StaticCache):
raise ValueError(
'`static` cache implementation is not compatible with '
Expand Down Expand Up @@ -219,13 +487,7 @@ def internlm2_varlen_attn_forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

try:
cos, sin = self.rotary_emb(value_states, position_ids)
except RuntimeError:
raise RuntimeError(
'You are using the old version of InternLM2 model. The '
'`modeling_internlm2.py` is outdated. Please update the InternLM2 '
'model.')
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin)

Expand Down

0 comments on commit 7a665ed

Please sign in to comment.