From 7a665edf55dd30d43fab2c02fe07a7dc1c8de131 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 19 Jun 2024 17:09:06 +0800 Subject: [PATCH] support old version internlm2 --- xtuner/model/modules/dispatch/__init__.py | 2 + xtuner/model/modules/dispatch/internlm2.py | 278 ++++++++++++++++++++- 2 files changed, 272 insertions(+), 8 deletions(-) diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 7cb159515..c2033c3c8 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -125,6 +125,8 @@ ) ROTE_DISPATCH_MAPPING = dict( + InternLM2RotaryEmbedding=LazyObject( + 'xtuner.model.modules.dispatch.internlm2', 'InternLM2RotaryEmbedding'), InternLMRotaryEmbedding=LazyObject( 'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'), MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral', diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index 5b855d4ab..34ea53e22 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings +from inspect import signature from typing import Optional, Tuple import torch import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from mmengine import MessageHub from transformers.cache_utils import Cache, StaticCache @@ -10,7 +14,54 @@ from xtuner.parallel.sequence import (get_sequence_parallel_world_size, post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn) -from .attention import SUPPORT_FLASH2, flash_attn_wo_mask, varlen_flash_attn +from .attention import (SUPPORT_FLASH2, flash_attn_w_mask, flash_attn_wo_mask, + varlen_flash_attn) +from .triton_kernels import apply_rotary_emb + + +# Copied from https://huggingface.co/internlm/internlm2-20b/blob/fa45716009471c75016da0ba85308cff1afd030a/modeling_internlm2.py#L97 # noqa: E501 +class InternLM2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for the InternLM2 model. + + Credits to the Reddit user /u/lucidrains. + """ + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + base**(torch.arange(0, dim, 2, + dtype=torch.int64).float().to(device) / dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance( + device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() + @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): @@ -60,6 +111,117 @@ def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim) +def _is_legacy(rote): + params = signature(rote.forward).parameters + return 'seq_len' in params + + +def _internlm2_attn_forward_legacy( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +): + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37' + 'Please make sure use `attention_mask` instead.`') + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat kv for sequence parallel + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if SUPPORT_FLASH2: + # the shape of attention_mask used by flash_attn and + # F.scaled_dot_product_attention are different + assert attention_mask is None or attention_mask.ndim == 2, \ + ('When using flash_attn, attention_mask.ndim should equal to 2.' + f'But got attention_mask.shape = {attention_mask.shape}.' + 'We can pass the `attn_implementation="flash_attention_2"` flag ' + 'to `.from_pretrained` method when instantiating a Internlm2 ' + 'model.') + # flash attn 2 need (bs, seq_len, nhead, h_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal = self.is_causal and q_len != 1 + + if attention_mask is not None: + attn_output = flash_attn_w_mask( + query_states, + key_states, + value_states, + attention_mask, + causal=causal, + training=self.training) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=causal, + training=self.training) + else: + # use flash attention implemented by pytorch + # do not support sequence parallel + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def internlm2_attn_forward( self, hidden_states: torch.Tensor, @@ -70,6 +232,12 @@ def internlm2_attn_forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ): + if _is_legacy(self.rotary_emb): + return _internlm2_attn_forward_legacy(self, hidden_states, + attention_mask, position_ids, + past_key_value, + output_attentions, use_cache) + if isinstance(past_key_value, StaticCache): raise ValueError( '`static` cache implementation is not compatible with ' @@ -171,6 +339,101 @@ def internlm2_attn_forward( return attn_output, attn_weights, past_key_value +def _internlm2_varlen_attn_forward_legacy( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + + message_hub = MessageHub.get_instance('varlen_attn_args') + rank = dist.get_rank() + cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') + max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') + use_varlen_atten = (cumulative_len is not None) + + bsz, q_len, _ = hidden_states.size() + + assert bsz == 1, (f'If utilizing local attention, the batch size should be' + f' set to 1, but got {bsz}') + + qkv_states = self.wqkv(hidden_states) + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + kv_seq_len = key_states.shape[-3] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if use_varlen_atten: + # Adapt to the new version of rote + cos, sin = self.rotary_emb(value_states, position_ids) + query_states = apply_rotary_emb(query_states, cos.squeeze(0), + sin.squeeze(0)) + key_states = apply_rotary_emb(key_states, cos.squeeze(0), + sin.squeeze(0)) + else: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + + assert SUPPORT_FLASH2 + if use_varlen_atten: + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + training=self.training) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + training=False) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + # Due to the implementation of the PyTorch version of flash attention, + # even when the output_attentions flag is set to True, it is not possible + # to return the attn_weights. + return attn_output, None, past_key_value + + def internlm2_varlen_attn_forward( self, hidden_states: torch.Tensor, @@ -183,6 +446,11 @@ def internlm2_varlen_attn_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if _is_legacy(self.rotary_emb): + return _internlm2_varlen_attn_forward_legacy( + self, hidden_states, attention_mask, position_ids, past_key_value, + output_attentions, use_cache) + if isinstance(past_key_value, StaticCache): raise ValueError( '`static` cache implementation is not compatible with ' @@ -219,13 +487,7 @@ def internlm2_varlen_attn_forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - try: - cos, sin = self.rotary_emb(value_states, position_ids) - except RuntimeError: - raise RuntimeError( - 'You are using the old version of InternLM2 model. The ' - '`modeling_internlm2.py` is outdated. Please update the InternLM2 ' - 'model.') + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)