diff --git a/xtuner/configs/yi/yi_34b/yi_34b_qlora_alpaca_enzh_e3.py b/xtuner/configs/yi/yi_34b/yi_34b_qlora_alpaca_enzh_e3.py new file mode 100644 index 000000000..6de89fd66 --- /dev/null +++ b/xtuner/configs/yi/yi_34b/yi_34b_qlora_alpaca_enzh_e3.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from bitsandbytes.optim import PagedAdamW32bit +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR +from peft import LoraConfig +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from xtuner.dataset import ConcatDataset, process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import (alpaca_map_fn, alpaca_zh_map_fn, + template_map_fn_factory) +from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = '01-ai/Yi-34B' + +# Data +alpaca_zh_path = 'silk-road/alpaca-data-gpt4-chinese' +alpaca_en_path = 'tatsu-lab/alpaca' +prompt_template = PROMPT_TEMPLATE.default +max_length = 2048 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 16 +dataloader_num_workers = 0 +max_epochs = 3 +optim_type = PagedAdamW32bit +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = SYSTEM_TEMPLATE.alpaca +evaluation_inputs = [ + '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' +] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + lora=dict( + type=LoraConfig, + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +alpaca_en = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=alpaca_en_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=alpaca_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +alpaca_zh = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=alpaca_zh_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=alpaca_zh_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +train_dataset = dict( + type=ConcatDataset, + datasets_cfg=dict(alpaca_en=alpaca_en, alpaca_zh=alpaca_zh)) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = dict( + type=CosineAnnealingLR, + eta_min=lr * 0.1, + by_epoch=True, + T_max=max_epochs, + convert_to_iter_based=True) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) diff --git a/xtuner/configs/yi/yi_6b/yi_6b_qlora_alpaca_enzh_e3.py b/xtuner/configs/yi/yi_6b/yi_6b_qlora_alpaca_enzh_e3.py new file mode 100644 index 000000000..5148a71e2 --- /dev/null +++ b/xtuner/configs/yi/yi_6b/yi_6b_qlora_alpaca_enzh_e3.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from bitsandbytes.optim import PagedAdamW32bit +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR +from peft import LoraConfig +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from xtuner.dataset import ConcatDataset, process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import (alpaca_map_fn, alpaca_zh_map_fn, + template_map_fn_factory) +from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = '01-ai/Yi-6B' + +# Data +alpaca_zh_path = 'silk-road/alpaca-data-gpt4-chinese' +alpaca_en_path = 'tatsu-lab/alpaca' +prompt_template = PROMPT_TEMPLATE.default +max_length = 2048 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 16 +dataloader_num_workers = 0 +max_epochs = 3 +optim_type = PagedAdamW32bit +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = SYSTEM_TEMPLATE.alpaca +evaluation_inputs = [ + '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' +] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + lora=dict( + type=LoraConfig, + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +alpaca_en = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=alpaca_en_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=alpaca_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +alpaca_zh = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=alpaca_zh_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=alpaca_zh_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +train_dataset = dict( + type=ConcatDataset, + datasets_cfg=dict(alpaca_en=alpaca_en, alpaca_zh=alpaca_zh)) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = dict( + type=CosineAnnealingLR, + eta_min=lr * 0.1, + by_epoch=True, + T_max=max_epochs, + convert_to_iter_based=True) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) diff --git a/xtuner/model/modules/dispatch.py b/xtuner/model/modules/dispatch.py index 67111872a..181ead803 100644 --- a/xtuner/model/modules/dispatch.py +++ b/xtuner/model/modules/dispatch.py @@ -10,6 +10,7 @@ baichuan_13b_attn_forward) from .internlm import internlm_attn_forward from .llama import llama_attn_forward +from .yi import yi_attn_forward NO_ATTN_WEIGHTS_MSG = ( 'Due to the implementation of the PyTorch version of flash attention, ' @@ -70,6 +71,17 @@ def dispath_baichuan_13b_attn_forward(model): module) +def dispatch_yi_attn_forward(model): + if digit_version(torch.__version__) < digit_version('2.0.0'): + # flash attention is only supported after pytorch2.0 + return + print_log('dispatch yi attn forward', 'current') + print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) + for module in model.modules(): + if type(module).__name__ == 'YiAttention': + module.forward = types.MethodType(yi_attn_forward, module) + + def dispatch_modules(model): model_name = model.__class__.__name__.lower() if 'internlm' in model_name: @@ -80,3 +92,5 @@ def dispatch_modules(model): dispath_baichuan2_norm_head_forward(model) dispath_baichuan_7b_attn_forward(model) dispath_baichuan_13b_attn_forward(model) + if 'yi' in model_name: + dispatch_yi_attn_forward(model) diff --git a/xtuner/model/modules/yi.py b/xtuner/model/modules/yi.py new file mode 100644 index 000000000..04b84b07f --- /dev/null +++ b/xtuner/model/modules/yi.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, + # so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """This is the equivalent of torch.repeat_interleave(x, dim=1, + repeats=n_rep). + + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +def yi_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # use flash attention implemented by pytorch + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + # Due to the implementation of the PyTorch version of flash attention, + # even when the output_attentions flag is set to True, it is not possible + # to return the attn_weights. + return attn_output, None, past_key_value