diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 4b9f51976..9c4080fae 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -1,4 +1,4 @@ -exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/ +exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/|^xtuner/_lite/modelings/|^xtuner/_lite/accelerate/dispatches/huggingface/ repos: - repo: https://gitee.com/openmmlab/mirrors-flake8 rev: 5.0.4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f6bbfd633..245f17c69 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/ +exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/|^xtuner/_lite/modelings/|^xtuner/_lite/accelerate/dispatches/huggingface/ repos: - repo: https://github.com/PyCQA/flake8 rev: 5.0.4 diff --git a/requirements/lmdeploy.txt b/requirements/lmdeploy.txt new file mode 100644 index 000000000..25ef3916f --- /dev/null +++ b/requirements/lmdeploy.txt @@ -0,0 +1 @@ +lmdeploy>=0.6.2 --no-deps \ No newline at end of file diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 3a4d2f84e..a90dae7b7 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,16 +1,11 @@ -# Minimum 0.40.0.post4 to fix some 4-bit precision bugs -bitsandbytes>=0.40.0.post4 # Minimum 2.16.0 to fix some bugs, see https://github.com/huggingface/datasets/pull/6444 datasets>=2.16.0 -einops -# Minimum 0.1.2 to fix some bugs, see https://github.com/InternLM/lagent/pull/44 -lagent>=0.1.2 +einop +# Avoid `import cv2` failed +opencv-python==4.7.0.72 # Minimum 0.10.3 to support distributed evaluation for MMBench # see https://github.com/open-mmlab/mmengine/pull/1469 mmengine>=0.10.3 -openpyxl -# Minimum 0.4.0 to support QLoRA, see https://github.com/huggingface/peft/pull/476 -peft>=0.4.0 scikit-image scipy SentencePiece @@ -23,5 +18,7 @@ torchvision # https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923 # transformers >= 4.43.0 use _flash_attention_forward but not self._flash_attention_forward # to calculate attn output which lead to bc braeking -transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2,<=4.42.4 +transformers>=4.45 transformers_stream_generator +loguru +pydantic \ No newline at end of file diff --git a/setup.py b/setup.py index 7a95dfab4..fe4d1b4f1 100644 --- a/setup.py +++ b/setup.py @@ -117,10 +117,12 @@ def gen_packages_items(): 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Utilities', ], # Python maximum version <3.11, to support mpi4py-mpich - python_requires='>=3.8, <3.11', + python_requires='>=3.8, <3.13', license='Apache License 2.0', install_requires=parse_requirements('requirements/runtime.txt'), extras_require={ diff --git a/tools/fsdp_sft.py b/tools/fsdp_sft.py new file mode 100644 index 000000000..d0d6bd00d --- /dev/null +++ b/tools/fsdp_sft.py @@ -0,0 +1,1036 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import math +import os +import sys +import re +import time +import shutil +import requests +import gc +from collections import OrderedDict +from contextlib import nullcontext +from concurrent.futures import wait +from datetime import datetime, timedelta +from functools import partial + +import torch +import torch.distributed as dist +from torch.nn import functional as F +import torch.distributed.checkpoint as dcp +from torch.distributed._composable.fsdp import MixedPrecisionPolicy +from accelerate.utils import set_module_tensor_to_device +from mmengine import mkdir_or_exist +from mmengine.runner import set_random_seed +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env + +from peft import LoraConfig, get_peft_model +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + apply_activation_checkpointing +from torch.distributed.checkpoint.state_dict import (StateDictOptions, + get_model_state_dict, + get_state_dict, set_state_dict) +from torch.distributed.checkpoint.stateful import Stateful + +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.api import CPUOffload, ShardingStrategy +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import _or_policy +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from torch.utils.data import ConcatDataset, DataLoader +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.utils.import_utils import is_flash_attn_2_available + +from xtuner._lite import (AutoTokenizer, get_device, get_logger, + get_torch_device_module) +from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_hf_code, LoadWoInit, + packed_sequence, varlen_attn_is_available, profile_time_and_memory) +from xtuner._lite.algorithms.sft import SftCollator, SftTokenizeFunction +from xtuner._lite.chat import CHAT_TEMPLATE_MAP +from xtuner._lite.datasets import (DATASET_CLS_MAP, OPENAI_CONVERT_MAP, + SoftPackDataset, HardPackDataset, load_datasets) +from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler, + get_dp_mesh, get_sp_mesh, + pad_for_sequence_parallel, + reduce_sequence_parallel_loss, + setup_parallel, split_for_sequence_parallel) + +from xtuner._lite.parallel import (ParallelSampler, get_dp_mesh, get_fsdp_mesh, + get_sp_mesh, get_tp_mesh, get_world_mesh, get_same_data_mesh, + pad_for_sequence_parallel, setup_parallel, + reduce_sequence_parallel_loss, + split_for_sequence_parallel) +from xtuner._lite.parallel.megatron import megatron_parallelize +from xtuner._lite.parallel.fsdp import clip_grad_norm_ + +gc.disable() +logger = get_logger() + +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() + +SUPPORT_DATA_FORMATS = OPENAI_CONVERT_MAP.keys() + +def log_format(rank, debug=False): + + sp_rank = get_sp_mesh().get_local_rank() + dp_rank = get_dp_mesh().get_local_rank() + tp_rank = get_tp_mesh().get_local_rank() + fsdp_rank = get_fsdp_mesh().get_local_rank() + + formatter = f'[XTuner][RANK {rank}][DP {dp_rank}][SP {sp_rank}][TP {tp_rank}]' + formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]' + + if debug: + formatter += '[{name}:' + formatter += '{function}:' + formatter += '{line}]' + + formatter += ' {message}' + return formatter + +def send_to_feishu(web_hook, msg): + + header = { + "Content-Type" : "application/json;charset=UTF-8" + } + + body = { + "msg_type" : "text", + "content" : { "text" : f"所有人{msg}"} + } + + try: + requests.post(url=web_hook, json=body, headers=header, timeout=1) + except requests.exceptions.RequestException: + pass + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + + model_args = parser.add_argument_group('model', 'Model Related Settings') + model_args.add_argument('--llm', help='repo id or local path of the model') + model_args.add_argument( + '-t', + '--tokenizer', + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--chat-template', + choices=CHAT_TEMPLATE_MAP.keys(), + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--use-lora', action='store_true', help='Apply the adapter to LLM.') + model_args.add_argument( + '--lora-targets', + default=None, + nargs='*', + help='The names of the modules to apply the adapter to. ') + model_args.add_argument( + '--lora-r', default=64, type=int, help="Not updating vit's parameters") + model_args.add_argument( + '--lora-alpha', + default=16, + type=int, + help='The alpha parameter for Lora scaling.') + model_args.add_argument( + '--lora-dropout', + default=0.1, + type=float, + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--lora-bias', + default='none', + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--dtype', + default='auto', + choices=['fp16', 'bf16', 'auto'], + help=("the dtype of the model forward. When set to 'auto', it will " + 'automatically determine whether bf16 is available, ' + 'prioritizing the use of bf16.')) + + model_args.add_argument( + '--selective-recompute', + default=1.0, + type=float, + help=('the ratio of re-computation for transforemer layers. ' + 'The maximum is 1; the larger the value, the less memory ' + 'required for training. The default is 1, meaning all layers ' + 'need to be re-computated.')) + model_args.add_argument( + '--shard-strategy', + default='full', + choices=['full', 'hybrid'], + help=('The sharding strategy to be used for distributed training.')) + model_args.add_argument('--cpu-offload', action='store_true', help=('')) + model_args.add_argument('--sp-size', type=int, default=1, help='') + data_args = parser.add_argument_group('data', 'Dataset Related Settings') + data_args.add_argument( + '--datasets', + nargs='*', + help=('repo id or local path or dir of the datasets. For repo ids, ' + 'the `dset-sources` needs to be appropriately set to ' + '`modelscope` or `huggingface`. For local dir, all json and ' + 'jsonl files will be loaded by default. The type of loaded ' + 'files can be controlled by setting `dset-file-type`')) + data_args.add_argument( + '--dset-file-types', + nargs='*', + default=DATASET_CLS_MAP.keys(), + choices=DATASET_CLS_MAP.keys(), + help='the file type that needs to be loaded') + data_args.add_argument( + '--dset-sources', + nargs='*', + default=['local'], + choices=['local', 'huggingface', 'modelscope'], + help=('the source of each dataset; it can accept one or the same ' + 'number of args as the number of `datasets`, with one arg ' + 'indicating that all datasets come from the same source. ' + '`local` represents the local path, `huggingface` represents ' + 'the open-source data in the Huggingface Hub, `modelscope` ' + 'indicates the open-source data in the Modelscope Hub.')) + data_args.add_argument( + '--dset-formats', + nargs='*', + default=['openai'], + help=('the format of each dataset; it can accept one or the same ' + 'number of args as the number of `datasets`, with one arg ' + 'indicating that all datasets are the same format.')) + data_args.add_argument( + '--dset-sample-ratios', + nargs='*', + type=float, + default=[1.0], + help=('the sample ratio of each dataset; it can accept one or the ' + 'same number of args as the number of `datasets`, with one arg ' + 'indicating that all datasets use the same sample ratio.')) + data_args.add_argument( + '--dset-cache-dir', + help=('the cache dir of the loaded datasets. When the `datasets` is ' + 'set, the loaded datasets will be cached to this dir. If the ' + '`datasets` are not set, the cached dataset in this dir will be ' + 'loaded.')) + data_args.add_argument( + '--dset-pack-level', + choices=['hard', 'soft'], + help=('the level of data packing. When `hard`, multiple data will be ' + 'packed to `max_length`, potentially causing some data to be ' + 'truncated, and the length of the packed data will always ' + 'be `max_length`; When `soft`, it will pack multiple data ' + 'into nearly `max_length` without truncating the data.')) + data_args.add_argument( + '--global-pack', + action='store_true', + help='A subsequence in the packed data comes from different files.') + data_args.add_argument( + '--max-length', + type=int, + default=2048, + help=('the maximum length of each piece of data, any excess will be ' + 'truncated.')) + data_args.add_argument( + '--num-workers', + type=int, + default=8, + help='how many subprocesses to use for data loading.') + data_args.add_argument('--file-pattern', type=str, default=None) + data_args.add_argument('--group-by-length', action='store_true') + + optim_args = parser.add_argument_group('optim', 'Optim Related Settings') + optim_args.add_argument( + '--mirco-batch-size', + type=int, + default=1, + help='batch size for each forward + backward pass') + optim_args.add_argument( + '--global-batch-size', + type=int, + default=16, + help='batch size for each optimizer step') + + optim_args.add_argument( + '--lr', default=4e-5, type=float, help='learning rate.') + optim_args.add_argument( + '--lr-min', default=6e-6, type=float, help='min learning rate.') + optim_args.add_argument( + '--wd', default=0.01, type=float, help='weight decay.') + optim_args.add_argument( + '--max-grad-norm', default=1, type=float, help='gradient clipping') + optim_args.add_argument( + '-e', '--epochs', default=1, type=int, help='total training epochs.') + optim_args.add_argument( + '--warmup-ratio', + default=0.03, + type=float, + help=('the proportion of training steps for learning rate warm-up in ' + 'relation to the total training steps.')) + + parser.add_argument('-c', '--config', default=None) + parser.add_argument( + '--work-dir', + default='work_dirs', + help='the dir to save logs and checkpoints') + parser.add_argument( + '--feishu-webhook', default=None, help='Webhook of Feishu Group Chat Bot') + parser.add_argument('--gc-interval', default=100, type=int) + parser.add_argument( + '--checkpoint-interval', + default=-1, + type=float, + help=('how many steps to save a checkpoint; it can be a floating ' + 'point number less than 1, or an integer greater than or equal ' + "to 1. When it's a floating point, it will be multiplied by the " + 'total number of training steps.')) + parser.add_argument( + '--checkpoint-max-keep', + default=1, + type=int, + help=('Maximum number of saved checkpoints。')) + parser.add_argument( + '--checkpoint-drop-optimizer', + action='store_true', + help=('only model parameters are saved when saving a checkpoint. ' + 'This can significantly reduce the size of checkpoint files, ' + 'but the saved checkpoints cannot be resumed.')) + parser.add_argument( + '--log-interval', default=1, type=int, help='log interval') + parser.add_argument( + '--resume', + action='store_true', + help='specify checkpoint path to be resumed from.') + parser.add_argument( + '--seed', type=int, default=0, help='random seed for the training') + parser.add_argument( + '--debug', action='store_true', help='Set logger level to `DEBUG`') + args = parser.parse_args() + return args + + +def is_interval(step, total_steps, interval): + return (step + 1) % interval == 0 or (step + 1) == total_steps + + +def map_meta_modules(model, meta_model): + modules = {name: mod for name, mod in model.named_modules()} + meta_module_map = { + mod: modules[name] + for name, mod in meta_model.named_modules() + } + return meta_module_map + + +def build_llm_model(args, config, world_size, dtype=torch.float32): + with LoadWoInit(): + llm = AutoModelForCausalLM.from_pretrained( + args.llm, config=config, attn_implementation='flash_attention_2', + trust_remote_code=True) + + # Ensure all numerical values in the optimizer are fp32. + # FSDP will use low precision during forward. + llm.to(dtype) + + if args.use_lora: + llm.requires_grad_(False) + if world_size > 1: + llm.to(dtype) + + if args.lora_targets is None: + llm_cls = llm.__class__.__name__ + args.lora_targets = LORA_TARGET_MAP[llm_cls] + llm_lora_cfg = LoraConfig( + target_modules=args.lora_targets, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias=args.lora_bias, + task_type='CAUSAL_LM') + llm = get_peft_model(llm, llm_lora_cfg) + + return llm + + +class TrainState(Stateful): + + def __init__(self, total_steps, seed): + super().__init__() + + self.seed = seed + self.cur_step = -1 + self.total_steps = total_steps + self.if_nan_skip_steps = 0 + + def load_state_dict(self, state_dict): + assert self.total_steps == state_dict['total_steps'] + self.cur_step = state_dict['current_step'] + self.if_nan_skip_steps = state_dict['if_nan_skip_steps'] + + def state_dict(self): + return { + 'seed': self.seed, 'current_step': self.cur_step, + 'total_steps': self.total_steps, + 'if_nan_skip_steps': self.if_nan_skip_steps + } + + def step(self): + self.cur_step = self.cur_step + 1 + + def found_nan(self): + self.if_nan_skip_steps += 1 + + +def find_latest_timestamp(work_dir): + # Initialize variables to keep track of the latest timestamp and its corresponding directory + latest_timestamp = None + + # Iterate over all files and directories in the specified directory + for entry in os.listdir(work_dir): + full_path = os.path.join(work_dir, entry) + + # Check if the entry is a directory + if os.path.isdir(full_path): + try: + # Try to interpret the directory name as a timestamp + timestamp = datetime.strptime(entry, '%Y%m%d%H%M%S') + + # Update the latest timestamp and directory if this one is more recent + if latest_timestamp is None or timestamp > latest_timestamp: + latest_timestamp = timestamp + except ValueError: + # If conversion fails, skip this entry + continue + + if latest_timestamp is not None: + latest_timestamp = latest_timestamp.strftime( '%Y%m%d%H%M%S') + + return latest_timestamp + + +def find_checkpoints(directory, prefix='ckpt'): + + if prefix == 'ckpt': + pattern = r'^ckpt-(\d+)$' + elif prefix == 'hf': + pattern = r'^hf-(\d+)$' + else: + raise ValueError + + latest_step = -1 + latest_checkpoint = None + + all_folders = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] + + checkpoints = [] + for folder in all_folders: + match = re.match(pattern, folder) + if match: + # 将文件夹名和匹配到的数字转换为整数并存储为元组 + checkpoints.append((folder, int(match.group(1)))) + + checkpoints.sort(key=lambda x: x[1]) + + return [os.path.join(directory, folder[0]) for folder in checkpoints] + + + + +# @logger.catch +def sft(args): + ########################################################################### + # 1. Environment # + ########################################################################### + setup_parallel(sp_size=args.sp_size, tp_size=1) + set_random_seed(args.seed) + + dp_mesh = get_dp_mesh() + tp_mesh = get_tp_mesh() + sp_mesh = get_sp_mesh() + fsdp_mesh = get_fsdp_mesh() # dp_size * sp_size + world_mesh = get_world_mesh() # dp_size * sp_size * tp_size + + dp_size = dp_mesh.size() + tp_size = tp_mesh.size() + sp_size = sp_mesh.size() + world_size = world_mesh.size() + + cpu_comm_timeout = timedelta(minutes=60) + gloo_group = dist.new_group(backend='gloo', timeout=cpu_comm_timeout) + + if args.global_batch_size < dp_size or args.global_batch_size % dp_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + 'should be divisible by the ' + f'world_size({world_size}).') + + if (args.global_batch_size / dp_size) % args.mirco_batch_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + f'should be divisible by the world_size({world_size})' + f' * `mirco_batch_size`({args.mirco_batch_size})') + + rank = dist.get_rank() + + if args.resume: + mkdir_or_exist(args.work_dir) + timestamp = find_latest_timestamp(args.work_dir) + + if timestamp is None: + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + else: + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + + objects = [timestamp] + dist.broadcast_object_list(objects, src=0) + timestamp = objects[0] + + args.work_dir = os.path.join(args.work_dir, timestamp) + mkdir_or_exist(args.work_dir) + + log_file = os.path.join(args.work_dir, f'rank{rank}.log') + + logger.remove() + # Change the log format printed in the terminal + lvl = 'DEBUG' if args.debug else 'INFO' + logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug)) + # Change the format saved in the log file + logger.add(log_file, format=log_format(rank), backtrace=True, catch=True) + + if args.feishu_webhook and rank == 0: + def log_handler(record): + if record['level'].name == "WARNING": + send_to_feishu(args.feishu_webhook, f"[WARNING] {record['message']}\n{args.work_dir}") + elif record['level'].name == "TRACE": + send_to_feishu(args.feishu_webhook, f"[TRACE] {record['message']}\n{args.work_dir}") + elif record['level'].name == "ERROR": + send_to_feishu(args.feishu_webhook, f"[ERROR] 任务失败\n{args.work_dir}") + + logger.add(sys.stderr, level='TRACE', filter=log_handler, catch=True) + + logger.trace('任务开始') + + logger.info(args) + if rank == 0: + env = collect_env() + import transformers + + import xtuner + env['Transformers'] = transformers.__version__ + env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}' + runtime_env = OrderedDict() + runtime_env.update(env) + runtime_env['Seed'] = args.seed + runtime_env['World Size'] = world_size + runtime_env['DP Size'] = dp_size + runtime_env['SP Size'] = sp_size + runtime_env['TP Size'] = tp_size + # runtime_env['Distributed launcher'] = dist_launcher + + runtime_env_info = '\n ' + '\n '.join( + f'{k}: {v}' for k, v in runtime_env.items()) + dash_line = '-' * 60 + logger.info('\n' + dash_line + '\nRuntime environment:' + + runtime_env_info + '\n' + dash_line + '\n') + # ------------------- Environment End ------------------------------ # + + ########################################################################### + # 2. Dataset & Dataloader # + ########################################################################### + + start_load_data_t = time.time() + + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer if args.tokenizer else args.llm, + trust_remote_code=True, + use_fast=False, + padding_side='right') + + if args.chat_template: + chat_template = CHAT_TEMPLATE_MAP[args.chat_template] + else: + chat_template = None + + tokenize_fns = [] + for dset_format in args.dset_formats: + # If your data format is not in `SUPPORT_DATA_FORMATS`, you should + # redefine a `tokenize_fn`, defining how to convert a piece of raw + # data into tokenized data. + # The tokenized data must include `input_ids`, `labels``, + # and `num_tokens`. + tokenize_fn = SftTokenizeFunction(tokenizer, chat_template, + dset_format) + tokenize_fns.append(tokenize_fn) + + _datasets = load_datasets( + paths=args.datasets, + cache_dir=args.dset_cache_dir, + file_types=args.dset_file_types, + sources=args.dset_sources, + sample_ratios=args.dset_sample_ratios, + map_fns=tokenize_fns, + file_pattern=args.file_pattern, + max_length=args.max_length + ) + + if args.dset_pack_level and rank == 0 and args.debug: + # Only the tokenized datasets can count the number of tokens + num_tokens = sum(dset.num_tokens.sum() for dset in _datasets) + logger.debug(f'[Dataset] {num_tokens} tokens.') + + if args.dset_pack_level == 'soft': + train_dataset = SoftPackDataset(_datasets, target=args.max_length, blend=args.global_pack) + elif args.dset_pack_level == 'hard': + raise NotImplementedError + else: + train_dataset = ConcatDataset(_datasets) + + if args.dset_pack_level and rank == 0: + ori_samples = sum([len(dset) for dset in _datasets]) + packed_samples = len(train_dataset) + logger.info(f'[Dataset] (Original) {ori_samples} samples.') + logger.info(f'[Dataset] (Packed) {packed_samples} samples.') + + assert varlen_attn_is_available() + collator = SftCollator( + pack_batch=varlen_attn_is_available(), + max_length=args.max_length) + + if args.group_by_length: + sampler = LengthGroupedSampler(train_dataset, dp_mesh, + args.global_batch_size) + else: + sampler = ParallelSampler( + train_dataset, dp_mesh, args.global_batch_size, shuffle=True) + + gc.collect() + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.mirco_batch_size, + num_workers=args.num_workers, + # Ensure to round up or drop last based on the `global_batch_size`, + # if you want to replace a custom sampler. + sampler=sampler, + collate_fn=collator, + persistent_workers=args.num_workers > 0) + + if rank == 0: + logger.info(f'[Dataloader] {len(train_dataloader)} batches.') + _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)] + _first_batch = collator(_first_batch) + _decoded = tokenizer.batch_decode(_first_batch['input_ids']) + logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}') + logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}') + dist.barrier() + + gc.collect() + load_data_cost_time = time.time() - start_load_data_t + logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s') + # ------------------- Dataset & Dataloader End --------------------- # + + ########################################################################### + # 3. FSDP # + ########################################################################### + + start_model_t = time.time() + + if args.dtype == 'auto': + args.dtype = 'bf16' if DEVICE_MODULE.is_bf16_supported() else 'fp16' + + if args.dtype == 'fp16': + dtype = torch.float16 + autocast = torch.amp.autocast(DEVICE, enabled=True, dtype=dtype) + scaler = ShardedGradScaler() + elif args.dtype == 'bf16': + if DEVICE_MODULE.is_bf16_supported(): + dtype = torch.bfloat16 + autocast = torch.amp.autocast(DEVICE, enabled=True, dtype=dtype) + scaler = None + else: + raise RuntimeError('The device does not support `bf16`, ' + 'please set `dtype` to `fp16`.') + else: + raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, ' + f'but found {args.dtype}.') + + llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True) + if is_flash_attn_2_available(): + llm_cfg.attn_implementation = 'flash_attention_2' + + llm_cfg.use_cache = False + llm_cfg.torch_dtype = dtype + + + + # Only load parameters on rank 0 to avoid each rank repeatedly loading the + # same model into the CPU, wasting memory + if rank == 0: + with torch.device('cpu'): + rank0_llm = build_llm_model(args, llm_cfg, world_size, dtype) + else: + rank0_llm = None + + dist.monitored_barrier(group=gloo_group, timeout=cpu_comm_timeout) + + with torch.device('meta'): + # Ensure all numerical values in the optimizer are fp32. + # FSDP will use low precision during forward. + llm = build_llm_model(args, llm_cfg, world_size, dtype) + dispatch_hf_code(llm) + for module in llm.modules(): + for p_name, param in module.named_parameters(recurse=False): + if param.requires_grad: + param_fp32 = torch.nn.Parameter( + param.to(dtype=torch.float32)) + setattr(module, p_name, param_fp32) + + mp_policy = MixedPrecisionPolicy(param_dtype=dtype, reduce_dtype=dtype) + + with profile_time_and_memory('[Parallelize LLM]'): + megatron_parallelize( + llm, + rank0_llm, + dp_mesh=fsdp_mesh, + tp_mesh=tp_mesh, + mp_policy=mp_policy, + recompute_ratio=args.selective_recompute, + reshard_after_forward=True) + + llm.train() + + dist.barrier() + gc.collect() + # -------------------------- FSDP End ------------------------------ # + + ########################################################################### + # 4. Optimizer & Scheduler # + ########################################################################### + requried_grad_params = [ + param for param in llm.parameters() if param.requires_grad + ] + optimizer = AdamW( + requried_grad_params, + lr=args.lr, + weight_decay=args.wd, + betas=(0.9, 0.95)) + + global_batch_size = args.global_batch_size + mirco_batch_size = args.mirco_batch_size + + # `iter` means once forward+backward + # `step` means once optimizer step + # `iters_per_step` means gradient accumulative counts + iters_per_step = global_batch_size // mirco_batch_size // dp_size + iters_per_epoch = len(train_dataloader) + steps_per_epoch = math.ceil(iters_per_epoch / iters_per_step) + + total_epochs = args.epochs + total_steps = steps_per_epoch * total_epochs + if_nan_skip_steps = 0 + train_state = TrainState(total_steps, args.seed) + + if args.checkpoint_interval == -1: + checkpoint_interval = total_steps + elif args.checkpoint_interval < 1: + checkpoint_interval = int(total_steps * args.checkpoint_interval) + else: + checkpoint_interval = int(args.checkpoint_interval) + + warmup_steps = int(args.warmup_ratio * total_steps) + + def warmup_fn(x): + return x / warmup_steps if x < warmup_steps else 1 + + warmup_scheduler = LambdaLR(optimizer, warmup_fn) + + cosine_scheduler = CosineAnnealingLR( + optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min) + + start_step = 0 + gc.collect() + # ---------------- Optimizer & Scheduler End ----------------------- # + + ########################################################################### + # 5. (Optional) Resume # + ########################################################################### + if args.resume: + + + _checkpoints = find_checkpoints(args.work_dir) + + latest_checkpoint = None + + for _ckpt_dir in reversed(_checkpoints): + if os.path.exists(os.path.join(_ckpt_dir, '.metadata')): + latest_checkpoint = _ckpt_dir + break + + if latest_checkpoint: + + with profile_time_and_memory('[Resume]'): + _options = StateDictOptions( + cpu_offload=True, ignore_frozen_params=True) + (shard_model_state_dict, + shard_optimizer_state_dict) = get_state_dict( + llm, optimizer, options=_options) + state_dict = { + 'model': shard_model_state_dict, + 'optimizer': shard_optimizer_state_dict, + 'train_state': train_state, + 'warmup_scheduler': warmup_scheduler, + 'cosine_scheduler': cosine_scheduler + } + + # inplace state_dict + dcp.load( + state_dict=state_dict, + checkpoint_id=latest_checkpoint, + ) + + _options = StateDictOptions( + cpu_offload=True, strict=False) + set_state_dict( + llm, + optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optimizer"], + options=_options + ) + + start_step = train_state.cur_step + 1 + + else: + logger.warning(f'There is no checkpoint available for resuming training in {args.work_dir}.') + + ########################################################################### + # 6. Training # + ########################################################################### + ckpt_handle = None + start_train_t = time.time() + DEVICE_MODULE.empty_cache() + DEVICE_MODULE.reset_peak_memory_stats() + max_memory = DEVICE_MODULE.max_memory_allocated() + logger.info('[Train] Begin Train Loop. The current GPU memory is ' + f'{(max_memory / 1024**3):.1f}GB') + + for step in range(start_step, total_steps): + + if is_interval(step + 1, total_steps, args.gc_interval): + gc.collect() + + epoch = step // steps_per_epoch + epoch_inner_step = step % steps_per_epoch + if epoch_inner_step == 0 or step == start_step: + # For the first step of each epoch, the data order needs to be + # readjusted. + # Or after resuming, for the first step, the dataloader needs to + # be adjusted to the position before resume. + train_dataloader.sampler.set_epoch(epoch, epoch_inner_step * iters_per_step ) + data_iterator = iter(train_dataloader) + + train_state.step() + + if step <= warmup_steps: + warmup_scheduler.step() + cur_lr = warmup_scheduler.get_last_lr()[0] + else: + cosine_scheduler.step() + cur_lr = cosine_scheduler.get_last_lr()[0] + + DEVICE_MODULE.reset_peak_memory_stats() + + step_loss = 0 + step_data_time = 0 + step_start_t = time.time() + step_consumed_tokens = 0 + + _data_start_t = time.time() + + step_data_list = [next(data_iterator) for _ in range(iters_per_step)] + rank_grad_tokens = 0 + for _iter in range(iters_per_step): + _iter_data = step_data_list[_iter] + _iter_labels = _iter_data['labels'][:, 1:] + rank_grad_tokens += (_iter_labels >= 0).sum() + rank_grad_tokens = rank_grad_tokens.to(DEVICE) + dist.all_reduce(rank_grad_tokens) + global_grad_tokens = rank_grad_tokens / sp_size / tp_size + + + step_data_time = time.time() - _data_start_t + + for _iter in range(iters_per_step): + + data = step_data_list[_iter] + input_ids = data['input_ids'][:, :-1].to(DEVICE) + + labels = data['labels'][:, 1:].to(DEVICE) + num_tokens = data['num_tokens'].to(DEVICE) + + if num_tokens[-1] == 1: + num_tokens = num_tokens[:-1] + else: + num_tokens[-1] = num_tokens[-1] - 1 + + if sp_size > 1: + # `dim` is 1 as the shape of tensor is (bs, seq_len, ...) + input_ids = pad_for_sequence_parallel(input_ids, 0, sp_mesh, dim=1) + _num_pad = input_ids.numel() - num_tokens.sum() + if _num_pad > 0: + _num_pad = torch.IntTensor([_num_pad]).to(DEVICE) + num_tokens = torch.cat([num_tokens, _num_pad], dim=-1) + + input_ids = split_for_sequence_parallel( + input_ids, dim=1, sp_mesh=sp_mesh) + + labels = pad_for_sequence_parallel(labels, -100,sp_mesh, dim=1) + labels = split_for_sequence_parallel( + labels, dim=1, sp_mesh=sp_mesh) + + packed_ctx = packed_sequence(num_tokens, sp_mesh=sp_mesh) + + with packed_ctx, autocast if args.use_lora else nullcontext(): + loss = llm(input_ids=input_ids, labels=labels, label_shifted=True, use_cache=False).loss + + loss = loss * (labels >= 0).sum() / global_grad_tokens * dp_size + + if scaler and args.use_lora: + scaler.scale(loss).backward() + else: + loss.backward() + + step_loss += loss.item() + step_consumed_tokens += num_tokens.sum() / sp_size / tp_size + + step_reduced_loss = torch.Tensor([step_loss]).to(DEVICE) + dist.all_reduce(step_reduced_loss) + step_reduced_loss = step_reduced_loss.item() / world_size + + grad_norm = clip_grad_norm_( + requried_grad_params, fsdp_mesh, args.max_grad_norm) + + if grad_norm.isnan() or grad_norm.isinf(): + train_state.found_nan() + logger.warning(f"[Step {step}] The grad norm is NaN or Inf, skip this step. Skipped {train_state.if_nan_skip_steps} steps in total.") + optimizer.zero_grad() + else: + optimizer.step() + optimizer.zero_grad() + + step_time = time.time() - step_start_t + eta = step_time * (total_steps - step) + eta = timedelta(seconds=int(eta)) + tgs = int(step_consumed_tokens / step_time) + max_memory = DEVICE_MODULE.max_memory_allocated() + if is_interval(step, total_steps, args.log_interval): + logger.info(f'[Train] (Epoch {epoch + 1}) Step ' + f'{step + 1}/{total_steps} ' + f'lr: {cur_lr:.6f} loss: {step_loss:.3f} ' + f'loss(reduced): {step_reduced_loss:.3f} ' + f'grad_norm: {grad_norm:.2f} ' + f'if_nan_skip: {train_state.if_nan_skip_steps} ' + f'max_memory: {(max_memory / 1024**3):.1f}GB ' + f'text_tokens: {step_consumed_tokens} ' + f'tgs: {tgs} data_time: {step_data_time:.2f}s ' + f'time: {step_time:.2f}s ' + f'eta: {eta}') + + if is_interval(step, total_steps, max(1, int(total_steps * 0.1))): + logger.trace(f'Step {step}/{total_steps}, loss {step_loss:.3f}, tgs {tgs}') + + if is_interval(step, total_steps, checkpoint_interval): + + num_digits = len(str(abs(total_steps))) + work_dir = args.work_dir + ckpt_dir = os.path.join(work_dir, f'ckpt-{step+1:0{num_digits}}') + hf_dir = os.path.join(work_dir, f'hf-{step+1:0{num_digits}}') + + with profile_time_and_memory('[HF Checkpoint]'): + + from torch.distributed._tensor import DTensor + + if rank == 0: + llm_state_dict = {} + + for name, param in llm.state_dict().items(): + if isinstance(param, DTensor): + with torch.no_grad(): + full_param = param.full_tensor().cpu() + else: + full_param = param.cpu() + + if rank == 0: + llm_state_dict[name] = full_param + + if rank == 0: + rank0_llm.load_state_dict(llm_state_dict) + rank0_llm.save_pretrained(hf_dir) + tokenizer.save_pretrained(hf_dir) + + dist.barrier() + + saved_hf_checkpoints = find_checkpoints(args.work_dir, prefix='hf') + + if len(saved_hf_checkpoints) > args.checkpoint_max_keep: + for _ckpt in saved_hf_checkpoints[:-args.checkpoint_max_keep]: + if rank == 0: + shutil.rmtree(_ckpt) + logger.info('[HF Checkpoint] Delete the oldest checkpoint.') + + + if args.checkpoint_drop_optimizer: + logger.warning('The saved checkpoint cannot be resumed. ' + 'If you want to save a resumable checkpoint, ' + 'please remove `--checkpoint-drop-optimizer` ' + 'from the command.') + else: + + with profile_time_and_memory('[PT Checkpoint]'): + if ckpt_handle is not None: + wait([ckpt_handle]) + + # FSDP cannot be saved via torch.save + # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501 + _options = StateDictOptions( + cpu_offload=True, ignore_frozen_params=True) + (shard_model_state_dict, + shard_optimizer_state_dict) = get_state_dict( + llm, optimizer, options=_options) + + state_dict = { + 'model': shard_model_state_dict, + 'optimizer': shard_optimizer_state_dict, + 'train_state': train_state.state_dict(), + 'warmup_scheduler': warmup_scheduler.state_dict(), + 'cosine_scheduler': cosine_scheduler.state_dict() + } + + mkdir_or_exist(ckpt_dir) + ckpt_handle = dcp.async_save(state_dict, checkpoint_id=ckpt_dir, process_group=gloo_group) + + saved_checkpoints = find_checkpoints(args.work_dir) + + if len(saved_checkpoints) > args.checkpoint_max_keep: + for _ckpt in saved_checkpoints[:-args.checkpoint_max_keep]: + if rank == 0: + shutil.rmtree(_ckpt) + logger.info('[PT Checkpoint] Delete the oldest checkpoint.') + + if ckpt_handle is not None: + wait([ckpt_handle]) + + logger.trace('Task Finished') + + train_cost_time = time.time() - start_train_t + logger.info(f'[Train] Cost {timedelta(seconds=int(train_cost_time))}') + # ------------------------ Training End ---------------------------- # + +if __name__ == '__main__': + + args = parse_args() + sft(args) diff --git a/tools/fsdp_tp_sft.py b/tools/fsdp_tp_sft.py new file mode 100644 index 000000000..321829496 --- /dev/null +++ b/tools/fsdp_tp_sft.py @@ -0,0 +1,786 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import math +import os +import sys +import time +from collections import OrderedDict +from datetime import datetime, timedelta +from functools import partial + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from accelerate.utils import set_module_tensor_to_device +from mmengine import mkdir_or_exist +from mmengine.dist import infer_launcher, init_dist +from mmengine.runner import set_random_seed +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env +from torch.distributed._tensor import Replicate +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + apply_activation_checkpointing +from torch.distributed.checkpoint.state_dict import (StateDictOptions, + get_model_state_dict, + get_state_dict) +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.api import CPUOffload, ShardingStrategy +from torch.distributed.tensor.parallel import (ColwiseParallel, + RowwiseParallel, + parallelize_module) +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from torch.utils.data import ConcatDataset, DataLoader +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.utils.import_utils import (is_flash_attn_2_available, + is_torch_sdpa_available) + +from xtuner._lite import AutoTokenizer, get_logger +from xtuner._lite.accelerate import dispatch_modules, packed_sequence +from xtuner._lite.chat import CHAT_TEMPLATE_MAP +from xtuner._lite.datasets import (OPENAI_FORMAT_MAP, SoftPackerForText, + TextCollator, TextOnlineTokenizeDataset, + TextTokenizedDataset, TextTokenizeFunction) +from xtuner._lite.datasets.load import (LOAD_FN_MAP, load_datasets, + load_from_cache) +from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler, + get_dp_mesh, get_tp_mesh, setup_parallel) +from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit, + checkpoint_check_fn, dp_tp_lazy_init, + layer_auto_wrap_policy) + +logger = get_logger() + +SUPPORT_DATA_FORMATS = OPENAI_FORMAT_MAP.keys() + + +def log_format(rank, debug=False): + + formatter = f'[XTuner][RANK {rank}]' + formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]' + + if debug: + formatter += '[{name}:' + formatter += '{function}:' + formatter += '{line}]' + + formatter += ' {message}' + return formatter + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + + model_args = parser.add_argument_group('model', 'Model Related Settings') + model_args.add_argument('--llm', help='repo id or local path of the model') + model_args.add_argument( + '-t', + '--tokenizer', + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--chat-template', + choices=CHAT_TEMPLATE_MAP.keys(), + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--dtype', + default='auto', + choices=['fp16', 'bf16', 'auto'], + help=("the dtype of the model forward. When set to 'auto', it will " + 'automatically determine whether bf16 is available, ' + 'prioritizing the use of bf16.')) + + model_args.add_argument( + '--selective-recompute', + default=1.0, + type=float, + help=('the ratio of re-computation for transforemer layers. ' + 'The maximum is 1; the larger the value, the less memory ' + 'required for training. The default is 1, meaning all layers ' + 'need to be re-computated.')) + model_args.add_argument( + '--shard-strategy', + default='full', + choices=['full', 'hybrid'], + help=('The sharding strategy to be used for distributed training.')) + model_args.add_argument('--cpu-offload', action='store_true', help=('')) + model_args.add_argument( + '--tp-size', type=int, default=1, help='Tensor Parallel Size') + + data_args = parser.add_argument_group('data', 'Dataset Related Settings') + data_args.add_argument( + '--datasets', + nargs='*', + help=('repo id or local path or dir of the datasets. For repo ids, ' + 'the `dset-sources` needs to be appropriately set to ' + '`modelscope` or `huggingface`. For local dir, all json and ' + 'jsonl files will be loaded by default. The type of loaded ' + 'files can be controlled by setting `dset-file-type`')) + data_args.add_argument( + '--dset-file-types', + nargs='*', + default=LOAD_FN_MAP.keys(), + choices=LOAD_FN_MAP.keys(), + help='the file type that needs to be loaded') + data_args.add_argument( + '--dset-sources', + nargs='*', + default=['local'], + choices=['local', 'huggingface', 'modelscope'], + help=('the source of each dataset; it can accept one or the same ' + 'number of args as the number of `datasets`, with one arg ' + 'indicating that all datasets come from the same source. ' + '`local` represents the local path, `huggingface` represents ' + 'the open-source data in the Huggingface Hub, `modelscope` ' + 'indicates the open-source data in the Modelscope Hub.')) + data_args.add_argument( + '--dset-formats', + nargs='*', + default=['openai'], + help=('the format of each dataset; it can accept one or the same ' + 'number of args as the number of `datasets`, with one arg ' + 'indicating that all datasets are the same format.')) + data_args.add_argument( + '--dset-sample-ratios', + nargs='*', + default=[1.0], + help=('the sample ratio of each dataset; it can accept one or the ' + 'same number of args as the number of `datasets`, with one arg ' + 'indicating that all datasets use the same sample ratio.')) + data_args.add_argument( + '--dset-cache-dir', + help=('the cache dir of the loaded datasets. When the `datasets` is ' + 'set, the loaded datasets will be cached to this dir. If the ' + '`datasets` are not set, the cached dataset in this dir will be ' + 'loaded.')) + data_args.add_argument( + '--dset-from-cache', + action='store_true', + help=('Load data directly from `dset-cache-dir`. This can save time ' + 'on online tokenization, but if the tokenizer changed, ' + 'recaching is needed.')) + data_args.add_argument( + '--dset-pack-level', + choices=['hard', 'soft'], + help=('the level of data packing. When `hard`, multiple data will be ' + 'packed to `max_length`, potentially causing some data to be ' + 'truncated, and the length of the packed data will always ' + 'be `max_length`; When `soft`, it will pack multiple data ' + 'into nearly `max_length` without truncating the data.')) + data_args.add_argument( + '--max-length', + type=int, + default=2048, + help=('the maximum length of each piece of data, any excess will be ' + 'truncated.')) + data_args.add_argument( + '--num-workers', + type=int, + default=0, + help='how many subprocesses to use for data loading.') + data_args.add_argument( + '--num-proc', + type=int, + default=8, + help='how many subprocesses to use for data mapping.') + data_args.add_argument('--file-pattern', type=str, default=None) + data_args.add_argument('--group-by-length', action='store_true') + + optim_args = parser.add_argument_group('optim', 'Optim Related Settings') + optim_args.add_argument( + '--mirco-batch-size', + type=int, + default=1, + help='batch size for each forward + backward pass') + optim_args.add_argument( + '--global-batch-size', + type=int, + default=16, + help='batch size for each optimizer step') + + optim_args.add_argument( + '--lr', default=4e-5, type=float, help='learning rate.') + optim_args.add_argument( + '--lr-min', default=6e-6, type=float, help='min learning rate.') + optim_args.add_argument( + '--wd', default=0.01, type=float, help='weight decay.') + optim_args.add_argument( + '--max-grad-norm', default=1, type=float, help='gradient clipping') + optim_args.add_argument( + '-e', '--epochs', default=1, type=int, help='total training epochs.') + optim_args.add_argument( + '--warmup-ratio', + default=0.03, + type=float, + help=('the proportion of training steps for learning rate warm-up in ' + 'relation to the total training steps.')) + + parser.add_argument('-c', '--config', default=None) + parser.add_argument( + '--work-dir', + default='work_dirs', + help='the dir to save logs and checkpoints') + parser.add_argument( + '--checkpoint-interval', + default=-1, + type=float, + help=('how many steps to save a checkpoint; it can be a floating ' + 'point number less than 1, or an integer greater than or equal ' + "to 1. When it's a floating point, it will be multiplied by the " + 'total number of training steps.')) + parser.add_argument( + '--checkpoint-drop-optimizer', + action='store_true', + help=('only model parameters are saved when saving a checkpoint. ' + 'This can significantly reduce the size of checkpoint files, ' + 'but the saved checkpoints cannot be resumed.')) + parser.add_argument( + '--log-interval', default=1, type=int, help='log interval') + parser.add_argument( + '--resume', + type=str, + default=None, + help='specify checkpoint path to be resumed from.') + parser.add_argument( + '--seed', type=int, default=0, help='random seed for the training') + parser.add_argument( + '--debug', action='store_true', help='Set logger level to `DEBUG`') + args = parser.parse_args() + return args + + +def is_interval(step, total_steps, interval): + return (step + 1) % interval == 0 or (step + 1) == total_steps + + +def map_meta_modules(model, meta_model): + modules = {name: mod for name, mod in model.named_modules()} + meta_module_map = { + mod: modules[name] + for name, mod in meta_model.named_modules() + } + return meta_module_map + + +def build_llm_model(args, config, world_size, dtype=torch.float32): + with LoadWoInit(): + llm = AutoModelForCausalLM.from_pretrained( + args.llm, config=config, trust_remote_code=True) + + # Ensure all numerical values in the optimizer are fp32. + # FSDP will use low precision during forward. + llm.to(dtype) + return llm + + +# @logger.catch +def sft(args): + ########################################################################### + # 1. Environment # + ########################################################################### + dist_launcher = infer_launcher() + init_dist(dist_launcher) + set_random_seed(args.seed) + + world_size = int(os.environ['WORLD_SIZE']) + tp_size = args.tp_size + dp_size = world_size // tp_size + + if args.global_batch_size < dp_size or args.global_batch_size % dp_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + 'should be divisible by the ' + f'world_size({world_size}).') + + if (args.global_batch_size / dp_size) % args.mirco_batch_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + f'should be divisible by the world_size({world_size})' + f' * `mirco_batch_size`({args.mirco_batch_size})') + + if args.dset_cache_dir and os.path.isdir(args.dset_cache_dir): + if len(os.listdir(args.dset_cache_dir)) and not args.dset_from_cache: + raise RuntimeError(f'`{args.dset_cache_dir}` is not an empty ' + 'folder, which may lead to inaccurate ' + 'cache results.') + + setup_parallel(tp_size=tp_size) + dp_mesh = get_dp_mesh() + tp_mesh = get_tp_mesh() + + rank = dist.get_rank() + + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + + objects = [timestamp] + dist.broadcast_object_list(objects, src=0) + timestamp = objects[0] + + args.work_dir = os.path.join(args.work_dir, timestamp) + mkdir_or_exist(args.work_dir) + + log_file = os.path.join(args.work_dir, f'rank{rank}.log') + + # Change the log format printed in the terminal + lvl = 'DEBUG' if args.debug else 'INFO' + logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug)) + # Change the format saved in the log file + logger.add(log_file, format=log_format(rank), backtrace=True, catch=True) + + logger.info(args) + if rank == 0: + env = collect_env() + import transformers + + import xtuner + env['Transformers'] = transformers.__version__ + env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}' + runtime_env = OrderedDict() + runtime_env.update(env) + runtime_env['Seed'] = args.seed + runtime_env['World Size'] = world_size + runtime_env['Distributed launcher'] = dist_launcher + + runtime_env_info = '\n ' + '\n '.join( + f'{k}: {v}' for k, v in runtime_env.items()) + dash_line = '-' * 60 + logger.info('\n' + dash_line + '\nRuntime environment:' + + runtime_env_info + '\n' + dash_line + '\n') + # ------------------- Environment End ------------------------------ # + + ########################################################################### + # 2. Dataset & Dataloader # + ########################################################################### + + start_load_data_t = time.time() + + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer if args.tokenizer else args.llm, + trust_remote_code=True, + padding_side='right') + + if args.dset_from_cache: + if args.dset_pack_level == 'soft': + init_fn = partial( + SoftPackerForText.from_cache, max_length=args.max_length) + elif args.dset_pack_level == 'hard': + raise NotImplementedError + else: + init_fn = partial( + TextTokenizeFunction.from_cache, max_length=args.max_length) + _datasets = load_from_cache(args.dset_cache_dir, init_fn) + dist.barrier() + else: + chat_template = CHAT_TEMPLATE_MAP[args.chat_template] + tokenize_fns = [] + init_fns = [] + for dset_format in args.dset_formats: + # If your data format is not in `SUPPORT_DATA_FORMATS`, you should + # redefine a `tokenize_fn`, defining how to convert a piece of raw + # data into tokenized data. + # The tokenized data must include `input_ids`, `labels``, + # and `num_tokens`. + tokenize_fn = TextTokenizeFunction(tokenizer, chat_template, + dset_format) + + if args.dset_pack_level == 'soft': + init_fn = partial( + SoftPackerForText, max_length=args.max_length) + elif args.dset_cache_dir: + init_fn = partial( + TextTokenizedDataset, max_length=args.max_length) + else: + init_fn = partial( + TextOnlineTokenizeDataset, tokenize_fn=tokenize_fn) + # Online tokenization is used when not using a pack dataset, + # saving startup time. + tokenize_fn = None + + tokenize_fns.append(tokenize_fn) + init_fns.append(init_fn) + + _datasets = load_datasets( + paths=args.datasets, + cache_dir=args.dset_cache_dir, + file_types=args.dset_file_types, + sources=args.dset_sources, + sample_ratios=args.dset_sample_ratios, + num_proc=args.num_proc, + map_fns=tokenize_fns, + init_fns=init_fns, + file_pattern=args.file_pattern) + + if (args.dset_pack_level or args.cache_dir) and rank == 0 and args.debug: + # Only the tokenized datasets can count the number of tokens + num_tokens = sum(sum(dset['num_tokens']) for dset in _datasets) + logger.debug(f'[Dataset] {num_tokens} tokens.') + + train_dataset = ConcatDataset(_datasets) + + if args.dset_pack_level and rank == 0: + ori_samples = sum([len(dset) for dset in _datasets]) + packed_samples = len(train_dataset) + logger.info(f'[Dataset] (Original) {ori_samples} samples.') + logger.info(f'[Dataset] (Packed) {packed_samples} samples.') + + pack_batch = is_flash_attn_2_available() + collator = TextCollator(pack_batch=pack_batch) + + if args.group_by_length: + sampler = LengthGroupedSampler(train_dataset, dp_mesh, + args.global_batch_size) + else: + sampler = ParallelSampler( + train_dataset, dp_mesh, args.global_batch_size, shuffle=True) + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.mirco_batch_size, + num_workers=args.num_workers, + # Ensure to round up or drop last based on the `global_batch_size`, + # if you want to replace a custom sampler. + sampler=sampler, + collate_fn=collator, + persistent_workers=args.num_workers > 0) + + if rank == 0: + logger.info(f'[Dataloader] {len(train_dataloader)} batches.') + _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)] + _first_batch = collator(_first_batch) + _decoded = tokenizer.batch_decode(_first_batch['input_ids']) + logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}') + logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}') + dist.barrier() + + load_data_cost_time = time.time() - start_load_data_t + logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s') + # ------------------- Dataset & Dataloader End --------------------- # + + ########################################################################### + # 3. FSDP # + ########################################################################### + + start_model_t = time.time() + + if args.dtype == 'auto': + args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16' + + if args.dtype == 'fp16': + dtype = torch.float16 + elif args.dtype == 'bf16': + if torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + else: + raise RuntimeError('The device does not support `bf16`, ' + 'please set `dtype` to `fp16`.') + else: + raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, ' + f'but found {args.dtype}.') + + llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True) + if is_flash_attn_2_available(): + llm_cfg.attn_implementation = 'flash_attention_2' + elif is_torch_sdpa_available(): + llm_cfg.attn_implementation = 'sdpa' + + llm_cfg.use_cache = False + llm_cfg.torch_dtype = dtype + + with torch.device('meta'): + # Ensure all numerical values in the optimizer are fp32. + # FSDP will use low precision during forward. + meta_llm = build_llm_model(args, llm_cfg, world_size, torch.float32) + + if pack_batch: + dispatch_modules(meta_llm) + + # Only load parameters on rank 0 to avoid each rank repeatedly loading the + # same model into the CPU, wasting memory + if rank == 0: + with torch.device('cpu'): + llm = build_llm_model(args, llm_cfg, world_size, dtype) + rank0_meta_llm = copy.deepcopy(meta_llm) + meta_llm_map = map_meta_modules(llm, meta_llm) + else: + meta_llm_map = None + + dist.barrier() + + if args.tp_size > 1: + layer_tp_plan = { + 'attention.wqkv': ColwiseParallel(), + 'attention.wo': RowwiseParallel(), + 'feed_forward.w1': ColwiseParallel(), + 'feed_forward.w2': RowwiseParallel(), + 'feed_forward.w3': ColwiseParallel(), + } + + for layer in meta_llm.model.layers: + attention = layer.attention + attention.num_heads = attention.num_heads // tp_mesh.size() + attention.hidden_size = attention.hidden_size // tp_mesh.size() + parallelize_module( + module=layer, + device_mesh=tp_mesh, + parallelize_plan=layer_tp_plan, + ) + + meta_llm = parallelize_module( + module=meta_llm, + device_mesh=tp_mesh, + parallelize_plan={ + 'model.tok_embeddings': + RowwiseParallel(input_layouts=Replicate(), ), + 'output': ColwiseParallel(output_layouts=Replicate(), ), + }) + + param_init_fn = partial( + dp_tp_lazy_init, + module_map=meta_llm_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh) + + if args.shard_strategy == 'full': + fsdp_device_mesh = dp_mesh + strategy = ShardingStrategy.FULL_SHARD + elif args.shard_strategy == 'hybrid': + fsdp_device_mesh = init_device_mesh('cuda', (dp_size // 8, 8)) + strategy = ShardingStrategy.HYBRID_SHARD + else: + raise ValueError + + torch.cuda.reset_peak_memory_stats() + shard_llm = FSDP( + meta_llm, + device_mesh=fsdp_device_mesh, + sharding_strategy=strategy, + cpu_offload=CPUOffload(offload_params=args.cpu_offload), + auto_wrap_policy=layer_auto_wrap_policy, + mixed_precision=MixedPrecision( + param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype), + device_id=torch.cuda.current_device(), + use_orig_params=True, + param_init_fn=param_init_fn, + sync_module_states=True, + ) + + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Model] During building the FSDP model, the peak GPU memory ' + f'is {max_memory/1024**3:.1f}GB.') + + if args.selective_recompute: + check_fn = partial( + checkpoint_check_fn, + target=RECOMPUTE_MODULES, + selective=args.selective_recompute) + apply_activation_checkpointing(shard_llm, check_fn=check_fn) + + fsdp_cost_time = time.time() - start_model_t + logger.info(f'[Model] Cost {fsdp_cost_time:.2f}s') + # -------------------------- FSDP End ------------------------------ # + + ########################################################################### + # 4. Optimizer & Scheduler # + ########################################################################### + requried_grad_params = [ + param for param in shard_llm.parameters() if param.requires_grad + ] + optimizer = AdamW( + requried_grad_params, + lr=args.lr, + weight_decay=args.wd, + betas=(0.9, 0.95)) + + global_batch_size = args.global_batch_size + mirco_batch_size = args.mirco_batch_size + + # `iter` means once forward+backward + # `step` means once optimizer step + # `iters_per_step` means gradient accumulative counts + iters_per_step = global_batch_size // mirco_batch_size // dp_size + iters_per_epoch = len(train_dataloader) + steps_per_epoch = math.ceil(iters_per_epoch / iters_per_step) + + total_epochs = args.epochs + total_steps = steps_per_epoch * total_epochs + + if args.checkpoint_interval == -1: + checkpoint_interval = total_steps + elif args.checkpoint_interval < 1: + checkpoint_interval = int(total_steps * args.checkpoint_interval) + else: + checkpoint_interval = int(args.checkpoint_interval) + + warmup_steps = int(args.warmup_ratio * total_steps) + + def warmup_fn(x): + return x / warmup_steps if x < warmup_steps else 1 + + warmup_scheduler = LambdaLR(optimizer, warmup_fn) + + cosine_scheduler = CosineAnnealingLR( + optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min) + + start_step = 0 + + # ---------------- Optimizer & Scheduler End ----------------------- # + + ########################################################################### + # 5. Training # + ########################################################################### + + start_train_t = time.time() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Train] Begin Train Loop. The current GPU memory is ' + f'{(max_memory / 1024**3):.1f}GB') + for step in range(start_step, total_steps): + + epoch = step // steps_per_epoch + epoch_inner_step = step % steps_per_epoch + if epoch_inner_step == 0 or step == start_step: + # For the first step of each epoch, the data order needs to be + # readjusted. + # Or after resuming, for the first step, the dataloader needs to + # be adjusted to the position before resume. + + train_dataloader.sampler.set_epoch(epoch, epoch_inner_step) + data_iterator = iter(train_dataloader) + + if step < warmup_steps: + warmup_scheduler.step() + cur_lr = warmup_scheduler.get_last_lr()[0] + else: + cosine_scheduler.step() + cur_lr = cosine_scheduler.get_last_lr()[0] + + torch.cuda.reset_peak_memory_stats() + + step_loss = 0 + step_data_time = 0 + step_start_t = time.time() + step_consumed_tokens = 0 + for _ in range(iters_per_step): + + _data_start_t = time.time() + data = next(data_iterator) + step_data_time += time.time() - _data_start_t + + input_ids = data['input_ids'].cuda() + labels = data['labels'].cuda() + attention_mask = data['attention_mask'].cuda() + num_tokens = data['num_tokens'].cuda() + + packed_ctx = packed_sequence(num_tokens, enable=pack_batch) + + with packed_ctx: + + outputs = shard_llm( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask) + + avg_iter_loss = outputs.loss / iters_per_step + avg_iter_loss.backward() + + step_loss += avg_iter_loss.item() + if args.dset_pack_level == 'soft': + # During a soft pack process, the data with a length that is + # still smaller than the max length after packing, will be + # padded to the max length. The last element of num tokens + # represents the count of pad tokens. + step_consumed_tokens += num_tokens[:-1].sum() / tp_size + else: + step_consumed_tokens += num_tokens.sum() / tp_size + + grad_norm = shard_llm.clip_grad_norm_(args.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + + step_time = time.time() - step_start_t + eta = step_time * (total_steps - step) + eta = timedelta(seconds=int(eta)) + tgs = int(step_consumed_tokens / step_time) + max_memory = torch.cuda.max_memory_allocated() + if is_interval(step, total_steps, args.log_interval): + logger.info(f'[Train] (Epoch {epoch + 1}) Step ' + f'{step + 1}/{total_steps} ' + f'lr: {cur_lr:.6f} loss: {step_loss:.3f} ' + f'grad_norm: {grad_norm:.2f} ' + f'max_memory: {(max_memory / 1024**3):.1f}GB ' + f'text_tokens: {step_consumed_tokens} ' + f'tgs: {tgs} data_time: {step_data_time:.2f}s ' + f'time: {step_time:.2f}s ' + f'eta: {eta}') + + if is_interval(step, total_steps, checkpoint_interval): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Checkpoint] Before saving checkpoint, the peak GPU ' + f'memory is {max_memory/1024**3:.1f}GB.') + + num_digits = len(str(abs(total_steps))) + work_dir = args.work_dir + ckpt_dir = os.path.join(work_dir, f'ckpt-{step+1:0{num_digits}}') + hf_dir = os.path.join(work_dir, f'hf-{step+1:0{num_digits}}') + _options = StateDictOptions(cpu_offload=True, full_state_dict=True) + + full_model_state_dict = get_model_state_dict( + shard_llm, options=_options) + if rank == 0: + saved_llm = copy.deepcopy(rank0_meta_llm) + saved_llm.to(dtype) + for name, param in full_model_state_dict.items(): + set_module_tensor_to_device(saved_llm, name, 'cpu', param) + + saved_llm.save_pretrained(hf_dir) + tokenizer.save_pretrained(hf_dir) + del saved_llm + + dist.barrier() + del full_model_state_dict + + if args.checkpoint_drop_optimizer: + logger.warning('The saved checkpoint cannot be resumed. ' + 'If you want to save a resumable checkpoint, ' + 'please remove `--checkpoint-drop-optimizer` ' + 'from the command.') + else: + # FSDP cannot be saved via torch.save + # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501 + _options = StateDictOptions( + cpu_offload=True, ignore_frozen_params=True) + (shard_model_state_dict, + shard_optimizer_state_dict) = get_state_dict( + shard_llm, optimizer, options=_options) + + state_dict = { + 'model': shard_model_state_dict, + 'optimizer': shard_optimizer_state_dict, + 'step': step, + 'total_steps': total_steps, + 'warmup_scheduler': warmup_scheduler.state_dict(), + 'cosine_scheduler': cosine_scheduler.state_dict() + } + + writer = dcp.FileSystemWriter(ckpt_dir) + mkdir_or_exist(ckpt_dir) + dcp.save(state_dict, writer) + + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Checkpoint] During saving checkpoint, the peak GPU ' + f'memory is {max_memory/1024**3:.1f}GB.') + + train_cost_time = time.time() - start_train_t + logger.info(f'[Train] Cost {timedelta(seconds=int(train_cost_time))}') + # ------------------------ Training End ---------------------------- # + + +if __name__ == '__main__': + + args = parse_args() + sft(args) diff --git a/tools/llava/convert_xtuner_weights_to_hf.py b/tools/llava/convert_xtuner_weights_to_hf.py new file mode 100644 index 000000000..e9f3ece9e --- /dev/null +++ b/tools/llava/convert_xtuner_weights_to_hf.py @@ -0,0 +1,147 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/llava/convert_llava_weights_to_hf.py # noqa: E501 +import argparse + +from transformers import (CLIPImageProcessor, CLIPVisionModel) +from xtuner._lite.modelings.llava import LlavaForConditionalGeneration, EnhancedLlavaConfig, LlavaProcessor +from mmengine import Config +from xtuner.registry import BUILDER +from mmengine import print_log +from xtuner._lite.parallel.fsdp import LoadWoInit + + + +LLM_PREFIX = 'language_model' +VIT_PREFIX = 'vision_tower' +PROJECTOR_MAPPING = { + 'model.0': 'multi_modal_projector.linear_1', + 'model.2': 'multi_modal_projector.linear_2', +} + + +def convert_state_dict_to_hf(state_dict, mapping): + new_state_dict = {} + for key, value in state_dict.items(): + if key.endswith('.inv_freq'): + continue + for key_to_modify, new_key in mapping.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value + return new_state_dict + +def add_prefix(state_dict,prefix): + + new_state_dict = {} + for key, value in state_dict.items(): + new_state_dict[f'{prefix}.{key}'] = value + + return new_state_dict + +def convert_to_hf(cfg, save_dir): + + print_log('Loading XTuner Checkpoint...', 'current') + model = BUILDER.build(cfg.model) + # get state_dict + llm = model.llm + if model.use_llm_lora: + llm = model.llm.merge_and_unload() + llm.config.use_cache = True + + + llm_state_dict = llm.state_dict() + llm_state_dict = add_prefix(llm_state_dict, LLM_PREFIX) + + + visual_encoder = model.visual_encoder + if model.use_visual_encoder_lora: + visual_encoder = model.visual_encoder.merge_and_unload() + assert isinstance(visual_encoder, CLIPVisionModel),\ + 'This conversion format only supports CLIPVisionModel.' + + visual_encoder_state_dict = visual_encoder.state_dict() + visual_encoder_state_dict = add_prefix( + visual_encoder_state_dict, VIT_PREFIX) + + projector_state_dict = model.projector.state_dict() + projector_state_dict = convert_state_dict_to_hf( + projector_state_dict, PROJECTOR_MAPPING) + + state_dict = { + **projector_state_dict, + **llm_state_dict, + **visual_encoder_state_dict + } + + tokenizer = BUILDER.build(cfg.tokenizer) + + # init model + text_config = llm.config + vision_config = visual_encoder.config + + img_token = '' + need_resize = False + if len(tokenizer.encode(img_token, add_special_tokens=False)) > 1: + tokenizer.add_tokens([img_token], special_tokens=True) + img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0] + + print_log(f'[Tokenizer] Added a new token `{img_token}`, ' + f'token id is {img_token_id}, the new vocab size is ' + f'{len(tokenizer)}', 'current') + + llm_vocab_size = text_config.vocab_size + if llm_vocab_size < len(tokenizer): + # We add an image token so we need to resize the model + need_resize = True + else: + img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0] + + print_log('Building an empty HF Llava...', 'current') + config = EnhancedLlavaConfig( + text_config=text_config, + vision_config=vision_config, + image_token_index=img_token_id, + attn_implementation='eager') + + with LoadWoInit(): + llava = LlavaForConditionalGeneration(config) + + print_log('Loading HF format state dict...', 'current') + llava.load_state_dict(state_dict, strict=True, assign=True) + + if need_resize: + ori_emb_shape = llava.get_input_embeddings().weight.shape + llava.resize_token_embeddings(len(tokenizer)) + new_emb_shape = llava.get_input_embeddings().weight.shape + print_log('Pad the parameters of `embbedings` and `output` from ' + f'shape {ori_emb_shape} to shape {new_emb_shape}', + 'current') + + + # processor + image_processor = BUILDER.build(cfg.image_processor) + assert isinstance(image_processor, CLIPImageProcessor),\ + 'This conversion format only supports CLIPImageProcessor.' + + processor = LlavaProcessor( + tokenizer=tokenizer, image_processor=image_processor) + + # save + print_log('Saving HF Llava...', 'current') + llava.save_pretrained(save_dir) + processor.save_pretrained(save_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('config') + parser.add_argument('save_dir') + args = parser.parse_args() + + cfg = Config.fromfile(args.config) + convert_to_hf(cfg, args.save_dir) + + +if __name__ == '__main__': + main() diff --git a/tools/llava/llava_data_examples.json b/tools/llava/llava_data_examples.json new file mode 100644 index 000000000..b559264ed --- /dev/null +++ b/tools/llava/llava_data_examples.json @@ -0,0 +1,6 @@ +{ + "med": { + "annotations": "/mnt/hwfile/gmai/litianbin/QA_Pairs_new/QA_Pairs/internvl_format/fundus_data/", + "sample_ratio": 1.0 + } +} diff --git a/tools/llava/llava_pretrain.py b/tools/llava/llava_pretrain.py new file mode 100644 index 000000000..166f8f539 --- /dev/null +++ b/tools/llava/llava_pretrain.py @@ -0,0 +1,934 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import math +import os +import shutil +import sys +import time +from collections import OrderedDict +from contextlib import nullcontext +from datetime import datetime, timedelta +from functools import partial + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from accelerate.utils import set_module_tensor_to_device +from datasets import Dataset +from mmengine import load, mkdir_or_exist +from mmengine.dist import infer_launcher, init_dist +from mmengine.runner import set_random_seed +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env +from peft import LoraConfig, get_peft_model +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + apply_activation_checkpointing +from torch.distributed.checkpoint.state_dict import (StateDictOptions, + get_model_state_dict, + get_state_dict) +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import _or_policy +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from torch.utils.data import ConcatDataset, DataLoader +from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor, + CLIPVisionModel) +from transformers.utils.import_utils import (is_flash_attn_2_available, + is_torch_sdpa_available) + +from xtuner._lite import AutoTokenizer, get_logger +from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_modules, + packed_sequence) +from xtuner._lite.chat import CHAT_TEMPLATE_MAP +from xtuner._lite.datasets import (LlavaCollator, LlavaRawDataset,LlavaTokenizedDataset, + LlavaTokenizeFunction, SoftPackerForLlava) +from xtuner._lite.datasets.load import (LOAD_FN_MAP, load_datasets, + load_from_cache) +from xtuner._lite.modelings import register_remote_code, LlavaForConditionalGeneration, EnhancedLlavaConfig, LlavaProcessor +from xtuner._lite.parallel import LengthGroupedSampler, ParallelSampler +from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit, + all_required_grad_wrap_policy, + checkpoint_check_fn, dp_lazy_init, + layer_auto_wrap_policy) + +logger = get_logger() + + +def log_format(rank, debug=False): + + formatter = f'[XTuner][RANK {rank}]' + formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]' + + if debug: + formatter += '[{name}:' + formatter += '{function}:' + formatter += '{line}]' + + formatter += ' {message}' + return formatter + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + + model_args = parser.add_argument_group('model', 'Model Related Settings') + model_args.add_argument('--llm', help='repo id or local path of the model') + model_args.add_argument( + '--vit', + default='openai/clip-vit-large-patch14-336', + help='repo id or local path of the model') + model_args.add_argument( + '-t', + '--tokenizer', + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--chat-template', + choices=CHAT_TEMPLATE_MAP.keys(), + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--freeze-llm', + action='store_true', + help="Not updating LLM's parameters") + model_args.add_argument( + '--freeze-vit', + action='store_true', + help="Not updating vit's parameters") + model_args.add_argument( + '--llm-use-lora', + action='store_true', + help='Apply the adapter to LLM.') + model_args.add_argument( + '--llm-lora-targets', + default=None, + nargs='*', + help='The names of the modules to apply the adapter to. ') + model_args.add_argument( + '--llm-lora-r', + default=64, + type=int, + help="Not updating vit's parameters") + model_args.add_argument( + '--llm-lora-alpha', + default=16, + type=int, + help='The alpha parameter for Lora scaling.') + model_args.add_argument( + '--llm-lora-dropout', + default=0.1, + type=float, + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--llm-lora-bias', + default='none', + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--vit-use-lora', + action='store_true', + help='Apply the adapter to Vit.') + model_args.add_argument( + '--vit-lora-targets', + default=None, + type=str, + help='The names of the modules to apply the adapter to. ') + model_args.add_argument( + '--vit-lora-r', + default=64, + type=int, + help="Not updating vit's parameters") + model_args.add_argument( + '--vit-lora-alpha', + default=16, + type=int, + help='The alpha parameter for vit Lora scaling.') + model_args.add_argument( + '--vit-lora-dropout', + default=0.1, + type=float, + help='The dropout probability for vit Lora layers.') + model_args.add_argument( + '--vit-lora-bias', + default='none', + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--dtype', + default='auto', + choices=['fp16', 'bf16', 'auto'], + help=("the dtype of the model forward. When set to 'auto', it will " + 'automatically determine whether bf16 is available, ' + 'prioritizing the use of bf16.')) + + model_args.add_argument( + '--selective-recompute', + default=1.0, + type=float, + help=('the ratio of re-computation for transforemer layers. ' + 'The maximum is 1; the larger the value, the less memory ' + 'required for training. The default is 1, meaning all layers ' + 'need to be re-computated.')) + + data_args = parser.add_argument_group('data', 'Dataset Related Settings') + data_args.add_argument( + '--datasets', + help=('repo id or local path or dir of the datasets. For repo ids, ' + 'the `dset-sources` needs to be appropriately set to ' + '`modelscope` or `huggingface`. For local dir, all json and ' + 'jsonl files will be loaded by default. The type of loaded ' + 'files can be controlled by setting `dset-file-type`')) + data_args.add_argument( + '--dset-file-types', + nargs='*', + default=LOAD_FN_MAP.keys(), + choices=LOAD_FN_MAP.keys(), + help='the file type that needs to be loaded') + data_args.add_argument( + '--dset-cache-dir', + help=('the cache dir of the loaded datasets. When the `datasets` is ' + 'set, the loaded datasets will be cached to this dir. If the ' + '`datasets` are not set, the cached dataset in this dir will be ' + 'loaded.')) + data_args.add_argument( + '--dset-from-cache', + action='store_true', + help=('Load data directly from `dset-cache-dir`. This can save time ' + 'on online tokenization, but if the tokenizer changed, ' + 'recaching is needed.')) + data_args.add_argument( + '--dset-pack-level', + choices=['soft'], + help=('the level of data packing. When `hard`, multiple data will be ' + 'packed to `max_length`, potentially causing some data to be ' + 'truncated, and the length of the packed data will always ' + 'be `max_length`; When `soft`, it will pack multiple data ' + 'into nearly `max_length` without truncating the data.')) + data_args.add_argument('--group-by-length', action='store_true') + data_args.add_argument( + '--max-length', + type=int, + default=2048, + help=('the maximum length of each piece of data, any excess will be ' + 'truncated.')) + data_args.add_argument( + '--num-workers', + type=int, + default=0, + help='how many subprocesses to use for data loading.') + data_args.add_argument( + '--num-proc', + type=int, + default=8, + help='how many subprocesses to use for data mapping.') + + optim_args = parser.add_argument_group('optim', 'Optim Related Settings') + optim_args.add_argument( + '--mirco-batch-size', + type=int, + default=1, + help='batch size for each forward + backward pass') + optim_args.add_argument( + '--global-batch-size', + type=int, + default=16, + help='batch size for each parameter update') + + optim_args.add_argument( + '--lr', default=4e-5, type=float, help='learning rate.') + optim_args.add_argument( + '--wd', default=0.01, type=float, help='weight decay.') + optim_args.add_argument( + '--max-grad-norm', default=1, type=float, help='gradient clipping') + optim_args.add_argument( + '-e', '--epochs', default=1, type=int, help='total training epochs.') + optim_args.add_argument( + '--warmup-ratio', + default=0.03, + type=float, + help=('the proportion of training steps for learning rate warm-up in ' + 'relation to the total training steps.')) + + parser.add_argument('-c', '--config', default=None) + parser.add_argument( + '--work-dir', + default='work_dirs', + help='the dir to save logs and checkpoints') + parser.add_argument( + '--checkpoint-interval', + default=-1, + type=float, + help=('how many steps to save a checkpoint; it can be a floating ' + 'point number less than 1, or an integer greater than or equal ' + "to 1. When it's a floating point, it will be multiplied by the " + 'total number of training steps.')) + parser.add_argument( + '--checkpoint-drop-optimizer', + action='store_true', + help=('only model parameters are saved when saving a checkpoint. ' + 'This can significantly reduce the size of checkpoint files, ' + 'but the saved checkpoints cannot be resumed.')) + parser.add_argument( + '--log-interval', default=1, type=int, help='log interval') + parser.add_argument( + '--resume', + type=str, + default=None, + help='specify checkpoint path to be resumed from.') + parser.add_argument( + '--seed', type=int, default=0, help='random seed for the training') + parser.add_argument( + '--debug', action='store_true', help='Set logger level to `DEBUG`') + args = parser.parse_args() + return args + + +def is_interval(step, total_steps, interval): + return (step + 1) % interval == 0 or (step + 1) == total_steps + + +def map_meta_modules(model, meta_model): + modules = {name: mod for name, mod in model.named_modules()} + meta_module_map = { + mod: modules[name] + for name, mod in meta_model.named_modules() + } + return meta_module_map + + +def build_llava_model(args, + config, + tokenizer, + world_size, + device='cpu', + dtype=torch.float32, + resize_emb=False): + + with torch.device(device): + _cfg = copy.deepcopy(config) + llava = LlavaForConditionalGeneration(_cfg) + + # llava has not loaded the pre-trained parameters of llm and vit + if device != 'meta': + del llava.language_model + del llava.vision_tower + with LoadWoInit(): + llm = AutoModelForCausalLM.from_pretrained( + args.llm, config=_cfg.text_config) + vit = CLIPVisionModel.from_pretrained( + args.vit, config=_cfg.vision_config) + llava.language_model = llm + llava.vision_tower = vit + + llava.to(dtype) + + if resize_emb: + ori_emb_shape = llava.get_input_embeddings().weight.shape + llava.resize_token_embeddings(len(tokenizer)) + new_emb_shape = llava.get_input_embeddings().weight.shape + logger.info('Pad the parameters of `embbedings` and `output` from ' + f'shape {ori_emb_shape} to shape {new_emb_shape}') + + if args.freeze_llm or args.llm_use_lora: + llava.language_model.requires_grad_(False) + if world_size > 1: + llava.language_model.to(dtype) + + if args.freeze_vit or args.vit_use_lora: + llava.vision_tower.requires_grad_(False) + if world_size > 1: + llava.vision_tower.to(dtype) + + if args.llm_use_lora: + llm = llava.language_model + if args.llm_lora_targets is None: + llm_cls = llm.__class__.__name__ + args.llm_lora_targets = LORA_TARGET_MAP[llm_cls] + llm_lora_cfg = LoraConfig( + target_modules=args.llm_lora_targets, + r=args.llm_lora_r, + lora_alpha=args.llm_lora_alpha, + lora_dropout=args.llm_lora_dropout, + bias=args.llm_lora_bias, + task_type='CAUSAL_LM') + lora_llm = get_peft_model(llm, llm_lora_cfg) + llava.language_model = lora_llm + + if args.vit_use_lora: + vit = llava.vision_tower + if args.vit_lora_targets is None: + vit_cls = vit.__class__.__name__ + args.vit_lora_targets = LORA_TARGET_MAP[vit_cls] + vit_lora_cfg = LoraConfig( + target_modules=args.vit_lora_targets, + r=args.vit_lora_r, + lora_alpha=args.vit_lora_alpha, + lora_dropout=args.vit_lora_dropout, + bias=args.vit_lora_bias, + ) + llava.vision_tower = get_peft_model(vit, vit_lora_cfg) + + return llava + + +# @logger.catch +def llava_pretrain(args): + ########################################################################### + # 1. Environment # + ########################################################################### + if args.llm_use_lora: + args.freeze_llm = True + + if args.vit_use_lora: + args.freeze_vit = True + + dist_launcher = infer_launcher() + init_dist(dist_launcher) + set_random_seed(args.seed) + + world_size = int(os.environ['WORLD_SIZE']) + dp_size = world_size + + if args.global_batch_size < dp_size or args.global_batch_size % dp_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + f'should be divisible by the world_size{world_size}.') + + if (args.global_batch_size / dp_size) % args.mirco_batch_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + f'should be divisible by the world_size{world_size}*' + f'`mirco_batch_size`({args.mirco_batch_size})') + + # During data packing, it is essential to tokenize the data in + # advance, cache the tokenized data, so that it can be quickly + # loaded for the second training without the need to re-tokenize. + if args.dset_cache_dir and os.path.isdir(args.dset_cache_dir): + if len(os.listdir(args.dset_cache_dir)): + logger.warning(f'`{args.dset_cache_dir}` is not an empty ' + 'folder, which may lead to inaccurate ' + 'cache results.') + + device_mesh = init_device_mesh( + 'cuda', (dp_size, ), mesh_dim_names=('dp', )) + + dp_mesh = device_mesh['dp'] + + rank = dp_mesh.get_local_rank() + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + + objects = [timestamp] + dist.broadcast_object_list(objects, src=0) + timestamp = objects[0] + + args.work_dir = os.path.join(args.work_dir, timestamp) + mkdir_or_exist(args.work_dir) + + log_file = os.path.join(args.work_dir, f'rank{rank}.log') + + # Change the log format printed in the terminal + lvl = 'DEBUG' if args.debug else 'INFO' + logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug)) + # Change the format saved in the log file + logger.add(log_file, format=log_format(rank), backtrace=True, catch=True) + + logger.info(args) + if rank == 0: + env = collect_env() + import transformers + + import xtuner + env['Transformers'] = transformers.__version__ + env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}' + runtime_env = OrderedDict() + runtime_env.update(env) + runtime_env['Seed'] = args.seed + runtime_env['World Size'] = world_size + runtime_env['Distributed launcher'] = dist_launcher + + runtime_env_info = '\n ' + '\n '.join( + f'{k}: {v}' for k, v in runtime_env.items()) + dash_line = '-' * 60 + logger.info('\n' + dash_line + '\nRuntime environment:' + + runtime_env_info + '\n' + dash_line + '\n') + + shutil.copy(__file__, args.work_dir) + + # ------------------- Environment End ------------------------------ # + + ########################################################################### + # 2. Dataset & Dataloader # + ########################################################################### + + start_load_data_t = time.time() + + chat_template = CHAT_TEMPLATE_MAP[args.chat_template] + + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer if args.tokenizer else args.llm, + trust_remote_code=True, + padding_side='right') + + register_remote_code() + _text_config = AutoConfig.from_pretrained(args.llm) + _vision_config = AutoConfig.from_pretrained(args.vit).vision_config + _img_processor = AutoProcessor.from_pretrained(args.vit).image_processor + processor = LlavaProcessor(_img_processor, tokenizer) + img_processor = processor.image_processor + + _crop_size = img_processor.crop_size + patch_size = _vision_config.patch_size + img_size = (_crop_size['height'], _crop_size['width']) + per_img_tokens = (img_size[0] // patch_size) * (img_size[1] // patch_size) + + img_token = chat_template.image_token + need_resize_emb = False + if len(tokenizer.encode(img_token, add_special_tokens=False)) > 1: + tokenizer.add_tokens([img_token], special_tokens=True) + img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0] + logger.info(f'[Tokenizer] Added a new token `{img_token}`, ' + f'token id is {img_token_id}, the new vocab size is ' + f'{len(tokenizer)}') + + _llm_vocab_size = _text_config.vocab_size + + if _llm_vocab_size < len(tokenizer): + need_resize_emb = True + else: + img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0] + + if args.dset_from_cache: + if args.dset_pack_level == 'soft': + init_fn = partial( + SoftPackerForLlava.from_cache, + image_processor=img_processor, + max_length=args.max_length) + else: + init_fn = partial( + LlavaTokenizedDataset.from_cache, + image_processor=img_processor, + max_length=args.max_length) + _datasets = load_from_cache(args.dset_cache_dir, init_fn) + dist.barrier() + else: + dset_infos = load(args.datasets) + + sample_ratios = [] + annotations = [] + init_fns = [] + tokenize_fns = [] + for _, info in dset_infos.items(): + if 'format' in info: + dset_format = info['format'] + else: + dset_format = 'llava' + + if 'image_dir' in info: + image_dir = info['image_dir'] + else: + image_dir = None + + # If your data format is not in `SUPPORT_DATA_FORMATS`, you should + # redefine a `tokenize_fn`, defining how to convert a piece of raw + # data into tokenized data. + # The tokenized data must include `input_ids`, `labels``, + # and `num_tokens`. + tokenize_fn = LlavaTokenizeFunction(tokenizer, chat_template, + per_img_tokens, image_dir, + dset_format) + + if args.dset_pack_level == 'soft': + init_fn = partial( + SoftPackerForLlava, + image_processor=img_processor, + max_length=args.max_length) + elif args.dset_cache_dir: + init_fn = partial( + LlavaTokenizedDataset, + image_processor=img_processor, + max_length=args.max_length) + else: + init_fn = partial( + LlavaRawDataset, + image_processor=processor.image_processor, + tokenize_fn=tokenize_fn) + # Online tokenization is used when not using a pack dataset, + # saving startup time. + tokenize_fn = None + + init_fns.append(init_fn) + tokenize_fns.append(tokenize_fn) + sample_ratios.append(info['sample_ratio']) + annotations.append(info['annotations']) + + _datasets = load_datasets( + paths=annotations, + sources='local', + cache_dir=args.dset_cache_dir, + file_types=args.dset_file_types, + sample_ratios=sample_ratios, + num_proc=args.num_proc, + map_fns=tokenize_fns, + init_fns=init_fns) + + if args.dset_pack_level and rank == 0: + # Only the tokenized datasets can count the number of tokens + total_tokens = sum(dset.total_tokens for dset in _datasets) + logger.debug(f'[Dataset] {total_tokens} tokens.') + + train_dataset = ConcatDataset(_datasets) + + if args.dset_pack_level and rank == 0: + ori_samples = sum([len(dset) for dset in _datasets]) + packed_samples = len(train_dataset) + logger.info(f'[Dataset] (Original) {ori_samples} samples.') + logger.info(f'[Dataset] (Packed) {packed_samples} samples.') + + pack_batch = is_flash_attn_2_available() + collator = LlavaCollator(pack_batch=pack_batch) + + if args.group_by_length: + sampler = LengthGroupedSampler(train_dataset, dp_mesh, + args.global_batch_size) + else: + sampler = ParallelSampler( + train_dataset, dp_mesh, args.global_batch_size, shuffle=True) + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.mirco_batch_size, + num_workers=args.num_workers, + sampler=sampler, + collate_fn=collator, + persistent_workers=args.num_workers > 0) + + if rank == 0: + logger.info(f'[Dataloader] {len(train_dataloader)} batches.') + _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)] + _first_batch = collator(_first_batch) + _decoded = tokenizer.batch_decode(_first_batch['input_ids']) + logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}') + logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}') + dist.barrier() + + load_data_cost_time = time.time() - start_load_data_t + logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s') + # ------------------- Dataset & Dataloader End --------------------- # + + ########################################################################### + # 3. FSDP # + ########################################################################### + + start_model_t = time.time() + + if args.dtype == 'auto': + args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16' + + if args.dtype == 'fp16': + dtype = torch.float16 + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype) + scaler = ShardedGradScaler() + elif args.dtype == 'bf16': + if torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype) + scaler = None + else: + raise RuntimeError('The device does not support `bf16`, ' + 'please set `dtype` to `fp16`.') + else: + raise RuntimeError('`dtype` only supports `fp16`,`bf16`, or `auto`, ' + f'but found {args.dtype}.') + + use_lora = (args.llm_use_lora or args.vit_use_lora) + if not use_lora: + autocast = nullcontext() + scaler = None + + + if is_flash_attn_2_available(): + _text_config.attn_implementation = 'flash_attention_2' + elif is_torch_sdpa_available(): + _text_config.attn_implementation = 'sdpa' + + _text_config.use_cache = False + + llava_config = EnhancedLlavaConfig( + _vision_config, _text_config, image_token_index=img_token_id) + + # model parameters must be in fp32. + # this ensures that all numerical values in the optimizer are in fp32. + # FSDP will use low precision during forward. + meta_llava = build_llava_model(args, llava_config, tokenizer, world_size, + 'meta', torch.float32, need_resize_emb) + + if pack_batch or args.dset_pack_level: + dispatch_modules(meta_llava) + + # Only load parameters on rank 0 to avoid each rank repeatedly loading the + # same model into the CPU, wasting memory + if rank == 0: + + llava = build_llava_model(args, llava_config, tokenizer, world_size, + 'cpu', dtype, need_resize_emb) + rank0_meta_llava = copy.deepcopy(meta_llava) + meta_llava_map = map_meta_modules(llava, meta_llava) + else: + meta_llava_map = None + + dist.barrier() + + param_init_fn = partial( + dp_lazy_init, module_map=meta_llava_map, dp_mesh=dp_mesh) + + policies = [layer_auto_wrap_policy] + if args.llm_use_lora or args.vit_use_lora: + policies.append(all_required_grad_wrap_policy) + + torch.cuda.reset_peak_memory_stats() + shard_llava = FSDP( + meta_llava, + device_mesh=dp_mesh, + auto_wrap_policy=partial(_or_policy, policies=policies), + mixed_precision=MixedPrecision( + param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype), + device_id=torch.cuda.current_device(), + use_orig_params=True, + param_init_fn=param_init_fn, + sync_module_states=True, + ) + + max_memory = torch.cuda.max_memory_allocated() + logger.info('The peak GPU memory when building the FSDP model is ' + f'{max_memory/1024**3:.1f}GB.') + + if args.selective_recompute: + check_fn = partial( + checkpoint_check_fn, + target=RECOMPUTE_MODULES, + selective=args.selective_recompute) + apply_activation_checkpointing(shard_llava, check_fn=check_fn) + + fsdp_cost_time = time.time() - start_model_t + logger.info(f'[Model] Cost {fsdp_cost_time:.2f}s') + # -------------------------- FSDP End ------------------------------ # + + ########################################################################### + # 4. Optimizer & Scheduler # + ########################################################################### + requried_grad_params = [ + param for param in shard_llava.parameters() if param.requires_grad + ] + optimizer = AdamW( + requried_grad_params, lr=args.lr, weight_decay=args.wd, fused=True) + + global_batch_size = args.global_batch_size + mirco_batch_size = args.mirco_batch_size + + # `iter` means once forward+backward + # `step` means once optimizer step + # `iters_per_step` means gradient accumulative counts + iters_per_step = global_batch_size // mirco_batch_size // dp_size + iters_per_epoch = len(train_dataloader) + steps_per_epoch = math.ceil(iters_per_epoch / iters_per_step) + + total_epochs = args.epochs + total_steps = steps_per_epoch * total_epochs + + if args.checkpoint_interval == -1: + checkpoint_interval = total_steps + elif args.checkpoint_interval < 1: + checkpoint_interval = int(total_steps * args.checkpoint_interval) + else: + checkpoint_interval = int(args.checkpoint_interval) + + warmup_steps = int(args.warmup_ratio * total_steps) + + def warmup_fn(x): + return x / warmup_steps if x < warmup_steps else 1 + + warmup_scheduler = LambdaLR(optimizer, warmup_fn) + + cosine_scheduler = CosineAnnealingLR( + optimizer, T_max=total_steps - warmup_steps, eta_min=0) + + start_step = 0 + + # ---------------- Optimizer & Scheduler End ----------------------- # + + ########################################################################### + # 5. Training # + ########################################################################### + + start_train_t = time.time() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Train] Begin Train Loop. The current GPU memory is ' + f'{(max_memory / 1024**3):.1f}GB') + for step in range(start_step, total_steps): + + epoch = step // steps_per_epoch + epoch_inner_step = step % steps_per_epoch + if epoch_inner_step == 0 or step == start_step: + # For the first step of each epoch, the data order needs to be + # readjusted. + # Or after resuming, for the first step, the dataloader needs to + # be adjusted to the position before resume. + # train_dataloader.sampler.set_epoch(epoch, inner_step) + # train_dataloader.sampler.set_epoch(epoch, epoch_inner_step) + train_dataloader.sampler.set_epoch(epoch) + data_iterator = iter(train_dataloader) + + if step < warmup_steps: + warmup_scheduler.step() + cur_lr = warmup_scheduler.get_lr()[0] + else: + cosine_scheduler.step() + cur_lr = cosine_scheduler.get_lr()[0] + + torch.cuda.reset_peak_memory_stats() + + step_loss = 0 + step_data_time = 0 + step_start_t = time.time() + step_consumed_tokens = 0 + step_consumed_img_tokens = 0 + for _ in range(iters_per_step): + + _data_start_t = time.time() + data = next(data_iterator) + step_data_time += time.time() - _data_start_t + + input_ids = data['input_ids'].cuda() + pixel_values = data['pixel_values'] + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.cuda() + labels = data['labels'].cuda() + attention_mask = data['attention_mask'].cuda() + num_tokens = data['num_tokens'].cuda() + num_img_tokens = data['num_img_tokens'].cuda() + + packed_ctx = packed_sequence(num_tokens, enable=pack_batch) + + with packed_ctx: + with autocast if use_lora else nullcontext(): + outputs = shard_llava( + input_ids=input_ids, + labels=labels, + pixel_values=pixel_values, + attention_mask=attention_mask) + avg_iter_loss = outputs.loss / iters_per_step + + if scaler and use_lora: + scaler.scale(avg_iter_loss).backward() + else: + avg_iter_loss.backward() + + step_loss += avg_iter_loss.item() + step_consumed_tokens += num_tokens.sum() + step_consumed_img_tokens += num_img_tokens.sum() + + grad_norm = shard_llava.clip_grad_norm_(args.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + + step_text_tokens = step_consumed_tokens - step_consumed_img_tokens + step_img_tokens = step_consumed_img_tokens + step_time = time.time() - step_start_t + eta = step_time * (total_steps - step) + eta = timedelta(seconds=int(eta)) + tgs = int(step_consumed_tokens / step_time) + max_memory = torch.cuda.max_memory_allocated() + if is_interval(step, total_steps, args.log_interval): + logger.info( + f'[Train] (Epoch {epoch}) Step {step+1}/{total_steps} ' # noqa: E501 + f'lr: {cur_lr:.6f} loss: {step_loss:.3f} ' + f'grad_norm: {grad_norm:.2f} ' + f'max_memory: {(max_memory / 1024**3):.1f}GB ' + f'text_tokens: {step_text_tokens} ' + f'image_tokens: {step_img_tokens} ' + f'tgs: {tgs} data_time: {step_data_time:.2f}s ' + f'time: {step_time:.2f}s ' + f'eta: {eta}') + + if is_interval(step, total_steps, checkpoint_interval): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Checkpoint] Before saving checkpoint, the peak GPU ' + f'memory is {max_memory/1024**3:.1f}GB.') + + num_digits = len(str(abs(total_steps))) + work_dir = args.work_dir + ckpt_dir = os.path.join(work_dir, f'ckpt-{step:0{num_digits}}') + hf_dir = os.path.join(work_dir, f'hf-{step:0{num_digits}}') + _options = StateDictOptions(cpu_offload=True, full_state_dict=True) + + full_model_state_dict = get_model_state_dict( + shard_llava, options=_options) + if rank == 0: + saved_llava = copy.deepcopy(rank0_meta_llava) + saved_llava.to(dtype) + for name, param in full_model_state_dict.items(): + set_module_tensor_to_device(saved_llava, name, 'cpu', + param) + + if args.llm_use_lora: + merged_llm = saved_llava.language_model.merge_and_unload() + saved_llava.language_model = merged_llm + + if args.vit_use_lora: + merged_vit = saved_llava.vision_tower.merge_and_unload() + saved_llava.vision_tower = merged_vit + + saved_llava.save_pretrained(hf_dir) + tokenizer.save_pretrained(hf_dir) + processor.save_pretrained(hf_dir) + del saved_llava + + dist.barrier() + del full_model_state_dict + + if args.checkpoint_drop_optimizer: + logger.warning('[Checkpoint] The saved checkpoint cannot be ' + 'resumed. If you want to save a resumable ' + 'checkpoint, please remove ' + '`--checkpoint-drop-optimizer` ' + 'from the command.') + else: + # FSDP cannot be saved via torch.save + # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501 + _options = StateDictOptions( + cpu_offload=True, ignore_frozen_params=True) + (shard_model_state_dict, + shard_optimizer_state_dict) = get_state_dict( + shard_llava, optimizer, options=_options) + + state_dict = { + 'model': shard_model_state_dict, + 'optimizer': shard_optimizer_state_dict, + 'step': step, + 'total_steps': total_steps, + 'warmup_scheduler': warmup_scheduler.state_dict(), + 'cosine_scheduler': cosine_scheduler.state_dict() + } + + writer = dcp.FileSystemWriter(ckpt_dir) + mkdir_or_exist(ckpt_dir) + dcp.save(state_dict, writer) + + max_memory = torch.cuda.max_memory_allocated() + logger.info( + '[Checkpoint] During saving checkpoint, the peak GPU ' + f'memory is {max_memory/1024**3:.1f}GB.') + + train_cost_time = time.time() - start_train_t + logger.info(f'[Train] Cost {train_cost_time}s') + # ------------------------ Training End ---------------------------- # + + +if __name__ == '__main__': + + args = parse_args() + llava_pretrain(args) diff --git a/tools/llava/llava_sft.py b/tools/llava/llava_sft.py new file mode 100644 index 000000000..c207dabfc --- /dev/null +++ b/tools/llava/llava_sft.py @@ -0,0 +1,890 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import math +import os +import shutil +import sys +import time +from collections import OrderedDict +from contextlib import nullcontext +from datetime import datetime, timedelta +from functools import partial + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from accelerate.utils import set_module_tensor_to_device +from datasets import Dataset +from mmengine import load, mkdir_or_exist +from mmengine.dist import infer_launcher, init_dist +from mmengine.runner import set_random_seed +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env +from peft import LoraConfig, get_peft_model +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + apply_activation_checkpointing +from torch.distributed.checkpoint.state_dict import (StateDictOptions, + get_model_state_dict, + get_state_dict) +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import _or_policy +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from torch.utils.data import ConcatDataset, DataLoader +from transformers import (AutoConfig, AutoProcessor) +from transformers.utils.import_utils import (is_flash_attn_2_available, + is_torch_sdpa_available) +from xtuner._lite.modelings.llava import LlavaForConditionalGeneration +from xtuner._lite import AutoTokenizer, get_logger +from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_modules, + packed_sequence) +from xtuner._lite.chat import CHAT_TEMPLATE_MAP +from xtuner._lite.datasets import (LlavaCollator, + LlavaRawDataset, LlavaTokenizedDataset, + LlavaTokenizeFunction, SoftPackerForLlava) +from xtuner._lite.datasets.load import (LOAD_FN_MAP, load_datasets, + load_from_cache) +from xtuner._lite.modelings import register_remote_code +from xtuner._lite.parallel import LengthGroupedSampler, ParallelSampler +from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit, + all_required_grad_wrap_policy, + checkpoint_check_fn, dp_lazy_init, + layer_auto_wrap_policy) + +logger = get_logger() + + +def log_format(rank, debug=False): + + formatter = f'[XTuner][RANK {rank}]' + formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]' + + if debug: + formatter += '[{name}:' + formatter += '{function}:' + formatter += '{line}]' + + formatter += ' {message}' + return formatter + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + + model_args = parser.add_argument_group('model', 'Model Related Settings') + model_args.add_argument( + '--llava', help='repo id or local path of the model') + model_args.add_argument( + '-t', + '--tokenizer', + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--chat-template', + choices=CHAT_TEMPLATE_MAP.keys(), + help=('repo id or local path of the tokenizer. ' + 'Defaults to the same as `model`')) + model_args.add_argument( + '--freeze-llm', + action='store_true', + help="Not updating LLM's parameters") + model_args.add_argument( + '--freeze-vit', + action='store_true', + help="Not updating vit's parameters") + model_args.add_argument( + '--llm-use-lora', + action='store_true', + help='Apply the adapter to LLM.') + model_args.add_argument( + '--llm-lora-targets', + default=None, + nargs='*', + help='The names of the modules to apply the adapter to. ') + model_args.add_argument( + '--llm-lora-r', + default=64, + type=int, + help="Not updating vit's parameters") + model_args.add_argument( + '--llm-lora-alpha', + default=16, + type=int, + help='The alpha parameter for Lora scaling.') + model_args.add_argument( + '--llm-lora-dropout', + default=0.1, + type=float, + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--llm-lora-bias', + default='none', + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--vit-use-lora', + action='store_true', + help='Apply the adapter to Vit.') + model_args.add_argument( + '--vit-lora-targets', + default=None, + type=str, + help='The names of the modules to apply the adapter to. ') + model_args.add_argument( + '--vit-lora-r', + default=64, + type=int, + help="Not updating vit's parameters") + model_args.add_argument( + '--vit-lora-alpha', + default=16, + type=int, + help='The alpha parameter for vit Lora scaling.') + model_args.add_argument( + '--vit-lora-dropout', + default=0.1, + type=float, + help='The dropout probability for vit Lora layers.') + model_args.add_argument( + '--vit-lora-bias', + default='none', + help='The dropout probability for Lora layers.') + model_args.add_argument( + '--dtype', + default='auto', + choices=['fp16', 'bf16', 'auto'], + help=("the dtype of the model forward. When set to 'auto', it will " + 'automatically determine whether bf16 is available, ' + 'prioritizing the use of bf16.')) + + model_args.add_argument( + '--selective-recompute', + default=1.0, + type=float, + help=('the ratio of re-computation for transforemer layers. ' + 'The maximum is 1; the larger the value, the less memory ' + 'required for training. The default is 1, meaning all layers ' + 'need to be re-computated.')) + + data_args = parser.add_argument_group('data', 'Dataset Related Settings') + data_args.add_argument( + '--datasets', + help=('repo id or local path or dir of the datasets. For repo ids, ' + 'the `dset-sources` needs to be appropriately set to ' + '`modelscope` or `huggingface`. For local dir, all json and ' + 'jsonl files will be loaded by default. The type of loaded ' + 'files can be controlled by setting `dset-file-type`')) + data_args.add_argument( + '--dset-file-types', + nargs='*', + default=LOAD_FN_MAP.keys(), + choices=LOAD_FN_MAP.keys(), + help='the file type that needs to be loaded') + data_args.add_argument( + '--dset-cache-dir', + help=('the cache dir of the loaded datasets. When the `datasets` is ' + 'set, the loaded datasets will be cached to this dir. If the ' + '`datasets` are not set, the cached dataset in this dir will be ' + 'loaded.')) + data_args.add_argument( + '--dset-from-cache', + action='store_true', + help=('Load data directly from `dset-cache-dir`. This can save time ' + 'on online tokenization, but if the tokenizer changed, ' + 'recaching is needed.')) + data_args.add_argument( + '--dset-pack-level', + choices=['soft'], + help=('the level of data packing. When `hard`, multiple data will be ' + 'packed to `max_length`, potentially causing some data to be ' + 'truncated, and the length of the packed data will always ' + 'be `max_length`; When `soft`, it will pack multiple data ' + 'into nearly `max_length` without truncating the data.')) + data_args.add_argument('--group-by-length', action='store_true') + data_args.add_argument( + '--max-length', + type=int, + default=2048, + help=('the maximum length of each piece of data, any excess will be ' + 'truncated.')) + data_args.add_argument( + '--num-workers', + type=int, + default=1, + help='how many subprocesses to use for data loading.') + data_args.add_argument( + '--num-proc', + type=int, + default=8, + help='how many subprocesses to use for data mapping.') + + optim_args = parser.add_argument_group('optim', 'Optim Related Settings') + optim_args.add_argument( + '--mirco-batch-size', + type=int, + default=1, + help='batch size for each forward + backward pass') + optim_args.add_argument( + '--global-batch-size', + type=int, + default=16, + help='batch size for each parameter update') + + optim_args.add_argument( + '--lr', default=4e-5, type=float, help='learning rate.') + optim_args.add_argument( + '--wd', default=0.01, type=float, help='weight decay.') + optim_args.add_argument( + '--max-grad-norm', default=1, type=float, help='gradient clipping') + optim_args.add_argument( + '-e', '--epochs', default=1, type=int, help='total training epochs.') + optim_args.add_argument( + '--warmup-ratio', + default=0.03, + type=float, + help=('the proportion of training steps for learning rate warm-up in ' + 'relation to the total training steps.')) + + parser.add_argument('-c', '--config', default=None) + parser.add_argument( + '--work-dir', + default='work_dirs', + help='the dir to save logs and checkpoints') + parser.add_argument( + '--checkpoint-interval', + default=0.25, + type=float, + help=('how many steps to save a checkpoint; it can be a floating ' + 'point number less than 1, or an integer greater than or equal ' + "to 1. When it's a floating point, it will be multiplied by the " + 'total number of training steps.')) + parser.add_argument( + '--checkpoint-drop-optimizer', + action='store_true', + help=('only model parameters are saved when saving a checkpoint. ' + 'This can significantly reduce the size of checkpoint files, ' + 'but the saved checkpoints cannot be resumed.')) + parser.add_argument( + '--log-interval', default=1, type=int, help='log interval') + parser.add_argument( + '--resume', + type=str, + default=None, + help='specify checkpoint path to be resumed from.') + parser.add_argument( + '--seed', type=int, default=0, help='random seed for the training') + parser.add_argument( + '--debug', action='store_true', help='Set logger level to `DEBUG`') + args = parser.parse_args() + return args + + +def is_interval(step, total_steps, interval): + return (step + 1) % interval == 0 or (step + 1) == total_steps + + +def map_meta_modules(model, meta_model): + modules = {name: mod for name, mod in model.named_modules()} + meta_module_map = { + mod: modules[name] + for name, mod in meta_model.named_modules() + } + return meta_module_map + + +def build_llava_model(args, config, world_size, dtype=torch.float32): + _cfg = copy.deepcopy(config) + + with LoadWoInit(): + llava = LlavaForConditionalGeneration.from_pretrained( + args.llava, config=_cfg) + + llava.to(dtype) + + if args.freeze_llm or args.llm_use_lora: + llava.language_model.requires_grad_(False) + if world_size > 1: + llava.language_model.to(dtype) + + if args.freeze_vit or args.vit_use_lora: + llava.vision_tower.requires_grad_(False) + if world_size > 1: + llava.vision_tower.to(dtype) + + if args.llm_use_lora: + llm = llava.language_model + if args.llm_lora_targets is None: + llm_cls = llm.__class__.__name__ + args.llm_lora_targets = LORA_TARGET_MAP[llm_cls] + llm_lora_cfg = LoraConfig( + target_modules=args.llm_lora_targets, + r=args.llm_lora_r, + lora_alpha=args.llm_lora_alpha, + lora_dropout=args.llm_lora_dropout, + bias=args.llm_lora_bias, + task_type='CAUSAL_LM') + lora_llm = get_peft_model(llm, llm_lora_cfg) + llava.language_model = lora_llm + + if args.vit_use_lora: + vit = llava.vision_tower + if args.vit_lora_targets is None: + vit_cls = vit.__class__.__name__ + args.vit_lora_targets = LORA_TARGET_MAP[vit_cls] + vit_lora_cfg = LoraConfig( + target_modules=args.vit_lora_targets, + r=args.vit_lora_r, + lora_alpha=args.vit_lora_alpha, + lora_dropout=args.vit_lora_dropout, + bias=args.vit_lora_bias, + ) + llava.vision_tower = get_peft_model(vit, vit_lora_cfg) + + return llava + + +# @logger.catch +def llava_sft(args): + ########################################################################### + # 1. Environment # + ########################################################################### + if args.llm_use_lora: + args.freeze_llm = True + + if args.vit_use_lora: + args.freeze_vit = True + + dist_launcher = infer_launcher() + init_dist(dist_launcher) + set_random_seed(args.seed) + + world_size = int(os.environ['WORLD_SIZE']) + dp_size = world_size + + if args.global_batch_size < dp_size or args.global_batch_size % dp_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + f'should be divisible by the world_size{world_size}.') + + if (args.global_batch_size / dp_size) % args.mirco_batch_size: + raise ValueError(f'The `global_batch_size`({args.global_batch_size}) ' + f'should be divisible by the world_size{world_size}*' + f'`mirco_batch_size`({args.mirco_batch_size})') + + # During data packing, it is essential to tokenize the data in + # advance, cache the tokenized data, so that it can be quickly + # loaded for the second training without the need to re-tokenize. + if args.dset_cache_dir and os.path.isdir(args.dset_cache_dir): + if len(os.listdir(args.dset_cache_dir)): + logger.warning(f'`{args.dset_cache_dir}` is not an empty ' + 'folder, which may lead to inaccurate ' + 'cache results.') + + device_mesh = init_device_mesh( + 'cuda', (dp_size, ), mesh_dim_names=('dp', )) + + dp_mesh = device_mesh['dp'] + + rank = dp_mesh.get_local_rank() + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + + objects = [timestamp] + dist.broadcast_object_list(objects, src=0) + timestamp = objects[0] + + args.work_dir = os.path.join(args.work_dir, timestamp) + mkdir_or_exist(args.work_dir) + + log_file = os.path.join(args.work_dir, f'rank{rank}.log') + + # Change the log format printed in the terminal + lvl = 'DEBUG' if args.debug else 'INFO' + logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug)) + # Change the format saved in the log file + logger.add(log_file, format=log_format(rank), backtrace=True, catch=True) + + logger.info(args) + if rank == 0: + env = collect_env() + import transformers + + import xtuner + env['Transformers'] = transformers.__version__ + env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}' + runtime_env = OrderedDict() + runtime_env.update(env) + runtime_env['Seed'] = args.seed + runtime_env['World Size'] = world_size + runtime_env['Distributed launcher'] = dist_launcher + + runtime_env_info = '\n ' + '\n '.join( + f'{k}: {v}' for k, v in runtime_env.items()) + dash_line = '-' * 60 + logger.info('\n' + dash_line + '\nRuntime environment:' + + runtime_env_info + '\n' + dash_line + '\n') + + shutil.copy(__file__, args.work_dir) + # ------------------- Environment End ------------------------------ # + + ########################################################################### + # 2. Dataset & Dataloader # + ########################################################################### + + start_load_data_t = time.time() + + chat_template = CHAT_TEMPLATE_MAP[args.chat_template] + + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer if args.tokenizer else args.llava, + trust_remote_code=True, + padding_side='right') + + llava_config = AutoConfig.from_pretrained(args.llava) + + processor = AutoProcessor.from_pretrained( + args.llava, trust_remote_code=True) + img_processor = processor.image_processor + + _crop_size = processor.image_processor.crop_size + patch_size = llava_config.vision_config.patch_size + img_size = (_crop_size['height'], _crop_size['width']) + per_img_tokens = (img_size[0] // patch_size) * (img_size[1] // patch_size) + + img_token = chat_template.image_token + assert len(tokenizer.convert_tokens_to_ids([img_token])) == 1 + + if args.dset_from_cache: + if args.dset_pack_level == 'soft': + init_fn = partial( + SoftPackerForLlava.from_cache, + image_processor=img_processor, + max_length=args.max_length) + else: + init_fn = partial( + LlavaTokenizedDataset.from_cache, + image_processor=img_processor, + max_length=args.max_length) + _datasets = load_from_cache(args.dset_cache_dir, init_fn) + dist.barrier() + else: + dset_infos = load(args.datasets) + + sample_ratios = [] + annotations = [] + init_fns = [] + tokenize_fns = [] + for _, info in dset_infos.items(): + if 'format' in info: + dset_format = info['format'] + else: + dset_format = 'llava' + + if 'image_dir' in info: + image_dir = info['image_dir'] + else: + image_dir = None + + # If your data format is not in `SUPPORT_DATA_FORMATS`, you should + # redefine a `tokenize_fn`, defining how to convert a piece of raw + # data into tokenized data. + # The tokenized data must include `input_ids`, `labels``, + # and `num_tokens`. + tokenize_fn = LlavaTokenizeFunction(tokenizer, chat_template, + per_img_tokens, image_dir, + dset_format) + + if args.dset_pack_level == 'soft': + init_fn = partial( + SoftPackerForLlava, + image_processor=img_processor, + max_length=args.max_length) + else: + init_fn = partial( + LlavaTokenizedDataset, + image_processor=img_processor, + max_length=args.max_length) + + + init_fns.append(init_fn) + tokenize_fns.append(tokenize_fn) + sample_ratios.append(info['sample_ratio']) + annotations.append(info['annotations']) + + _datasets = load_datasets( + paths=annotations, + sources='local', + cache_dir=args.dset_cache_dir, + file_types=args.dset_file_types, + sample_ratios=sample_ratios, + num_proc=args.num_proc, + map_fns=tokenize_fns, + init_fns=init_fns) + + if args.dset_pack_level and rank == 0: + # Only the tokenized datasets can count the number of tokens + total_tokens = sum(dset.total_tokens for dset in _datasets) + logger.debug(f'[Dataset] {total_tokens} tokens.') + + train_dataset = ConcatDataset(_datasets) + + if args.dset_pack_level and rank == 0: + ori_samples = sum([dset.num_samples for dset in _datasets]) + packed_samples = len(train_dataset) + logger.info(f'[Dataset] (Original) {ori_samples} samples.') + logger.info(f'[Dataset] (Packed) {packed_samples} samples.') + + pack_batch = is_flash_attn_2_available() + collator = LlavaCollator(pack_batch=pack_batch) + + if args.group_by_length: + sampler = LengthGroupedSampler(train_dataset, dp_mesh, + args.global_batch_size) + else: + sampler = ParallelSampler( + train_dataset, dp_mesh, args.global_batch_size, shuffle=False) + + dist.barrier() + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.mirco_batch_size, + num_workers=args.num_workers, + sampler=sampler, + collate_fn=collator, + persistent_workers=args.num_workers > 0) + + if rank == 0: + logger.info(f'[Dataloader] {len(train_dataloader)} batches.') + _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)] + _first_batch = collator(_first_batch) + _decoded = tokenizer.batch_decode(_first_batch['input_ids']) + logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}') + logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}') + dist.barrier() + + load_data_cost_time = time.time() - start_load_data_t + logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s') + # ------------------- Dataset & Dataloader End --------------------- # + + ########################################################################### + # 3. FSDP # + ########################################################################### + + start_model_t = time.time() + + if args.dtype == 'auto': + args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16' + + if args.dtype == 'fp16': + dtype = torch.float16 + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype) + scaler = ShardedGradScaler() + elif args.dtype == 'bf16': + if torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype) + scaler = None + else: + raise RuntimeError('The device does not support `bf16`, ' + 'please set `dtype` to `fp16`.') + else: + raise RuntimeError('`dtype` only supports `fp16`,`bf16`, or `auto`, ' + f'but found {args.dtype}.') + + use_lora = args.llm_use_lora or args.vit_use_lora + if not use_lora: + autocast = nullcontext() + scaler = None + + if is_flash_attn_2_available(): + llava_config.text_config.attn_implementation = 'flash_attention_2' + elif is_torch_sdpa_available(): + llava_config.text_config.attn_implementation = 'sdpa' + llava_config.text_config.use_cache = False + + with torch.device('meta'): + # Ensure all numerical values in the optimizer are fp32. + # FSDP will use low precision during forward. + meta_llava = build_llava_model(args, llava_config, world_size, + torch.float32) + + if pack_batch or args.dset_pack_level: + dispatch_modules(meta_llava) + + # Only load parameters on rank 0 to avoid each rank repeatedly loading the + # same model into the CPU, wasting memory + if rank == 0: + with torch.device('cpu'): + llava = build_llava_model(args, llava_config, world_size, dtype) + rank0_meta_llava = copy.deepcopy(meta_llava) + meta_llava_map = map_meta_modules(llava, meta_llava) + else: + meta_llava_map = None + + dist.barrier() + + param_init_fn = partial( + dp_lazy_init, module_map=meta_llava_map, dp_mesh=dp_mesh) + + policies = [layer_auto_wrap_policy] + if args.llm_use_lora or args.vit_use_lora: + policies.append(all_required_grad_wrap_policy) + + torch.cuda.reset_peak_memory_stats() + shard_llava = FSDP( + meta_llava, + device_mesh=dp_mesh, + auto_wrap_policy=partial(_or_policy, policies=policies), + mixed_precision=MixedPrecision( + param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype), + device_id=torch.cuda.current_device(), + use_orig_params=True, + param_init_fn=param_init_fn, + sync_module_states=True, + ) + + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Model] The peak GPU memory when building the FSDP model is ' + f'{max_memory/1024**3:.1f}GB.') + + if args.selective_recompute: + check_fn = partial( + checkpoint_check_fn, + target=RECOMPUTE_MODULES, + selective=args.selective_recompute) + apply_activation_checkpointing(shard_llava, check_fn=check_fn) + + fsdp_cost_time = time.time() - start_model_t + logger.info(f'[Model] Cost {fsdp_cost_time:.2f}s') + # -------------------------- FSDP End ------------------------------ # + + ########################################################################### + # 4. Optimizer & Scheduler # + ########################################################################### + requried_grad_params = [ + param for param in shard_llava.parameters() if param.requires_grad + ] + optimizer = AdamW( + requried_grad_params, lr=args.lr, weight_decay=args.wd, fused=True) + + global_batch_size = args.global_batch_size + mirco_batch_size = args.mirco_batch_size + + # `iter` means once forward+backward + # `step` means once optimizer step + # `per_step_iters` means gradient accumulative counts + per_step_iters = global_batch_size // mirco_batch_size // dp_size + per_epoch_iters = len(train_dataloader) + per_epoch_steps = math.ceil(per_epoch_iters / per_step_iters) + + total_epochs = args.epochs + total_steps = per_epoch_steps * total_epochs + + if args.checkpoint_interval == -1: + checkpoint_interval = total_steps + elif args.checkpoint_interval < 1: + checkpoint_interval = int(total_steps * args.checkpoint_interval) + else: + checkpoint_interval = int(args.checkpoint_interval) + + warmup_steps = int(args.warmup_ratio * total_steps) + + def warmup_fn(x): + return x / warmup_steps if x < warmup_steps else 1 + + warmup_scheduler = LambdaLR(optimizer, warmup_fn) + + cosine_scheduler = CosineAnnealingLR( + optimizer, T_max=total_steps - warmup_steps, eta_min=0) + + start_step = 0 + + # ---------------- Optimizer & Scheduler End ----------------------- # + + ########################################################################### + # 5. Training # + ########################################################################### + + start_train_t = time.time() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Train] Begin Train Loop. The current GPU memory is ' + f'{(max_memory / 1024**3):.1f}GB') + for step in range(start_step, total_steps): + + epoch = step // per_epoch_steps + epoch_inner_step = step % per_epoch_steps + if epoch_inner_step == 0 or step == start_step: + # For the first step of each epoch, the data order needs to be + # readjusted. + # Or after resuming, for the first step, the dataloader needs to + # be adjusted to the position before resume. + # train_dataloader.sampler.set_epoch(epoch, inner_step) + train_dataloader.sampler.set_epoch(epoch, epoch_inner_step) + data_iterator = iter(train_dataloader) + + if step < warmup_steps: + warmup_scheduler.step() + cur_lr = warmup_scheduler.get_lr()[0] + else: + cosine_scheduler.step() + cur_lr = cosine_scheduler.get_lr()[0] + + torch.cuda.reset_peak_memory_stats() + + step_loss = 0 + step_data_time = 0 + step_start_t = time.time() + step_consumed_tokens = 0 + step_consumed_img_tokens = 0 + for _ in range(per_step_iters): + + _data_start_t = time.time() + data = next(data_iterator) + step_data_time += time.time() - _data_start_t + + input_ids = data['input_ids'].cuda() + pixel_values = data['pixel_values'] + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.cuda() + labels = data['labels'].cuda() + attention_mask = data['attention_mask'].cuda() + num_tokens = data['num_tokens'].cuda() + num_img_tokens = data['num_img_tokens'].cuda() + + packed_ctx = packed_sequence(num_tokens, enable=pack_batch) + + with packed_ctx: + with autocast if use_lora else nullcontext(): + outputs = shard_llava( + input_ids=input_ids, + labels=labels, + pixel_values=pixel_values, + attention_mask=attention_mask) + avg_iter_loss = outputs.loss / per_step_iters + + if scaler and use_lora: + scaler.scale(avg_iter_loss).backward() + else: + avg_iter_loss.backward() + + step_loss += avg_iter_loss.item() + step_consumed_img_tokens += num_img_tokens.sum() + + if args.dset_pack_level == 'soft': + # During a soft pack process, the data with a length that is + # still smaller than the max length after packing, will be + # padded to the max length. The last element of num tokens + # represents the count of pad tokens. + step_consumed_tokens += num_tokens[:-1].sum() + else: + step_consumed_tokens += num_tokens.sum() + + grad_norm = shard_llava.clip_grad_norm_(args.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + + step_text_tokens = step_consumed_tokens - step_consumed_img_tokens + step_img_tokens = step_consumed_img_tokens + step_time = time.time() - step_start_t + eta = step_time * (total_steps - step) + eta = timedelta(seconds=int(eta)) + tgs = int(step_consumed_tokens / step_time) + max_memory = torch.cuda.max_memory_allocated() + if is_interval(step, total_steps, args.log_interval): + logger.info( + f'[Train] (Epoch {epoch}) Step {step+1}/{total_steps} ' # noqa: E501 + f'lr: {cur_lr:.6f} loss: {step_loss:.3f} ' + f'grad_norm: {grad_norm:.2f} ' + f'max_memory: {(max_memory / 1024**3):.1f}GB ' + f'text_tokens: {step_text_tokens} ' + f'image_tokens: {step_img_tokens} ' + f'tgs: {tgs} data_time: {step_data_time:.2f}s ' + f'time: {step_time:.2f}s ' + f'eta: {eta}') + + if is_interval(step, total_steps, checkpoint_interval): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_memory = torch.cuda.max_memory_allocated() + logger.info('[Checkpoint] Before saving checkpoint, the peak GPU ' + f'memory is {max_memory/1024**3:.1f}GB.') + + digits = len(str(abs(total_steps))) + work_dir = args.work_dir + + ckpt_id = f'{(step+1):0{digits}}-of-{total_steps:0{digits}}' + ckpt_dir = os.path.join(work_dir, f'ckpt-{ckpt_id}') + hf_dir = os.path.join(work_dir, f'hf-{ckpt_id}') + _options = StateDictOptions(cpu_offload=True, full_state_dict=True) + + full_model_state_dict = get_model_state_dict( + shard_llava, options=_options) + if rank == 0: + saved_llava = copy.deepcopy(rank0_meta_llava) + saved_llava.to(dtype) + for name, param in full_model_state_dict.items(): + set_module_tensor_to_device(saved_llava, name, 'cpu', + param) + + if args.llm_use_lora: + merged_llm = saved_llava.language_model.merge_and_unload() + saved_llava.language_model = merged_llm + + if args.vit_use_lora: + merged_vit = saved_llava.vision_tower.merge_and_unload() + saved_llava.vision_tower = merged_vit + + saved_llava.save_pretrained(hf_dir) + tokenizer.save_pretrained(hf_dir) + processor.save_pretrained(hf_dir) + del saved_llava + + dist.barrier() + del full_model_state_dict + + if args.checkpoint_drop_optimizer: + logger.warning('[Checkpoint] The saved checkpoint cannot be ' + 'resumed. If you want to save a resumable ' + 'checkpoint, please remove ' + '`--checkpoint-drop-optimizer` ' + 'from the command.') + else: + # FSDP cannot be saved via torch.save + # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501 + _options = StateDictOptions( + cpu_offload=True, ignore_frozen_params=True) + (shard_model_state_dict, + shard_optimizer_state_dict) = get_state_dict( + shard_llava, optimizer, options=_options) + + state_dict = { + 'model': shard_model_state_dict, + 'optimizer': shard_optimizer_state_dict, + 'step': step, + 'total_steps': total_steps, + 'warmup_scheduler': warmup_scheduler.state_dict(), + 'cosine_scheduler': cosine_scheduler.state_dict() + } + + writer = dcp.FileSystemWriter(ckpt_dir) + mkdir_or_exist(ckpt_dir) + dcp.save(state_dict, writer) + + max_memory = torch.cuda.max_memory_allocated() + logger.info( + '[Checkpoint] During saving checkpoint, the peak GPU ' + f'memory is {max_memory/1024**3:.1f}GB.') + + train_cost_time = time.time() - start_train_t + logger.info(f'[Train] Cost {train_cost_time}s') + # ------------------------ Training End ---------------------------- # + + +if __name__ == '__main__': + + args = parse_args() + llava_sft(args) diff --git a/tools/llava_pretrain.json b/tools/llava_pretrain.json new file mode 100644 index 000000000..727907d74 --- /dev/null +++ b/tools/llava_pretrain.json @@ -0,0 +1,7 @@ +{ + "llava_pretrain": { + "image_dir": "/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/LLaVA-Pretrain/images", + "annotations": "/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json", + "sample_ratio": 1.0 + } +} diff --git a/xtuner/_lite/__init__.py b/xtuner/_lite/__init__.py new file mode 100644 index 000000000..192fd51f8 --- /dev/null +++ b/xtuner/_lite/__init__.py @@ -0,0 +1,65 @@ +import sys +from loguru import logger +import os +import subprocess + +from .auto import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from .device import get_device, get_torch_device_module + + +_LOGGER = None + +def log_format(debug=False): + + formatter = '[XTuner][{time:YYYY-MM-DD HH:mm:ss}][{level}]' + + if debug: + formatter += '[{name}:' + formatter += '{function}:' + formatter += '{line}]' + + formatter += ' {message}' + return formatter + + +def get_logger(level="INFO"): + global _LOGGER + if _LOGGER is None: + # Remove the original logger in Python to prevent duplicate printing. + logger.remove() + logger.add(sys.stderr, level=level, format=log_format(debug=level=="DEBUG")) + _LOGGER = logger + return _LOGGER + + +def get_repo_git_info(repo_path): + original_directory = os.getcwd() + os.chdir(repo_path) + + try: + branch = subprocess.check_output( + ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], + stderr=subprocess.STDOUT + ).strip().decode('utf-8') + + commit_id = subprocess.check_output( + ['git', 'rev-parse', 'HEAD'], + stderr=subprocess.STDOUT + ).strip().decode('utf-8') + + remote_url = subprocess.check_output( + ['git', 'remote', 'get-url', 'origin'], + stderr=subprocess.STDOUT + ).strip().decode('utf-8') + + return branch, commit_id, remote_url + except subprocess.CalledProcessError as e: + return None, None, None + finally: + os.chdir(original_directory) + + +__all__ = [ + 'AutoConfig', 'AutoModelForCausalLM', 'AutoTokenizer', 'get_device', + 'get_torch_device_module' +] diff --git a/xtuner/_lite/accelerate/__init__.py b/xtuner/_lite/accelerate/__init__.py new file mode 100644 index 000000000..15316a252 --- /dev/null +++ b/xtuner/_lite/accelerate/__init__.py @@ -0,0 +1,14 @@ +from .dispatches import dispatch_hf_code +from .generate import contiguous_batching_generate +from .load import LoadWoInit +from .lora import LORA_TARGET_MAP +from .packed import pack_sequence, packed_sequence, unpack_sequence +from .utils import (lmdeploy_is_available, npu_is_available, liger_kernel_is_available, + profile_time_and_memory, varlen_attn_is_available) + +__all__ = [ + 'dispatch_hf_code', 'contiguous_batching_generate', 'LoadWoInit', + 'LORA_TARGET_MAP', 'pack_sequence', 'packed_sequence', 'unpack_sequence', + 'varlen_attn_is_available', 'lmdeploy_is_available', 'npu_is_available', + 'profile_time_and_memory' +] diff --git a/xtuner/_lite/accelerate/dispatches/__init__.py b/xtuner/_lite/accelerate/dispatches/__init__.py new file mode 100644 index 000000000..6f470c258 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .huggingface import dispatch_hf_code + +__all__ = ['dispatch_hf_code'] diff --git a/xtuner/_lite/accelerate/dispatches/_attention.py b/xtuner/_lite/accelerate/dispatches/_attention.py new file mode 100644 index 000000000..c723e549d --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/_attention.py @@ -0,0 +1,179 @@ +import math + +import numpy as np +import torch +from torch.nn import functional as F + +from xtuner._lite import get_device +from xtuner._lite.parallel import sequence_parallel_wrapper + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import (index_first_axis, pad_input, + unpad_input) +except ImportError: + pass + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def upad_qkv(query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + # Different from the origin version as sequence parallel change + # the number of attention heads. + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), + indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \ + unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +@sequence_parallel_wrapper +def flash_attn_wo_mask( + query_states, + key_states, + value_states, + dropout_p=0.0, + softmax_scale=None, + causal=True, + window_size=(-1, -1), # -1 means infinite context window +): + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size) + return attn_output + + +@sequence_parallel_wrapper +def flash_attn_w_mask( + query_states, # bs, q_len, nhead, h_dim + key_states, + value_states, + attention_mask, + causal=True, + dropout_p=0.0, + window_size=(-1, -1), # -1 means infinite context window +): + batch_size, q_len = query_states.shape[:2] + query_states, key_states, value_states, indices_q, \ + cu_seq_lens, max_seq_lens = upad_qkv( + query_states, key_states, value_states, attention_mask, q_len) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout_p, + causal=causal, + window_size=window_size) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + return attn_output + + +@sequence_parallel_wrapper +def varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + dropout_p=0., + causal=True, + window_size=(-1, -1), # -1 means infinite context window +): + q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( + 0, 1), value_states.flatten(0, 1) + + device = get_device() + if device == 'cuda': + attn_output = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cumulative_len, + cumulative_len, + max_seqlen, + max_seqlen, + dropout_p=dropout_p, + return_attn_probs=False, + causal=causal, + window_size=window_size) + attn_output = attn_output.unsqueeze(0) + elif device == 'npu': + import torch_npu + atten_mask_npu = torch.from_numpy( + np.triu(np.ones([max_seqlen, max_seqlen]), k=1)).bool().to(device) + head_num = q_unpad.shape[1] + attn_output = torch_npu.npu_fusion_attention( + q_unpad, + k_unpad, + v_unpad, + head_num, + pse=None, + padding_mask=None, + atten_mask=atten_mask_npu, + scale=1.0 / math.sqrt(q_unpad.shape[-1]), + keep_prob=1, + input_layout='TND', + actual_seq_qlen=tuple(cumulative_len[1:].cpu().numpy().tolist()), + actual_seq_kvlen=tuple(cumulative_len[1:].cpu().numpy().tolist()), + pre_tockens=2147483647, + next_tockens=0, + inner_precise=0)[0] + attn_output = attn_output.unsqueeze(0) + return attn_output diff --git a/xtuner/_lite/accelerate/dispatches/_fused/__init__.py b/xtuner/_lite/accelerate/dispatches/_fused/__init__.py new file mode 100644 index 000000000..60b38809b --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/_fused/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_norm import layer_norm_forward +from .rms_norm import rms_norm_forward + +# from .rotary import apply_rotary_emb + +__all__ = ['rms_norm_forward', 'layer_norm_forward', 'apply_rotary_emb'] diff --git a/xtuner/_lite/accelerate/dispatches/_fused/layer_norm.py b/xtuner/_lite/accelerate/dispatches/_fused/layer_norm.py new file mode 100644 index 000000000..010ff07c5 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/_fused/layer_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from torch.distributed._tensor import DTensor + +def layer_norm_forward(self, hidden_states): + + if isinstance(self.weight, DTensor): + weight = self.weight.full_tensor() + else: + weight = self.weight + + if isinstance(self.bias, DTensor): + bias = self.bias.full_tensor() + else: + bias = self.bias + + if isinstance(hidden_states, DTensor): + hidden_states = hidden_states.full_tensor() + else: + hidden_states = hidden_states + + return F.layer_norm( + hidden_states, self.normalized_shape, weight, bias, self.eps + ) \ No newline at end of file diff --git a/xtuner/_lite/accelerate/dispatches/_fused/rms_norm.py b/xtuner/_lite/accelerate/dispatches/_fused/rms_norm.py new file mode 100644 index 000000000..0ef772f9e --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/_fused/rms_norm.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from transformers.utils.import_utils import is_flash_attn_2_available +from torch.distributed._tensor import DTensor +from xtuner._lite.accelerate import lmdeploy_is_available, npu_is_available + + +def rms_norm_forward(self, hidden_states): + + from torch.distributed._functional_collectives import AsyncCollectiveTensor + if isinstance(hidden_states, AsyncCollectiveTensor): + hidden_states = hidden_states.wait() + if (hidden_states.device == torch.device('cpu') + or self.weight.device == torch.device('cpu')): + raise RuntimeError( + 'Can not use triton kernels on cpu. Please set `USE_TRITON_KERNEL`' + ' environment variable to 0 before training.') + + if isinstance(self.weight, DTensor): + weight = self.weight.full_tensor() + else: + weight = self.weight + + if lmdeploy_is_available() and not self.training: + from lmdeploy.pytorch.kernels import rms_norm + ret = rms_norm(hidden_states, weight, eps=self.variance_epsilon) + elif is_flash_attn_2_available(): + # from flash_attn.ops.triton.layer_norm import rms_norm_fn + try: + from flash_attn.ops.triton.layernorm import rms_norm_fn + except ImportError: + try: + from flash_attn.ops.triton.layer_norm import rms_norm_fn + except ImportError: + import flash_attn + raise ImportError(f'flash_attn version {flash_attn.__version__}') + ret = rms_norm_fn( + hidden_states, weight, None, eps=self.variance_epsilon) + + elif npu_is_available(): + import torch_npu + ret = torch_npu.npu_rms_norm( + hidden_states, weight, epsilon=self.variance_epsilon)[0] + return ret diff --git a/xtuner/_lite/accelerate/dispatches/_fused/rotary.py b/xtuner/_lite/accelerate/dispatches/_fused/rotary.py new file mode 100644 index 000000000..1e09c1662 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/_fused/rotary.py @@ -0,0 +1,327 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py # noqa:E501 +from typing import Optional, Union + +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + rotary_dim, + seqlen_ro, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + \ + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of X, do calculation, + # then store to 1st and 2nd halves of OUT + X = X + ( + rm[:, None] * stride_x_seqlen + + rk_half[None, :] * stride_x_headdim) + # This is different from the official implementation as the shapes of + # the two tensors cos and sin are (seqlen_ro, rotary_dim) instead of + # (seqlen_ro, rotary_dim // 2). + COS = COS + (rm_cs[:, None] * rotary_dim + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim + rk_half[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & + (rk_half[None, :] < rotary_dim_half), + other=1.0).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & + (rk_half[None, :] < rotary_dim_half), + other=0.0).to(tl.float32) + x0 = tl.load( + X, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0).to(tl.float32) + x1 = tl.load( + X + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + OUT = OUT + ( + rm[:, None] * stride_out_seqlen + + rk_half[None, :] * stride_out_headdim) + tl.store( + OUT, + o0, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + OUT + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately + # since both are slow. + # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = COS[0, 0, 1, 1, ...] and + # sin = SIN[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right + # outputs for the even and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, BLOCK_K) // 2 + # This is different from the official implementation as the shapes of + # the two tensors cos and sin are (seqlen_ro, rotary_dim) instead of + # (seqlen_ro, rotary_dim // 2). + X0 = X + ( + rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + X1 = X + ( + rm[:, None] * stride_x_seqlen + + rk_swap[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim + rk_repeat[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim + rk_repeat[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & + (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & + (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load( + X0, + mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), + other=0.0).to(tl.float32) + x1 = tl.load( + X1, + mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), + other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + OUT = OUT + ( + rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store( + OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim) + sin: (seqlen_ro, rotary_dim) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, ('If cu_seqlens is passed in, ' + 'then max_seqlen must be passed') + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + # rotary_dim *= 2 + assert rotary_dim <= headdim, 'rotary_dim must be <= headdim' + assert headdim <= 256, 'Only support headdim <= 256' + assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen' + + assert ( + cos.dtype == sin.dtype + ), f'cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}' + assert (x.dtype == cos.dtype), ( + f'Input and cos/sin must have the same dtype, ' + f'got {x.dtype} and {cos.dtype}') + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch, ) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = (32 if rotary_dim <= 32 else + (64 if rotary_dim <= 64 else + (128 if rotary_dim <= 128 else 256))) + + def grid(META): + return (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads) + + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton + # (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + rotary_dim, + seqlen_ro, + output.stride(0) + if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) + if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output + + +class ApplyRotaryEmb(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward( + cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. + # Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +apply_rotary_emb = ApplyRotaryEmb.apply diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/__init__.py b/xtuner/_lite/accelerate/dispatches/huggingface/__init__.py new file mode 100644 index 000000000..4fb4b5c36 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/huggingface/__init__.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import types + +from xtuner._lite import get_logger + +logger = get_logger() + + +def _dispatch_forward_fn(module, dispatch_fn): + module.forward = types.MethodType(dispatch_fn, module) + + +def _dispatch_qwen2_attn_flash_forward(module): + assert module.__class__.__name__ in ['Qwen2FlashAttention2', 'Qwen2Attention', 'Qwen2SdpaAttention'] + from .qwen2 import qwen2_attn_flash_forward + from xtuner._lite.accelerate import varlen_attn_is_available + if varlen_attn_is_available(): + _dispatch_forward_fn(module, qwen2_attn_flash_forward) + return qwen2_attn_flash_forward.__name__ + +def _dispatch_qwen2_casual_forward(module): + assert module.__class__.__name__ in ['Qwen2ForCausalLM'] + from .qwen2 import qwen2_casual_forward + _dispatch_forward_fn(module, qwen2_casual_forward) + return qwen2_casual_forward.__name__ + + +def _dispatch_internlm2_varlen_attn_forward(module): + assert module.__class__.__name__ in ['InternLM2FlashAttention2', 'InternLM2Attention', 'InternLM2SdpaAttention'] + from .internlm2 import internlm2_varlen_attn_forward + from xtuner._lite.accelerate import varlen_attn_is_available + if varlen_attn_is_available(): + _dispatch_forward_fn(module, internlm2_varlen_attn_forward) + return internlm2_varlen_attn_forward.__name__ + +def _dispatch_internlm2_casual_forward(module): + assert module.__class__.__name__ in ['InternLM2ForCausalLM'] + from .internlm2 import internlm2_causal_forward + _dispatch_forward_fn(module, internlm2_causal_forward) + return internlm2_causal_forward.__name__ + + +def _dispatch_internlm3_varlen_self_attn_forward(module): + assert module.__class__.__name__ in ['InternLM3FlashSelfAttention2'] + from .internlm3 import internlm3_self_attn_forward + from xtuner._lite.accelerate import varlen_attn_is_available + if varlen_attn_is_available(): + _dispatch_forward_fn(module, internlm3_self_attn_forward) + return internlm3_self_attn_forward.__name__ + +def _dispatch_internlm3_varlen_cross_attn_forward(module): + assert module.__class__.__name__ in ['InternLM3FlashCrossAttention2'] + from .internlm3 import internlm3_cross_attn_forward + from xtuner._lite.accelerate import varlen_attn_is_available + if varlen_attn_is_available(): + _dispatch_forward_fn(module, internlm3_cross_attn_forward) + return internlm3_cross_attn_forward.__name__ + +def _dispatch_internlm3_cross_decoder_forward(module): + assert module.__class__.__name__ == 'InternLM3CrossDecoder' + from .internlm3 import internlm3_cross_decoder_forward + _dispatch_forward_fn(module, internlm3_cross_decoder_forward) + return internlm3_cross_decoder_forward.__name__ + + +def _dispatch_internlm2_reward_forward(module): + assert module.__class__.__name__ == 'InternLM2ForRewardModel' + from .internlm2 import internlm2_reward_forward + _dispatch_forward_fn(module, internlm2_reward_forward) + return internlm2_reward_forward.__name__ + + +# HACK +def _dispatch_qwen2_reward_forward(module): + assert module.__class__.__name__ == 'Qwen2ForRewardModel' + from .internlm2 import internlm2_reward_forward + _dispatch_forward_fn(module, internlm2_reward_forward) + return internlm2_reward_forward.__name__ + + +def _dispatch_clip_attn_forward(module): + assert module.__class__.__name__ == 'CLIPAttention' + from .clip import clip_flash_attn_forward + _dispatch_forward_fn(module, clip_flash_attn_forward) + return clip_flash_attn_forward.__name__ + + +def _dispatch_rms_norm_forward(module): + from .._fused import rms_norm_forward + _dispatch_forward_fn(module, rms_norm_forward) + return rms_norm_forward.__name__ + + +def _dispatch_internvl2_forward(module): + assert module.__class__.__name__ == 'InternVLChatModel' + from .internvl2 import internvl2_forward + _dispatch_forward_fn(module, internvl2_forward) + return internvl2_forward.__name__ + + +def _dispatch_llama_varlen_attn_forward(module): + assert module.__class__.__name__ == 'LlamaFlashAttention2' + from .llama import llama_flash_attn_forward + _dispatch_forward_fn(module, llama_flash_attn_forward) + return llama_flash_attn_forward.__name__ + + +def _dispatch_llama_casual_forward(module): + assert module.__class__.__name__ in ['LlamaForCausalLM'] + from .llama import llama_casual_forward + _dispatch_forward_fn(module, llama_casual_forward) + return llama_casual_forward.__name__ + + +def _dispatch_minicpmv_forward(module): + assert module.__class__.__name__ == 'MiniCPMV' + from .minicpmv import minicpmv_forward + _dispatch_forward_fn(module, minicpmv_forward) + return minicpmv_forward.__name__ + + +DISPATCH_MAP = { + 'Qwen2RMSNorm': _dispatch_rms_norm_forward, + 'Qwen2FlashAttention2': _dispatch_qwen2_attn_flash_forward, + 'Qwen2Attention': _dispatch_qwen2_attn_flash_forward, + 'Qwen2SdpaAttention': _dispatch_qwen2_attn_flash_forward, + 'Qwen2ForCausalLM': _dispatch_qwen2_casual_forward, + 'InternLM2Attention': _dispatch_internlm2_varlen_attn_forward, + 'InternLM2SdpaAttention': _dispatch_internlm2_varlen_attn_forward, + 'InternLM2FlashAttention2': _dispatch_internlm2_varlen_attn_forward, + 'InternLM2ForCausalLM': _dispatch_internlm2_casual_forward, + 'CLIPAttention': _dispatch_clip_attn_forward, + 'InternLM2ForRewardModel': _dispatch_internlm2_reward_forward, + 'Qwen2ForRewardModel': _dispatch_qwen2_reward_forward, + 'InternLM2RMSNorm': _dispatch_rms_norm_forward, + 'InternVLChatModel': _dispatch_internvl2_forward, # to support sp and liger + 'LlamaFlashAttention2': _dispatch_llama_varlen_attn_forward, + 'LlamaForCausalLM': _dispatch_llama_casual_forward, + 'LlamaRMSNorm': _dispatch_rms_norm_forward, + 'MiniCPMV': _dispatch_minicpmv_forward, # to support sp and liger +} + + +def dispatch_hf_code(model): + from xtuner._lite import get_logger + logger = get_logger() + + for name, module in model.named_modules(): + cls_name = module.__class__.__name__ + if cls_name in DISPATCH_MAP: + dispatched = DISPATCH_MAP[cls_name](module) + if dispatched is not None: + logger.debug( + f'Dispatch {name}({cls_name}) forward to `{dispatched}`') diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/clip.py b/xtuner/_lite/accelerate/dispatches/huggingface/clip.py new file mode 100644 index 000000000..ed99bf771 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/huggingface/clip.py @@ -0,0 +1,98 @@ +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from transformers import CLIPVisionModel + +from .._attention import flash_attn_wo_mask + + +def clip_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel.""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states).view(bsz, tgt_len, + self.num_heads, -1) + key_states = self.k_proj(hidden_states).view(bsz, tgt_len, self.num_heads, + -1) + value_states = self.v_proj(hidden_states).view(bsz, tgt_len, + self.num_heads, -1) + + # proj_shape = (bsz * self.num_heads, -1, self.head_dim) + # query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + # key_states = key_states.view(*proj_shape) + # value_states = value_states.view(*proj_shape) + + # src_len = key_states.size(1) + # attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + # raise ValueError( + # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + # f" {attn_weights.size()}" + # ) + + # # apply the causal_attention_mask first + # if causal_attention_mask is not None: + # if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + # f" {causal_attention_mask.size()}" + # ) + # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, tgt_len, src_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + # ) + # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # if output_attentions: + # # this operation is a bit akward, but it's required to + # # make sure that attn_weights keeps its gradient. + # # In order to do so, attn_weights have to reshaped + # # twice and have to be reused in the following + # attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + # attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + # else: + # attn_weights_reshaped = None + + # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # attn_output = torch.bmm(attn_probs, value_states) + + # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + # raise ValueError( + # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + # f" {attn_output.size()}" + # ) + + # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + # attn_output = attn_output.transpose(1, 2) + # attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + self.dropout if self.training else 0, + causal=causal_attention_mask is not None).view(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/internlm2.py b/xtuner/_lite/accelerate/dispatches/huggingface/internlm2.py new file mode 100644 index 000000000..fa81f0936 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/huggingface/internlm2.py @@ -0,0 +1,677 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union +import inspect + +import torch +from einops import rearrange +from mmengine import MessageHub +from transformers.cache_utils import StaticCache, Cache +from transformers.modeling_outputs import SequenceClassifierOutputWithPast, CausalLMOutputWithPast + +from xtuner._lite.accelerate import lmdeploy_is_available, liger_kernel_is_available +from .._attention import flash_attn_wo_mask, varlen_flash_attn + + +class InternLM2RotaryEmbedding(torch.nn.Module): + + def __init__(self, + dim, + max_position_embeddings=2048, + base=1000000, + device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / ( + base**(torch.arange(0, dim, 2).float().to(device) / dim)) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange( + self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + + def forward(self, x, seq_len): + # x: [bs, num_attention_heads, seq_len, head_size] + if (seq_len > self.max_seq_len_cached + or self.cos_cached.device != x.device + or self.cos_cached.dtype != x.dtype): + self.max_seq_len_cached = seq_len + assert self.inv_freq.dtype == torch.float32 + t = torch.arange( + self.max_seq_len_cached, + device=x.device, + dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(t.device)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos().to(x.dtype) + self.sin_cached = emb.sin().to(x.dtype) + return ( + self.cos_cached[:seq_len, ...], + self.sin_cached[:seq_len, ...], + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_old(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """This is the equivalent of torch.repeat_interleave(x, dim=1, + repeats=n_rep). + + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) + to (batch, seqlen, num_attention_heads, head_dim)""" + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, + None, :].expand(batch, slen, + num_key_value_heads, n_rep, + head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, + head_dim) + + +def _internlm2_varlen_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + # Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501 + if isinstance(past_key_value, StaticCache): + raise ValueError( + '`static` cache implementation is not compatible with ' + '`attn_implementation==flash_attention_2` make sure to use `sdpa` ' + 'in the mean time, and open an issue at ' + 'https://github.com/huggingface/transformers') + + bsz, q_len, _ = hidden_states.size() + attn_context = MessageHub.get_instance('packed_sequence') + + position_ids = attn_context.get_info('position_ids') + sp_mesh = attn_context.get_info('sp_mesh') + assert position_ids.size(1) == q_len, f'{position_ids.size(1)} {q_len}' + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + signature = inspect.signature(self.rotary_emb.forward) + if 'seq_len' in signature.parameters: + # old + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1) + query_states, key_states = apply_rotary_pos_emb_old(query_states, key_states, cos, sin, position_ids) + else: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; + # cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently + # casted in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (InternLM2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.wqkv.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + + cumulative_lengths = attn_context.get_info('cumulative_lengths') + if cumulative_lengths is not None and bsz == 1: + max_seqlen = attn_context.get_info('max_seqlen') + attn_output = varlen_flash_attn(query_states, key_states, value_states, + cumulative_lengths, max_seqlen,training=self.training, sp_mesh=sp_mesh) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + training=self.training, + sp_mesh=sp_mesh) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + # Due to the implementation of the PyTorch version of flash attention, + # even when the output_attentions flag is set to True, it is not possible + # to return the attn_weights. + return attn_output, None, past_key_value + + +def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite implementation of LlamaAttention.forward. + + Add continuous batching support. Add paged attention support. TP support. + """ + from lmdeploy.pytorch.kernels import \ + apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy + from lmdeploy.pytorch.kernels import fill_kv_cache, paged_attention_fwd + attn_ctx = MessageHub.get_instance('paged_attention') + kv_seq_length = attn_ctx.get_info('kv_seq_length') + q_seq_length = attn_ctx.get_info('q_seq_length') + q_start_loc = attn_ctx.get_info('q_start_loc') + block_offsets = attn_ctx.get_info('block_offsets') + max_q_seq_length = attn_ctx.get_info('max_q_seq_length') + max_kv_seq_length = attn_ctx.get_info('max_kv_seq_length') + position_ids = attn_ctx.get_info('position_ids') + + # position_ids + def __qkv_proj(hidden_states): + """qkv_proj.""" + # from torch.distributed import get_rank + # if get_rank() == 0: + # breakpoint() + # else: + # import time + # time.sleep(10000) + qkv_states = self.wqkv(hidden_states[0]).unsqueeze(0) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> (b q) h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = query_states.flatten(1, 2) + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + return query_states, key_states, value_states + + from lmdeploy.pytorch.kernels import \ + apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy + + def __rotary_emb_fn(query_states, key_states, value_states): + """rotary embedding func.""" + # breakpoint() + # query_states = query_states.unsqueeze(0).transpose(1, 2) + # key_states = key_states.unsqueeze(0).transpose(1, 2) + # value_states = value_states.unsqueeze(0).transpose(1, 2) + if self.layer_idx == 0: + + cos, sin = self.rotary_emb(value_states, position_ids) + attn_ctx.update_info('rotary_cos_sin', (cos, sin)) + else: + cos, sin = attn_ctx.get_info('rotary_cos_sin') + + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + # cos, sin) + # query_states = query_states.transpose(1, 2).squeeze(0) + # key_states = key_states.transpose(1, 2).squeeze(0) + # value_states = value_states.transpose(1, 2).squeeze(0) + query_states, key_states = apply_rotary_pos_emb_lmdeploy( + query_states, + key_states, + cos, + sin, + q_embed=query_states, + k_embed=key_states) + + return query_states, key_states, value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + + fill_kv_cache( + key_states, + value_states, + past_key_value[self.layer_idx][0], + past_key_value[self.layer_idx][1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + ) + + attn_output = query_states + paged_attention_fwd( + query_states, + past_key_value[self.layer_idx][0], + past_key_value[self.layer_idx][1], + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + # max_kv_seq_length=max_kv_seq_length, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + attn_output = self.wo(attn_output) + + return attn_output, None, past_key_value + + +def _flash_att_infer( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite implementation of LlamaAttention.forward. + + Add continuous batching support. Add paged attention support. TP support. + """ + from lmdeploy.pytorch.kernels import \ + apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy + from lmdeploy.pytorch.kernels import fill_kv_cache, paged_attention_fwd + attn_ctx = MessageHub.get_instance('paged_attention') + kv_seq_length = attn_ctx.get_info('kv_seq_length') + q_seq_length = attn_ctx.get_info('q_seq_length') + q_start_loc = attn_ctx.get_info('q_start_loc') + block_offsets = attn_ctx.get_info('block_offsets') + max_q_seq_length = attn_ctx.get_info('max_q_seq_length') + max_kv_seq_length = attn_ctx.get_info('max_kv_seq_length') + cumulative_length = attn_ctx.get_info('cumulative_length') + is_prefilling = attn_ctx.get_info('is_prefilling') + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> (b q) h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = query_states.flatten(1, 2) + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + + from lmdeploy.pytorch.kernels import \ + apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy + + if self.layer_idx == 0: + + cos, sin = self.rotary_emb(value_states, position_ids) + attn_ctx.update_info('rotary_cos_sin', (cos, sin)) + else: + cos, sin = attn_ctx.get_info('rotary_cos_sin') + + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + # cos, sin) + # query_states = query_states.transpose(1, 2).squeeze(0) + # key_states = key_states.transpose(1, 2).squeeze(0) + # value_states = value_states.transpose(1, 2).squeeze(0) + query_states, key_states = apply_rotary_pos_emb_lmdeploy( + query_states, + key_states, + cos, + sin, + q_embed=query_states, + k_embed=key_states) + + fill_kv_cache( + key_states, + value_states, + past_key_value[self.layer_idx][0], + past_key_value[self.layer_idx][1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + ) + + # attn_output = query_states + + from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + if is_prefilling: + # breakpoint() + key_states = repeat_kv_bshd( + key_states.unsqueeze(0), self.num_key_value_groups).squeeze(0) + value_states = repeat_kv_bshd( + value_states.unsqueeze(0), self.num_key_value_groups).squeeze(0) + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cumulative_length, + cumulative_length, + max_q_seq_length, + max_kv_seq_length, + causal=True) + # attn_output = varlen_flash_attn(query_states, key_states, value_states, + # cumulative_length, max_q_seq_length) + else: + query_states = query_states.unsqueeze(1) + attn_output = flash_attn_with_kvcache( + query_states, + past_key_value[self.layer_idx][0], + past_key_value[self.layer_idx][1], + cache_seqlens=kv_seq_length, + block_table=block_offsets, + causal=True) + attn_output = attn_output.squeeze(1) + # paged_attention_fwd( + # query_states, + # past_key_value[self.layer_idx][0], + # past_key_value[self.layer_idx][1], + # attn_output, + # block_offsets, + # q_start_loc=q_start_loc, + # q_seqlens=q_seq_length, + # kv_seqlens=kv_seq_length, + # max_seqlen=max_q_seq_length, + # ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + attn_output = self.wo(attn_output) + + return attn_output, None, past_key_value + + +def internlm2_varlen_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + + lmdeploy_ctx = MessageHub.get_instance('paged_attention') + + if lmdeploy_is_available() and len(lmdeploy_ctx.runtime_info) > 0: + + # return _contiguous_batching_forward_impl( + # self, hidden_states, position_ids, past_key_value) + return _flash_att_infer(self, hidden_states, position_ids, + past_key_value) + else: + return _internlm2_varlen_attn_forward(self, hidden_states, + attention_mask, position_ids, + past_key_value, + output_attentions, use_cache) + + +def internlm2_reward_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, SequenceClassifierOutputWithPast]: + """labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels + for computing the sequence classification/regression loss. + + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + reward_scores = self.v_head(hidden_states).squeeze(-1) + + loss = None + + # hidden_states = outputs[0] + # hidden_states = self.v_head(hidden_states) + # # get end reward token's score + # ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1) + + # reward_scores = torch.gather(hidden_states.squeeze(-1), 1, ends) + + loss = None + + # if not return_dict: + # ssoutput = (reward_scores,) + outputs[1:] + # return (loss,) + output if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=reward_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + + + +def internlm2_causal_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + label_shifted: bool = False, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, InternLM2ForCausalLM + >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + loss = None + if labels is None: + logits = self.output(hidden_states) + else: + + if liger_kernel_is_available(): + # unable to return logits when using Liger Kernel + logits = None + + if label_shifted: + shift_hidden_states = hidden_states + shift_labels = labels + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_hidden_states.device) + + from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + + loss_fct = LigerFusedLinearCrossEntropyLoss() + loss = loss_fct(self.output.weight, shift_hidden_states, shift_labels, self.output.bias) + + else: + logits = self.output(hidden_states) + + if label_shifted: + shift_logits = logits + shift_labels = labels + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/internvl2.py b/xtuner/_lite/accelerate/dispatches/huggingface/internvl2.py new file mode 100644 index 000000000..eefc77991 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/huggingface/internvl2.py @@ -0,0 +1,259 @@ +from typing import List, Optional, Tuple, Union + +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast +import torch.distributed as dist +from torch.distributed.nn.functional import all_gather +from mmengine.logging import MessageHub +import copy +from xtuner._lite.parallel.setup import get_sp_mesh +import math +import os +from xtuner._lite.parallel.sequence import split_for_sequence_parallel + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss +except ImportError: + LigerFusedLinearCrossEntropyLoss = None + + +def rescale_sp_loss(loss_per_sp_rank, + labels_per_sp_rank, + sp_mesh = None, + ignore_index=-100): + if sp_mesh is None: + sp_group = get_sp_mesh().get_group() + else: + sp_group = sp_mesh.get_group() + + if (sp_group is None) or (dist.get_world_size(sp_group) == 1): + return loss_per_sp_rank + + shift_labels = labels_per_sp_rank + active_tokens = (shift_labels != ignore_index).long().sum() + global_active_tokens = copy.deepcopy(active_tokens) + dist.all_reduce(global_active_tokens, group=sp_group) + loss_weight = active_tokens / global_active_tokens * dist.get_world_size( + group=sp_group) + + if active_tokens == 0: + # convert nan to 0 just for logging + loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank) + + return loss_per_sp_rank * loss_weight + + +def internvl2_forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_flags: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + sp_mesh = get_sp_mesh() + sp_size = sp_mesh.size() + if sp_size > 1: + sp_group = sp_mesh.get_group() + sp_rank = dist.get_rank(sp_group) + + no_split_input_ids = os.environ.get('NO_SPLIT_INPUT_IDS') + split_input_ids = not no_split_input_ids + if split_input_ids: + pad_id = 0 + orig_len_input_ids = input_ids.shape[1] + image_flags = image_flags.squeeze(-1) + assert input_ids.shape[0] == 1, 'batch size must be 1 for sequence parallel' + + if orig_len_input_ids % sp_size != 0: + max_inputs_len = math.ceil(orig_len_input_ids / sp_size) * sp_size + _temp = input_ids.new_full((1, max_inputs_len - orig_len_input_ids), pad_id) + input_ids_new = torch.cat([input_ids, _temp], dim=-1) + else: + input_ids_new = input_ids + input_ids_list = torch.split(input_ids_new, input_ids_new.shape[1] // sp_size, dim=-1) + input_ids_rank_pre = input_ids_list[sp_rank].contiguous() + input_embeds_rank_pre = self.language_model.get_input_embeddings()(input_ids_rank_pre).clone() + + input_embeds = all_gather(input_embeds_rank_pre, group=sp_group) + + input_embeds = torch.cat(input_embeds, dim=1) + input_embeds = input_embeds[:, :orig_len_input_ids] + else: + input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() + + no_split_pixel_values = os.environ.get('NO_SPLIT_PIXEL_VALUES') + split_pixel_values = not no_split_pixel_values + if split_pixel_values: + orig_img_batch = pixel_values.shape[0] + if orig_img_batch % sp_size != 0: + max_inputs_len = math.ceil(orig_img_batch / sp_size) * sp_size + pad_img_batch = max_inputs_len - orig_img_batch + pad_pixel_values_ = pixel_values.new_zeros(pad_img_batch, 3, + pixel_values.shape[2], + pixel_values.shape[3]) + pixel_values = torch.cat([pixel_values, pad_pixel_values_], dim=0) + pixel_values = torch.split(pixel_values, len(pixel_values) // sp_size, dim=0) + pixel_values = pixel_values[sp_rank].contiguous() + + vit_embeds = self.extract_feature(pixel_values) + + vit_embeds = all_gather(vit_embeds, group=sp_group) + + vit_embeds = torch.cat(vit_embeds, dim=0)[:orig_img_batch] + else: + vit_embeds = self.extract_feature(pixel_values) + vit_embeds = vit_embeds[image_flags == 1] + else: + image_flags = image_flags.squeeze(-1) + input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() + + vit_embeds = self.extract_feature(pixel_values) + vit_embeds = vit_embeds[image_flags == 1] + + # vit_batch_size = pixel_values.shape[0] + + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + + # if torch.distributed.get_rank() == 0: + # print( + # f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') + + input_ids = input_ids.reshape(B * N) + selected = (input_ids == self.img_context_token_id) + try: + input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) + except Exception as e: + vit_embeds = vit_embeds.reshape(-1, C) + print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' + f'vit_embeds.shape={vit_embeds.shape}') + n_token = selected.sum() + input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] + + input_embeds = input_embeds.reshape(B, N, C) + + if sp_size > 1: + attn_context = MessageHub.get_instance('packed_sequence') + position_ids = attn_context.get_info('position_ids') + + is_ref_forward = attn_context.get_info('is_ref_forward') + + # TODO: phi3 attention + attn_context.update_info('global_position_ids', position_ids) + attention_mask = None + + if is_ref_forward is not None: + input_embeds = split_for_sequence_parallel( + input_embeds, dim=1, sp_mesh=sp_mesh) + if labels is not None: + labels = split_for_sequence_parallel( + labels, dim=1, sp_mesh=sp_mesh) + else: + if labels is not None: + assert position_ids.size(1) == input_embeds.shape[1] == labels.shape[1], \ + f'{position_ids.size(1)} {input_embeds.shape[1]} {labels.shape[1]}' + else: + assert position_ids.size(1) == input_embeds.shape[1], \ + f'{position_ids.size(1)} {input_embeds.shape[1]}' + + assert position_ids.size(1) % sp_size == 0 + # `dim` is 1 as the shape of tensor is (bs, seq_len) + position_ids = split_for_sequence_parallel( + position_ids, dim=1, sp_mesh=sp_mesh) + input_embeds = split_for_sequence_parallel( + input_embeds, dim=1, sp_mesh=sp_mesh) + if labels is not None: + labels = split_for_sequence_parallel( + labels, dim=1, sp_mesh=sp_mesh) + + attn_context.update_info('position_ids', position_ids) + + use_liger_kernel = os.environ.get('USE_LIGER_KERNEL') + if use_liger_kernel and labels is not None and self.training: + output_attentions = output_attentions if output_attentions is not None else self.language_model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.language_model.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.language_model.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.language_model.model( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.language_model.config.hidden_size) + shift_labels = shift_labels.view(-1) + + if LigerFusedLinearCrossEntropyLoss is None: + raise ImportError('LigerFusedLinearCrossEntropyLoss is not available, ' + 'please install liger-kernel by "pip install liger_kernel".') + lce = LigerFusedLinearCrossEntropyLoss() + if hasattr(self.language_model, 'lm_head'): + loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels) + else: + loss = lce(self.language_model.output.weight, shift_hidden_states, shift_labels) + if sp_size > 1: + loss = rescale_sp_loss(loss, shift_labels, sp_mesh=sp_mesh) + logits = None + else: + outputs = self.language_model( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if sp_size > 1: + loss = rescale_sp_loss(loss, shift_labels, sp_mesh=sp_mesh) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/llama.py b/xtuner/_lite/accelerate/dispatches/huggingface/llama.py new file mode 100644 index 000000000..0630f6f58 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/huggingface/llama.py @@ -0,0 +1,257 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from mmengine import MessageHub +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from xtuner._lite.accelerate import lmdeploy_is_available, liger_kernel_is_available + +from .._attention import flash_attn_wo_mask, varlen_flash_attn + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def llama_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs # FlashAttentionKwargs in the future +) -> Tuple[torch.Tensor, Optional[torch.Tensor], +Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + attn_context = MessageHub.get_instance('packed_sequence') + sp_mesh = attn_context.get_info('sp_mesh') + + position_ids = attn_context.get_info('position_ids') + assert position_ids.size(1) == q_len, f'{position_ids.size(1)} {q_len}' + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + cumulative_lengths = attn_context.get_info('cumulative_lengths') + if cumulative_lengths is not None and bsz == 1: + max_seqlen = attn_context.get_info('max_seqlen') + attn_output = varlen_flash_attn(query_states, key_states, value_states, + cumulative_lengths, max_seqlen, dropout_p=dropout_rate, + training=self.training, sp_mesh=sp_mesh) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + dropout_p=dropout_rate, + training=self.training) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + + +def llama_casual_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + label_shifted = False, + **kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + if labels is None: + loss = None + logits = self.lm_head(hidden_states) + else: + if liger_kernel_is_available(): + # unable to return logits when using Liger Kernel + logits = None + + if label_shifted: + shift_hidden_states = hidden_states + shift_labels = labels + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_hidden_states.device) + + from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + + loss_fct = LigerFusedLinearCrossEntropyLoss() + loss = loss_fct(self.lm_head.weight, shift_hidden_states, shift_labels, self.lm_head.bias) + + else: + logits = self.lm_head(hidden_states) + + if label_shifted: + shift_logits = logits + shift_labels = labels + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/minicpmv.py b/xtuner/_lite/accelerate/dispatches/huggingface/minicpmv.py new file mode 100644 index 000000000..5e326ea5f --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/huggingface/minicpmv.py @@ -0,0 +1,56 @@ +import os +from transformers.modeling_outputs import CausalLMOutputWithPast + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss +except ImportError: + LigerFusedLinearCrossEntropyLoss = None + + +def minicpmv_forward(self, data, **kwargs): + vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) + use_liger_kernel = os.environ.get('USE_LIGER_KERNEL') + labels = data.get('labels') + if use_liger_kernel and labels is not None and self.training: + output_attentions = self.config.output_attentions + output_hidden_states = self.config.output_hidden_states + return_dict = self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.llm.model( + inputs_embeds=vllm_embedding, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.llm.config.hidden_size) + shift_labels = shift_labels.view(-1) + + if LigerFusedLinearCrossEntropyLoss is None: + raise ImportError('LigerFusedLinearCrossEntropyLoss is not available, ' + 'please install liger-kernel by "pip install liger_kernel".') + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.llm.lm_head.weight, shift_hidden_states, shift_labels) + if not return_dict: + output = (None,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + return self.llm( + input_ids=None, + inputs_embeds=vllm_embedding, + labels=labels, + **kwargs + ) diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/qwen2.py b/xtuner/_lite/accelerate/dispatches/huggingface/qwen2.py new file mode 100644 index 000000000..1b85e37f1 --- /dev/null +++ b/xtuner/_lite/accelerate/dispatches/huggingface/qwen2.py @@ -0,0 +1,455 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from mmengine import MessageHub +from transformers.cache_utils import Cache +from transformers.models.qwen2.modeling_qwen2 import (apply_rotary_pos_emb, + repeat_kv) +from transformers.modeling_outputs import CausalLMOutputWithPast +from xtuner._lite.accelerate import lmdeploy_is_available, liger_kernel_is_available +from .._attention import flash_attn_wo_mask, varlen_flash_attn + +SUPPORT_FLASH2 = False + +try: + from flash_attn import flash_attn_func + _flash_supports_window_size = 'window_size' in list( + inspect.signature(flash_attn_func).parameters) + SUPPORT_FLASH2 = True +except ImportError: + pass + + +def _qwen2_attn_varlen_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +): + is_training = self.training + + # assert is_training == (past_key_value is None) + + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37' + ' Please make sure use `attention_mask` instead.`') + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + bsz, q_len, _ = hidden_states.size() + + attn_context = MessageHub.get_instance('packed_sequence') + position_ids = attn_context.get_info('position_ids') + assert position_ids.size(1) == q_len, f'{position_ids.size(1)} {q_len}' + sp_mesh = attn_context.get_info('sp_mesh') + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + 'The cache structure has changed since version v4.36. ' + f'If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, ' + 'please make sure to initialize the attention class ' + 'with a layer index.') + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, + self.layer_idx) + + cos, sin = self.rotary_emb(value_states, position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value + # `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if (getattr(self.config, 'sliding_window', None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + 'past key must have a shape of (`batch_size, num_heads, ' + 'self.config.sliding_window-1, head_dim`), got' + f' {past_key.shape}') + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat( + [attention_mask, + torch.ones_like(attention_mask[:, -1:])], + dim=-1) + + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads for sequence parallel + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for + # training stability reasons, therefore the input hidden states gets + # silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # ----------------- flash attention forward ------------------------# + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and q_len != 1 + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, 'sliding_window', None) is not None + and kv_seq_len > self.config.sliding_window + and self.layer_idx < self.config.max_window_layers + and self.config.use_sliding_window) + + window_size = (self.config.sliding_window, + self.config.sliding_window) if use_sliding_windows else (-1, + -1) + + assert SUPPORT_FLASH2 + cumulative_lengths = attn_context.get_info('cumulative_lengths') + if cumulative_lengths is not None and SUPPORT_FLASH2 and bsz == 1: + max_seqlen = attn_context.get_info('max_seqlen') + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_lengths, + max_seqlen, + causal=causal, + dropout_p=dropout_rate, + window_size=window_size, + training=self.training, + sp_mesh=sp_mesh) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=causal, + dropout_p=dropout_rate, + window_size=window_size, + training=self.training, + sp_mesh=sp_mesh) + + # ---------------- flash attention forward end ------------------- # + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + + + +def _qwen2_attn_contiguous_batching_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +): + + + from lmdeploy.pytorch.kernels import \ + apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy + from lmdeploy.pytorch.kernels import fill_kv_cache, paged_attention_fwd + attn_ctx = MessageHub.get_instance('paged_attention') + kv_seq_length = attn_ctx.get_info('kv_seq_length') + q_seq_length = attn_ctx.get_info('q_seq_length') + q_start_loc = attn_ctx.get_info('q_start_loc') + block_offsets = attn_ctx.get_info('block_offsets') + max_q_seq_length = attn_ctx.get_info('max_q_seq_length') + max_kv_seq_length = attn_ctx.get_info('max_kv_seq_length') + cumulative_length = attn_ctx.get_info('cumulative_length') + is_prefilling = attn_ctx.get_info('is_prefilling') + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + fill_kv_cache( + key_states.transpose(1, 2), + value_states.transpose(1, 2), + past_key_value[self.layer_idx][0], + past_key_value[self.layer_idx][1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + ) + + # ----------------- flash attention forward ------------------------# + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and q_len != 1 + + use_sliding_windows = False + + window_size = (self.config.sliding_window, + self.config.sliding_window) if use_sliding_windows else (-1, + -1) + # TODO support sliding window attention + from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + if is_prefilling: + + key_states = repeat_kv( + key_states, self.num_key_value_groups) + value_states = repeat_kv( + value_states, self.num_key_value_groups) + + attn_output = flash_attn_varlen_func( + query_states.transpose(1,2).squeeze(0), + key_states.transpose(1,2).squeeze(0), + value_states.transpose(1,2).squeeze(0), + cumulative_length, + cumulative_length, + max_q_seq_length, + max_kv_seq_length, + causal=True) + else: + # breakpoint() + query_states = query_states.transpose(1,2).transpose(0,1) + + attn_output = flash_attn_with_kvcache( + query_states, + past_key_value[self.layer_idx][0], + past_key_value[self.layer_idx][1], + cache_seqlens=kv_seq_length, + block_table=block_offsets, + causal=True) + attn_output = attn_output.squeeze(1) + + # ---------------- flash attention forward end ------------------- # + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + attn_weights = None + + return attn_output, attn_weights, past_key_value + + + +def qwen2_attn_flash_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +): + + lmdeploy_ctx = MessageHub.get_instance('paged_attention') + + if lmdeploy_is_available() and len(lmdeploy_ctx.runtime_info) > 0: + + return _qwen2_attn_contiguous_batching_forward(self, hidden_states,attention_mask, position_ids, + past_key_value, use_cache) + else: + return _qwen2_attn_varlen_forward(self, hidden_states, + attention_mask, position_ids, + past_key_value, + output_attentions, use_cache) + + + + +def qwen2_casual_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + label_shifted = False, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if labels is None: + loss = None + logits = self.lm_head(hidden_states) + else: + if liger_kernel_is_available(): + # unable to return logits when using Liger Kernel + logits = None + + if label_shifted: + shift_hidden_states = hidden_states + shift_labels = labels + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_hidden_states.device) + + from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + + loss_fct = LigerFusedLinearCrossEntropyLoss() + loss = loss_fct(self.lm_head.weight, shift_hidden_states, shift_labels, self.lm_head.bias) + + else: + logits = self.lm_head(hidden_states) + + if label_shifted: + shift_logits = logits + shift_labels = labels + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/xtuner/_lite/accelerate/generate.py b/xtuner/_lite/accelerate/generate.py new file mode 100644 index 000000000..d75dffeb8 --- /dev/null +++ b/xtuner/_lite/accelerate/generate.py @@ -0,0 +1,204 @@ +import torch +from mmengine import MessageHub + +from xtuner._lite import get_logger +from .packed import pack_sequence, packed_cumulative_length + +logger = get_logger() + + +@torch.no_grad() +def sample(logits,do_sample=True, top_k=0, top_p=0.9, temperature=1.0): + + if not do_sample: + return logits.argmax(-1) + + # Apply temperature if necessary + if temperature != 1.0: + logits = logits / temperature + + # Apply top-k if necessary + if top_k > 0: + top_k = min(top_k, logits.size(-1)) + _, topk_indices = logits.topk(top_k,dim=-1) + mask = torch.ones_like(logits, dtype=torch.bool) + mask.scatter_(-1, topk_indices, False) + logits.masked_fill_(mask, -torch.inf) + + # Apply top-p (nucleus sampling) if necessary + if top_p < 1.0: + + sorted_logits, sorted_indices = torch.sort(logits, dim=-1) + cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + mask = (cum_probs <= (1- top_p)) + mask[:,-1] = False + sorted_logits.masked_fill_(mask, -torch.inf) + + logits.scatter_( -1, sorted_indices, sorted_logits) + + probs = logits.softmax(-1) + + return torch.multinomial(probs, 1).squeeze(-1) + + +@torch.no_grad() +def contiguous_batching_generate(model, + input_ids, + stop_token_ids=[], + max_batch_size=64, + max_new_tokens=128, + max_length=2048, + do_sample=False, + top_k=0, + top_p=1.0, + temperature=1.0, + tp_size=1, + device='cuda'): + + model.config.use_cache = True + + from lmdeploy.pytorch.config import CacheConfig, ModelConfig + from lmdeploy.pytorch.engine.cache_engine import CacheEngine + from lmdeploy.logger import get_logger + get_logger('lmdeploy').setLevel('ERROR') + + block_size = 256 + max_batch_size = min(max_batch_size, len(input_ids)) + num_blocks = max_length // block_size * max_batch_size + cache_config = CacheConfig(max_batch_size, block_size, num_blocks, + num_blocks) + + + if model.config.architectures[0] == 'InternLM3ForCausalLM': + model.config.num_hidden_layers = model.config.num_self_decoder_layers + 1 + + model_config = ModelConfig.from_hf_config(model.config) + cache_engine = CacheEngine(cache_config, model_config, world_size=tp_size) + + block_table = torch.arange(num_blocks).reshape(max_batch_size, -1) + + _packed_ids, _num_tokens = pack_sequence(input_ids[:max_batch_size]) + _position_ids = [ + torch.arange(seq.numel()) for seq in input_ids[:max_batch_size] + ] + _packed_pos_ids = torch.cat(_position_ids, dim=0).unsqueeze(0) + _cumulative_length = packed_cumulative_length(_num_tokens) + + next_input_ids = _packed_ids.to(device) + next_position_ids = _packed_pos_ids.to(device) + next_start_pos = _cumulative_length[:-1].to(device) + next_end_pos = (_cumulative_length[1:] - 1).to(device) + next_query_length = _num_tokens.to(device) + next_cache_length = _num_tokens.to(device) + next_block_table = block_table.to(device).to(torch.int32) + next_cumulative_length = _cumulative_length.to(device) + next_is_prefilling = True + + num_sessions = len(input_ids) + computing = [i for i in range(max_batch_size)] + waiting = [i for i in range(max_batch_size, num_sessions)] + + responses = [[] for _ in range(num_sessions)] + + while len(waiting) or len(computing): + + attn_ctx = MessageHub.get_instance('paged_attention') + attn_ctx.update_info('block_offsets', next_block_table) + attn_ctx.update_info('kv_seq_length', next_cache_length) + attn_ctx.update_info('q_seq_length', next_query_length) + attn_ctx.update_info('position_ids', next_position_ids) + attn_ctx.update_info('max_kv_seq_length', next_cache_length.max()) + attn_ctx.update_info('max_q_seq_length', next_query_length.max()) + attn_ctx.update_info('q_start_loc', next_start_pos) + attn_ctx.update_info('cumulative_length', next_cumulative_length) + attn_ctx.update_info('is_prefilling', next_is_prefilling) + + # logger.info('Begin Prefilling') + + outputs = model( + input_ids=next_input_ids, + position_ids=next_position_ids, + past_key_values=cache_engine.gpu_cache, + cache_position=next_position_ids, + ) + + for key in list(attn_ctx.runtime_info.keys()): + attn_ctx.pop_info(key) + + # TODO (pppppM) support sampling + sampled = sample(outputs.logits[0, next_end_pos],do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature) + + _next_input_ids = [] + _next_position_ids = [] + _next_computing = [] + _next_query_length = [] + _next_cache_length = [] + _next_block_table = [] + + for i, sess_id in enumerate(computing): + token_id = sampled[i] + responses[sess_id].append(token_id.item()) + + _sess_new_tokens = len(responses[sess_id]) + _sess_len = _sess_new_tokens + input_ids[sess_id].numel() + + stop = ( + _sess_new_tokens >= max_new_tokens or _sess_len >= max_length + or token_id in stop_token_ids) + + if stop: + # session ended + if len(waiting): + # next step is prefilling + new_sess_id = waiting.pop(0) + new_sess = input_ids[new_sess_id].to(device) + + _new_sess_len = new_sess.size(-1) + # new session override the cache of the stopped session + _next_block_table.append(next_block_table[i]) + _next_computing.append(new_sess_id) + _next_input_ids.append(new_sess) + _next_position_ids.append(torch.arange(_new_sess_len)) + _next_query_length.append(_new_sess_len) + _next_cache_length.append(_new_sess_len) + else: + # next step is decoding + _next_computing.append(sess_id) + _next_block_table.append(next_block_table[i]) + _next_input_ids.append(token_id.reshape(1, -1)) + _next_position_ids.append( + torch.arange(_sess_len - 1, _sess_len)) + _next_query_length.append(1) + _next_cache_length.append(_sess_len) + + computing = _next_computing + if len(computing) == 0: + # All sessions have ended. + assert len(waiting) == 0 + break + + _packed_ids, _num_tokens = pack_sequence(_next_input_ids) + _cumulative_length = packed_cumulative_length(_num_tokens) + + next_input_ids = _packed_ids.to(device) + next_position_ids = torch.cat(_next_position_ids, dim=0).unsqueeze(0) + next_position_ids = next_position_ids.to(device) + next_start_pos = _cumulative_length[:-1].to(device) + next_end_pos = (_cumulative_length[1:] - 1).to(device) + next_query_length = torch.IntTensor(_next_query_length).to(device) + next_cache_length = torch.IntTensor(_next_cache_length).to(device) + next_block_table = torch.stack(_next_block_table).to(device) + + next_cumulative_length = _cumulative_length.to(device) + next_is_prefilling = False + + for i in range(len(cache_engine.gpu_cache)): + cache_engine.gpu_cache.pop() + + del cache_engine + torch.cuda.empty_cache() + + model.config.use_cache = False + + return responses diff --git a/xtuner/_lite/accelerate/load.py b/xtuner/_lite/accelerate/load.py new file mode 100644 index 000000000..bf8a35513 --- /dev/null +++ b/xtuner/_lite/accelerate/load.py @@ -0,0 +1,32 @@ +import torch + + +class LoadWoInit: + """Context manager that disable parameter initialization.""" + + def __init__(self): + self.constant_ = torch.nn.init.constant_ + self.zeros_ = torch.nn.init.zeros_ + self.ones_ = torch.nn.init.ones_ + self.uniform_ = torch.nn.init.uniform_ + self.normal_ = torch.nn.init.normal_ + self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_ + self.kaiming_normal_ = torch.nn.init.kaiming_normal_ + + def __enter__(self, *args, **kwargs): + torch.nn.init.constant_ = lambda *args, **kwargs: None + torch.nn.init.zeros_ = lambda *args, **kwargs: None + torch.nn.init.ones_ = lambda *args, **kwargs: None + torch.nn.init.uniform_ = lambda *args, **kwargs: None + torch.nn.init.normal_ = lambda *args, **kwargs: None + torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None + torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None + + def __exit__(self, *args, **kwargs): + torch.nn.init.constant_ = self.constant_ + torch.nn.init.zeros_ = self.zeros_ + torch.nn.init.ones_ = self.ones_ + torch.nn.init.uniform_ = self.uniform_ + torch.nn.init.normal_ = self.normal_ + torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_ + torch.nn.init.kaiming_normal_ = self.kaiming_normal_ diff --git a/xtuner/_lite/accelerate/lora.py b/xtuner/_lite/accelerate/lora.py new file mode 100644 index 000000000..ad3c9a3f5 --- /dev/null +++ b/xtuner/_lite/accelerate/lora.py @@ -0,0 +1,5 @@ +LORA_TARGET_MAP = { + 'InternLM2ForCausalLM': ['wqkv', 'wo', 'w1', 'w2', 'w3'], + 'CLIPVisionModel': + ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2'] +} diff --git a/xtuner/_lite/accelerate/packed.py b/xtuner/_lite/accelerate/packed.py new file mode 100644 index 000000000..773e5822a --- /dev/null +++ b/xtuner/_lite/accelerate/packed.py @@ -0,0 +1,72 @@ +from contextlib import contextmanager +from typing import List, Union + +import torch + +from xtuner._lite import get_device +from xtuner._lite.parallel import get_sp_mesh, split_for_sequence_parallel + + +def unpack_sequence(packed: torch.Tensor, + num_tokens: Union[torch.Tensor, List], + dim=1): + + if isinstance(num_tokens, torch.Tensor): + num_tokens = num_tokens.tolist() + sequences = torch.split(packed, num_tokens, dim=dim) + return sequences + + +def pack_sequence(sequences, dim=1): + num_tokens = torch.IntTensor([seq.size(dim) for seq in sequences]) + packed = torch.cat(sequences, dim=dim) + return packed, num_tokens.to(packed.device) + + +def packed_cumulative_length(num_tokens: torch.Tensor): + + device = num_tokens.device + _zero_pad = torch.zeros(1, device=device) + _pad_length = torch.cat([_zero_pad, num_tokens]).int() + return torch.cumsum(_pad_length, 0).int() + + +@contextmanager +def packed_sequence(num_tokens, enable=True, sp_mesh=None): + from mmengine import MessageHub + ctx = MessageHub.get_instance('packed_sequence') + + device = get_device() + if enable: + num_tokens = num_tokens.to(device) + device = num_tokens.device + _zero_length = torch.zeros(1, device=device) + _pad_length = torch.cat([_zero_length, num_tokens]).int() + cumulative_lengths = torch.cumsum(_pad_length, 0).int() + position_ids = [torch.arange(num.item()) for num in num_tokens] + position_ids = torch.cat(position_ids, dim=0).to(device) + position_ids = position_ids.unsqueeze(0) + if sp_mesh: + # `dim` is 1 as the shape of tensor is (bs, seq_len) + position_ids = split_for_sequence_parallel( + position_ids, dim=1, sp_mesh=sp_mesh) + + # ctx.update_info('num_tokens', num_tokens) + ctx.update_info('position_ids', position_ids) + ctx.update_info('cumulative_lengths', cumulative_lengths) + ctx.update_info('max_seqlen', num_tokens.max()) + ctx.update_info('sp_mesh', sp_mesh) + + else: + # ctx.update_info('num_tokens', None) + ctx.update_info('position_ids', None) + ctx.update_info('cumulative_lengths', None) + ctx.update_info('max_seqlen', None) + ctx.update_info('sp_mesh', None) + yield + + # ctx.update_info('num_tokens', None) + ctx.update_info('position_ids', None) + ctx.update_info('cumulative_lengths', None) + ctx.update_info('max_seqlen', None) + ctx.update_info('sp_mesh', None) diff --git a/xtuner/_lite/accelerate/utils.py b/xtuner/_lite/accelerate/utils.py new file mode 100644 index 000000000..86629fef5 --- /dev/null +++ b/xtuner/_lite/accelerate/utils.py @@ -0,0 +1,56 @@ +import time +from contextlib import contextmanager +from transformers.utils.import_utils import is_flash_attn_2_available +from xtuner._lite import get_device, get_logger, get_torch_device_module + +logger = get_logger() + + + + +def npu_is_available(): + return get_device() == 'npu' + + +def varlen_attn_is_available(): + + return is_flash_attn_2_available() or npu_is_available() + + +def lmdeploy_is_available(): + + available = False + try: + import lmdeploy # noqa: F401 + available = True + except ImportError: + available = False + + return available + +def liger_kernel_is_available(): + + available = False + try: + import liger_kernel # noqa: F401 + available = True + except ImportError: + available = False + + return available + + +@contextmanager +def profile_time_and_memory(desc): + + torch_device = get_torch_device_module() + start_t = time.time() + torch_device.reset_peak_memory_stats() + + yield + + max_memory = torch_device.max_memory_allocated() + cost_time = time.time() - start_t + + logger.success(f'{desc} Elapsed time {cost_time:.2f} seconds, ' + f'peak gpu memory {max_memory/1024**3:.1f}G') diff --git a/xtuner/_lite/algorithms/__init__.py b/xtuner/_lite/algorithms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/_lite/algorithms/ppo/__init__.py b/xtuner/_lite/algorithms/ppo/__init__.py new file mode 100644 index 000000000..2a76959bf --- /dev/null +++ b/xtuner/_lite/algorithms/ppo/__init__.py @@ -0,0 +1,10 @@ +from .dataset import RewardBufferCollator, InferDataset, PPOTokenizeFunction, RewardBuffer +from .loss import (CriticLoss, PPOPolicyLoss, compute_advantages_and_returns, + compute_kl_rewards, gather_logprobs) +from .model import build_actor_model, build_reward_model + +__all__ = [ + 'PPOCollator', 'PPODataset', 'PPOTokenizeFunction', 'CriticLoss', + 'PPOPolicyLoss', 'compute_advantages_and_returns', 'compute_rewards', + 'gather_logprobs', 'build_actor_model', 'build_reward_model' +] diff --git a/xtuner/_lite/algorithms/ppo/dataset.py b/xtuner/_lite/algorithms/ppo/dataset.py new file mode 100644 index 000000000..4a40d75b4 --- /dev/null +++ b/xtuner/_lite/algorithms/ppo/dataset.py @@ -0,0 +1,170 @@ +import torch +import json +import numpy as np +from xtuner._lite.chat.messages.chat import ChatMsg +from xtuner._lite.datasets import OPENAI_CONVERT_MAP +from torch import nn +from ..sft import SftCollator, SftTokenizeFunction + + +class InferDataset(torch.utils.data.Dataset): + + def __init__(self, prompts, responses): + super().__init__() + + assert len(prompts) == len(responses) + self.prompts = prompts + self.responses = responses + self.policies = None + + def __len__(self): + return len(self.prompts) + + + def __getitem__(self, item): + + prompt = self.prompts[item] + response = self.responses[item] + num_prefill_tokens = len(prompt) + + input_ids = prompt + response + labels = [-100] * (num_prefill_tokens - 1) + response + [-100] + + return { + 'input_ids': input_ids, + 'labels': labels, + 'num_tokens': len(input_ids) + } + + + + +FASTER = False +class RewardBuffer(torch.utils.data.Dataset): + + def __init__(self, clip_min=-5,clip_max = 5, normalize=True, faster=False): + super().__init__() + + + self.clip_min = clip_min + self.clip_max = clip_max + + self.normalize = normalize + + if self.normalize: + self.bn = nn.BatchNorm1d(1, momentum=None, affine=False) + else: + self.bn = None + + self._num_action_tokens = 0 + self._num_total_tokens = 0 + self._trajectories = [] + + self._current_mean = 0 + + @property + def running_mean(self): + return self.bn.running_mean.item() + + @property + def current_mean(self): + return self._current_mean + + @property + def num_action_tokens(self): + return self._num_action_tokens.item() + + @property + def num_total_tokens(self): + return self._num_total_tokens + + def update(self, trajectories): + + rewards = [data['reward'] for data in trajectories] + + for i in range(len(trajectories)): + trajectories[i]['ori_reward'] = trajectories[i]['reward'] + + rewards = torch.tensor(rewards) + + self._current_mean = rewards.mean().item() + + rewards = rewards.clip(self.clip_min, self.clip_max) + + if self.normalize: + self.bn.train() + _ = self.bn(rewards.unsqueeze(-1)) + self.bn.eval() + rewards = self.bn(rewards.unsqueeze(-1)) + + for i in range(len(trajectories)): + trajectories[i]['reward'] = rewards[i].item() + + num_total_tokens = 0 + num_action_tokens = 0 + for data in trajectories: + labels = np.array(data['labels']) + num_total_tokens += labels.size + num_action_tokens += (labels >= 0).sum() + + self._num_action_tokens = num_action_tokens + self._num_total_tokens = num_total_tokens + + self._trajectories = trajectories + + def dump_jsonl(self, path, tokenizer, debug=False): + + with open(path, 'w', encoding='utf8') as f: + for data in self._trajectories: + json_line = { + 'num_tokens': data['num_tokens'], + 'reward': data['ori_reward'], + 'sequence': tokenizer.decode(data['input_ids']), + } + + if debug: + json_line['input_ids'] = data['input_ids'] + json_line['labels'] = data['labels'] + + json_str = json.dumps(json_line, ensure_ascii=False) + f.write(json_str + '\n') + + def __len__(self): + return len(self._trajectories) + + + def __getitem__(self, item): + + return self._trajectories[item] + + +class PPOTokenizeFunction(SftTokenizeFunction): + + def __init__(self, + tokenizer, + chat_template, + raw_format='openai', + sys_prompt=None): + super().__init__(tokenizer, chat_template, raw_format) + self.sys_prompt = sys_prompt + + def __call__(self, item): + + formatter = OPENAI_CONVERT_MAP[self.raw_format] + msg = formatter(item) + if self.sys_prompt is not None: + sys_msg = ChatMsg(role='system', content=self.sys_prompt) + msg.messages = [sys_msg] + msg.messages + tokenized = msg.tokenize(self.tokenizer, self.chat_template) + + return tokenized + + +class RewardBufferCollator(SftCollator): + + def __call__(self, instances): + + data = super().__call__(instances) + data['rewards'] = [item['reward'] for item in instances] + + return data diff --git a/xtuner/_lite/algorithms/ppo/loss.py b/xtuner/_lite/algorithms/ppo/loss.py new file mode 100644 index 000000000..167379719 --- /dev/null +++ b/xtuner/_lite/algorithms/ppo/loss.py @@ -0,0 +1,125 @@ +import torch +from torch.nn import functional as F + +from xtuner._lite import get_logger + +logger = get_logger() + + + +def gather_logprobs(logits, labels): + log_probs = F.log_softmax(logits, dim=-1) + log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return log_probs_labels.squeeze(-1) + +@torch.no_grad() +def compute_kl_rewards(logprobs, ref_logprobs, reward_score, kl_coef=0.01): + + assert logprobs.ndim == 1 + last_mask = torch.zeros_like(logprobs, dtype=torch.int) + last_mask[-1] = 1 + + kl = (ref_logprobs - logprobs) + kl_reward = kl_coef * kl * (1 - last_mask) + + last_reward = reward_score * last_mask + + rewards = kl_reward + last_reward + + return rewards + +@torch.no_grad() +def compute_advantages_and_returns(values, rewards, gamma=1.0, gae_lambda=0.99): + # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 + """Function that computes advantages and returns from rewards and + values. Calculated as in the original PPO paper: + https://arxiv.org/abs/1707.06347 Note that rewards may include a KL + divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + """ + lastgaelam = 0 + advantages_reversed = [] + + assert values.numel() == rewards.numel(), f'{values.numel()}, {rewards.numel()}' + length = rewards.numel() + + for t in reversed(range(0, length)): + nextvalues = values[t + 1] if t < length - 1 else 0.0 + # Since old_rewards and old_values are masked with action_mask, + # i.e. they have 0's at pad tokens, + # delta will be 0 if current t is at a pad token, + # so will lastgaelam + delta = rewards[t] + gamma * nextvalues - values[t] + lastgaelam = delta + gamma * gae_lambda * lastgaelam + advantages_reversed.append(lastgaelam) + + advantages = torch.stack(advantages_reversed[::-1], dim=0) + returns = advantages + values + return advantages.detach(), returns + + +class CriticLoss(torch.nn.Module): + """Loss function for critic model.""" + + def __init__(self, + cliprange_value: float = 0.5, + loss_type: str = 'per_seq'): + super().__init__() + self.cliprange_value = cliprange_value + self.loss_type = loss_type + + assert self.loss_type in ['per_token', 'per_seq'] + + def critic_loss_fn(self, values, old_values, returns, loss_factor=None): + values_clipped = old_values + (values - old_values).clamp( + -self.cliprange_value, self.cliprange_value) + vf_loss1 = (values_clipped - returns)**2 + vf_loss2 = (values - returns)**2 + if self.loss_type == 'per_seq': + vf_loss = torch.max(vf_loss1, vf_loss2).mean(-1) + elif self.loss_type == 'per_token': + assert loss_factor is not None + vf_loss = torch.sum(torch.max(vf_loss1, vf_loss2) * loss_factor) + return 0.5 * vf_loss + + def forward(self, + values: torch.Tensor, + old_values, + returns, + loss_factor=None): + + loss = self.critic_loss_fn( + values=values, + old_values=old_values, + returns=returns, + loss_factor=loss_factor) + return loss + + +class PPOPolicyLoss(torch.nn.Module): + """Loss function for policy model.""" + + def __init__(self, cliprange: float = 0.2, loss_type: str = 'per_seq'): + super().__init__() + self.cliprange = cliprange + self.loss_type = loss_type + assert self.loss_type in ['per_token', 'per_seq'] + + def forward(self, logprobs, old_logprobs, advantages, loss_factor=None): + ratio = (logprobs - old_logprobs).exp() + pg_loss1 = -ratio * advantages + pg_loss2 = -ratio.clamp(1 - self.cliprange, + 1 + self.cliprange) * advantages + if self.loss_type == 'per_seq': + pg_loss = torch.max(pg_loss1, pg_loss2).mean(dim=-1) + elif self.loss_type == 'per_token': + assert loss_factor is not None + pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2)) * loss_factor + return pg_loss diff --git a/xtuner/_lite/algorithms/ppo/model.py b/xtuner/_lite/algorithms/ppo/model.py new file mode 100644 index 000000000..2f90e810f --- /dev/null +++ b/xtuner/_lite/algorithms/ppo/model.py @@ -0,0 +1,46 @@ +import torch +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.utils.import_utils import (is_flash_attn_2_available, + is_torch_sdpa_available) + +from xtuner._lite.accelerate import LoadWoInit + + +def build_actor_model(model_path, dtype=torch.float32, trust_remote_code=True): + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if is_flash_attn_2_available(): + config.attn_implementation = 'flash_attention_2' + elif is_torch_sdpa_available(): + config.attn_implementation = 'sdpa' + + with LoadWoInit(): + policy = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation='flash_attention_2', + torch_dtype=dtype, + trust_remote_code=trust_remote_code) + + return policy + + +def build_reward_model(model_path, dtype=torch.float32, trust_remote_code=True): + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if is_flash_attn_2_available(): + config.attn_implementation = 'flash_attention_2' + elif is_torch_sdpa_available(): + config.attn_implementation = 'sdpa' + + config.use_cache = False + config.torch_dtype = dtype + with LoadWoInit(): + reward = AutoModel.from_pretrained( + model_path, + attn_implementation='flash_attention_2', + torch_dtype=dtype, + trust_remote_code=trust_remote_code) + + reward.model.use_cache = False + + return reward diff --git a/xtuner/_lite/algorithms/sft/__init__.py b/xtuner/_lite/algorithms/sft/__init__.py new file mode 100644 index 000000000..01a3a63a2 --- /dev/null +++ b/xtuner/_lite/algorithms/sft/__init__.py @@ -0,0 +1,3 @@ +from .dataset import SftCollator, SftTokenizeFunction + +__all__ = ['SftCollator', 'SftTokenizeFunction'] diff --git a/xtuner/_lite/algorithms/sft/dataset.py b/xtuner/_lite/algorithms/sft/dataset.py new file mode 100644 index 000000000..bbb9e9608 --- /dev/null +++ b/xtuner/_lite/algorithms/sft/dataset.py @@ -0,0 +1,109 @@ +import torch +from torch.nn.utils.rnn import pad_sequence + +from xtuner._lite import get_logger +from xtuner._lite.datasets import OPENAI_CONVERT_MAP + +logger = get_logger() + + +class SftTokenizeFunction(): + + def __init__(self, tokenizer, chat_template, raw_format='openai'): + + self.tokenizer = tokenizer + self.chat_template = chat_template + self.raw_format = raw_format + + def __call__(self, item): + + formatter = OPENAI_CONVERT_MAP[self.raw_format] + msg = formatter(item) + tokenized = msg.tokenize(self.tokenizer, self.chat_template) + return tokenized + + +class SftCollator(): + + def __init__(self, pad_token_id=0, ignore_id=-100, pack_batch=False, max_length=None): + self.pack_batch = pack_batch + self.pad_token_id = pad_token_id + self.ignore_id = ignore_id + self.max_length = max_length + + def __call__(self, instances): + + _instances = [] + for ins in instances: + if isinstance(ins, list): + _instances.extend(ins) + else: + _instances.append(ins) + + instances = _instances + + input_ids = [] + labels = [] + num_tokens = [] + + for data in instances: + + _input_ids = data['input_ids'] + _labels = data['labels'] + _num_tokens = data['num_tokens'] + + # TODO remove list + if isinstance(_num_tokens, list): + assert len(_num_tokens) == 1 + _num_tokens = _num_tokens[0] + + assert isinstance(_num_tokens, int) + + if self.max_length: + _input_ids = _input_ids[:self.max_length] + _labels = _labels[:self.max_length] + _num_tokens = min(_num_tokens, self.max_length) + + input_ids.append(torch.LongTensor(_input_ids)) + labels.append(torch.LongTensor(_labels)) + num_tokens.append(_num_tokens) + + attention_mask = [torch.ones_like(ids) for ids in input_ids] + num_tokens = torch.IntTensor(num_tokens) + + if len(instances) > 1 and self.pack_batch: + + input_ids = torch.cat(input_ids, dim=0).unsqueeze(0) + labels = torch.cat(labels, dim=0).unsqueeze(0) + attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0) + + elif len(instances) > 1 and not self.pack_batch: + + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=self.pad_token_id) + labels = pad_sequence( + labels, batch_first=True, padding_value=self.ignore_id) + attention_mask = pad_sequence( + attention_mask, batch_first=True, padding_value=0) + else: + input_ids = torch.stack(input_ids) + labels = torch.stack(labels) + attention_mask = torch.stack(attention_mask) + + if input_ids.shape != labels.shape: + logger.error(f'[instances] {instances}') + logger.error(f'[num_tokens] {num_tokens}') + logger.error(f'[input_ids] {input_ids}') + logger.error(f'[labels] {labels}') + raise RuntimeError('The shape of input_ids and labels must be ' + f'equal, but found {input_ids.shape} and ' + f'{labels.shape}.') + + data_dict = { + 'input_ids': input_ids, + 'labels': labels, + 'num_tokens': num_tokens, + 'attention_mask': attention_mask.bool() + } + + return data_dict diff --git a/xtuner/_lite/auto.py b/xtuner/_lite/auto.py new file mode 100644 index 000000000..d6e931836 --- /dev/null +++ b/xtuner/_lite/auto.py @@ -0,0 +1,185 @@ +import math +import os +from typing import Literal, Optional + +import torch +from transformers import BitsAndBytesConfig, PretrainedConfig + +from xtuner.model.modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 + +if os.environ.get('XTUNER_USE_MODELSCOPE'): + from modelscope import AutoTokenizer # noqa: F401 + from modelscope import AutoConfig + from modelscope import AutoModelForCausalLM as OriAutoModelForCausalLM +else: + from transformers import AutoTokenizer # noqa: F401 + from transformers import AutoConfig + from transformers import AutoModelForCausalLM as OriAutoModelForCausalLM + + +def download_model_from_hub( + model_name_or_path: str, + from_hub: Literal['huggingface', 'modelscope'] = 'huggingface', + cache_dir: Optional[str] = None, +) -> str: + """Automatically download model from the HUB. + + Note: + If `model_name_or_path` is a local path, it will return the path + directly without downloading it again. + + Args: + model_name_or_path (str): The model name, model path or repo id. + config (str | None): The config path. Default is None. + from_hub (str): The model hosting hub, modelscope, or huggingface. + Default is huggingface. + cache_dir (str | None): + The save path when downloading the model. If it is None, it + will be stored in the default location of the HUB. For + Huggingface, it's ~/.cache/huggingface/hub, for ModelScope, + it's ~/.cache/modelscope/hub. + + Returns: + str: The local path of the model. + """ + if os.path.isdir(model_name_or_path): + model_path = model_name_or_path + elif from_hub == 'huggingface': + from huggingface_hub import snapshot_download + model_path = snapshot_download( + repo_id=model_name_or_path, cache_dir=cache_dir) + elif from_hub == 'modelscope': + from modelscope import snapshot_download + model_path = snapshot_download( + model_id=model_name_or_path, cache_dir=cache_dir) + else: + # TODO support openxlab + raise NotImplementedError('The model does not support downloading ' + f'from {from_hub}, it only supports ' + '`huggingface` and `modelscope`.') + + return model_path + + +class AutoModelForCausalLM: + """Enhanced version of Huggingface's `AutoModelForCausalLM`. + + Compared to HuggingFace's `AutoModelForCausalLM`, the following three + features have been added: + + 1. Load the model from either HuggingFace or ModelScope based on the + environment variable `XTUNER_USE_MODELSCOPE` (bool). + 2. Automatically enables Flash Attention. If `flash-attn` is already + installed, Flash Attention 2 will be used. If there is no + `flash-attn`, use Flash Attention 1 when torch version is less than + 2.2. When torch version is greater than or equal to 2.2, use Flash + Attention 2. + 3. When the length of the target sequence during training exceeds the + maximum length of the original model, the rope scaling is + automatically set to the `linear` type with a factor of 1." + + Note: + If the model is built through `from_config`, it will not automatically + enable flash attention or modify rope scaling. + + Note: + If you want to load the model on ModelScope, please set the + environment variable `XTUNER_USE_MODELSCOPE=1`. + """ + + @classmethod + def from_config(cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool = True, + **kwargs): + """Consistent with the usage of HuggingFace's AutoModelForCausalLM.""" + return AutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool = True, + quantization_config: Optional[BitsAndBytesConfig] = None, + max_position_embeddings: Optional[int] = None, + **kwargs): + """Consistent with the usage of HuggingFace's AutoModelForCausalLM.""" + config = cls.from_config( + pretrained_model_name_or_path, trust_remote_code=True) + + attn_kwargs = cls._flash_attn_kwargs(config) + kwargs.update(attn_kwargs) + + if max_position_embeddings: + long_ctx_kwargs = cls._long_ctx_kwargs(config, + max_position_embeddings) + kwargs.update(long_ctx_kwargs) + + if 'torch_dtype' not in kwargs: + if torch.cuda.is_bf16_supported(): + kwargs.update(torch_dtype=torch.bfloat16) + else: + kwargs.update(torch_dtype=torch.float16) + + model = OriAutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + quantization_config=quantization_config, + **kwargs) + + from xtuner._lite.accelerate import dispatch_modules + dispatch_modules(model, use_varlen_attn=True) + + return model + + @staticmethod + def _flash_attn_kwargs(config: PretrainedConfig) -> dict: + """Arguments Required to Enable Flash Attention.""" + cls_name = type(config).__name__ + _built_in_flash_attn_1 = ('LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', + 'Qwen2Config', 'Starcoder2Config', + 'Starcoder2Config') + + _built_in_flash_attn_2 = ('InternLMConfig', 'InternLM2Config', + 'LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', + 'Qwen2Config', 'Starcoder2Config', + 'Starcoder2Config') + + attn_kwargs = {} + if SUPPORT_FLASH2 and cls_name in _built_in_flash_attn_2: + attn_kwargs.update(attn_implementation='flash_attention_2') + elif SUPPORT_FLASH1 and cls_name in _built_in_flash_attn_1: + attn_kwargs.update(attn_implementation='sdpa') + + return attn_kwargs + + @staticmethod + def _long_ctx_kwargs(config: PretrainedConfig, + max_position_embeddings: int) -> dict: + """Arguments Required for Long Context Training.""" + ori_rope_scaling = getattr(config, 'rope_scaling', None) + if ori_rope_scaling is None: + ori_rope_scaling = {'factor': 1} + + if 'factor' in ori_rope_scaling.keys(): + ori_rope_scaling_factor = ori_rope_scaling['factor'] + else: + ori_rope_scaling_factor = 1 + + ori_ctx_len = getattr(config, 'max_position_embeddings', None) + + long_text_kwargs = {} + if ori_ctx_len: + ori_ctx_len *= ori_rope_scaling_factor + if max_position_embeddings > ori_ctx_len: + scaling_factor = float( + math.ceil(max_position_embeddings / ori_ctx_len)) + + new_rope_scaling = {'type': 'linear', 'factor': scaling_factor} + long_text_kwargs.update(dict(rope_scaling=new_rope_scaling)) + return long_text_kwargs diff --git a/xtuner/_lite/chat/__init__.py b/xtuner/_lite/chat/__init__.py new file mode 100644 index 000000000..6443e50b4 --- /dev/null +++ b/xtuner/_lite/chat/__init__.py @@ -0,0 +1,6 @@ +from .messages import ChatMessages +from .templates import CHAT_TEMPLATE_MAP, ChatTemplate, HybridChatTemplate + +__all__ = [ + 'ChatMessages', 'CHAT_TEMPLATE_MAP', 'ChatTemplate', 'HybridChatTemplate' +] diff --git a/xtuner/_lite/chat/backends/__init__.py b/xtuner/_lite/chat/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/_lite/chat/messages/__init__.py b/xtuner/_lite/chat/messages/__init__.py new file mode 100644 index 000000000..b8c75b45d --- /dev/null +++ b/xtuner/_lite/chat/messages/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseMessages +from .chat import ChatMessages + +__all__ = ['BaseMessages', 'ChatMessages'] diff --git a/xtuner/_lite/chat/messages/base.py b/xtuner/_lite/chat/messages/base.py new file mode 100644 index 000000000..0266986ac --- /dev/null +++ b/xtuner/_lite/chat/messages/base.py @@ -0,0 +1,31 @@ +from abc import abstractclassmethod, abstractmethod +from typing import Dict + +from pydantic import BaseModel +from transformers import PreTrainedTokenizer + +from ..templates import ChatTemplate + + +class BaseMessages(BaseModel): + + @abstractmethod + def add(self, role: str, content): + pass + + @abstractmethod + def pop(self): + pass + + @abstractmethod + def get_prompt(self, chat_template: ChatTemplate) -> str: + pass + + @abstractmethod + def tokenize(self, tokenizer: PreTrainedTokenizer, + chat_template: ChatTemplate) -> Dict: + pass + + @abstractclassmethod + def from_dict(cls, item: Dict) -> 'BaseMessages': + pass diff --git a/xtuner/_lite/chat/messages/chat.py b/xtuner/_lite/chat/messages/chat.py new file mode 100644 index 000000000..af1756e8e --- /dev/null +++ b/xtuner/_lite/chat/messages/chat.py @@ -0,0 +1,213 @@ +import copy +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel +from transformers import PreTrainedTokenizer + +from xtuner._lite import get_logger +from xtuner.utils import IGNORE_INDEX +from ..templates import ChatTemplate, HybridChatTemplate +from .base import BaseMessages + +logger = get_logger() + + +class TextContentItem(BaseModel): + type: Literal['text'] = 'text' + text: str + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return self.text + + +class ImageContentItem(BaseModel): + type: Literal['image_url'] = 'image_url' + image_url: str + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return chat_template.image_token + + +MultModalContentType = Union[TextContentItem, ImageContentItem] +ContentType = Union[str, List[MultModalContentType]] + + +class ChatMsg(BaseModel): + + role: Literal['assistant', 'user', 'system'] + content: ContentType + loss: Optional[bool] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.loss is None: + if self.role == 'system': + self.loss = False + elif self.role == 'user': + self.loss = False + elif self.role == 'assistant': + self.loss = True + else: + raise NotImplementedError + + def collect_img_urls(self) -> List[str]: + img_urls = [] + if isinstance(self.content, list): + for item in self.content: + if isinstance(item, ImageContentItem): + img_urls.append(item.image_url) + return img_urls + + def get_prompt(self, chat_template: ChatTemplate) -> str: + + if isinstance(self.content, str): + text = self.content + elif isinstance(self.content, list): + text = '' + for i, item in enumerate(self.content): + if i == 0: + text += item.apply_chat_template(chat_template) + else: + text += '\n' + item.apply_chat_template(chat_template) + else: + raise NotImplementedError + + if self.role == 'system': + prompt = chat_template.decorate_system(text) + elif self.role == 'user': + prompt = chat_template.decorate_user(text) + elif self.role == 'assistant': + prompt = chat_template.decorate_assistant(text) + else: + raise NotImplementedError + + return prompt + + def tokenize( + self, + tokenizer: PreTrainedTokenizer, + chat_template: ChatTemplate, + ): + + decorated = self.get_prompt(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return { + 'input_ids': token_ids, + 'labels': label_ids, + } + + +class ChatMessages(BaseMessages): + + messages: List[ChatMsg] + + def add(self, role, content, loss=False): + self.messages.append(ChatMsg(role=role, content=content, loss=loss)) + + def pop(self): + return self.messages.pop() + + def get_prompt(self, chat_template: ChatTemplate) -> str: + + prompt = '' + + for msg in self.messages: + prompt += msg.get_prompt(chat_template) + if msg.role == 'assistant': + prompt += chat_template.sep + return prompt + + def tokenize(self, tokenizer: PreTrainedTokenizer, + chat_template: ChatTemplate) -> Dict: + + input_ids = tokenizer.encode('', add_special_tokens=True) + labels = [IGNORE_INDEX for _ in input_ids] + image_urls = [] + + + + for msg in self.messages: + res = msg.tokenize(tokenizer, chat_template) + token_ids, label_ids = res['input_ids'], res['labels'] + + input_ids.extend(token_ids) + labels.extend(label_ids) + + image_urls.extend(msg.collect_img_urls()) + + if msg.role == 'assistant': + sep = chat_template.sep + sep_tokens = tokenizer.encode(sep, add_special_tokens=False) + input_ids.extend(sep_tokens) + labels.extend([IGNORE_INDEX] * len(sep_tokens)) + + if len(input_ids) != len(labels): + logger.error(f'[messages] {self.messages}') + logger.error(f'[input_ids] {input_ids}') + logger.error(f'[labels] {labels}') + raise RuntimeError('The lengths of input_ids and labels must be ' + f'equal, but found {len(input_ids)} and ' + f'{len(labels)}.') + + training_data = { + 'input_ids': input_ids, + 'labels': labels, + 'num_tokens': len(input_ids), + } + + if len(image_urls) > 0: + training_data['image_urls'] = image_urls + + return training_data + + @classmethod + def from_str(cls, prompt: str) -> 'ChatMessages': + + msg = ChatMsg(role='user', content=prompt) + return cls(messages=[msg]) + + @classmethod + def from_dict(cls, item: dict) -> 'ChatMessages': + ''' + item + { + 'messages':[ + {'role':'user', 'content':'hello'}, + {'role':'assistant', 'content':'hello!'}, + ], + } + ''' + return cls(**item) + + +if __name__ == '__main__': + + data = { + 'messages': [ + { + 'role': 'user', + 'content': 'hello' + }, + { + 'role': 'assistant', + 'content': 'hello!' + }, + ] + } + + messages = ChatMessages.from_dict(data) + chat_template = ChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + ) + + print(messages.get_prompt(chat_template)) diff --git a/xtuner/_lite/chat/templates/__init__.py b/xtuner/_lite/chat/templates/__init__.py new file mode 100644 index 000000000..7ed468e20 --- /dev/null +++ b/xtuner/_lite/chat/templates/__init__.py @@ -0,0 +1,29 @@ +from .chat import ChatTemplate +from .hybrid import HybridChatTemplate + +CHAT_TEMPLATE_MAP = { + 'internlm2': + HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>', + stop_words=['<|im_end|>']), + 'qwen2': + HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>', + stop_words=['<|im_end|>', '<|endoftext|>']), + 'llama3': + HybridChatTemplate( + system=('<|start_header_id|>system<|end_header_id|>\n\n{system}' + '<|eot_id|>'), + user=('<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>' + '<|start_header_id|>assistant<|end_header_id|>\n\n'), + assistant='{assistant}<|eot_id|>', + sep='', + stop_words=['<|eot_id|>']), + +} + +__all__ = ['ChatTemplate', 'HybridChatTemplate'] diff --git a/xtuner/_lite/chat/templates/chat.py b/xtuner/_lite/chat/templates/chat.py new file mode 100644 index 000000000..9ce574fef --- /dev/null +++ b/xtuner/_lite/chat/templates/chat.py @@ -0,0 +1,59 @@ +from typing import List + +from pydantic import BaseModel, field_validator + + +class ChatTemplate(BaseModel): + """Define a Pydantic data model for a hybrid chat with attributes for + system, user and assistant chat as well as function and interpreter calls + and results.""" + + # Normal Chat + system: str # System message format + user: str # User message format + assistant: str # Assistant message format + stop_words: List[str] # List of stop words + sep: str = '\n' + + def decorate_system(self, text: str) -> str: + """Decorate text with the `system` template.""" + return self.system.format(system=text) + + def decorate_assistant(self, text: str) -> str: + """Decorate text with the `assistant` template.""" + return self.assistant.format(assistant=text) + + def decorate_user(self, text: str) -> str: + """Decorate text with the `user` template.""" + return self.user.format(user=text) + + @field_validator('system') + def check_system(cls, v: str) -> str: + """Validate that `system` contains '{system}'. + + If not, raises a ValueError. + """ + if v is not None and '{system}' not in v: + raise ValueError("system must contain the keyword '{system}'") + return v + + @field_validator('user') + def check_user(cls, v: str) -> str: + """Validate that `user` contains '{user}'. + + If not, raises a ValueError. + """ + if v is not None and '{user}' not in v: + raise ValueError("user must contain the keyword '{user}'") + return v + + @field_validator('assistant') + def check_assistant(cls, v: str) -> str: + """Validate that `assistant` contains '{assistant}'. + + If not, raises a ValueError. + """ + if v is not None and '{assistant}' not in v: + raise ValueError( + "assistant must contain the keyword '{assistant}'") + return v diff --git a/xtuner/_lite/chat/templates/hybrid.py b/xtuner/_lite/chat/templates/hybrid.py new file mode 100644 index 000000000..d9f563c0d --- /dev/null +++ b/xtuner/_lite/chat/templates/hybrid.py @@ -0,0 +1,196 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, field_validator + + +class HybridChatTemplate(BaseModel): + """Define a Pydantic data model for a hybrid chat with attributes for + system, user and assistant chat as well as function and interpreter calls + and results.""" + + # Normal Chat + system: str # System message format + user: str # User message format + assistant: str # Assistant message format + stop_words: List[str] # List of stop words + sep: str = '\n' + + # Multimodal Chat + # Predefined token and index for images + image_token: str = '' + image_token_index: int = -100 + + # Agent Chat + + # Interpreter and function related strings + files: Optional[str] = None + + functions: Optional[str] = None # Function description format + function_call: Optional[str] = None # Function call format + function_result: Optional[str] = None # Function result format + + code_interpreter: Optional[str] = None + code_interpreter_call: Optional[str] = None # Interpreter call format + code_interpreter_result: Optional[str] = None # Interpreter result format + + function_token: Optional[str] = None + code_interpreter_token: Optional[str] = None + action_start_token: Optional[str] = None + action_end_token: Optional[str] = None + + @property + def mm_token_maps(self) -> Dict[str, int]: + """Return a dictionary that maps multimodal tokens to corresponding + token indexes.""" + return {self.image_token: self.image_token_index} + + def decorate_system(self, text: str) -> str: + """Decorate text with the `system` template.""" + return self.system.format(system=text) + + def decorate_assistant(self, text: str) -> str: + """Decorate text with the `assistant` template.""" + return self.assistant.format(assistant=text) + + def decorate_user(self, text: str) -> str: + """Decorate text with the `user` template.""" + return self.user.format(user=text) + + def decorate_files(self, text: str) -> str: + """Decorate text with the `functions` template.""" + return self.files.format(files=text) + + def decorate_functions(self, text: str) -> str: + """Decorate text with the `functions` template.""" + return self.functions.format(functions=text) + + def decorate_function_call(self, text: str, func: str) -> str: + """Decorate text with the `function_call` template.""" + return self.function_call.format(assistant=text, function_call=func) + + def decorate_function_result(self, text: str) -> str: + """Decorate text with the `function_result` template.""" + return self.function_result.format(function_result=text) + + def decorate_code_interpreter(self, text: str) -> str: + """Decorate text with the `code_interpreter` template.""" + return self.code_interpreter.format(code_interpreter=text) + + def decorate_code_interpreter_call(self, text: str, func: str) -> str: + """Decorate text with the `code_interpreter_call` template.""" + return self.code_interpreter_call.format( + assistant=text, code_interpreter_call=func) + + def decorate_code_interpreter_result(self, text: str) -> str: + """Decorate text with the `code_interpreter_result` template.""" + return self.code_interpreter_result.format( + code_interpreter_result=text) + + @field_validator('system') + def check_system(cls, v: str) -> str: + """Validate that `system` contains '{system}'. + + If not, raises a ValueError. + """ + if v is not None and '{system}' not in v: + raise ValueError("system must contain the keyword '{system}'") + return v + + @field_validator('user') + def check_user(cls, v: str) -> str: + """Validate that `user` contains '{user}'. + + If not, raises a ValueError. + """ + if v is not None and '{user}' not in v: + raise ValueError("user must contain the keyword '{user}'") + return v + + @field_validator('assistant') + def check_assistant(cls, v: str) -> str: + """Validate that `assistant` contains '{assistant}'. + + If not, raises a ValueError. + """ + if v is not None and '{assistant}' not in v: + raise ValueError( + "assistant must contain the keyword '{assistant}'") + return v + + @field_validator('function_call') + def check_function_call(cls, v: str) -> str: + """Validate that `function_call` contains '{function_call}'. + + If not, raises a ValueError. + """ + if (v is not None and '{function_call}' not in v + and '{assistant}' not in v): + raise ValueError( + "function_call must contain the keywords '{function_call}'") + if v is not None and '{assistant}' not in v: + raise ValueError( + "function_call must contain the keyword '{assistant}' and " + "'{function_call}'") + return v + + @field_validator('function_result') + def check_function_result(cls, v: str) -> str: + """Validate that `function_result` contains '{function_result}'. + + If not, raises a ValueError. + """ + if v is not None and '{function_result}' not in v: + raise ValueError( + "function_result must contain the keyword '{function_result}'") + return v + + @field_validator('functions') + def check_functions(cls, v: str) -> str: + """Validate that `functions` contains '{functions}'. + + If not, raises a ValueError. + """ + if v is not None and '{functions}' not in v: + raise ValueError( + "functions must contain the keyword '{functions}'") + return v + + @field_validator('code_interpreter') + def check_code_interpreter(cls, v: str) -> str: + """Validate that `code_interpreter` contains '{code_interpreter}'. + + If not, raises a ValueError. + """ + if v is not None and '{code_interpreter}' not in v: + raise ValueError('code_interpreter must contain the keyword ' + "'{code_interpreter}'") + return v + + @field_validator('code_interpreter_call') + def check_code_interpreter_call(cls, v: str) -> str: + """Validate that `code_interpreter_call` contains + '{code_interpreter_call}'. + + If not, raises a ValueError. + """ + if (v is not None and '{code_interpreter_call}' not in v + and '{assistant}' not in v): + raise ValueError('code_interpreter_call must contain the keywords ' + "'{assistant}' and '{code_interpreter_call}'") + if v is not None and '{assistant}' not in v: + raise ValueError('code_interpreter_call must contain the keywords ' + "'{assistant}' and '{code_interpreter_call}'") + return v + + @field_validator('code_interpreter_result') + def check_code_interpreter_result(cls, v: str) -> str: + """Validate that `code_interpreter_result` contains + '{code_interpreter_result}'. + + If not, raises a ValueError. + """ + if v is not None and '{code_interpreter_result}' not in v: + raise ValueError( + 'code_interpreter_result must contain the keyword ' + "'{code_interpreter_result}'") + return v diff --git a/xtuner/_lite/datasets/__init__.py b/xtuner/_lite/datasets/__init__.py new file mode 100644 index 000000000..f7dc3b293 --- /dev/null +++ b/xtuner/_lite/datasets/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .json import JsonDataset +from .jsonl import JsonlDataset +from .pack import SoftPackDataset, HardPackDataset +from .utils import DATASET_CLS_MAP, OPENAI_CONVERT_MAP, load_datasets + +__all__ = [ + 'JsonDataset', 'JsonlDataset', 'SoftPackDataset', 'DATASET_CLS_MAP', + 'OPENAI_CONVERT_MAP', 'load_datasets' +] diff --git a/xtuner/_lite/datasets/internvl2/__init__.py b/xtuner/_lite/datasets/internvl2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/_lite/datasets/internvl2/conversation.py b/xtuner/_lite/datasets/internvl2/conversation.py new file mode 100644 index 000000000..10f7f6b14 --- /dev/null +++ b/xtuner/_lite/datasets/internvl2/conversation.py @@ -0,0 +1,393 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Dict, List, Tuple, Union + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + INTERNVL_ZH = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = '{system_message}' + # The system message + system_message: str = '' + # The names of two roles + roles: Tuple[str] = ('USER', 'ASSISTANT') + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = '\n' + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ': ' # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = '' if system_prompt == '' else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ': ' + + message.replace('\r\n', '\n').replace('\n\n', '\n') + ) + ret += '\n\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = '[INST] ' + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + ' ' + else: + ret += tag + ' ' + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == 'chatglm2' else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = '' + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f'[Round {i//2 + round_add_n}]{self.sep}' + + if message: + ret += f'{role}:{message}{self.sep}' + else: + ret += f'{role}:' + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + '\n' + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = '' + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + '\n' + ' ' + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + # if i % 2 == 0: + # ret += "" + if message: + ret += role + ':' + message + seps[i % 2] + '\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ':\n' + message + seps[i % 2] + if i % 2 == 1: + ret += '\n\n' + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ': ' + '' + message + '' + else: + ret += role + ': ' + '' + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ':\n' + message + self.sep + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = '' + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + + return ret + elif self.sep_style == SeparatorStyle.INTERNVL_ZH: + seps = [self.sep, self.sep2] + ret = self.system_message + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.MPT: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f'Invalid style: {self.sep_style}') + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{'role': 'system', 'content': self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({'role': 'user', 'content': msg}) + else: + if msg is not None: + ret.append({'role': 'assistant', 'content': msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + 'template_name': self.name, + 'system_message': self.system_message, + 'roles': self.roles, + 'messages': self.messages, + 'offset': self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f'{template.name} has been registered.' + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference +# is that during training, the preprocessing function for the Hermes-2 template doesn't add +# at the beginning of the tokenized sequence, while the internlm2-chat template does. +# Therefore, they are completely equivalent during inference. +register_conv_template( + Conversation( + name='Hermes-2', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + stop_token_ids=[ + 2, + 6, + 7, + 8, + ], + stop_str='<|endoftext|>', + ) +) + + +register_conv_template( + Conversation( + name='internlm2-chat', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + stop_token_ids=[ + 2, + 92543, + 92542 + ] + ) +) + + +register_conv_template( + Conversation( + name='phi3-chat', + system_template='<|system|>\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|user|>\n', '<|assistant|>\n'), + sep_style=SeparatorStyle.MPT, + sep='<|end|>', + stop_token_ids=[ + 2, + 32000, + 32007 + ] + ) +) diff --git a/xtuner/_lite/datasets/internvl2/dataset.py b/xtuner/_lite/datasets/internvl2/dataset.py new file mode 100644 index 000000000..bc7bf9653 --- /dev/null +++ b/xtuner/_lite/datasets/internvl2/dataset.py @@ -0,0 +1,295 @@ +import os +import torch.distributed as dist +from mmengine.utils import mkdir_or_exist +from torch.utils.data import ConcatDataset, DataLoader, Dataset +import numpy as np +import json +import math +from concurrent.futures import ProcessPoolExecutor +from tqdm import tqdm +import copy +import random + +from ..json import calculate_json_sha256 +from ..jsonl import calculate_jsonl_sha256 +from ..pack import SoftPackDataset + +from xtuner._lite import get_logger +from xtuner._lite.parallel import get_dp_mesh, VLMLengthGroupedSampler, ParallelSampler + +logger = get_logger() + + +def _load_json_or_jsonl(json_path): + if json_path.endswith('.json'): + with open(json_path) as f: + data = json.load(f) + elif json_path.endswith('.jsonl'): + with open(json_path) as f: + data = f.readlines() + else: + raise ValueError(f'Unsupported file format: {json_path}, ' + f'only support .json and .jsonl.') + return data + + +class BaseOrigDataset(Dataset): + def __init__(self, + data_name, + data, + chat_template, + tokenizer, + max_length, + image_token_str='', + group_by_length=False, + pack_data=False, + pack_data_cache_dir=None, + random_sample=False): + self.data_name = data_name + self.max_length = max_length + self.group_by_length = group_by_length + self.pack_data = pack_data + self.pack_data_cache_dir = pack_data_cache_dir + self.chat_template = chat_template + self.image_token_str = image_token_str + self.tokenizer = tokenizer + self.tokenizer_workers = int(os.environ.get('XTUNER_TOKENIZE_WORKERS', 8)) + + try: + self.root = data['media_root'] + except KeyError: + self.root = data.get('root', '') + logger.info(f"{dist.get_rank()} ======= Start to process dataset: {os.path.basename(data['annotation'])}") + + self.annotation = data['annotation'] + self._is_jsonl = self.annotation.endswith('.jsonl') + self.raw_data = _load_json_or_jsonl(self.annotation) + + # -------------------pack--------------------------------------- + self.num_tokens = None + self.pack_data_cache_dir = pack_data_cache_dir + if pack_data: + assert pack_data_cache_dir is not None, 'pack_data_cache_dir must be provided when pack_data is True' + self.num_tokens = self.calc_packing_info() + assert len(self.num_tokens) == len( + self.raw_data), f'===={len(self.num_tokens)} neq {len(self.raw_data)}====' + + repeat_time = data.get('repeat_time', 1) + if repeat_time < 1: + # If repeat_time is less than 1, select a portion of the data + if random_sample: + num_samples = int(len(self.raw_data) * repeat_time) + sampled = random.sample([i for i in range(len(self.raw_data))], num_samples) + self.raw_data = [self.raw_data[index] for index in sampled] + if pack_data: + self.num_tokens = self.num_tokens[sampled] + else: + num_samples = int(len(self.raw_data) * repeat_time) + self.raw_data = self.raw_data[:num_samples] + if pack_data: + self.num_tokens = self.num_tokens[:num_samples] + + if repeat_time > 1: + assert isinstance(repeat_time, int) + # Repeat the list if repeat_time is greater than 1 + self.raw_data = self.raw_data * repeat_time + if pack_data: + self.num_tokens = np.tile(self.num_tokens, repeat_time) + + if pack_data: + assert len(self.num_tokens) == len(self.raw_data), f' {len(self.num_tokens)} neq {len(self.raw_data)}' + + self.group_length = [] + if self.group_by_length and not pack_data: + self.group_length = self.calc_group_len() + + def __len__(self): + return len(self.raw_data) + + def calc_group_len(self): + raise NotImplementedError + + def calc_packing_info(self): + if os.path.exists(self.pack_data_cache_dir): + assert os.path.isdir(self.pack_data_cache_dir) + else: + mkdir_or_exist(self.pack_data_cache_dir) + + # TODO: more rubost way to calculate the hash + if self._is_jsonl: + file_hash = calculate_jsonl_sha256(self.annotation) + else: + file_hash = calculate_json_sha256(self.annotation) + file_cache_dir = os.path.join(self.pack_data_cache_dir, file_hash) + if not os.path.exists(file_cache_dir): + mkdir_or_exist(file_cache_dir) + + if 'num_tokens.npy' in os.listdir(file_cache_dir): + _cached_file = os.path.join(file_cache_dir, 'num_tokens.npy') + num_tokens = np.load(_cached_file) + logger.info(f"Load num_tokens from cache: {os.path.basename(self.annotation)}") + else: + logger.info(f"Start calculating the cache of num_tokens: {os.path.basename(self.annotation)}") + num_tokens = self.count_tokens_for_pack(file_cache_dir) + return num_tokens + + def count_tokens_for_pack(self, cache_dir=None): + num_samples = len(self.raw_data) + + if dist.is_available(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + world_size = 1 + rank = 0 + + num_per_rank = math.ceil(num_samples / world_size) + + start = rank * num_per_rank + end = (rank + 1) * num_per_rank + dataset_shard = self.raw_data[start:end] + + desc = f'[Rank {rank}] {os.path.basename(self.annotation)}' + with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor: + tokenized = list( + tqdm( + executor.map(self.pre_tokenize_fn_for_pack, dataset_shard, + chunksize=min(max(1, len(dataset_shard) // self.tokenizer_workers), 500)), + desc=desc, + total=len(dataset_shard))) + + _num_tokens = [data['num_tokens'] for data in tokenized] + _num_tokens = np.array(_num_tokens) + + if dist.is_available(): + num_tokens = [None] * world_size + dist.all_gather_object(num_tokens, _num_tokens) + num_tokens = np.concatenate(num_tokens, axis=0) + else: + num_tokens = _num_tokens + + if rank == 0 and cache_dir: + save_path = os.path.join(cache_dir, 'num_tokens.npy') + np.save(save_path, num_tokens) + + return num_tokens + + def pre_tokenize_fn_for_pack(self, data): + raise NotImplementedError + + def process_text(self, conversations, media_type='image', image_grids=None): + while conversations and conversations[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + conversations = conversations[1:] + + assert len(conversations) % 2 == 0, f'Invalid conversation length: {len(conversations)}' + + input_ = '' + out_conversation = [] + for msg in conversations: + if msg['from'] == 'human': + input_ += msg['value'].strip() + elif msg['from'] == 'gpt': + out_conversation.append({ + 'input': input_, + 'output': msg['value'].strip() + }) + input_ = '' + else: + raise NotImplementedError(f'Unsupported message type: {msg}') + + input_ids, labels = [], [] + for i, single_turn_conversation in enumerate(out_conversation): + input_ = single_turn_conversation.get('input', '') + if input_ is None: + input_ = '' + input_ = self.chat_template['user'].format(user=input_) + + if i == 0: + input_ = self._process_media_format_first_round(input_, media_type, image_grids) + # TODO: support system prompt + # input_ = self.chat_template['system'] + input_ + input_encode = self.tokenizer.encode(input_, add_special_tokens=True) + else: + input_encode = self.tokenizer.encode(input_, add_special_tokens=False) + + input_ids += input_encode + labels += [-100] * len(input_encode) + + output_text = single_turn_conversation.get('output', '') + output_encode = self.chat_template['assistant'].format(assistant=output_text) + output_encode = self.tokenizer.encode(output_encode, add_special_tokens=False) + input_ids += output_encode + labels += copy.deepcopy(output_encode) + + if len(input_ids) > self.max_length: + input_ids = input_ids[:self.max_length] + labels = labels[:self.max_length] + logger.info( + f'Warning: input_ids length({len(input_ids)}) ' + f'is longer than max_length, cut to {self.max_length}') + return {'input_ids': input_ids, 'labels': labels} + + def _process_media_format_first_round(self, input_, media_type, image_grids): + raise NotImplementedError + + @property + def modality_length(self): + return self.group_length + + @property + def length(self): + group_length = np.array(self.group_length) + group_length = np.abs(group_length).tolist() + return group_length + + +def build_dataset(args, datasets): + assert len(datasets) > 0, 'No dataset found.' + if args.dset_pack: + train_dataset = SoftPackDataset(datasets, + target=args.pack_max_length, + blend=args.concat_before_pack) + else: + train_dataset = ConcatDataset(datasets) + if dist.get_rank() == 0: + logger.info(f'[Dataset] (Original) {len(train_dataset)} samples.') + return train_dataset + + +def build_train_dataloader(args, train_dataset, collate_fn): + dp_mesh = get_dp_mesh() + if args.group_by_length: + if args.dset_pack: + length_property = 'longest' + else: + length_property = 'length' + sampler = VLMLengthGroupedSampler(train_dataset, dp_mesh, + args.global_batch_size, + seed=args.seed, + length_property=length_property) + elif args.group_by_modality_length: + if args.dset_pack: + raise NotImplementedError + else: + sampler = VLMLengthGroupedSampler(train_dataset, dp_mesh, + args.global_batch_size, + seed=args.seed, + length_property='modality_length') + else: + sampler = ParallelSampler( + train_dataset, dp_mesh, args.global_batch_size, seed=args.seed, shuffle=True) + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.mirco_batch_size, + num_workers=args.num_workers, + sampler=sampler, + collate_fn=collate_fn, + persistent_workers=args.num_workers > 0) + + if dist.get_rank() == 0: + logger.info(f'[Dataloader] {len(train_dataloader)} batches.') + + dist.barrier() + return train_dataloader diff --git a/xtuner/_lite/datasets/internvl2/process.py b/xtuner/_lite/datasets/internvl2/process.py new file mode 100644 index 000000000..f0c48d752 --- /dev/null +++ b/xtuner/_lite/datasets/internvl2/process.py @@ -0,0 +1,705 @@ +import io + +from transformers.trainer_pt_utils import LabelSmoother + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +from typing import Dict +import torch +import torchvision.transforms as T +import transformers +from .conversation import get_conv_template +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +import sys + + +IMG_CONTEXT_TOKEN = '' +IMG_START_TOKEN = '' +IMG_END_TOKEN = '' +QUAD_START_TOKEN = '' +QUAD_END_TOKEN = '' +REF_START_TOKEN = '' +REF_END_TOKEN = '' +BOX_START_TOKEN = '' +BOX_END_TOKEN = '' +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) +CLIP_STD = (0.2686295, 0.2613025, 0.2757711) +SIGLIP_MEAN = (0.5, 0.5, 0.5) +SIGLIP_STD = (0.5, 0.5, 0.5) +IGNORE_INDEX = -100 + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def simulate_jpeg_degradation(quality): + def jpeg_degrade(img): + with io.BytesIO() as output: + img.convert('RGB').save(output, format='JPEG', quality=quality) + output.seek(0) # Move the reading cursor to the start of the stream + img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory + return img_jpeg + return jpeg_degrade + + +# Define the JPEG compression quality range, pre-create all JPEG compression functions +qualities = list(range(75, 101)) +jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} + + +def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): + if normalize_type == 'imagenet': + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + elif normalize_type == 'clip': + MEAN, STD = CLIP_MEAN, CLIP_STD + elif normalize_type == 'siglip': + MEAN, STD = SIGLIP_MEAN, SIGLIP_STD + else: + raise NotImplementedError + if is_train: # use data augumentation + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + else: + if pad2square is False: # now we use this transform function by default + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + else: + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + + return transform + + +def preprocess( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] + ': ' + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(turns): + if turn == '': + break + turn_len = len(tokenizer(turn).input_ids) + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + instruction_len -= 1 + + # Ignore the user instructions + target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + cur_len -= 1 + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + logger.info(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print( + f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' + f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' + ) + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_mpt( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep) + re_turns = [conv.sep.join(turns[:3])] # system + user + gpt + for conv_idx in range(3, len(turns), 2): + re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt + cur_len = 0 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(re_turns): + if turn == '': + break + turn_len = len(tokenizer(turn).input_ids) + 1 + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + instruction_len = len(tokenizer(parts[0]).input_ids) + + # Ignore the user instructions + target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID + # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) + # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) + # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print( + f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' + f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' + ) + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_phi3( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + tokenizer.padding_side = 'right' + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|> + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) + + turns = conversation.split(conv.sep) + re_turns = [conv.sep.join(turns[:3])] # system + user + gpt + for conv_idx in range(3, len(turns), 2): + re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') + target[target == endoftext_id] = IGNORE_TOKEN_ID + + for i, turn in enumerate(re_turns): + if turn == '': + break + if i == 0: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) - 1 + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if i == 0: + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + else: + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # Ignore the user instructions + target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID + # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) + # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) + # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + print(repr(tokenizer.decode(z))) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print( + f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' + f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' + ) + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_phi3_fast( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + conv.append_message(role, sentence['value']) + + assert len(conv.messages) % 2 == 0, f'{ds_name}, {len(conv.messages)}, {conv.messages}' + inputs = conv.messages[::2] + outputs = conv.messages[1::2] + + input_ids, labels = [], [] + # input_texts = '' + system_prompt = conv.system_template.format(system_message=conv.system_message) + input_text = system_prompt + conv.sep + # input_texts += input_text + input_encode = tokenizer.encode(input_text, add_special_tokens=True) + input_ids += input_encode + labels += [IGNORE_INDEX] * len(input_encode) + + real_num_images = 0 + for input_, output_ in zip(inputs, outputs): + # output_[0] = '<|assistant|>\n' + # 放到 input 而不是 output 是为了和官方对齐 + input_text = ''.join(input_) + conv.sep + output_[0] + + if not text_only: + real_num_images += input_text.count('') + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + input_text = input_text.replace('', image_tokens, 1) + assert '' not in input_text, f'error: {ds_name}, {input_text}' + output_text = output_[1] + conv.sep + + input_encode = tokenizer.encode(input_text, add_special_tokens=False) + output_encode = tokenizer.encode(output_text, add_special_tokens=False) + input_ids += input_encode + input_ids += output_encode + labels += [IGNORE_INDEX] * len(input_encode) + labels += output_encode + + # input_texts += input_text + # input_texts += output_text + + if not text_only: + assert real_num_images == num_image, f'{ds_name} data error: {real_num_images} vs. {num_image}' + # print(input_texts) + # assert input_ids.count(32013) == num_image_token_list[ + # 0], f'error1: {input_ids}, {num_image_token_list[0]}, {input_texts}' + if len(input_ids) > tokenizer.model_max_length: + print(f'WARNING: input_ids length {len(input_ids)} exceeds ' + f'model_max_length {tokenizer.model_max_length}. truncated!') + input_ids = input_ids[:tokenizer.model_max_length] + labels = labels[:tokenizer.model_max_length] + + # if not text_only: + # if input_ids.count(32013) != num_image_token_list[0]: + # print(f'WARNING: IMG_CONTEXT_TOKEN is broken. {input_ids.count(32013)} vs. {num_image_token_list[0]}') + + input_ids = torch.tensor(input_ids, dtype=torch.long)[None] + labels = torch.tensor(labels, dtype=torch.long)[None] + assert input_ids.size() == labels.size() + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_internlm( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + sentence['value'] = sentence['value'].strip() + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID # + parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n + info = parts[0] + conv.roles[1] + temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的 + target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID + cur_len = cur_len + temp_len + + for index in range(1, len(parts) - 1): + info = parts[index] + part1, part2 = info.split(conv.roles[0]) + temp_len = len(tokenizer(part1).input_ids) - 1 + cur_len = cur_len + temp_len + part = conv.roles[0] + part2 + conv.roles[1] + temp_len = len(tokenizer(part).input_ids) - 1 + target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID + cur_len = cur_len + temp_len + last_info = parts[-1] + temp_len = len(tokenizer(last_info).input_ids) - 1 + cur_len = cur_len + temp_len + + target[cur_len:] = IGNORE_TOKEN_ID + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + print(repr(tokenizer.decode(z))) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def dynamic_num_patch(size, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + orig_width, orig_height = size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + if use_thumbnail and blocks > 1: + blocks += 1 + return blocks + + +def packing_collate(features, pack_batch=True, pad_id=0): + input_ids = [] + labels = [] + pixel_values = [] + num_tokens = [] + num_img_tokens = [] + image_flags = [] + + for data in features: + input_ids.append(torch.LongTensor(data['input_ids'])) + labels.append(torch.LongTensor(data['labels'])) + num_tokens.extend(data['num_tokens']) + num_img_tokens.extend(data['num_img_tokens']) + pixel_values.append(data['pixel_values']) + image_flags.append(data['image_flags']) + + attention_mask = [ids.ne(pad_id) for ids in input_ids] + num_tokens = torch.IntTensor(num_tokens) + num_img_tokens = torch.IntTensor(num_img_tokens) + + if len(features) > 1 and pack_batch: + # batch packing + input_ids = torch.cat(input_ids, dim=0).unsqueeze(0) + labels = torch.cat(labels, dim=0).unsqueeze(0) + attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0) + image_flags = torch.cat(image_flags, dim=0) + pixel_values = torch.cat(pixel_values, dim=0) + elif len(features) > 1 and not pack_batch: + raise NotImplementedError + else: + raise NotImplementedError + + data_dict = { + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': attention_mask.bool(), + 'pixel_values': pixel_values, + 'image_flags': image_flags, + 'num_tokens': num_tokens, + 'num_img_tokens': num_img_tokens, + } + + return data_dict \ No newline at end of file diff --git a/xtuner/_lite/datasets/json.py b/xtuner/_lite/datasets/json.py new file mode 100644 index 000000000..3efd91a1d --- /dev/null +++ b/xtuner/_lite/datasets/json.py @@ -0,0 +1,173 @@ +import hashlib +import inspect +import json +import math +import os +import random +from concurrent.futures import ProcessPoolExecutor +from mmengine import mkdir_or_exist +import numpy as np +import torch +from torch import distributed as dist +from tqdm import tqdm +from xtuner._lite import get_logger + + +logger = get_logger() + +def calculate_json_sha256(file_path): + with open(file_path, 'rb') as f: + data = f.read() + + hash_object = hashlib.sha256(data) + hash_hex = hash_object.hexdigest() + return hash_hex + + +def calculate_tokenize_fn_sha256(tokenize_fn): + """Calculate SHA-256 hash for an instance method's source code.""" + # Get the source code of the method + fn_source = inspect.getsource(tokenize_fn.__call__) + return hashlib.sha256(fn_source.encode('utf-8')).hexdigest() + + +class JsonDataset(torch.utils.data.Dataset): + + def __init__(self, + path, + sample_ratio=1.0, + tokenize_fn=None, + cache_dir=None, + max_length=None): + super().__init__() + + self.tokenize_fn = tokenize_fn + self.path = path + self.tokenizer_workers = int(os.environ.get('XTUNER_TOKENIZE_WORKERS', 8)) + + if cache_dir: + if os.path.exists(cache_dir): + assert os.path.isdir(cache_dir) + else: + mkdir_or_exist(cache_dir) + + file_hash = calculate_json_sha256(path) + file_cache_dir = os.path.join(cache_dir, file_hash) + + if file_hash not in os.listdir(cache_dir): + mkdir_or_exist(file_cache_dir) + + if self.tokenize_fn: + tok_hash = calculate_tokenize_fn_sha256(tokenize_fn) + tok_cache_dir = os.path.join(file_cache_dir, tok_hash) + if tok_hash not in os.listdir(file_cache_dir): + mkdir_or_exist(tok_cache_dir) + + if 'num_tokens.npy' in os.listdir(tok_cache_dir): + _cached_file = os.path.join(tok_cache_dir, + 'num_tokens.npy') + num_tokens = np.load(_cached_file) + else: + num_tokens = self.count_tokens(tok_cache_dir) + else: + num_tokens = None + + else: + num_tokens = None + + with open(self.path) as f: + dataset = json.load(f) + + _sampled = [i for i in range(len(dataset))] + + if max_length is not None: + assert isinstance(max_length, int) + _filtered = [x for i, x in enumerate(_sampled) if num_tokens[i] < max_length] + + if len(_filtered) < len(_sampled): + missed_num = len(_sampled) - len(_filtered) + logger.warning(f"{path} has {missed_num} prompt length>{max_length}, discard.") + + _sampled = _filtered + + _target_num_samples = int(len(_sampled) * sample_ratio) + self.sampled = _sampled * int(sample_ratio) + self.sampled.extend(random.sample(_sampled, _target_num_samples - len(self.sampled))) + + if num_tokens is not None: + num_tokens = num_tokens[self.sampled] + + self.num_tokens = num_tokens + self.dataset = None + + def count_tokens(self, cache_dir=None): + + dataset = [] + + with open(self.path) as f: + dataset = json.load(f) + + num_samples = len(dataset) + + if dist.is_available(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + world_size = 1 + rank = 0 + + num_per_rank = math.ceil(num_samples / world_size) + + start = rank * num_per_rank + end = (rank + 1) * num_per_rank + dataset_shard = dataset[start:end] + + desc = f'[Rank {rank}] {self.path}' + chunk_size = min(1024, max(1, len(dataset_shard) // self.tokenizer_workers)) + with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor: + tokenized = list( + tqdm( + executor.map(self.tokenize_fn, dataset_shard, + chunksize=chunk_size), + desc=desc, + total=len(dataset_shard))) + + _num_tokens = [data['num_tokens'] for data in tokenized] + _num_tokens = np.array(_num_tokens) + + if dist.is_available(): + num_tokens = [None] * world_size + dist.all_gather_object(num_tokens, _num_tokens) + num_tokens = np.concatenate(num_tokens, axis=0) + else: + num_tokens = _num_tokens + + if rank == 0 and cache_dir: + save_path = os.path.join(cache_dir, 'num_tokens.npy') + np.save(save_path, num_tokens) + + return num_tokens + + def __len__(self): + return len(self.sampled) + + def __getitem__(self, item): + """Returns a dict containing packed data in the given item. + + Args: + item: An index to retrieve packed data. + + Returns: + A dict including packed input_ids, labels, and cumulative_len. + """ + if self.dataset is None: + with open(self.path) as f: + self.dataset = json.load(f) + + raw_data = self.dataset[self.sampled[item]] + + if self.tokenize_fn: + tokenized_data = self.tokenize_fn(raw_data) + return tokenized_data + else: + return raw_data diff --git a/xtuner/_lite/datasets/jsonl.py b/xtuner/_lite/datasets/jsonl.py new file mode 100644 index 000000000..3bfc2c4bb --- /dev/null +++ b/xtuner/_lite/datasets/jsonl.py @@ -0,0 +1,205 @@ +import hashlib +import inspect +import json +import math +import os +import random +from concurrent.futures import ProcessPoolExecutor +from mmengine import mkdir_or_exist +import numpy as np +import torch +from torch import distributed as dist +from tqdm import tqdm +from xtuner._lite import get_logger + +logger = get_logger() + + +def calculate_jsonl_sha256(path): + with open(path, 'rb') as f: + file_hash = hashlib.sha256() + file_hash.update(f.read()) + return file_hash.hexdigest() + + +def calculate_tokenize_fn_sha256(tokenize_fn): + """Calculate SHA-256 hash for an instance method's source code.""" + # Get the source code of the method + fn_source = inspect.getsource(tokenize_fn.__call__) + return hashlib.sha256(fn_source.encode('utf-8')).hexdigest() + + +class JsonlDataset(torch.utils.data.Dataset): + + def __init__(self, + path, + sample_ratio=1.0, + tokenize_fn=None, + cache_dir=None, + max_length=None,): + super().__init__() + + self.tokenize_fn = tokenize_fn + self.path = path + self.tokenizer_workers = int(os.environ.get('XTUNER_TOKENIZE_WORKERS', 8)) + + if cache_dir: + if os.path.exists(cache_dir): + assert os.path.isdir(cache_dir) + else: + mkdir_or_exist(cache_dir) + + file_hash = calculate_jsonl_sha256(path) + file_cache_dir = os.path.join(cache_dir, file_hash) + + if file_hash not in os.listdir(cache_dir): + mkdir_or_exist(file_cache_dir) + + if 'offsets.npy' in os.listdir(file_cache_dir): + _cached_file = os.path.join(file_cache_dir, 'offsets.npy') + offsets = np.load(_cached_file) + else: + offsets = self.count_offsets(file_cache_dir) + + if self.tokenize_fn: + tok_hash = calculate_tokenize_fn_sha256(tokenize_fn) + tok_cache_dir = os.path.join(file_cache_dir, tok_hash) + if tok_hash not in os.listdir(file_cache_dir): + mkdir_or_exist(tok_cache_dir) + + if 'num_tokens.npy' in os.listdir(tok_cache_dir): + _cached_file = os.path.join(tok_cache_dir, + 'num_tokens.npy') + num_tokens = np.load(_cached_file) + else: + num_tokens = self.count_tokens(offsets, tok_cache_dir) + else: + num_tokens = None + + offsets = offsets + num_tokens = num_tokens + + else: + offsets = self.count_offsets() + num_tokens = None + if max_length is not None: + assert self.tokenize_fn + num_tokens = self.count_tokens(offsets) + + _sampled = [i for i in range(len(offsets))] + + if max_length is not None: + assert isinstance(max_length, int) + _filtered = [x for i, x in enumerate(_sampled) if num_tokens[i] < max_length] + + if len(_filtered) < len(_sampled): + missed_num = len(_sampled) - len(_filtered) + logger.warning(f"{path} has {missed_num} prompt length>{max_length}, discard.") + + _sampled = _filtered + + _target_num_samples = int(len(_sampled) * sample_ratio) + self.sampled = _sampled * int(sample_ratio) + self.sampled.extend(random.sample(_sampled, _target_num_samples - len(self.sampled))) + + if num_tokens is not None: + num_tokens = num_tokens[self.sampled] + + self.num_tokens = num_tokens + self.offsets = offsets[self.sampled] + + + def count_offsets(self, cache_dir=None): + + offsets = [0] + with open(self.path) as f: + + lines = f.readlines() + for line in lines[:-1]: + offsets.append(offsets[-1]+len(line.encode())) + + offsets = np.array(offsets) + + if dist.get_rank() == 0 and cache_dir: + save_path = os.path.join(cache_dir, 'offsets.npy') + np.save(save_path, offsets) + + return offsets + + def _tokenize_by_offset(self, offset): + + with open(self.path, 'r') as f: + f.seek(offset) + data = json.loads(f.readline()) + return self.tokenize_fn(data) + + def count_tokens(self, offsets, cache_dir=None): + + num_samples = len(offsets) + + if dist.is_available(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + world_size = 1 + rank = 0 + + num_per_rank = math.ceil(num_samples / world_size) + + start = rank * num_per_rank + end = (rank + 1) * num_per_rank + offsets_shard = offsets[start:end] + + + desc = f'[Rank {rank}] {self.path}' + chunk_size = min(1024, max(1, len(offsets_shard) // self.tokenizer_workers)) + + with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor: + tokenized = list( + tqdm( + executor.map( + self._tokenize_by_offset, + offsets_shard, + chunksize=chunk_size), + desc=desc, + total=len(offsets_shard))) + + _num_tokens = [data['num_tokens'] for data in tokenized] + _num_tokens = np.array(_num_tokens) + + if dist.is_available(): + num_tokens = [None] * world_size + dist.all_gather_object(num_tokens, _num_tokens) + num_tokens = np.concatenate(num_tokens, axis=0) + else: + num_tokens = _num_tokens + + if rank == 0 and cache_dir: + save_path = os.path.join(cache_dir, 'num_tokens.npy') + np.save(save_path, num_tokens) + + return num_tokens + + def __len__(self): + return len(self.offsets) + + def __getitem__(self, item): + """Returns a dict containing packed data in the given item. + + Args: + item: An index to retrieve packed data. + + Returns: + A dict including packed input_ids, labels, and cumulative_len. + """ + with open(self.path, 'r') as f: + f.seek(self.offsets[item]) + line = f.readline() + + raw_data = json.loads(line) + + if self.tokenize_fn: + tokenized_data = self.tokenize_fn(raw_data) + return tokenized_data + else: + return raw_data diff --git a/xtuner/_lite/datasets/pack.py b/xtuner/_lite/datasets/pack.py new file mode 100644 index 000000000..2cbfce863 --- /dev/null +++ b/xtuner/_lite/datasets/pack.py @@ -0,0 +1,267 @@ +import random + +import numpy as np +import torch +from datasets import Dataset, concatenate_datasets +from torch.utils.data import ConcatDataset +import bisect +import itertools + + +class SoftPackDataset(torch.utils.data.Dataset): + + def __init__(self, datasets, target=2048, blend=False, sort=False): + + if blend: + num_tokens = [ + np.concatenate([dset.num_tokens for dset in datasets]) + ] + datasets = [ConcatDataset(datasets)] + else: + num_tokens = [dset.num_tokens for dset in datasets] + self.datasets = datasets + self.target = target + + pack_infos = [] + for i, dataset in enumerate(self.datasets): + _infos = self.get_pack_infos(dataset, i, num_tokens[i]) + pack_infos.append(_infos) + self.pack_infos = concatenate_datasets(pack_infos) + + @property + def longest(self): + return self.pack_infos['longest'] + + def get_pack_infos(self, dataset, dataset_id, num_tokens): + # _ori_lens = dataset['num_tokens'] + inds = [i for i in range(len(dataset))] + random.shuffle(inds) + + item_buffer = [] + length_buffer = [] + longest = 0 + + pack_infos = [] + for shfl_i in inds: + if num_tokens[shfl_i] + sum(length_buffer) <= self.target: + item_buffer.append(shfl_i) + length_buffer.append(num_tokens[shfl_i]) + longest = max(longest, num_tokens[shfl_i]) + else: + if len(item_buffer) > 0: + info = { + 'dataset_id': dataset_id, + 'indices': item_buffer, + 'longest': int(longest) + } + pack_infos.append(info) + + item_buffer = [shfl_i] + length_buffer = [num_tokens[shfl_i]] + longest = num_tokens[shfl_i] + + if len(item_buffer) > 0: + info = { + 'dataset_id': dataset_id, + 'indices': item_buffer, + 'longest': int(longest) + } + + pack_infos.append(info) + + pack_infos = Dataset.from_list(pack_infos) + + return pack_infos + + def __len__(self): + return len(self.pack_infos) + + def __getitem__(self, item): + indices = self.pack_infos[item]['indices'] + dataset_id = self.pack_infos[item]['dataset_id'] + return [self.datasets[dataset_id][i] for i in indices] + + +class HardPackDataset(torch.utils.data.Dataset): + + + def __init__(self, datasets, target=2048, blend=True, sort=False): + + if blend: + num_tokens = [ + np.concatenate([dset.num_tokens for dset in datasets]) + ] + datasets = [ConcatDataset(datasets)] + else: + num_tokens = [dset.num_tokens for dset in datasets] + self.datasets = datasets + self.target = target + + pack_infos = [] + for i, dataset in enumerate(self.datasets): + _info = self.get_pack_info(dataset, i, num_tokens[i]) + pack_infos.append(_info) + + _ranges_left = [] + _ranges_right = [] + _num_packed_samples = [] + _indices = [] + _max_length_per_pack = [] + _dataset_id = [] + for info in pack_infos: + _ranges_left.extend(info['ranges_left']) + _ranges_right.extend(info['ranges_right']) + _num_packed_samples.append(info['num_packed_samples']) + _indices.extend(info['indices']) + _max_length_per_pack.extend(info['max_length_per_pack']) + _dataset_id.extend(info['dataset_id']) + + self.pack_infos = { + 'ranges_left': _ranges_left, + 'ranges_right': _ranges_right, + 'num_packed_samples': _num_packed_samples, + 'indices': _indices, + 'max_length_per_pack':_max_length_per_pack, + 'dataset_id': _dataset_id + } + + + @classmethod + def _cal_max_length(cls, begin, end, shfl_item_rngs_left, + shfl_item_rngs_right): + left = bisect.bisect(shfl_item_rngs_right, begin) + right = bisect.bisect(shfl_item_rngs_left, end) + max_length = 0 + for i in range(left, right): + item_begin = shfl_item_rngs_left[i] + item_end = shfl_item_rngs_right[i] + inner_l = max(begin, item_begin) - item_begin + inner_r = min(end, item_end) - item_begin + trunc_size = inner_r - inner_l + max_length = max(max_length, trunc_size) + return max_length + + def get_pack_info(self, dataset, dataset_id, num_tokens): + + # The number of data items after packing + num_packed_samples = int(num_tokens.sum() / self.target) + + # Shuffle the order of the original dataset + # The packing will proceed according to the order after shuffle. + # Assume the following conditions hold: + # (1) shfl_inds = [3, 1, 2, 0] + # (2) self._ori_lens[3] + self._ori_lens[1] = max_length + # (3) self._ori_lens[2] + self._ori_lens[0] = max_length + # Ultimately, dataset[3] and dataset[1] will be combined into a new + # data, and dataset[2] and dataset[0] will be combined into a new data. + inds = [i for i in range(len(dataset))] + # if seed is not None: + # random.seed(seed) + random.shuffle(inds) + shfl_inds = inds + + # shuffled cumulative lengths + shfl_lens = [num_tokens[i] for i in shfl_inds] + shfl_acc_lens = list(itertools.accumulate(shfl_lens)) + + shfl_item_rngs_left = [0] + shfl_acc_lens[:-1] + shfl_item_rngs_right = shfl_acc_lens + + max_length_per_pack = [] + belong_dataset_ids = [] + for i in range(num_packed_samples): + begin = i * self.target + end = (i + 1) * self.target + max_length_per_pack.append( + self._cal_max_length(begin, end, shfl_item_rngs_left, + shfl_item_rngs_right)) + belong_dataset_ids.append(dataset_id) + + pack_infos = { + 'ranges_left': shfl_item_rngs_left, + 'ranges_right': shfl_item_rngs_right, + 'num_packed_samples': num_packed_samples, + 'indices': shfl_inds, + 'dataset_id': belong_dataset_ids, + 'max_length_per_pack': max_length_per_pack + } + + # pack_infos = Dataset.from_list(pack_infos) + + return pack_infos + + def _pack_ids_and_labels_in_range(self, begin: int, end: int): + """Packs ids and labels in a given range using bisection method. + + Args: + begin: Index indicating the beginning of the range. + end: Index indicating the end of the range. + + Returns: + A tuple containing packed ids, labels, and cumulative lengths. + """ + + # Use binary search to find dataset positions that fall within begin + # and end range + left = bisect.bisect(self.pack_infos['ranges_left'], begin) + right = bisect.bisect(self.pack_infos['ranges_right'], end) + + trunc_input_ids = [] + trunc_labels = [] + trunc_sizes = [] + + for i in range(left, right): + + # Determine the real range we will cut in current original item + item_begin = self.pack_infos['ranges_left'][i] + item_end = self.pack_infos['ranges_right'][i] + + # Calculate exact positions within current dataset item + inner_l = max(begin, item_begin) - item_begin + inner_r = min(end, item_end) - item_begin + + # Get original data and labels + ori_idx = self.pack_infos['indices'][i] + ori_dataset_id = self.pack_infos['dataset_id'][i] + ori_input_ids = self.datasets[ori_dataset_id][ori_idx]['input_ids'] + ori_labels = self.datasets[ori_dataset_id][ori_idx]['labels'] + + # Add original data and labels from calculated positions + # to trunc_ids and trunc_labels + trunc_input_ids.extend(ori_input_ids[inner_l:inner_r]) + trunc_labels.extend(ori_labels[inner_l:inner_r]) + trunc_sizes.append(inner_r - inner_l) + + # return populated lists of truncated ids, labels and their cumulative + # lengths + return trunc_input_ids, trunc_labels, trunc_sizes + + def __len__(self): + return len(self.pack_infos['indices']) + + def __getitem__(self, item): + """Returns a dict containing packed data in the given item. + + Args: + item: An index to retrieve packed data. + + Returns: + A dict including packed input_ids, labels, and cumulative_len. + """ + # The cumulative length from the start position of this data + begin = item * self.target + # The cumulative length from the end position of this data + end = (item + 1) * self.target + + # Extract data within the range from the shuffled original dataset. + _res = self._pack_ids_and_labels_in_range(begin, end) + packed_input_ids, packed_labels, num_tokens = _res + assert self.target == len(packed_input_ids) == len(packed_labels) + + packed = { + 'input_ids': packed_input_ids, + 'labels': packed_labels, + 'num_tokens': num_tokens, + } + + return packed \ No newline at end of file diff --git a/xtuner/_lite/datasets/streaming.py b/xtuner/_lite/datasets/streaming.py new file mode 100644 index 000000000..bb567e780 --- /dev/null +++ b/xtuner/_lite/datasets/streaming.py @@ -0,0 +1,159 @@ +import json + +import numpy as np +from torch.utils.data import IterableDataset + + +class Streaming: + + def __init__(self, file, max_epoch=1): + self.file = file + self.offset = 0 + self.epoch = 1 + self.max_epoch = max_epoch + + def __iter__(self): + return self + + def __next__(self): + + with open(self.file) as f: + f.seek(self.offset) + line = f.readline() + + if not line and self.epoch < self.max_epoch: + self.offset = 0 + self.epoch += 1 + return next(self) + + elif not line and self.epoch == self.max_epoch: + raise StopIteration + + self.offset = f.tell() + return line + + +# import torch + +# class MultiStreamingDataset(torch.utils.data.IterableDataset): + +# def __init__(self, streamings, weights, max_length, tokenize_fn, seed, dp_rank, dp_world_size, crossover = False): + +# assert len(streamings) == len(weights) +# self.streamings = streamings +# self.activated = [True for _ in self.streamings] +# for sid, stream in enumerate(self.streamings): +# stream.offset = 0 +# try: +# for _ in range(self.dp_rank): +# next(stream) +# except StopIteration: +# self.activated[sid] = False + +# self.random_state = np.random.RandomState(seed + dp_rank) +# self.weights = weights + +# self.max_length = max_length +# self.tokenize_fn = tokenize_fn +# self.dp_rank = dp_rank +# self.dp_world_size = dp_world_size +# self.crossover = crossover + +# def reactivate(self): +# self.activated = [True for _ in self.streamings] +# for stream in self.streamings: +# stream.offset = 0 +# for _ in range(self.dp_rank): +# next(stream) + +# @property +# def probabilities(self): +# if sum(self.activated) == 0: +# self.reactivate() + +# probs = (np.array(self.weights) * self.activated) / sum(self.weights[self.activated]) +# return probs + +# @property +# def num_streamings(self): +# assert len(self.iterators) == len(self.weights) +# return len(self.weights) + +# def per_rank_next(self, streaming_id): + +# sid = streaming_id +# streaming = self.streamings[sid] + +# try: +# data = next(streaming) +# except StopIteration: +# self.activated[sid] = False +# sid = self.random_state.choice( +# self.num_streamings, p=self.probabilities) +# return self.per_rank_next(sid) + +# try: +# for _ in range(self.dp_world_size): +# next(streaming) +# except StopIteration: +# self.activated[sid] = False + +# return data, sid + +# def __iter__(self): +# worker_info = torch.utils.data.get_worker_info() + +# if worker_info and worker_info.num_workers > 1: +# raise NotImplementedError + +# input_ids = [] +# labels = [] +# num_tokens = [] +# while True: +# sid = self.random_state.choice( +# self.num_streamings, p=self.probabilities) + +# while len(input_ids) < self.max_length: +# if self.crossover: +# sid = self.random_state.choice( +# self.num_streamings, p=self.probabilities) + +# line, sid = self.per_rank_next(sid) + +# tokenized = self.tokenize_fn(json.loads(line)) + +# input_ids.extend(tokenized['input_ids']) +# labels.extend(tokenized['labels']) +# num_tokens.extend(tokenized['num_tokens']) + +# remain_tokens = max(sum(num_tokens) - self.max_length, 0) +# num_tokens[-1] = num_tokens[-1] - remain_tokens + +# packed_ids = input_ids[:self.max_length] +# packed_labels = labels[:self.max_length] +# packed_tokens = num_tokens + +# if remain_tokens: +# input_ids = input_ids[self.max_length:] +# labels = labels[self.max_length:] +# num_tokens = [remain_tokens] + +# yield {'input_ids': packed_ids, 'labels': packed_labels, 'num_tokens': packed_tokens} + +if __name__ == '__main__': + import json + streaming = Streaming( + '/mnt/hwfile/xtuner/huanghaian/data/databricks-dolly-15k/databricks-dolly-15k.jsonl' + ) + + data = next(streaming) + print(json.loads(data)) + + data = next(streaming) + print(json.loads(data)) + + data = next(streaming) + print(json.loads(data)) + + data = next(streaming) + print(json.loads(data)) diff --git a/xtuner/_lite/datasets/utils/__init__.py b/xtuner/_lite/datasets/utils/__init__.py new file mode 100644 index 000000000..b16dc419d --- /dev/null +++ b/xtuner/_lite/datasets/utils/__init__.py @@ -0,0 +1,6 @@ +from .convert import OPENAI_CONVERT_MAP +from .load import DATASET_CLS_MAP, load_datasets +from .utils import apply_exif_orientation, move_data_to_device + +__all__ = ['OPENAI_CONVERT_MAP', 'DATASET_CLS_MAP', 'load_datasets', + 'apply_exif_orientation', 'move_data_to_device'] diff --git a/xtuner/_lite/datasets/utils/convert.py b/xtuner/_lite/datasets/utils/convert.py new file mode 100644 index 000000000..b78a12d51 --- /dev/null +++ b/xtuner/_lite/datasets/utils/convert.py @@ -0,0 +1,234 @@ +import re + +from xtuner._lite.chat import ChatMessages + + +class XTunerFormat2Openai(): + + @classmethod + def source_format(cls): + data = { + 'conversation': [{ + 'system': 'SYSTEM', + 'input': 'INPUT', + 'output': 'OUTPUT' + }, { + 'input': 'INPUT', + 'output': 'OUTPUT' + }] + } + return data + + @classmethod + def target_format(cls): + data = { + 'messages': [ + { + 'role': 'system', + 'content': 'SYSTEM' + }, + { + 'role': 'user', + 'content': 'INPUT' + }, + { + 'role': 'assistant', + 'content': 'OUTPUT' + }, + { + 'role': 'user', + 'content': 'INPUT' + }, + { + 'role': 'assistant', + 'content': 'OUTPUT' + }, + ] + } + return data + + @staticmethod + def convert(data): + ROLE_MAPPING = { + 'system': 'system', + 'input': 'user', + 'output': 'assistant' + } + messages = [] + for single_turn_conversation in data['conversation']: + for role, content in single_turn_conversation.items(): + messages.append({ + 'role': ROLE_MAPPING[role], + 'content': content + }) + return ChatMessages.from_dict({'messages': messages}) + + +class Alpaca2Openai(): + + @classmethod + def source_format(cls): + data = { + 'instruction': 'INSTRUCTION', + 'input': 'INPUT', + 'output': 'OUTPUT', + } + return data + + @classmethod + def target_format(cls): + data = { + 'messages': [ + { + 'role': 'user', + 'content': 'INSTRUCTION\nINPUT' + }, + { + 'role': 'assistant', + 'content': 'OUTPUT' + }, + ] + } + return data + + @staticmethod + def convert(data): + if data.get('output') == '': + return ChatMessages.from_dict({'messages': []}) + else: + return ChatMessages.from_dict({ + 'messages': [ + { + 'role': 'user', + 'content': f"{data['instruction']}\n{data['input']}" + }, + { + 'role': 'assistant', + 'content': f"{data['output']}" + }, + ] + }) + + +def llava_to_openai(data): + + image_token = '' + conversations = data['conversations'] + messages = [] + + if 'image' in data: + image_urls = data['image'] + if isinstance(image_urls, str): + image_urls = [image_urls] + else: + image_urls = None + + while conversations and conversations[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + conversations = conversations[1:] + + image_id = 0 + for convs in conversations: + if convs['from'] == 'human': + pattern = f'({image_token})' + chunks = re.split(pattern, convs['value']) + + text_content = [] + img_content = [] + + for chunk in chunks: + if chunk == image_token: + url = image_urls[image_id] + if not isinstance(url, str): + raise TypeError(data) + # assert , image_url + item = dict(type='image_url', image_url=url) + img_content.append(item) + image_id += 1 + elif len(chunk.strip()): + item = dict(type='text', text=chunk.strip()) + text_content.append(item) + + msg = {'role': 'user', 'content': img_content + text_content} + messages.append(msg) + + elif convs['from'] == 'gpt': + msg = {'role': 'assistant', 'content': convs['value']} + messages.append(msg) + else: + raise NotImplementedError + + return ChatMessages.from_dict({'messages': messages}) + + +def llava_to_openai_interleave(data): + + image_token = '' + conversations = data['conversations'] + messages = [] + + if 'image' in data: + image_urls = data['image'] + if isinstance(image_urls, str): + image_urls = [image_urls] + else: + image_urls = None + + while conversations and conversations[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + conversations = conversations[1:] + + image_id = 0 + for convs in conversations: + if convs['from'] == 'human': + pattern = f'({image_token})' + chunks = re.split(pattern, convs['value']) + + content = [] + + for chunk in chunks: + if chunk == image_token: + url = image_urls[image_id] + if not isinstance(url, str): + raise TypeError(data) + # assert , image_url + item = dict(type='image_url', image_url=url) + content.append(item) + image_id += 1 + elif len(chunk.strip()): + item = dict(type='text', text=chunk.strip()) + content.append(item) + + msg = {'role': 'user', 'content': content} + messages.append(msg) + + elif convs['from'] == 'gpt': + msg = {'role': 'assistant', 'content': convs['value']} + messages.append(msg) + else: + raise NotImplementedError + + return ChatMessages.from_dict({'messages': messages}) + + +def official_openai(data): + if 'messages' in data: + return ChatMessages.from_dict(data) + elif 'message_data' in data: + return ChatMessages.from_dict({'messages': data['message_data']}) + elif 'dialogs' in data: + return ChatMessages.from_dict({'messages': data['dialogs']}) + else: + return ChatMessages.from_dict({'messages': data}) + +OPENAI_CONVERT_MAP = { + 'llava': + llava_to_openai, + 'llava_interleave': + llava_to_openai_interleave, + 'alpaca': + Alpaca2Openai.convert, + 'xtuner': + XTunerFormat2Openai.convert, + 'openai': official_openai, +} diff --git a/xtuner/_lite/datasets/utils/load.py b/xtuner/_lite/datasets/utils/load.py new file mode 100644 index 000000000..7a39f3c75 --- /dev/null +++ b/xtuner/_lite/datasets/utils/load.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import math +import os +import random +import re + +from torch import distributed as dist +from tqdm import tqdm + +from xtuner._lite import get_logger +from ..json import JsonDataset +from ..jsonl import JsonlDataset + +logger = get_logger() + +DATASET_CLS_MAP = {'.jsonl': JsonlDataset, '.json': JsonDataset} + + +def load_hf_dataset(path, + split='train', + sample_ratio=1.0, + cache_dir=None, + map_fn=None): + from datasets import load_dataset + dataset = load_dataset(path)[split] + + if map_fn: + dataset = dataset.map(map_fn, num_proc=8) + + if sample_ratio != 1: + ori_samples = len(dataset) + target_samples = int(sample_ratio * ori_samples) + indices = random.choices([i for i in range(ori_samples)], + k=target_samples) + dataset = dataset.select(indices) + + dataset = dataset.to_list() + + # if init_fn: + # dataset = init_fn(dataset) + + # if cache_dir and isinstance(dataset, CacheDataset): + # dataset.cache(cache_dir) + + return dataset + + +def load_from_cache(cache_dir, init_fn): + + if dist.is_available(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + world_size = 1 + rank = 0 + + sub_cache_dirs = [] + for _path in tqdm(os.listdir(cache_dir)): + path = os.path.join(cache_dir, _path) + if os.path.isdir(path): + sub_cache_dirs.append(path) + + num_dsets = len(sub_cache_dirs) + avg_num = math.ceil(num_dsets / world_size) + start = rank * avg_num + end = min((rank + 1) * avg_num, num_dsets) + desc = f'[Rank {rank}] Loading Cached Dataset' + + rank_datasets = [] + for ind in tqdm(range(start, end), desc=desc): + dset = init_fn(sub_cache_dirs[ind]) + rank_datasets.append(dset) + + if dist.is_available() and world_size > 1: + dist.barrier() + buffers = [None] * world_size + dist.all_gather_object(buffers, rank_datasets) + world_datasets = [] + for dsets_per_rank in buffers: + world_datasets.extend(dsets_per_rank) + + assert len(world_datasets) == num_dsets + else: + world_datasets = rank_datasets + + return world_datasets + + +def load_local_datasets(paths, + file_types, + file_pattern=None, + cache_dir=None, + sample_ratios=1.0, + map_fns=None, + max_length=None): + + if isinstance(paths, str): + paths = [paths] + + if isinstance(sample_ratios, (tuple, list)): + + if len(sample_ratios) == 1: + sample_ratios = list(sample_ratios) * len(paths) + + if len(sample_ratios) != len(paths): + raise RuntimeError(f'There are {len(paths)} paths, but only ' + f'{len(sample_ratios)} sample ratios were set.') + + if map_fns is None: + map_fns = [None] * len(paths) + + if isinstance(map_fns, (tuple, list)): + + if len(map_fns) == 1: + map_fns = list(map_fns) * len(paths) + + if len(map_fns) != len(paths): + raise RuntimeError(f'There are {len(paths)} paths, but only' + f'{len(map_fns)} map fns were set.') + + files = [] + file_sample_ratios = [] + file_map_fns = [] + + for pid, path in enumerate(paths): + if os.path.isdir(path): + dir_files = [] + for root, dirs, _files in os.walk(path, followlinks=True): + dirs.sort() + for relative_path in sorted(_files): + suffix = os.path.splitext(relative_path)[-1] + absolute_path = os.path.join(root, relative_path) + if file_pattern is not None: + if bool(re.match(file_pattern, absolute_path)): + dir_files.append(absolute_path) + elif suffix in file_types: + dir_files.append(absolute_path) + + _num_dir_files = len(dir_files) + if _num_dir_files == 0: + raise RuntimeError( + f'There are no files with the suffix {file_types}' + f'in `{path}`.') + + logger.info(f'Found {len(dir_files)} files in {path}') + files.extend(dir_files) + file_sample_ratios.extend([sample_ratios[pid]] * _num_dir_files) + file_map_fns.extend([map_fns[pid]] * _num_dir_files) + + elif os.path.isfile(path): + files.append(path) + file_sample_ratios.append(sample_ratios[pid]) + file_map_fns.append(map_fns[pid]) + + else: + raise RuntimeError(f'`{path}` not found.') + + num_files = len(files) + + datasets = [] + for i in range(num_files): + _path = files[i] + _ratio = file_sample_ratios[i] + _map_fn = file_map_fns[i] + _suffix = os.path.splitext(_path)[-1] + + dataset_cls = DATASET_CLS_MAP[_suffix] + _dataset = dataset_cls(_path, _ratio, _map_fn, cache_dir, max_length) + datasets.append(_dataset) + + return datasets + + +def load_datasets(paths, + sources='local', + sample_ratios=1.0, + file_types=DATASET_CLS_MAP.keys(), + file_pattern=None, + cache_dir=None, + map_fns=None, + max_length=None): + + if isinstance(paths, str): + paths = [paths] + + num_paths = len(paths) + + if isinstance(sample_ratios, (float, int)): + sample_ratios = [sample_ratios] * num_paths + + if isinstance(sample_ratios, (tuple, list)): + + if len(sample_ratios) == 1: + sample_ratios = list(sample_ratios) * num_paths + + if len(sample_ratios) != num_paths: + raise RuntimeError(f'There are {num_paths} paths, but only ' + f'{len(sample_ratios)} sample ratios were set.') + + if isinstance(sources, str): + sources = [sources] + + if isinstance(sources, (tuple, list)): + + if len(sources) == 1: + sources = list(sources) * num_paths + + if len(sources) != num_paths: + raise RuntimeError(f'There are {num_paths} paths, but only ' + f'{len(sources)} sources were set.') + + if not isinstance(map_fns, (tuple, list)): + map_fns = [map_fns] * num_paths + + if isinstance(map_fns, (tuple, list)): + + if len(map_fns) == 1: + map_fns = list(map_fns) * num_paths + + if len(map_fns) != num_paths: + raise RuntimeError(f'There are {num_paths} paths, but only' + f'{len(map_fns)} map fns were set.') + + local_inds = [i for i, src in enumerate(sources) if src == 'local'] + local_paths = [paths[ind] for ind in local_inds] + local_map_fns = [map_fns[ind] for ind in local_inds] + local_sample_ratios = [sample_ratios[ind] for ind in local_inds] + + hf_inds = [i for i, src in enumerate(sources) if src == 'huggingface'] + hf_paths = [paths[ind] for ind in hf_inds] + hf_map_fns = [map_fns[ind] for ind in hf_inds] + hf_sample_ratios = [sample_ratios[ind] for ind in hf_inds] + + datasets = [] + if len(local_inds): + local_datasets = load_local_datasets(local_paths, file_types, + file_pattern, cache_dir, + local_sample_ratios, + local_map_fns, max_length) + datasets.extend(local_datasets) + + if len(hf_inds): + cached_infos = {} + for i in range(len(hf_inds)): + if cache_dir: + digits = len(str(abs(len(hf_inds)))) + cache_id = (f'cache-hf-{i+1:0{digits}}-of-' + f'{len(hf_inds):0{digits}}') + sub_cache_dir = os.path.join(cache_dir, cache_id) + else: + sub_cache_dir = None + dset = load_hf_dataset( + hf_paths[i], + sample_ratio=hf_sample_ratios[i], + map_fn=hf_map_fns[i], + cache_dir=sub_cache_dir, + max_length=max_length) + datasets.append(dset) + breakpoint() + if cache_dir: + + infos = { + 'path': hf_paths[i], + 'num_samples': dset.num_samples, + 'num_tokens': dset.total_tokens + } + cached_infos[cache_id] = infos + + if cache_dir: + _path = os.path.join(cache_dir, 'hf_infos.json') + with open(_path, 'w') as f: + json.dump(cached_infos, f) + + return datasets + + +def load_ms_dataset(): + pass diff --git a/xtuner/_lite/datasets/utils/utils.py b/xtuner/_lite/datasets/utils/utils.py new file mode 100644 index 000000000..19f572ac0 --- /dev/null +++ b/xtuner/_lite/datasets/utils/utils.py @@ -0,0 +1,66 @@ +from PIL import Image +import torch +from collections.abc import Mapping + +_EXIF_ORIENT = 274 # exif 'Orientation' tag + + +def apply_exif_orientation(image): + """ + Applies the exif orientation correctly. + + This code exists per the bug: + https://github.com/python-pillow/Pillow/issues/3973 + with the function `ImageOps.exif_transpose`. The Pillow source raises errors with + various methods, especially `tobytes` + + Function based on: + https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59 + https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527 + + Args: + image (PIL.Image): a PIL image + + Returns: + (PIL.Image): the PIL image with exif orientation applied, if applicable + """ + if not hasattr(image, "getexif"): + return image + + try: + exif = image.getexif() + except Exception: # https://github.com/facebookresearch/detectron2/issues/1885 + exif = None + + if exif is None: + return image + + orientation = exif.get(_EXIF_ORIENT) + + method = { + 2: Image.FLIP_LEFT_RIGHT, + 3: Image.ROTATE_180, + 4: Image.FLIP_TOP_BOTTOM, + 5: Image.TRANSPOSE, + 6: Image.ROTATE_270, + 7: Image.TRANSVERSE, + 8: Image.ROTATE_90, + }.get(orientation) + + if method is not None: + return image.transpose(method) + return image + + +def move_data_to_device(data, device='cuda'): + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)({k: move_data_to_device(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(move_data_to_device(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = {"device": device} + return data.to(non_blocking=True, **kwargs) + return data diff --git a/xtuner/_lite/device.py b/xtuner/_lite/device.py new file mode 100644 index 000000000..ab6890c5f --- /dev/null +++ b/xtuner/_lite/device.py @@ -0,0 +1,32 @@ +import torch + + +def get_device(): + device = None + if torch.cuda.is_available(): + device = 'cuda' + else: + try: + import torch_npu # noqa: F401 + device = 'npu' + except ImportError: + pass + + if device is None: + raise NotImplementedError( + 'Supports only CUDA or NPU. If your device is CUDA or NPU, ' + 'please make sure that your environmental settings are ' + 'configured correctly.') + + return device + + +def get_torch_device_module(): + + device = get_device() + if device == 'cuda': + return torch.cuda + elif device == 'npu': + return torch.npu + else: + raise NotImplementedError diff --git a/xtuner/_lite/modelings/__init__.py b/xtuner/_lite/modelings/__init__.py new file mode 100644 index 000000000..de930bff1 --- /dev/null +++ b/xtuner/_lite/modelings/__init__.py @@ -0,0 +1,10 @@ +from .internlm2 import InternLM2Config, InternLM2ForCausalLM +from .llava.modeling_llava import LlavaForConditionalGeneration +from .llava.configuration_llava import EnhancedLlavaConfig +from .llava.processing_llava import LlavaProcessor + +def register_remote_code(): + from transformers import AutoConfig, AutoModelForCausalLM + AutoConfig.register('internlm2', InternLM2Config, exist_ok=True) + AutoModelForCausalLM.register( + InternLM2Config, InternLM2ForCausalLM, exist_ok=True) diff --git a/xtuner/_lite/modelings/internlm2/__init__.py b/xtuner/_lite/modelings/internlm2/__init__.py new file mode 100644 index 000000000..e43d72d4a --- /dev/null +++ b/xtuner/_lite/modelings/internlm2/__init__.py @@ -0,0 +1,2 @@ +from .configuration_internlm2 import InternLM2Config +from .modeling_internlm2 import InternLM2ForCausalLM diff --git a/xtuner/_lite/modelings/internlm2/configuration_internlm2.py b/xtuner/_lite/modelings/internlm2/configuration_internlm2.py new file mode 100644 index 000000000..8b8107947 --- /dev/null +++ b/xtuner/_lite/modelings/internlm2/configuration_internlm2.py @@ -0,0 +1,175 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" InternLM2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate + an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is necessary to ensure exact reproducibility + of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + """ + _auto_class = 'AutoConfig' + model_type = 'internlm2' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act='silu', + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + bias=True, + rope_theta=10000, + rope_scaling=None, + attn_implementation=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attn_implementation = attn_implementation + if self.attn_implementation is None: + self.attn_implementation = 'eager' + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, + dict) or len(self.rope_scaling) != 2: + raise ValueError( + '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' + f'got {self.rope_scaling}') + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_factor = self.rope_scaling.get('factor', None) + if rope_scaling_type is None or rope_scaling_type not in [ + 'linear', 'dynamic' + ]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, + (float, int)) or rope_scaling_factor < 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} " + f'of type {type(rope_scaling_factor)}') diff --git a/xtuner/_lite/modelings/internlm2/modeling_internlm2.py b/xtuner/_lite/modelings/internlm2/modeling_internlm2.py new file mode 100644 index 000000000..69ddc6196 --- /dev/null +++ b/xtuner/_lite/modelings/internlm2/modeling_internlm2.py @@ -0,0 +1,1899 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/modeling_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch InternLM2.5 model.""" +import math +import queue +import threading +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, logging, + replace_return_docstrings) + +try: + from transformers.generation.streamers import BaseStreamer +except Exception: + BaseStreamer = None + +from .configuration_internlm2 import InternLM2Config + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import (index_first_axis, pad_input, + unpad_input) +except: + pass + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'InternLM2Config' + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102 + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class InternLM2RMSNorm(nn.Module): + """InternLM2RMSNorm is equivalent to T5LayerNorm.""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm) + + +class InternLM2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains.""" + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + **(torch.arange(0, self.dim, 2, + dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance( + device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() + @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / + self.max_position_embeddings) - + (self.scaling_factor - 1))**( + self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to( + x.device) / self.dim)) + self.register_buffer( + 'inv_freq', inv_freq, + persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class InternLM2MLP(nn.Module): + """MLP for InternLM2 model.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.w1 = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.w3 = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.w2 = nn.Linear( + self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +class InternLM2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, + config: InternLM2Config, + layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will ' + 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` ' + 'when creating this class.') + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).') + + self.wqkv = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.bias, + ) + self.wo = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = InternLM2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'linear': + self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == 'dynamic': + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + # split qkv_states by tp size + key_value_slicing = (self.num_key_value_heads * + self.head_dim) // self.config.pretraining_tp + qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0) + qkv_states = torch.cat( + [ + F.linear(hidden_states, qkv_slice) + for qkv_slice in qkv_slices + ], + dim=-1 # pylint: disable=E1102 + ) + else: + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, + 'b q h gs d -> b q (h gs) d').transpose(1, 2) + key_states = qkv_states[..., -2, :].transpose(1, 2) + value_states = qkv_states[..., -1, :].transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split( + self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.wo.weight.split( + self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([ + F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102 + for i in range(self.config.pretraining_tp) + ]) + else: + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class InternLM2FlashAttention2(InternLM2Attention): + """ + InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, + # that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) + # produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` ' + 'make sure to use `sdpa` in the mean time, and open an issue at ' + 'https://github.com/huggingface/transformers') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # dropout_rate = self.attention_dropout if self.training else 0.0 + dropout_rate = 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (InternLM2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.wqkv.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.') + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, + self.hidden_size).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value # pylint: disable=E0606 + + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. + # For details, please see the comment in InternLM2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606 + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) # pylint: disable=E0606 + else: + attn_output = flash_attn_func( # pylint: disable=E0606 + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( # pylint: disable=E0606 + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( # pylint: disable=E0606 + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( # pylint: disable=E0606 + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606 + query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2 +class InternLM2SdpaAttention(InternLM2Attention): + """ + InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass + to adapt to SDPA API. + """ + + # Adapted from InternLM2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` + # once this is implemented. + logger.warning_once( + 'InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` ' + 'does not support `output_attentions=True`. Falling back to the manual attention implementation, ' + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. ' + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == 'cuda' and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of + # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph + # options. An inline conditional prevents dynamic shapes from compiling. + is_causal = bool(causal_mask is None and q_len > 1) + + attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102 + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + return attn_output, None, past_key_value + + +INTERNLM2_ATTENTION_CLASSES = { + 'eager': InternLM2Attention, + 'flash_attention_2': InternLM2FlashAttention2, + 'sdpa': InternLM2SdpaAttention, +} + + +# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2 +class InternLM2DecoderLayer(nn.Module): + """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model.""" + + def __init__(self, config: InternLM2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.attention = INTERNLM2_ATTENTION_CLASSES[ + config.attn_implementation]( + config=config, layer_idx=layer_idx) + + self.feed_forward = InternLM2MLP(config) + self.attention_norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + self.ffn_norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +InternLM2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`InternLM2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2 +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2PreTrainedModel(PreTrainedModel): + """ + InternLM2 pretraiend model's base class. + """ + + config_class = InternLM2Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['InternLM2DecoderLayer'] + _skip_keys_device_placement = ['past_key_values'] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +InternLM2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +# Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2 +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2Model(InternLM2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`] + Args: + config: InternLM2Config + """ + + _auto_class = 'AutoModel' + + def __init__(self, config: InternLM2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx) + + self.layers = nn.ModuleList([ + InternLM2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.tok_embeddings + + def set_input_embeddings(self, value): + self.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one' + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.' + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.tok_embeddings(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance( + past_key_values, + Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length( + ) if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, + cache_position, past_key_values, + output_attentions) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[ + 2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length + # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at + # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is + # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`. + # See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config.attn_implementation == 'flash_attention_2': + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length( + ) if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config.attn_implementation == 'sdpa' and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] if isinstance( + attention_mask, torch.Tensor) else past_seen_tokens + + sequence_length + 1) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError( + 'Custom 4D attention mask should be passed in inverted form with max==0`' + ) + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand( + input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone( + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, : + mask_length] + attention_mask[:, + None, + None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, : + mask_length] = causal_mask[:, :, :, : + mask_length].masked_fill( + padding_mask, + min_dtype) + if (self.config.attn_implementation == 'sdpa' + and attention_mask is not None + and attention_mask.device.type == 'cuda' + and not output_attentions): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype) # pylint: disable=E1120 + + return causal_mask + + +# Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM +class InternLM2ForCausalLM(InternLM2PreTrainedModel): + """Causal language model (CLM) for InternLM2.""" + + _auto_class = 'AutoModelForCausalLM' + _tied_weights_keys = ['output.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = InternLM2Model(config) + self.vocab_size = config.vocab_size + self.output = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + return self.output + + def set_output_embeddings(self, new_embeddings): + self.output = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, InternLM2ForCausalLM + >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + output_slices = self.output.weight.split( + self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [ + F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable + for i in range(self.config.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.output(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[ + 0] if cache_position is not None else past_key_values.get_seq_length( + ) + max_cache_length = ( + torch.tensor( + past_key_values.get_max_length(), + device=input_ids.device) + if past_key_values.get_max_length() is not None else None) + cache_length = past_length if max_cache_length is None else torch.min( + max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[ + 1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - + past_length):] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if (max_cache_length is not None and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length): + attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130 + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + input_length = position_ids.shape[ + -1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange( + past_length, + past_length + input_length, + device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update({ + 'position_ids': position_ids, + 'cache_position': cache_position, + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past), ) + return reordered_past + + def build_inputs(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + meta_instruction=''): + if history is None: + history = [] + if tokenizer.add_bos_token: + prompt = '' + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n""" + for record in history: + prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n""" + prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n""" + return tokenizer([prompt], return_tensors='pt') + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: + str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n' + '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory ' + '(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n' + '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such ' + 'as English and 中文.', + **kwargs, + ): + if history is None: + history = [] + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = { + k: v.to(self.device) + for k, v in inputs.items() if torch.is_tensor(v) + } + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0] + ] + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split('<|im_end|>')[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + if history is None: + history = [] + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + 'The version of `transformers` is too low. Please make sure ' + 'that you have installed `transformers>=4.28.0`.') + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + """ + Streamer used in generate to print words one by one. + """ + + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = '' + self.cache = [] + self.received_inputs = False + self.queue.put( + (self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError('ChatStreamer only supports batch size 1') + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode( + self.cache, skip_special_tokens=True) + if token.strip() != '<|im_end|>': + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2 +@add_start_docstrings( + """ + The InternLM2 Model transformer with a sequence classification head on top (linear layer). + [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForSequenceClassification(InternLM2PreTrainedModel): + """Sequence Classification Head for InternLM2 Model.""" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLM2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + 'Cannot handle batch sizes > 1 if no padding token is defined.' + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq( + input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), + sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype + in (torch.long, torch.int)): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2 +@add_start_docstrings( + """ +The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel): + """Question Answering model for InternLM2.""" + + base_model_prefix = 'transformer' + + def __init__(self, config): + super().__init__(config) + self.transformer = InternLM2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.tok_embeddings + + def set_input_embeddings(self, value): + self.transformer.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to( + start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss, ) + + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2 +@add_start_docstrings( + """ + The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForTokenClassification(InternLM2PreTrainedModel): + """Token classification model for InternLM2.""" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLM2Model(config) + if getattr(config, 'classifier_dropout', None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, 'hidden_dropout', None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/xtuner/_lite/modelings/internvl2/__init__.py b/xtuner/_lite/modelings/internvl2/__init__.py new file mode 100644 index 000000000..8652be2d9 --- /dev/null +++ b/xtuner/_lite/modelings/internvl2/__init__.py @@ -0,0 +1,3 @@ +from .modeling_intern_vit import InternVisionModel + +__all__ = ['InternVisionModel'] diff --git a/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py b/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py new file mode 100644 index 000000000..32f469c4b --- /dev/null +++ b/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py @@ -0,0 +1,119 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class InternVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to + instantiate a vision encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + Number of color channels in the input images (e.g., 3 for RGB). + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries and values in the self-attention layers. + hidden_size (`int`, *optional*, defaults to 3200): + Dimensionality of the encoder layers and the pooler layer. + num_attention_heads (`int`, *optional*, defaults to 25): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 12800): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + qk_normalization (`bool`, *optional*, defaults to `True`): + Whether to normalize the queries and keys in the self-attention layers. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + use_flash_attn (`bool`, *optional*, defaults to `True`): + Whether to use flash attention mechanism. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for stochastic depth. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 0.1): + A factor for layer scale. + """ + + model_type = 'intern_vit_6b' + + def __init__( + self, + num_channels=3, + patch_size=14, + image_size=224, + qkv_bias=False, + hidden_size=3200, + num_attention_heads=25, + intermediate_size=12800, + qk_normalization=True, + num_hidden_layers=48, + use_flash_attn=True, + hidden_act='gelu', + norm_type='rms_norm', + layer_norm_eps=1e-6, + dropout=0.0, + drop_path_rate=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=0.1, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.dropout = dropout + self.drop_path_rate = drop_path_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.norm_type = norm_type + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.use_flash_attn = use_flash_attn + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if 'vision_config' in config_dict: + config_dict = config_dict['vision_config'] + + if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' + ) + + return cls.from_dict(config_dict, **kwargs) \ No newline at end of file diff --git a/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py b/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py new file mode 100644 index 000000000..a8d36d9e3 --- /dev/null +++ b/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py @@ -0,0 +1,432 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from timm.models.layers import DropPath +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_intern_vit import InternVisionConfig + +try: + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import \ + flash_attn_varlen_qkvpacked_func + has_flash_attn = True +except: + print('FlashAttention2 is not installed.') + has_flash_attn = False + +logger = logging.get_logger(__name__) + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, + max_s=None, need_weights=False): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + if unpadded: (nnz, 3, h, d) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert not need_weights + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, batch_size, seqlen), + 'b s (h d) -> b s h d', h=nheads) + else: + assert max_s is not None + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + + return output, None + + +class InternRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +try: + from apex.normalization import FusedRMSNorm + + InternRMSNorm = FusedRMSNorm # noqa + + logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') +except ImportError: + # using the normal InternRMSNorm + pass +except Exception: + logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') + pass + + +NORM2FN = { + 'rms_norm': InternRMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.randn(1, 1, self.embed_dim), + ) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ + reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, width) + ], dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_flash_attn = config.use_flash_attn and has_flash_attn + if config.use_flash_attn and not has_flash_attn: + print('Warning: Flash Attention is not available, use_flash_attn is set to False.') + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).' + ) + + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) + self.attn_drop = nn.Dropout(config.attention_dropout) + self.proj_drop = nn.Dropout(config.dropout) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + if self.use_flash_attn: + self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) + self.proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _naive_attn(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _flash_attn(self, x, key_padding_mask=None, need_weights=False): + qkv = self.qkv(x) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + + if self.qk_normalization: + q, k, v = qkv.unbind(2) + q = self.q_norm(q.flatten(-2, -1)).view(q.shape) + k = self.k_norm(k.flatten(-2, -1)).view(k.shape) + qkv = torch.stack([q, k, v], dim=2) + + context, _ = self.inner_attn( + qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False + ) + outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) + outs = self.proj_drop(outs) + return outs + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) + return x + + +class InternMLP(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.act = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + def __init__(self, config: InternVisionConfig, drop_path_rate: float): + super().__init__() + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = InternAttention(config) + self.mlp = InternMLP(config) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: + """ + Args: + hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1) + + hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2) + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`InternEncoderLayer`]. + + Args: + config (`InternConfig`): + The corresponding vision configuration for the `InternEncoder`. + """ + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + # stochastic depth decay rule + # TODO: error + # dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] + dpr = np.linspace(0.0, float(config.drop_path_rate), int(config.num_hidden_layers)) + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) + self.gradient_checkpointing = True + + def forward( + self, + inputs_embeds, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + hidden_states = inputs_embeds + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + encoder_layer, + hidden_states) + else: + layer_outputs = encoder_layer( + hidden_states, + ) + hidden_states = layer_outputs + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states + ) + + +class InternVisionModel(PreTrainedModel): + main_input_name = 'pixel_values' + _supports_flash_attn_2 = True + config_class = InternVisionConfig + _no_split_modules = ['InternVisionEncoderLayer'] + + def __init__(self, config: InternVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder(config) + + def resize_pos_embeddings(self, old_size, new_size, patch_size): + pos_emb = self.embeddings.position_embedding + _, num_positions, embed_dim = pos_emb.shape + cls_emb = pos_emb[:, :1, :] + pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) + pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) + pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) + pos_emb = torch.cat([cls_emb, pos_emb], dim=1) + self.embeddings.position_embedding = nn.Parameter(pos_emb) + self.embeddings.image_size = new_size + logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_embeds: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and pixel_embeds is None: + raise ValueError('You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + else: + if len(pixel_values.shape) == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/xtuner/_lite/modelings/llava/__init__.py b/xtuner/_lite/modelings/llava/__init__.py new file mode 100644 index 000000000..036324005 --- /dev/null +++ b/xtuner/_lite/modelings/llava/__init__.py @@ -0,0 +1,3 @@ +from .configuration_llava import EnhancedLlavaConfig +from .modeling_llava import LlavaForConditionalGeneration +from .processing_llava import LlavaProcessor diff --git a/xtuner/_lite/modelings/llava/configuration_internlm2.py b/xtuner/_lite/modelings/llava/configuration_internlm2.py new file mode 100644 index 000000000..8b8107947 --- /dev/null +++ b/xtuner/_lite/modelings/llava/configuration_internlm2.py @@ -0,0 +1,175 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" InternLM2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate + an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is necessary to ensure exact reproducibility + of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + """ + _auto_class = 'AutoConfig' + model_type = 'internlm2' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act='silu', + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + bias=True, + rope_theta=10000, + rope_scaling=None, + attn_implementation=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attn_implementation = attn_implementation + if self.attn_implementation is None: + self.attn_implementation = 'eager' + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, + dict) or len(self.rope_scaling) != 2: + raise ValueError( + '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' + f'got {self.rope_scaling}') + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_factor = self.rope_scaling.get('factor', None) + if rope_scaling_type is None or rope_scaling_type not in [ + 'linear', 'dynamic' + ]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, + (float, int)) or rope_scaling_factor < 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} " + f'of type {type(rope_scaling_factor)}') diff --git a/xtuner/_lite/modelings/llava/configuration_llava.py b/xtuner/_lite/modelings/llava/configuration_llava.py new file mode 100644 index 000000000..f5ec7bbfa --- /dev/null +++ b/xtuner/_lite/modelings/llava/configuration_llava.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llava model configuration""" +import os +from typing import Union +from transformers.configuration_utils import PretrainedConfig, custom_object_save +from transformers.utils import logging +from transformers import CONFIG_MAPPING, AutoModelForCausalLM, AutoConfig + +logger = logging.get_logger(__name__) + +class EnhancedLlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an + Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llava-9B. + + e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + vision_feature_layer (`int`, *optional*, defaults to -2): + The index of the layer to select the vision feature. + + Example: + + ```python + >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a Llava llava-1.5-7b style configuration + >>> configuration = LlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the llava-1.5-7b style configuration + >>> model = LlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + _auto_class = 'AutoConfig' + model_type = "enhanced_llava" + is_composition = False + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + + if text_config["model_type"] == 'internlm2': + from .configuration_internlm2 import InternLM2Config + from .modeling_internlm2 import InternLM2ForCausalLM + AutoConfig.register('internlm2', InternLM2Config) + AutoModelForCausalLM.register( + InternLM2Config, InternLM2ForCausalLM) + text_config['auto_map']['AutoConfig'] = 'configuration_internlm2.InternLM2Config' + text_config['auto_map']['AutoModel'] = 'modeling_internlm2.InternLM2ForCausalLM' + text_config['auto_map']['AutoModelForCausalLM'] = 'modeling_internlm2.InternLM2ForCausalLM' + text_config = InternLM2Config(**text_config) + else: + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + + super().__init__(**kwargs) + + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~PretrainedConfig.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + super().save_pretrained(save_directory, push_to_hub, **kwargs) + + if self.text_config._auto_class is not None: + custom_object_save(self.text_config, save_directory, config=self.text_config) + +AutoConfig.register('enhanced_llava', EnhancedLlavaConfig, exist_ok=True) \ No newline at end of file diff --git a/xtuner/_lite/modelings/llava/modeling_internlm2.py b/xtuner/_lite/modelings/llava/modeling_internlm2.py new file mode 100644 index 000000000..69ddc6196 --- /dev/null +++ b/xtuner/_lite/modelings/llava/modeling_internlm2.py @@ -0,0 +1,1899 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/modeling_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch InternLM2.5 model.""" +import math +import queue +import threading +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, logging, + replace_return_docstrings) + +try: + from transformers.generation.streamers import BaseStreamer +except Exception: + BaseStreamer = None + +from .configuration_internlm2 import InternLM2Config + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import (index_first_axis, pad_input, + unpad_input) +except: + pass + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'InternLM2Config' + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102 + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class InternLM2RMSNorm(nn.Module): + """InternLM2RMSNorm is equivalent to T5LayerNorm.""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm) + + +class InternLM2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains.""" + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + **(torch.arange(0, self.dim, 2, + dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance( + device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() + @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / + self.max_position_embeddings) - + (self.scaling_factor - 1))**( + self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to( + x.device) / self.dim)) + self.register_buffer( + 'inv_freq', inv_freq, + persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class InternLM2MLP(nn.Module): + """MLP for InternLM2 model.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.w1 = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.w3 = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.w2 = nn.Linear( + self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +class InternLM2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, + config: InternLM2Config, + layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will ' + 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` ' + 'when creating this class.') + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).') + + self.wqkv = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.bias, + ) + self.wo = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = InternLM2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'linear': + self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == 'dynamic': + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + # split qkv_states by tp size + key_value_slicing = (self.num_key_value_heads * + self.head_dim) // self.config.pretraining_tp + qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0) + qkv_states = torch.cat( + [ + F.linear(hidden_states, qkv_slice) + for qkv_slice in qkv_slices + ], + dim=-1 # pylint: disable=E1102 + ) + else: + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, + 'b q h gs d -> b q (h gs) d').transpose(1, 2) + key_states = qkv_states[..., -2, :].transpose(1, 2) + value_states = qkv_states[..., -1, :].transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split( + self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.wo.weight.split( + self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([ + F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102 + for i in range(self.config.pretraining_tp) + ]) + else: + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class InternLM2FlashAttention2(InternLM2Attention): + """ + InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, + # that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) + # produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` ' + 'make sure to use `sdpa` in the mean time, and open an issue at ' + 'https://github.com/huggingface/transformers') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # dropout_rate = self.attention_dropout if self.training else 0.0 + dropout_rate = 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (InternLM2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.wqkv.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.') + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, + self.hidden_size).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value # pylint: disable=E0606 + + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. + # For details, please see the comment in InternLM2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606 + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) # pylint: disable=E0606 + else: + attn_output = flash_attn_func( # pylint: disable=E0606 + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( # pylint: disable=E0606 + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( # pylint: disable=E0606 + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( # pylint: disable=E0606 + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606 + query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2 +class InternLM2SdpaAttention(InternLM2Attention): + """ + InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass + to adapt to SDPA API. + """ + + # Adapted from InternLM2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` + # once this is implemented. + logger.warning_once( + 'InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` ' + 'does not support `output_attentions=True`. Falling back to the manual attention implementation, ' + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. ' + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with + # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == 'cuda' and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of + # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph + # options. An inline conditional prevents dynamic shapes from compiling. + is_causal = bool(causal_mask is None and q_len > 1) + + attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102 + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + return attn_output, None, past_key_value + + +INTERNLM2_ATTENTION_CLASSES = { + 'eager': InternLM2Attention, + 'flash_attention_2': InternLM2FlashAttention2, + 'sdpa': InternLM2SdpaAttention, +} + + +# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2 +class InternLM2DecoderLayer(nn.Module): + """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model.""" + + def __init__(self, config: InternLM2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.attention = INTERNLM2_ATTENTION_CLASSES[ + config.attn_implementation]( + config=config, layer_idx=layer_idx) + + self.feed_forward = InternLM2MLP(config) + self.attention_norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + self.ffn_norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +InternLM2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`InternLM2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2 +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2PreTrainedModel(PreTrainedModel): + """ + InternLM2 pretraiend model's base class. + """ + + config_class = InternLM2Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['InternLM2DecoderLayer'] + _skip_keys_device_placement = ['past_key_values'] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +InternLM2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +# Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2 +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2Model(InternLM2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`] + Args: + config: InternLM2Config + """ + + _auto_class = 'AutoModel' + + def __init__(self, config: InternLM2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx) + + self.layers = nn.ModuleList([ + InternLM2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.tok_embeddings + + def set_input_embeddings(self, value): + self.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one' + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.' + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.tok_embeddings(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance( + past_key_values, + Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length( + ) if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, + cache_position, past_key_values, + output_attentions) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[ + 2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length + # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at + # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is + # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`. + # See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config.attn_implementation == 'flash_attention_2': + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length( + ) if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config.attn_implementation == 'sdpa' and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] if isinstance( + attention_mask, torch.Tensor) else past_seen_tokens + + sequence_length + 1) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError( + 'Custom 4D attention mask should be passed in inverted form with max==0`' + ) + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand( + input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone( + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, : + mask_length] + attention_mask[:, + None, + None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, : + mask_length] = causal_mask[:, :, :, : + mask_length].masked_fill( + padding_mask, + min_dtype) + if (self.config.attn_implementation == 'sdpa' + and attention_mask is not None + and attention_mask.device.type == 'cuda' + and not output_attentions): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype) # pylint: disable=E1120 + + return causal_mask + + +# Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM +class InternLM2ForCausalLM(InternLM2PreTrainedModel): + """Causal language model (CLM) for InternLM2.""" + + _auto_class = 'AutoModelForCausalLM' + _tied_weights_keys = ['output.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = InternLM2Model(config) + self.vocab_size = config.vocab_size + self.output = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + return self.output + + def set_output_embeddings(self, new_embeddings): + self.output = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, InternLM2ForCausalLM + >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + output_slices = self.output.weight.split( + self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [ + F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable + for i in range(self.config.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.output(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[ + 0] if cache_position is not None else past_key_values.get_seq_length( + ) + max_cache_length = ( + torch.tensor( + past_key_values.get_max_length(), + device=input_ids.device) + if past_key_values.get_max_length() is not None else None) + cache_length = past_length if max_cache_length is None else torch.min( + max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[ + 1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - + past_length):] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if (max_cache_length is not None and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length): + attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130 + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + input_length = position_ids.shape[ + -1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange( + past_length, + past_length + input_length, + device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update({ + 'position_ids': position_ids, + 'cache_position': cache_position, + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past), ) + return reordered_past + + def build_inputs(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + meta_instruction=''): + if history is None: + history = [] + if tokenizer.add_bos_token: + prompt = '' + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n""" + for record in history: + prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n""" + prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n""" + return tokenizer([prompt], return_tensors='pt') + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: + str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n' + '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory ' + '(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n' + '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such ' + 'as English and 中文.', + **kwargs, + ): + if history is None: + history = [] + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = { + k: v.to(self.device) + for k, v in inputs.items() if torch.is_tensor(v) + } + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0] + ] + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split('<|im_end|>')[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + if history is None: + history = [] + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + 'The version of `transformers` is too low. Please make sure ' + 'that you have installed `transformers>=4.28.0`.') + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + """ + Streamer used in generate to print words one by one. + """ + + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = '' + self.cache = [] + self.received_inputs = False + self.queue.put( + (self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError('ChatStreamer only supports batch size 1') + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode( + self.cache, skip_special_tokens=True) + if token.strip() != '<|im_end|>': + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2 +@add_start_docstrings( + """ + The InternLM2 Model transformer with a sequence classification head on top (linear layer). + [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForSequenceClassification(InternLM2PreTrainedModel): + """Sequence Classification Head for InternLM2 Model.""" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLM2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + 'Cannot handle batch sizes > 1 if no padding token is defined.' + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq( + input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), + sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype + in (torch.long, torch.int)): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2 +@add_start_docstrings( + """ +The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel): + """Question Answering model for InternLM2.""" + + base_model_prefix = 'transformer' + + def __init__(self, config): + super().__init__(config) + self.transformer = InternLM2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.tok_embeddings + + def set_input_embeddings(self, value): + self.transformer.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, + List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to( + start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss, ) + + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2 +@add_start_docstrings( + """ + The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForTokenClassification(InternLM2PreTrainedModel): + """Token classification model for InternLM2.""" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLM2Model(config) + if getattr(config, 'classifier_dropout', None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, 'hidden_dropout', None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/xtuner/_lite/modelings/llava/modeling_llava.py b/xtuner/_lite/modelings/llava/modeling_llava.py new file mode 100644 index 000000000..b987db7b5 --- /dev/null +++ b/xtuner/_lite/modelings/llava/modeling_llava.py @@ -0,0 +1,573 @@ +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Llava model.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ModelOutput +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers import AutoModel, AutoModelForCausalLM +from .configuration_llava import EnhancedLlavaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlavaConfig" + + + +@dataclass +# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava +class LlavaCausalLMOutputWithPast(ModelOutput): + """ + Base class for Llava causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: EnhancedLlavaConfig): + super().__init__() + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +LLAVA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlavaConfig`] or [`LlavaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAVA_START_DOCSTRING, +) +class LlavaPreTrainedModel(PreTrainedModel): + config_class = EnhancedLlavaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + # important: this ported version of Llava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + return self.language_model._supports_sdpa + + +LLAVA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + vision_feature_layer (`int`, *optional*, defaults to -2): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The LLAVA model which consists of a vision backbone and a language model.""", + LLAVA_START_DOCSTRING, +) +class LlavaForConditionalGeneration(LlavaPreTrainedModel): + + _auto_class = 'AutoModel' + + def __init__(self, config: EnhancedLlavaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config( + config.text_config, + attn_implementation=config._attn_implementation) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if inputs_embeds is None: + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # ------------- start add this ---------------- + if pixel_values is None and self.training: + # all of the input is text + # If not handled properly, deadlock can occur. + # print('===================all of the input is text==============') + image_size = self.config.vision_config.image_size + pixel_values = torch.zeros(input_ids.shape[0], 3, image_size, image_size, + dtype=torch.float32, + device=input_ids.device) + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + image_features = self.multi_modal_projector(selected_image_feature) + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features[0:0], inputs_embeds, input_ids, attention_mask, labels + ) + # ------------- end add this ---------------- + # 2. Merge text and images + elif pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + + image_features = self.multi_modal_projector(selected_image_feature) + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + else: + cache_length = past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.config.image_token_index in input_ids: + input_ids = input_ids[:, input_ids.shape[1] - 1 :] + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + ) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) + +AutoModel.register(EnhancedLlavaConfig, LlavaForConditionalGeneration, exist_ok=True) +AutoModelForCausalLM.register(EnhancedLlavaConfig, LlavaForConditionalGeneration, exist_ok=True) \ No newline at end of file diff --git a/xtuner/_lite/modelings/llava/processing_llava.py b/xtuner/_lite/modelings/llava/processing_llava.py new file mode 100644 index 000000000..230975575 --- /dev/null +++ b/xtuner/_lite/modelings/llava/processing_llava.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Llava. +""" + +from typing import List, Optional, Union + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType + + +class LlavaProcessor(ProcessorMixin): + r""" + Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. + + [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information. + + Args: + image_processor ([`CLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is not None: + image_inputs = self.image_processor(images, return_tensors=return_tensors) + else: + image_inputs = {} + text_inputs = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) \ No newline at end of file diff --git a/xtuner/_lite/parallel/__init__.py b/xtuner/_lite/parallel/__init__.py new file mode 100644 index 000000000..e631d05f9 --- /dev/null +++ b/xtuner/_lite/parallel/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .comm import all_to_all, all_to_all_list, barrier +from .sampler import LengthGroupedSampler, ParallelSampler, VLMLengthGroupedSampler +from .sequence import * # noqa: F401, F403 +from .setup import (get_dp_mesh, get_fsdp_mesh, get_sp_mesh, get_tp_mesh, + get_world_mesh, get_same_data_mesh, setup_parallel, + get_ep_mesh, get_experts_fsdp_mesh) +from .utils import MetaStateful + +__all__ = [ + 'ParallelSampler', + 'LengthGroupedSampler', + 'VLMLengthGroupedSampler', + 'all_to_all', + 'all_to_all_list', + 'get_dp_mesh', + 'get_same_data_mesh', + 'get_fsdp_mesh', + 'get_sp_mesh', + 'get_tp_mesh', + 'get_world_mesh', + 'setup_parallel', + 'MetaStateful', + 'get_ep_mesh', + 'get_experts_fsdp_mesh', + 'barrier' +] diff --git a/xtuner/_lite/parallel/comm.py b/xtuner/_lite/parallel/comm.py new file mode 100644 index 000000000..47daf4fb6 --- /dev/null +++ b/xtuner/_lite/parallel/comm.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.distributed_c10d import (_get_pg_default_device, + _object_to_tensor, + _tensor_to_object) + + +# Modified from https://github.com/microsoft/DeepSpeed/blob/ffd0a0e3ef24bfd00c2e5f35019d2674cc01ec14/deepspeed/sequence/layer.py#L15 # noqa: E501 +def _all_to_all( + input: Tensor, + world_size: int, + group: dist.ProcessGroup, + scatter_dim: int, + gather_dim: int, +): + input_list = [ + t.contiguous() + for t in torch.tensor_split(input, world_size, scatter_dim) + ] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input: Input tensor + sp_group: Sequence parallel process group + scatter_dim: Scatter dimension + gather_dim: Gather dimension + """ + + @staticmethod + def forward(ctx: Any, input: Tensor, sp_group: dist.ProcessGroup, + scatter_dim: int, gather_dim: int): + ctx.sp_group = sp_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.world_size = dist.get_world_size(sp_group) + output = _all_to_all(input, ctx.world_size, sp_group, scatter_dim, + gather_dim) + return output + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tuple: + grad_output = _all_to_all( + grad_output, + ctx.world_size, + ctx.sp_group, + ctx.gather_dim, + ctx.scatter_dim, + ) + return ( + grad_output, + None, + None, + None, + ) + + +def all_to_all( + input: Tensor, + sp_group: dist.ProcessGroup, + scatter_dim: int = 2, + gather_dim: int = 1, +): + """Convenience function to apply the all-to-all operation with scatter and + gather dimensions. + + Notes: + We have wrapped the `torch.distributed.all_to_all` function to + enable automatic differentiation of the all-to-all operation. + + Args: + input: The input tensor for which all-to-all communication is performed + sp_group: The sequence parallel process group. + scatter_dim: The dimension along which the input tensor is scattered + (default: 2). + gather_dim: The dimension along which the output tensor is gathered + (default: 1). + + Returns: + The output tensor after the all-to-all communication. + """ + return _AllToAll.apply(input, sp_group, scatter_dim, gather_dim) + + +def all_to_all_list(object_list, group=None): + current_device = _get_pg_default_device(group) + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + tensor_list, size_list = zip( + * + [_object_to_tensor(obj, current_device, group) for obj in object_list]) + tensor_list = list(tensor_list) + size_list = torch.cat(size_list) + buffer = [None] * world_size + + dist.all_gather_object(buffer, size_list, group=group) + size_this_rank = [] + for size_list in buffer: + size_this_rank.append(size_list[rank]) + + target_tensor_list = [ + torch.empty(size.item(), dtype=torch.uint8, device=current_device) + for size in size_this_rank + ] + dist.all_to_all(target_tensor_list, tensor_list, group=group) + + for i in range(len(target_tensor_list)): + obj_view = target_tensor_list[i].type(torch.uint8) + target_tensor_list[i] = _tensor_to_object(obj_view, size_this_rank[i], + group) + + return target_tensor_list + + +def barrier(): + if not dist.is_available(): + return + + rank = dist.get_rank() + if rank == 0: + objects = [1] + else: + objects = [None] + + dist.broadcast_object_list(objects, src=0) + return diff --git a/xtuner/_lite/parallel/device.py b/xtuner/_lite/parallel/device.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/_lite/parallel/fsdp/__init__.py b/xtuner/_lite/parallel/fsdp/__init__.py new file mode 100644 index 000000000..4f5d2b80f --- /dev/null +++ b/xtuner/_lite/parallel/fsdp/__init__.py @@ -0,0 +1,11 @@ +from .checkpointing import RECOMPUTE_MODULES, checkpoint_check_fn, checkpoint +from .lazy import LoadWoInit, dp_lazy_init, dp_sp_lazy_init, lazy_init_megatron +from .wrap import (all_required_grad_wrap_policy, layer_and_emb_wrap_policy, + layer_auto_wrap_policy, token_embedding_wrap_policy) +from .clip_grad import clip_grad_norm_ +__all__ = [ + 'RECOMPUTE_MODULES', 'checkpoint_check_fn', 'LoadWoInit', 'dp_lazy_init', + 'all_required_grad_wrap_policy', 'layer_auto_wrap_policy', + 'token_embedding_wrap_policy', 'lazy_init_megatron', 'dp_sp_lazy_init', + 'layer_and_emb_wrap_policy', 'checkpoint' +] diff --git a/xtuner/_lite/parallel/fsdp/checkpointing.py b/xtuner/_lite/parallel/fsdp/checkpointing.py new file mode 100644 index 000000000..1fe998987 --- /dev/null +++ b/xtuner/_lite/parallel/fsdp/checkpointing.py @@ -0,0 +1,92 @@ +import random +from typing import Any, Tuple +import torch +import torch.nn as nn +from torch.utils.checkpoint import ( + _checkpoint_without_reentrant_generator, + _DEFAULT_DETERMINISM_MODE, +) +from contextlib import nullcontext, contextmanager +from torch.distributed._composable.contract import contract + + +RECOMPUTE_MODULES = ('InternLM2DecoderLayer', 'CLIPEncoderLayer') + + +def checkpoint_check_fn(submodule, target=RECOMPUTE_MODULES, selective=1.0): + ret = False + if type(submodule).__name__ in target: + if random.uniform(0, 1) < selective: + ret = True + return ret + + +@contextmanager +def _no_hook(module: nn.Module): + r""" + Disable hooks installed by checkpoint to avoid unintentional recursion + during backward recomputation. + """ + orig_enable_hook = checkpoint.state(module).enable_hook + checkpoint.state(module).enable_hook = False + try: + yield + finally: + checkpoint.state(module).enable_hook = orig_enable_hook + + +# Support **kwargs +@contract() +def checkpoint(module: nn.Module) -> nn.Module: + torch._C._log_api_usage_once("torch.distributed.checkpoint") + + def forward_pre_hook(module: nn.Module, *args) -> None: + if checkpoint.state(module).enable_hook: + def context_fns(): + return nullcontext(), _no_hook(module) + + checkpoint.state( + module + )._ac_generator = _checkpoint_without_reentrant_generator( + module, True, context_fns, _DEFAULT_DETERMINISM_MODE, False, *args[0], **args[1] + ) + next(checkpoint.state(module)._ac_generator) + + def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any: + if checkpoint.state(module).enable_hook: + try: + next(checkpoint.state(module)._ac_generator) + except StopIteration: + pass + else: + raise RuntimeError( + "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!" + ) + + # Ensure that we no longer hold on to the generator. always_call=True helps ensure we + # clear this even in the case of exception in fwd pass. + checkpoint.state(module)._ac_generator = None + + checkpoint.state(module).enable_hook = True + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + module.register_forward_hook(forward_hook, prepend=True, always_call=True) + return module + + +if __name__ == '__main__': + + class MyModel(nn.Module): + + def __init__(self): + super().__init__() + self.l1 = nn.Linear(10, 10) + self.l2 = nn.Linear(10, 10) + + def forward(self, x, b, a=4, c=4): + print(b, a, c) + return self.l2(self.l1(x)) + + # from torch.distributed._composable.checkpoint_activation import checkpoint + model = MyModel() + checkpoint(model) # apply activation checkpointing only to l1 + model(torch.zeros(2, 10), 2, a=5, c=6).sum().backward() diff --git a/xtuner/_lite/parallel/fsdp/clip_grad.py b/xtuner/_lite/parallel/fsdp/clip_grad.py new file mode 100644 index 000000000..a6cecbc95 --- /dev/null +++ b/xtuner/_lite/parallel/fsdp/clip_grad.py @@ -0,0 +1,88 @@ +from torch.nn.utils.clip_grad import _no_grad +import torch +from typing import List, Optional, Tuple, Union, Dict +from torch import Tensor +from torch import distributed as dist +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@_no_grad +def clip_grad_norm_( + parameters, + fsdp_mesh, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach= None, +) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + # for grouped_device_grads in group_tensors_by_device_mesh(device_grads).values(): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + local_sharded_norm = torch.linalg.vector_norm( + torch.stack([norm.to_local().to(first_device) for norm in norms]), norm_type, dtype=torch.float32 + ) + + if norm_type == 2: + total_norm = local_sharded_norm**norm_type + dist.all_reduce(total_norm, group=fsdp_mesh.get_group(mesh_dim=0)) + total_norm = total_norm ** (1 / norm_type) + else: + raise NotImplementedError + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device.to(g.dtype)) + + return total_norm diff --git a/xtuner/_lite/parallel/fsdp/lazy.py b/xtuner/_lite/parallel/fsdp/lazy.py new file mode 100644 index 000000000..149f141fb --- /dev/null +++ b/xtuner/_lite/parallel/fsdp/lazy.py @@ -0,0 +1,153 @@ +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor, distribute_tensor + +from xtuner._lite import get_logger, get_torch_device_module + +logger = get_logger() + +DEVICE_MODULE = get_torch_device_module() + + +@torch.no_grad +def dp_lazy_init(module, module_map, dp_mesh): + device = DEVICE_MODULE.current_device() + module.to_empty(device=DEVICE_MODULE.current_device(), recurse=False) + + if dp_mesh.get_local_rank() == 0: + master_module = module_map[module] + master_params = { + name: param + for name, param in master_module.named_parameters(recurse=False) + } + master_buffers = { + name: buffer + for name, buffer in master_module.named_buffers(recurse=False) + } + + for name, param in module.named_parameters(recurse=False): + + p_copy = master_params[name].to(device).to(param.dtype) + # if param.requires_grad: + # p_copy = p_copy.to(device).to(param.dtype) + # else: + # p_copy = p_copy.to(device) + param.data.copy_(p_copy) + + for name, buffer in module.named_buffers(recurse=False): + + b_copy = master_buffers[name].to(device).to(buffer.dtype) + # b_copy = b_copy.to(device) + buffer.data.copy_(b_copy) + + +@torch.no_grad +def dp_sp_lazy_init(module, module_map, dp_mesh, sp_mesh): + device = DEVICE_MODULE.current_device() + module.to_empty(device=DEVICE_MODULE.current_device(), recurse=False) + + if dp_mesh.get_local_rank() == 0 and sp_mesh.get_local_rank() == 0: + master_module = module_map[module] + master_params = { + name: param + for name, param in master_module.named_parameters(recurse=False) + } + master_buffers = { + name: buffer + for name, buffer in master_module.named_buffers(recurse=False) + } + + for name, param in module.named_parameters(recurse=False): + p_copy = master_params[name].to(device).to(param.dtype) + param.data.copy_(p_copy) + + for name, buffer in module.named_buffers(recurse=False): + b_copy = master_buffers[name].to(device).to(buffer.dtype) + buffer.data.copy_(b_copy) + + +@torch.no_grad +def lazy_init_megatron(module, rank0_map, dp_mesh, tp_mesh=None, pp_mesh=None): + device = DEVICE_MODULE.current_device() + + if dp_mesh.get_rank() == 0: + rank0_module = rank0_map[module] + rank0_params = { + name: param + for name, param in rank0_module.named_parameters(recurse=False) + } + rank0_buffers = { + name: buffer + for name, buffer in rank0_module.named_buffers(recurse=False) + } + else: + rank0_params = None + rank0_buffers = None + + param_shapes = { + name: param.full_tensor().shape + if isinstance(param, DTensor) else param.shape + for name, param in module.named_parameters(recurse=False) + } + + module.to_empty(device=DEVICE_MODULE.current_device(), recurse=False) + + for name, param in module.named_parameters(recurse=False): + dtype = param.dtype + if dp_mesh.get_rank() == 0: + rank0_param = rank0_params[name].to(device).to(dtype) + else: + full_shape = param_shapes[name] + rank0_param = torch.zeros(full_shape, dtype=dtype, device=device) + + dist.broadcast(rank0_param, src=0) + + if isinstance(param, DTensor): + mesh = param.device_mesh + assert mesh == tp_mesh + placements = param.placements + rank0_param = distribute_tensor(rank0_param, mesh, placements) + + param.data.copy_(rank0_param) + dist.barrier() + + # TP does not shard buffers + for name, buffer in module.named_buffers(recurse=False): + if dp_mesh.get_rank() == 0: + rank0_buffer = rank0_buffers[name].to(device).to(buffer.dtype) + else: + rank0_buffer = torch.empty_like(buffer).to(device) + + dist.broadcast(rank0_buffer, src=0) + buffer.data.copy_(rank0_buffer) + + +class LoadWoInit: + """Context manager that disable parameter initialization.""" + + def __init__(self): + self.constant_ = torch.nn.init.constant_ + self.zeros_ = torch.nn.init.zeros_ + self.ones_ = torch.nn.init.ones_ + self.uniform_ = torch.nn.init.uniform_ + self.normal_ = torch.nn.init.normal_ + self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_ + self.kaiming_normal_ = torch.nn.init.kaiming_normal_ + + def __enter__(self, *args, **kwargs): + torch.nn.init.constant_ = lambda *args, **kwargs: None + torch.nn.init.zeros_ = lambda *args, **kwargs: None + torch.nn.init.ones_ = lambda *args, **kwargs: None + torch.nn.init.uniform_ = lambda *args, **kwargs: None + torch.nn.init.normal_ = lambda *args, **kwargs: None + torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None + torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None + + def __exit__(self, *args, **kwargs): + torch.nn.init.constant_ = self.constant_ + torch.nn.init.zeros_ = self.zeros_ + torch.nn.init.ones_ = self.ones_ + torch.nn.init.uniform_ = self.uniform_ + torch.nn.init.normal_ = self.normal_ + torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_ + torch.nn.init.kaiming_normal_ = self.kaiming_normal_ diff --git a/xtuner/_lite/parallel/fsdp/precision.py b/xtuner/_lite/parallel/fsdp/precision.py new file mode 100644 index 000000000..b8d1e7249 --- /dev/null +++ b/xtuner/_lite/parallel/fsdp/precision.py @@ -0,0 +1,23 @@ +import torch +from torch import nn + + +def set_require_grad_param_to_fp32(model: nn.Module): + + def traverse(module: nn.Module): + + for m_name, child in module.named_children(): + + all_require_grad = True + for p_name, param in child.named_parameters(): + + if not param.requires_grad: + all_require_grad = False + break + + if all_require_grad: + child.to(torch.float32) + + traverse(child) + + traverse(model) diff --git a/xtuner/_lite/parallel/fsdp/wrap.py b/xtuner/_lite/parallel/fsdp/wrap.py new file mode 100644 index 000000000..f1d97c071 --- /dev/null +++ b/xtuner/_lite/parallel/fsdp/wrap.py @@ -0,0 +1,80 @@ +from torch import nn + +from xtuner._lite import get_logger + +logger = get_logger() +_LAYERS = [ + 'InternLM2DecoderLayer', 'CLIPVisionModel', 'LlavaMultiModalProjector' +] + + +def layer_auto_wrap_policy( + module, + recurse: bool, + nonwrapped_numel: int, + layer_cls=_LAYERS, +) -> bool: + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap for + # the leaf node or reminder + return module.__class__.__name__ in layer_cls + + +def layer_and_emb_wrap_policy( + module, + recurse: bool, + nonwrapped_numel: int, + vocab_size, + layer_cls=_LAYERS, +) -> bool: + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap for + # the leaf node or reminder + if module.__class__.__name__ in layer_cls or isinstance( + module, nn.Embedding): + return True + elif isinstance(module, nn.Linear): + return module.weight.size(0) == vocab_size + else: + return False + + +def token_embedding_wrap_policy( + module, + recurse: bool, + nonwrapped_numel: int, + vocab_size: int, +) -> bool: + if recurse: + # always recurse + return True + + if isinstance(module, (nn.Embedding, nn.Linear)): + if module.weight.size(0) == vocab_size: + return True + + return False + + +def all_required_grad_wrap_policy( + module, + recurse: bool, + nonwrapped_numel: int, +) -> bool: + if recurse: + # always recurse + return True + + requires_grads = [p.requires_grad for p in module.parameters()] + + if len(requires_grads) and all(requires_grads): + logger.debug(module) + return True + + return False diff --git a/xtuner/_lite/parallel/loss.py b/xtuner/_lite/parallel/loss.py new file mode 100644 index 000000000..44fd286f8 --- /dev/null +++ b/xtuner/_lite/parallel/loss.py @@ -0,0 +1,18 @@ +from torch import distributed as dist +from torch.distributed.nn.functional import all_reduce + + +def dist_softmax(logits, mesh, dim=-1, temperature=1): + + logits = logits / temperature + + max_values = logits.max(dim, keepdim=True) + all_reduce(max_values, dist.ReduceOp.MAX, mesh.group()) + + exps = (logits - max_values).exp() + sums = exps.sum(dim, keepdim=True) + all_reduce(sums, dist.ReduceOp.SUM, mesh.group()) + + probs = exps / sums + + return probs diff --git a/xtuner/_lite/parallel/megatron/__init__.py b/xtuner/_lite/parallel/megatron/__init__.py new file mode 100644 index 000000000..87e746247 --- /dev/null +++ b/xtuner/_lite/parallel/megatron/__init__.py @@ -0,0 +1,52 @@ +from .internlm2 import (megatron_internlm2, megatron_internlm2_casual, + megatron_internlm2_reward) + +from .qwen2 import (megatron_qwen2_casual, megatron_qwen2, megatron_qwen2_reward) +from .internvl2 import megatron_internvl2_casual +from .minicpmv import megatron_minicpmv_casual +from .llama import megatron_llama, megatron_llama_casual +from .janus import megatron_janus_casual + +MEGATRON_MAP = { + 'InternLM2ForCausalLM': megatron_internlm2_casual, + 'InternLM2ForRewardModel': megatron_internlm2_reward, + 'InternLM2Model': megatron_internlm2, + 'Qwen2ForCausalLM': megatron_qwen2_casual, + 'Qwen2Model': megatron_qwen2, + 'Qwen2ForRewardModel': megatron_qwen2_reward, + 'InternVLChatModel': megatron_internvl2_casual, + 'MiniCPMV': megatron_minicpmv_casual, + 'MultiModalityCausalLM': megatron_janus_casual, + 'LlamaModel': megatron_llama, + 'LlamaForCausalLM': megatron_llama_casual +} + + +def megatron_parallelize(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True, + **kwargs): + + cls_name = model.__class__.__name__ + if cls_name not in MEGATRON_MAP: + raise NotImplementedError + + parallel_fn = MEGATRON_MAP[cls_name] + + model = parallel_fn( + model, + rank0_model, + dp_mesh, + tp_mesh=tp_mesh, + pp_mesh=pp_mesh, + mp_policy=mp_policy, + recompute_ratio=recompute_ratio, + reshard_after_forward=reshard_after_forward, + **kwargs) + + return model diff --git a/xtuner/_lite/parallel/megatron/internlm2.py b/xtuner/_lite/parallel/megatron/internlm2.py new file mode 100644 index 000000000..0da9437ff --- /dev/null +++ b/xtuner/_lite/parallel/megatron/internlm2.py @@ -0,0 +1,303 @@ +from functools import partial +from packaging import version + +import torch +from torch import nn +from torch.distributed._tensor import Replicate, distribute_tensor +from torch.distributed.tensor.parallel import (ColwiseParallel, + PrepareModuleInput, + PrepareModuleOutput, + RowwiseParallel, + parallelize_module) + +from xtuner._lite import get_logger +from ..fsdp.lazy import lazy_init_megatron +from .utils import map_rank0_modules + +logger = get_logger() + + +def _tp_internlm2(model, tp_mesh): + + layer_tp_plan = { + # by default ColwiseParallel input layouts is replicated + # and RowwiseParallel output layouts is replicated + 'attention.wqkv': + ColwiseParallel(), + 'attention.wo': + RowwiseParallel(), + 'attention_norm': + PrepareModuleInput( + input_layouts=(Replicate(), ), + desired_input_layouts=(Replicate(), ), + use_local_output=True + ), + 'feed_forward.w1': + ColwiseParallel(), + 'feed_forward.w2': + RowwiseParallel(), + 'feed_forward.w3': + ColwiseParallel(), + 'ffn_norm': + PrepareModuleInput( + input_layouts=(Replicate(), ), + desired_input_layouts=(Replicate(), ), + use_local_output=True + ) + } + + tp_size = tp_mesh.size() + for layer in model.layers: + attention = layer.attention + num_key_value_heads = attention.num_key_value_heads + num_heads = attention.num_heads + hidden_size = attention.hidden_size + + attention.num_heads = num_heads // tp_size + attention.num_key_value_heads = num_key_value_heads // tp_size + attention.hidden_size = hidden_size // tp_size + + attn_norm = layer.attention_norm + attn_norm.register_parameter( + 'weight', + nn.Parameter( + distribute_tensor(attn_norm.weight, tp_mesh, [Replicate()]))) + + ffn_norm = layer.ffn_norm + ffn_norm.register_parameter( + 'weight', + nn.Parameter( + distribute_tensor(ffn_norm.weight, tp_mesh, [Replicate()]))) + + parallelize_module( + module=layer, + device_mesh=tp_mesh, + parallelize_plan=layer_tp_plan, + ) + norm = model.norm + dist_norm_w = nn.Parameter( + distribute_tensor(norm.weight, tp_mesh, [Replicate()])) + norm.register_parameter('weight', dist_norm_w) + + # emb = model.tok_embeddings + # dist_emb_w = nn.Parameter( + # distribute_tensor(emb.weight, tp_mesh, [Replicate()])) + # emb.register_parameter('weight', dist_emb_w) + + model = parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'model.tok_embeddings': + RowwiseParallel(input_layouts=Replicate(), ), + 'model.norm':PrepareModuleInput( + input_layouts=(Replicate(),), + desired_input_layouts=(Replicate(),), + use_local_output=True + ), + }) + + +def megatron_internlm2(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + if tp_mesh and tp_mesh.size() > 1: + _tp_internlm2(model, tp_mesh) + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + from torch.distributed._composable import checkpoint + from torch.distributed._composable.fsdp import fully_shard + num_layers = len(model.layers) + num_recompute_layers = int(num_layers * recompute_ratio) + + for i, block in enumerate(model.layers): + + block.apply(param_init_fn) + + # # As an optimization, do not reshard after forward for the last + # # transformer block since FSDP would prefetch it immediately + # if i < num_layers - 1: + # _reshard = reshard_after_forward + # else: + # _reshard = False + + fully_shard( + block, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + if i < num_recompute_layers: + checkpoint(block) + + if version.parse(torch.__version__) >= version.parse("2.5.0"): + for layer_cur, layer_next in zip(model.layers[:-1], model.layers[1:]): + layer_cur.set_modules_to_forward_prefetch([layer_next]) + + model.tok_embeddings.apply(param_init_fn) + model.norm.apply(param_init_fn) + + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + + +def megatron_internlm2_casual(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + megatron_internlm2( + model.model, + rank0_model.model if dp_mesh.get_rank() == 0 else None, + dp_mesh, + tp_mesh=tp_mesh, + pp_mesh=pp_mesh, + mp_policy=mp_policy, + recompute_ratio=recompute_ratio, + reshard_after_forward=reshard_after_forward) + + if tp_mesh and tp_mesh.size() > 1: + model = parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'output': ColwiseParallel(output_layouts=Replicate(), ), + }) + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + model.output.apply(param_init_fn) + + from torch.distributed._composable.fsdp import fully_shard + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + + +def megatron_internlm2_reward(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + megatron_internlm2( + model.model, + rank0_model.model if dp_mesh.get_rank() == 0 else None, + dp_mesh, + tp_mesh=tp_mesh, + pp_mesh=pp_mesh, + mp_policy=mp_policy, + recompute_ratio=recompute_ratio, + reshard_after_forward=reshard_after_forward) + + if tp_mesh and tp_mesh.size() > 1: + + head_0 = model.v_head[0] + dist_head_0 = nn.Parameter( + distribute_tensor(head_0.weight, tp_mesh, [Replicate()])) + head_0.register_parameter('weight', dist_head_0) + + head_norm = model.v_head[1] + dist_head_norm = nn.Parameter( + distribute_tensor(head_norm.weight, tp_mesh, [Replicate()])) + dist_head_bias = nn.Parameter( + distribute_tensor(head_norm.bias, tp_mesh, [Replicate()])) + head_norm.register_parameter('weight', dist_head_norm) + head_norm.register_parameter('bias', dist_head_bias) + + head_1 = model.v_head[3] + dist_head_1 = nn.Parameter( + distribute_tensor(head_1.weight, tp_mesh, [Replicate()])) + head_1.register_parameter('weight', dist_head_1) + + + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'v_head.0': PrepareModuleInput( + input_layouts=(Replicate(),), + desired_input_layouts=(Replicate(),), + ), + 'v_head.0': PrepareModuleOutput( + output_layouts=(Replicate(),), + desired_output_layouts=(Replicate(),), + use_local_output=True + ), + 'v_head.1': PrepareModuleInput( + input_layouts=(Replicate(),), + desired_input_layouts=(Replicate(),), + ), + 'v_head.1': PrepareModuleOutput( + output_layouts=(Replicate(),), + desired_output_layouts=(Replicate(),), + use_local_output=True + ), + 'v_head.3': PrepareModuleInput( + input_layouts=(Replicate(),), + desired_input_layouts=(Replicate(),), + ), + 'v_head.3': PrepareModuleOutput( + output_layouts=(Replicate(),), + desired_output_layouts=(Replicate(),), + use_local_output=True + ), + + }) + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + model.v_head.apply(param_init_fn) + + from torch.distributed._composable.fsdp import fully_shard + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) diff --git a/xtuner/_lite/parallel/megatron/internvl2.py b/xtuner/_lite/parallel/megatron/internvl2.py new file mode 100644 index 000000000..0796a033b --- /dev/null +++ b/xtuner/_lite/parallel/megatron/internvl2.py @@ -0,0 +1,122 @@ +from functools import partial +from torch.distributed._composable.fsdp import fully_shard +from xtuner._lite import get_logger +from ..fsdp.lazy import lazy_init_megatron +from .utils import map_rank0_modules +from ..fsdp import checkpoint + +logger = get_logger() + + +def megatron_internvl2_casual(meta_model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True, + small_model=True): # if 70b model, set to False + if tp_mesh.size() > 1: + raise NotImplementedError + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(meta_model, rank0_model) + else: + rank0_map = None + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + num_layers = len(meta_model.language_model.model.layers) + num_recompute_layers = int(num_layers * recompute_ratio) + for i, block in enumerate(meta_model.language_model.model.layers): + block.apply(param_init_fn) + + fully_shard( + block, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + if i < num_recompute_layers: + checkpoint(block) + + has_forward_prefetch = hasattr(meta_model.language_model.model.layers[0], 'set_modules_to_forward_prefetch') + if has_forward_prefetch: + for layer_cur, layer_next in zip(meta_model.language_model.model.layers[:-1], + meta_model.language_model.model.layers[1:]): + layer_cur.set_modules_to_forward_prefetch([layer_next]) + + if small_model: + meta_model.vision_model.apply(param_init_fn) + fully_shard( + meta_model.vision_model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + for i, layers in enumerate(meta_model.vision_model.encoder.layers): + checkpoint(layers) + + if has_forward_prefetch: + meta_model.vision_model.set_modules_to_forward_prefetch([meta_model.language_model.model.layers[0]]) + else: + # visual + for i, block in enumerate(meta_model.vision_model.encoder.layers): + block.apply(param_init_fn) + + fully_shard( + block, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + checkpoint(block) + + if has_forward_prefetch: + for layer_cur, layer_next in zip(meta_model.vision_model.encoder.layers[:-1], + meta_model.vision_model.encoder.layers[1:]): + layer_cur.set_modules_to_forward_prefetch([layer_next]) + + meta_model.vision_model.encoder.layers[-1].set_modules_to_forward_prefetch([meta_model.language_model.model.layers[0]]) + + meta_model.vision_model.embeddings.apply(param_init_fn) + fully_shard( + meta_model.vision_model.embeddings, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + if has_forward_prefetch: + meta_model.vision_model.embeddings.set_modules_to_forward_prefetch([meta_model.vision_model.encoder.layers[0]]) + + meta_model.mlp1.apply(param_init_fn) + try: + meta_model.language_model.model.tok_embeddings.apply(param_init_fn) + except AttributeError: + meta_model.language_model.model.embed_tokens.apply(param_init_fn) + + meta_model.language_model.model.norm.apply(param_init_fn) + try: + meta_model.language_model.output.apply(param_init_fn) + except AttributeError: + meta_model.language_model.lm_head.apply(param_init_fn) + + model = fully_shard( + meta_model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) # False is zero2, True is zero3 + + if has_forward_prefetch and small_model: + model.set_modules_to_forward_prefetch([model.vision_model]) + elif has_forward_prefetch: + model.set_modules_to_forward_prefetch([model.vision_model.embeddings]) + return model diff --git a/xtuner/_lite/parallel/megatron/janus.py b/xtuner/_lite/parallel/megatron/janus.py new file mode 100644 index 000000000..e8eec0477 --- /dev/null +++ b/xtuner/_lite/parallel/megatron/janus.py @@ -0,0 +1,125 @@ +from functools import partial +from torch.distributed._composable.fsdp import fully_shard +from xtuner._lite import get_logger +from ..fsdp.lazy import lazy_init_megatron +from .utils import map_rank0_modules +from ..fsdp import checkpoint + +logger = get_logger() + + +def megatron_janus_casual(meta_model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True, + freeze_style='mode1'): + if tp_mesh.size() > 1: + raise NotImplementedError + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(meta_model, rank0_model) + else: + rank0_map = None + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + num_layers = len(meta_model.language_model.model.layers) + num_recompute_layers = int(num_layers * recompute_ratio) + for i, block in enumerate(meta_model.language_model.model.layers): + block.apply(param_init_fn) + + fully_shard( + block, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + if i < num_recompute_layers: + checkpoint(block) + + if freeze_style == 'mode1': + meta_model.language_model.lm_head.apply(param_init_fn) + fully_shard( + meta_model.language_model.lm_head, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + meta_model.gen_head.apply(param_init_fn) + fully_shard( + meta_model.gen_head, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + meta_model.aligner.apply(param_init_fn) + fully_shard( + meta_model.aligner, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + meta_model.gen_aligner.apply(param_init_fn) + fully_shard( + meta_model.gen_aligner, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + meta_model.vision_model.apply(param_init_fn) + meta_model.gen_vision_model.apply(param_init_fn) + meta_model.gen_embed.apply(param_init_fn) + meta_model.language_model.model.embed_tokens.apply(param_init_fn) + meta_model.language_model.model.norm.apply(param_init_fn) + + model = fully_shard( + meta_model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) # False is zero2, True is zero3 + + # TODO: Bug + # model.set_reshard_after_backward(False) + + elif freeze_style == 'mode2': + meta_model.gen_vision_model.apply(param_init_fn) + fully_shard( + meta_model.gen_vision_model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + meta_model.gen_vision_model.set_reshard_after_backward(False) + + meta_model.vision_model.apply(param_init_fn) + fully_shard( + meta_model.vision_model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + meta_model.vision_model.set_reshard_after_backward(False) + + meta_model.gen_head.apply(param_init_fn) + meta_model.aligner.apply(param_init_fn) + meta_model.gen_aligner.apply(param_init_fn) + meta_model.gen_embed.apply(param_init_fn) + meta_model.language_model.model.embed_tokens.apply(param_init_fn) + meta_model.language_model.model.norm.apply(param_init_fn) + meta_model.language_model.lm_head.apply(param_init_fn) + model = fully_shard( + meta_model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + + return model diff --git a/xtuner/_lite/parallel/megatron/llama.py b/xtuner/_lite/parallel/megatron/llama.py new file mode 100644 index 000000000..e100e88f0 --- /dev/null +++ b/xtuner/_lite/parallel/megatron/llama.py @@ -0,0 +1,216 @@ +from functools import partial +from packaging import version + +import torch +from torch import nn +from torch.distributed._tensor import Replicate, distribute_tensor +from torch.distributed.tensor.parallel import (ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + parallelize_module) + +from xtuner._lite import get_logger +from ..fsdp.lazy import lazy_init_megatron +from .utils import map_rank0_modules + +logger = get_logger() + + +def _tp_llama(model, tp_mesh): + + layer_tp_plan = { + + # by default ColwiseParallel input layouts is replicated + # and RowwiseParallel output layouts is replicated + 'self_attn.q_proj': + ColwiseParallel(), + 'self_attn.k_proj': + ColwiseParallel(), + 'self_attn.v_proj': + ColwiseParallel(), + 'self_attn.o_proj': + RowwiseParallel(), + 'input_layernorm': + PrepareModuleInput( + input_layouts=(Replicate(), ), + desired_input_layouts=(Replicate(), ), + use_local_output=True + ), + 'mlp.up_proj': + ColwiseParallel(), + 'mlp.down_proj': + RowwiseParallel(), + 'mlp.gate_proj': + ColwiseParallel(), + 'post_attention_layernorm': + PrepareModuleInput( + input_layouts=(Replicate(), ), + desired_input_layouts=(Replicate(), ), + use_local_output=True + ) + } + + tp_size = tp_mesh.size() + for layer in model.layers: + attention = layer.self_attn + num_key_value_heads = attention.num_key_value_heads + num_heads = attention.num_heads + hidden_size = attention.hidden_size + + attention.num_heads = num_heads // tp_size + attention.num_key_value_heads = num_key_value_heads // tp_size + attention.hidden_size = hidden_size // tp_size + + attn_norm = layer.input_layernorm + attn_norm.register_parameter( + 'weight', + nn.Parameter( + distribute_tensor(attn_norm.weight, tp_mesh, [Replicate()]))) + + ffn_norm = layer.post_attention_layernorm + ffn_norm.register_parameter( + 'weight', + nn.Parameter( + distribute_tensor(ffn_norm.weight, tp_mesh, [Replicate()]))) + + parallelize_module( + module=layer, + device_mesh=tp_mesh, + parallelize_plan=layer_tp_plan, + ) + + norm = model.norm + dist_norm_w = nn.Parameter( + distribute_tensor(norm.weight, tp_mesh, [Replicate()])) + norm.register_parameter('weight', dist_norm_w) + + model = parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'model.embed_tokens': + RowwiseParallel(input_layouts=Replicate(), ), + 'model.norm':PrepareModuleInput( + input_layouts=(Replicate(),), + desired_input_layouts=(Replicate(),), + use_local_output=True + ), + }) + + +def megatron_llama(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + + if tp_mesh and tp_mesh.size() > 1: + _tp_llama(model, tp_mesh) + + from torch.distributed._composable import checkpoint + from torch.distributed._composable.fsdp import fully_shard + num_layers = len(model.layers) + num_recompute_layers = int(num_layers * recompute_ratio) + + for i, block in enumerate(model.layers): + + block.apply(param_init_fn) + + # # As an optimization, do not reshard after forward for the last + # # transformer block since FSDP would prefetch it immediately + # if i < num_layers - 1: + # _reshard = reshard_after_forward + # else: + # _reshard = False + + fully_shard( + block, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + if i < num_recompute_layers: + checkpoint(block) + + if version.parse(torch.__version__) >= version.parse("2.5.0"): + for layer_cur, layer_next in zip(model.layers[:-1], model.layers[1:]): + layer_cur.set_modules_to_forward_prefetch([layer_next]) + + + model.embed_tokens.apply(param_init_fn) + model.norm.apply(param_init_fn) + if hasattr(model, 'rotary_emb'): + model.rotary_emb.apply(param_init_fn) + + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + + +def megatron_llama_casual(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + megatron_llama( + model.model, + rank0_model.model if dp_mesh.get_rank() == 0 else None, + dp_mesh, + tp_mesh=tp_mesh, + pp_mesh=pp_mesh, + mp_policy=mp_policy, + recompute_ratio=recompute_ratio, + reshard_after_forward=reshard_after_forward) + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + + if tp_mesh and tp_mesh.size() > 1: + model = parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'lm_head': ColwiseParallel(output_layouts=Replicate(), ), + }) + + model.lm_head.apply(param_init_fn) + + from torch.distributed._composable.fsdp import fully_shard + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) diff --git a/xtuner/_lite/parallel/megatron/minicpmv.py b/xtuner/_lite/parallel/megatron/minicpmv.py new file mode 100644 index 000000000..2e667cca0 --- /dev/null +++ b/xtuner/_lite/parallel/megatron/minicpmv.py @@ -0,0 +1,79 @@ +from functools import partial +from torch.distributed._composable.fsdp import fully_shard +from xtuner._lite import get_logger +from ..fsdp.lazy import lazy_init_megatron +from .utils import map_rank0_modules +from ..fsdp import checkpoint + +logger = get_logger() + + +def megatron_minicpmv_casual(meta_model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + if tp_mesh.size() > 1: + raise NotImplementedError + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(meta_model, rank0_model) + else: + rank0_map = None + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + # visual + meta_model.vpm.apply(param_init_fn) + fully_shard( + meta_model.vpm, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + for i, layers in enumerate(meta_model.vpm.encoder.layers): + checkpoint(layers) + + # resampler + meta_model.resampler.apply(param_init_fn) + fully_shard( + meta_model.resampler, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + # llm + num_layers = len(meta_model.llm.model.layers) + num_recompute_layers = int(num_layers * recompute_ratio) + for i, block in enumerate(meta_model.llm.model.layers): + block.apply(param_init_fn) + + fully_shard( + block, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + if i < num_recompute_layers: + checkpoint(block) + + meta_model.llm.model.embed_tokens.apply(param_init_fn) + meta_model.llm.model.norm.apply(param_init_fn) + meta_model.llm.lm_head.apply(param_init_fn) + + model = fully_shard( + meta_model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) # False is zero2, True is zero3 + return model diff --git a/xtuner/_lite/parallel/megatron/qwen2.py b/xtuner/_lite/parallel/megatron/qwen2.py new file mode 100644 index 000000000..90c65309b --- /dev/null +++ b/xtuner/_lite/parallel/megatron/qwen2.py @@ -0,0 +1,270 @@ +from functools import partial +from packaging import version + +import torch +from torch import nn +from torch.distributed._tensor import Replicate, distribute_tensor +from torch.distributed.tensor.parallel import (ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + parallelize_module) + +from xtuner._lite import get_logger +from ..fsdp.lazy import lazy_init_megatron +from .utils import map_rank0_modules + +logger = get_logger() + + +def _tp_qwen2(model, tp_mesh): + + layer_tp_plan = { + + # by default ColwiseParallel input layouts is replicated + # and RowwiseParallel output layouts is replicated + 'self_attn.q_proj': + ColwiseParallel(), + 'self_attn.k_proj': + ColwiseParallel(), + 'self_attn.v_proj': + ColwiseParallel(), + 'self_attn.o_proj': + RowwiseParallel(), + 'input_layernorm': + PrepareModuleInput( + input_layouts=(Replicate(), ), + desired_input_layouts=(Replicate(), ), + use_local_output=True + ), + 'mlp.up_proj': + ColwiseParallel(), + 'mlp.down_proj': + RowwiseParallel(), + 'mlp.gate_proj': + ColwiseParallel(), + 'post_attention_layernorm': + PrepareModuleInput( + input_layouts=(Replicate(), ), + desired_input_layouts=(Replicate(), ), + use_local_output=True + ) + } + + tp_size = tp_mesh.size() + for layer in model.layers: + attention = layer.self_attn + num_key_value_heads = attention.num_key_value_heads + num_heads = attention.num_heads + hidden_size = attention.hidden_size + + attention.num_heads = num_heads // tp_size + attention.num_key_value_heads = num_key_value_heads // tp_size + attention.hidden_size = hidden_size // tp_size + + attn_norm = layer.input_layernorm + attn_norm.register_parameter( + 'weight', + nn.Parameter( + distribute_tensor(attn_norm.weight, tp_mesh, [Replicate()]))) + + ffn_norm = layer.post_attention_layernorm + ffn_norm.register_parameter( + 'weight', + nn.Parameter( + distribute_tensor(ffn_norm.weight, tp_mesh, [Replicate()]))) + + parallelize_module( + module=layer, + device_mesh=tp_mesh, + parallelize_plan=layer_tp_plan, + ) + + norm = model.norm + dist_norm_w = nn.Parameter( + distribute_tensor(norm.weight, tp_mesh, [Replicate()])) + norm.register_parameter('weight', dist_norm_w) + + # emb = model.embed_tokens + # dist_emb_w = nn.Parameter( + # distribute_tensor(emb.weight, tp_mesh, [Replicate()])) + # emb.register_parameter('weight', dist_emb_w) + + # model.norm.apply(param_init_fn) + # model.embed_tokens.apply(param_init_fn) + + model = parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'model.embed_tokens': + RowwiseParallel(input_layouts=Replicate(), ), + 'model.norm':PrepareModuleInput( + input_layouts=(Replicate(),), + desired_input_layouts=(Replicate(),), + use_local_output=True + ), + }) + + +def megatron_qwen2(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + + if tp_mesh and tp_mesh.size() > 1: + _tp_qwen2(model, tp_mesh) + + from torch.distributed._composable import checkpoint + from torch.distributed._composable.fsdp import fully_shard + num_layers = len(model.layers) + num_recompute_layers = int(num_layers * recompute_ratio) + + for i, block in enumerate(model.layers): + + block.apply(param_init_fn) + + # # As an optimization, do not reshard after forward for the last + # # transformer block since FSDP would prefetch it immediately + # if i < num_layers - 1: + # _reshard = reshard_after_forward + # else: + # _reshard = False + + fully_shard( + block, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + if i < num_recompute_layers: + checkpoint(block) + + if version.parse(torch.__version__) >= version.parse("2.5.0"): + for layer_cur, layer_next in zip(model.layers[:-1], model.layers[1:]): + layer_cur.set_modules_to_forward_prefetch([layer_next]) + + model.embed_tokens.apply(param_init_fn) + model.norm.apply(param_init_fn) + + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + + +def megatron_qwen2_casual(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + megatron_qwen2( + model.model, + rank0_model.model if dp_mesh.get_rank() == 0 else None, + dp_mesh, + tp_mesh=tp_mesh, + pp_mesh=pp_mesh, + mp_policy=mp_policy, + recompute_ratio=recompute_ratio, + reshard_after_forward=reshard_after_forward) + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + + if tp_mesh and tp_mesh.size() > 1: + model = parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'lm_head': ColwiseParallel(output_layouts=Replicate(), ), + }) + + model.lm_head.apply(param_init_fn) + + from torch.distributed._composable.fsdp import fully_shard + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) + + +def megatron_qwen2_reward(model, + rank0_model, + dp_mesh, + tp_mesh=None, + pp_mesh=None, + mp_policy=None, + recompute_ratio=1.0, + reshard_after_forward=True): + megatron_qwen2( + model.model, + rank0_model.model if dp_mesh.get_rank() == 0 else None, + dp_mesh, + tp_mesh=tp_mesh, + pp_mesh=pp_mesh, + mp_policy=mp_policy, + recompute_ratio=recompute_ratio, + reshard_after_forward=reshard_after_forward) + + if dp_mesh.get_rank() == 0: + rank0_map = map_rank0_modules(model, rank0_model) + else: + rank0_map = None + + if tp_mesh and tp_mesh.size() > 1: + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan={ + 'score.0': ColwiseParallel(), + 'score.2': RowwiseParallel(), + }) + + param_init_fn = partial( + lazy_init_megatron, + rank0_map=rank0_map, + dp_mesh=dp_mesh, + tp_mesh=tp_mesh, + ) + + model.v_head.apply(param_init_fn) + + from torch.distributed._composable.fsdp import fully_shard + fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward) diff --git a/xtuner/_lite/parallel/megatron/utils.py b/xtuner/_lite/parallel/megatron/utils.py new file mode 100644 index 000000000..5e4ed2dd2 --- /dev/null +++ b/xtuner/_lite/parallel/megatron/utils.py @@ -0,0 +1,7 @@ +def map_rank0_modules(model, rank0_model): + rank0_modules = {name: mod for name, mod in rank0_model.named_modules()} + rank0_map = { + mod: rank0_modules[name] + for name, mod in model.named_modules() + } + return rank0_map diff --git a/xtuner/_lite/parallel/sampler.py b/xtuner/_lite/parallel/sampler.py new file mode 100644 index 000000000..91b286f86 --- /dev/null +++ b/xtuner/_lite/parallel/sampler.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import random +from typing import Iterator, Optional, Sized + +import torch +from mmengine.dist import sync_random_seed +from torch.distributed.device_mesh import DeviceMesh +from torch.utils.data import ConcatDataset as TorchConcatDataset +from torch.utils.data import Sampler + + +class ParallelSampler(Sampler): + """The default data sampler for both distributed and non-distributed + environment. + + It has several differences from the PyTorch ``DistributedSampler`` as + below: + + 1. This sampler supports non-distributed environment. + + 2. The round up behaviors are a little different. + + - If ``round_up=True``, this sampler will add extra samples to make the + number of samples is evenly divisible by the world size. And + this behavior is the same as the ``DistributedSampler`` with + ``drop_last=False``. + - If ``round_up=False``, this sampler won't remove or add any samples + while the ``DistributedSampler`` with ``drop_last=True`` will remove + tail samples. + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + """ + + def __init__( + self, + dataset: Sized, + dp_mesh: DeviceMesh, + global_batch_size: int, + shuffle: bool = True, + seed: Optional[int] = None, + round_up: bool = True, + ) -> None: + rank = dp_mesh.get_local_rank() + world_size = dp_mesh.size() + + assert global_batch_size % world_size == 0 + self.global_batch_size = global_batch_size + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.step = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil( + len(self.dataset) / + global_batch_size) * global_batch_size // world_size + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil( + (len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + + # subsample + indices = indices[self.rank:self.total_size:self.world_size] + + return iter(indices[self.step:]) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples - self.step + + def set_epoch(self, epoch: int, step=0) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + self.step = step + + +def get_length_grouped_indices(max_lengths, + group_batch_size, + dp_size, + seed=None): + if seed is not None: + torch.manual_seed(seed) + random.seed(seed) + + assert all(leng != 0 + for leng in max_lengths), 'Should not have zero length.' + indices = torch.randperm(len(max_lengths)) + megabatches = [ + indices[i:i + group_batch_size].tolist() + for i in range(0, len(max_lengths), group_batch_size) + ] + output = [] + for megabatch in megabatches: + megabatch = sorted( + megabatch, key=lambda i: max_lengths[i], reverse=True) + grouped_megabatch = [ + megabatch[i:i + dp_size] for i in range(0, len(megabatch), dp_size) + ] + random.shuffle(grouped_megabatch) + for group in grouped_megabatch: + output.extend(group) + + return output + + +class LengthGroupedSampler(Sampler): + + def __init__(self, + dataset: Sized, + dp_mesh: DeviceMesh, + global_batch_size: int, + length_attr: str = 'longest', + mega_batch_mult: Optional[int] = None, + seed: Optional[int] = None, + round_up: bool = True) -> None: + rank = dp_mesh.get_local_rank() + world_size = dp_mesh.size() + self.rank = rank + self.world_size = world_size + assert global_batch_size % world_size == 0 + + self.dataset = dataset + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.step = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil( + len(self.dataset) / + global_batch_size) * global_batch_size // world_size + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil( + (len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) + + if mega_batch_mult is None: + # Default for mega_batch_mult: 50 or the number to get 4 + # megabatches, whichever is smaller. + mega_batch_mult = min( + len(self.dataset) // (global_batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + self.group_batch_size = mega_batch_mult * global_batch_size + + if isinstance(self.dataset, TorchConcatDataset): + max_lengths = [] + for sub_dataset in self.dataset.datasets: + if hasattr(sub_dataset, length_attr): + max_lengths.extend(getattr(sub_dataset, length_attr)) + else: + raise ValueError + self.max_lengths = max_lengths + else: + if hasattr(self.dataset, length_attr): + self.max_lengths = getattr(self.dataset, length_attr) + assert isinstance(self.max_lengths, (list, tuple)) + + self.global_batch_size = global_batch_size + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch) + seed = self.seed + self.epoch + indices = get_length_grouped_indices( + max_lengths=self.max_lengths, + group_batch_size=self.group_batch_size, + dp_size=self.world_size, + seed=seed) + assert len(set(indices)) == len(indices) + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + # subsample + assert len(indices) == self.total_size + indices = indices[self.rank:self.total_size:self.world_size] + assert len(indices) == self.num_samples + return iter(indices[self.step:]) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples - self.step + + def set_epoch(self, epoch: int, step=0) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + self.step = step + + +def vlm_get_length_grouped_indices(max_lengths, group_batch_size, generator=None, **kwargs): + + def process(lengths, group_batch_size, generator=None): + indices = torch.randperm(len(lengths), generator=generator) + megabatches = [ + indices[i:i + group_batch_size].tolist() + for i in range(0, len(lengths), group_batch_size) + ] + megabatches = [ + sorted(megabatch, key=lambda i: lengths[i], reverse=True) + for megabatch in megabatches + ] + return megabatches + + lengths = max_lengths + assert all(leng != 0 for leng in lengths), 'Should not have zero length.' + if all(leng > 0 for leng in lengths) or all(leng < 0 for leng in lengths): + # all samples are in the same modality + megabatches = process(lengths, group_batch_size, generator=generator) + else: + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) + if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) + for i, l in enumerate(lengths) + if l < 0]) + mm_megabatches = [] + for mm_megabatch in process( + mm_lengths, group_batch_size, generator=generator): + mm_megabatches.append([mm_indices[i] for i in mm_megabatch]) + lang_megabatches = [] + for lang_megabatch in process( + lang_lengths, group_batch_size, generator=generator): + lang_megabatches.append([lang_indices[i] for i in lang_megabatch]) + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + last_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + + megabatch_indices = torch.randperm( + len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(last_batch) > 0: + megabatches.append( + sorted( + last_batch, key=lambda i: abs(lengths[i]), reverse=True)) + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, + # the longest element is the first + megabatch_maximums = [ + abs(lengths[megabatch[0]]) for megabatch in megabatches + ] + max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() + # Switch to put the longest element in first position + megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][ + 0], megabatches[0][0] + + return [i for megabatch in megabatches for i in megabatch] + + +class VLMLengthGroupedSampler(Sampler): + + def __init__(self, + dataset: Sized, + dp_mesh: DeviceMesh, + global_batch_size: int, + mega_batch_mult: Optional[int] = None, + seed: Optional[int] = None, + round_up: bool = True, + length_property='length') -> None: + rank = dp_mesh.get_local_rank() + world_size = dp_mesh.size() + self.rank = rank + self.world_size = world_size + assert global_batch_size % world_size == 0 + + self.dataset = dataset + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.step = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil( + len(self.dataset) / + global_batch_size) * global_batch_size // world_size + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil( + (len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) + + if mega_batch_mult is None: + # Default for mega_batch_mult: 50 or the number to get 4 + # megabatches, whichever is smaller. + mega_batch_mult = min( + len(self.dataset) // (global_batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + self.group_batch_size = mega_batch_mult * global_batch_size + + if isinstance(self.dataset, TorchConcatDataset): + max_lengths = [] + for sub_dataset in self.dataset.datasets: + max_lengths.extend(getattr(sub_dataset, length_property)) + self.max_lengths = max_lengths + else: + self.max_lengths = getattr(self.dataset, length_property) + assert isinstance(self.max_lengths, (list, tuple)) + + self.global_batch_size = global_batch_size + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch) + indices = vlm_get_length_grouped_indices( + max_lengths=self.max_lengths, + group_batch_size=self.group_batch_size, + dp_size=self.world_size, + generator=generator) + assert len(set(indices)) == len(indices) + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + # subsample + assert len(indices) == self.total_size + indices = indices[self.rank:self.total_size:self.world_size] + assert len(indices) == self.num_samples + return iter(indices[self.step:]) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples - self.step + + def set_epoch(self, epoch: int, step=0) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + self.step = step \ No newline at end of file diff --git a/xtuner/_lite/parallel/sequence/__init__.py b/xtuner/_lite/parallel/sequence/__init__.py new file mode 100644 index 000000000..2dbdcaacf --- /dev/null +++ b/xtuner/_lite/parallel/sequence/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dist import init_dist + +from .attention import (post_process_for_sequence_parallel_attn, + pre_process_for_sequence_parallel_attn, + sequence_parallel_wrapper) +from .data_collate import (pad_cumulative_len_for_sequence_parallel, + pad_for_sequence_parallel) +from .ops import (gather_for_sequence_parallel, gather_forward_split_backward, + split_for_sequence_parallel, split_forward_gather_backward) +from .reduce_loss import reduce_sequence_parallel_loss + +__all__ = [ + 'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn', + 'post_process_for_sequence_parallel_attn', 'split_for_sequence_parallel', + 'init_dist', 'gather_for_sequence_parallel', + 'split_forward_gather_backward', 'gather_forward_split_backward', + 'pad_cumulative_len_for_sequence_parallel', 'pad_for_sequence_parallel', + 'reduce_sequence_parallel_loss' +] diff --git a/xtuner/_lite/parallel/sequence/attention.py b/xtuner/_lite/parallel/sequence/attention.py new file mode 100644 index 000000000..5866a9f33 --- /dev/null +++ b/xtuner/_lite/parallel/sequence/attention.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.distributed as dist + +from ..comm import all_to_all +from ..setup import get_sp_mesh + + +def pre_process_for_sequence_parallel_attn(query_states, + key_states, + value_states, + sp_mesh, + scatter_dim=2, + gather_dim=1): + sp_size = sp_mesh.size() + n_head = query_states.shape[2] + assert n_head % sp_size == 0, \ + ('The number of attention heads should be divisible by ' + f'sequence_parallel_world_size. But got n_head = {n_head} and ' + f'sequence_parallel_world_size = {sp_size}.') + + # (b, s // sp_world_size, nd, dim) -> (b, s, nd // sp_world_size, dim) + sp_group = sp_mesh.get_group() + query_states = all_to_all( + query_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim) + key_states = all_to_all( + key_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim) + value_states = all_to_all( + value_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim) + + return query_states, key_states, value_states + + +def post_process_for_sequence_parallel_attn(attn_output, + sp_mesh, + scatter_dim=1, + gather_dim=2): + # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim) + sp_group = sp_mesh.get_group() + output = all_to_all( + attn_output, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim) + return output + + +def sequence_parallel_wrapper(local_attn): + + def sequence_parallel_attn(query_states, key_states, value_states, *args, + **kwargs): + training = kwargs.pop('training', True) + sp_mesh = kwargs.pop('sp_mesh', None) + + if sp_mesh: + sp_size = sp_mesh.size() + else: + sp_size = get_sp_mesh().size() + + enable_sequence_parallel = sp_size > 1 + if enable_sequence_parallel: + query_states, key_states, value_states = \ + pre_process_for_sequence_parallel_attn( + query_states, key_states, value_states, sp_mesh) + + out = local_attn(query_states, key_states, value_states, *args, + **kwargs) + + if enable_sequence_parallel: + out = post_process_for_sequence_parallel_attn(out, sp_mesh).contiguous() + + return out + + return sequence_parallel_attn diff --git a/xtuner/_lite/parallel/sequence/data_collate.py b/xtuner/_lite/parallel/sequence/data_collate.py new file mode 100644 index 000000000..32c8f161b --- /dev/null +++ b/xtuner/_lite/parallel/sequence/data_collate.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..setup import get_sp_mesh + + +def pad_for_sequence_parallel(tensor, padding_value, sp_mesh, dim=-1): + + sp_size = sp_mesh.size() + length = tensor.shape[dim] + if length % sp_size == 0: + return tensor + + pad_num = sp_size - (length % sp_size) + pad_shape = (*tensor.shape[:dim], pad_num, + *tensor.shape[dim + 1:]) if dim != -1 else ( + *tensor.shape[:dim], pad_num) + pad = torch.full( + pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device) + tensor = torch.cat([tensor, pad], dim=dim) + return tensor + + +# This function only meets the following two conditions: +# 1. use_varlen_attn = True +# 2. pack_to_max_length = True and the lengths of each sequence are different +def pad_cumulative_len_for_sequence_parallel(cumulative_len): + assert len(cumulative_len) == 1 + seqlen = cumulative_len[0][-1] + sp_size = get_sp_mesh().size() + if seqlen % sp_size == 0: + return cumulative_len, None + + bs = len(cumulative_len) + pad_len = sp_size - (seqlen % sp_size) + seqlen_new = seqlen + pad_len + attention_mask = torch.zeros( + bs, seqlen_new, dtype=torch.bool, device=cumulative_len[0].device) + attention_mask[:, :seqlen] = True + + for i, cu_len in enumerate(cumulative_len): + pad = torch.tensor([seqlen_new], + device=cu_len.device, + dtype=cu_len.dtype) + cumulative_len[i] = torch.cat([cu_len, pad], dim=0) + + return cumulative_len, attention_mask diff --git a/xtuner/_lite/parallel/sequence/ops.py b/xtuner/_lite/parallel/sequence/ops.py new file mode 100644 index 000000000..fb0ba0d86 --- /dev/null +++ b/xtuner/_lite/parallel/sequence/ops.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist + + +def split_for_sequence_parallel(input, dim: int, sp_mesh): + """Splits the input tensor along a given dimension for sequence parallel. + + Args: + input: The input tensor to be split. + dim: The dimension along which the tensor should be split. + sp_group: The sequence parallel process group. + + Returns: + The split tensor corresponding to the current rank's chunk. + """ + sp_group = sp_mesh.get_group() + sp_size = sp_mesh.size() + if sp_size == 1: + return input + + rank = dist.get_rank(sp_group) + dim_size = input.size(dim) + assert dim_size % sp_size == 0, ( + f'The dimension to split ({dim_size}) is not a multiple of ' + f'sp size ({sp_size}), cannot split tensor evenly') + + tensor_list = torch.split(input, dim_size // sp_size, dim=dim) + output = tensor_list[rank].contiguous() + + return output + + +def gather_for_sequence_parallel(input, dim: int, sp_group: dist.ProcessGroup): + """Gathers the input tensor along a given dimension for sequence parallel. + + Args: + input: The input tensor to be gathered. + dim: The dimension along which the tensor should be gathered. + sp_group: The sequence parallel process group. + + Returns: + The gathered tensor concatenated along the specified dimension. + """ + input = input.contiguous() + world_size = dist.get_world_size(sp_group) + dist.get_rank(sp_group) + + if world_size == 1: + return input + + tensor_list = [torch.empty_like(input) for _ in range(world_size)] + assert input.device.type == 'cuda' + dist.all_gather(tensor_list, input, group=sp_group) + + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input during forward. + + Scale and split the grad and keep only the corresponding chuck to the rank + during backward. + """ + + @staticmethod + def forward(ctx, input, dim, sp_group, grad_scale): + ctx.dim = dim + ctx.sp_group = sp_group + ctx.grad_scale = grad_scale + return gather_for_sequence_parallel(input, dim, sp_group) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == 'up': + grad_output = grad_output * dist.get_world_size(ctx.sp_group) + elif ctx.grad_scale == 'down': + grad_output = grad_output / dist.get_world_size(ctx.sp_group) + + return (split_for_sequence_parallel(grad_output, ctx.dim, + ctx.sp_group), None, None, None) + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank during + forward. + + Scale and gather the grad during backward. + """ + + @staticmethod + def forward(ctx, input, dim, sp_group, grad_scale): + ctx.dim = dim + ctx.sp_group = sp_group + ctx.grad_scale = grad_scale + return split_for_sequence_parallel(input, dim, sp_group) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == 'up': + grad_output = grad_output * dist.get_world_size(ctx.sp_group) + elif ctx.grad_scale == 'down': + grad_output = grad_output / dist.get_world_size(ctx.sp_group) + return (gather_for_sequence_parallel(grad_output, ctx.dim, + ctx.sp_group), None, None, None) + + +def split_forward_gather_backward(input, dim, sp_group, grad_scale=None): + """Split tensors according to the sp rank during forward propagation and + gather the grad from the whole sp group during backward propagation. + + 1. When do we need this? input.requires_grad = True + + 2. Why we need grad scale? + + We have to scale down the grads as `gather_forward_split_backward` scales + up the grads. + """ + return _SplitForwardGatherBackward.apply(input, dim, sp_group, grad_scale) + + +def gather_forward_split_backward(input, dim, sp_group, grad_scale=None): + """Gather tensors from the whole sp group during forward propagation and + split the grad according to the sp rank during backward propagation. + + 1. When do we need this? + + When sp is greater than 1, we need to slice the input `x` along + sequence length dimension before it is passed into the model and get + `sub_seq_x`. We then pass `sub_seq_x` into model and get output + `sub_seq_out`. If the loss calculation process needs to use the complete + output, we have to gather the `sub_seq_out` in all sp ranks during forward + propagation and split the grad during backward propagation. + + 2. Why we need grad scale? + Here is a simple case. + + -------- SP 1 ----------- + Suppose here is a toy model with only one linear module + (in_features = 2, out_features = 1) and the input x has shape(2, 2). + Y = [[y1], = [[w11x11 + w21x12], = [[x11, x12], dot [[w11], + [y2]] [w11x21 + w21x22]] [x21, x22]] [w21]] + z = mean(Y) = (y1 + y2) / 2 + Here is the partial derivative of z with respect to w11: + ∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 + ∂z / ∂y2 * ∂y2 / ∂w11 + = 1/2 * x11 + 1/2 * x21 = (x11 + x21) / 2 + + -------- SP 2 ----------- + When sequence parallel world size is set to 2, we will split the input x + and scatter them to the two rank in the same sequence parallel group. + ```Step 1 + Y_rank0 = [[y1]] = [[w11x11 + w21x12]] = [[x11, x12]] dot [[w11, w21]]^T + Y_rank1 = [[y2]] = [[w11x21 + w21x22]] = [[x21, x22]] dot [[w11, w21]]^T + ``` + + Then, we have to gather them: + ```Step 2 + Y_rank0 = [[y1], + detach([y2])] + Y_rank1 = [detach([y1]), + [y2]] + ``` + Note that y2 in Y_rank0 does not have grad, neither does y1 in Y_rank1. + + Similarly, we calculate the loss in each rank: + ```Step 3 + z_rank0 = mean(Y_rank0) = (y1 + detach(y2)) / 2 + z_rank1 = mean(Y_rank1) = (detach(y1) + y2) / 2 + ``` + So the partial derivative of loss_rank0 with respect to w11: + ```∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 = x11 / 2``` + The same for rank1: + ```∂z / ∂w11 = ∂z / ∂y2 * ∂y2 / ∂w11 = x21 / 2``` + + Finally, we need to all_reduce them: + ```Step 4 + In both rank: + ∂z / ∂w11 = (x11 / 2 + x21 / 2) / 2 = (x11 + x21) / 4 + ``` + + In SP2, the gradient of each param is only half of that in SP1. + So we should scale up the grad during the backward process in Step 2. + """ # noqa: E501 + return _GatherForwardSplitBackward.apply(input, dim, sp_group, grad_scale) diff --git a/xtuner/_lite/parallel/sequence/reduce_loss.py b/xtuner/_lite/parallel/sequence/reduce_loss.py new file mode 100644 index 000000000..11e204806 --- /dev/null +++ b/xtuner/_lite/parallel/sequence/reduce_loss.py @@ -0,0 +1,33 @@ +import torch +import torch.distributed as dist + +from ..setup import get_sp_mesh + + +class _ReduceLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, mean_loss, loss_scale, process_group): + ctx.mode = process_group + if loss_scale == 0: + # convert nan to 0 just for logging + mean_loss = torch.nan_to_num(mean_loss) + loss_sum = mean_loss * loss_scale + dist.all_reduce(loss_sum, group=process_group) + dist.all_reduce(loss_scale, group=process_group) + loss = loss_sum / loss_scale + return loss + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None, None + + +def reduce_sequence_parallel_loss(mean_loss, loss_scale, sp_mesh=None): + if sp_mesh.size() == 1: + return mean_loss + if sp_mesh is None: + # avoid bc breaking + sp_mesh = get_sp_mesh() + sp_group = sp_mesh.get_group() + return _ReduceLoss.apply(mean_loss, loss_scale, sp_group) diff --git a/xtuner/_lite/parallel/setup.py b/xtuner/_lite/parallel/setup.py new file mode 100644 index 000000000..7a02f5d03 --- /dev/null +++ b/xtuner/_lite/parallel/setup.py @@ -0,0 +1,110 @@ +import torch.distributed as dist +from mmengine.dist import infer_launcher, init_dist +from torch.distributed.device_mesh import init_device_mesh + +from xtuner._lite import get_device + +_SP_MESH = None +_DP_MESH = None +_SAME_DATA_MESH = None +_TP_MESH = None +_FSDP_MESH = None +_WORLD_MESH = None + +_EP_MESH = None +_EXPERTS_FSDP_MESH = None + + +def setup_parallel(sp_size=1, tp_size=1, ep_size=1): + + if not dist.is_initialized(): + dist_launcher = infer_launcher() + init_dist(dist_launcher) + + device = get_device() + + world_size = dist.get_world_size() + assert world_size % sp_size == 0 + assert world_size % sp_size % tp_size == 0 + assert tp_size <= 8 + + dp_size = world_size // sp_size // tp_size + data_mesh = init_device_mesh( + device, (dp_size, sp_size, tp_size), mesh_dim_names=('dp', 'sp', 'tp')) + + same_data_mesh = init_device_mesh( + device, (dp_size, sp_size * tp_size), mesh_dim_names=('dp', 'same_data')) + + model_mesh = init_device_mesh( + device, (dp_size * sp_size, tp_size), mesh_dim_names=('fsdp', 'tp')) + + world_mesh = init_device_mesh( + device, (world_size, ), mesh_dim_names=('world', )) + + global _DP_MESH, _DP_GROUP, _DP_WORLD_SIZE + _DP_MESH = data_mesh['dp'] + _DP_GROUP = data_mesh['dp'].get_group() + _DP_WORLD_SIZE = data_mesh['dp'].size() + + global _SP_MESH, _SP_GROUP, _SP_WORLD_SIZE + _SP_MESH = data_mesh['sp'] + _SP_GROUP = data_mesh['sp'].get_group() + _SP_WORLD_SIZE = data_mesh['sp'].size() + + global _TP_MESH, _TP_GROUP, _TP_WORLD_SIZE + _TP_MESH = model_mesh['tp'] + _TP_GROUP = model_mesh['tp'].get_group() + _TP_WORLD_SIZE = model_mesh['tp'].size() + + global _WORLD_MESH, _FSDP_MESH + _WORLD_MESH = world_mesh['world'] + _FSDP_MESH = model_mesh['fsdp'] + + global _SAME_DATA_MESH + _SAME_DATA_MESH = same_data_mesh['same_data'] + + assert world_size % ep_size == 0 + fsdp_size = world_size // ep_size + + # faster in multi nodes + device_mesh = init_device_mesh( + device, (fsdp_size, ep_size), mesh_dim_names=('fsdp', 'ep')) + # slower in multi nodes + # device_mesh = init_device_mesh('cuda', (ep_size, fsdp_size), + # mesh_dim_names=('ep', 'fsdp')) + + global _EP_MESH + global _EXPERTS_FSDP_MESH + _EP_MESH = device_mesh['ep'] + _EXPERTS_FSDP_MESH = device_mesh['fsdp'] + + +def get_ep_mesh(): + return _EP_MESH + + +def get_experts_fsdp_mesh(): + return _EXPERTS_FSDP_MESH + + +def get_world_mesh(): + return _WORLD_MESH + + +def get_dp_mesh(): + return _DP_MESH + + +def get_fsdp_mesh(): + return _FSDP_MESH + + +def get_sp_mesh(): + return _SP_MESH + + +def get_tp_mesh(): + return _TP_MESH + +def get_same_data_mesh(): + return _SAME_DATA_MESH diff --git a/xtuner/_lite/parallel/utils.py b/xtuner/_lite/parallel/utils.py new file mode 100644 index 000000000..4c9033e4f --- /dev/null +++ b/xtuner/_lite/parallel/utils.py @@ -0,0 +1,16 @@ +from torch.distributed.checkpoint.stateful import Stateful + + +class MetaStateful(Stateful): + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def state_dict(self): + return self.kwargs + + def load_state_dict(self, state_dict) -> None: + self.kwargs = state_dict + + def __getitem__(self, key): + return self.kwargs[key] diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index e81ec7a3a..25444530e 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -97,6 +97,12 @@ DeepseekV2FlashAttention2=LazyObject( 'xtuner.model.modules.dispatch.deepseek_v2', 'deepseek_varlen_attn_forward'), + InternLM3FlashCrossAttention2=LazyObject( + 'xtuner.model.modules.dispatch.internlm3', + 'internlm3_cross_attn_varlen_forward'), + InternLM3FlashSelfAttention2=LazyObject( + 'xtuner.model.modules.dispatch.internlm3', + 'internlm3_self_attn_varlen_forward') ) VARLEN_ATTN_LEGACY_DISPATCH_MAPPING = dict( @@ -181,24 +187,23 @@ def dispatch_varlen_attn_forward(model): return from mmengine import print_log - print_log = log_once(print_log) + # print_log = log_once(print_log) - varlen_attn_forward = None + for module in model.modules(): name = type(module).__name__ if (IS_LOW_VERSION_TRANSFORMERS and name in VARLEN_ATTN_LEGACY_DISPATCH_MAPPING): - if varlen_attn_forward is None: - varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name] - varlen_attn_forward = varlen_attn_forward.build() + + varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name] + varlen_attn_forward = varlen_attn_forward.build() print_log( f'Dispatch legacy {name} varlen forward. ' f'{NO_ATTN_WEIGHTS_MSG}', 'current') module.forward = types.MethodType(varlen_attn_forward, module) elif name in VARLEN_ATTN_DISPATCH_MAPPING: - if varlen_attn_forward is None: - varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name] - varlen_attn_forward = varlen_attn_forward.build() + varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name] + varlen_attn_forward = varlen_attn_forward.build() print_log(f'Dispatch {name} varlen forward. {NO_ATTN_WEIGHTS_MSG}', 'current') module.forward = types.MethodType(varlen_attn_forward, module) diff --git a/xtuner/model/modules/dispatch/internlm3.py b/xtuner/model/modules/dispatch/internlm3.py new file mode 100644 index 000000000..f9f964a58 --- /dev/null +++ b/xtuner/model/modules/dispatch/internlm3.py @@ -0,0 +1,367 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from einops import rearrange +from mmengine import MessageHub +from transformers.cache_utils import StaticCache +from transformers.modeling_outputs import SequenceClassifierOutputWithPast +from torch import distributed as dist +from xtuner.parallel.sequence import (get_sequence_parallel_world_size, + post_process_for_sequence_parallel_attn, + pre_process_for_sequence_parallel_attn) +from .attention import SUPPORT_FLASH2, flash_attn_wo_mask, varlen_flash_attn + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + if k is None: + return q_embed, None + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """This is the equivalent of torch.repeat_interleave(x, dim=1, + repeats=n_rep). + + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) + to (batch, seqlen, num_attention_heads, head_dim)""" + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, + None, :].expand(batch, slen, + num_key_value_heads, n_rep, + head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, + head_dim) + + +def internlm3_self_attn_varlen_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + # Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501 + if isinstance(past_key_value, StaticCache): + raise ValueError( + '`static` cache implementation is not compatible with ' + '`attn_implementation==flash_attention_2` make sure to use `sdpa` ' + 'in the mean time, and open an issue at ' + 'https://github.com/huggingface/transformers') + + bsz, q_len, _ = hidden_states.size() + + message_hub = MessageHub.get_instance('varlen_attn_args') + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') + max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') + use_varlen_atten = (cumulative_len is not None) + + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # dropout_rate = self.attention_dropout if self.training else 0.0 + dropout_rate = 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (InternLM3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.wqkv.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + + if use_varlen_atten: + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + causal=True, + dropout_p=dropout_rate, + training=self.training) + elif attention_mask is None: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + training=self.training) + else: + + enable_sequence_parallel = ( + dist.is_initialized() and get_sequence_parallel_world_size() > 1 + and self.training) + if enable_sequence_parallel: + query_states, key_states, value_states = \ + pre_process_for_sequence_parallel_attn( + query_states, key_states, value_states) + # self.num_heads is used in self._upad_input method + # num_heads has been changed because of sequence parallel + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=dropout_rate) + + if enable_sequence_parallel: + attn_output = post_process_for_sequence_parallel_attn(attn_output) + self.num_heads = ori_num_head + + + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + # Due to the implementation of the PyTorch version of flash attention, + # even when the output_attentions flag is set to True, it is not possible + # to return the attn_weights. + return attn_output, None, past_key_value + + + + +def internlm3_cross_attn_varlen_forward( + self, + hidden_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + message_hub = MessageHub.get_instance('varlen_attn_args') + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') + max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') + use_varlen_atten = (cumulative_len is not None) + + if self.config.pretraining_tp > 1: + # split qkv_states by tp size + key_value_slicing = self.hidden_size // self.config.pretraining_tp + q_slices = self.wq.weight.split(key_value_slicing, dim=0) + query_states = torch.cat( + [F.linear(hidden_states, q_slice) for q_slice in q_slices], dim=-1 # pylint: disable=E1102 + ) + else: + query_states = self.wq(hidden_states) + + query_states = rearrange(query_states, "b q (h d) -> b q h d", d=self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # Only query_states are rotated in cross-attention + query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # dropout_rate = self.attention_dropout if self.training else 0.0 + dropout_rate = 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (InternLM3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.wqkv.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + dropout_rate = 0.0 + if use_varlen_atten: + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + causal=True, + dropout_p=dropout_rate, + training=self.training) + elif attention_mask is None: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + training=self.training) + else: + + enable_sequence_parallel = ( + dist.is_initialized() and get_sequence_parallel_world_size() > 1 + and self.training) + if enable_sequence_parallel: + query_states, key_states, value_states = \ + pre_process_for_sequence_parallel_attn( + query_states, key_states, value_states) + # self.num_heads is used in self._upad_input method + # num_heads has been changed because of sequence parallel + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + dropout=dropout_rate) + + if enable_sequence_parallel: + attn_output = post_process_for_sequence_parallel_attn(attn_output) + self.num_heads = ori_num_head + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + return attn_output, None \ No newline at end of file diff --git a/xtuner/utils/handle_moe_load_and_save.py b/xtuner/utils/handle_moe_load_and_save.py index 88a3936a8..18764a82d 100644 --- a/xtuner/utils/handle_moe_load_and_save.py +++ b/xtuner/utils/handle_moe_load_and_save.py @@ -3,7 +3,6 @@ import re from collections import OrderedDict -import deepspeed import torch import torch.distributed as dist import torch.nn as nn @@ -145,6 +144,7 @@ def load(module: nn.Module, state_dict, unloaded_shard_files, prefix=''): if len(params_to_gather) > 0: args = (state_dict, prefix, {}, True, [], [], error_msgs) if is_deepspeed_zero3_enabled(): + import deepspeed with deepspeed.zero.GatheredParameters( params_to_gather, modifier_rank=0): if dist.get_rank() == 0: