diff --git a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_e3.py b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_e3.py index 966284430..207da5f8f 100644 --- a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_e3.py +++ b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_e3.py @@ -10,7 +10,7 @@ from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory -from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.engine import DatasetInfoHook, EvaluateChatHook, ThroughputHook from xtuner.model import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE @@ -118,7 +118,8 @@ every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, system=SYSTEM, - prompt_template=prompt_template) + prompt_template=prompt_template), + dict(type=ThroughputHook) ] # configure default hooks diff --git a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_e3.py b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_e3.py index dc1f2e69c..591974519 100644 --- a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_e3.py +++ b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_e3.py @@ -11,7 +11,7 @@ from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import (alpaca_map_fn, alpaca_zh_map_fn, template_map_fn_factory) -from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.engine import DatasetInfoHook, EvaluateChatHook, ThroughputHook from xtuner.model import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE @@ -136,7 +136,8 @@ every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, system=SYSTEM, - prompt_template=prompt_template) + prompt_template=prompt_template), + dict(type=ThroughputHook) ] # configure default hooks diff --git a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_oasst1_e3.py b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_oasst1_e3.py index 1b1274d17..92dc2bfb4 100644 --- a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_oasst1_e3.py +++ b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_enzh_oasst1_e3.py @@ -11,7 +11,7 @@ from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import (alpaca_map_fn, alpaca_zh_map_fn, oasst1_map_fn, template_map_fn_factory) -from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.engine import DatasetInfoHook, EvaluateChatHook, ThroughputHook from xtuner.model import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE @@ -149,7 +149,8 @@ every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, system=SYSTEM, - prompt_template=prompt_template) + prompt_template=prompt_template), + dict(type=ThroughputHook) ] # configure default hooks diff --git a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_zh_e3.py b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_zh_e3.py index 82ec21a33..24e0b638c 100644 --- a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_zh_e3.py +++ b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_alpaca_zh_e3.py @@ -10,7 +10,7 @@ from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import alpaca_zh_map_fn, template_map_fn_factory -from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.engine import DatasetInfoHook, EvaluateChatHook, ThroughputHook from xtuner.model import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE @@ -118,7 +118,8 @@ every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, system=SYSTEM, - prompt_template=prompt_template) + prompt_template=prompt_template), + dict(type=ThroughputHook) ] # configure default hooks diff --git a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_oasst1_e3.py b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_oasst1_e3.py index 572eb88d9..d3fa65a2b 100644 --- a/xtuner/configs/internlm/internlm_7b/internlm_7b_full_oasst1_e3.py +++ b/xtuner/configs/internlm/internlm_7b/internlm_7b_full_oasst1_e3.py @@ -10,7 +10,7 @@ from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory -from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.engine import DatasetInfoHook, EvaluateChatHook, ThroughputHook from xtuner.model import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE @@ -119,7 +119,8 @@ every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, system=SYSTEM, - prompt_template=prompt_template) + prompt_template=prompt_template), + dict(type=ThroughputHook) ] # configure default hooks diff --git a/xtuner/configs/llama/llama2_70b/llama2_70b_full_wizardlm_e1.py b/xtuner/configs/llama/llama2_70b/llama2_70b_full_wizardlm_e1.py new file mode 100644 index 000000000..31944f072 --- /dev/null +++ b/xtuner/configs/llama/llama2_70b/llama2_70b_full_wizardlm_e1.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import template_map_fn_factory, wizardlm_map_fn +from xtuner.engine import DatasetInfoHook, EvaluateChatHook, ThroughputHook +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'meta-llama/Llama-2-70b-hf' + +# Data +data_path = 'WizardLM/WizardLM_evol_instruct_V2_196k' +prompt_template = PROMPT_TEMPLATE.llama2_chat +max_length = 2048 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 4 # 1bs * 4acc * 32gpu = 128 batchsize +dataloader_num_workers = 0 +max_epochs = 3 +optim_type = AdamW +lr = 2e-5 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = '' +evaluation_inputs = [ + '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' +] + +####################################################################### +# PART 2 Model & Tokenizer #q +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=data_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=wizardlm_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', +) + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = dict( + type=CosineAnnealingLR, + eta_min=lr * 0.1, + by_epoch=True, + T_max=max_epochs, + convert_to_iter_based=True) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template), + dict(type=ThroughputHook) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) diff --git a/xtuner/configs/llama/llama2_7b/llama2_7b_full_wizardlm_e1.py b/xtuner/configs/llama/llama2_7b/llama2_7b_full_wizardlm_e1.py index c4e557541..27b27d776 100644 --- a/xtuner/configs/llama/llama2_7b/llama2_7b_full_wizardlm_e1.py +++ b/xtuner/configs/llama/llama2_7b/llama2_7b_full_wizardlm_e1.py @@ -10,7 +10,7 @@ from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import template_map_fn_factory, wizardlm_map_fn -from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.engine import DatasetInfoHook, EvaluateChatHook, ThroughputHook from xtuner.model import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE @@ -119,7 +119,8 @@ every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, system=SYSTEM, - prompt_template=prompt_template) + prompt_template=prompt_template), + dict(type=ThroughputHook) ] # configure default hooks diff --git a/xtuner/dataset/huggingface.py b/xtuner/dataset/huggingface.py index 13bd00a2f..ae261b185 100644 --- a/xtuner/dataset/huggingface.py +++ b/xtuner/dataset/huggingface.py @@ -6,23 +6,25 @@ from datasets import DatasetDict from mmengine import print_log from mmengine.config import Config, ConfigDict +from torch import distributed as dist from xtuner.registry import BUILDER, MAP_FUNC from .utils import Packer, encode_fn -def process_hf_dataset(dataset, - tokenizer, - max_length, - dataset_map_fn=None, - template_map_fn=None, - max_dataset_length=None, - split='train', - remove_unused_columns=False, - rename_maps=[], - shuffle_before_pack=True, - pack_to_max_length=True, - input_ids_with_output=True): +def process(dataset, + tokenizer, + max_length, + dataset_map_fn=None, + template_map_fn=None, + max_dataset_length=None, + split='train', + remove_unused_columns=False, + rename_maps=[], + shuffle_before_pack=True, + pack_to_max_length=True, + input_ids_with_output=True, + map_num_proc=32): """Post-process the dataset loaded from the Hugging Face Hub, or a local dataset. @@ -51,6 +53,7 @@ def process_hf_dataset(dataset, input_ids_with_output: Whether to put the groundtruth output corresponding to the question into the dataset. Typically set it to True during training and False during testing. + map_num_proc: Max number of processes when mapping the dataset. """ if isinstance(dataset, DatasetDict): @@ -74,7 +77,7 @@ def process_hf_dataset(dataset, if isinstance(dataset_map_fn, str): dataset_map_fn = MAP_FUNC.get(dataset_map_fn) - dataset = dataset.map(dataset_map_fn) + dataset = dataset.map(dataset_map_fn, num_proc=map_num_proc) # Add prompt template, such as <|System|>: xxx <|User|>: xxx <|Bot|>: xxx if template_map_fn is not None: @@ -82,7 +85,7 @@ def process_hf_dataset(dataset, template_map_fn, Config) or isinstance(template_map_fn, ConfigDict): template_map_fn = BUILDER.build(template_map_fn) - dataset = dataset.map(template_map_fn) + dataset = dataset.map(template_map_fn, num_proc=map_num_proc) for old, new in rename_maps: dataset = dataset.rename_column(old, new) @@ -110,13 +113,28 @@ def process_hf_dataset(dataset, max_length=max_length, input_ids_with_output=input_ids_with_output), remove_columns=list(dataset.column_names) - if remove_unused_columns else None) + if remove_unused_columns else None, + num_proc=map_num_proc) # pack to max length if pack_to_max_length and split == 'train': if shuffle_before_pack: dataset = dataset.shuffle() dataset = dataset.flatten_indices() - dataset = dataset.map(Packer(max_length), batched=True) + dataset = dataset.map( + Packer(max_length), batched=True, num_proc=map_num_proc) return dataset + + +def process_hf_dataset(*args, **kwargs): + if not (dist.is_available() and dist.is_initialized()): + return process(*args, **kwargs) + + if dist.get_rank() == 0: + dataset = process(*args, **kwargs) + objects = [dataset] + else: + objects = [None] + dist.broadcast_object_list(objects, src=0) + return objects[0] diff --git a/xtuner/engine/__init__.py b/xtuner/engine/__init__.py index 6e2d891c6..038364677 100644 --- a/xtuner/engine/__init__.py +++ b/xtuner/engine/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .hooks import DatasetInfoHook, EvaluateChatHook +from ._strategy import DeepSpeedStrategy +from .hooks import DatasetInfoHook, EvaluateChatHook, ThroughputHook -__all__ = ['EvaluateChatHook', 'DatasetInfoHook'] +__all__ = [ + 'EvaluateChatHook', 'DatasetInfoHook', 'ThroughputHook', + 'DeepSpeedStrategy' +] diff --git a/xtuner/engine/_strategy/__init__.py b/xtuner/engine/_strategy/__init__.py new file mode 100644 index 000000000..bac6095f9 --- /dev/null +++ b/xtuner/engine/_strategy/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .deepspeed import DeepSpeedStrategy + +__all__ = ['DeepSpeedStrategy'] diff --git a/xtuner/engine/_strategy/deepspeed.py b/xtuner/engine/_strategy/deepspeed.py new file mode 100644 index 000000000..b031dece6 --- /dev/null +++ b/xtuner/engine/_strategy/deepspeed.py @@ -0,0 +1,15 @@ +from mmengine._strategy import DeepSpeedStrategy as MMEngineDeepSpeedStrategy + +from xtuner.registry import STRATEGIES + + +@STRATEGIES.register_module() +class DeepSpeedStrategy(MMEngineDeepSpeedStrategy): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + from transformers.integrations.deepspeed import HfDeepSpeedConfig + + # hf_deepspeed_config has to be saved as an attribute. + self.hf_deepspeed_config = HfDeepSpeedConfig(self.config) diff --git a/xtuner/engine/hooks/__init__.py b/xtuner/engine/hooks/__init__.py index a8ad7f8ec..c23b6b347 100644 --- a/xtuner/engine/hooks/__init__.py +++ b/xtuner/engine/hooks/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dataset_info_hook import DatasetInfoHook from .evaluate_chat_hook import EvaluateChatHook +from .throughput_hook import ThroughputHook -__all__ = ['EvaluateChatHook', 'DatasetInfoHook'] +__all__ = ['EvaluateChatHook', 'DatasetInfoHook', 'ThroughputHook'] diff --git a/xtuner/engine/hooks/throughput_hook.py b/xtuner/engine/hooks/throughput_hook.py new file mode 100644 index 000000000..0a087b865 --- /dev/null +++ b/xtuner/engine/hooks/throughput_hook.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import torch +from mmengine.hooks import Hook +from mmengine.model.wrappers import is_model_wrapper +from torch.utils._pytree import tree_flatten + +DATA_BATCH = Optional[Union[dict, tuple, list]] + + +class ThroughputHook(Hook): + priority = 'BELOW_NORMAL' + + def __init__(self, + use_activation_checkpointing=None, + hidden_size=None, + num_layers=None, + vocab_size=None, + mlp_ratio=None): + self.use_activation_checkpointing = use_activation_checkpointing + self.hidden_size = hidden_size + self.num_layers = num_layers + self.vocab_size = vocab_size + self.mlp_ratio = mlp_ratio + + def before_run(self, runner) -> None: + if is_model_wrapper(runner.model): + model = runner.model.module + else: + model = runner.model + self.use_activation_checkpointing = \ + (self.use_activation_checkpointing or + self._guess_use_activation_checkpointing(model)) + self.hidden_size = self.hidden_size or model.config.hidden_size + self.num_layers = self.num_layers or model.config.num_hidden_layers + self.vocab_size = self.vocab_size or model.config.vocab_size + self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size / + model.config.hidden_size) + self.mlp_ratio *= 1.5 # has gate_proj + return + + def _get_batch_size_and_sequence_len(self, data_batch): + data_list, _ = tree_flatten(data_batch) + for data in data_list: + if isinstance(data, torch.Tensor): + return data.size(0), data.size(1) + raise RuntimeError('No tensor found in the batch') + + def _guess_use_activation_checkpointing(self, model): + for module in model.modules(): + if hasattr(module, 'gradient_checkpointing'): + return module.gradient_checkpointing + return False + + def after_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: + """Calc flops based on the paper of Megatron + https://deepakn94.github.io/assets/papers/megatron-sc21.pdf.""" + + batch_size, sequence_len = self._get_batch_size_and_sequence_len( + data_batch) + + message_hub = runner.message_hub + iter_time = message_hub.get_scalar('train/time').current() + + flops_per_iteration = ( + (3 + int(self.use_activation_checkpointing)) * + ((8 + self.mlp_ratio * 4) * batch_size * sequence_len * + self.hidden_size**2 + + 4 * batch_size * sequence_len**2 * self.hidden_size) + ) * self.num_layers + \ + 6 * batch_size * sequence_len * self.hidden_size * self.vocab_size + + avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12) + tokens_per_sec_per_gpu = batch_size * sequence_len / ( + iter_time + 1e-12) + + message_hub.update_scalar('train/tflops', avg_tflops_per_gpu) + message_hub.update_scalar('train/tokens_per_sec', + tokens_per_sec_per_gpu) diff --git a/xtuner/registry.py b/xtuner/registry.py index 7c8907e0b..3bd69829d 100644 --- a/xtuner/registry.py +++ b/xtuner/registry.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmengine.registry import STRATEGIES as MMENGINE_STRATEGIES from mmengine.registry import Registry -__all__ = ['BUILDER', 'MAP_FUNC'] +__all__ = ['BUILDER', 'MAP_FUNC', 'STRATEGIES'] BUILDER = Registry('builder') MAP_FUNC = Registry('map_fn') +STRATEGIES = Registry('strategy', parent=MMENGINE_STRATEGIES) diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py index 2e73994af..ebfda801e 100644 --- a/xtuner/tools/train.py +++ b/xtuner/tools/train.py @@ -11,6 +11,7 @@ from mmengine.logging import print_log from mmengine.registry import RUNNERS from mmengine.runner import Runner +from mmengine.utils import digit_version from peft import get_peft_model, prepare_model_for_kbit_training from transformers import TrainingArguments @@ -165,10 +166,14 @@ def main(): if args.deepspeed: try: - import deepspeed # pre-check # noqa: F401 + import deepspeed except ImportError: raise ImportError( 'deepspeed is not installed properly, please check.') + if digit_version(deepspeed.__version__) < digit_version('0.12.3'): + raise RuntimeError('Please upgrade your DeepSpeed version ' + 'by using the command pip install ' + '`deepspeed>=0.12.3`') optim_wrapper = cfg.optim_wrapper.type if optim_wrapper == 'DeepSpeedOptimWrapper': print_log( @@ -224,13 +229,15 @@ def main(): level=logging.WARNING) grad_clip = mm_max_norm ds_cfg = auto_dtype_of_deepspeed_config(ds_cfg) + exclude_frozen_parameters = True if digit_version( + deepspeed.__version__) >= digit_version('0.10.1') else None strategy = dict( - type='DeepSpeedStrategy', + type='xtuner.DeepSpeedStrategy', config=ds_cfg, gradient_accumulation_steps=grad_accum, train_micro_batch_size_per_gpu=train_bs, gradient_clipping=grad_clip, - exclude_frozen_parameters=True) + exclude_frozen_parameters=exclude_frozen_parameters) cfg.__setitem__('strategy', strategy) optim_wrapper = dict( type='DeepSpeedOptimWrapper',