diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml
index 4b9f51976..9c4080fae 100644
--- a/.pre-commit-config-zh-cn.yaml
+++ b/.pre-commit-config-zh-cn.yaml
@@ -1,4 +1,4 @@
-exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/
+exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/|^xtuner/_lite/modelings/|^xtuner/_lite/accelerate/dispatches/huggingface/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f6bbfd633..245f17c69 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,4 +1,4 @@
-exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/
+exclude: ^tests/data/|^xtuner/tools/model_converters/modeling_internlm2_reward/|^xtuner/_lite/modelings/|^xtuner/_lite/accelerate/dispatches/huggingface/
repos:
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
diff --git a/requirements/lmdeploy.txt b/requirements/lmdeploy.txt
new file mode 100644
index 000000000..25ef3916f
--- /dev/null
+++ b/requirements/lmdeploy.txt
@@ -0,0 +1 @@
+lmdeploy>=0.6.2 --no-deps
\ No newline at end of file
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index 3a4d2f84e..a90dae7b7 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,16 +1,11 @@
-# Minimum 0.40.0.post4 to fix some 4-bit precision bugs
-bitsandbytes>=0.40.0.post4
# Minimum 2.16.0 to fix some bugs, see https://github.com/huggingface/datasets/pull/6444
datasets>=2.16.0
-einops
-# Minimum 0.1.2 to fix some bugs, see https://github.com/InternLM/lagent/pull/44
-lagent>=0.1.2
+einop
+# Avoid `import cv2` failed
+opencv-python==4.7.0.72
# Minimum 0.10.3 to support distributed evaluation for MMBench
# see https://github.com/open-mmlab/mmengine/pull/1469
mmengine>=0.10.3
-openpyxl
-# Minimum 0.4.0 to support QLoRA, see https://github.com/huggingface/peft/pull/476
-peft>=0.4.0
scikit-image
scipy
SentencePiece
@@ -23,5 +18,7 @@ torchvision
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923
# transformers >= 4.43.0 use _flash_attention_forward but not self._flash_attention_forward
# to calculate attn output which lead to bc braeking
-transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2,<=4.42.4
+transformers>=4.45
transformers_stream_generator
+loguru
+pydantic
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 7a95dfab4..fe4d1b4f1 100644
--- a/setup.py
+++ b/setup.py
@@ -117,10 +117,12 @@ def gen_packages_items():
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
+ 'Programming Language :: Python :: 3.11',
+ 'Programming Language :: Python :: 3.12',
'Topic :: Utilities',
],
# Python maximum version <3.11, to support mpi4py-mpich
- python_requires='>=3.8, <3.11',
+ python_requires='>=3.8, <3.13',
license='Apache License 2.0',
install_requires=parse_requirements('requirements/runtime.txt'),
extras_require={
diff --git a/tools/fsdp_sft.py b/tools/fsdp_sft.py
new file mode 100644
index 000000000..d0d6bd00d
--- /dev/null
+++ b/tools/fsdp_sft.py
@@ -0,0 +1,1036 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import math
+import os
+import sys
+import re
+import time
+import shutil
+import requests
+import gc
+from collections import OrderedDict
+from contextlib import nullcontext
+from concurrent.futures import wait
+from datetime import datetime, timedelta
+from functools import partial
+
+import torch
+import torch.distributed as dist
+from torch.nn import functional as F
+import torch.distributed.checkpoint as dcp
+from torch.distributed._composable.fsdp import MixedPrecisionPolicy
+from accelerate.utils import set_module_tensor_to_device
+from mmengine import mkdir_or_exist
+from mmengine.runner import set_random_seed
+from mmengine.utils import get_git_hash
+from mmengine.utils.dl_utils import collect_env
+
+from peft import LoraConfig, get_peft_model
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
+ apply_activation_checkpointing
+from torch.distributed.checkpoint.state_dict import (StateDictOptions,
+ get_model_state_dict,
+ get_state_dict, set_state_dict)
+from torch.distributed.checkpoint.stateful import Stateful
+
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import MixedPrecision
+from torch.distributed.fsdp.api import CPUOffload, ShardingStrategy
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import _or_policy
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
+from torch.utils.data import ConcatDataset, DataLoader
+from transformers import AutoConfig, AutoModelForCausalLM
+from transformers.utils.import_utils import is_flash_attn_2_available
+
+from xtuner._lite import (AutoTokenizer, get_device, get_logger,
+ get_torch_device_module)
+from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_hf_code, LoadWoInit,
+ packed_sequence, varlen_attn_is_available, profile_time_and_memory)
+from xtuner._lite.algorithms.sft import SftCollator, SftTokenizeFunction
+from xtuner._lite.chat import CHAT_TEMPLATE_MAP
+from xtuner._lite.datasets import (DATASET_CLS_MAP, OPENAI_CONVERT_MAP,
+ SoftPackDataset, HardPackDataset, load_datasets)
+from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler,
+ get_dp_mesh, get_sp_mesh,
+ pad_for_sequence_parallel,
+ reduce_sequence_parallel_loss,
+ setup_parallel, split_for_sequence_parallel)
+
+from xtuner._lite.parallel import (ParallelSampler, get_dp_mesh, get_fsdp_mesh,
+ get_sp_mesh, get_tp_mesh, get_world_mesh, get_same_data_mesh,
+ pad_for_sequence_parallel, setup_parallel,
+ reduce_sequence_parallel_loss,
+ split_for_sequence_parallel)
+from xtuner._lite.parallel.megatron import megatron_parallelize
+from xtuner._lite.parallel.fsdp import clip_grad_norm_
+
+gc.disable()
+logger = get_logger()
+
+DEVICE = get_device()
+DEVICE_MODULE = get_torch_device_module()
+
+SUPPORT_DATA_FORMATS = OPENAI_CONVERT_MAP.keys()
+
+def log_format(rank, debug=False):
+
+ sp_rank = get_sp_mesh().get_local_rank()
+ dp_rank = get_dp_mesh().get_local_rank()
+ tp_rank = get_tp_mesh().get_local_rank()
+ fsdp_rank = get_fsdp_mesh().get_local_rank()
+
+ formatter = f'[XTuner][RANK {rank}][DP {dp_rank}][SP {sp_rank}][TP {tp_rank}]'
+ formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]'
+
+ if debug:
+ formatter += '[{name}:'
+ formatter += '{function}:'
+ formatter += '{line}]'
+
+ formatter += ' {message}'
+ return formatter
+
+def send_to_feishu(web_hook, msg):
+
+ header = {
+ "Content-Type" : "application/json;charset=UTF-8"
+ }
+
+ body = {
+ "msg_type" : "text",
+ "content" : { "text" : f"所有人{msg}"}
+ }
+
+ try:
+ requests.post(url=web_hook, json=body, headers=header, timeout=1)
+ except requests.exceptions.RequestException:
+ pass
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train LLM')
+
+ model_args = parser.add_argument_group('model', 'Model Related Settings')
+ model_args.add_argument('--llm', help='repo id or local path of the model')
+ model_args.add_argument(
+ '-t',
+ '--tokenizer',
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--chat-template',
+ choices=CHAT_TEMPLATE_MAP.keys(),
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--use-lora', action='store_true', help='Apply the adapter to LLM.')
+ model_args.add_argument(
+ '--lora-targets',
+ default=None,
+ nargs='*',
+ help='The names of the modules to apply the adapter to. ')
+ model_args.add_argument(
+ '--lora-r', default=64, type=int, help="Not updating vit's parameters")
+ model_args.add_argument(
+ '--lora-alpha',
+ default=16,
+ type=int,
+ help='The alpha parameter for Lora scaling.')
+ model_args.add_argument(
+ '--lora-dropout',
+ default=0.1,
+ type=float,
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--lora-bias',
+ default='none',
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--dtype',
+ default='auto',
+ choices=['fp16', 'bf16', 'auto'],
+ help=("the dtype of the model forward. When set to 'auto', it will "
+ 'automatically determine whether bf16 is available, '
+ 'prioritizing the use of bf16.'))
+
+ model_args.add_argument(
+ '--selective-recompute',
+ default=1.0,
+ type=float,
+ help=('the ratio of re-computation for transforemer layers. '
+ 'The maximum is 1; the larger the value, the less memory '
+ 'required for training. The default is 1, meaning all layers '
+ 'need to be re-computated.'))
+ model_args.add_argument(
+ '--shard-strategy',
+ default='full',
+ choices=['full', 'hybrid'],
+ help=('The sharding strategy to be used for distributed training.'))
+ model_args.add_argument('--cpu-offload', action='store_true', help=(''))
+ model_args.add_argument('--sp-size', type=int, default=1, help='')
+ data_args = parser.add_argument_group('data', 'Dataset Related Settings')
+ data_args.add_argument(
+ '--datasets',
+ nargs='*',
+ help=('repo id or local path or dir of the datasets. For repo ids, '
+ 'the `dset-sources` needs to be appropriately set to '
+ '`modelscope` or `huggingface`. For local dir, all json and '
+ 'jsonl files will be loaded by default. The type of loaded '
+ 'files can be controlled by setting `dset-file-type`'))
+ data_args.add_argument(
+ '--dset-file-types',
+ nargs='*',
+ default=DATASET_CLS_MAP.keys(),
+ choices=DATASET_CLS_MAP.keys(),
+ help='the file type that needs to be loaded')
+ data_args.add_argument(
+ '--dset-sources',
+ nargs='*',
+ default=['local'],
+ choices=['local', 'huggingface', 'modelscope'],
+ help=('the source of each dataset; it can accept one or the same '
+ 'number of args as the number of `datasets`, with one arg '
+ 'indicating that all datasets come from the same source. '
+ '`local` represents the local path, `huggingface` represents '
+ 'the open-source data in the Huggingface Hub, `modelscope` '
+ 'indicates the open-source data in the Modelscope Hub.'))
+ data_args.add_argument(
+ '--dset-formats',
+ nargs='*',
+ default=['openai'],
+ help=('the format of each dataset; it can accept one or the same '
+ 'number of args as the number of `datasets`, with one arg '
+ 'indicating that all datasets are the same format.'))
+ data_args.add_argument(
+ '--dset-sample-ratios',
+ nargs='*',
+ type=float,
+ default=[1.0],
+ help=('the sample ratio of each dataset; it can accept one or the '
+ 'same number of args as the number of `datasets`, with one arg '
+ 'indicating that all datasets use the same sample ratio.'))
+ data_args.add_argument(
+ '--dset-cache-dir',
+ help=('the cache dir of the loaded datasets. When the `datasets` is '
+ 'set, the loaded datasets will be cached to this dir. If the '
+ '`datasets` are not set, the cached dataset in this dir will be '
+ 'loaded.'))
+ data_args.add_argument(
+ '--dset-pack-level',
+ choices=['hard', 'soft'],
+ help=('the level of data packing. When `hard`, multiple data will be '
+ 'packed to `max_length`, potentially causing some data to be '
+ 'truncated, and the length of the packed data will always '
+ 'be `max_length`; When `soft`, it will pack multiple data '
+ 'into nearly `max_length` without truncating the data.'))
+ data_args.add_argument(
+ '--global-pack',
+ action='store_true',
+ help='A subsequence in the packed data comes from different files.')
+ data_args.add_argument(
+ '--max-length',
+ type=int,
+ default=2048,
+ help=('the maximum length of each piece of data, any excess will be '
+ 'truncated.'))
+ data_args.add_argument(
+ '--num-workers',
+ type=int,
+ default=8,
+ help='how many subprocesses to use for data loading.')
+ data_args.add_argument('--file-pattern', type=str, default=None)
+ data_args.add_argument('--group-by-length', action='store_true')
+
+ optim_args = parser.add_argument_group('optim', 'Optim Related Settings')
+ optim_args.add_argument(
+ '--mirco-batch-size',
+ type=int,
+ default=1,
+ help='batch size for each forward + backward pass')
+ optim_args.add_argument(
+ '--global-batch-size',
+ type=int,
+ default=16,
+ help='batch size for each optimizer step')
+
+ optim_args.add_argument(
+ '--lr', default=4e-5, type=float, help='learning rate.')
+ optim_args.add_argument(
+ '--lr-min', default=6e-6, type=float, help='min learning rate.')
+ optim_args.add_argument(
+ '--wd', default=0.01, type=float, help='weight decay.')
+ optim_args.add_argument(
+ '--max-grad-norm', default=1, type=float, help='gradient clipping')
+ optim_args.add_argument(
+ '-e', '--epochs', default=1, type=int, help='total training epochs.')
+ optim_args.add_argument(
+ '--warmup-ratio',
+ default=0.03,
+ type=float,
+ help=('the proportion of training steps for learning rate warm-up in '
+ 'relation to the total training steps.'))
+
+ parser.add_argument('-c', '--config', default=None)
+ parser.add_argument(
+ '--work-dir',
+ default='work_dirs',
+ help='the dir to save logs and checkpoints')
+ parser.add_argument(
+ '--feishu-webhook', default=None, help='Webhook of Feishu Group Chat Bot')
+ parser.add_argument('--gc-interval', default=100, type=int)
+ parser.add_argument(
+ '--checkpoint-interval',
+ default=-1,
+ type=float,
+ help=('how many steps to save a checkpoint; it can be a floating '
+ 'point number less than 1, or an integer greater than or equal '
+ "to 1. When it's a floating point, it will be multiplied by the "
+ 'total number of training steps.'))
+ parser.add_argument(
+ '--checkpoint-max-keep',
+ default=1,
+ type=int,
+ help=('Maximum number of saved checkpoints。'))
+ parser.add_argument(
+ '--checkpoint-drop-optimizer',
+ action='store_true',
+ help=('only model parameters are saved when saving a checkpoint. '
+ 'This can significantly reduce the size of checkpoint files, '
+ 'but the saved checkpoints cannot be resumed.'))
+ parser.add_argument(
+ '--log-interval', default=1, type=int, help='log interval')
+ parser.add_argument(
+ '--resume',
+ action='store_true',
+ help='specify checkpoint path to be resumed from.')
+ parser.add_argument(
+ '--seed', type=int, default=0, help='random seed for the training')
+ parser.add_argument(
+ '--debug', action='store_true', help='Set logger level to `DEBUG`')
+ args = parser.parse_args()
+ return args
+
+
+def is_interval(step, total_steps, interval):
+ return (step + 1) % interval == 0 or (step + 1) == total_steps
+
+
+def map_meta_modules(model, meta_model):
+ modules = {name: mod for name, mod in model.named_modules()}
+ meta_module_map = {
+ mod: modules[name]
+ for name, mod in meta_model.named_modules()
+ }
+ return meta_module_map
+
+
+def build_llm_model(args, config, world_size, dtype=torch.float32):
+ with LoadWoInit():
+ llm = AutoModelForCausalLM.from_pretrained(
+ args.llm, config=config, attn_implementation='flash_attention_2',
+ trust_remote_code=True)
+
+ # Ensure all numerical values in the optimizer are fp32.
+ # FSDP will use low precision during forward.
+ llm.to(dtype)
+
+ if args.use_lora:
+ llm.requires_grad_(False)
+ if world_size > 1:
+ llm.to(dtype)
+
+ if args.lora_targets is None:
+ llm_cls = llm.__class__.__name__
+ args.lora_targets = LORA_TARGET_MAP[llm_cls]
+ llm_lora_cfg = LoraConfig(
+ target_modules=args.lora_targets,
+ r=args.lora_r,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias,
+ task_type='CAUSAL_LM')
+ llm = get_peft_model(llm, llm_lora_cfg)
+
+ return llm
+
+
+class TrainState(Stateful):
+
+ def __init__(self, total_steps, seed):
+ super().__init__()
+
+ self.seed = seed
+ self.cur_step = -1
+ self.total_steps = total_steps
+ self.if_nan_skip_steps = 0
+
+ def load_state_dict(self, state_dict):
+ assert self.total_steps == state_dict['total_steps']
+ self.cur_step = state_dict['current_step']
+ self.if_nan_skip_steps = state_dict['if_nan_skip_steps']
+
+ def state_dict(self):
+ return {
+ 'seed': self.seed, 'current_step': self.cur_step,
+ 'total_steps': self.total_steps,
+ 'if_nan_skip_steps': self.if_nan_skip_steps
+ }
+
+ def step(self):
+ self.cur_step = self.cur_step + 1
+
+ def found_nan(self):
+ self.if_nan_skip_steps += 1
+
+
+def find_latest_timestamp(work_dir):
+ # Initialize variables to keep track of the latest timestamp and its corresponding directory
+ latest_timestamp = None
+
+ # Iterate over all files and directories in the specified directory
+ for entry in os.listdir(work_dir):
+ full_path = os.path.join(work_dir, entry)
+
+ # Check if the entry is a directory
+ if os.path.isdir(full_path):
+ try:
+ # Try to interpret the directory name as a timestamp
+ timestamp = datetime.strptime(entry, '%Y%m%d%H%M%S')
+
+ # Update the latest timestamp and directory if this one is more recent
+ if latest_timestamp is None or timestamp > latest_timestamp:
+ latest_timestamp = timestamp
+ except ValueError:
+ # If conversion fails, skip this entry
+ continue
+
+ if latest_timestamp is not None:
+ latest_timestamp = latest_timestamp.strftime( '%Y%m%d%H%M%S')
+
+ return latest_timestamp
+
+
+def find_checkpoints(directory, prefix='ckpt'):
+
+ if prefix == 'ckpt':
+ pattern = r'^ckpt-(\d+)$'
+ elif prefix == 'hf':
+ pattern = r'^hf-(\d+)$'
+ else:
+ raise ValueError
+
+ latest_step = -1
+ latest_checkpoint = None
+
+ all_folders = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
+
+ checkpoints = []
+ for folder in all_folders:
+ match = re.match(pattern, folder)
+ if match:
+ # 将文件夹名和匹配到的数字转换为整数并存储为元组
+ checkpoints.append((folder, int(match.group(1))))
+
+ checkpoints.sort(key=lambda x: x[1])
+
+ return [os.path.join(directory, folder[0]) for folder in checkpoints]
+
+
+
+
+# @logger.catch
+def sft(args):
+ ###########################################################################
+ # 1. Environment #
+ ###########################################################################
+ setup_parallel(sp_size=args.sp_size, tp_size=1)
+ set_random_seed(args.seed)
+
+ dp_mesh = get_dp_mesh()
+ tp_mesh = get_tp_mesh()
+ sp_mesh = get_sp_mesh()
+ fsdp_mesh = get_fsdp_mesh() # dp_size * sp_size
+ world_mesh = get_world_mesh() # dp_size * sp_size * tp_size
+
+ dp_size = dp_mesh.size()
+ tp_size = tp_mesh.size()
+ sp_size = sp_mesh.size()
+ world_size = world_mesh.size()
+
+ cpu_comm_timeout = timedelta(minutes=60)
+ gloo_group = dist.new_group(backend='gloo', timeout=cpu_comm_timeout)
+
+ if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ 'should be divisible by the '
+ f'world_size({world_size}).')
+
+ if (args.global_batch_size / dp_size) % args.mirco_batch_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ f'should be divisible by the world_size({world_size})'
+ f' * `mirco_batch_size`({args.mirco_batch_size})')
+
+ rank = dist.get_rank()
+
+ if args.resume:
+ mkdir_or_exist(args.work_dir)
+ timestamp = find_latest_timestamp(args.work_dir)
+
+ if timestamp is None:
+ timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
+ else:
+ timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
+
+ objects = [timestamp]
+ dist.broadcast_object_list(objects, src=0)
+ timestamp = objects[0]
+
+ args.work_dir = os.path.join(args.work_dir, timestamp)
+ mkdir_or_exist(args.work_dir)
+
+ log_file = os.path.join(args.work_dir, f'rank{rank}.log')
+
+ logger.remove()
+ # Change the log format printed in the terminal
+ lvl = 'DEBUG' if args.debug else 'INFO'
+ logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
+ # Change the format saved in the log file
+ logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)
+
+ if args.feishu_webhook and rank == 0:
+ def log_handler(record):
+ if record['level'].name == "WARNING":
+ send_to_feishu(args.feishu_webhook, f"[WARNING] {record['message']}\n{args.work_dir}")
+ elif record['level'].name == "TRACE":
+ send_to_feishu(args.feishu_webhook, f"[TRACE] {record['message']}\n{args.work_dir}")
+ elif record['level'].name == "ERROR":
+ send_to_feishu(args.feishu_webhook, f"[ERROR] 任务失败\n{args.work_dir}")
+
+ logger.add(sys.stderr, level='TRACE', filter=log_handler, catch=True)
+
+ logger.trace('任务开始')
+
+ logger.info(args)
+ if rank == 0:
+ env = collect_env()
+ import transformers
+
+ import xtuner
+ env['Transformers'] = transformers.__version__
+ env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
+ runtime_env = OrderedDict()
+ runtime_env.update(env)
+ runtime_env['Seed'] = args.seed
+ runtime_env['World Size'] = world_size
+ runtime_env['DP Size'] = dp_size
+ runtime_env['SP Size'] = sp_size
+ runtime_env['TP Size'] = tp_size
+ # runtime_env['Distributed launcher'] = dist_launcher
+
+ runtime_env_info = '\n ' + '\n '.join(
+ f'{k}: {v}' for k, v in runtime_env.items())
+ dash_line = '-' * 60
+ logger.info('\n' + dash_line + '\nRuntime environment:' +
+ runtime_env_info + '\n' + dash_line + '\n')
+ # ------------------- Environment End ------------------------------ #
+
+ ###########################################################################
+ # 2. Dataset & Dataloader #
+ ###########################################################################
+
+ start_load_data_t = time.time()
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer if args.tokenizer else args.llm,
+ trust_remote_code=True,
+ use_fast=False,
+ padding_side='right')
+
+ if args.chat_template:
+ chat_template = CHAT_TEMPLATE_MAP[args.chat_template]
+ else:
+ chat_template = None
+
+ tokenize_fns = []
+ for dset_format in args.dset_formats:
+ # If your data format is not in `SUPPORT_DATA_FORMATS`, you should
+ # redefine a `tokenize_fn`, defining how to convert a piece of raw
+ # data into tokenized data.
+ # The tokenized data must include `input_ids`, `labels``,
+ # and `num_tokens`.
+ tokenize_fn = SftTokenizeFunction(tokenizer, chat_template,
+ dset_format)
+ tokenize_fns.append(tokenize_fn)
+
+ _datasets = load_datasets(
+ paths=args.datasets,
+ cache_dir=args.dset_cache_dir,
+ file_types=args.dset_file_types,
+ sources=args.dset_sources,
+ sample_ratios=args.dset_sample_ratios,
+ map_fns=tokenize_fns,
+ file_pattern=args.file_pattern,
+ max_length=args.max_length
+ )
+
+ if args.dset_pack_level and rank == 0 and args.debug:
+ # Only the tokenized datasets can count the number of tokens
+ num_tokens = sum(dset.num_tokens.sum() for dset in _datasets)
+ logger.debug(f'[Dataset] {num_tokens} tokens.')
+
+ if args.dset_pack_level == 'soft':
+ train_dataset = SoftPackDataset(_datasets, target=args.max_length, blend=args.global_pack)
+ elif args.dset_pack_level == 'hard':
+ raise NotImplementedError
+ else:
+ train_dataset = ConcatDataset(_datasets)
+
+ if args.dset_pack_level and rank == 0:
+ ori_samples = sum([len(dset) for dset in _datasets])
+ packed_samples = len(train_dataset)
+ logger.info(f'[Dataset] (Original) {ori_samples} samples.')
+ logger.info(f'[Dataset] (Packed) {packed_samples} samples.')
+
+ assert varlen_attn_is_available()
+ collator = SftCollator(
+ pack_batch=varlen_attn_is_available(),
+ max_length=args.max_length)
+
+ if args.group_by_length:
+ sampler = LengthGroupedSampler(train_dataset, dp_mesh,
+ args.global_batch_size)
+ else:
+ sampler = ParallelSampler(
+ train_dataset, dp_mesh, args.global_batch_size, shuffle=True)
+
+ gc.collect()
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.mirco_batch_size,
+ num_workers=args.num_workers,
+ # Ensure to round up or drop last based on the `global_batch_size`,
+ # if you want to replace a custom sampler.
+ sampler=sampler,
+ collate_fn=collator,
+ persistent_workers=args.num_workers > 0)
+
+ if rank == 0:
+ logger.info(f'[Dataloader] {len(train_dataloader)} batches.')
+ _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)]
+ _first_batch = collator(_first_batch)
+ _decoded = tokenizer.batch_decode(_first_batch['input_ids'])
+ logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}')
+ logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}')
+ dist.barrier()
+
+ gc.collect()
+ load_data_cost_time = time.time() - start_load_data_t
+ logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s')
+ # ------------------- Dataset & Dataloader End --------------------- #
+
+ ###########################################################################
+ # 3. FSDP #
+ ###########################################################################
+
+ start_model_t = time.time()
+
+ if args.dtype == 'auto':
+ args.dtype = 'bf16' if DEVICE_MODULE.is_bf16_supported() else 'fp16'
+
+ if args.dtype == 'fp16':
+ dtype = torch.float16
+ autocast = torch.amp.autocast(DEVICE, enabled=True, dtype=dtype)
+ scaler = ShardedGradScaler()
+ elif args.dtype == 'bf16':
+ if DEVICE_MODULE.is_bf16_supported():
+ dtype = torch.bfloat16
+ autocast = torch.amp.autocast(DEVICE, enabled=True, dtype=dtype)
+ scaler = None
+ else:
+ raise RuntimeError('The device does not support `bf16`, '
+ 'please set `dtype` to `fp16`.')
+ else:
+ raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, '
+ f'but found {args.dtype}.')
+
+ llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True)
+ if is_flash_attn_2_available():
+ llm_cfg.attn_implementation = 'flash_attention_2'
+
+ llm_cfg.use_cache = False
+ llm_cfg.torch_dtype = dtype
+
+
+
+ # Only load parameters on rank 0 to avoid each rank repeatedly loading the
+ # same model into the CPU, wasting memory
+ if rank == 0:
+ with torch.device('cpu'):
+ rank0_llm = build_llm_model(args, llm_cfg, world_size, dtype)
+ else:
+ rank0_llm = None
+
+ dist.monitored_barrier(group=gloo_group, timeout=cpu_comm_timeout)
+
+ with torch.device('meta'):
+ # Ensure all numerical values in the optimizer are fp32.
+ # FSDP will use low precision during forward.
+ llm = build_llm_model(args, llm_cfg, world_size, dtype)
+ dispatch_hf_code(llm)
+ for module in llm.modules():
+ for p_name, param in module.named_parameters(recurse=False):
+ if param.requires_grad:
+ param_fp32 = torch.nn.Parameter(
+ param.to(dtype=torch.float32))
+ setattr(module, p_name, param_fp32)
+
+ mp_policy = MixedPrecisionPolicy(param_dtype=dtype, reduce_dtype=dtype)
+
+ with profile_time_and_memory('[Parallelize LLM]'):
+ megatron_parallelize(
+ llm,
+ rank0_llm,
+ dp_mesh=fsdp_mesh,
+ tp_mesh=tp_mesh,
+ mp_policy=mp_policy,
+ recompute_ratio=args.selective_recompute,
+ reshard_after_forward=True)
+
+ llm.train()
+
+ dist.barrier()
+ gc.collect()
+ # -------------------------- FSDP End ------------------------------ #
+
+ ###########################################################################
+ # 4. Optimizer & Scheduler #
+ ###########################################################################
+ requried_grad_params = [
+ param for param in llm.parameters() if param.requires_grad
+ ]
+ optimizer = AdamW(
+ requried_grad_params,
+ lr=args.lr,
+ weight_decay=args.wd,
+ betas=(0.9, 0.95))
+
+ global_batch_size = args.global_batch_size
+ mirco_batch_size = args.mirco_batch_size
+
+ # `iter` means once forward+backward
+ # `step` means once optimizer step
+ # `iters_per_step` means gradient accumulative counts
+ iters_per_step = global_batch_size // mirco_batch_size // dp_size
+ iters_per_epoch = len(train_dataloader)
+ steps_per_epoch = math.ceil(iters_per_epoch / iters_per_step)
+
+ total_epochs = args.epochs
+ total_steps = steps_per_epoch * total_epochs
+ if_nan_skip_steps = 0
+ train_state = TrainState(total_steps, args.seed)
+
+ if args.checkpoint_interval == -1:
+ checkpoint_interval = total_steps
+ elif args.checkpoint_interval < 1:
+ checkpoint_interval = int(total_steps * args.checkpoint_interval)
+ else:
+ checkpoint_interval = int(args.checkpoint_interval)
+
+ warmup_steps = int(args.warmup_ratio * total_steps)
+
+ def warmup_fn(x):
+ return x / warmup_steps if x < warmup_steps else 1
+
+ warmup_scheduler = LambdaLR(optimizer, warmup_fn)
+
+ cosine_scheduler = CosineAnnealingLR(
+ optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min)
+
+ start_step = 0
+ gc.collect()
+ # ---------------- Optimizer & Scheduler End ----------------------- #
+
+ ###########################################################################
+ # 5. (Optional) Resume #
+ ###########################################################################
+ if args.resume:
+
+
+ _checkpoints = find_checkpoints(args.work_dir)
+
+ latest_checkpoint = None
+
+ for _ckpt_dir in reversed(_checkpoints):
+ if os.path.exists(os.path.join(_ckpt_dir, '.metadata')):
+ latest_checkpoint = _ckpt_dir
+ break
+
+ if latest_checkpoint:
+
+ with profile_time_and_memory('[Resume]'):
+ _options = StateDictOptions(
+ cpu_offload=True, ignore_frozen_params=True)
+ (shard_model_state_dict,
+ shard_optimizer_state_dict) = get_state_dict(
+ llm, optimizer, options=_options)
+ state_dict = {
+ 'model': shard_model_state_dict,
+ 'optimizer': shard_optimizer_state_dict,
+ 'train_state': train_state,
+ 'warmup_scheduler': warmup_scheduler,
+ 'cosine_scheduler': cosine_scheduler
+ }
+
+ # inplace state_dict
+ dcp.load(
+ state_dict=state_dict,
+ checkpoint_id=latest_checkpoint,
+ )
+
+ _options = StateDictOptions(
+ cpu_offload=True, strict=False)
+ set_state_dict(
+ llm,
+ optimizer,
+ model_state_dict=state_dict["model"],
+ optim_state_dict=state_dict["optimizer"],
+ options=_options
+ )
+
+ start_step = train_state.cur_step + 1
+
+ else:
+ logger.warning(f'There is no checkpoint available for resuming training in {args.work_dir}.')
+
+ ###########################################################################
+ # 6. Training #
+ ###########################################################################
+ ckpt_handle = None
+ start_train_t = time.time()
+ DEVICE_MODULE.empty_cache()
+ DEVICE_MODULE.reset_peak_memory_stats()
+ max_memory = DEVICE_MODULE.max_memory_allocated()
+ logger.info('[Train] Begin Train Loop. The current GPU memory is '
+ f'{(max_memory / 1024**3):.1f}GB')
+
+ for step in range(start_step, total_steps):
+
+ if is_interval(step + 1, total_steps, args.gc_interval):
+ gc.collect()
+
+ epoch = step // steps_per_epoch
+ epoch_inner_step = step % steps_per_epoch
+ if epoch_inner_step == 0 or step == start_step:
+ # For the first step of each epoch, the data order needs to be
+ # readjusted.
+ # Or after resuming, for the first step, the dataloader needs to
+ # be adjusted to the position before resume.
+ train_dataloader.sampler.set_epoch(epoch, epoch_inner_step * iters_per_step )
+ data_iterator = iter(train_dataloader)
+
+ train_state.step()
+
+ if step <= warmup_steps:
+ warmup_scheduler.step()
+ cur_lr = warmup_scheduler.get_last_lr()[0]
+ else:
+ cosine_scheduler.step()
+ cur_lr = cosine_scheduler.get_last_lr()[0]
+
+ DEVICE_MODULE.reset_peak_memory_stats()
+
+ step_loss = 0
+ step_data_time = 0
+ step_start_t = time.time()
+ step_consumed_tokens = 0
+
+ _data_start_t = time.time()
+
+ step_data_list = [next(data_iterator) for _ in range(iters_per_step)]
+ rank_grad_tokens = 0
+ for _iter in range(iters_per_step):
+ _iter_data = step_data_list[_iter]
+ _iter_labels = _iter_data['labels'][:, 1:]
+ rank_grad_tokens += (_iter_labels >= 0).sum()
+ rank_grad_tokens = rank_grad_tokens.to(DEVICE)
+ dist.all_reduce(rank_grad_tokens)
+ global_grad_tokens = rank_grad_tokens / sp_size / tp_size
+
+
+ step_data_time = time.time() - _data_start_t
+
+ for _iter in range(iters_per_step):
+
+ data = step_data_list[_iter]
+ input_ids = data['input_ids'][:, :-1].to(DEVICE)
+
+ labels = data['labels'][:, 1:].to(DEVICE)
+ num_tokens = data['num_tokens'].to(DEVICE)
+
+ if num_tokens[-1] == 1:
+ num_tokens = num_tokens[:-1]
+ else:
+ num_tokens[-1] = num_tokens[-1] - 1
+
+ if sp_size > 1:
+ # `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
+ input_ids = pad_for_sequence_parallel(input_ids, 0, sp_mesh, dim=1)
+ _num_pad = input_ids.numel() - num_tokens.sum()
+ if _num_pad > 0:
+ _num_pad = torch.IntTensor([_num_pad]).to(DEVICE)
+ num_tokens = torch.cat([num_tokens, _num_pad], dim=-1)
+
+ input_ids = split_for_sequence_parallel(
+ input_ids, dim=1, sp_mesh=sp_mesh)
+
+ labels = pad_for_sequence_parallel(labels, -100,sp_mesh, dim=1)
+ labels = split_for_sequence_parallel(
+ labels, dim=1, sp_mesh=sp_mesh)
+
+ packed_ctx = packed_sequence(num_tokens, sp_mesh=sp_mesh)
+
+ with packed_ctx, autocast if args.use_lora else nullcontext():
+ loss = llm(input_ids=input_ids, labels=labels, label_shifted=True, use_cache=False).loss
+
+ loss = loss * (labels >= 0).sum() / global_grad_tokens * dp_size
+
+ if scaler and args.use_lora:
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ step_loss += loss.item()
+ step_consumed_tokens += num_tokens.sum() / sp_size / tp_size
+
+ step_reduced_loss = torch.Tensor([step_loss]).to(DEVICE)
+ dist.all_reduce(step_reduced_loss)
+ step_reduced_loss = step_reduced_loss.item() / world_size
+
+ grad_norm = clip_grad_norm_(
+ requried_grad_params, fsdp_mesh, args.max_grad_norm)
+
+ if grad_norm.isnan() or grad_norm.isinf():
+ train_state.found_nan()
+ logger.warning(f"[Step {step}] The grad norm is NaN or Inf, skip this step. Skipped {train_state.if_nan_skip_steps} steps in total.")
+ optimizer.zero_grad()
+ else:
+ optimizer.step()
+ optimizer.zero_grad()
+
+ step_time = time.time() - step_start_t
+ eta = step_time * (total_steps - step)
+ eta = timedelta(seconds=int(eta))
+ tgs = int(step_consumed_tokens / step_time)
+ max_memory = DEVICE_MODULE.max_memory_allocated()
+ if is_interval(step, total_steps, args.log_interval):
+ logger.info(f'[Train] (Epoch {epoch + 1}) Step '
+ f'{step + 1}/{total_steps} '
+ f'lr: {cur_lr:.6f} loss: {step_loss:.3f} '
+ f'loss(reduced): {step_reduced_loss:.3f} '
+ f'grad_norm: {grad_norm:.2f} '
+ f'if_nan_skip: {train_state.if_nan_skip_steps} '
+ f'max_memory: {(max_memory / 1024**3):.1f}GB '
+ f'text_tokens: {step_consumed_tokens} '
+ f'tgs: {tgs} data_time: {step_data_time:.2f}s '
+ f'time: {step_time:.2f}s '
+ f'eta: {eta}')
+
+ if is_interval(step, total_steps, max(1, int(total_steps * 0.1))):
+ logger.trace(f'Step {step}/{total_steps}, loss {step_loss:.3f}, tgs {tgs}')
+
+ if is_interval(step, total_steps, checkpoint_interval):
+
+ num_digits = len(str(abs(total_steps)))
+ work_dir = args.work_dir
+ ckpt_dir = os.path.join(work_dir, f'ckpt-{step+1:0{num_digits}}')
+ hf_dir = os.path.join(work_dir, f'hf-{step+1:0{num_digits}}')
+
+ with profile_time_and_memory('[HF Checkpoint]'):
+
+ from torch.distributed._tensor import DTensor
+
+ if rank == 0:
+ llm_state_dict = {}
+
+ for name, param in llm.state_dict().items():
+ if isinstance(param, DTensor):
+ with torch.no_grad():
+ full_param = param.full_tensor().cpu()
+ else:
+ full_param = param.cpu()
+
+ if rank == 0:
+ llm_state_dict[name] = full_param
+
+ if rank == 0:
+ rank0_llm.load_state_dict(llm_state_dict)
+ rank0_llm.save_pretrained(hf_dir)
+ tokenizer.save_pretrained(hf_dir)
+
+ dist.barrier()
+
+ saved_hf_checkpoints = find_checkpoints(args.work_dir, prefix='hf')
+
+ if len(saved_hf_checkpoints) > args.checkpoint_max_keep:
+ for _ckpt in saved_hf_checkpoints[:-args.checkpoint_max_keep]:
+ if rank == 0:
+ shutil.rmtree(_ckpt)
+ logger.info('[HF Checkpoint] Delete the oldest checkpoint.')
+
+
+ if args.checkpoint_drop_optimizer:
+ logger.warning('The saved checkpoint cannot be resumed. '
+ 'If you want to save a resumable checkpoint, '
+ 'please remove `--checkpoint-drop-optimizer` '
+ 'from the command.')
+ else:
+
+ with profile_time_and_memory('[PT Checkpoint]'):
+ if ckpt_handle is not None:
+ wait([ckpt_handle])
+
+ # FSDP cannot be saved via torch.save
+ # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501
+ _options = StateDictOptions(
+ cpu_offload=True, ignore_frozen_params=True)
+ (shard_model_state_dict,
+ shard_optimizer_state_dict) = get_state_dict(
+ llm, optimizer, options=_options)
+
+ state_dict = {
+ 'model': shard_model_state_dict,
+ 'optimizer': shard_optimizer_state_dict,
+ 'train_state': train_state.state_dict(),
+ 'warmup_scheduler': warmup_scheduler.state_dict(),
+ 'cosine_scheduler': cosine_scheduler.state_dict()
+ }
+
+ mkdir_or_exist(ckpt_dir)
+ ckpt_handle = dcp.async_save(state_dict, checkpoint_id=ckpt_dir, process_group=gloo_group)
+
+ saved_checkpoints = find_checkpoints(args.work_dir)
+
+ if len(saved_checkpoints) > args.checkpoint_max_keep:
+ for _ckpt in saved_checkpoints[:-args.checkpoint_max_keep]:
+ if rank == 0:
+ shutil.rmtree(_ckpt)
+ logger.info('[PT Checkpoint] Delete the oldest checkpoint.')
+
+ if ckpt_handle is not None:
+ wait([ckpt_handle])
+
+ logger.trace('Task Finished')
+
+ train_cost_time = time.time() - start_train_t
+ logger.info(f'[Train] Cost {timedelta(seconds=int(train_cost_time))}')
+ # ------------------------ Training End ---------------------------- #
+
+if __name__ == '__main__':
+
+ args = parse_args()
+ sft(args)
diff --git a/tools/fsdp_tp_sft.py b/tools/fsdp_tp_sft.py
new file mode 100644
index 000000000..321829496
--- /dev/null
+++ b/tools/fsdp_tp_sft.py
@@ -0,0 +1,786 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import math
+import os
+import sys
+import time
+from collections import OrderedDict
+from datetime import datetime, timedelta
+from functools import partial
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dcp
+from accelerate.utils import set_module_tensor_to_device
+from mmengine import mkdir_or_exist
+from mmengine.dist import infer_launcher, init_dist
+from mmengine.runner import set_random_seed
+from mmengine.utils import get_git_hash
+from mmengine.utils.dl_utils import collect_env
+from torch.distributed._tensor import Replicate
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
+ apply_activation_checkpointing
+from torch.distributed.checkpoint.state_dict import (StateDictOptions,
+ get_model_state_dict,
+ get_state_dict)
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import MixedPrecision
+from torch.distributed.fsdp.api import CPUOffload, ShardingStrategy
+from torch.distributed.tensor.parallel import (ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module)
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
+from torch.utils.data import ConcatDataset, DataLoader
+from transformers import AutoConfig, AutoModelForCausalLM
+from transformers.utils.import_utils import (is_flash_attn_2_available,
+ is_torch_sdpa_available)
+
+from xtuner._lite import AutoTokenizer, get_logger
+from xtuner._lite.accelerate import dispatch_modules, packed_sequence
+from xtuner._lite.chat import CHAT_TEMPLATE_MAP
+from xtuner._lite.datasets import (OPENAI_FORMAT_MAP, SoftPackerForText,
+ TextCollator, TextOnlineTokenizeDataset,
+ TextTokenizedDataset, TextTokenizeFunction)
+from xtuner._lite.datasets.load import (LOAD_FN_MAP, load_datasets,
+ load_from_cache)
+from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler,
+ get_dp_mesh, get_tp_mesh, setup_parallel)
+from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit,
+ checkpoint_check_fn, dp_tp_lazy_init,
+ layer_auto_wrap_policy)
+
+logger = get_logger()
+
+SUPPORT_DATA_FORMATS = OPENAI_FORMAT_MAP.keys()
+
+
+def log_format(rank, debug=False):
+
+ formatter = f'[XTuner][RANK {rank}]'
+ formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]'
+
+ if debug:
+ formatter += '[{name}:'
+ formatter += '{function}:'
+ formatter += '{line}]'
+
+ formatter += ' {message}'
+ return formatter
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train LLM')
+
+ model_args = parser.add_argument_group('model', 'Model Related Settings')
+ model_args.add_argument('--llm', help='repo id or local path of the model')
+ model_args.add_argument(
+ '-t',
+ '--tokenizer',
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--chat-template',
+ choices=CHAT_TEMPLATE_MAP.keys(),
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--dtype',
+ default='auto',
+ choices=['fp16', 'bf16', 'auto'],
+ help=("the dtype of the model forward. When set to 'auto', it will "
+ 'automatically determine whether bf16 is available, '
+ 'prioritizing the use of bf16.'))
+
+ model_args.add_argument(
+ '--selective-recompute',
+ default=1.0,
+ type=float,
+ help=('the ratio of re-computation for transforemer layers. '
+ 'The maximum is 1; the larger the value, the less memory '
+ 'required for training. The default is 1, meaning all layers '
+ 'need to be re-computated.'))
+ model_args.add_argument(
+ '--shard-strategy',
+ default='full',
+ choices=['full', 'hybrid'],
+ help=('The sharding strategy to be used for distributed training.'))
+ model_args.add_argument('--cpu-offload', action='store_true', help=(''))
+ model_args.add_argument(
+ '--tp-size', type=int, default=1, help='Tensor Parallel Size')
+
+ data_args = parser.add_argument_group('data', 'Dataset Related Settings')
+ data_args.add_argument(
+ '--datasets',
+ nargs='*',
+ help=('repo id or local path or dir of the datasets. For repo ids, '
+ 'the `dset-sources` needs to be appropriately set to '
+ '`modelscope` or `huggingface`. For local dir, all json and '
+ 'jsonl files will be loaded by default. The type of loaded '
+ 'files can be controlled by setting `dset-file-type`'))
+ data_args.add_argument(
+ '--dset-file-types',
+ nargs='*',
+ default=LOAD_FN_MAP.keys(),
+ choices=LOAD_FN_MAP.keys(),
+ help='the file type that needs to be loaded')
+ data_args.add_argument(
+ '--dset-sources',
+ nargs='*',
+ default=['local'],
+ choices=['local', 'huggingface', 'modelscope'],
+ help=('the source of each dataset; it can accept one or the same '
+ 'number of args as the number of `datasets`, with one arg '
+ 'indicating that all datasets come from the same source. '
+ '`local` represents the local path, `huggingface` represents '
+ 'the open-source data in the Huggingface Hub, `modelscope` '
+ 'indicates the open-source data in the Modelscope Hub.'))
+ data_args.add_argument(
+ '--dset-formats',
+ nargs='*',
+ default=['openai'],
+ help=('the format of each dataset; it can accept one or the same '
+ 'number of args as the number of `datasets`, with one arg '
+ 'indicating that all datasets are the same format.'))
+ data_args.add_argument(
+ '--dset-sample-ratios',
+ nargs='*',
+ default=[1.0],
+ help=('the sample ratio of each dataset; it can accept one or the '
+ 'same number of args as the number of `datasets`, with one arg '
+ 'indicating that all datasets use the same sample ratio.'))
+ data_args.add_argument(
+ '--dset-cache-dir',
+ help=('the cache dir of the loaded datasets. When the `datasets` is '
+ 'set, the loaded datasets will be cached to this dir. If the '
+ '`datasets` are not set, the cached dataset in this dir will be '
+ 'loaded.'))
+ data_args.add_argument(
+ '--dset-from-cache',
+ action='store_true',
+ help=('Load data directly from `dset-cache-dir`. This can save time '
+ 'on online tokenization, but if the tokenizer changed, '
+ 'recaching is needed.'))
+ data_args.add_argument(
+ '--dset-pack-level',
+ choices=['hard', 'soft'],
+ help=('the level of data packing. When `hard`, multiple data will be '
+ 'packed to `max_length`, potentially causing some data to be '
+ 'truncated, and the length of the packed data will always '
+ 'be `max_length`; When `soft`, it will pack multiple data '
+ 'into nearly `max_length` without truncating the data.'))
+ data_args.add_argument(
+ '--max-length',
+ type=int,
+ default=2048,
+ help=('the maximum length of each piece of data, any excess will be '
+ 'truncated.'))
+ data_args.add_argument(
+ '--num-workers',
+ type=int,
+ default=0,
+ help='how many subprocesses to use for data loading.')
+ data_args.add_argument(
+ '--num-proc',
+ type=int,
+ default=8,
+ help='how many subprocesses to use for data mapping.')
+ data_args.add_argument('--file-pattern', type=str, default=None)
+ data_args.add_argument('--group-by-length', action='store_true')
+
+ optim_args = parser.add_argument_group('optim', 'Optim Related Settings')
+ optim_args.add_argument(
+ '--mirco-batch-size',
+ type=int,
+ default=1,
+ help='batch size for each forward + backward pass')
+ optim_args.add_argument(
+ '--global-batch-size',
+ type=int,
+ default=16,
+ help='batch size for each optimizer step')
+
+ optim_args.add_argument(
+ '--lr', default=4e-5, type=float, help='learning rate.')
+ optim_args.add_argument(
+ '--lr-min', default=6e-6, type=float, help='min learning rate.')
+ optim_args.add_argument(
+ '--wd', default=0.01, type=float, help='weight decay.')
+ optim_args.add_argument(
+ '--max-grad-norm', default=1, type=float, help='gradient clipping')
+ optim_args.add_argument(
+ '-e', '--epochs', default=1, type=int, help='total training epochs.')
+ optim_args.add_argument(
+ '--warmup-ratio',
+ default=0.03,
+ type=float,
+ help=('the proportion of training steps for learning rate warm-up in '
+ 'relation to the total training steps.'))
+
+ parser.add_argument('-c', '--config', default=None)
+ parser.add_argument(
+ '--work-dir',
+ default='work_dirs',
+ help='the dir to save logs and checkpoints')
+ parser.add_argument(
+ '--checkpoint-interval',
+ default=-1,
+ type=float,
+ help=('how many steps to save a checkpoint; it can be a floating '
+ 'point number less than 1, or an integer greater than or equal '
+ "to 1. When it's a floating point, it will be multiplied by the "
+ 'total number of training steps.'))
+ parser.add_argument(
+ '--checkpoint-drop-optimizer',
+ action='store_true',
+ help=('only model parameters are saved when saving a checkpoint. '
+ 'This can significantly reduce the size of checkpoint files, '
+ 'but the saved checkpoints cannot be resumed.'))
+ parser.add_argument(
+ '--log-interval', default=1, type=int, help='log interval')
+ parser.add_argument(
+ '--resume',
+ type=str,
+ default=None,
+ help='specify checkpoint path to be resumed from.')
+ parser.add_argument(
+ '--seed', type=int, default=0, help='random seed for the training')
+ parser.add_argument(
+ '--debug', action='store_true', help='Set logger level to `DEBUG`')
+ args = parser.parse_args()
+ return args
+
+
+def is_interval(step, total_steps, interval):
+ return (step + 1) % interval == 0 or (step + 1) == total_steps
+
+
+def map_meta_modules(model, meta_model):
+ modules = {name: mod for name, mod in model.named_modules()}
+ meta_module_map = {
+ mod: modules[name]
+ for name, mod in meta_model.named_modules()
+ }
+ return meta_module_map
+
+
+def build_llm_model(args, config, world_size, dtype=torch.float32):
+ with LoadWoInit():
+ llm = AutoModelForCausalLM.from_pretrained(
+ args.llm, config=config, trust_remote_code=True)
+
+ # Ensure all numerical values in the optimizer are fp32.
+ # FSDP will use low precision during forward.
+ llm.to(dtype)
+ return llm
+
+
+# @logger.catch
+def sft(args):
+ ###########################################################################
+ # 1. Environment #
+ ###########################################################################
+ dist_launcher = infer_launcher()
+ init_dist(dist_launcher)
+ set_random_seed(args.seed)
+
+ world_size = int(os.environ['WORLD_SIZE'])
+ tp_size = args.tp_size
+ dp_size = world_size // tp_size
+
+ if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ 'should be divisible by the '
+ f'world_size({world_size}).')
+
+ if (args.global_batch_size / dp_size) % args.mirco_batch_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ f'should be divisible by the world_size({world_size})'
+ f' * `mirco_batch_size`({args.mirco_batch_size})')
+
+ if args.dset_cache_dir and os.path.isdir(args.dset_cache_dir):
+ if len(os.listdir(args.dset_cache_dir)) and not args.dset_from_cache:
+ raise RuntimeError(f'`{args.dset_cache_dir}` is not an empty '
+ 'folder, which may lead to inaccurate '
+ 'cache results.')
+
+ setup_parallel(tp_size=tp_size)
+ dp_mesh = get_dp_mesh()
+ tp_mesh = get_tp_mesh()
+
+ rank = dist.get_rank()
+
+ timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
+
+ objects = [timestamp]
+ dist.broadcast_object_list(objects, src=0)
+ timestamp = objects[0]
+
+ args.work_dir = os.path.join(args.work_dir, timestamp)
+ mkdir_or_exist(args.work_dir)
+
+ log_file = os.path.join(args.work_dir, f'rank{rank}.log')
+
+ # Change the log format printed in the terminal
+ lvl = 'DEBUG' if args.debug else 'INFO'
+ logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
+ # Change the format saved in the log file
+ logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)
+
+ logger.info(args)
+ if rank == 0:
+ env = collect_env()
+ import transformers
+
+ import xtuner
+ env['Transformers'] = transformers.__version__
+ env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
+ runtime_env = OrderedDict()
+ runtime_env.update(env)
+ runtime_env['Seed'] = args.seed
+ runtime_env['World Size'] = world_size
+ runtime_env['Distributed launcher'] = dist_launcher
+
+ runtime_env_info = '\n ' + '\n '.join(
+ f'{k}: {v}' for k, v in runtime_env.items())
+ dash_line = '-' * 60
+ logger.info('\n' + dash_line + '\nRuntime environment:' +
+ runtime_env_info + '\n' + dash_line + '\n')
+ # ------------------- Environment End ------------------------------ #
+
+ ###########################################################################
+ # 2. Dataset & Dataloader #
+ ###########################################################################
+
+ start_load_data_t = time.time()
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer if args.tokenizer else args.llm,
+ trust_remote_code=True,
+ padding_side='right')
+
+ if args.dset_from_cache:
+ if args.dset_pack_level == 'soft':
+ init_fn = partial(
+ SoftPackerForText.from_cache, max_length=args.max_length)
+ elif args.dset_pack_level == 'hard':
+ raise NotImplementedError
+ else:
+ init_fn = partial(
+ TextTokenizeFunction.from_cache, max_length=args.max_length)
+ _datasets = load_from_cache(args.dset_cache_dir, init_fn)
+ dist.barrier()
+ else:
+ chat_template = CHAT_TEMPLATE_MAP[args.chat_template]
+ tokenize_fns = []
+ init_fns = []
+ for dset_format in args.dset_formats:
+ # If your data format is not in `SUPPORT_DATA_FORMATS`, you should
+ # redefine a `tokenize_fn`, defining how to convert a piece of raw
+ # data into tokenized data.
+ # The tokenized data must include `input_ids`, `labels``,
+ # and `num_tokens`.
+ tokenize_fn = TextTokenizeFunction(tokenizer, chat_template,
+ dset_format)
+
+ if args.dset_pack_level == 'soft':
+ init_fn = partial(
+ SoftPackerForText, max_length=args.max_length)
+ elif args.dset_cache_dir:
+ init_fn = partial(
+ TextTokenizedDataset, max_length=args.max_length)
+ else:
+ init_fn = partial(
+ TextOnlineTokenizeDataset, tokenize_fn=tokenize_fn)
+ # Online tokenization is used when not using a pack dataset,
+ # saving startup time.
+ tokenize_fn = None
+
+ tokenize_fns.append(tokenize_fn)
+ init_fns.append(init_fn)
+
+ _datasets = load_datasets(
+ paths=args.datasets,
+ cache_dir=args.dset_cache_dir,
+ file_types=args.dset_file_types,
+ sources=args.dset_sources,
+ sample_ratios=args.dset_sample_ratios,
+ num_proc=args.num_proc,
+ map_fns=tokenize_fns,
+ init_fns=init_fns,
+ file_pattern=args.file_pattern)
+
+ if (args.dset_pack_level or args.cache_dir) and rank == 0 and args.debug:
+ # Only the tokenized datasets can count the number of tokens
+ num_tokens = sum(sum(dset['num_tokens']) for dset in _datasets)
+ logger.debug(f'[Dataset] {num_tokens} tokens.')
+
+ train_dataset = ConcatDataset(_datasets)
+
+ if args.dset_pack_level and rank == 0:
+ ori_samples = sum([len(dset) for dset in _datasets])
+ packed_samples = len(train_dataset)
+ logger.info(f'[Dataset] (Original) {ori_samples} samples.')
+ logger.info(f'[Dataset] (Packed) {packed_samples} samples.')
+
+ pack_batch = is_flash_attn_2_available()
+ collator = TextCollator(pack_batch=pack_batch)
+
+ if args.group_by_length:
+ sampler = LengthGroupedSampler(train_dataset, dp_mesh,
+ args.global_batch_size)
+ else:
+ sampler = ParallelSampler(
+ train_dataset, dp_mesh, args.global_batch_size, shuffle=True)
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.mirco_batch_size,
+ num_workers=args.num_workers,
+ # Ensure to round up or drop last based on the `global_batch_size`,
+ # if you want to replace a custom sampler.
+ sampler=sampler,
+ collate_fn=collator,
+ persistent_workers=args.num_workers > 0)
+
+ if rank == 0:
+ logger.info(f'[Dataloader] {len(train_dataloader)} batches.')
+ _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)]
+ _first_batch = collator(_first_batch)
+ _decoded = tokenizer.batch_decode(_first_batch['input_ids'])
+ logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}')
+ logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}')
+ dist.barrier()
+
+ load_data_cost_time = time.time() - start_load_data_t
+ logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s')
+ # ------------------- Dataset & Dataloader End --------------------- #
+
+ ###########################################################################
+ # 3. FSDP #
+ ###########################################################################
+
+ start_model_t = time.time()
+
+ if args.dtype == 'auto':
+ args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16'
+
+ if args.dtype == 'fp16':
+ dtype = torch.float16
+ elif args.dtype == 'bf16':
+ if torch.cuda.is_bf16_supported():
+ dtype = torch.bfloat16
+ else:
+ raise RuntimeError('The device does not support `bf16`, '
+ 'please set `dtype` to `fp16`.')
+ else:
+ raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, '
+ f'but found {args.dtype}.')
+
+ llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True)
+ if is_flash_attn_2_available():
+ llm_cfg.attn_implementation = 'flash_attention_2'
+ elif is_torch_sdpa_available():
+ llm_cfg.attn_implementation = 'sdpa'
+
+ llm_cfg.use_cache = False
+ llm_cfg.torch_dtype = dtype
+
+ with torch.device('meta'):
+ # Ensure all numerical values in the optimizer are fp32.
+ # FSDP will use low precision during forward.
+ meta_llm = build_llm_model(args, llm_cfg, world_size, torch.float32)
+
+ if pack_batch:
+ dispatch_modules(meta_llm)
+
+ # Only load parameters on rank 0 to avoid each rank repeatedly loading the
+ # same model into the CPU, wasting memory
+ if rank == 0:
+ with torch.device('cpu'):
+ llm = build_llm_model(args, llm_cfg, world_size, dtype)
+ rank0_meta_llm = copy.deepcopy(meta_llm)
+ meta_llm_map = map_meta_modules(llm, meta_llm)
+ else:
+ meta_llm_map = None
+
+ dist.barrier()
+
+ if args.tp_size > 1:
+ layer_tp_plan = {
+ 'attention.wqkv': ColwiseParallel(),
+ 'attention.wo': RowwiseParallel(),
+ 'feed_forward.w1': ColwiseParallel(),
+ 'feed_forward.w2': RowwiseParallel(),
+ 'feed_forward.w3': ColwiseParallel(),
+ }
+
+ for layer in meta_llm.model.layers:
+ attention = layer.attention
+ attention.num_heads = attention.num_heads // tp_mesh.size()
+ attention.hidden_size = attention.hidden_size // tp_mesh.size()
+ parallelize_module(
+ module=layer,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_tp_plan,
+ )
+
+ meta_llm = parallelize_module(
+ module=meta_llm,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'model.tok_embeddings':
+ RowwiseParallel(input_layouts=Replicate(), ),
+ 'output': ColwiseParallel(output_layouts=Replicate(), ),
+ })
+
+ param_init_fn = partial(
+ dp_tp_lazy_init,
+ module_map=meta_llm_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh)
+
+ if args.shard_strategy == 'full':
+ fsdp_device_mesh = dp_mesh
+ strategy = ShardingStrategy.FULL_SHARD
+ elif args.shard_strategy == 'hybrid':
+ fsdp_device_mesh = init_device_mesh('cuda', (dp_size // 8, 8))
+ strategy = ShardingStrategy.HYBRID_SHARD
+ else:
+ raise ValueError
+
+ torch.cuda.reset_peak_memory_stats()
+ shard_llm = FSDP(
+ meta_llm,
+ device_mesh=fsdp_device_mesh,
+ sharding_strategy=strategy,
+ cpu_offload=CPUOffload(offload_params=args.cpu_offload),
+ auto_wrap_policy=layer_auto_wrap_policy,
+ mixed_precision=MixedPrecision(
+ param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype),
+ device_id=torch.cuda.current_device(),
+ use_orig_params=True,
+ param_init_fn=param_init_fn,
+ sync_module_states=True,
+ )
+
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Model] During building the FSDP model, the peak GPU memory '
+ f'is {max_memory/1024**3:.1f}GB.')
+
+ if args.selective_recompute:
+ check_fn = partial(
+ checkpoint_check_fn,
+ target=RECOMPUTE_MODULES,
+ selective=args.selective_recompute)
+ apply_activation_checkpointing(shard_llm, check_fn=check_fn)
+
+ fsdp_cost_time = time.time() - start_model_t
+ logger.info(f'[Model] Cost {fsdp_cost_time:.2f}s')
+ # -------------------------- FSDP End ------------------------------ #
+
+ ###########################################################################
+ # 4. Optimizer & Scheduler #
+ ###########################################################################
+ requried_grad_params = [
+ param for param in shard_llm.parameters() if param.requires_grad
+ ]
+ optimizer = AdamW(
+ requried_grad_params,
+ lr=args.lr,
+ weight_decay=args.wd,
+ betas=(0.9, 0.95))
+
+ global_batch_size = args.global_batch_size
+ mirco_batch_size = args.mirco_batch_size
+
+ # `iter` means once forward+backward
+ # `step` means once optimizer step
+ # `iters_per_step` means gradient accumulative counts
+ iters_per_step = global_batch_size // mirco_batch_size // dp_size
+ iters_per_epoch = len(train_dataloader)
+ steps_per_epoch = math.ceil(iters_per_epoch / iters_per_step)
+
+ total_epochs = args.epochs
+ total_steps = steps_per_epoch * total_epochs
+
+ if args.checkpoint_interval == -1:
+ checkpoint_interval = total_steps
+ elif args.checkpoint_interval < 1:
+ checkpoint_interval = int(total_steps * args.checkpoint_interval)
+ else:
+ checkpoint_interval = int(args.checkpoint_interval)
+
+ warmup_steps = int(args.warmup_ratio * total_steps)
+
+ def warmup_fn(x):
+ return x / warmup_steps if x < warmup_steps else 1
+
+ warmup_scheduler = LambdaLR(optimizer, warmup_fn)
+
+ cosine_scheduler = CosineAnnealingLR(
+ optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min)
+
+ start_step = 0
+
+ # ---------------- Optimizer & Scheduler End ----------------------- #
+
+ ###########################################################################
+ # 5. Training #
+ ###########################################################################
+
+ start_train_t = time.time()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Train] Begin Train Loop. The current GPU memory is '
+ f'{(max_memory / 1024**3):.1f}GB')
+ for step in range(start_step, total_steps):
+
+ epoch = step // steps_per_epoch
+ epoch_inner_step = step % steps_per_epoch
+ if epoch_inner_step == 0 or step == start_step:
+ # For the first step of each epoch, the data order needs to be
+ # readjusted.
+ # Or after resuming, for the first step, the dataloader needs to
+ # be adjusted to the position before resume.
+
+ train_dataloader.sampler.set_epoch(epoch, epoch_inner_step)
+ data_iterator = iter(train_dataloader)
+
+ if step < warmup_steps:
+ warmup_scheduler.step()
+ cur_lr = warmup_scheduler.get_last_lr()[0]
+ else:
+ cosine_scheduler.step()
+ cur_lr = cosine_scheduler.get_last_lr()[0]
+
+ torch.cuda.reset_peak_memory_stats()
+
+ step_loss = 0
+ step_data_time = 0
+ step_start_t = time.time()
+ step_consumed_tokens = 0
+ for _ in range(iters_per_step):
+
+ _data_start_t = time.time()
+ data = next(data_iterator)
+ step_data_time += time.time() - _data_start_t
+
+ input_ids = data['input_ids'].cuda()
+ labels = data['labels'].cuda()
+ attention_mask = data['attention_mask'].cuda()
+ num_tokens = data['num_tokens'].cuda()
+
+ packed_ctx = packed_sequence(num_tokens, enable=pack_batch)
+
+ with packed_ctx:
+
+ outputs = shard_llm(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=attention_mask)
+
+ avg_iter_loss = outputs.loss / iters_per_step
+ avg_iter_loss.backward()
+
+ step_loss += avg_iter_loss.item()
+ if args.dset_pack_level == 'soft':
+ # During a soft pack process, the data with a length that is
+ # still smaller than the max length after packing, will be
+ # padded to the max length. The last element of num tokens
+ # represents the count of pad tokens.
+ step_consumed_tokens += num_tokens[:-1].sum() / tp_size
+ else:
+ step_consumed_tokens += num_tokens.sum() / tp_size
+
+ grad_norm = shard_llm.clip_grad_norm_(args.max_grad_norm)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ step_time = time.time() - step_start_t
+ eta = step_time * (total_steps - step)
+ eta = timedelta(seconds=int(eta))
+ tgs = int(step_consumed_tokens / step_time)
+ max_memory = torch.cuda.max_memory_allocated()
+ if is_interval(step, total_steps, args.log_interval):
+ logger.info(f'[Train] (Epoch {epoch + 1}) Step '
+ f'{step + 1}/{total_steps} '
+ f'lr: {cur_lr:.6f} loss: {step_loss:.3f} '
+ f'grad_norm: {grad_norm:.2f} '
+ f'max_memory: {(max_memory / 1024**3):.1f}GB '
+ f'text_tokens: {step_consumed_tokens} '
+ f'tgs: {tgs} data_time: {step_data_time:.2f}s '
+ f'time: {step_time:.2f}s '
+ f'eta: {eta}')
+
+ if is_interval(step, total_steps, checkpoint_interval):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Checkpoint] Before saving checkpoint, the peak GPU '
+ f'memory is {max_memory/1024**3:.1f}GB.')
+
+ num_digits = len(str(abs(total_steps)))
+ work_dir = args.work_dir
+ ckpt_dir = os.path.join(work_dir, f'ckpt-{step+1:0{num_digits}}')
+ hf_dir = os.path.join(work_dir, f'hf-{step+1:0{num_digits}}')
+ _options = StateDictOptions(cpu_offload=True, full_state_dict=True)
+
+ full_model_state_dict = get_model_state_dict(
+ shard_llm, options=_options)
+ if rank == 0:
+ saved_llm = copy.deepcopy(rank0_meta_llm)
+ saved_llm.to(dtype)
+ for name, param in full_model_state_dict.items():
+ set_module_tensor_to_device(saved_llm, name, 'cpu', param)
+
+ saved_llm.save_pretrained(hf_dir)
+ tokenizer.save_pretrained(hf_dir)
+ del saved_llm
+
+ dist.barrier()
+ del full_model_state_dict
+
+ if args.checkpoint_drop_optimizer:
+ logger.warning('The saved checkpoint cannot be resumed. '
+ 'If you want to save a resumable checkpoint, '
+ 'please remove `--checkpoint-drop-optimizer` '
+ 'from the command.')
+ else:
+ # FSDP cannot be saved via torch.save
+ # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501
+ _options = StateDictOptions(
+ cpu_offload=True, ignore_frozen_params=True)
+ (shard_model_state_dict,
+ shard_optimizer_state_dict) = get_state_dict(
+ shard_llm, optimizer, options=_options)
+
+ state_dict = {
+ 'model': shard_model_state_dict,
+ 'optimizer': shard_optimizer_state_dict,
+ 'step': step,
+ 'total_steps': total_steps,
+ 'warmup_scheduler': warmup_scheduler.state_dict(),
+ 'cosine_scheduler': cosine_scheduler.state_dict()
+ }
+
+ writer = dcp.FileSystemWriter(ckpt_dir)
+ mkdir_or_exist(ckpt_dir)
+ dcp.save(state_dict, writer)
+
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Checkpoint] During saving checkpoint, the peak GPU '
+ f'memory is {max_memory/1024**3:.1f}GB.')
+
+ train_cost_time = time.time() - start_train_t
+ logger.info(f'[Train] Cost {timedelta(seconds=int(train_cost_time))}')
+ # ------------------------ Training End ---------------------------- #
+
+
+if __name__ == '__main__':
+
+ args = parse_args()
+ sft(args)
diff --git a/tools/llava/convert_xtuner_weights_to_hf.py b/tools/llava/convert_xtuner_weights_to_hf.py
new file mode 100644
index 000000000..e9f3ece9e
--- /dev/null
+++ b/tools/llava/convert_xtuner_weights_to_hf.py
@@ -0,0 +1,147 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/llava/convert_llava_weights_to_hf.py # noqa: E501
+import argparse
+
+from transformers import (CLIPImageProcessor, CLIPVisionModel)
+from xtuner._lite.modelings.llava import LlavaForConditionalGeneration, EnhancedLlavaConfig, LlavaProcessor
+from mmengine import Config
+from xtuner.registry import BUILDER
+from mmengine import print_log
+from xtuner._lite.parallel.fsdp import LoadWoInit
+
+
+
+LLM_PREFIX = 'language_model'
+VIT_PREFIX = 'vision_tower'
+PROJECTOR_MAPPING = {
+ 'model.0': 'multi_modal_projector.linear_1',
+ 'model.2': 'multi_modal_projector.linear_2',
+}
+
+
+def convert_state_dict_to_hf(state_dict, mapping):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if key.endswith('.inv_freq'):
+ continue
+ for key_to_modify, new_key in mapping.items():
+ if key_to_modify in key:
+ key = key.replace(key_to_modify, new_key)
+
+ new_state_dict[key] = value
+ return new_state_dict
+
+def add_prefix(state_dict,prefix):
+
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ new_state_dict[f'{prefix}.{key}'] = value
+
+ return new_state_dict
+
+def convert_to_hf(cfg, save_dir):
+
+ print_log('Loading XTuner Checkpoint...', 'current')
+ model = BUILDER.build(cfg.model)
+ # get state_dict
+ llm = model.llm
+ if model.use_llm_lora:
+ llm = model.llm.merge_and_unload()
+ llm.config.use_cache = True
+
+
+ llm_state_dict = llm.state_dict()
+ llm_state_dict = add_prefix(llm_state_dict, LLM_PREFIX)
+
+
+ visual_encoder = model.visual_encoder
+ if model.use_visual_encoder_lora:
+ visual_encoder = model.visual_encoder.merge_and_unload()
+ assert isinstance(visual_encoder, CLIPVisionModel),\
+ 'This conversion format only supports CLIPVisionModel.'
+
+ visual_encoder_state_dict = visual_encoder.state_dict()
+ visual_encoder_state_dict = add_prefix(
+ visual_encoder_state_dict, VIT_PREFIX)
+
+ projector_state_dict = model.projector.state_dict()
+ projector_state_dict = convert_state_dict_to_hf(
+ projector_state_dict, PROJECTOR_MAPPING)
+
+ state_dict = {
+ **projector_state_dict,
+ **llm_state_dict,
+ **visual_encoder_state_dict
+ }
+
+ tokenizer = BUILDER.build(cfg.tokenizer)
+
+ # init model
+ text_config = llm.config
+ vision_config = visual_encoder.config
+
+ img_token = ''
+ need_resize = False
+ if len(tokenizer.encode(img_token, add_special_tokens=False)) > 1:
+ tokenizer.add_tokens([img_token], special_tokens=True)
+ img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0]
+
+ print_log(f'[Tokenizer] Added a new token `{img_token}`, '
+ f'token id is {img_token_id}, the new vocab size is '
+ f'{len(tokenizer)}', 'current')
+
+ llm_vocab_size = text_config.vocab_size
+ if llm_vocab_size < len(tokenizer):
+ # We add an image token so we need to resize the model
+ need_resize = True
+ else:
+ img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0]
+
+ print_log('Building an empty HF Llava...', 'current')
+ config = EnhancedLlavaConfig(
+ text_config=text_config,
+ vision_config=vision_config,
+ image_token_index=img_token_id,
+ attn_implementation='eager')
+
+ with LoadWoInit():
+ llava = LlavaForConditionalGeneration(config)
+
+ print_log('Loading HF format state dict...', 'current')
+ llava.load_state_dict(state_dict, strict=True, assign=True)
+
+ if need_resize:
+ ori_emb_shape = llava.get_input_embeddings().weight.shape
+ llava.resize_token_embeddings(len(tokenizer))
+ new_emb_shape = llava.get_input_embeddings().weight.shape
+ print_log('Pad the parameters of `embbedings` and `output` from '
+ f'shape {ori_emb_shape} to shape {new_emb_shape}',
+ 'current')
+
+
+ # processor
+ image_processor = BUILDER.build(cfg.image_processor)
+ assert isinstance(image_processor, CLIPImageProcessor),\
+ 'This conversion format only supports CLIPImageProcessor.'
+
+ processor = LlavaProcessor(
+ tokenizer=tokenizer, image_processor=image_processor)
+
+ # save
+ print_log('Saving HF Llava...', 'current')
+ llava.save_pretrained(save_dir)
+ processor.save_pretrained(save_dir)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('config')
+ parser.add_argument('save_dir')
+ args = parser.parse_args()
+
+ cfg = Config.fromfile(args.config)
+ convert_to_hf(cfg, args.save_dir)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/llava/llava_data_examples.json b/tools/llava/llava_data_examples.json
new file mode 100644
index 000000000..b559264ed
--- /dev/null
+++ b/tools/llava/llava_data_examples.json
@@ -0,0 +1,6 @@
+{
+ "med": {
+ "annotations": "/mnt/hwfile/gmai/litianbin/QA_Pairs_new/QA_Pairs/internvl_format/fundus_data/",
+ "sample_ratio": 1.0
+ }
+}
diff --git a/tools/llava/llava_pretrain.py b/tools/llava/llava_pretrain.py
new file mode 100644
index 000000000..166f8f539
--- /dev/null
+++ b/tools/llava/llava_pretrain.py
@@ -0,0 +1,934 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import math
+import os
+import shutil
+import sys
+import time
+from collections import OrderedDict
+from contextlib import nullcontext
+from datetime import datetime, timedelta
+from functools import partial
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dcp
+from accelerate.utils import set_module_tensor_to_device
+from datasets import Dataset
+from mmengine import load, mkdir_or_exist
+from mmengine.dist import infer_launcher, init_dist
+from mmengine.runner import set_random_seed
+from mmengine.utils import get_git_hash
+from mmengine.utils.dl_utils import collect_env
+from peft import LoraConfig, get_peft_model
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
+ apply_activation_checkpointing
+from torch.distributed.checkpoint.state_dict import (StateDictOptions,
+ get_model_state_dict,
+ get_state_dict)
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import MixedPrecision
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import _or_policy
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
+from torch.utils.data import ConcatDataset, DataLoader
+from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor,
+ CLIPVisionModel)
+from transformers.utils.import_utils import (is_flash_attn_2_available,
+ is_torch_sdpa_available)
+
+from xtuner._lite import AutoTokenizer, get_logger
+from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_modules,
+ packed_sequence)
+from xtuner._lite.chat import CHAT_TEMPLATE_MAP
+from xtuner._lite.datasets import (LlavaCollator, LlavaRawDataset,LlavaTokenizedDataset,
+ LlavaTokenizeFunction, SoftPackerForLlava)
+from xtuner._lite.datasets.load import (LOAD_FN_MAP, load_datasets,
+ load_from_cache)
+from xtuner._lite.modelings import register_remote_code, LlavaForConditionalGeneration, EnhancedLlavaConfig, LlavaProcessor
+from xtuner._lite.parallel import LengthGroupedSampler, ParallelSampler
+from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit,
+ all_required_grad_wrap_policy,
+ checkpoint_check_fn, dp_lazy_init,
+ layer_auto_wrap_policy)
+
+logger = get_logger()
+
+
+def log_format(rank, debug=False):
+
+ formatter = f'[XTuner][RANK {rank}]'
+ formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]'
+
+ if debug:
+ formatter += '[{name}:'
+ formatter += '{function}:'
+ formatter += '{line}]'
+
+ formatter += ' {message}'
+ return formatter
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train LLM')
+
+ model_args = parser.add_argument_group('model', 'Model Related Settings')
+ model_args.add_argument('--llm', help='repo id or local path of the model')
+ model_args.add_argument(
+ '--vit',
+ default='openai/clip-vit-large-patch14-336',
+ help='repo id or local path of the model')
+ model_args.add_argument(
+ '-t',
+ '--tokenizer',
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--chat-template',
+ choices=CHAT_TEMPLATE_MAP.keys(),
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--freeze-llm',
+ action='store_true',
+ help="Not updating LLM's parameters")
+ model_args.add_argument(
+ '--freeze-vit',
+ action='store_true',
+ help="Not updating vit's parameters")
+ model_args.add_argument(
+ '--llm-use-lora',
+ action='store_true',
+ help='Apply the adapter to LLM.')
+ model_args.add_argument(
+ '--llm-lora-targets',
+ default=None,
+ nargs='*',
+ help='The names of the modules to apply the adapter to. ')
+ model_args.add_argument(
+ '--llm-lora-r',
+ default=64,
+ type=int,
+ help="Not updating vit's parameters")
+ model_args.add_argument(
+ '--llm-lora-alpha',
+ default=16,
+ type=int,
+ help='The alpha parameter for Lora scaling.')
+ model_args.add_argument(
+ '--llm-lora-dropout',
+ default=0.1,
+ type=float,
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--llm-lora-bias',
+ default='none',
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--vit-use-lora',
+ action='store_true',
+ help='Apply the adapter to Vit.')
+ model_args.add_argument(
+ '--vit-lora-targets',
+ default=None,
+ type=str,
+ help='The names of the modules to apply the adapter to. ')
+ model_args.add_argument(
+ '--vit-lora-r',
+ default=64,
+ type=int,
+ help="Not updating vit's parameters")
+ model_args.add_argument(
+ '--vit-lora-alpha',
+ default=16,
+ type=int,
+ help='The alpha parameter for vit Lora scaling.')
+ model_args.add_argument(
+ '--vit-lora-dropout',
+ default=0.1,
+ type=float,
+ help='The dropout probability for vit Lora layers.')
+ model_args.add_argument(
+ '--vit-lora-bias',
+ default='none',
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--dtype',
+ default='auto',
+ choices=['fp16', 'bf16', 'auto'],
+ help=("the dtype of the model forward. When set to 'auto', it will "
+ 'automatically determine whether bf16 is available, '
+ 'prioritizing the use of bf16.'))
+
+ model_args.add_argument(
+ '--selective-recompute',
+ default=1.0,
+ type=float,
+ help=('the ratio of re-computation for transforemer layers. '
+ 'The maximum is 1; the larger the value, the less memory '
+ 'required for training. The default is 1, meaning all layers '
+ 'need to be re-computated.'))
+
+ data_args = parser.add_argument_group('data', 'Dataset Related Settings')
+ data_args.add_argument(
+ '--datasets',
+ help=('repo id or local path or dir of the datasets. For repo ids, '
+ 'the `dset-sources` needs to be appropriately set to '
+ '`modelscope` or `huggingface`. For local dir, all json and '
+ 'jsonl files will be loaded by default. The type of loaded '
+ 'files can be controlled by setting `dset-file-type`'))
+ data_args.add_argument(
+ '--dset-file-types',
+ nargs='*',
+ default=LOAD_FN_MAP.keys(),
+ choices=LOAD_FN_MAP.keys(),
+ help='the file type that needs to be loaded')
+ data_args.add_argument(
+ '--dset-cache-dir',
+ help=('the cache dir of the loaded datasets. When the `datasets` is '
+ 'set, the loaded datasets will be cached to this dir. If the '
+ '`datasets` are not set, the cached dataset in this dir will be '
+ 'loaded.'))
+ data_args.add_argument(
+ '--dset-from-cache',
+ action='store_true',
+ help=('Load data directly from `dset-cache-dir`. This can save time '
+ 'on online tokenization, but if the tokenizer changed, '
+ 'recaching is needed.'))
+ data_args.add_argument(
+ '--dset-pack-level',
+ choices=['soft'],
+ help=('the level of data packing. When `hard`, multiple data will be '
+ 'packed to `max_length`, potentially causing some data to be '
+ 'truncated, and the length of the packed data will always '
+ 'be `max_length`; When `soft`, it will pack multiple data '
+ 'into nearly `max_length` without truncating the data.'))
+ data_args.add_argument('--group-by-length', action='store_true')
+ data_args.add_argument(
+ '--max-length',
+ type=int,
+ default=2048,
+ help=('the maximum length of each piece of data, any excess will be '
+ 'truncated.'))
+ data_args.add_argument(
+ '--num-workers',
+ type=int,
+ default=0,
+ help='how many subprocesses to use for data loading.')
+ data_args.add_argument(
+ '--num-proc',
+ type=int,
+ default=8,
+ help='how many subprocesses to use for data mapping.')
+
+ optim_args = parser.add_argument_group('optim', 'Optim Related Settings')
+ optim_args.add_argument(
+ '--mirco-batch-size',
+ type=int,
+ default=1,
+ help='batch size for each forward + backward pass')
+ optim_args.add_argument(
+ '--global-batch-size',
+ type=int,
+ default=16,
+ help='batch size for each parameter update')
+
+ optim_args.add_argument(
+ '--lr', default=4e-5, type=float, help='learning rate.')
+ optim_args.add_argument(
+ '--wd', default=0.01, type=float, help='weight decay.')
+ optim_args.add_argument(
+ '--max-grad-norm', default=1, type=float, help='gradient clipping')
+ optim_args.add_argument(
+ '-e', '--epochs', default=1, type=int, help='total training epochs.')
+ optim_args.add_argument(
+ '--warmup-ratio',
+ default=0.03,
+ type=float,
+ help=('the proportion of training steps for learning rate warm-up in '
+ 'relation to the total training steps.'))
+
+ parser.add_argument('-c', '--config', default=None)
+ parser.add_argument(
+ '--work-dir',
+ default='work_dirs',
+ help='the dir to save logs and checkpoints')
+ parser.add_argument(
+ '--checkpoint-interval',
+ default=-1,
+ type=float,
+ help=('how many steps to save a checkpoint; it can be a floating '
+ 'point number less than 1, or an integer greater than or equal '
+ "to 1. When it's a floating point, it will be multiplied by the "
+ 'total number of training steps.'))
+ parser.add_argument(
+ '--checkpoint-drop-optimizer',
+ action='store_true',
+ help=('only model parameters are saved when saving a checkpoint. '
+ 'This can significantly reduce the size of checkpoint files, '
+ 'but the saved checkpoints cannot be resumed.'))
+ parser.add_argument(
+ '--log-interval', default=1, type=int, help='log interval')
+ parser.add_argument(
+ '--resume',
+ type=str,
+ default=None,
+ help='specify checkpoint path to be resumed from.')
+ parser.add_argument(
+ '--seed', type=int, default=0, help='random seed for the training')
+ parser.add_argument(
+ '--debug', action='store_true', help='Set logger level to `DEBUG`')
+ args = parser.parse_args()
+ return args
+
+
+def is_interval(step, total_steps, interval):
+ return (step + 1) % interval == 0 or (step + 1) == total_steps
+
+
+def map_meta_modules(model, meta_model):
+ modules = {name: mod for name, mod in model.named_modules()}
+ meta_module_map = {
+ mod: modules[name]
+ for name, mod in meta_model.named_modules()
+ }
+ return meta_module_map
+
+
+def build_llava_model(args,
+ config,
+ tokenizer,
+ world_size,
+ device='cpu',
+ dtype=torch.float32,
+ resize_emb=False):
+
+ with torch.device(device):
+ _cfg = copy.deepcopy(config)
+ llava = LlavaForConditionalGeneration(_cfg)
+
+ # llava has not loaded the pre-trained parameters of llm and vit
+ if device != 'meta':
+ del llava.language_model
+ del llava.vision_tower
+ with LoadWoInit():
+ llm = AutoModelForCausalLM.from_pretrained(
+ args.llm, config=_cfg.text_config)
+ vit = CLIPVisionModel.from_pretrained(
+ args.vit, config=_cfg.vision_config)
+ llava.language_model = llm
+ llava.vision_tower = vit
+
+ llava.to(dtype)
+
+ if resize_emb:
+ ori_emb_shape = llava.get_input_embeddings().weight.shape
+ llava.resize_token_embeddings(len(tokenizer))
+ new_emb_shape = llava.get_input_embeddings().weight.shape
+ logger.info('Pad the parameters of `embbedings` and `output` from '
+ f'shape {ori_emb_shape} to shape {new_emb_shape}')
+
+ if args.freeze_llm or args.llm_use_lora:
+ llava.language_model.requires_grad_(False)
+ if world_size > 1:
+ llava.language_model.to(dtype)
+
+ if args.freeze_vit or args.vit_use_lora:
+ llava.vision_tower.requires_grad_(False)
+ if world_size > 1:
+ llava.vision_tower.to(dtype)
+
+ if args.llm_use_lora:
+ llm = llava.language_model
+ if args.llm_lora_targets is None:
+ llm_cls = llm.__class__.__name__
+ args.llm_lora_targets = LORA_TARGET_MAP[llm_cls]
+ llm_lora_cfg = LoraConfig(
+ target_modules=args.llm_lora_targets,
+ r=args.llm_lora_r,
+ lora_alpha=args.llm_lora_alpha,
+ lora_dropout=args.llm_lora_dropout,
+ bias=args.llm_lora_bias,
+ task_type='CAUSAL_LM')
+ lora_llm = get_peft_model(llm, llm_lora_cfg)
+ llava.language_model = lora_llm
+
+ if args.vit_use_lora:
+ vit = llava.vision_tower
+ if args.vit_lora_targets is None:
+ vit_cls = vit.__class__.__name__
+ args.vit_lora_targets = LORA_TARGET_MAP[vit_cls]
+ vit_lora_cfg = LoraConfig(
+ target_modules=args.vit_lora_targets,
+ r=args.vit_lora_r,
+ lora_alpha=args.vit_lora_alpha,
+ lora_dropout=args.vit_lora_dropout,
+ bias=args.vit_lora_bias,
+ )
+ llava.vision_tower = get_peft_model(vit, vit_lora_cfg)
+
+ return llava
+
+
+# @logger.catch
+def llava_pretrain(args):
+ ###########################################################################
+ # 1. Environment #
+ ###########################################################################
+ if args.llm_use_lora:
+ args.freeze_llm = True
+
+ if args.vit_use_lora:
+ args.freeze_vit = True
+
+ dist_launcher = infer_launcher()
+ init_dist(dist_launcher)
+ set_random_seed(args.seed)
+
+ world_size = int(os.environ['WORLD_SIZE'])
+ dp_size = world_size
+
+ if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ f'should be divisible by the world_size{world_size}.')
+
+ if (args.global_batch_size / dp_size) % args.mirco_batch_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ f'should be divisible by the world_size{world_size}*'
+ f'`mirco_batch_size`({args.mirco_batch_size})')
+
+ # During data packing, it is essential to tokenize the data in
+ # advance, cache the tokenized data, so that it can be quickly
+ # loaded for the second training without the need to re-tokenize.
+ if args.dset_cache_dir and os.path.isdir(args.dset_cache_dir):
+ if len(os.listdir(args.dset_cache_dir)):
+ logger.warning(f'`{args.dset_cache_dir}` is not an empty '
+ 'folder, which may lead to inaccurate '
+ 'cache results.')
+
+ device_mesh = init_device_mesh(
+ 'cuda', (dp_size, ), mesh_dim_names=('dp', ))
+
+ dp_mesh = device_mesh['dp']
+
+ rank = dp_mesh.get_local_rank()
+ timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
+
+ objects = [timestamp]
+ dist.broadcast_object_list(objects, src=0)
+ timestamp = objects[0]
+
+ args.work_dir = os.path.join(args.work_dir, timestamp)
+ mkdir_or_exist(args.work_dir)
+
+ log_file = os.path.join(args.work_dir, f'rank{rank}.log')
+
+ # Change the log format printed in the terminal
+ lvl = 'DEBUG' if args.debug else 'INFO'
+ logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
+ # Change the format saved in the log file
+ logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)
+
+ logger.info(args)
+ if rank == 0:
+ env = collect_env()
+ import transformers
+
+ import xtuner
+ env['Transformers'] = transformers.__version__
+ env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
+ runtime_env = OrderedDict()
+ runtime_env.update(env)
+ runtime_env['Seed'] = args.seed
+ runtime_env['World Size'] = world_size
+ runtime_env['Distributed launcher'] = dist_launcher
+
+ runtime_env_info = '\n ' + '\n '.join(
+ f'{k}: {v}' for k, v in runtime_env.items())
+ dash_line = '-' * 60
+ logger.info('\n' + dash_line + '\nRuntime environment:' +
+ runtime_env_info + '\n' + dash_line + '\n')
+
+ shutil.copy(__file__, args.work_dir)
+
+ # ------------------- Environment End ------------------------------ #
+
+ ###########################################################################
+ # 2. Dataset & Dataloader #
+ ###########################################################################
+
+ start_load_data_t = time.time()
+
+ chat_template = CHAT_TEMPLATE_MAP[args.chat_template]
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer if args.tokenizer else args.llm,
+ trust_remote_code=True,
+ padding_side='right')
+
+ register_remote_code()
+ _text_config = AutoConfig.from_pretrained(args.llm)
+ _vision_config = AutoConfig.from_pretrained(args.vit).vision_config
+ _img_processor = AutoProcessor.from_pretrained(args.vit).image_processor
+ processor = LlavaProcessor(_img_processor, tokenizer)
+ img_processor = processor.image_processor
+
+ _crop_size = img_processor.crop_size
+ patch_size = _vision_config.patch_size
+ img_size = (_crop_size['height'], _crop_size['width'])
+ per_img_tokens = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+
+ img_token = chat_template.image_token
+ need_resize_emb = False
+ if len(tokenizer.encode(img_token, add_special_tokens=False)) > 1:
+ tokenizer.add_tokens([img_token], special_tokens=True)
+ img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0]
+ logger.info(f'[Tokenizer] Added a new token `{img_token}`, '
+ f'token id is {img_token_id}, the new vocab size is '
+ f'{len(tokenizer)}')
+
+ _llm_vocab_size = _text_config.vocab_size
+
+ if _llm_vocab_size < len(tokenizer):
+ need_resize_emb = True
+ else:
+ img_token_id = tokenizer.convert_tokens_to_ids([img_token])[0]
+
+ if args.dset_from_cache:
+ if args.dset_pack_level == 'soft':
+ init_fn = partial(
+ SoftPackerForLlava.from_cache,
+ image_processor=img_processor,
+ max_length=args.max_length)
+ else:
+ init_fn = partial(
+ LlavaTokenizedDataset.from_cache,
+ image_processor=img_processor,
+ max_length=args.max_length)
+ _datasets = load_from_cache(args.dset_cache_dir, init_fn)
+ dist.barrier()
+ else:
+ dset_infos = load(args.datasets)
+
+ sample_ratios = []
+ annotations = []
+ init_fns = []
+ tokenize_fns = []
+ for _, info in dset_infos.items():
+ if 'format' in info:
+ dset_format = info['format']
+ else:
+ dset_format = 'llava'
+
+ if 'image_dir' in info:
+ image_dir = info['image_dir']
+ else:
+ image_dir = None
+
+ # If your data format is not in `SUPPORT_DATA_FORMATS`, you should
+ # redefine a `tokenize_fn`, defining how to convert a piece of raw
+ # data into tokenized data.
+ # The tokenized data must include `input_ids`, `labels``,
+ # and `num_tokens`.
+ tokenize_fn = LlavaTokenizeFunction(tokenizer, chat_template,
+ per_img_tokens, image_dir,
+ dset_format)
+
+ if args.dset_pack_level == 'soft':
+ init_fn = partial(
+ SoftPackerForLlava,
+ image_processor=img_processor,
+ max_length=args.max_length)
+ elif args.dset_cache_dir:
+ init_fn = partial(
+ LlavaTokenizedDataset,
+ image_processor=img_processor,
+ max_length=args.max_length)
+ else:
+ init_fn = partial(
+ LlavaRawDataset,
+ image_processor=processor.image_processor,
+ tokenize_fn=tokenize_fn)
+ # Online tokenization is used when not using a pack dataset,
+ # saving startup time.
+ tokenize_fn = None
+
+ init_fns.append(init_fn)
+ tokenize_fns.append(tokenize_fn)
+ sample_ratios.append(info['sample_ratio'])
+ annotations.append(info['annotations'])
+
+ _datasets = load_datasets(
+ paths=annotations,
+ sources='local',
+ cache_dir=args.dset_cache_dir,
+ file_types=args.dset_file_types,
+ sample_ratios=sample_ratios,
+ num_proc=args.num_proc,
+ map_fns=tokenize_fns,
+ init_fns=init_fns)
+
+ if args.dset_pack_level and rank == 0:
+ # Only the tokenized datasets can count the number of tokens
+ total_tokens = sum(dset.total_tokens for dset in _datasets)
+ logger.debug(f'[Dataset] {total_tokens} tokens.')
+
+ train_dataset = ConcatDataset(_datasets)
+
+ if args.dset_pack_level and rank == 0:
+ ori_samples = sum([len(dset) for dset in _datasets])
+ packed_samples = len(train_dataset)
+ logger.info(f'[Dataset] (Original) {ori_samples} samples.')
+ logger.info(f'[Dataset] (Packed) {packed_samples} samples.')
+
+ pack_batch = is_flash_attn_2_available()
+ collator = LlavaCollator(pack_batch=pack_batch)
+
+ if args.group_by_length:
+ sampler = LengthGroupedSampler(train_dataset, dp_mesh,
+ args.global_batch_size)
+ else:
+ sampler = ParallelSampler(
+ train_dataset, dp_mesh, args.global_batch_size, shuffle=True)
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.mirco_batch_size,
+ num_workers=args.num_workers,
+ sampler=sampler,
+ collate_fn=collator,
+ persistent_workers=args.num_workers > 0)
+
+ if rank == 0:
+ logger.info(f'[Dataloader] {len(train_dataloader)} batches.')
+ _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)]
+ _first_batch = collator(_first_batch)
+ _decoded = tokenizer.batch_decode(_first_batch['input_ids'])
+ logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}')
+ logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}')
+ dist.barrier()
+
+ load_data_cost_time = time.time() - start_load_data_t
+ logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s')
+ # ------------------- Dataset & Dataloader End --------------------- #
+
+ ###########################################################################
+ # 3. FSDP #
+ ###########################################################################
+
+ start_model_t = time.time()
+
+ if args.dtype == 'auto':
+ args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16'
+
+ if args.dtype == 'fp16':
+ dtype = torch.float16
+ autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
+ scaler = ShardedGradScaler()
+ elif args.dtype == 'bf16':
+ if torch.cuda.is_bf16_supported():
+ dtype = torch.bfloat16
+ autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
+ scaler = None
+ else:
+ raise RuntimeError('The device does not support `bf16`, '
+ 'please set `dtype` to `fp16`.')
+ else:
+ raise RuntimeError('`dtype` only supports `fp16`,`bf16`, or `auto`, '
+ f'but found {args.dtype}.')
+
+ use_lora = (args.llm_use_lora or args.vit_use_lora)
+ if not use_lora:
+ autocast = nullcontext()
+ scaler = None
+
+
+ if is_flash_attn_2_available():
+ _text_config.attn_implementation = 'flash_attention_2'
+ elif is_torch_sdpa_available():
+ _text_config.attn_implementation = 'sdpa'
+
+ _text_config.use_cache = False
+
+ llava_config = EnhancedLlavaConfig(
+ _vision_config, _text_config, image_token_index=img_token_id)
+
+ # model parameters must be in fp32.
+ # this ensures that all numerical values in the optimizer are in fp32.
+ # FSDP will use low precision during forward.
+ meta_llava = build_llava_model(args, llava_config, tokenizer, world_size,
+ 'meta', torch.float32, need_resize_emb)
+
+ if pack_batch or args.dset_pack_level:
+ dispatch_modules(meta_llava)
+
+ # Only load parameters on rank 0 to avoid each rank repeatedly loading the
+ # same model into the CPU, wasting memory
+ if rank == 0:
+
+ llava = build_llava_model(args, llava_config, tokenizer, world_size,
+ 'cpu', dtype, need_resize_emb)
+ rank0_meta_llava = copy.deepcopy(meta_llava)
+ meta_llava_map = map_meta_modules(llava, meta_llava)
+ else:
+ meta_llava_map = None
+
+ dist.barrier()
+
+ param_init_fn = partial(
+ dp_lazy_init, module_map=meta_llava_map, dp_mesh=dp_mesh)
+
+ policies = [layer_auto_wrap_policy]
+ if args.llm_use_lora or args.vit_use_lora:
+ policies.append(all_required_grad_wrap_policy)
+
+ torch.cuda.reset_peak_memory_stats()
+ shard_llava = FSDP(
+ meta_llava,
+ device_mesh=dp_mesh,
+ auto_wrap_policy=partial(_or_policy, policies=policies),
+ mixed_precision=MixedPrecision(
+ param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype),
+ device_id=torch.cuda.current_device(),
+ use_orig_params=True,
+ param_init_fn=param_init_fn,
+ sync_module_states=True,
+ )
+
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('The peak GPU memory when building the FSDP model is '
+ f'{max_memory/1024**3:.1f}GB.')
+
+ if args.selective_recompute:
+ check_fn = partial(
+ checkpoint_check_fn,
+ target=RECOMPUTE_MODULES,
+ selective=args.selective_recompute)
+ apply_activation_checkpointing(shard_llava, check_fn=check_fn)
+
+ fsdp_cost_time = time.time() - start_model_t
+ logger.info(f'[Model] Cost {fsdp_cost_time:.2f}s')
+ # -------------------------- FSDP End ------------------------------ #
+
+ ###########################################################################
+ # 4. Optimizer & Scheduler #
+ ###########################################################################
+ requried_grad_params = [
+ param for param in shard_llava.parameters() if param.requires_grad
+ ]
+ optimizer = AdamW(
+ requried_grad_params, lr=args.lr, weight_decay=args.wd, fused=True)
+
+ global_batch_size = args.global_batch_size
+ mirco_batch_size = args.mirco_batch_size
+
+ # `iter` means once forward+backward
+ # `step` means once optimizer step
+ # `iters_per_step` means gradient accumulative counts
+ iters_per_step = global_batch_size // mirco_batch_size // dp_size
+ iters_per_epoch = len(train_dataloader)
+ steps_per_epoch = math.ceil(iters_per_epoch / iters_per_step)
+
+ total_epochs = args.epochs
+ total_steps = steps_per_epoch * total_epochs
+
+ if args.checkpoint_interval == -1:
+ checkpoint_interval = total_steps
+ elif args.checkpoint_interval < 1:
+ checkpoint_interval = int(total_steps * args.checkpoint_interval)
+ else:
+ checkpoint_interval = int(args.checkpoint_interval)
+
+ warmup_steps = int(args.warmup_ratio * total_steps)
+
+ def warmup_fn(x):
+ return x / warmup_steps if x < warmup_steps else 1
+
+ warmup_scheduler = LambdaLR(optimizer, warmup_fn)
+
+ cosine_scheduler = CosineAnnealingLR(
+ optimizer, T_max=total_steps - warmup_steps, eta_min=0)
+
+ start_step = 0
+
+ # ---------------- Optimizer & Scheduler End ----------------------- #
+
+ ###########################################################################
+ # 5. Training #
+ ###########################################################################
+
+ start_train_t = time.time()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Train] Begin Train Loop. The current GPU memory is '
+ f'{(max_memory / 1024**3):.1f}GB')
+ for step in range(start_step, total_steps):
+
+ epoch = step // steps_per_epoch
+ epoch_inner_step = step % steps_per_epoch
+ if epoch_inner_step == 0 or step == start_step:
+ # For the first step of each epoch, the data order needs to be
+ # readjusted.
+ # Or after resuming, for the first step, the dataloader needs to
+ # be adjusted to the position before resume.
+ # train_dataloader.sampler.set_epoch(epoch, inner_step)
+ # train_dataloader.sampler.set_epoch(epoch, epoch_inner_step)
+ train_dataloader.sampler.set_epoch(epoch)
+ data_iterator = iter(train_dataloader)
+
+ if step < warmup_steps:
+ warmup_scheduler.step()
+ cur_lr = warmup_scheduler.get_lr()[0]
+ else:
+ cosine_scheduler.step()
+ cur_lr = cosine_scheduler.get_lr()[0]
+
+ torch.cuda.reset_peak_memory_stats()
+
+ step_loss = 0
+ step_data_time = 0
+ step_start_t = time.time()
+ step_consumed_tokens = 0
+ step_consumed_img_tokens = 0
+ for _ in range(iters_per_step):
+
+ _data_start_t = time.time()
+ data = next(data_iterator)
+ step_data_time += time.time() - _data_start_t
+
+ input_ids = data['input_ids'].cuda()
+ pixel_values = data['pixel_values']
+ if isinstance(pixel_values, torch.Tensor):
+ pixel_values = pixel_values.cuda()
+ labels = data['labels'].cuda()
+ attention_mask = data['attention_mask'].cuda()
+ num_tokens = data['num_tokens'].cuda()
+ num_img_tokens = data['num_img_tokens'].cuda()
+
+ packed_ctx = packed_sequence(num_tokens, enable=pack_batch)
+
+ with packed_ctx:
+ with autocast if use_lora else nullcontext():
+ outputs = shard_llava(
+ input_ids=input_ids,
+ labels=labels,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask)
+ avg_iter_loss = outputs.loss / iters_per_step
+
+ if scaler and use_lora:
+ scaler.scale(avg_iter_loss).backward()
+ else:
+ avg_iter_loss.backward()
+
+ step_loss += avg_iter_loss.item()
+ step_consumed_tokens += num_tokens.sum()
+ step_consumed_img_tokens += num_img_tokens.sum()
+
+ grad_norm = shard_llava.clip_grad_norm_(args.max_grad_norm)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ step_text_tokens = step_consumed_tokens - step_consumed_img_tokens
+ step_img_tokens = step_consumed_img_tokens
+ step_time = time.time() - step_start_t
+ eta = step_time * (total_steps - step)
+ eta = timedelta(seconds=int(eta))
+ tgs = int(step_consumed_tokens / step_time)
+ max_memory = torch.cuda.max_memory_allocated()
+ if is_interval(step, total_steps, args.log_interval):
+ logger.info(
+ f'[Train] (Epoch {epoch}) Step {step+1}/{total_steps} ' # noqa: E501
+ f'lr: {cur_lr:.6f} loss: {step_loss:.3f} '
+ f'grad_norm: {grad_norm:.2f} '
+ f'max_memory: {(max_memory / 1024**3):.1f}GB '
+ f'text_tokens: {step_text_tokens} '
+ f'image_tokens: {step_img_tokens} '
+ f'tgs: {tgs} data_time: {step_data_time:.2f}s '
+ f'time: {step_time:.2f}s '
+ f'eta: {eta}')
+
+ if is_interval(step, total_steps, checkpoint_interval):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Checkpoint] Before saving checkpoint, the peak GPU '
+ f'memory is {max_memory/1024**3:.1f}GB.')
+
+ num_digits = len(str(abs(total_steps)))
+ work_dir = args.work_dir
+ ckpt_dir = os.path.join(work_dir, f'ckpt-{step:0{num_digits}}')
+ hf_dir = os.path.join(work_dir, f'hf-{step:0{num_digits}}')
+ _options = StateDictOptions(cpu_offload=True, full_state_dict=True)
+
+ full_model_state_dict = get_model_state_dict(
+ shard_llava, options=_options)
+ if rank == 0:
+ saved_llava = copy.deepcopy(rank0_meta_llava)
+ saved_llava.to(dtype)
+ for name, param in full_model_state_dict.items():
+ set_module_tensor_to_device(saved_llava, name, 'cpu',
+ param)
+
+ if args.llm_use_lora:
+ merged_llm = saved_llava.language_model.merge_and_unload()
+ saved_llava.language_model = merged_llm
+
+ if args.vit_use_lora:
+ merged_vit = saved_llava.vision_tower.merge_and_unload()
+ saved_llava.vision_tower = merged_vit
+
+ saved_llava.save_pretrained(hf_dir)
+ tokenizer.save_pretrained(hf_dir)
+ processor.save_pretrained(hf_dir)
+ del saved_llava
+
+ dist.barrier()
+ del full_model_state_dict
+
+ if args.checkpoint_drop_optimizer:
+ logger.warning('[Checkpoint] The saved checkpoint cannot be '
+ 'resumed. If you want to save a resumable '
+ 'checkpoint, please remove '
+ '`--checkpoint-drop-optimizer` '
+ 'from the command.')
+ else:
+ # FSDP cannot be saved via torch.save
+ # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501
+ _options = StateDictOptions(
+ cpu_offload=True, ignore_frozen_params=True)
+ (shard_model_state_dict,
+ shard_optimizer_state_dict) = get_state_dict(
+ shard_llava, optimizer, options=_options)
+
+ state_dict = {
+ 'model': shard_model_state_dict,
+ 'optimizer': shard_optimizer_state_dict,
+ 'step': step,
+ 'total_steps': total_steps,
+ 'warmup_scheduler': warmup_scheduler.state_dict(),
+ 'cosine_scheduler': cosine_scheduler.state_dict()
+ }
+
+ writer = dcp.FileSystemWriter(ckpt_dir)
+ mkdir_or_exist(ckpt_dir)
+ dcp.save(state_dict, writer)
+
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info(
+ '[Checkpoint] During saving checkpoint, the peak GPU '
+ f'memory is {max_memory/1024**3:.1f}GB.')
+
+ train_cost_time = time.time() - start_train_t
+ logger.info(f'[Train] Cost {train_cost_time}s')
+ # ------------------------ Training End ---------------------------- #
+
+
+if __name__ == '__main__':
+
+ args = parse_args()
+ llava_pretrain(args)
diff --git a/tools/llava/llava_sft.py b/tools/llava/llava_sft.py
new file mode 100644
index 000000000..c207dabfc
--- /dev/null
+++ b/tools/llava/llava_sft.py
@@ -0,0 +1,890 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import math
+import os
+import shutil
+import sys
+import time
+from collections import OrderedDict
+from contextlib import nullcontext
+from datetime import datetime, timedelta
+from functools import partial
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dcp
+from accelerate.utils import set_module_tensor_to_device
+from datasets import Dataset
+from mmengine import load, mkdir_or_exist
+from mmengine.dist import infer_launcher, init_dist
+from mmengine.runner import set_random_seed
+from mmengine.utils import get_git_hash
+from mmengine.utils.dl_utils import collect_env
+from peft import LoraConfig, get_peft_model
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
+ apply_activation_checkpointing
+from torch.distributed.checkpoint.state_dict import (StateDictOptions,
+ get_model_state_dict,
+ get_state_dict)
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import MixedPrecision
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import _or_policy
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
+from torch.utils.data import ConcatDataset, DataLoader
+from transformers import (AutoConfig, AutoProcessor)
+from transformers.utils.import_utils import (is_flash_attn_2_available,
+ is_torch_sdpa_available)
+from xtuner._lite.modelings.llava import LlavaForConditionalGeneration
+from xtuner._lite import AutoTokenizer, get_logger
+from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_modules,
+ packed_sequence)
+from xtuner._lite.chat import CHAT_TEMPLATE_MAP
+from xtuner._lite.datasets import (LlavaCollator,
+ LlavaRawDataset, LlavaTokenizedDataset,
+ LlavaTokenizeFunction, SoftPackerForLlava)
+from xtuner._lite.datasets.load import (LOAD_FN_MAP, load_datasets,
+ load_from_cache)
+from xtuner._lite.modelings import register_remote_code
+from xtuner._lite.parallel import LengthGroupedSampler, ParallelSampler
+from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit,
+ all_required_grad_wrap_policy,
+ checkpoint_check_fn, dp_lazy_init,
+ layer_auto_wrap_policy)
+
+logger = get_logger()
+
+
+def log_format(rank, debug=False):
+
+ formatter = f'[XTuner][RANK {rank}]'
+ formatter += '[{time:YYYY-MM-DD HH:mm:ss}][{level}]'
+
+ if debug:
+ formatter += '[{name}:'
+ formatter += '{function}:'
+ formatter += '{line}]'
+
+ formatter += ' {message}'
+ return formatter
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train LLM')
+
+ model_args = parser.add_argument_group('model', 'Model Related Settings')
+ model_args.add_argument(
+ '--llava', help='repo id or local path of the model')
+ model_args.add_argument(
+ '-t',
+ '--tokenizer',
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--chat-template',
+ choices=CHAT_TEMPLATE_MAP.keys(),
+ help=('repo id or local path of the tokenizer. '
+ 'Defaults to the same as `model`'))
+ model_args.add_argument(
+ '--freeze-llm',
+ action='store_true',
+ help="Not updating LLM's parameters")
+ model_args.add_argument(
+ '--freeze-vit',
+ action='store_true',
+ help="Not updating vit's parameters")
+ model_args.add_argument(
+ '--llm-use-lora',
+ action='store_true',
+ help='Apply the adapter to LLM.')
+ model_args.add_argument(
+ '--llm-lora-targets',
+ default=None,
+ nargs='*',
+ help='The names of the modules to apply the adapter to. ')
+ model_args.add_argument(
+ '--llm-lora-r',
+ default=64,
+ type=int,
+ help="Not updating vit's parameters")
+ model_args.add_argument(
+ '--llm-lora-alpha',
+ default=16,
+ type=int,
+ help='The alpha parameter for Lora scaling.')
+ model_args.add_argument(
+ '--llm-lora-dropout',
+ default=0.1,
+ type=float,
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--llm-lora-bias',
+ default='none',
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--vit-use-lora',
+ action='store_true',
+ help='Apply the adapter to Vit.')
+ model_args.add_argument(
+ '--vit-lora-targets',
+ default=None,
+ type=str,
+ help='The names of the modules to apply the adapter to. ')
+ model_args.add_argument(
+ '--vit-lora-r',
+ default=64,
+ type=int,
+ help="Not updating vit's parameters")
+ model_args.add_argument(
+ '--vit-lora-alpha',
+ default=16,
+ type=int,
+ help='The alpha parameter for vit Lora scaling.')
+ model_args.add_argument(
+ '--vit-lora-dropout',
+ default=0.1,
+ type=float,
+ help='The dropout probability for vit Lora layers.')
+ model_args.add_argument(
+ '--vit-lora-bias',
+ default='none',
+ help='The dropout probability for Lora layers.')
+ model_args.add_argument(
+ '--dtype',
+ default='auto',
+ choices=['fp16', 'bf16', 'auto'],
+ help=("the dtype of the model forward. When set to 'auto', it will "
+ 'automatically determine whether bf16 is available, '
+ 'prioritizing the use of bf16.'))
+
+ model_args.add_argument(
+ '--selective-recompute',
+ default=1.0,
+ type=float,
+ help=('the ratio of re-computation for transforemer layers. '
+ 'The maximum is 1; the larger the value, the less memory '
+ 'required for training. The default is 1, meaning all layers '
+ 'need to be re-computated.'))
+
+ data_args = parser.add_argument_group('data', 'Dataset Related Settings')
+ data_args.add_argument(
+ '--datasets',
+ help=('repo id or local path or dir of the datasets. For repo ids, '
+ 'the `dset-sources` needs to be appropriately set to '
+ '`modelscope` or `huggingface`. For local dir, all json and '
+ 'jsonl files will be loaded by default. The type of loaded '
+ 'files can be controlled by setting `dset-file-type`'))
+ data_args.add_argument(
+ '--dset-file-types',
+ nargs='*',
+ default=LOAD_FN_MAP.keys(),
+ choices=LOAD_FN_MAP.keys(),
+ help='the file type that needs to be loaded')
+ data_args.add_argument(
+ '--dset-cache-dir',
+ help=('the cache dir of the loaded datasets. When the `datasets` is '
+ 'set, the loaded datasets will be cached to this dir. If the '
+ '`datasets` are not set, the cached dataset in this dir will be '
+ 'loaded.'))
+ data_args.add_argument(
+ '--dset-from-cache',
+ action='store_true',
+ help=('Load data directly from `dset-cache-dir`. This can save time '
+ 'on online tokenization, but if the tokenizer changed, '
+ 'recaching is needed.'))
+ data_args.add_argument(
+ '--dset-pack-level',
+ choices=['soft'],
+ help=('the level of data packing. When `hard`, multiple data will be '
+ 'packed to `max_length`, potentially causing some data to be '
+ 'truncated, and the length of the packed data will always '
+ 'be `max_length`; When `soft`, it will pack multiple data '
+ 'into nearly `max_length` without truncating the data.'))
+ data_args.add_argument('--group-by-length', action='store_true')
+ data_args.add_argument(
+ '--max-length',
+ type=int,
+ default=2048,
+ help=('the maximum length of each piece of data, any excess will be '
+ 'truncated.'))
+ data_args.add_argument(
+ '--num-workers',
+ type=int,
+ default=1,
+ help='how many subprocesses to use for data loading.')
+ data_args.add_argument(
+ '--num-proc',
+ type=int,
+ default=8,
+ help='how many subprocesses to use for data mapping.')
+
+ optim_args = parser.add_argument_group('optim', 'Optim Related Settings')
+ optim_args.add_argument(
+ '--mirco-batch-size',
+ type=int,
+ default=1,
+ help='batch size for each forward + backward pass')
+ optim_args.add_argument(
+ '--global-batch-size',
+ type=int,
+ default=16,
+ help='batch size for each parameter update')
+
+ optim_args.add_argument(
+ '--lr', default=4e-5, type=float, help='learning rate.')
+ optim_args.add_argument(
+ '--wd', default=0.01, type=float, help='weight decay.')
+ optim_args.add_argument(
+ '--max-grad-norm', default=1, type=float, help='gradient clipping')
+ optim_args.add_argument(
+ '-e', '--epochs', default=1, type=int, help='total training epochs.')
+ optim_args.add_argument(
+ '--warmup-ratio',
+ default=0.03,
+ type=float,
+ help=('the proportion of training steps for learning rate warm-up in '
+ 'relation to the total training steps.'))
+
+ parser.add_argument('-c', '--config', default=None)
+ parser.add_argument(
+ '--work-dir',
+ default='work_dirs',
+ help='the dir to save logs and checkpoints')
+ parser.add_argument(
+ '--checkpoint-interval',
+ default=0.25,
+ type=float,
+ help=('how many steps to save a checkpoint; it can be a floating '
+ 'point number less than 1, or an integer greater than or equal '
+ "to 1. When it's a floating point, it will be multiplied by the "
+ 'total number of training steps.'))
+ parser.add_argument(
+ '--checkpoint-drop-optimizer',
+ action='store_true',
+ help=('only model parameters are saved when saving a checkpoint. '
+ 'This can significantly reduce the size of checkpoint files, '
+ 'but the saved checkpoints cannot be resumed.'))
+ parser.add_argument(
+ '--log-interval', default=1, type=int, help='log interval')
+ parser.add_argument(
+ '--resume',
+ type=str,
+ default=None,
+ help='specify checkpoint path to be resumed from.')
+ parser.add_argument(
+ '--seed', type=int, default=0, help='random seed for the training')
+ parser.add_argument(
+ '--debug', action='store_true', help='Set logger level to `DEBUG`')
+ args = parser.parse_args()
+ return args
+
+
+def is_interval(step, total_steps, interval):
+ return (step + 1) % interval == 0 or (step + 1) == total_steps
+
+
+def map_meta_modules(model, meta_model):
+ modules = {name: mod for name, mod in model.named_modules()}
+ meta_module_map = {
+ mod: modules[name]
+ for name, mod in meta_model.named_modules()
+ }
+ return meta_module_map
+
+
+def build_llava_model(args, config, world_size, dtype=torch.float32):
+ _cfg = copy.deepcopy(config)
+
+ with LoadWoInit():
+ llava = LlavaForConditionalGeneration.from_pretrained(
+ args.llava, config=_cfg)
+
+ llava.to(dtype)
+
+ if args.freeze_llm or args.llm_use_lora:
+ llava.language_model.requires_grad_(False)
+ if world_size > 1:
+ llava.language_model.to(dtype)
+
+ if args.freeze_vit or args.vit_use_lora:
+ llava.vision_tower.requires_grad_(False)
+ if world_size > 1:
+ llava.vision_tower.to(dtype)
+
+ if args.llm_use_lora:
+ llm = llava.language_model
+ if args.llm_lora_targets is None:
+ llm_cls = llm.__class__.__name__
+ args.llm_lora_targets = LORA_TARGET_MAP[llm_cls]
+ llm_lora_cfg = LoraConfig(
+ target_modules=args.llm_lora_targets,
+ r=args.llm_lora_r,
+ lora_alpha=args.llm_lora_alpha,
+ lora_dropout=args.llm_lora_dropout,
+ bias=args.llm_lora_bias,
+ task_type='CAUSAL_LM')
+ lora_llm = get_peft_model(llm, llm_lora_cfg)
+ llava.language_model = lora_llm
+
+ if args.vit_use_lora:
+ vit = llava.vision_tower
+ if args.vit_lora_targets is None:
+ vit_cls = vit.__class__.__name__
+ args.vit_lora_targets = LORA_TARGET_MAP[vit_cls]
+ vit_lora_cfg = LoraConfig(
+ target_modules=args.vit_lora_targets,
+ r=args.vit_lora_r,
+ lora_alpha=args.vit_lora_alpha,
+ lora_dropout=args.vit_lora_dropout,
+ bias=args.vit_lora_bias,
+ )
+ llava.vision_tower = get_peft_model(vit, vit_lora_cfg)
+
+ return llava
+
+
+# @logger.catch
+def llava_sft(args):
+ ###########################################################################
+ # 1. Environment #
+ ###########################################################################
+ if args.llm_use_lora:
+ args.freeze_llm = True
+
+ if args.vit_use_lora:
+ args.freeze_vit = True
+
+ dist_launcher = infer_launcher()
+ init_dist(dist_launcher)
+ set_random_seed(args.seed)
+
+ world_size = int(os.environ['WORLD_SIZE'])
+ dp_size = world_size
+
+ if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ f'should be divisible by the world_size{world_size}.')
+
+ if (args.global_batch_size / dp_size) % args.mirco_batch_size:
+ raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
+ f'should be divisible by the world_size{world_size}*'
+ f'`mirco_batch_size`({args.mirco_batch_size})')
+
+ # During data packing, it is essential to tokenize the data in
+ # advance, cache the tokenized data, so that it can be quickly
+ # loaded for the second training without the need to re-tokenize.
+ if args.dset_cache_dir and os.path.isdir(args.dset_cache_dir):
+ if len(os.listdir(args.dset_cache_dir)):
+ logger.warning(f'`{args.dset_cache_dir}` is not an empty '
+ 'folder, which may lead to inaccurate '
+ 'cache results.')
+
+ device_mesh = init_device_mesh(
+ 'cuda', (dp_size, ), mesh_dim_names=('dp', ))
+
+ dp_mesh = device_mesh['dp']
+
+ rank = dp_mesh.get_local_rank()
+ timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
+
+ objects = [timestamp]
+ dist.broadcast_object_list(objects, src=0)
+ timestamp = objects[0]
+
+ args.work_dir = os.path.join(args.work_dir, timestamp)
+ mkdir_or_exist(args.work_dir)
+
+ log_file = os.path.join(args.work_dir, f'rank{rank}.log')
+
+ # Change the log format printed in the terminal
+ lvl = 'DEBUG' if args.debug else 'INFO'
+ logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
+ # Change the format saved in the log file
+ logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)
+
+ logger.info(args)
+ if rank == 0:
+ env = collect_env()
+ import transformers
+
+ import xtuner
+ env['Transformers'] = transformers.__version__
+ env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
+ runtime_env = OrderedDict()
+ runtime_env.update(env)
+ runtime_env['Seed'] = args.seed
+ runtime_env['World Size'] = world_size
+ runtime_env['Distributed launcher'] = dist_launcher
+
+ runtime_env_info = '\n ' + '\n '.join(
+ f'{k}: {v}' for k, v in runtime_env.items())
+ dash_line = '-' * 60
+ logger.info('\n' + dash_line + '\nRuntime environment:' +
+ runtime_env_info + '\n' + dash_line + '\n')
+
+ shutil.copy(__file__, args.work_dir)
+ # ------------------- Environment End ------------------------------ #
+
+ ###########################################################################
+ # 2. Dataset & Dataloader #
+ ###########################################################################
+
+ start_load_data_t = time.time()
+
+ chat_template = CHAT_TEMPLATE_MAP[args.chat_template]
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer if args.tokenizer else args.llava,
+ trust_remote_code=True,
+ padding_side='right')
+
+ llava_config = AutoConfig.from_pretrained(args.llava)
+
+ processor = AutoProcessor.from_pretrained(
+ args.llava, trust_remote_code=True)
+ img_processor = processor.image_processor
+
+ _crop_size = processor.image_processor.crop_size
+ patch_size = llava_config.vision_config.patch_size
+ img_size = (_crop_size['height'], _crop_size['width'])
+ per_img_tokens = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+
+ img_token = chat_template.image_token
+ assert len(tokenizer.convert_tokens_to_ids([img_token])) == 1
+
+ if args.dset_from_cache:
+ if args.dset_pack_level == 'soft':
+ init_fn = partial(
+ SoftPackerForLlava.from_cache,
+ image_processor=img_processor,
+ max_length=args.max_length)
+ else:
+ init_fn = partial(
+ LlavaTokenizedDataset.from_cache,
+ image_processor=img_processor,
+ max_length=args.max_length)
+ _datasets = load_from_cache(args.dset_cache_dir, init_fn)
+ dist.barrier()
+ else:
+ dset_infos = load(args.datasets)
+
+ sample_ratios = []
+ annotations = []
+ init_fns = []
+ tokenize_fns = []
+ for _, info in dset_infos.items():
+ if 'format' in info:
+ dset_format = info['format']
+ else:
+ dset_format = 'llava'
+
+ if 'image_dir' in info:
+ image_dir = info['image_dir']
+ else:
+ image_dir = None
+
+ # If your data format is not in `SUPPORT_DATA_FORMATS`, you should
+ # redefine a `tokenize_fn`, defining how to convert a piece of raw
+ # data into tokenized data.
+ # The tokenized data must include `input_ids`, `labels``,
+ # and `num_tokens`.
+ tokenize_fn = LlavaTokenizeFunction(tokenizer, chat_template,
+ per_img_tokens, image_dir,
+ dset_format)
+
+ if args.dset_pack_level == 'soft':
+ init_fn = partial(
+ SoftPackerForLlava,
+ image_processor=img_processor,
+ max_length=args.max_length)
+ else:
+ init_fn = partial(
+ LlavaTokenizedDataset,
+ image_processor=img_processor,
+ max_length=args.max_length)
+
+
+ init_fns.append(init_fn)
+ tokenize_fns.append(tokenize_fn)
+ sample_ratios.append(info['sample_ratio'])
+ annotations.append(info['annotations'])
+
+ _datasets = load_datasets(
+ paths=annotations,
+ sources='local',
+ cache_dir=args.dset_cache_dir,
+ file_types=args.dset_file_types,
+ sample_ratios=sample_ratios,
+ num_proc=args.num_proc,
+ map_fns=tokenize_fns,
+ init_fns=init_fns)
+
+ if args.dset_pack_level and rank == 0:
+ # Only the tokenized datasets can count the number of tokens
+ total_tokens = sum(dset.total_tokens for dset in _datasets)
+ logger.debug(f'[Dataset] {total_tokens} tokens.')
+
+ train_dataset = ConcatDataset(_datasets)
+
+ if args.dset_pack_level and rank == 0:
+ ori_samples = sum([dset.num_samples for dset in _datasets])
+ packed_samples = len(train_dataset)
+ logger.info(f'[Dataset] (Original) {ori_samples} samples.')
+ logger.info(f'[Dataset] (Packed) {packed_samples} samples.')
+
+ pack_batch = is_flash_attn_2_available()
+ collator = LlavaCollator(pack_batch=pack_batch)
+
+ if args.group_by_length:
+ sampler = LengthGroupedSampler(train_dataset, dp_mesh,
+ args.global_batch_size)
+ else:
+ sampler = ParallelSampler(
+ train_dataset, dp_mesh, args.global_batch_size, shuffle=False)
+
+ dist.barrier()
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.mirco_batch_size,
+ num_workers=args.num_workers,
+ sampler=sampler,
+ collate_fn=collator,
+ persistent_workers=args.num_workers > 0)
+
+ if rank == 0:
+ logger.info(f'[Dataloader] {len(train_dataloader)} batches.')
+ _first_batch = [train_dataset[i] for i in range(args.mirco_batch_size)]
+ _first_batch = collator(_first_batch)
+ _decoded = tokenizer.batch_decode(_first_batch['input_ids'])
+ logger.debug(f'[Dataloader] Training Batch:\n{_first_batch}')
+ logger.debug(f'[Dataloader] Training Batch(Decoded):\n{_decoded}')
+ dist.barrier()
+
+ load_data_cost_time = time.time() - start_load_data_t
+ logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s')
+ # ------------------- Dataset & Dataloader End --------------------- #
+
+ ###########################################################################
+ # 3. FSDP #
+ ###########################################################################
+
+ start_model_t = time.time()
+
+ if args.dtype == 'auto':
+ args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16'
+
+ if args.dtype == 'fp16':
+ dtype = torch.float16
+ autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
+ scaler = ShardedGradScaler()
+ elif args.dtype == 'bf16':
+ if torch.cuda.is_bf16_supported():
+ dtype = torch.bfloat16
+ autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
+ scaler = None
+ else:
+ raise RuntimeError('The device does not support `bf16`, '
+ 'please set `dtype` to `fp16`.')
+ else:
+ raise RuntimeError('`dtype` only supports `fp16`,`bf16`, or `auto`, '
+ f'but found {args.dtype}.')
+
+ use_lora = args.llm_use_lora or args.vit_use_lora
+ if not use_lora:
+ autocast = nullcontext()
+ scaler = None
+
+ if is_flash_attn_2_available():
+ llava_config.text_config.attn_implementation = 'flash_attention_2'
+ elif is_torch_sdpa_available():
+ llava_config.text_config.attn_implementation = 'sdpa'
+ llava_config.text_config.use_cache = False
+
+ with torch.device('meta'):
+ # Ensure all numerical values in the optimizer are fp32.
+ # FSDP will use low precision during forward.
+ meta_llava = build_llava_model(args, llava_config, world_size,
+ torch.float32)
+
+ if pack_batch or args.dset_pack_level:
+ dispatch_modules(meta_llava)
+
+ # Only load parameters on rank 0 to avoid each rank repeatedly loading the
+ # same model into the CPU, wasting memory
+ if rank == 0:
+ with torch.device('cpu'):
+ llava = build_llava_model(args, llava_config, world_size, dtype)
+ rank0_meta_llava = copy.deepcopy(meta_llava)
+ meta_llava_map = map_meta_modules(llava, meta_llava)
+ else:
+ meta_llava_map = None
+
+ dist.barrier()
+
+ param_init_fn = partial(
+ dp_lazy_init, module_map=meta_llava_map, dp_mesh=dp_mesh)
+
+ policies = [layer_auto_wrap_policy]
+ if args.llm_use_lora or args.vit_use_lora:
+ policies.append(all_required_grad_wrap_policy)
+
+ torch.cuda.reset_peak_memory_stats()
+ shard_llava = FSDP(
+ meta_llava,
+ device_mesh=dp_mesh,
+ auto_wrap_policy=partial(_or_policy, policies=policies),
+ mixed_precision=MixedPrecision(
+ param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype),
+ device_id=torch.cuda.current_device(),
+ use_orig_params=True,
+ param_init_fn=param_init_fn,
+ sync_module_states=True,
+ )
+
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Model] The peak GPU memory when building the FSDP model is '
+ f'{max_memory/1024**3:.1f}GB.')
+
+ if args.selective_recompute:
+ check_fn = partial(
+ checkpoint_check_fn,
+ target=RECOMPUTE_MODULES,
+ selective=args.selective_recompute)
+ apply_activation_checkpointing(shard_llava, check_fn=check_fn)
+
+ fsdp_cost_time = time.time() - start_model_t
+ logger.info(f'[Model] Cost {fsdp_cost_time:.2f}s')
+ # -------------------------- FSDP End ------------------------------ #
+
+ ###########################################################################
+ # 4. Optimizer & Scheduler #
+ ###########################################################################
+ requried_grad_params = [
+ param for param in shard_llava.parameters() if param.requires_grad
+ ]
+ optimizer = AdamW(
+ requried_grad_params, lr=args.lr, weight_decay=args.wd, fused=True)
+
+ global_batch_size = args.global_batch_size
+ mirco_batch_size = args.mirco_batch_size
+
+ # `iter` means once forward+backward
+ # `step` means once optimizer step
+ # `per_step_iters` means gradient accumulative counts
+ per_step_iters = global_batch_size // mirco_batch_size // dp_size
+ per_epoch_iters = len(train_dataloader)
+ per_epoch_steps = math.ceil(per_epoch_iters / per_step_iters)
+
+ total_epochs = args.epochs
+ total_steps = per_epoch_steps * total_epochs
+
+ if args.checkpoint_interval == -1:
+ checkpoint_interval = total_steps
+ elif args.checkpoint_interval < 1:
+ checkpoint_interval = int(total_steps * args.checkpoint_interval)
+ else:
+ checkpoint_interval = int(args.checkpoint_interval)
+
+ warmup_steps = int(args.warmup_ratio * total_steps)
+
+ def warmup_fn(x):
+ return x / warmup_steps if x < warmup_steps else 1
+
+ warmup_scheduler = LambdaLR(optimizer, warmup_fn)
+
+ cosine_scheduler = CosineAnnealingLR(
+ optimizer, T_max=total_steps - warmup_steps, eta_min=0)
+
+ start_step = 0
+
+ # ---------------- Optimizer & Scheduler End ----------------------- #
+
+ ###########################################################################
+ # 5. Training #
+ ###########################################################################
+
+ start_train_t = time.time()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Train] Begin Train Loop. The current GPU memory is '
+ f'{(max_memory / 1024**3):.1f}GB')
+ for step in range(start_step, total_steps):
+
+ epoch = step // per_epoch_steps
+ epoch_inner_step = step % per_epoch_steps
+ if epoch_inner_step == 0 or step == start_step:
+ # For the first step of each epoch, the data order needs to be
+ # readjusted.
+ # Or after resuming, for the first step, the dataloader needs to
+ # be adjusted to the position before resume.
+ # train_dataloader.sampler.set_epoch(epoch, inner_step)
+ train_dataloader.sampler.set_epoch(epoch, epoch_inner_step)
+ data_iterator = iter(train_dataloader)
+
+ if step < warmup_steps:
+ warmup_scheduler.step()
+ cur_lr = warmup_scheduler.get_lr()[0]
+ else:
+ cosine_scheduler.step()
+ cur_lr = cosine_scheduler.get_lr()[0]
+
+ torch.cuda.reset_peak_memory_stats()
+
+ step_loss = 0
+ step_data_time = 0
+ step_start_t = time.time()
+ step_consumed_tokens = 0
+ step_consumed_img_tokens = 0
+ for _ in range(per_step_iters):
+
+ _data_start_t = time.time()
+ data = next(data_iterator)
+ step_data_time += time.time() - _data_start_t
+
+ input_ids = data['input_ids'].cuda()
+ pixel_values = data['pixel_values']
+ if isinstance(pixel_values, torch.Tensor):
+ pixel_values = pixel_values.cuda()
+ labels = data['labels'].cuda()
+ attention_mask = data['attention_mask'].cuda()
+ num_tokens = data['num_tokens'].cuda()
+ num_img_tokens = data['num_img_tokens'].cuda()
+
+ packed_ctx = packed_sequence(num_tokens, enable=pack_batch)
+
+ with packed_ctx:
+ with autocast if use_lora else nullcontext():
+ outputs = shard_llava(
+ input_ids=input_ids,
+ labels=labels,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask)
+ avg_iter_loss = outputs.loss / per_step_iters
+
+ if scaler and use_lora:
+ scaler.scale(avg_iter_loss).backward()
+ else:
+ avg_iter_loss.backward()
+
+ step_loss += avg_iter_loss.item()
+ step_consumed_img_tokens += num_img_tokens.sum()
+
+ if args.dset_pack_level == 'soft':
+ # During a soft pack process, the data with a length that is
+ # still smaller than the max length after packing, will be
+ # padded to the max length. The last element of num tokens
+ # represents the count of pad tokens.
+ step_consumed_tokens += num_tokens[:-1].sum()
+ else:
+ step_consumed_tokens += num_tokens.sum()
+
+ grad_norm = shard_llava.clip_grad_norm_(args.max_grad_norm)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ step_text_tokens = step_consumed_tokens - step_consumed_img_tokens
+ step_img_tokens = step_consumed_img_tokens
+ step_time = time.time() - step_start_t
+ eta = step_time * (total_steps - step)
+ eta = timedelta(seconds=int(eta))
+ tgs = int(step_consumed_tokens / step_time)
+ max_memory = torch.cuda.max_memory_allocated()
+ if is_interval(step, total_steps, args.log_interval):
+ logger.info(
+ f'[Train] (Epoch {epoch}) Step {step+1}/{total_steps} ' # noqa: E501
+ f'lr: {cur_lr:.6f} loss: {step_loss:.3f} '
+ f'grad_norm: {grad_norm:.2f} '
+ f'max_memory: {(max_memory / 1024**3):.1f}GB '
+ f'text_tokens: {step_text_tokens} '
+ f'image_tokens: {step_img_tokens} '
+ f'tgs: {tgs} data_time: {step_data_time:.2f}s '
+ f'time: {step_time:.2f}s '
+ f'eta: {eta}')
+
+ if is_interval(step, total_steps, checkpoint_interval):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info('[Checkpoint] Before saving checkpoint, the peak GPU '
+ f'memory is {max_memory/1024**3:.1f}GB.')
+
+ digits = len(str(abs(total_steps)))
+ work_dir = args.work_dir
+
+ ckpt_id = f'{(step+1):0{digits}}-of-{total_steps:0{digits}}'
+ ckpt_dir = os.path.join(work_dir, f'ckpt-{ckpt_id}')
+ hf_dir = os.path.join(work_dir, f'hf-{ckpt_id}')
+ _options = StateDictOptions(cpu_offload=True, full_state_dict=True)
+
+ full_model_state_dict = get_model_state_dict(
+ shard_llava, options=_options)
+ if rank == 0:
+ saved_llava = copy.deepcopy(rank0_meta_llava)
+ saved_llava.to(dtype)
+ for name, param in full_model_state_dict.items():
+ set_module_tensor_to_device(saved_llava, name, 'cpu',
+ param)
+
+ if args.llm_use_lora:
+ merged_llm = saved_llava.language_model.merge_and_unload()
+ saved_llava.language_model = merged_llm
+
+ if args.vit_use_lora:
+ merged_vit = saved_llava.vision_tower.merge_and_unload()
+ saved_llava.vision_tower = merged_vit
+
+ saved_llava.save_pretrained(hf_dir)
+ tokenizer.save_pretrained(hf_dir)
+ processor.save_pretrained(hf_dir)
+ del saved_llava
+
+ dist.barrier()
+ del full_model_state_dict
+
+ if args.checkpoint_drop_optimizer:
+ logger.warning('[Checkpoint] The saved checkpoint cannot be '
+ 'resumed. If you want to save a resumable '
+ 'checkpoint, please remove '
+ '`--checkpoint-drop-optimizer` '
+ 'from the command.')
+ else:
+ # FSDP cannot be saved via torch.save
+ # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501
+ _options = StateDictOptions(
+ cpu_offload=True, ignore_frozen_params=True)
+ (shard_model_state_dict,
+ shard_optimizer_state_dict) = get_state_dict(
+ shard_llava, optimizer, options=_options)
+
+ state_dict = {
+ 'model': shard_model_state_dict,
+ 'optimizer': shard_optimizer_state_dict,
+ 'step': step,
+ 'total_steps': total_steps,
+ 'warmup_scheduler': warmup_scheduler.state_dict(),
+ 'cosine_scheduler': cosine_scheduler.state_dict()
+ }
+
+ writer = dcp.FileSystemWriter(ckpt_dir)
+ mkdir_or_exist(ckpt_dir)
+ dcp.save(state_dict, writer)
+
+ max_memory = torch.cuda.max_memory_allocated()
+ logger.info(
+ '[Checkpoint] During saving checkpoint, the peak GPU '
+ f'memory is {max_memory/1024**3:.1f}GB.')
+
+ train_cost_time = time.time() - start_train_t
+ logger.info(f'[Train] Cost {train_cost_time}s')
+ # ------------------------ Training End ---------------------------- #
+
+
+if __name__ == '__main__':
+
+ args = parse_args()
+ llava_sft(args)
diff --git a/tools/llava_pretrain.json b/tools/llava_pretrain.json
new file mode 100644
index 000000000..727907d74
--- /dev/null
+++ b/tools/llava_pretrain.json
@@ -0,0 +1,7 @@
+{
+ "llava_pretrain": {
+ "image_dir": "/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/LLaVA-Pretrain/images",
+ "annotations": "/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json",
+ "sample_ratio": 1.0
+ }
+}
diff --git a/xtuner/_lite/__init__.py b/xtuner/_lite/__init__.py
new file mode 100644
index 000000000..192fd51f8
--- /dev/null
+++ b/xtuner/_lite/__init__.py
@@ -0,0 +1,65 @@
+import sys
+from loguru import logger
+import os
+import subprocess
+
+from .auto import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+from .device import get_device, get_torch_device_module
+
+
+_LOGGER = None
+
+def log_format(debug=False):
+
+ formatter = '[XTuner][{time:YYYY-MM-DD HH:mm:ss}][{level}]'
+
+ if debug:
+ formatter += '[{name}:'
+ formatter += '{function}:'
+ formatter += '{line}]'
+
+ formatter += ' {message}'
+ return formatter
+
+
+def get_logger(level="INFO"):
+ global _LOGGER
+ if _LOGGER is None:
+ # Remove the original logger in Python to prevent duplicate printing.
+ logger.remove()
+ logger.add(sys.stderr, level=level, format=log_format(debug=level=="DEBUG"))
+ _LOGGER = logger
+ return _LOGGER
+
+
+def get_repo_git_info(repo_path):
+ original_directory = os.getcwd()
+ os.chdir(repo_path)
+
+ try:
+ branch = subprocess.check_output(
+ ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
+ stderr=subprocess.STDOUT
+ ).strip().decode('utf-8')
+
+ commit_id = subprocess.check_output(
+ ['git', 'rev-parse', 'HEAD'],
+ stderr=subprocess.STDOUT
+ ).strip().decode('utf-8')
+
+ remote_url = subprocess.check_output(
+ ['git', 'remote', 'get-url', 'origin'],
+ stderr=subprocess.STDOUT
+ ).strip().decode('utf-8')
+
+ return branch, commit_id, remote_url
+ except subprocess.CalledProcessError as e:
+ return None, None, None
+ finally:
+ os.chdir(original_directory)
+
+
+__all__ = [
+ 'AutoConfig', 'AutoModelForCausalLM', 'AutoTokenizer', 'get_device',
+ 'get_torch_device_module'
+]
diff --git a/xtuner/_lite/accelerate/__init__.py b/xtuner/_lite/accelerate/__init__.py
new file mode 100644
index 000000000..15316a252
--- /dev/null
+++ b/xtuner/_lite/accelerate/__init__.py
@@ -0,0 +1,14 @@
+from .dispatches import dispatch_hf_code
+from .generate import contiguous_batching_generate
+from .load import LoadWoInit
+from .lora import LORA_TARGET_MAP
+from .packed import pack_sequence, packed_sequence, unpack_sequence
+from .utils import (lmdeploy_is_available, npu_is_available, liger_kernel_is_available,
+ profile_time_and_memory, varlen_attn_is_available)
+
+__all__ = [
+ 'dispatch_hf_code', 'contiguous_batching_generate', 'LoadWoInit',
+ 'LORA_TARGET_MAP', 'pack_sequence', 'packed_sequence', 'unpack_sequence',
+ 'varlen_attn_is_available', 'lmdeploy_is_available', 'npu_is_available',
+ 'profile_time_and_memory'
+]
diff --git a/xtuner/_lite/accelerate/dispatches/__init__.py b/xtuner/_lite/accelerate/dispatches/__init__.py
new file mode 100644
index 000000000..6f470c258
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .huggingface import dispatch_hf_code
+
+__all__ = ['dispatch_hf_code']
diff --git a/xtuner/_lite/accelerate/dispatches/_attention.py b/xtuner/_lite/accelerate/dispatches/_attention.py
new file mode 100644
index 000000000..c723e549d
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/_attention.py
@@ -0,0 +1,179 @@
+import math
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+from xtuner._lite import get_device
+from xtuner._lite.parallel import sequence_parallel_wrapper
+
+try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import (index_first_axis, pad_input,
+ unpad_input)
+except ImportError:
+ pass
+
+
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def upad_qkv(query_layer, key_layer, value_layer, attention_mask,
+ query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
+ attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ if query_length == kv_seq_len:
+ # Different from the origin version as sequence parallel change
+ # the number of attention heads.
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, -1, head_dim),
+ indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \
+ unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+@sequence_parallel_wrapper
+def flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=True,
+ window_size=(-1, -1), # -1 means infinite context window
+):
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=window_size)
+ return attn_output
+
+
+@sequence_parallel_wrapper
+def flash_attn_w_mask(
+ query_states, # bs, q_len, nhead, h_dim
+ key_states,
+ value_states,
+ attention_mask,
+ causal=True,
+ dropout_p=0.0,
+ window_size=(-1, -1), # -1 means infinite context window
+):
+ batch_size, q_len = query_states.shape[:2]
+ query_states, key_states, value_states, indices_q, \
+ cu_seq_lens, max_seq_lens = upad_qkv(
+ query_states, key_states, value_states, attention_mask, q_len)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout_p,
+ causal=causal,
+ window_size=window_size)
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
+ return attn_output
+
+
+@sequence_parallel_wrapper
+def varlen_flash_attn(
+ query_states,
+ key_states,
+ value_states,
+ cumulative_len,
+ max_seqlen,
+ dropout_p=0.,
+ causal=True,
+ window_size=(-1, -1), # -1 means infinite context window
+):
+ q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
+ 0, 1), value_states.flatten(0, 1)
+
+ device = get_device()
+ if device == 'cuda':
+ attn_output = flash_attn_varlen_func(
+ q_unpad,
+ k_unpad,
+ v_unpad,
+ cumulative_len,
+ cumulative_len,
+ max_seqlen,
+ max_seqlen,
+ dropout_p=dropout_p,
+ return_attn_probs=False,
+ causal=causal,
+ window_size=window_size)
+ attn_output = attn_output.unsqueeze(0)
+ elif device == 'npu':
+ import torch_npu
+ atten_mask_npu = torch.from_numpy(
+ np.triu(np.ones([max_seqlen, max_seqlen]), k=1)).bool().to(device)
+ head_num = q_unpad.shape[1]
+ attn_output = torch_npu.npu_fusion_attention(
+ q_unpad,
+ k_unpad,
+ v_unpad,
+ head_num,
+ pse=None,
+ padding_mask=None,
+ atten_mask=atten_mask_npu,
+ scale=1.0 / math.sqrt(q_unpad.shape[-1]),
+ keep_prob=1,
+ input_layout='TND',
+ actual_seq_qlen=tuple(cumulative_len[1:].cpu().numpy().tolist()),
+ actual_seq_kvlen=tuple(cumulative_len[1:].cpu().numpy().tolist()),
+ pre_tockens=2147483647,
+ next_tockens=0,
+ inner_precise=0)[0]
+ attn_output = attn_output.unsqueeze(0)
+ return attn_output
diff --git a/xtuner/_lite/accelerate/dispatches/_fused/__init__.py b/xtuner/_lite/accelerate/dispatches/_fused/__init__.py
new file mode 100644
index 000000000..60b38809b
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/_fused/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .layer_norm import layer_norm_forward
+from .rms_norm import rms_norm_forward
+
+# from .rotary import apply_rotary_emb
+
+__all__ = ['rms_norm_forward', 'layer_norm_forward', 'apply_rotary_emb']
diff --git a/xtuner/_lite/accelerate/dispatches/_fused/layer_norm.py b/xtuner/_lite/accelerate/dispatches/_fused/layer_norm.py
new file mode 100644
index 000000000..010ff07c5
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/_fused/layer_norm.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+from torch.distributed._tensor import DTensor
+
+def layer_norm_forward(self, hidden_states):
+
+ if isinstance(self.weight, DTensor):
+ weight = self.weight.full_tensor()
+ else:
+ weight = self.weight
+
+ if isinstance(self.bias, DTensor):
+ bias = self.bias.full_tensor()
+ else:
+ bias = self.bias
+
+ if isinstance(hidden_states, DTensor):
+ hidden_states = hidden_states.full_tensor()
+ else:
+ hidden_states = hidden_states
+
+ return F.layer_norm(
+ hidden_states, self.normalized_shape, weight, bias, self.eps
+ )
\ No newline at end of file
diff --git a/xtuner/_lite/accelerate/dispatches/_fused/rms_norm.py b/xtuner/_lite/accelerate/dispatches/_fused/rms_norm.py
new file mode 100644
index 000000000..0ef772f9e
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/_fused/rms_norm.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from transformers.utils.import_utils import is_flash_attn_2_available
+from torch.distributed._tensor import DTensor
+from xtuner._lite.accelerate import lmdeploy_is_available, npu_is_available
+
+
+def rms_norm_forward(self, hidden_states):
+
+ from torch.distributed._functional_collectives import AsyncCollectiveTensor
+ if isinstance(hidden_states, AsyncCollectiveTensor):
+ hidden_states = hidden_states.wait()
+ if (hidden_states.device == torch.device('cpu')
+ or self.weight.device == torch.device('cpu')):
+ raise RuntimeError(
+ 'Can not use triton kernels on cpu. Please set `USE_TRITON_KERNEL`'
+ ' environment variable to 0 before training.')
+
+ if isinstance(self.weight, DTensor):
+ weight = self.weight.full_tensor()
+ else:
+ weight = self.weight
+
+ if lmdeploy_is_available() and not self.training:
+ from lmdeploy.pytorch.kernels import rms_norm
+ ret = rms_norm(hidden_states, weight, eps=self.variance_epsilon)
+ elif is_flash_attn_2_available():
+ # from flash_attn.ops.triton.layer_norm import rms_norm_fn
+ try:
+ from flash_attn.ops.triton.layernorm import rms_norm_fn
+ except ImportError:
+ try:
+ from flash_attn.ops.triton.layer_norm import rms_norm_fn
+ except ImportError:
+ import flash_attn
+ raise ImportError(f'flash_attn version {flash_attn.__version__}')
+ ret = rms_norm_fn(
+ hidden_states, weight, None, eps=self.variance_epsilon)
+
+ elif npu_is_available():
+ import torch_npu
+ ret = torch_npu.npu_rms_norm(
+ hidden_states, weight, epsilon=self.variance_epsilon)[0]
+ return ret
diff --git a/xtuner/_lite/accelerate/dispatches/_fused/rotary.py b/xtuner/_lite/accelerate/dispatches/_fused/rotary.py
new file mode 100644
index 000000000..1e09c1662
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/_fused/rotary.py
@@ -0,0 +1,327 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py # noqa:E501
+from typing import Optional, Union
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def rotary_kernel(
+ OUT, # Pointers to matrices
+ X,
+ COS,
+ SIN,
+ CU_SEQLENS,
+ SEQLEN_OFFSETS, # this could be int or a pointer
+ # Matrix dimensions
+ seqlen,
+ rotary_dim,
+ seqlen_ro,
+ # strides
+ stride_out_batch,
+ stride_out_seqlen,
+ stride_out_nheads,
+ stride_out_headdim,
+ stride_x_batch,
+ stride_x_seqlen,
+ stride_x_nheads,
+ stride_x_headdim,
+ # Meta-parameters
+ BLOCK_K: tl.constexpr,
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ INTERLEAVED: tl.constexpr,
+ CONJUGATE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_batch = tl.program_id(axis=1)
+ pid_head = tl.program_id(axis=2)
+ rotary_dim_half = rotary_dim // 2
+
+ if not IS_VARLEN:
+ X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
+ OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
+ else:
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
+ X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
+ OUT = OUT + start_idx * stride_out_seqlen + \
+ pid_head * stride_out_nheads
+
+ if pid_m * BLOCK_M >= seqlen:
+ return
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ if not IS_SEQLEN_OFFSETS_TENSOR:
+ rm_cs = rm + SEQLEN_OFFSETS
+ else:
+ rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
+ rk = tl.arange(0, BLOCK_K)
+ rk_half = tl.arange(0, BLOCK_K // 2)
+
+ if not INTERLEAVED:
+ # Load the 1st and 2nd halves of X, do calculation,
+ # then store to 1st and 2nd halves of OUT
+ X = X + (
+ rm[:, None] * stride_x_seqlen +
+ rk_half[None, :] * stride_x_headdim)
+ # This is different from the official implementation as the shapes of
+ # the two tensors cos and sin are (seqlen_ro, rotary_dim) instead of
+ # (seqlen_ro, rotary_dim // 2).
+ COS = COS + (rm_cs[:, None] * rotary_dim + rk_half[None, :])
+ SIN = SIN + (rm_cs[:, None] * rotary_dim + rk_half[None, :])
+ cos = tl.load(
+ COS,
+ mask=(rm_cs[:, None] < seqlen_ro) &
+ (rk_half[None, :] < rotary_dim_half),
+ other=1.0).to(tl.float32)
+ sin = tl.load(
+ SIN,
+ mask=(rm_cs[:, None] < seqlen_ro) &
+ (rk_half[None, :] < rotary_dim_half),
+ other=0.0).to(tl.float32)
+ x0 = tl.load(
+ X,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
+ other=0.0).to(tl.float32)
+ x1 = tl.load(
+ X + rotary_dim_half * stride_x_headdim,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
+ other=0.0,
+ ).to(tl.float32)
+ if CONJUGATE:
+ sin = -sin
+ o0 = x0 * cos - x1 * sin
+ o1 = x0 * sin + x1 * cos
+ # write back result
+ OUT = OUT + (
+ rm[:, None] * stride_out_seqlen +
+ rk_half[None, :] * stride_out_headdim)
+ tl.store(
+ OUT,
+ o0,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
+ tl.store(
+ OUT + rotary_dim_half * stride_out_headdim,
+ o1,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
+ )
+ else:
+ # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately
+ # since both are slow.
+ # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
+ # Loading x0 will be fast but x1 will be slow.
+ # Then we load cos = COS[0, 0, 1, 1, ...] and
+ # sin = SIN[0, 0, 1, 1, ...].
+ # Then we do the calculation and use tl.where to pick put the right
+ # outputs for the even and for the odd indices.
+ rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
+ rk_repeat = tl.arange(0, BLOCK_K) // 2
+ # This is different from the official implementation as the shapes of
+ # the two tensors cos and sin are (seqlen_ro, rotary_dim) instead of
+ # (seqlen_ro, rotary_dim // 2).
+ X0 = X + (
+ rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
+ X1 = X + (
+ rm[:, None] * stride_x_seqlen +
+ rk_swap[None, :] * stride_x_headdim)
+ COS = COS + (rm_cs[:, None] * rotary_dim + rk_repeat[None, :])
+ SIN = SIN + (rm_cs[:, None] * rotary_dim + rk_repeat[None, :])
+ cos = tl.load(
+ COS,
+ mask=(rm_cs[:, None] < seqlen_ro) &
+ (rk_repeat[None, :] < rotary_dim_half),
+ other=1.0,
+ ).to(tl.float32)
+ sin = tl.load(
+ SIN,
+ mask=(rm_cs[:, None] < seqlen_ro) &
+ (rk_repeat[None, :] < rotary_dim_half),
+ other=0.0,
+ ).to(tl.float32)
+ x0 = tl.load(
+ X0,
+ mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim),
+ other=0.0).to(tl.float32)
+ x1 = tl.load(
+ X1,
+ mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim),
+ other=0.0).to(tl.float32)
+ if CONJUGATE:
+ sin = -sin
+ x0_cos = x0 * cos
+ x1_sin = x1 * sin
+ out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
+ OUT = OUT + (
+ rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
+ tl.store(
+ OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
+
+
+def apply_rotary(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ interleaved=False,
+ inplace=False,
+ conjugate=False,
+) -> torch.Tensor:
+ """
+ Arguments:
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim).
+ cos: (seqlen_ro, rotary_dim)
+ sin: (seqlen_ro, rotary_dim)
+ seqlen_offsets: integer or integer tensor of size (batch,)
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Returns:
+ y: (batch, seqlen, nheads, headdim)
+ """
+ is_varlen = cu_seqlens is not None
+ if not is_varlen:
+ batch, seqlen, nheads, headdim = x.shape
+ else:
+ assert max_seqlen is not None, ('If cu_seqlens is passed in, '
+ 'then max_seqlen must be passed')
+ total_seqlen, nheads, headdim = x.shape
+ batch_p_1 = cu_seqlens.shape[0]
+ batch = batch_p_1 - 1
+ seqlen = max_seqlen
+ seqlen_ro, rotary_dim = cos.shape
+ assert sin.shape == cos.shape
+ # rotary_dim *= 2
+ assert rotary_dim <= headdim, 'rotary_dim must be <= headdim'
+ assert headdim <= 256, 'Only support headdim <= 256'
+ assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen'
+
+ assert (
+ cos.dtype == sin.dtype
+ ), f'cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}'
+ assert (x.dtype == cos.dtype), (
+ f'Input and cos/sin must have the same dtype, '
+ f'got {x.dtype} and {cos.dtype}')
+
+ cos, sin = cos.contiguous(), sin.contiguous()
+ if isinstance(seqlen_offsets, torch.Tensor):
+ assert seqlen_offsets.shape == (batch, )
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
+ seqlen_offsets = seqlen_offsets.contiguous()
+ else:
+ assert seqlen_offsets + seqlen <= seqlen_ro
+
+ output = torch.empty_like(x) if not inplace else x
+ if rotary_dim < headdim and not inplace:
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
+
+ BLOCK_K = (32 if rotary_dim <= 32 else
+ (64 if rotary_dim <= 64 else
+ (128 if rotary_dim <= 128 else 256)))
+
+ def grid(META):
+ return (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads)
+
+ BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
+
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton
+ # (cpu tensor?)
+ with torch.cuda.device(x.device.index):
+ rotary_kernel[grid](
+ output, # data ptrs
+ x,
+ cos,
+ sin,
+ cu_seqlens,
+ seqlen_offsets,
+ seqlen, # shapes
+ rotary_dim,
+ seqlen_ro,
+ output.stride(0)
+ if not is_varlen else 0, # batch_strides if not varlen else 0
+ output.stride(-3), # seqlen_stride or total_seqlen_stride
+ output.stride(-2), # nheads_stride
+ output.stride(-1), # headdim_stride
+ x.stride(0)
+ if not is_varlen else 0, # batch_strides if not varlen else 0
+ x.stride(-3), # seqlen stride or total_seqlen_stride
+ x.stride(-2), # nheads stride
+ x.stride(-1), # headdim stride
+ BLOCK_K,
+ isinstance(seqlen_offsets, torch.Tensor),
+ is_varlen,
+ interleaved,
+ conjugate,
+ BLOCK_M,
+ )
+ return output
+
+
+class ApplyRotaryEmb(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ cos,
+ sin,
+ interleaved=False,
+ inplace=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ ):
+ out = apply_rotary(
+ x,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ interleaved=interleaved,
+ inplace=inplace,
+ )
+ if isinstance(seqlen_offsets, int):
+ ctx.save_for_backward(
+ cos, sin, cu_seqlens) # Can't save int with save_for_backward
+ ctx.seqlen_offsets = seqlen_offsets
+ else:
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
+ ctx.seqlen_offsets = None
+ ctx.interleaved = interleaved
+ ctx.inplace = inplace
+ ctx.max_seqlen = max_seqlen
+ return out if not inplace else x
+
+ @staticmethod
+ def backward(ctx, do):
+ seqlen_offsets = ctx.seqlen_offsets
+ if seqlen_offsets is None:
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
+ else:
+ cos, sin, cu_seqlens = ctx.saved_tensors
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why.
+ # Triton 2.1.0 works.
+ if not ctx.interleaved and not ctx.inplace:
+ do = do.clone()
+ dx = apply_rotary(
+ do,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=ctx.max_seqlen,
+ interleaved=ctx.interleaved,
+ inplace=ctx.inplace,
+ conjugate=True,
+ )
+ return dx, None, None, None, None, None, None, None
+
+
+apply_rotary_emb = ApplyRotaryEmb.apply
diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/__init__.py b/xtuner/_lite/accelerate/dispatches/huggingface/__init__.py
new file mode 100644
index 000000000..4fb4b5c36
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/huggingface/__init__.py
@@ -0,0 +1,154 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import types
+
+from xtuner._lite import get_logger
+
+logger = get_logger()
+
+
+def _dispatch_forward_fn(module, dispatch_fn):
+ module.forward = types.MethodType(dispatch_fn, module)
+
+
+def _dispatch_qwen2_attn_flash_forward(module):
+ assert module.__class__.__name__ in ['Qwen2FlashAttention2', 'Qwen2Attention', 'Qwen2SdpaAttention']
+ from .qwen2 import qwen2_attn_flash_forward
+ from xtuner._lite.accelerate import varlen_attn_is_available
+ if varlen_attn_is_available():
+ _dispatch_forward_fn(module, qwen2_attn_flash_forward)
+ return qwen2_attn_flash_forward.__name__
+
+def _dispatch_qwen2_casual_forward(module):
+ assert module.__class__.__name__ in ['Qwen2ForCausalLM']
+ from .qwen2 import qwen2_casual_forward
+ _dispatch_forward_fn(module, qwen2_casual_forward)
+ return qwen2_casual_forward.__name__
+
+
+def _dispatch_internlm2_varlen_attn_forward(module):
+ assert module.__class__.__name__ in ['InternLM2FlashAttention2', 'InternLM2Attention', 'InternLM2SdpaAttention']
+ from .internlm2 import internlm2_varlen_attn_forward
+ from xtuner._lite.accelerate import varlen_attn_is_available
+ if varlen_attn_is_available():
+ _dispatch_forward_fn(module, internlm2_varlen_attn_forward)
+ return internlm2_varlen_attn_forward.__name__
+
+def _dispatch_internlm2_casual_forward(module):
+ assert module.__class__.__name__ in ['InternLM2ForCausalLM']
+ from .internlm2 import internlm2_causal_forward
+ _dispatch_forward_fn(module, internlm2_causal_forward)
+ return internlm2_causal_forward.__name__
+
+
+def _dispatch_internlm3_varlen_self_attn_forward(module):
+ assert module.__class__.__name__ in ['InternLM3FlashSelfAttention2']
+ from .internlm3 import internlm3_self_attn_forward
+ from xtuner._lite.accelerate import varlen_attn_is_available
+ if varlen_attn_is_available():
+ _dispatch_forward_fn(module, internlm3_self_attn_forward)
+ return internlm3_self_attn_forward.__name__
+
+def _dispatch_internlm3_varlen_cross_attn_forward(module):
+ assert module.__class__.__name__ in ['InternLM3FlashCrossAttention2']
+ from .internlm3 import internlm3_cross_attn_forward
+ from xtuner._lite.accelerate import varlen_attn_is_available
+ if varlen_attn_is_available():
+ _dispatch_forward_fn(module, internlm3_cross_attn_forward)
+ return internlm3_cross_attn_forward.__name__
+
+def _dispatch_internlm3_cross_decoder_forward(module):
+ assert module.__class__.__name__ == 'InternLM3CrossDecoder'
+ from .internlm3 import internlm3_cross_decoder_forward
+ _dispatch_forward_fn(module, internlm3_cross_decoder_forward)
+ return internlm3_cross_decoder_forward.__name__
+
+
+def _dispatch_internlm2_reward_forward(module):
+ assert module.__class__.__name__ == 'InternLM2ForRewardModel'
+ from .internlm2 import internlm2_reward_forward
+ _dispatch_forward_fn(module, internlm2_reward_forward)
+ return internlm2_reward_forward.__name__
+
+
+# HACK
+def _dispatch_qwen2_reward_forward(module):
+ assert module.__class__.__name__ == 'Qwen2ForRewardModel'
+ from .internlm2 import internlm2_reward_forward
+ _dispatch_forward_fn(module, internlm2_reward_forward)
+ return internlm2_reward_forward.__name__
+
+
+def _dispatch_clip_attn_forward(module):
+ assert module.__class__.__name__ == 'CLIPAttention'
+ from .clip import clip_flash_attn_forward
+ _dispatch_forward_fn(module, clip_flash_attn_forward)
+ return clip_flash_attn_forward.__name__
+
+
+def _dispatch_rms_norm_forward(module):
+ from .._fused import rms_norm_forward
+ _dispatch_forward_fn(module, rms_norm_forward)
+ return rms_norm_forward.__name__
+
+
+def _dispatch_internvl2_forward(module):
+ assert module.__class__.__name__ == 'InternVLChatModel'
+ from .internvl2 import internvl2_forward
+ _dispatch_forward_fn(module, internvl2_forward)
+ return internvl2_forward.__name__
+
+
+def _dispatch_llama_varlen_attn_forward(module):
+ assert module.__class__.__name__ == 'LlamaFlashAttention2'
+ from .llama import llama_flash_attn_forward
+ _dispatch_forward_fn(module, llama_flash_attn_forward)
+ return llama_flash_attn_forward.__name__
+
+
+def _dispatch_llama_casual_forward(module):
+ assert module.__class__.__name__ in ['LlamaForCausalLM']
+ from .llama import llama_casual_forward
+ _dispatch_forward_fn(module, llama_casual_forward)
+ return llama_casual_forward.__name__
+
+
+def _dispatch_minicpmv_forward(module):
+ assert module.__class__.__name__ == 'MiniCPMV'
+ from .minicpmv import minicpmv_forward
+ _dispatch_forward_fn(module, minicpmv_forward)
+ return minicpmv_forward.__name__
+
+
+DISPATCH_MAP = {
+ 'Qwen2RMSNorm': _dispatch_rms_norm_forward,
+ 'Qwen2FlashAttention2': _dispatch_qwen2_attn_flash_forward,
+ 'Qwen2Attention': _dispatch_qwen2_attn_flash_forward,
+ 'Qwen2SdpaAttention': _dispatch_qwen2_attn_flash_forward,
+ 'Qwen2ForCausalLM': _dispatch_qwen2_casual_forward,
+ 'InternLM2Attention': _dispatch_internlm2_varlen_attn_forward,
+ 'InternLM2SdpaAttention': _dispatch_internlm2_varlen_attn_forward,
+ 'InternLM2FlashAttention2': _dispatch_internlm2_varlen_attn_forward,
+ 'InternLM2ForCausalLM': _dispatch_internlm2_casual_forward,
+ 'CLIPAttention': _dispatch_clip_attn_forward,
+ 'InternLM2ForRewardModel': _dispatch_internlm2_reward_forward,
+ 'Qwen2ForRewardModel': _dispatch_qwen2_reward_forward,
+ 'InternLM2RMSNorm': _dispatch_rms_norm_forward,
+ 'InternVLChatModel': _dispatch_internvl2_forward, # to support sp and liger
+ 'LlamaFlashAttention2': _dispatch_llama_varlen_attn_forward,
+ 'LlamaForCausalLM': _dispatch_llama_casual_forward,
+ 'LlamaRMSNorm': _dispatch_rms_norm_forward,
+ 'MiniCPMV': _dispatch_minicpmv_forward, # to support sp and liger
+}
+
+
+def dispatch_hf_code(model):
+ from xtuner._lite import get_logger
+ logger = get_logger()
+
+ for name, module in model.named_modules():
+ cls_name = module.__class__.__name__
+ if cls_name in DISPATCH_MAP:
+ dispatched = DISPATCH_MAP[cls_name](module)
+ if dispatched is not None:
+ logger.debug(
+ f'Dispatch {name}({cls_name}) forward to `{dispatched}`')
diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/clip.py b/xtuner/_lite/accelerate/dispatches/huggingface/clip.py
new file mode 100644
index 000000000..ed99bf771
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/huggingface/clip.py
@@ -0,0 +1,98 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from transformers import CLIPVisionModel
+
+from .._attention import flash_attn_wo_mask
+
+
+def clip_flash_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel."""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(bsz, tgt_len,
+ self.num_heads, -1)
+ key_states = self.k_proj(hidden_states).view(bsz, tgt_len, self.num_heads,
+ -1)
+ value_states = self.v_proj(hidden_states).view(bsz, tgt_len,
+ self.num_heads, -1)
+
+ # proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ # query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ # key_states = key_states.view(*proj_shape)
+ # value_states = value_states.view(*proj_shape)
+
+ # src_len = key_states.size(1)
+ # attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ # raise ValueError(
+ # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ # f" {attn_weights.size()}"
+ # )
+
+ # # apply the causal_attention_mask first
+ # if causal_attention_mask is not None:
+ # if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ # raise ValueError(
+ # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ # f" {causal_attention_mask.size()}"
+ # )
+ # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ # if attention_mask is not None:
+ # if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ # raise ValueError(
+ # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ # )
+ # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ # attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # if output_attentions:
+ # # this operation is a bit akward, but it's required to
+ # # make sure that attn_weights keeps its gradient.
+ # # In order to do so, attn_weights have to reshaped
+ # # twice and have to be reused in the following
+ # attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ # attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ # else:
+ # attn_weights_reshaped = None
+
+ # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # attn_output = torch.bmm(attn_probs, value_states)
+
+ # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ # raise ValueError(
+ # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ # f" {attn_output.size()}"
+ # )
+
+ # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ # attn_output = attn_output.transpose(1, 2)
+ # attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ self.dropout if self.training else 0,
+ causal=causal_attention_mask is not None).view(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None
diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/internlm2.py b/xtuner/_lite/accelerate/dispatches/huggingface/internlm2.py
new file mode 100644
index 000000000..fa81f0936
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/huggingface/internlm2.py
@@ -0,0 +1,677 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Tuple, Union
+import inspect
+
+import torch
+from einops import rearrange
+from mmengine import MessageHub
+from transformers.cache_utils import StaticCache, Cache
+from transformers.modeling_outputs import SequenceClassifierOutputWithPast, CausalLMOutputWithPast
+
+from xtuner._lite.accelerate import lmdeploy_is_available, liger_kernel_is_available
+from .._attention import flash_attn_wo_mask, varlen_flash_attn
+
+
+class InternLM2RotaryEmbedding(torch.nn.Module):
+
+ def __init__(self,
+ dim,
+ max_position_embeddings=2048,
+ base=1000000,
+ device=None):
+ super().__init__()
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.inv_freq = 1.0 / (
+ base**(torch.arange(0, dim, 2).float().to(device) / dim))
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(
+ self.max_seq_len_cached,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = emb.cos()
+ self.sin_cached = emb.sin()
+
+ def forward(self, x, seq_len):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if (seq_len > self.max_seq_len_cached
+ or self.cos_cached.device != x.device
+ or self.cos_cached.dtype != x.dtype):
+ self.max_seq_len_cached = seq_len
+ assert self.inv_freq.dtype == torch.float32
+ t = torch.arange(
+ self.max_seq_len_cached,
+ device=x.device,
+ dtype=self.inv_freq.dtype)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(t.device))
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.cos_cached = emb.cos().to(x.dtype)
+ self.sin_cached = emb.sin().to(x.dtype)
+ return (
+ self.cos_cached[:seq_len, ...],
+ self.sin_cached[:seq_len, ...],
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., :x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def apply_rotary_pos_emb_old(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors."""
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """This is the equivalent of torch.repeat_interleave(x, dim=1,
+ repeats=n_rep).
+
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
+ (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :,
+ None, :, :].expand(batch,
+ num_key_value_heads,
+ n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
+ head_dim)
+
+
+def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
+ to (batch, seqlen, num_attention_heads, head_dim)"""
+ batch, slen, num_key_value_heads, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, :,
+ None, :].expand(batch, slen,
+ num_key_value_heads, n_rep,
+ head_dim)
+ return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
+ head_dim)
+
+
+def _internlm2_varlen_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ # Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ '`static` cache implementation is not compatible with '
+ '`attn_implementation==flash_attention_2` make sure to use `sdpa` '
+ 'in the mean time, and open an issue at '
+ 'https://github.com/huggingface/transformers')
+
+ bsz, q_len, _ = hidden_states.size()
+ attn_context = MessageHub.get_instance('packed_sequence')
+
+ position_ids = attn_context.get_info('position_ids')
+ sp_mesh = attn_context.get_info('sp_mesh')
+ assert position_ids.size(1) == q_len, f'{position_ids.size(1)} {q_len}'
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ signature = inspect.signature(self.rotary_emb.forward)
+ if 'seq_len' in signature.parameters:
+ # old
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
+ query_states, key_states = apply_rotary_pos_emb_old(query_states, key_states, cos, sin, position_ids)
+ else:
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
+ cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models;
+ # cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training
+ # stability reasons therefore the input hidden states gets silently
+ # casted in float32. Hence, we need cast them back in the correct dtype
+ # just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not
+ # cast the LayerNorms in fp32. (InternLM2RMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.wqkv.weight.dtype
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # repeat kv for sequence parallel
+ key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
+ value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+
+ cumulative_lengths = attn_context.get_info('cumulative_lengths')
+ if cumulative_lengths is not None and bsz == 1:
+ max_seqlen = attn_context.get_info('max_seqlen')
+ attn_output = varlen_flash_attn(query_states, key_states, value_states,
+ cumulative_lengths, max_seqlen,training=self.training, sp_mesh=sp_mesh)
+ else:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=True,
+ training=self.training,
+ sp_mesh=sp_mesh)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.wo(attn_output)
+
+ # Due to the implementation of the PyTorch version of flash attention,
+ # even when the output_attentions flag is set to True, it is not possible
+ # to return the attn_weights.
+ return attn_output, None, past_key_value
+
+
+def _contiguous_batching_forward_impl(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ """Rewrite implementation of LlamaAttention.forward.
+
+ Add continuous batching support. Add paged attention support. TP support.
+ """
+ from lmdeploy.pytorch.kernels import \
+ apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy
+ from lmdeploy.pytorch.kernels import fill_kv_cache, paged_attention_fwd
+ attn_ctx = MessageHub.get_instance('paged_attention')
+ kv_seq_length = attn_ctx.get_info('kv_seq_length')
+ q_seq_length = attn_ctx.get_info('q_seq_length')
+ q_start_loc = attn_ctx.get_info('q_start_loc')
+ block_offsets = attn_ctx.get_info('block_offsets')
+ max_q_seq_length = attn_ctx.get_info('max_q_seq_length')
+ max_kv_seq_length = attn_ctx.get_info('max_kv_seq_length')
+ position_ids = attn_ctx.get_info('position_ids')
+
+ # position_ids
+ def __qkv_proj(hidden_states):
+ """qkv_proj."""
+ # from torch.distributed import get_rank
+ # if get_rank() == 0:
+ # breakpoint()
+ # else:
+ # import time
+ # time.sleep(10000)
+ qkv_states = self.wqkv(hidden_states[0]).unsqueeze(0)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> (b q) h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = query_states.flatten(1, 2)
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+ return query_states, key_states, value_states
+
+ from lmdeploy.pytorch.kernels import \
+ apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy
+
+ def __rotary_emb_fn(query_states, key_states, value_states):
+ """rotary embedding func."""
+ # breakpoint()
+ # query_states = query_states.unsqueeze(0).transpose(1, 2)
+ # key_states = key_states.unsqueeze(0).transpose(1, 2)
+ # value_states = value_states.unsqueeze(0).transpose(1, 2)
+ if self.layer_idx == 0:
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ attn_ctx.update_info('rotary_cos_sin', (cos, sin))
+ else:
+ cos, sin = attn_ctx.get_info('rotary_cos_sin')
+
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
+ # cos, sin)
+ # query_states = query_states.transpose(1, 2).squeeze(0)
+ # key_states = key_states.transpose(1, 2).squeeze(0)
+ # value_states = value_states.transpose(1, 2).squeeze(0)
+ query_states, key_states = apply_rotary_pos_emb_lmdeploy(
+ query_states,
+ key_states,
+ cos,
+ sin,
+ q_embed=query_states,
+ k_embed=key_states)
+
+ return query_states, key_states, value_states
+
+ query_states, key_states, value_states = __qkv_proj(hidden_states)
+
+ query_states, key_states, value_states = __rotary_emb_fn(
+ query_states, key_states, value_states)
+
+ fill_kv_cache(
+ key_states,
+ value_states,
+ past_key_value[self.layer_idx][0],
+ past_key_value[self.layer_idx][1],
+ q_start_loc,
+ q_seq_length,
+ kv_seq_length=kv_seq_length,
+ max_q_seq_length=max_q_seq_length,
+ block_offsets=block_offsets,
+ )
+
+ attn_output = query_states
+ paged_attention_fwd(
+ query_states,
+ past_key_value[self.layer_idx][0],
+ past_key_value[self.layer_idx][1],
+ attn_output,
+ block_offsets,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seq_length,
+ kv_seqlens=kv_seq_length,
+ max_seqlen=max_q_seq_length,
+ # max_kv_seq_length=max_kv_seq_length,
+ )
+ attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
+
+ attn_output = self.wo(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+def _flash_att_infer(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ """Rewrite implementation of LlamaAttention.forward.
+
+ Add continuous batching support. Add paged attention support. TP support.
+ """
+ from lmdeploy.pytorch.kernels import \
+ apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy
+ from lmdeploy.pytorch.kernels import fill_kv_cache, paged_attention_fwd
+ attn_ctx = MessageHub.get_instance('paged_attention')
+ kv_seq_length = attn_ctx.get_info('kv_seq_length')
+ q_seq_length = attn_ctx.get_info('q_seq_length')
+ q_start_loc = attn_ctx.get_info('q_start_loc')
+ block_offsets = attn_ctx.get_info('block_offsets')
+ max_q_seq_length = attn_ctx.get_info('max_q_seq_length')
+ max_kv_seq_length = attn_ctx.get_info('max_kv_seq_length')
+ cumulative_length = attn_ctx.get_info('cumulative_length')
+ is_prefilling = attn_ctx.get_info('is_prefilling')
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> (b q) h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = query_states.flatten(1, 2)
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+
+ from lmdeploy.pytorch.kernels import \
+ apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy
+
+ if self.layer_idx == 0:
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ attn_ctx.update_info('rotary_cos_sin', (cos, sin))
+ else:
+ cos, sin = attn_ctx.get_info('rotary_cos_sin')
+
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
+ # cos, sin)
+ # query_states = query_states.transpose(1, 2).squeeze(0)
+ # key_states = key_states.transpose(1, 2).squeeze(0)
+ # value_states = value_states.transpose(1, 2).squeeze(0)
+ query_states, key_states = apply_rotary_pos_emb_lmdeploy(
+ query_states,
+ key_states,
+ cos,
+ sin,
+ q_embed=query_states,
+ k_embed=key_states)
+
+ fill_kv_cache(
+ key_states,
+ value_states,
+ past_key_value[self.layer_idx][0],
+ past_key_value[self.layer_idx][1],
+ q_start_loc,
+ q_seq_length,
+ kv_seq_length=kv_seq_length,
+ max_q_seq_length=max_q_seq_length,
+ block_offsets=block_offsets,
+ )
+
+ # attn_output = query_states
+
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
+ if is_prefilling:
+ # breakpoint()
+ key_states = repeat_kv_bshd(
+ key_states.unsqueeze(0), self.num_key_value_groups).squeeze(0)
+ value_states = repeat_kv_bshd(
+ value_states.unsqueeze(0), self.num_key_value_groups).squeeze(0)
+ attn_output = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cumulative_length,
+ cumulative_length,
+ max_q_seq_length,
+ max_kv_seq_length,
+ causal=True)
+ # attn_output = varlen_flash_attn(query_states, key_states, value_states,
+ # cumulative_length, max_q_seq_length)
+ else:
+ query_states = query_states.unsqueeze(1)
+ attn_output = flash_attn_with_kvcache(
+ query_states,
+ past_key_value[self.layer_idx][0],
+ past_key_value[self.layer_idx][1],
+ cache_seqlens=kv_seq_length,
+ block_table=block_offsets,
+ causal=True)
+ attn_output = attn_output.squeeze(1)
+ # paged_attention_fwd(
+ # query_states,
+ # past_key_value[self.layer_idx][0],
+ # past_key_value[self.layer_idx][1],
+ # attn_output,
+ # block_offsets,
+ # q_start_loc=q_start_loc,
+ # q_seqlens=q_seq_length,
+ # kv_seqlens=kv_seq_length,
+ # max_seqlen=max_q_seq_length,
+ # )
+ attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
+
+ attn_output = self.wo(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+def internlm2_varlen_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+
+ lmdeploy_ctx = MessageHub.get_instance('paged_attention')
+
+ if lmdeploy_is_available() and len(lmdeploy_ctx.runtime_info) > 0:
+
+ # return _contiguous_batching_forward_impl(
+ # self, hidden_states, position_ids, past_key_value)
+ return _flash_att_infer(self, hidden_states, position_ids,
+ past_key_value)
+ else:
+ return _internlm2_varlen_attn_forward(self, hidden_states,
+ attention_mask, position_ids,
+ past_key_value,
+ output_attentions, use_cache)
+
+
+def internlm2_reward_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ """labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels
+ for computing the sequence classification/regression loss.
+
+ Returns:
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ reward_scores = self.v_head(hidden_states).squeeze(-1)
+
+ loss = None
+
+ # hidden_states = outputs[0]
+ # hidden_states = self.v_head(hidden_states)
+ # # get end reward token's score
+ # ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1)
+
+ # reward_scores = torch.gather(hidden_states.squeeze(-1), 1, ends)
+
+ loss = None
+
+ # if not return_dict:
+ # ssoutput = (reward_scores,) + outputs[1:]
+ # return (loss,) + output if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=reward_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+
+
+def internlm2_causal_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ label_shifted: bool = False,
+) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ Returns:
+ Example:
+ ```python
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
+ >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+
+ loss = None
+ if labels is None:
+ logits = self.output(hidden_states)
+ else:
+
+ if liger_kernel_is_available():
+ # unable to return logits when using Liger Kernel
+ logits = None
+
+ if label_shifted:
+ shift_hidden_states = hidden_states
+ shift_labels = labels
+ else:
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_hidden_states.device)
+
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
+
+ loss_fct = LigerFusedLinearCrossEntropyLoss()
+ loss = loss_fct(self.output.weight, shift_hidden_states, shift_labels, self.output.bias)
+
+ else:
+ logits = self.output(hidden_states)
+
+ if label_shifted:
+ shift_logits = logits
+ shift_labels = labels
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_logits.device)
+
+ loss_fct = torch.nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/internvl2.py b/xtuner/_lite/accelerate/dispatches/huggingface/internvl2.py
new file mode 100644
index 000000000..eefc77991
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/huggingface/internvl2.py
@@ -0,0 +1,259 @@
+from typing import List, Optional, Tuple, Union
+
+import torch.utils.checkpoint
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import CausalLMOutputWithPast
+import torch.distributed as dist
+from torch.distributed.nn.functional import all_gather
+from mmengine.logging import MessageHub
+import copy
+from xtuner._lite.parallel.setup import get_sp_mesh
+import math
+import os
+from xtuner._lite.parallel.sequence import split_for_sequence_parallel
+
+try:
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
+except ImportError:
+ LigerFusedLinearCrossEntropyLoss = None
+
+
+def rescale_sp_loss(loss_per_sp_rank,
+ labels_per_sp_rank,
+ sp_mesh = None,
+ ignore_index=-100):
+ if sp_mesh is None:
+ sp_group = get_sp_mesh().get_group()
+ else:
+ sp_group = sp_mesh.get_group()
+
+ if (sp_group is None) or (dist.get_world_size(sp_group) == 1):
+ return loss_per_sp_rank
+
+ shift_labels = labels_per_sp_rank
+ active_tokens = (shift_labels != ignore_index).long().sum()
+ global_active_tokens = copy.deepcopy(active_tokens)
+ dist.all_reduce(global_active_tokens, group=sp_group)
+ loss_weight = active_tokens / global_active_tokens * dist.get_world_size(
+ group=sp_group)
+
+ if active_tokens == 0:
+ # convert nan to 0 just for logging
+ loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank)
+
+ return loss_per_sp_rank * loss_weight
+
+
+def internvl2_forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ image_flags: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ sp_mesh = get_sp_mesh()
+ sp_size = sp_mesh.size()
+ if sp_size > 1:
+ sp_group = sp_mesh.get_group()
+ sp_rank = dist.get_rank(sp_group)
+
+ no_split_input_ids = os.environ.get('NO_SPLIT_INPUT_IDS')
+ split_input_ids = not no_split_input_ids
+ if split_input_ids:
+ pad_id = 0
+ orig_len_input_ids = input_ids.shape[1]
+ image_flags = image_flags.squeeze(-1)
+ assert input_ids.shape[0] == 1, 'batch size must be 1 for sequence parallel'
+
+ if orig_len_input_ids % sp_size != 0:
+ max_inputs_len = math.ceil(orig_len_input_ids / sp_size) * sp_size
+ _temp = input_ids.new_full((1, max_inputs_len - orig_len_input_ids), pad_id)
+ input_ids_new = torch.cat([input_ids, _temp], dim=-1)
+ else:
+ input_ids_new = input_ids
+ input_ids_list = torch.split(input_ids_new, input_ids_new.shape[1] // sp_size, dim=-1)
+ input_ids_rank_pre = input_ids_list[sp_rank].contiguous()
+ input_embeds_rank_pre = self.language_model.get_input_embeddings()(input_ids_rank_pre).clone()
+
+ input_embeds = all_gather(input_embeds_rank_pre, group=sp_group)
+
+ input_embeds = torch.cat(input_embeds, dim=1)
+ input_embeds = input_embeds[:, :orig_len_input_ids]
+ else:
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
+
+ no_split_pixel_values = os.environ.get('NO_SPLIT_PIXEL_VALUES')
+ split_pixel_values = not no_split_pixel_values
+ if split_pixel_values:
+ orig_img_batch = pixel_values.shape[0]
+ if orig_img_batch % sp_size != 0:
+ max_inputs_len = math.ceil(orig_img_batch / sp_size) * sp_size
+ pad_img_batch = max_inputs_len - orig_img_batch
+ pad_pixel_values_ = pixel_values.new_zeros(pad_img_batch, 3,
+ pixel_values.shape[2],
+ pixel_values.shape[3])
+ pixel_values = torch.cat([pixel_values, pad_pixel_values_], dim=0)
+ pixel_values = torch.split(pixel_values, len(pixel_values) // sp_size, dim=0)
+ pixel_values = pixel_values[sp_rank].contiguous()
+
+ vit_embeds = self.extract_feature(pixel_values)
+
+ vit_embeds = all_gather(vit_embeds, group=sp_group)
+
+ vit_embeds = torch.cat(vit_embeds, dim=0)[:orig_img_batch]
+ else:
+ vit_embeds = self.extract_feature(pixel_values)
+ vit_embeds = vit_embeds[image_flags == 1]
+ else:
+ image_flags = image_flags.squeeze(-1)
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
+
+ vit_embeds = self.extract_feature(pixel_values)
+ vit_embeds = vit_embeds[image_flags == 1]
+
+ # vit_batch_size = pixel_values.shape[0]
+
+ B, N, C = input_embeds.shape
+ input_embeds = input_embeds.reshape(B * N, C)
+
+ # if torch.distributed.get_rank() == 0:
+ # print(
+ # f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
+
+ input_ids = input_ids.reshape(B * N)
+ selected = (input_ids == self.img_context_token_id)
+ try:
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
+ except Exception as e:
+ vit_embeds = vit_embeds.reshape(-1, C)
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
+ f'vit_embeds.shape={vit_embeds.shape}')
+ n_token = selected.sum()
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
+
+ input_embeds = input_embeds.reshape(B, N, C)
+
+ if sp_size > 1:
+ attn_context = MessageHub.get_instance('packed_sequence')
+ position_ids = attn_context.get_info('position_ids')
+
+ is_ref_forward = attn_context.get_info('is_ref_forward')
+
+ # TODO: phi3 attention
+ attn_context.update_info('global_position_ids', position_ids)
+ attention_mask = None
+
+ if is_ref_forward is not None:
+ input_embeds = split_for_sequence_parallel(
+ input_embeds, dim=1, sp_mesh=sp_mesh)
+ if labels is not None:
+ labels = split_for_sequence_parallel(
+ labels, dim=1, sp_mesh=sp_mesh)
+ else:
+ if labels is not None:
+ assert position_ids.size(1) == input_embeds.shape[1] == labels.shape[1], \
+ f'{position_ids.size(1)} {input_embeds.shape[1]} {labels.shape[1]}'
+ else:
+ assert position_ids.size(1) == input_embeds.shape[1], \
+ f'{position_ids.size(1)} {input_embeds.shape[1]}'
+
+ assert position_ids.size(1) % sp_size == 0
+ # `dim` is 1 as the shape of tensor is (bs, seq_len)
+ position_ids = split_for_sequence_parallel(
+ position_ids, dim=1, sp_mesh=sp_mesh)
+ input_embeds = split_for_sequence_parallel(
+ input_embeds, dim=1, sp_mesh=sp_mesh)
+ if labels is not None:
+ labels = split_for_sequence_parallel(
+ labels, dim=1, sp_mesh=sp_mesh)
+
+ attn_context.update_info('position_ids', position_ids)
+
+ use_liger_kernel = os.environ.get('USE_LIGER_KERNEL')
+ if use_liger_kernel and labels is not None and self.training:
+ output_attentions = output_attentions if output_attentions is not None else self.language_model.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.language_model.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.language_model.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.language_model.model(
+ inputs_embeds=input_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ # Flatten tokens
+ shift_hidden_states = shift_hidden_states.view(-1, self.language_model.config.hidden_size)
+ shift_labels = shift_labels.view(-1)
+
+ if LigerFusedLinearCrossEntropyLoss is None:
+ raise ImportError('LigerFusedLinearCrossEntropyLoss is not available, '
+ 'please install liger-kernel by "pip install liger_kernel".')
+ lce = LigerFusedLinearCrossEntropyLoss()
+ if hasattr(self.language_model, 'lm_head'):
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
+ else:
+ loss = lce(self.language_model.output.weight, shift_hidden_states, shift_labels)
+ if sp_size > 1:
+ loss = rescale_sp_loss(loss, shift_labels, sp_mesh=sp_mesh)
+ logits = None
+ else:
+ outputs = self.language_model(
+ inputs_embeds=input_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ logits = outputs.logits
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if sp_size > 1:
+ loss = rescale_sp_loss(loss, shift_labels, sp_mesh=sp_mesh)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/llama.py b/xtuner/_lite/accelerate/dispatches/huggingface/llama.py
new file mode 100644
index 000000000..0630f6f58
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/huggingface/llama.py
@@ -0,0 +1,257 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Tuple, Union
+
+import torch
+from mmengine import MessageHub
+from transformers.cache_utils import Cache
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from xtuner._lite.accelerate import lmdeploy_is_available, liger_kernel_is_available
+
+from .._attention import flash_attn_wo_mask, varlen_flash_attn
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def llama_flash_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs # FlashAttentionKwargs in the future
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+ attn_context = MessageHub.get_instance('packed_sequence')
+ sp_mesh = attn_context.get_info('sp_mesh')
+
+ position_ids = attn_context.get_info('position_ids')
+ assert position_ids.size(1) == q_len, f'{position_ids.size(1)} {q_len}'
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ cumulative_lengths = attn_context.get_info('cumulative_lengths')
+ if cumulative_lengths is not None and bsz == 1:
+ max_seqlen = attn_context.get_info('max_seqlen')
+ attn_output = varlen_flash_attn(query_states, key_states, value_states,
+ cumulative_lengths, max_seqlen, dropout_p=dropout_rate,
+ training=self.training, sp_mesh=sp_mesh)
+ else:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=True,
+ dropout_p=dropout_rate,
+ training=self.training)
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+
+def llama_casual_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ label_shifted = False,
+ **kwargs,
+) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ num_logits_to_keep (`int`, *optional*):
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ if labels is None:
+ loss = None
+ logits = self.lm_head(hidden_states)
+ else:
+ if liger_kernel_is_available():
+ # unable to return logits when using Liger Kernel
+ logits = None
+
+ if label_shifted:
+ shift_hidden_states = hidden_states
+ shift_labels = labels
+ else:
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_hidden_states.device)
+
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
+
+ loss_fct = LigerFusedLinearCrossEntropyLoss()
+ loss = loss_fct(self.lm_head.weight, shift_hidden_states, shift_labels, self.lm_head.bias)
+
+ else:
+ logits = self.lm_head(hidden_states)
+
+ if label_shifted:
+ shift_logits = logits
+ shift_labels = labels
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_logits.device)
+
+ loss_fct = torch.nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/minicpmv.py b/xtuner/_lite/accelerate/dispatches/huggingface/minicpmv.py
new file mode 100644
index 000000000..5e326ea5f
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/huggingface/minicpmv.py
@@ -0,0 +1,56 @@
+import os
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+try:
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
+except ImportError:
+ LigerFusedLinearCrossEntropyLoss = None
+
+
+def minicpmv_forward(self, data, **kwargs):
+ vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
+ use_liger_kernel = os.environ.get('USE_LIGER_KERNEL')
+ labels = data.get('labels')
+ if use_liger_kernel and labels is not None and self.training:
+ output_attentions = self.config.output_attentions
+ output_hidden_states = self.config.output_hidden_states
+ return_dict = self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.llm.model(
+ inputs_embeds=vllm_embedding,
+ use_cache=False,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten tokens
+ shift_hidden_states = shift_hidden_states.view(-1, self.llm.config.hidden_size)
+ shift_labels = shift_labels.view(-1)
+
+ if LigerFusedLinearCrossEntropyLoss is None:
+ raise ImportError('LigerFusedLinearCrossEntropyLoss is not available, '
+ 'please install liger-kernel by "pip install liger_kernel".')
+ lce = LigerFusedLinearCrossEntropyLoss()
+ loss = lce(self.llm.lm_head.weight, shift_hidden_states, shift_labels)
+ if not return_dict:
+ output = (None,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ return self.llm(
+ input_ids=None,
+ inputs_embeds=vllm_embedding,
+ labels=labels,
+ **kwargs
+ )
diff --git a/xtuner/_lite/accelerate/dispatches/huggingface/qwen2.py b/xtuner/_lite/accelerate/dispatches/huggingface/qwen2.py
new file mode 100644
index 000000000..1b85e37f1
--- /dev/null
+++ b/xtuner/_lite/accelerate/dispatches/huggingface/qwen2.py
@@ -0,0 +1,455 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from mmengine import MessageHub
+from transformers.cache_utils import Cache
+from transformers.models.qwen2.modeling_qwen2 import (apply_rotary_pos_emb,
+ repeat_kv)
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from xtuner._lite.accelerate import lmdeploy_is_available, liger_kernel_is_available
+from .._attention import flash_attn_wo_mask, varlen_flash_attn
+
+SUPPORT_FLASH2 = False
+
+try:
+ from flash_attn import flash_attn_func
+ _flash_supports_window_size = 'window_size' in list(
+ inspect.signature(flash_attn_func).parameters)
+ SUPPORT_FLASH2 = True
+except ImportError:
+ pass
+
+
+def _qwen2_attn_varlen_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+):
+ is_training = self.training
+
+ # assert is_training == (past_key_value is None)
+
+ if 'padding_mask' in kwargs:
+ warnings.warn(
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37'
+ ' Please make sure use `attention_mask` instead.`')
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop('padding_mask')
+ bsz, q_len, _ = hidden_states.size()
+
+ attn_context = MessageHub.get_instance('packed_sequence')
+ position_ids = attn_context.get_info('position_ids')
+ assert position_ids.size(1) == q_len, f'{position_ids.size(1)} {q_len}'
+ sp_mesh = attn_context.get_info('sp_mesh')
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
+ self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
+ self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ 'The cache structure has changed since version v4.36. '
+ f'If you are using {self.__class__.__name__} '
+ 'for auto-regressive decoding with k/v caching, '
+ 'please make sure to initialize the attention class '
+ 'with a layer index.')
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
+ self.layer_idx)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
+ cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value
+ # `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (getattr(self.config, 'sliding_window', None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ 'past key must have a shape of (`batch_size, num_heads, '
+ 'self.config.sliding_window-1, head_dim`), got'
+ f' {past_key.shape}')
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat(
+ [attention_mask,
+ torch.ones_like(attention_mask[:, -1:])],
+ dim=-1)
+
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads for sequence parallel
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for
+ # training stability reasons, therefore the input hidden states gets
+ # silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # ----------------- flash attention forward ------------------------#
+
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ causal = self.is_causal and q_len != 1
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and getattr(self.config, 'sliding_window', None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and self.layer_idx < self.config.max_window_layers
+ and self.config.use_sliding_window)
+
+ window_size = (self.config.sliding_window,
+ self.config.sliding_window) if use_sliding_windows else (-1,
+ -1)
+
+ assert SUPPORT_FLASH2
+ cumulative_lengths = attn_context.get_info('cumulative_lengths')
+ if cumulative_lengths is not None and SUPPORT_FLASH2 and bsz == 1:
+ max_seqlen = attn_context.get_info('max_seqlen')
+ attn_output = varlen_flash_attn(
+ query_states,
+ key_states,
+ value_states,
+ cumulative_lengths,
+ max_seqlen,
+ causal=causal,
+ dropout_p=dropout_rate,
+ window_size=window_size,
+ training=self.training,
+ sp_mesh=sp_mesh)
+ else:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=causal,
+ dropout_p=dropout_rate,
+ window_size=window_size,
+ training=self.training,
+ sp_mesh=sp_mesh)
+
+ # ---------------- flash attention forward end ------------------- #
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+
+
+def _qwen2_attn_contiguous_batching_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+):
+
+
+ from lmdeploy.pytorch.kernels import \
+ apply_rotary_pos_emb as apply_rotary_pos_emb_lmdeploy
+ from lmdeploy.pytorch.kernels import fill_kv_cache, paged_attention_fwd
+ attn_ctx = MessageHub.get_instance('paged_attention')
+ kv_seq_length = attn_ctx.get_info('kv_seq_length')
+ q_seq_length = attn_ctx.get_info('q_seq_length')
+ q_start_loc = attn_ctx.get_info('q_start_loc')
+ block_offsets = attn_ctx.get_info('block_offsets')
+ max_q_seq_length = attn_ctx.get_info('max_q_seq_length')
+ max_kv_seq_length = attn_ctx.get_info('max_kv_seq_length')
+ cumulative_length = attn_ctx.get_info('cumulative_length')
+ is_prefilling = attn_ctx.get_info('is_prefilling')
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
+ self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
+ self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
+ cos, sin, position_ids)
+
+ fill_kv_cache(
+ key_states.transpose(1, 2),
+ value_states.transpose(1, 2),
+ past_key_value[self.layer_idx][0],
+ past_key_value[self.layer_idx][1],
+ q_start_loc,
+ q_seq_length,
+ kv_seq_length=kv_seq_length,
+ max_q_seq_length=max_q_seq_length,
+ block_offsets=block_offsets,
+ )
+
+ # ----------------- flash attention forward ------------------------#
+
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ causal = self.is_causal and q_len != 1
+
+ use_sliding_windows = False
+
+ window_size = (self.config.sliding_window,
+ self.config.sliding_window) if use_sliding_windows else (-1,
+ -1)
+ # TODO support sliding window attention
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
+ if is_prefilling:
+
+ key_states = repeat_kv(
+ key_states, self.num_key_value_groups)
+ value_states = repeat_kv(
+ value_states, self.num_key_value_groups)
+
+ attn_output = flash_attn_varlen_func(
+ query_states.transpose(1,2).squeeze(0),
+ key_states.transpose(1,2).squeeze(0),
+ value_states.transpose(1,2).squeeze(0),
+ cumulative_length,
+ cumulative_length,
+ max_q_seq_length,
+ max_kv_seq_length,
+ causal=True)
+ else:
+ # breakpoint()
+ query_states = query_states.transpose(1,2).transpose(0,1)
+
+ attn_output = flash_attn_with_kvcache(
+ query_states,
+ past_key_value[self.layer_idx][0],
+ past_key_value[self.layer_idx][1],
+ cache_seqlens=kv_seq_length,
+ block_table=block_offsets,
+ causal=True)
+ attn_output = attn_output.squeeze(1)
+
+ # ---------------- flash attention forward end ------------------- #
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+
+def qwen2_attn_flash_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+):
+
+ lmdeploy_ctx = MessageHub.get_instance('paged_attention')
+
+ if lmdeploy_is_available() and len(lmdeploy_ctx.runtime_info) > 0:
+
+ return _qwen2_attn_contiguous_batching_forward(self, hidden_states,attention_mask, position_ids,
+ past_key_value, use_cache)
+ else:
+ return _qwen2_attn_varlen_forward(self, hidden_states,
+ attention_mask, position_ids,
+ past_key_value,
+ output_attentions, use_cache)
+
+
+
+
+def qwen2_casual_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ label_shifted = False,
+ **loss_kwargs,
+) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ num_logits_to_keep (`int`, *optional*):
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
+
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+
+ if labels is None:
+ loss = None
+ logits = self.lm_head(hidden_states)
+ else:
+ if liger_kernel_is_available():
+ # unable to return logits when using Liger Kernel
+ logits = None
+
+ if label_shifted:
+ shift_hidden_states = hidden_states
+ shift_labels = labels
+ else:
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_hidden_states.device)
+
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
+
+ loss_fct = LigerFusedLinearCrossEntropyLoss()
+ loss = loss_fct(self.lm_head.weight, shift_hidden_states, shift_labels, self.lm_head.bias)
+
+ else:
+ logits = self.lm_head(hidden_states)
+
+ if label_shifted:
+ shift_logits = logits
+ shift_labels = labels
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ shift_labels = shift_labels.to(shift_logits.device)
+
+ loss_fct = torch.nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/xtuner/_lite/accelerate/generate.py b/xtuner/_lite/accelerate/generate.py
new file mode 100644
index 000000000..d75dffeb8
--- /dev/null
+++ b/xtuner/_lite/accelerate/generate.py
@@ -0,0 +1,204 @@
+import torch
+from mmengine import MessageHub
+
+from xtuner._lite import get_logger
+from .packed import pack_sequence, packed_cumulative_length
+
+logger = get_logger()
+
+
+@torch.no_grad()
+def sample(logits,do_sample=True, top_k=0, top_p=0.9, temperature=1.0):
+
+ if not do_sample:
+ return logits.argmax(-1)
+
+ # Apply temperature if necessary
+ if temperature != 1.0:
+ logits = logits / temperature
+
+ # Apply top-k if necessary
+ if top_k > 0:
+ top_k = min(top_k, logits.size(-1))
+ _, topk_indices = logits.topk(top_k,dim=-1)
+ mask = torch.ones_like(logits, dtype=torch.bool)
+ mask.scatter_(-1, topk_indices, False)
+ logits.masked_fill_(mask, -torch.inf)
+
+ # Apply top-p (nucleus sampling) if necessary
+ if top_p < 1.0:
+
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1)
+ cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
+
+ mask = (cum_probs <= (1- top_p))
+ mask[:,-1] = False
+ sorted_logits.masked_fill_(mask, -torch.inf)
+
+ logits.scatter_( -1, sorted_indices, sorted_logits)
+
+ probs = logits.softmax(-1)
+
+ return torch.multinomial(probs, 1).squeeze(-1)
+
+
+@torch.no_grad()
+def contiguous_batching_generate(model,
+ input_ids,
+ stop_token_ids=[],
+ max_batch_size=64,
+ max_new_tokens=128,
+ max_length=2048,
+ do_sample=False,
+ top_k=0,
+ top_p=1.0,
+ temperature=1.0,
+ tp_size=1,
+ device='cuda'):
+
+ model.config.use_cache = True
+
+ from lmdeploy.pytorch.config import CacheConfig, ModelConfig
+ from lmdeploy.pytorch.engine.cache_engine import CacheEngine
+ from lmdeploy.logger import get_logger
+ get_logger('lmdeploy').setLevel('ERROR')
+
+ block_size = 256
+ max_batch_size = min(max_batch_size, len(input_ids))
+ num_blocks = max_length // block_size * max_batch_size
+ cache_config = CacheConfig(max_batch_size, block_size, num_blocks,
+ num_blocks)
+
+
+ if model.config.architectures[0] == 'InternLM3ForCausalLM':
+ model.config.num_hidden_layers = model.config.num_self_decoder_layers + 1
+
+ model_config = ModelConfig.from_hf_config(model.config)
+ cache_engine = CacheEngine(cache_config, model_config, world_size=tp_size)
+
+ block_table = torch.arange(num_blocks).reshape(max_batch_size, -1)
+
+ _packed_ids, _num_tokens = pack_sequence(input_ids[:max_batch_size])
+ _position_ids = [
+ torch.arange(seq.numel()) for seq in input_ids[:max_batch_size]
+ ]
+ _packed_pos_ids = torch.cat(_position_ids, dim=0).unsqueeze(0)
+ _cumulative_length = packed_cumulative_length(_num_tokens)
+
+ next_input_ids = _packed_ids.to(device)
+ next_position_ids = _packed_pos_ids.to(device)
+ next_start_pos = _cumulative_length[:-1].to(device)
+ next_end_pos = (_cumulative_length[1:] - 1).to(device)
+ next_query_length = _num_tokens.to(device)
+ next_cache_length = _num_tokens.to(device)
+ next_block_table = block_table.to(device).to(torch.int32)
+ next_cumulative_length = _cumulative_length.to(device)
+ next_is_prefilling = True
+
+ num_sessions = len(input_ids)
+ computing = [i for i in range(max_batch_size)]
+ waiting = [i for i in range(max_batch_size, num_sessions)]
+
+ responses = [[] for _ in range(num_sessions)]
+
+ while len(waiting) or len(computing):
+
+ attn_ctx = MessageHub.get_instance('paged_attention')
+ attn_ctx.update_info('block_offsets', next_block_table)
+ attn_ctx.update_info('kv_seq_length', next_cache_length)
+ attn_ctx.update_info('q_seq_length', next_query_length)
+ attn_ctx.update_info('position_ids', next_position_ids)
+ attn_ctx.update_info('max_kv_seq_length', next_cache_length.max())
+ attn_ctx.update_info('max_q_seq_length', next_query_length.max())
+ attn_ctx.update_info('q_start_loc', next_start_pos)
+ attn_ctx.update_info('cumulative_length', next_cumulative_length)
+ attn_ctx.update_info('is_prefilling', next_is_prefilling)
+
+ # logger.info('Begin Prefilling')
+
+ outputs = model(
+ input_ids=next_input_ids,
+ position_ids=next_position_ids,
+ past_key_values=cache_engine.gpu_cache,
+ cache_position=next_position_ids,
+ )
+
+ for key in list(attn_ctx.runtime_info.keys()):
+ attn_ctx.pop_info(key)
+
+ # TODO (pppppM) support sampling
+ sampled = sample(outputs.logits[0, next_end_pos],do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature)
+
+ _next_input_ids = []
+ _next_position_ids = []
+ _next_computing = []
+ _next_query_length = []
+ _next_cache_length = []
+ _next_block_table = []
+
+ for i, sess_id in enumerate(computing):
+ token_id = sampled[i]
+ responses[sess_id].append(token_id.item())
+
+ _sess_new_tokens = len(responses[sess_id])
+ _sess_len = _sess_new_tokens + input_ids[sess_id].numel()
+
+ stop = (
+ _sess_new_tokens >= max_new_tokens or _sess_len >= max_length
+ or token_id in stop_token_ids)
+
+ if stop:
+ # session ended
+ if len(waiting):
+ # next step is prefilling
+ new_sess_id = waiting.pop(0)
+ new_sess = input_ids[new_sess_id].to(device)
+
+ _new_sess_len = new_sess.size(-1)
+ # new session override the cache of the stopped session
+ _next_block_table.append(next_block_table[i])
+ _next_computing.append(new_sess_id)
+ _next_input_ids.append(new_sess)
+ _next_position_ids.append(torch.arange(_new_sess_len))
+ _next_query_length.append(_new_sess_len)
+ _next_cache_length.append(_new_sess_len)
+ else:
+ # next step is decoding
+ _next_computing.append(sess_id)
+ _next_block_table.append(next_block_table[i])
+ _next_input_ids.append(token_id.reshape(1, -1))
+ _next_position_ids.append(
+ torch.arange(_sess_len - 1, _sess_len))
+ _next_query_length.append(1)
+ _next_cache_length.append(_sess_len)
+
+ computing = _next_computing
+ if len(computing) == 0:
+ # All sessions have ended.
+ assert len(waiting) == 0
+ break
+
+ _packed_ids, _num_tokens = pack_sequence(_next_input_ids)
+ _cumulative_length = packed_cumulative_length(_num_tokens)
+
+ next_input_ids = _packed_ids.to(device)
+ next_position_ids = torch.cat(_next_position_ids, dim=0).unsqueeze(0)
+ next_position_ids = next_position_ids.to(device)
+ next_start_pos = _cumulative_length[:-1].to(device)
+ next_end_pos = (_cumulative_length[1:] - 1).to(device)
+ next_query_length = torch.IntTensor(_next_query_length).to(device)
+ next_cache_length = torch.IntTensor(_next_cache_length).to(device)
+ next_block_table = torch.stack(_next_block_table).to(device)
+
+ next_cumulative_length = _cumulative_length.to(device)
+ next_is_prefilling = False
+
+ for i in range(len(cache_engine.gpu_cache)):
+ cache_engine.gpu_cache.pop()
+
+ del cache_engine
+ torch.cuda.empty_cache()
+
+ model.config.use_cache = False
+
+ return responses
diff --git a/xtuner/_lite/accelerate/load.py b/xtuner/_lite/accelerate/load.py
new file mode 100644
index 000000000..bf8a35513
--- /dev/null
+++ b/xtuner/_lite/accelerate/load.py
@@ -0,0 +1,32 @@
+import torch
+
+
+class LoadWoInit:
+ """Context manager that disable parameter initialization."""
+
+ def __init__(self):
+ self.constant_ = torch.nn.init.constant_
+ self.zeros_ = torch.nn.init.zeros_
+ self.ones_ = torch.nn.init.ones_
+ self.uniform_ = torch.nn.init.uniform_
+ self.normal_ = torch.nn.init.normal_
+ self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_
+ self.kaiming_normal_ = torch.nn.init.kaiming_normal_
+
+ def __enter__(self, *args, **kwargs):
+ torch.nn.init.constant_ = lambda *args, **kwargs: None
+ torch.nn.init.zeros_ = lambda *args, **kwargs: None
+ torch.nn.init.ones_ = lambda *args, **kwargs: None
+ torch.nn.init.uniform_ = lambda *args, **kwargs: None
+ torch.nn.init.normal_ = lambda *args, **kwargs: None
+ torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None
+ torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
+
+ def __exit__(self, *args, **kwargs):
+ torch.nn.init.constant_ = self.constant_
+ torch.nn.init.zeros_ = self.zeros_
+ torch.nn.init.ones_ = self.ones_
+ torch.nn.init.uniform_ = self.uniform_
+ torch.nn.init.normal_ = self.normal_
+ torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_
+ torch.nn.init.kaiming_normal_ = self.kaiming_normal_
diff --git a/xtuner/_lite/accelerate/lora.py b/xtuner/_lite/accelerate/lora.py
new file mode 100644
index 000000000..ad3c9a3f5
--- /dev/null
+++ b/xtuner/_lite/accelerate/lora.py
@@ -0,0 +1,5 @@
+LORA_TARGET_MAP = {
+ 'InternLM2ForCausalLM': ['wqkv', 'wo', 'w1', 'w2', 'w3'],
+ 'CLIPVisionModel':
+ ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2']
+}
diff --git a/xtuner/_lite/accelerate/packed.py b/xtuner/_lite/accelerate/packed.py
new file mode 100644
index 000000000..773e5822a
--- /dev/null
+++ b/xtuner/_lite/accelerate/packed.py
@@ -0,0 +1,72 @@
+from contextlib import contextmanager
+from typing import List, Union
+
+import torch
+
+from xtuner._lite import get_device
+from xtuner._lite.parallel import get_sp_mesh, split_for_sequence_parallel
+
+
+def unpack_sequence(packed: torch.Tensor,
+ num_tokens: Union[torch.Tensor, List],
+ dim=1):
+
+ if isinstance(num_tokens, torch.Tensor):
+ num_tokens = num_tokens.tolist()
+ sequences = torch.split(packed, num_tokens, dim=dim)
+ return sequences
+
+
+def pack_sequence(sequences, dim=1):
+ num_tokens = torch.IntTensor([seq.size(dim) for seq in sequences])
+ packed = torch.cat(sequences, dim=dim)
+ return packed, num_tokens.to(packed.device)
+
+
+def packed_cumulative_length(num_tokens: torch.Tensor):
+
+ device = num_tokens.device
+ _zero_pad = torch.zeros(1, device=device)
+ _pad_length = torch.cat([_zero_pad, num_tokens]).int()
+ return torch.cumsum(_pad_length, 0).int()
+
+
+@contextmanager
+def packed_sequence(num_tokens, enable=True, sp_mesh=None):
+ from mmengine import MessageHub
+ ctx = MessageHub.get_instance('packed_sequence')
+
+ device = get_device()
+ if enable:
+ num_tokens = num_tokens.to(device)
+ device = num_tokens.device
+ _zero_length = torch.zeros(1, device=device)
+ _pad_length = torch.cat([_zero_length, num_tokens]).int()
+ cumulative_lengths = torch.cumsum(_pad_length, 0).int()
+ position_ids = [torch.arange(num.item()) for num in num_tokens]
+ position_ids = torch.cat(position_ids, dim=0).to(device)
+ position_ids = position_ids.unsqueeze(0)
+ if sp_mesh:
+ # `dim` is 1 as the shape of tensor is (bs, seq_len)
+ position_ids = split_for_sequence_parallel(
+ position_ids, dim=1, sp_mesh=sp_mesh)
+
+ # ctx.update_info('num_tokens', num_tokens)
+ ctx.update_info('position_ids', position_ids)
+ ctx.update_info('cumulative_lengths', cumulative_lengths)
+ ctx.update_info('max_seqlen', num_tokens.max())
+ ctx.update_info('sp_mesh', sp_mesh)
+
+ else:
+ # ctx.update_info('num_tokens', None)
+ ctx.update_info('position_ids', None)
+ ctx.update_info('cumulative_lengths', None)
+ ctx.update_info('max_seqlen', None)
+ ctx.update_info('sp_mesh', None)
+ yield
+
+ # ctx.update_info('num_tokens', None)
+ ctx.update_info('position_ids', None)
+ ctx.update_info('cumulative_lengths', None)
+ ctx.update_info('max_seqlen', None)
+ ctx.update_info('sp_mesh', None)
diff --git a/xtuner/_lite/accelerate/utils.py b/xtuner/_lite/accelerate/utils.py
new file mode 100644
index 000000000..86629fef5
--- /dev/null
+++ b/xtuner/_lite/accelerate/utils.py
@@ -0,0 +1,56 @@
+import time
+from contextlib import contextmanager
+from transformers.utils.import_utils import is_flash_attn_2_available
+from xtuner._lite import get_device, get_logger, get_torch_device_module
+
+logger = get_logger()
+
+
+
+
+def npu_is_available():
+ return get_device() == 'npu'
+
+
+def varlen_attn_is_available():
+
+ return is_flash_attn_2_available() or npu_is_available()
+
+
+def lmdeploy_is_available():
+
+ available = False
+ try:
+ import lmdeploy # noqa: F401
+ available = True
+ except ImportError:
+ available = False
+
+ return available
+
+def liger_kernel_is_available():
+
+ available = False
+ try:
+ import liger_kernel # noqa: F401
+ available = True
+ except ImportError:
+ available = False
+
+ return available
+
+
+@contextmanager
+def profile_time_and_memory(desc):
+
+ torch_device = get_torch_device_module()
+ start_t = time.time()
+ torch_device.reset_peak_memory_stats()
+
+ yield
+
+ max_memory = torch_device.max_memory_allocated()
+ cost_time = time.time() - start_t
+
+ logger.success(f'{desc} Elapsed time {cost_time:.2f} seconds, '
+ f'peak gpu memory {max_memory/1024**3:.1f}G')
diff --git a/xtuner/_lite/algorithms/__init__.py b/xtuner/_lite/algorithms/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/xtuner/_lite/algorithms/ppo/__init__.py b/xtuner/_lite/algorithms/ppo/__init__.py
new file mode 100644
index 000000000..2a76959bf
--- /dev/null
+++ b/xtuner/_lite/algorithms/ppo/__init__.py
@@ -0,0 +1,10 @@
+from .dataset import RewardBufferCollator, InferDataset, PPOTokenizeFunction, RewardBuffer
+from .loss import (CriticLoss, PPOPolicyLoss, compute_advantages_and_returns,
+ compute_kl_rewards, gather_logprobs)
+from .model import build_actor_model, build_reward_model
+
+__all__ = [
+ 'PPOCollator', 'PPODataset', 'PPOTokenizeFunction', 'CriticLoss',
+ 'PPOPolicyLoss', 'compute_advantages_and_returns', 'compute_rewards',
+ 'gather_logprobs', 'build_actor_model', 'build_reward_model'
+]
diff --git a/xtuner/_lite/algorithms/ppo/dataset.py b/xtuner/_lite/algorithms/ppo/dataset.py
new file mode 100644
index 000000000..4a40d75b4
--- /dev/null
+++ b/xtuner/_lite/algorithms/ppo/dataset.py
@@ -0,0 +1,170 @@
+import torch
+import json
+import numpy as np
+from xtuner._lite.chat.messages.chat import ChatMsg
+from xtuner._lite.datasets import OPENAI_CONVERT_MAP
+from torch import nn
+from ..sft import SftCollator, SftTokenizeFunction
+
+
+class InferDataset(torch.utils.data.Dataset):
+
+ def __init__(self, prompts, responses):
+ super().__init__()
+
+ assert len(prompts) == len(responses)
+ self.prompts = prompts
+ self.responses = responses
+ self.policies = None
+
+ def __len__(self):
+ return len(self.prompts)
+
+
+ def __getitem__(self, item):
+
+ prompt = self.prompts[item]
+ response = self.responses[item]
+ num_prefill_tokens = len(prompt)
+
+ input_ids = prompt + response
+ labels = [-100] * (num_prefill_tokens - 1) + response + [-100]
+
+ return {
+ 'input_ids': input_ids,
+ 'labels': labels,
+ 'num_tokens': len(input_ids)
+ }
+
+
+
+
+FASTER = False
+class RewardBuffer(torch.utils.data.Dataset):
+
+ def __init__(self, clip_min=-5,clip_max = 5, normalize=True, faster=False):
+ super().__init__()
+
+
+ self.clip_min = clip_min
+ self.clip_max = clip_max
+
+ self.normalize = normalize
+
+ if self.normalize:
+ self.bn = nn.BatchNorm1d(1, momentum=None, affine=False)
+ else:
+ self.bn = None
+
+ self._num_action_tokens = 0
+ self._num_total_tokens = 0
+ self._trajectories = []
+
+ self._current_mean = 0
+
+ @property
+ def running_mean(self):
+ return self.bn.running_mean.item()
+
+ @property
+ def current_mean(self):
+ return self._current_mean
+
+ @property
+ def num_action_tokens(self):
+ return self._num_action_tokens.item()
+
+ @property
+ def num_total_tokens(self):
+ return self._num_total_tokens
+
+ def update(self, trajectories):
+
+ rewards = [data['reward'] for data in trajectories]
+
+ for i in range(len(trajectories)):
+ trajectories[i]['ori_reward'] = trajectories[i]['reward']
+
+ rewards = torch.tensor(rewards)
+
+ self._current_mean = rewards.mean().item()
+
+ rewards = rewards.clip(self.clip_min, self.clip_max)
+
+ if self.normalize:
+ self.bn.train()
+ _ = self.bn(rewards.unsqueeze(-1))
+ self.bn.eval()
+ rewards = self.bn(rewards.unsqueeze(-1))
+
+ for i in range(len(trajectories)):
+ trajectories[i]['reward'] = rewards[i].item()
+
+ num_total_tokens = 0
+ num_action_tokens = 0
+ for data in trajectories:
+ labels = np.array(data['labels'])
+ num_total_tokens += labels.size
+ num_action_tokens += (labels >= 0).sum()
+
+ self._num_action_tokens = num_action_tokens
+ self._num_total_tokens = num_total_tokens
+
+ self._trajectories = trajectories
+
+ def dump_jsonl(self, path, tokenizer, debug=False):
+
+ with open(path, 'w', encoding='utf8') as f:
+ for data in self._trajectories:
+ json_line = {
+ 'num_tokens': data['num_tokens'],
+ 'reward': data['ori_reward'],
+ 'sequence': tokenizer.decode(data['input_ids']),
+ }
+
+ if debug:
+ json_line['input_ids'] = data['input_ids']
+ json_line['labels'] = data['labels']
+
+ json_str = json.dumps(json_line, ensure_ascii=False)
+ f.write(json_str + '\n')
+
+ def __len__(self):
+ return len(self._trajectories)
+
+
+ def __getitem__(self, item):
+
+ return self._trajectories[item]
+
+
+class PPOTokenizeFunction(SftTokenizeFunction):
+
+ def __init__(self,
+ tokenizer,
+ chat_template,
+ raw_format='openai',
+ sys_prompt=None):
+ super().__init__(tokenizer, chat_template, raw_format)
+ self.sys_prompt = sys_prompt
+
+ def __call__(self, item):
+
+ formatter = OPENAI_CONVERT_MAP[self.raw_format]
+ msg = formatter(item)
+ if self.sys_prompt is not None:
+ sys_msg = ChatMsg(role='system', content=self.sys_prompt)
+ msg.messages = [sys_msg] + msg.messages
+ tokenized = msg.tokenize(self.tokenizer, self.chat_template)
+
+ return tokenized
+
+
+class RewardBufferCollator(SftCollator):
+
+ def __call__(self, instances):
+
+ data = super().__call__(instances)
+ data['rewards'] = [item['reward'] for item in instances]
+
+ return data
diff --git a/xtuner/_lite/algorithms/ppo/loss.py b/xtuner/_lite/algorithms/ppo/loss.py
new file mode 100644
index 000000000..167379719
--- /dev/null
+++ b/xtuner/_lite/algorithms/ppo/loss.py
@@ -0,0 +1,125 @@
+import torch
+from torch.nn import functional as F
+
+from xtuner._lite import get_logger
+
+logger = get_logger()
+
+
+
+def gather_logprobs(logits, labels):
+ log_probs = F.log_softmax(logits, dim=-1)
+ log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
+ return log_probs_labels.squeeze(-1)
+
+@torch.no_grad()
+def compute_kl_rewards(logprobs, ref_logprobs, reward_score, kl_coef=0.01):
+
+ assert logprobs.ndim == 1
+ last_mask = torch.zeros_like(logprobs, dtype=torch.int)
+ last_mask[-1] = 1
+
+ kl = (ref_logprobs - logprobs)
+ kl_reward = kl_coef * kl * (1 - last_mask)
+
+ last_reward = reward_score * last_mask
+
+ rewards = kl_reward + last_reward
+
+ return rewards
+
+@torch.no_grad()
+def compute_advantages_and_returns(values, rewards, gamma=1.0, gae_lambda=0.99):
+ # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501
+ """Function that computes advantages and returns from rewards and
+ values. Calculated as in the original PPO paper:
+ https://arxiv.org/abs/1707.06347 Note that rewards may include a KL
+ divergence loss term.
+
+ Advantages looks like this:
+ Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
+ - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
+
+ Returns looks like this:
+ Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
+ + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
+ """
+ lastgaelam = 0
+ advantages_reversed = []
+
+ assert values.numel() == rewards.numel(), f'{values.numel()}, {rewards.numel()}'
+ length = rewards.numel()
+
+ for t in reversed(range(0, length)):
+ nextvalues = values[t + 1] if t < length - 1 else 0.0
+ # Since old_rewards and old_values are masked with action_mask,
+ # i.e. they have 0's at pad tokens,
+ # delta will be 0 if current t is at a pad token,
+ # so will lastgaelam
+ delta = rewards[t] + gamma * nextvalues - values[t]
+ lastgaelam = delta + gamma * gae_lambda * lastgaelam
+ advantages_reversed.append(lastgaelam)
+
+ advantages = torch.stack(advantages_reversed[::-1], dim=0)
+ returns = advantages + values
+ return advantages.detach(), returns
+
+
+class CriticLoss(torch.nn.Module):
+ """Loss function for critic model."""
+
+ def __init__(self,
+ cliprange_value: float = 0.5,
+ loss_type: str = 'per_seq'):
+ super().__init__()
+ self.cliprange_value = cliprange_value
+ self.loss_type = loss_type
+
+ assert self.loss_type in ['per_token', 'per_seq']
+
+ def critic_loss_fn(self, values, old_values, returns, loss_factor=None):
+ values_clipped = old_values + (values - old_values).clamp(
+ -self.cliprange_value, self.cliprange_value)
+ vf_loss1 = (values_clipped - returns)**2
+ vf_loss2 = (values - returns)**2
+ if self.loss_type == 'per_seq':
+ vf_loss = torch.max(vf_loss1, vf_loss2).mean(-1)
+ elif self.loss_type == 'per_token':
+ assert loss_factor is not None
+ vf_loss = torch.sum(torch.max(vf_loss1, vf_loss2) * loss_factor)
+ return 0.5 * vf_loss
+
+ def forward(self,
+ values: torch.Tensor,
+ old_values,
+ returns,
+ loss_factor=None):
+
+ loss = self.critic_loss_fn(
+ values=values,
+ old_values=old_values,
+ returns=returns,
+ loss_factor=loss_factor)
+ return loss
+
+
+class PPOPolicyLoss(torch.nn.Module):
+ """Loss function for policy model."""
+
+ def __init__(self, cliprange: float = 0.2, loss_type: str = 'per_seq'):
+ super().__init__()
+ self.cliprange = cliprange
+ self.loss_type = loss_type
+ assert self.loss_type in ['per_token', 'per_seq']
+
+ def forward(self, logprobs, old_logprobs, advantages, loss_factor=None):
+ ratio = (logprobs - old_logprobs).exp()
+ pg_loss1 = -ratio * advantages
+ pg_loss2 = -ratio.clamp(1 - self.cliprange,
+ 1 + self.cliprange) * advantages
+ if self.loss_type == 'per_seq':
+ pg_loss = torch.max(pg_loss1, pg_loss2).mean(dim=-1)
+ elif self.loss_type == 'per_token':
+ assert loss_factor is not None
+ pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2)) * loss_factor
+ return pg_loss
diff --git a/xtuner/_lite/algorithms/ppo/model.py b/xtuner/_lite/algorithms/ppo/model.py
new file mode 100644
index 000000000..2f90e810f
--- /dev/null
+++ b/xtuner/_lite/algorithms/ppo/model.py
@@ -0,0 +1,46 @@
+import torch
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.utils.import_utils import (is_flash_attn_2_available,
+ is_torch_sdpa_available)
+
+from xtuner._lite.accelerate import LoadWoInit
+
+
+def build_actor_model(model_path, dtype=torch.float32, trust_remote_code=True):
+
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ if is_flash_attn_2_available():
+ config.attn_implementation = 'flash_attention_2'
+ elif is_torch_sdpa_available():
+ config.attn_implementation = 'sdpa'
+
+ with LoadWoInit():
+ policy = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ attn_implementation='flash_attention_2',
+ torch_dtype=dtype,
+ trust_remote_code=trust_remote_code)
+
+ return policy
+
+
+def build_reward_model(model_path, dtype=torch.float32, trust_remote_code=True):
+
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ if is_flash_attn_2_available():
+ config.attn_implementation = 'flash_attention_2'
+ elif is_torch_sdpa_available():
+ config.attn_implementation = 'sdpa'
+
+ config.use_cache = False
+ config.torch_dtype = dtype
+ with LoadWoInit():
+ reward = AutoModel.from_pretrained(
+ model_path,
+ attn_implementation='flash_attention_2',
+ torch_dtype=dtype,
+ trust_remote_code=trust_remote_code)
+
+ reward.model.use_cache = False
+
+ return reward
diff --git a/xtuner/_lite/algorithms/sft/__init__.py b/xtuner/_lite/algorithms/sft/__init__.py
new file mode 100644
index 000000000..01a3a63a2
--- /dev/null
+++ b/xtuner/_lite/algorithms/sft/__init__.py
@@ -0,0 +1,3 @@
+from .dataset import SftCollator, SftTokenizeFunction
+
+__all__ = ['SftCollator', 'SftTokenizeFunction']
diff --git a/xtuner/_lite/algorithms/sft/dataset.py b/xtuner/_lite/algorithms/sft/dataset.py
new file mode 100644
index 000000000..bbb9e9608
--- /dev/null
+++ b/xtuner/_lite/algorithms/sft/dataset.py
@@ -0,0 +1,109 @@
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+from xtuner._lite import get_logger
+from xtuner._lite.datasets import OPENAI_CONVERT_MAP
+
+logger = get_logger()
+
+
+class SftTokenizeFunction():
+
+ def __init__(self, tokenizer, chat_template, raw_format='openai'):
+
+ self.tokenizer = tokenizer
+ self.chat_template = chat_template
+ self.raw_format = raw_format
+
+ def __call__(self, item):
+
+ formatter = OPENAI_CONVERT_MAP[self.raw_format]
+ msg = formatter(item)
+ tokenized = msg.tokenize(self.tokenizer, self.chat_template)
+ return tokenized
+
+
+class SftCollator():
+
+ def __init__(self, pad_token_id=0, ignore_id=-100, pack_batch=False, max_length=None):
+ self.pack_batch = pack_batch
+ self.pad_token_id = pad_token_id
+ self.ignore_id = ignore_id
+ self.max_length = max_length
+
+ def __call__(self, instances):
+
+ _instances = []
+ for ins in instances:
+ if isinstance(ins, list):
+ _instances.extend(ins)
+ else:
+ _instances.append(ins)
+
+ instances = _instances
+
+ input_ids = []
+ labels = []
+ num_tokens = []
+
+ for data in instances:
+
+ _input_ids = data['input_ids']
+ _labels = data['labels']
+ _num_tokens = data['num_tokens']
+
+ # TODO remove list
+ if isinstance(_num_tokens, list):
+ assert len(_num_tokens) == 1
+ _num_tokens = _num_tokens[0]
+
+ assert isinstance(_num_tokens, int)
+
+ if self.max_length:
+ _input_ids = _input_ids[:self.max_length]
+ _labels = _labels[:self.max_length]
+ _num_tokens = min(_num_tokens, self.max_length)
+
+ input_ids.append(torch.LongTensor(_input_ids))
+ labels.append(torch.LongTensor(_labels))
+ num_tokens.append(_num_tokens)
+
+ attention_mask = [torch.ones_like(ids) for ids in input_ids]
+ num_tokens = torch.IntTensor(num_tokens)
+
+ if len(instances) > 1 and self.pack_batch:
+
+ input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
+ labels = torch.cat(labels, dim=0).unsqueeze(0)
+ attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0)
+
+ elif len(instances) > 1 and not self.pack_batch:
+
+ input_ids = pad_sequence(
+ input_ids, batch_first=True, padding_value=self.pad_token_id)
+ labels = pad_sequence(
+ labels, batch_first=True, padding_value=self.ignore_id)
+ attention_mask = pad_sequence(
+ attention_mask, batch_first=True, padding_value=0)
+ else:
+ input_ids = torch.stack(input_ids)
+ labels = torch.stack(labels)
+ attention_mask = torch.stack(attention_mask)
+
+ if input_ids.shape != labels.shape:
+ logger.error(f'[instances] {instances}')
+ logger.error(f'[num_tokens] {num_tokens}')
+ logger.error(f'[input_ids] {input_ids}')
+ logger.error(f'[labels] {labels}')
+ raise RuntimeError('The shape of input_ids and labels must be '
+ f'equal, but found {input_ids.shape} and '
+ f'{labels.shape}.')
+
+ data_dict = {
+ 'input_ids': input_ids,
+ 'labels': labels,
+ 'num_tokens': num_tokens,
+ 'attention_mask': attention_mask.bool()
+ }
+
+ return data_dict
diff --git a/xtuner/_lite/auto.py b/xtuner/_lite/auto.py
new file mode 100644
index 000000000..d6e931836
--- /dev/null
+++ b/xtuner/_lite/auto.py
@@ -0,0 +1,185 @@
+import math
+import os
+from typing import Literal, Optional
+
+import torch
+from transformers import BitsAndBytesConfig, PretrainedConfig
+
+from xtuner.model.modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
+
+if os.environ.get('XTUNER_USE_MODELSCOPE'):
+ from modelscope import AutoTokenizer # noqa: F401
+ from modelscope import AutoConfig
+ from modelscope import AutoModelForCausalLM as OriAutoModelForCausalLM
+else:
+ from transformers import AutoTokenizer # noqa: F401
+ from transformers import AutoConfig
+ from transformers import AutoModelForCausalLM as OriAutoModelForCausalLM
+
+
+def download_model_from_hub(
+ model_name_or_path: str,
+ from_hub: Literal['huggingface', 'modelscope'] = 'huggingface',
+ cache_dir: Optional[str] = None,
+) -> str:
+ """Automatically download model from the HUB.
+
+ Note:
+ If `model_name_or_path` is a local path, it will return the path
+ directly without downloading it again.
+
+ Args:
+ model_name_or_path (str): The model name, model path or repo id.
+ config (str | None): The config path. Default is None.
+ from_hub (str): The model hosting hub, modelscope, or huggingface.
+ Default is huggingface.
+ cache_dir (str | None):
+ The save path when downloading the model. If it is None, it
+ will be stored in the default location of the HUB. For
+ Huggingface, it's ~/.cache/huggingface/hub, for ModelScope,
+ it's ~/.cache/modelscope/hub.
+
+ Returns:
+ str: The local path of the model.
+ """
+ if os.path.isdir(model_name_or_path):
+ model_path = model_name_or_path
+ elif from_hub == 'huggingface':
+ from huggingface_hub import snapshot_download
+ model_path = snapshot_download(
+ repo_id=model_name_or_path, cache_dir=cache_dir)
+ elif from_hub == 'modelscope':
+ from modelscope import snapshot_download
+ model_path = snapshot_download(
+ model_id=model_name_or_path, cache_dir=cache_dir)
+ else:
+ # TODO support openxlab
+ raise NotImplementedError('The model does not support downloading '
+ f'from {from_hub}, it only supports '
+ '`huggingface` and `modelscope`.')
+
+ return model_path
+
+
+class AutoModelForCausalLM:
+ """Enhanced version of Huggingface's `AutoModelForCausalLM`.
+
+ Compared to HuggingFace's `AutoModelForCausalLM`, the following three
+ features have been added:
+
+ 1. Load the model from either HuggingFace or ModelScope based on the
+ environment variable `XTUNER_USE_MODELSCOPE` (bool).
+ 2. Automatically enables Flash Attention. If `flash-attn` is already
+ installed, Flash Attention 2 will be used. If there is no
+ `flash-attn`, use Flash Attention 1 when torch version is less than
+ 2.2. When torch version is greater than or equal to 2.2, use Flash
+ Attention 2.
+ 3. When the length of the target sequence during training exceeds the
+ maximum length of the original model, the rope scaling is
+ automatically set to the `linear` type with a factor of 1."
+
+ Note:
+ If the model is built through `from_config`, it will not automatically
+ enable flash attention or modify rope scaling.
+
+ Note:
+ If you want to load the model on ModelScope, please set the
+ environment variable `XTUNER_USE_MODELSCOPE=1`.
+ """
+
+ @classmethod
+ def from_config(cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: bool = True,
+ **kwargs):
+ """Consistent with the usage of HuggingFace's AutoModelForCausalLM."""
+ return AutoConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: bool = True,
+ quantization_config: Optional[BitsAndBytesConfig] = None,
+ max_position_embeddings: Optional[int] = None,
+ **kwargs):
+ """Consistent with the usage of HuggingFace's AutoModelForCausalLM."""
+ config = cls.from_config(
+ pretrained_model_name_or_path, trust_remote_code=True)
+
+ attn_kwargs = cls._flash_attn_kwargs(config)
+ kwargs.update(attn_kwargs)
+
+ if max_position_embeddings:
+ long_ctx_kwargs = cls._long_ctx_kwargs(config,
+ max_position_embeddings)
+ kwargs.update(long_ctx_kwargs)
+
+ if 'torch_dtype' not in kwargs:
+ if torch.cuda.is_bf16_supported():
+ kwargs.update(torch_dtype=torch.bfloat16)
+ else:
+ kwargs.update(torch_dtype=torch.float16)
+
+ model = OriAutoModelForCausalLM.from_pretrained(
+ pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ quantization_config=quantization_config,
+ **kwargs)
+
+ from xtuner._lite.accelerate import dispatch_modules
+ dispatch_modules(model, use_varlen_attn=True)
+
+ return model
+
+ @staticmethod
+ def _flash_attn_kwargs(config: PretrainedConfig) -> dict:
+ """Arguments Required to Enable Flash Attention."""
+ cls_name = type(config).__name__
+ _built_in_flash_attn_1 = ('LlamaConfig', 'GemmaConfig',
+ 'MistralConfig', 'MixtralConfig',
+ 'Qwen2Config', 'Starcoder2Config',
+ 'Starcoder2Config')
+
+ _built_in_flash_attn_2 = ('InternLMConfig', 'InternLM2Config',
+ 'LlamaConfig', 'GemmaConfig',
+ 'MistralConfig', 'MixtralConfig',
+ 'Qwen2Config', 'Starcoder2Config',
+ 'Starcoder2Config')
+
+ attn_kwargs = {}
+ if SUPPORT_FLASH2 and cls_name in _built_in_flash_attn_2:
+ attn_kwargs.update(attn_implementation='flash_attention_2')
+ elif SUPPORT_FLASH1 and cls_name in _built_in_flash_attn_1:
+ attn_kwargs.update(attn_implementation='sdpa')
+
+ return attn_kwargs
+
+ @staticmethod
+ def _long_ctx_kwargs(config: PretrainedConfig,
+ max_position_embeddings: int) -> dict:
+ """Arguments Required for Long Context Training."""
+ ori_rope_scaling = getattr(config, 'rope_scaling', None)
+ if ori_rope_scaling is None:
+ ori_rope_scaling = {'factor': 1}
+
+ if 'factor' in ori_rope_scaling.keys():
+ ori_rope_scaling_factor = ori_rope_scaling['factor']
+ else:
+ ori_rope_scaling_factor = 1
+
+ ori_ctx_len = getattr(config, 'max_position_embeddings', None)
+
+ long_text_kwargs = {}
+ if ori_ctx_len:
+ ori_ctx_len *= ori_rope_scaling_factor
+ if max_position_embeddings > ori_ctx_len:
+ scaling_factor = float(
+ math.ceil(max_position_embeddings / ori_ctx_len))
+
+ new_rope_scaling = {'type': 'linear', 'factor': scaling_factor}
+ long_text_kwargs.update(dict(rope_scaling=new_rope_scaling))
+ return long_text_kwargs
diff --git a/xtuner/_lite/chat/__init__.py b/xtuner/_lite/chat/__init__.py
new file mode 100644
index 000000000..6443e50b4
--- /dev/null
+++ b/xtuner/_lite/chat/__init__.py
@@ -0,0 +1,6 @@
+from .messages import ChatMessages
+from .templates import CHAT_TEMPLATE_MAP, ChatTemplate, HybridChatTemplate
+
+__all__ = [
+ 'ChatMessages', 'CHAT_TEMPLATE_MAP', 'ChatTemplate', 'HybridChatTemplate'
+]
diff --git a/xtuner/_lite/chat/backends/__init__.py b/xtuner/_lite/chat/backends/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/xtuner/_lite/chat/messages/__init__.py b/xtuner/_lite/chat/messages/__init__.py
new file mode 100644
index 000000000..b8c75b45d
--- /dev/null
+++ b/xtuner/_lite/chat/messages/__init__.py
@@ -0,0 +1,4 @@
+from .base import BaseMessages
+from .chat import ChatMessages
+
+__all__ = ['BaseMessages', 'ChatMessages']
diff --git a/xtuner/_lite/chat/messages/base.py b/xtuner/_lite/chat/messages/base.py
new file mode 100644
index 000000000..0266986ac
--- /dev/null
+++ b/xtuner/_lite/chat/messages/base.py
@@ -0,0 +1,31 @@
+from abc import abstractclassmethod, abstractmethod
+from typing import Dict
+
+from pydantic import BaseModel
+from transformers import PreTrainedTokenizer
+
+from ..templates import ChatTemplate
+
+
+class BaseMessages(BaseModel):
+
+ @abstractmethod
+ def add(self, role: str, content):
+ pass
+
+ @abstractmethod
+ def pop(self):
+ pass
+
+ @abstractmethod
+ def get_prompt(self, chat_template: ChatTemplate) -> str:
+ pass
+
+ @abstractmethod
+ def tokenize(self, tokenizer: PreTrainedTokenizer,
+ chat_template: ChatTemplate) -> Dict:
+ pass
+
+ @abstractclassmethod
+ def from_dict(cls, item: Dict) -> 'BaseMessages':
+ pass
diff --git a/xtuner/_lite/chat/messages/chat.py b/xtuner/_lite/chat/messages/chat.py
new file mode 100644
index 000000000..af1756e8e
--- /dev/null
+++ b/xtuner/_lite/chat/messages/chat.py
@@ -0,0 +1,213 @@
+import copy
+from typing import Dict, List, Literal, Optional, Union
+
+from pydantic import BaseModel
+from transformers import PreTrainedTokenizer
+
+from xtuner._lite import get_logger
+from xtuner.utils import IGNORE_INDEX
+from ..templates import ChatTemplate, HybridChatTemplate
+from .base import BaseMessages
+
+logger = get_logger()
+
+
+class TextContentItem(BaseModel):
+ type: Literal['text'] = 'text'
+ text: str
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+ return self.text
+
+
+class ImageContentItem(BaseModel):
+ type: Literal['image_url'] = 'image_url'
+ image_url: str
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+ return chat_template.image_token
+
+
+MultModalContentType = Union[TextContentItem, ImageContentItem]
+ContentType = Union[str, List[MultModalContentType]]
+
+
+class ChatMsg(BaseModel):
+
+ role: Literal['assistant', 'user', 'system']
+ content: ContentType
+ loss: Optional[bool] = None
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.loss is None:
+ if self.role == 'system':
+ self.loss = False
+ elif self.role == 'user':
+ self.loss = False
+ elif self.role == 'assistant':
+ self.loss = True
+ else:
+ raise NotImplementedError
+
+ def collect_img_urls(self) -> List[str]:
+ img_urls = []
+ if isinstance(self.content, list):
+ for item in self.content:
+ if isinstance(item, ImageContentItem):
+ img_urls.append(item.image_url)
+ return img_urls
+
+ def get_prompt(self, chat_template: ChatTemplate) -> str:
+
+ if isinstance(self.content, str):
+ text = self.content
+ elif isinstance(self.content, list):
+ text = ''
+ for i, item in enumerate(self.content):
+ if i == 0:
+ text += item.apply_chat_template(chat_template)
+ else:
+ text += '\n' + item.apply_chat_template(chat_template)
+ else:
+ raise NotImplementedError
+
+ if self.role == 'system':
+ prompt = chat_template.decorate_system(text)
+ elif self.role == 'user':
+ prompt = chat_template.decorate_user(text)
+ elif self.role == 'assistant':
+ prompt = chat_template.decorate_assistant(text)
+ else:
+ raise NotImplementedError
+
+ return prompt
+
+ def tokenize(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ chat_template: ChatTemplate,
+ ):
+
+ decorated = self.get_prompt(chat_template)
+
+ token_ids = tokenizer.encode(decorated, add_special_tokens=False)
+
+ if self.loss:
+ label_ids = copy.deepcopy(token_ids)
+ else:
+ label_ids = [IGNORE_INDEX] * len(token_ids)
+
+ return {
+ 'input_ids': token_ids,
+ 'labels': label_ids,
+ }
+
+
+class ChatMessages(BaseMessages):
+
+ messages: List[ChatMsg]
+
+ def add(self, role, content, loss=False):
+ self.messages.append(ChatMsg(role=role, content=content, loss=loss))
+
+ def pop(self):
+ return self.messages.pop()
+
+ def get_prompt(self, chat_template: ChatTemplate) -> str:
+
+ prompt = ''
+
+ for msg in self.messages:
+ prompt += msg.get_prompt(chat_template)
+ if msg.role == 'assistant':
+ prompt += chat_template.sep
+ return prompt
+
+ def tokenize(self, tokenizer: PreTrainedTokenizer,
+ chat_template: ChatTemplate) -> Dict:
+
+ input_ids = tokenizer.encode('', add_special_tokens=True)
+ labels = [IGNORE_INDEX for _ in input_ids]
+ image_urls = []
+
+
+
+ for msg in self.messages:
+ res = msg.tokenize(tokenizer, chat_template)
+ token_ids, label_ids = res['input_ids'], res['labels']
+
+ input_ids.extend(token_ids)
+ labels.extend(label_ids)
+
+ image_urls.extend(msg.collect_img_urls())
+
+ if msg.role == 'assistant':
+ sep = chat_template.sep
+ sep_tokens = tokenizer.encode(sep, add_special_tokens=False)
+ input_ids.extend(sep_tokens)
+ labels.extend([IGNORE_INDEX] * len(sep_tokens))
+
+ if len(input_ids) != len(labels):
+ logger.error(f'[messages] {self.messages}')
+ logger.error(f'[input_ids] {input_ids}')
+ logger.error(f'[labels] {labels}')
+ raise RuntimeError('The lengths of input_ids and labels must be '
+ f'equal, but found {len(input_ids)} and '
+ f'{len(labels)}.')
+
+ training_data = {
+ 'input_ids': input_ids,
+ 'labels': labels,
+ 'num_tokens': len(input_ids),
+ }
+
+ if len(image_urls) > 0:
+ training_data['image_urls'] = image_urls
+
+ return training_data
+
+ @classmethod
+ def from_str(cls, prompt: str) -> 'ChatMessages':
+
+ msg = ChatMsg(role='user', content=prompt)
+ return cls(messages=[msg])
+
+ @classmethod
+ def from_dict(cls, item: dict) -> 'ChatMessages':
+ '''
+ item
+ {
+ 'messages':[
+ {'role':'user', 'content':'hello'},
+ {'role':'assistant', 'content':'hello!'},
+ ],
+ }
+ '''
+ return cls(**item)
+
+
+if __name__ == '__main__':
+
+ data = {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': 'hello'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'hello!'
+ },
+ ]
+ }
+
+ messages = ChatMessages.from_dict(data)
+ chat_template = ChatTemplate(
+ system='<|im_start|>system\n{system}<|im_end|>\n',
+ user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
+ assistant='{assistant}<|im_end|>\n',
+ stop_words=['<|im_end|>'],
+ )
+
+ print(messages.get_prompt(chat_template))
diff --git a/xtuner/_lite/chat/templates/__init__.py b/xtuner/_lite/chat/templates/__init__.py
new file mode 100644
index 000000000..7ed468e20
--- /dev/null
+++ b/xtuner/_lite/chat/templates/__init__.py
@@ -0,0 +1,29 @@
+from .chat import ChatTemplate
+from .hybrid import HybridChatTemplate
+
+CHAT_TEMPLATE_MAP = {
+ 'internlm2':
+ HybridChatTemplate(
+ system='<|im_start|>system\n{system}<|im_end|>\n',
+ user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
+ assistant='{assistant}<|im_end|>',
+ stop_words=['<|im_end|>']),
+ 'qwen2':
+ HybridChatTemplate(
+ system='<|im_start|>system\n{system}<|im_end|>\n',
+ user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
+ assistant='{assistant}<|im_end|>',
+ stop_words=['<|im_end|>', '<|endoftext|>']),
+ 'llama3':
+ HybridChatTemplate(
+ system=('<|start_header_id|>system<|end_header_id|>\n\n{system}'
+ '<|eot_id|>'),
+ user=('<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>'
+ '<|start_header_id|>assistant<|end_header_id|>\n\n'),
+ assistant='{assistant}<|eot_id|>',
+ sep='',
+ stop_words=['<|eot_id|>']),
+
+}
+
+__all__ = ['ChatTemplate', 'HybridChatTemplate']
diff --git a/xtuner/_lite/chat/templates/chat.py b/xtuner/_lite/chat/templates/chat.py
new file mode 100644
index 000000000..9ce574fef
--- /dev/null
+++ b/xtuner/_lite/chat/templates/chat.py
@@ -0,0 +1,59 @@
+from typing import List
+
+from pydantic import BaseModel, field_validator
+
+
+class ChatTemplate(BaseModel):
+ """Define a Pydantic data model for a hybrid chat with attributes for
+ system, user and assistant chat as well as function and interpreter calls
+ and results."""
+
+ # Normal Chat
+ system: str # System message format
+ user: str # User message format
+ assistant: str # Assistant message format
+ stop_words: List[str] # List of stop words
+ sep: str = '\n'
+
+ def decorate_system(self, text: str) -> str:
+ """Decorate text with the `system` template."""
+ return self.system.format(system=text)
+
+ def decorate_assistant(self, text: str) -> str:
+ """Decorate text with the `assistant` template."""
+ return self.assistant.format(assistant=text)
+
+ def decorate_user(self, text: str) -> str:
+ """Decorate text with the `user` template."""
+ return self.user.format(user=text)
+
+ @field_validator('system')
+ def check_system(cls, v: str) -> str:
+ """Validate that `system` contains '{system}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{system}' not in v:
+ raise ValueError("system must contain the keyword '{system}'")
+ return v
+
+ @field_validator('user')
+ def check_user(cls, v: str) -> str:
+ """Validate that `user` contains '{user}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{user}' not in v:
+ raise ValueError("user must contain the keyword '{user}'")
+ return v
+
+ @field_validator('assistant')
+ def check_assistant(cls, v: str) -> str:
+ """Validate that `assistant` contains '{assistant}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{assistant}' not in v:
+ raise ValueError(
+ "assistant must contain the keyword '{assistant}'")
+ return v
diff --git a/xtuner/_lite/chat/templates/hybrid.py b/xtuner/_lite/chat/templates/hybrid.py
new file mode 100644
index 000000000..d9f563c0d
--- /dev/null
+++ b/xtuner/_lite/chat/templates/hybrid.py
@@ -0,0 +1,196 @@
+from typing import Dict, List, Optional
+
+from pydantic import BaseModel, field_validator
+
+
+class HybridChatTemplate(BaseModel):
+ """Define a Pydantic data model for a hybrid chat with attributes for
+ system, user and assistant chat as well as function and interpreter calls
+ and results."""
+
+ # Normal Chat
+ system: str # System message format
+ user: str # User message format
+ assistant: str # Assistant message format
+ stop_words: List[str] # List of stop words
+ sep: str = '\n'
+
+ # Multimodal Chat
+ # Predefined token and index for images
+ image_token: str = ''
+ image_token_index: int = -100
+
+ # Agent Chat
+
+ # Interpreter and function related strings
+ files: Optional[str] = None
+
+ functions: Optional[str] = None # Function description format
+ function_call: Optional[str] = None # Function call format
+ function_result: Optional[str] = None # Function result format
+
+ code_interpreter: Optional[str] = None
+ code_interpreter_call: Optional[str] = None # Interpreter call format
+ code_interpreter_result: Optional[str] = None # Interpreter result format
+
+ function_token: Optional[str] = None
+ code_interpreter_token: Optional[str] = None
+ action_start_token: Optional[str] = None
+ action_end_token: Optional[str] = None
+
+ @property
+ def mm_token_maps(self) -> Dict[str, int]:
+ """Return a dictionary that maps multimodal tokens to corresponding
+ token indexes."""
+ return {self.image_token: self.image_token_index}
+
+ def decorate_system(self, text: str) -> str:
+ """Decorate text with the `system` template."""
+ return self.system.format(system=text)
+
+ def decorate_assistant(self, text: str) -> str:
+ """Decorate text with the `assistant` template."""
+ return self.assistant.format(assistant=text)
+
+ def decorate_user(self, text: str) -> str:
+ """Decorate text with the `user` template."""
+ return self.user.format(user=text)
+
+ def decorate_files(self, text: str) -> str:
+ """Decorate text with the `functions` template."""
+ return self.files.format(files=text)
+
+ def decorate_functions(self, text: str) -> str:
+ """Decorate text with the `functions` template."""
+ return self.functions.format(functions=text)
+
+ def decorate_function_call(self, text: str, func: str) -> str:
+ """Decorate text with the `function_call` template."""
+ return self.function_call.format(assistant=text, function_call=func)
+
+ def decorate_function_result(self, text: str) -> str:
+ """Decorate text with the `function_result` template."""
+ return self.function_result.format(function_result=text)
+
+ def decorate_code_interpreter(self, text: str) -> str:
+ """Decorate text with the `code_interpreter` template."""
+ return self.code_interpreter.format(code_interpreter=text)
+
+ def decorate_code_interpreter_call(self, text: str, func: str) -> str:
+ """Decorate text with the `code_interpreter_call` template."""
+ return self.code_interpreter_call.format(
+ assistant=text, code_interpreter_call=func)
+
+ def decorate_code_interpreter_result(self, text: str) -> str:
+ """Decorate text with the `code_interpreter_result` template."""
+ return self.code_interpreter_result.format(
+ code_interpreter_result=text)
+
+ @field_validator('system')
+ def check_system(cls, v: str) -> str:
+ """Validate that `system` contains '{system}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{system}' not in v:
+ raise ValueError("system must contain the keyword '{system}'")
+ return v
+
+ @field_validator('user')
+ def check_user(cls, v: str) -> str:
+ """Validate that `user` contains '{user}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{user}' not in v:
+ raise ValueError("user must contain the keyword '{user}'")
+ return v
+
+ @field_validator('assistant')
+ def check_assistant(cls, v: str) -> str:
+ """Validate that `assistant` contains '{assistant}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{assistant}' not in v:
+ raise ValueError(
+ "assistant must contain the keyword '{assistant}'")
+ return v
+
+ @field_validator('function_call')
+ def check_function_call(cls, v: str) -> str:
+ """Validate that `function_call` contains '{function_call}'.
+
+ If not, raises a ValueError.
+ """
+ if (v is not None and '{function_call}' not in v
+ and '{assistant}' not in v):
+ raise ValueError(
+ "function_call must contain the keywords '{function_call}'")
+ if v is not None and '{assistant}' not in v:
+ raise ValueError(
+ "function_call must contain the keyword '{assistant}' and "
+ "'{function_call}'")
+ return v
+
+ @field_validator('function_result')
+ def check_function_result(cls, v: str) -> str:
+ """Validate that `function_result` contains '{function_result}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{function_result}' not in v:
+ raise ValueError(
+ "function_result must contain the keyword '{function_result}'")
+ return v
+
+ @field_validator('functions')
+ def check_functions(cls, v: str) -> str:
+ """Validate that `functions` contains '{functions}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{functions}' not in v:
+ raise ValueError(
+ "functions must contain the keyword '{functions}'")
+ return v
+
+ @field_validator('code_interpreter')
+ def check_code_interpreter(cls, v: str) -> str:
+ """Validate that `code_interpreter` contains '{code_interpreter}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{code_interpreter}' not in v:
+ raise ValueError('code_interpreter must contain the keyword '
+ "'{code_interpreter}'")
+ return v
+
+ @field_validator('code_interpreter_call')
+ def check_code_interpreter_call(cls, v: str) -> str:
+ """Validate that `code_interpreter_call` contains
+ '{code_interpreter_call}'.
+
+ If not, raises a ValueError.
+ """
+ if (v is not None and '{code_interpreter_call}' not in v
+ and '{assistant}' not in v):
+ raise ValueError('code_interpreter_call must contain the keywords '
+ "'{assistant}' and '{code_interpreter_call}'")
+ if v is not None and '{assistant}' not in v:
+ raise ValueError('code_interpreter_call must contain the keywords '
+ "'{assistant}' and '{code_interpreter_call}'")
+ return v
+
+ @field_validator('code_interpreter_result')
+ def check_code_interpreter_result(cls, v: str) -> str:
+ """Validate that `code_interpreter_result` contains
+ '{code_interpreter_result}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{code_interpreter_result}' not in v:
+ raise ValueError(
+ 'code_interpreter_result must contain the keyword '
+ "'{code_interpreter_result}'")
+ return v
diff --git a/xtuner/_lite/datasets/__init__.py b/xtuner/_lite/datasets/__init__.py
new file mode 100644
index 000000000..f7dc3b293
--- /dev/null
+++ b/xtuner/_lite/datasets/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .json import JsonDataset
+from .jsonl import JsonlDataset
+from .pack import SoftPackDataset, HardPackDataset
+from .utils import DATASET_CLS_MAP, OPENAI_CONVERT_MAP, load_datasets
+
+__all__ = [
+ 'JsonDataset', 'JsonlDataset', 'SoftPackDataset', 'DATASET_CLS_MAP',
+ 'OPENAI_CONVERT_MAP', 'load_datasets'
+]
diff --git a/xtuner/_lite/datasets/internvl2/__init__.py b/xtuner/_lite/datasets/internvl2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/xtuner/_lite/datasets/internvl2/conversation.py b/xtuner/_lite/datasets/internvl2/conversation.py
new file mode 100644
index 000000000..10f7f6b14
--- /dev/null
+++ b/xtuner/_lite/datasets/internvl2/conversation.py
@@ -0,0 +1,393 @@
+"""
+Conversation prompt templates.
+
+We kindly request that you import fastchat instead of copying this file if you wish to use it.
+If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
+"""
+
+import dataclasses
+from enum import IntEnum, auto
+from typing import Dict, List, Tuple, Union
+
+
+class SeparatorStyle(IntEnum):
+ """Separator styles."""
+
+ ADD_COLON_SINGLE = auto()
+ ADD_COLON_TWO = auto()
+ ADD_COLON_SPACE_SINGLE = auto()
+ NO_COLON_SINGLE = auto()
+ NO_COLON_TWO = auto()
+ ADD_NEW_LINE_SINGLE = auto()
+ LLAMA2 = auto()
+ CHATGLM = auto()
+ CHATML = auto()
+ CHATINTERN = auto()
+ DOLLY = auto()
+ RWKV = auto()
+ PHOENIX = auto()
+ ROBIN = auto()
+ FALCON_CHAT = auto()
+ CHATGLM3 = auto()
+ INTERNVL_ZH = auto()
+ MPT = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that manages prompt templates and keeps all conversation history."""
+
+ # The name of this template
+ name: str
+ # The template of the system prompt
+ system_template: str = '{system_message}'
+ # The system message
+ system_message: str = ''
+ # The names of two roles
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
+ # All messages. Each item is (role, message).
+ messages: List[List[str]] = ()
+ # The number of few shot examples
+ offset: int = 0
+ # The separator style and configurations
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
+ sep: str = '\n'
+ sep2: str = None
+ # Stop criteria (the default one is EOS token)
+ stop_str: Union[str, List[str]] = None
+ # Stops generation if meeting any token in this list
+ stop_token_ids: List[int] = None
+
+ def get_prompt(self) -> str:
+ """Get the prompt for generation."""
+ system_prompt = self.system_template.format(system_message=self.system_message)
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
+ seps = [self.sep, self.sep2]
+ ret = system_prompt + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ': ' + message + seps[i % 2]
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ': ' # must be end with a space
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
+ ret = '' if system_prompt == '' else system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + '\n' + message + self.sep
+ else:
+ ret += role + '\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
+ ret = system_prompt
+ for role, message in self.messages:
+ if message:
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + message + seps[i % 2]
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.RWKV:
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += (
+ role
+ + ': '
+ + message.replace('\r\n', '\n').replace('\n\n', '\n')
+ )
+ ret += '\n\n'
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.LLAMA2:
+ seps = [self.sep, self.sep2]
+ if self.system_message:
+ ret = system_prompt
+ else:
+ ret = '[INST] '
+ for i, (role, message) in enumerate(self.messages):
+ tag = self.roles[i % 2]
+ if message:
+ if i == 0:
+ ret += message + ' '
+ else:
+ ret += tag + ' ' + message + seps[i % 2]
+ else:
+ ret += tag
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATGLM:
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
+ round_add_n = 1 if self.name == 'chatglm2' else 0
+ if system_prompt:
+ ret = system_prompt + self.sep
+ else:
+ ret = ''
+
+ for i, (role, message) in enumerate(self.messages):
+ if i % 2 == 0:
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
+
+ if message:
+ ret += f'{role}:{message}{self.sep}'
+ else:
+ ret += f'{role}:'
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATML:
+ ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
+ for role, message in self.messages:
+ if message:
+ ret += role + '\n' + message + self.sep + '\n'
+ else:
+ ret += role + '\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATGLM3:
+ ret = ''
+ if self.system_message:
+ ret += system_prompt
+ for role, message in self.messages:
+ if message:
+ ret += role + '\n' + ' ' + message
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ # if i % 2 == 0:
+ # ret += ""
+ if message:
+ ret += role + ':' + message + seps[i % 2] + '\n'
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.DOLLY:
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ':\n' + message + seps[i % 2]
+ if i % 2 == 1:
+ ret += '\n\n'
+ else:
+ ret += role + ':\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.PHOENIX:
+ ret = system_prompt
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + '' + message + ''
+ else:
+ ret += role + ': ' + ''
+ return ret
+ elif self.sep_style == SeparatorStyle.ROBIN:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ':\n' + message + self.sep
+ else:
+ ret += role + ':\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.FALCON_CHAT:
+ ret = ''
+ if self.system_message:
+ ret += system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ':'
+
+ return ret
+ elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
+ seps = [self.sep, self.sep2]
+ ret = self.system_message + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ': ' + message + seps[i % 2]
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ else:
+ raise ValueError(f'Invalid style: {self.sep_style}')
+
+ def set_system_message(self, system_message: str):
+ """Set the system message."""
+ self.system_message = system_message
+
+ def append_message(self, role: str, message: str):
+ """Append a new message."""
+ self.messages.append([role, message])
+
+ def update_last_message(self, message: str):
+ """Update the last output.
+
+ The last message is typically set to be None when constructing the prompt,
+ so we need to update it in-place after getting the response from a model.
+ """
+ self.messages[-1][1] = message
+
+ def to_gradio_chatbot(self):
+ """Convert the conversation to gradio chatbot format."""
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def to_openai_api_messages(self):
+ """Convert the conversation to OpenAI chat completion format."""
+ ret = [{'role': 'system', 'content': self.system_message}]
+
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append({'role': 'user', 'content': msg})
+ else:
+ if msg is not None:
+ ret.append({'role': 'assistant', 'content': msg})
+ return ret
+
+ def copy(self):
+ return Conversation(
+ name=self.name,
+ system_template=self.system_template,
+ system_message=self.system_message,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ stop_str=self.stop_str,
+ stop_token_ids=self.stop_token_ids,
+ )
+
+ def dict(self):
+ return {
+ 'template_name': self.name,
+ 'system_message': self.system_message,
+ 'roles': self.roles,
+ 'messages': self.messages,
+ 'offset': self.offset,
+ }
+
+
+# A global registry for all conversation templates
+conv_templates: Dict[str, Conversation] = {}
+
+
+def register_conv_template(template: Conversation, override: bool = False):
+ """Register a new conversation template."""
+ if not override:
+ assert (
+ template.name not in conv_templates
+ ), f'{template.name} has been registered.'
+
+ conv_templates[template.name] = template
+
+
+def get_conv_template(name: str) -> Conversation:
+ """Get a conversation template."""
+ return conv_templates[name].copy()
+
+
+# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference
+# is that during training, the preprocessing function for the Hermes-2 template doesn't add
+# at the beginning of the tokenized sequence, while the internlm2-chat template does.
+# Therefore, they are completely equivalent during inference.
+register_conv_template(
+ Conversation(
+ name='Hermes-2',
+ system_template='<|im_start|>system\n{system_message}',
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
+ sep_style=SeparatorStyle.MPT,
+ sep='<|im_end|>',
+ stop_token_ids=[
+ 2,
+ 6,
+ 7,
+ 8,
+ ],
+ stop_str='<|endoftext|>',
+ )
+)
+
+
+register_conv_template(
+ Conversation(
+ name='internlm2-chat',
+ system_template='<|im_start|>system\n{system_message}',
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
+ sep_style=SeparatorStyle.MPT,
+ sep='<|im_end|>',
+ stop_token_ids=[
+ 2,
+ 92543,
+ 92542
+ ]
+ )
+)
+
+
+register_conv_template(
+ Conversation(
+ name='phi3-chat',
+ system_template='<|system|>\n{system_message}',
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
+ roles=('<|user|>\n', '<|assistant|>\n'),
+ sep_style=SeparatorStyle.MPT,
+ sep='<|end|>',
+ stop_token_ids=[
+ 2,
+ 32000,
+ 32007
+ ]
+ )
+)
diff --git a/xtuner/_lite/datasets/internvl2/dataset.py b/xtuner/_lite/datasets/internvl2/dataset.py
new file mode 100644
index 000000000..bc7bf9653
--- /dev/null
+++ b/xtuner/_lite/datasets/internvl2/dataset.py
@@ -0,0 +1,295 @@
+import os
+import torch.distributed as dist
+from mmengine.utils import mkdir_or_exist
+from torch.utils.data import ConcatDataset, DataLoader, Dataset
+import numpy as np
+import json
+import math
+from concurrent.futures import ProcessPoolExecutor
+from tqdm import tqdm
+import copy
+import random
+
+from ..json import calculate_json_sha256
+from ..jsonl import calculate_jsonl_sha256
+from ..pack import SoftPackDataset
+
+from xtuner._lite import get_logger
+from xtuner._lite.parallel import get_dp_mesh, VLMLengthGroupedSampler, ParallelSampler
+
+logger = get_logger()
+
+
+def _load_json_or_jsonl(json_path):
+ if json_path.endswith('.json'):
+ with open(json_path) as f:
+ data = json.load(f)
+ elif json_path.endswith('.jsonl'):
+ with open(json_path) as f:
+ data = f.readlines()
+ else:
+ raise ValueError(f'Unsupported file format: {json_path}, '
+ f'only support .json and .jsonl.')
+ return data
+
+
+class BaseOrigDataset(Dataset):
+ def __init__(self,
+ data_name,
+ data,
+ chat_template,
+ tokenizer,
+ max_length,
+ image_token_str='',
+ group_by_length=False,
+ pack_data=False,
+ pack_data_cache_dir=None,
+ random_sample=False):
+ self.data_name = data_name
+ self.max_length = max_length
+ self.group_by_length = group_by_length
+ self.pack_data = pack_data
+ self.pack_data_cache_dir = pack_data_cache_dir
+ self.chat_template = chat_template
+ self.image_token_str = image_token_str
+ self.tokenizer = tokenizer
+ self.tokenizer_workers = int(os.environ.get('XTUNER_TOKENIZE_WORKERS', 8))
+
+ try:
+ self.root = data['media_root']
+ except KeyError:
+ self.root = data.get('root', '')
+ logger.info(f"{dist.get_rank()} ======= Start to process dataset: {os.path.basename(data['annotation'])}")
+
+ self.annotation = data['annotation']
+ self._is_jsonl = self.annotation.endswith('.jsonl')
+ self.raw_data = _load_json_or_jsonl(self.annotation)
+
+ # -------------------pack---------------------------------------
+ self.num_tokens = None
+ self.pack_data_cache_dir = pack_data_cache_dir
+ if pack_data:
+ assert pack_data_cache_dir is not None, 'pack_data_cache_dir must be provided when pack_data is True'
+ self.num_tokens = self.calc_packing_info()
+ assert len(self.num_tokens) == len(
+ self.raw_data), f'===={len(self.num_tokens)} neq {len(self.raw_data)}===='
+
+ repeat_time = data.get('repeat_time', 1)
+ if repeat_time < 1:
+ # If repeat_time is less than 1, select a portion of the data
+ if random_sample:
+ num_samples = int(len(self.raw_data) * repeat_time)
+ sampled = random.sample([i for i in range(len(self.raw_data))], num_samples)
+ self.raw_data = [self.raw_data[index] for index in sampled]
+ if pack_data:
+ self.num_tokens = self.num_tokens[sampled]
+ else:
+ num_samples = int(len(self.raw_data) * repeat_time)
+ self.raw_data = self.raw_data[:num_samples]
+ if pack_data:
+ self.num_tokens = self.num_tokens[:num_samples]
+
+ if repeat_time > 1:
+ assert isinstance(repeat_time, int)
+ # Repeat the list if repeat_time is greater than 1
+ self.raw_data = self.raw_data * repeat_time
+ if pack_data:
+ self.num_tokens = np.tile(self.num_tokens, repeat_time)
+
+ if pack_data:
+ assert len(self.num_tokens) == len(self.raw_data), f' {len(self.num_tokens)} neq {len(self.raw_data)}'
+
+ self.group_length = []
+ if self.group_by_length and not pack_data:
+ self.group_length = self.calc_group_len()
+
+ def __len__(self):
+ return len(self.raw_data)
+
+ def calc_group_len(self):
+ raise NotImplementedError
+
+ def calc_packing_info(self):
+ if os.path.exists(self.pack_data_cache_dir):
+ assert os.path.isdir(self.pack_data_cache_dir)
+ else:
+ mkdir_or_exist(self.pack_data_cache_dir)
+
+ # TODO: more rubost way to calculate the hash
+ if self._is_jsonl:
+ file_hash = calculate_jsonl_sha256(self.annotation)
+ else:
+ file_hash = calculate_json_sha256(self.annotation)
+ file_cache_dir = os.path.join(self.pack_data_cache_dir, file_hash)
+ if not os.path.exists(file_cache_dir):
+ mkdir_or_exist(file_cache_dir)
+
+ if 'num_tokens.npy' in os.listdir(file_cache_dir):
+ _cached_file = os.path.join(file_cache_dir, 'num_tokens.npy')
+ num_tokens = np.load(_cached_file)
+ logger.info(f"Load num_tokens from cache: {os.path.basename(self.annotation)}")
+ else:
+ logger.info(f"Start calculating the cache of num_tokens: {os.path.basename(self.annotation)}")
+ num_tokens = self.count_tokens_for_pack(file_cache_dir)
+ return num_tokens
+
+ def count_tokens_for_pack(self, cache_dir=None):
+ num_samples = len(self.raw_data)
+
+ if dist.is_available():
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ else:
+ world_size = 1
+ rank = 0
+
+ num_per_rank = math.ceil(num_samples / world_size)
+
+ start = rank * num_per_rank
+ end = (rank + 1) * num_per_rank
+ dataset_shard = self.raw_data[start:end]
+
+ desc = f'[Rank {rank}] {os.path.basename(self.annotation)}'
+ with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor:
+ tokenized = list(
+ tqdm(
+ executor.map(self.pre_tokenize_fn_for_pack, dataset_shard,
+ chunksize=min(max(1, len(dataset_shard) // self.tokenizer_workers), 500)),
+ desc=desc,
+ total=len(dataset_shard)))
+
+ _num_tokens = [data['num_tokens'] for data in tokenized]
+ _num_tokens = np.array(_num_tokens)
+
+ if dist.is_available():
+ num_tokens = [None] * world_size
+ dist.all_gather_object(num_tokens, _num_tokens)
+ num_tokens = np.concatenate(num_tokens, axis=0)
+ else:
+ num_tokens = _num_tokens
+
+ if rank == 0 and cache_dir:
+ save_path = os.path.join(cache_dir, 'num_tokens.npy')
+ np.save(save_path, num_tokens)
+
+ return num_tokens
+
+ def pre_tokenize_fn_for_pack(self, data):
+ raise NotImplementedError
+
+ def process_text(self, conversations, media_type='image', image_grids=None):
+ while conversations and conversations[0]['from'] == 'gpt':
+ # Skip the first one if it is from gpt
+ conversations = conversations[1:]
+
+ assert len(conversations) % 2 == 0, f'Invalid conversation length: {len(conversations)}'
+
+ input_ = ''
+ out_conversation = []
+ for msg in conversations:
+ if msg['from'] == 'human':
+ input_ += msg['value'].strip()
+ elif msg['from'] == 'gpt':
+ out_conversation.append({
+ 'input': input_,
+ 'output': msg['value'].strip()
+ })
+ input_ = ''
+ else:
+ raise NotImplementedError(f'Unsupported message type: {msg}')
+
+ input_ids, labels = [], []
+ for i, single_turn_conversation in enumerate(out_conversation):
+ input_ = single_turn_conversation.get('input', '')
+ if input_ is None:
+ input_ = ''
+ input_ = self.chat_template['user'].format(user=input_)
+
+ if i == 0:
+ input_ = self._process_media_format_first_round(input_, media_type, image_grids)
+ # TODO: support system prompt
+ # input_ = self.chat_template['system'] + input_
+ input_encode = self.tokenizer.encode(input_, add_special_tokens=True)
+ else:
+ input_encode = self.tokenizer.encode(input_, add_special_tokens=False)
+
+ input_ids += input_encode
+ labels += [-100] * len(input_encode)
+
+ output_text = single_turn_conversation.get('output', '')
+ output_encode = self.chat_template['assistant'].format(assistant=output_text)
+ output_encode = self.tokenizer.encode(output_encode, add_special_tokens=False)
+ input_ids += output_encode
+ labels += copy.deepcopy(output_encode)
+
+ if len(input_ids) > self.max_length:
+ input_ids = input_ids[:self.max_length]
+ labels = labels[:self.max_length]
+ logger.info(
+ f'Warning: input_ids length({len(input_ids)}) '
+ f'is longer than max_length, cut to {self.max_length}')
+ return {'input_ids': input_ids, 'labels': labels}
+
+ def _process_media_format_first_round(self, input_, media_type, image_grids):
+ raise NotImplementedError
+
+ @property
+ def modality_length(self):
+ return self.group_length
+
+ @property
+ def length(self):
+ group_length = np.array(self.group_length)
+ group_length = np.abs(group_length).tolist()
+ return group_length
+
+
+def build_dataset(args, datasets):
+ assert len(datasets) > 0, 'No dataset found.'
+ if args.dset_pack:
+ train_dataset = SoftPackDataset(datasets,
+ target=args.pack_max_length,
+ blend=args.concat_before_pack)
+ else:
+ train_dataset = ConcatDataset(datasets)
+ if dist.get_rank() == 0:
+ logger.info(f'[Dataset] (Original) {len(train_dataset)} samples.')
+ return train_dataset
+
+
+def build_train_dataloader(args, train_dataset, collate_fn):
+ dp_mesh = get_dp_mesh()
+ if args.group_by_length:
+ if args.dset_pack:
+ length_property = 'longest'
+ else:
+ length_property = 'length'
+ sampler = VLMLengthGroupedSampler(train_dataset, dp_mesh,
+ args.global_batch_size,
+ seed=args.seed,
+ length_property=length_property)
+ elif args.group_by_modality_length:
+ if args.dset_pack:
+ raise NotImplementedError
+ else:
+ sampler = VLMLengthGroupedSampler(train_dataset, dp_mesh,
+ args.global_batch_size,
+ seed=args.seed,
+ length_property='modality_length')
+ else:
+ sampler = ParallelSampler(
+ train_dataset, dp_mesh, args.global_batch_size, seed=args.seed, shuffle=True)
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.mirco_batch_size,
+ num_workers=args.num_workers,
+ sampler=sampler,
+ collate_fn=collate_fn,
+ persistent_workers=args.num_workers > 0)
+
+ if dist.get_rank() == 0:
+ logger.info(f'[Dataloader] {len(train_dataloader)} batches.')
+
+ dist.barrier()
+ return train_dataloader
diff --git a/xtuner/_lite/datasets/internvl2/process.py b/xtuner/_lite/datasets/internvl2/process.py
new file mode 100644
index 000000000..f0c48d752
--- /dev/null
+++ b/xtuner/_lite/datasets/internvl2/process.py
@@ -0,0 +1,705 @@
+import io
+
+from transformers.trainer_pt_utils import LabelSmoother
+
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+from typing import Dict
+import torch
+import torchvision.transforms as T
+import transformers
+from .conversation import get_conv_template
+from PIL import Image
+from torchvision.transforms.functional import InterpolationMode
+import sys
+
+
+IMG_CONTEXT_TOKEN = ''
+IMG_START_TOKEN = ''
+IMG_END_TOKEN = ''
+QUAD_START_TOKEN = ''
+QUAD_END_TOKEN = ''
+REF_START_TOKEN = '['
+REF_END_TOKEN = ']'
+BOX_START_TOKEN = ''
+BOX_END_TOKEN = ''
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
+CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
+SIGLIP_MEAN = (0.5, 0.5, 0.5)
+SIGLIP_STD = (0.5, 0.5, 0.5)
+IGNORE_INDEX = -100
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def simulate_jpeg_degradation(quality):
+ def jpeg_degrade(img):
+ with io.BytesIO() as output:
+ img.convert('RGB').save(output, format='JPEG', quality=quality)
+ output.seek(0) # Move the reading cursor to the start of the stream
+ img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
+ return img_jpeg
+ return jpeg_degrade
+
+
+# Define the JPEG compression quality range, pre-create all JPEG compression functions
+qualities = list(range(75, 101))
+jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
+
+
+def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
+ if normalize_type == 'imagenet':
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ elif normalize_type == 'clip':
+ MEAN, STD = CLIP_MEAN, CLIP_STD
+ elif normalize_type == 'siglip':
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
+ else:
+ raise NotImplementedError
+ if is_train: # use data augumentation
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+ else:
+ if pad2square is False: # now we use this transform function by default
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+ else:
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+
+ return transform
+
+
+def preprocess(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
+
+ # Mask targets. Only compute loss on the assistant outputs.
+ sep = conv.sep + conv.roles[1] + ': '
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ turns = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID
+ for i, turn in enumerate(turns):
+ if turn == '':
+ break
+ turn_len = len(tokenizer(turn).input_ids)
+
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ if i != 0 and not tokenizer.legacy:
+ # The legacy and non-legacy modes handle special tokens differently
+ instruction_len -= 1
+
+ # Ignore the user instructions
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
+ cur_len += turn_len
+
+ if i != 0 and not tokenizer.legacy:
+ # The legacy and non-legacy modes handle special tokens differently
+ cur_len -= 1
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ logger.info(tokenizer.decode(z))
+ exit()
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
+ )
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_mpt(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ # Mask targets. Only compute loss on the assistant outputs.
+ sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ turns = conversation.split(conv.sep)
+ re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
+ for conv_idx in range(3, len(turns), 2):
+ re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
+ cur_len = 0
+ target[:cur_len] = IGNORE_TOKEN_ID
+ for i, turn in enumerate(re_turns):
+ if turn == '':
+ break
+ turn_len = len(tokenizer(turn).input_ids) + 1
+
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ instruction_len = len(tokenizer(parts[0]).input_ids)
+
+ # Ignore the user instructions
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
+ # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
+ # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
+ # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
+ cur_len += turn_len
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
+ )
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_phi3(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ tokenizer.padding_side = 'right'
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ # Mask targets. Only compute loss on the assistant outputs.
+ sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
+
+ turns = conversation.split(conv.sep)
+ re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
+ for conv_idx in range(3, len(turns), 2):
+ re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID
+ endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
+ target[target == endoftext_id] = IGNORE_TOKEN_ID
+
+ for i, turn in enumerate(re_turns):
+ if turn == '':
+ break
+ if i == 0:
+ turn_len = len(tokenizer(turn).input_ids)
+ else:
+ turn_len = len(tokenizer(turn).input_ids) - 1
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if i == 0:
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
+ else:
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ # Ignore the user instructions
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
+ # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
+ # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
+ # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
+ cur_len += turn_len
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ print(repr(tokenizer.decode(z)))
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
+ )
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_phi3_fast(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ conv.append_message(role, sentence['value'])
+
+ assert len(conv.messages) % 2 == 0, f'{ds_name}, {len(conv.messages)}, {conv.messages}'
+ inputs = conv.messages[::2]
+ outputs = conv.messages[1::2]
+
+ input_ids, labels = [], []
+ # input_texts = ''
+ system_prompt = conv.system_template.format(system_message=conv.system_message)
+ input_text = system_prompt + conv.sep
+ # input_texts += input_text
+ input_encode = tokenizer.encode(input_text, add_special_tokens=True)
+ input_ids += input_encode
+ labels += [IGNORE_INDEX] * len(input_encode)
+
+ real_num_images = 0
+ for input_, output_ in zip(inputs, outputs):
+ # output_[0] = '<|assistant|>\n'
+ # 放到 input 而不是 output 是为了和官方对齐
+ input_text = ''.join(input_) + conv.sep + output_[0]
+
+ if not text_only:
+ real_num_images += input_text.count('')
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ input_text = input_text.replace('', image_tokens, 1)
+ assert '' not in input_text, f'error: {ds_name}, {input_text}'
+ output_text = output_[1] + conv.sep
+
+ input_encode = tokenizer.encode(input_text, add_special_tokens=False)
+ output_encode = tokenizer.encode(output_text, add_special_tokens=False)
+ input_ids += input_encode
+ input_ids += output_encode
+ labels += [IGNORE_INDEX] * len(input_encode)
+ labels += output_encode
+
+ # input_texts += input_text
+ # input_texts += output_text
+
+ if not text_only:
+ assert real_num_images == num_image, f'{ds_name} data error: {real_num_images} vs. {num_image}'
+ # print(input_texts)
+ # assert input_ids.count(32013) == num_image_token_list[
+ # 0], f'error1: {input_ids}, {num_image_token_list[0]}, {input_texts}'
+ if len(input_ids) > tokenizer.model_max_length:
+ print(f'WARNING: input_ids length {len(input_ids)} exceeds '
+ f'model_max_length {tokenizer.model_max_length}. truncated!')
+ input_ids = input_ids[:tokenizer.model_max_length]
+ labels = labels[:tokenizer.model_max_length]
+
+ # if not text_only:
+ # if input_ids.count(32013) != num_image_token_list[0]:
+ # print(f'WARNING: IMG_CONTEXT_TOKEN is broken. {input_ids.count(32013)} vs. {num_image_token_list[0]}')
+
+ input_ids = torch.tensor(input_ids, dtype=torch.long)[None]
+ labels = torch.tensor(labels, dtype=torch.long)[None]
+ assert input_ids.size() == labels.size()
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_internlm(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ sentence['value'] = sentence['value'].strip()
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID #
+ parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
+ info = parts[0] + conv.roles[1]
+ temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的
+ target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
+ cur_len = cur_len + temp_len
+
+ for index in range(1, len(parts) - 1):
+ info = parts[index]
+ part1, part2 = info.split(conv.roles[0])
+ temp_len = len(tokenizer(part1).input_ids) - 1
+ cur_len = cur_len + temp_len
+ part = conv.roles[0] + part2 + conv.roles[1]
+ temp_len = len(tokenizer(part).input_ids) - 1
+ target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
+ cur_len = cur_len + temp_len
+ last_info = parts[-1]
+ temp_len = len(tokenizer(last_info).input_ids) - 1
+ cur_len = cur_len + temp_len
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ print(repr(tokenizer.decode(z)))
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float('inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
+ return best_ratio
+
+
+def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
+
+
+def dynamic_num_patch(size, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
+ orig_width, orig_height = size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ if use_thumbnail and blocks > 1:
+ blocks += 1
+ return blocks
+
+
+def packing_collate(features, pack_batch=True, pad_id=0):
+ input_ids = []
+ labels = []
+ pixel_values = []
+ num_tokens = []
+ num_img_tokens = []
+ image_flags = []
+
+ for data in features:
+ input_ids.append(torch.LongTensor(data['input_ids']))
+ labels.append(torch.LongTensor(data['labels']))
+ num_tokens.extend(data['num_tokens'])
+ num_img_tokens.extend(data['num_img_tokens'])
+ pixel_values.append(data['pixel_values'])
+ image_flags.append(data['image_flags'])
+
+ attention_mask = [ids.ne(pad_id) for ids in input_ids]
+ num_tokens = torch.IntTensor(num_tokens)
+ num_img_tokens = torch.IntTensor(num_img_tokens)
+
+ if len(features) > 1 and pack_batch:
+ # batch packing
+ input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
+ labels = torch.cat(labels, dim=0).unsqueeze(0)
+ attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0)
+ image_flags = torch.cat(image_flags, dim=0)
+ pixel_values = torch.cat(pixel_values, dim=0)
+ elif len(features) > 1 and not pack_batch:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ data_dict = {
+ 'input_ids': input_ids,
+ 'labels': labels,
+ 'attention_mask': attention_mask.bool(),
+ 'pixel_values': pixel_values,
+ 'image_flags': image_flags,
+ 'num_tokens': num_tokens,
+ 'num_img_tokens': num_img_tokens,
+ }
+
+ return data_dict
\ No newline at end of file
diff --git a/xtuner/_lite/datasets/json.py b/xtuner/_lite/datasets/json.py
new file mode 100644
index 000000000..3efd91a1d
--- /dev/null
+++ b/xtuner/_lite/datasets/json.py
@@ -0,0 +1,173 @@
+import hashlib
+import inspect
+import json
+import math
+import os
+import random
+from concurrent.futures import ProcessPoolExecutor
+from mmengine import mkdir_or_exist
+import numpy as np
+import torch
+from torch import distributed as dist
+from tqdm import tqdm
+from xtuner._lite import get_logger
+
+
+logger = get_logger()
+
+def calculate_json_sha256(file_path):
+ with open(file_path, 'rb') as f:
+ data = f.read()
+
+ hash_object = hashlib.sha256(data)
+ hash_hex = hash_object.hexdigest()
+ return hash_hex
+
+
+def calculate_tokenize_fn_sha256(tokenize_fn):
+ """Calculate SHA-256 hash for an instance method's source code."""
+ # Get the source code of the method
+ fn_source = inspect.getsource(tokenize_fn.__call__)
+ return hashlib.sha256(fn_source.encode('utf-8')).hexdigest()
+
+
+class JsonDataset(torch.utils.data.Dataset):
+
+ def __init__(self,
+ path,
+ sample_ratio=1.0,
+ tokenize_fn=None,
+ cache_dir=None,
+ max_length=None):
+ super().__init__()
+
+ self.tokenize_fn = tokenize_fn
+ self.path = path
+ self.tokenizer_workers = int(os.environ.get('XTUNER_TOKENIZE_WORKERS', 8))
+
+ if cache_dir:
+ if os.path.exists(cache_dir):
+ assert os.path.isdir(cache_dir)
+ else:
+ mkdir_or_exist(cache_dir)
+
+ file_hash = calculate_json_sha256(path)
+ file_cache_dir = os.path.join(cache_dir, file_hash)
+
+ if file_hash not in os.listdir(cache_dir):
+ mkdir_or_exist(file_cache_dir)
+
+ if self.tokenize_fn:
+ tok_hash = calculate_tokenize_fn_sha256(tokenize_fn)
+ tok_cache_dir = os.path.join(file_cache_dir, tok_hash)
+ if tok_hash not in os.listdir(file_cache_dir):
+ mkdir_or_exist(tok_cache_dir)
+
+ if 'num_tokens.npy' in os.listdir(tok_cache_dir):
+ _cached_file = os.path.join(tok_cache_dir,
+ 'num_tokens.npy')
+ num_tokens = np.load(_cached_file)
+ else:
+ num_tokens = self.count_tokens(tok_cache_dir)
+ else:
+ num_tokens = None
+
+ else:
+ num_tokens = None
+
+ with open(self.path) as f:
+ dataset = json.load(f)
+
+ _sampled = [i for i in range(len(dataset))]
+
+ if max_length is not None:
+ assert isinstance(max_length, int)
+ _filtered = [x for i, x in enumerate(_sampled) if num_tokens[i] < max_length]
+
+ if len(_filtered) < len(_sampled):
+ missed_num = len(_sampled) - len(_filtered)
+ logger.warning(f"{path} has {missed_num} prompt length>{max_length}, discard.")
+
+ _sampled = _filtered
+
+ _target_num_samples = int(len(_sampled) * sample_ratio)
+ self.sampled = _sampled * int(sample_ratio)
+ self.sampled.extend(random.sample(_sampled, _target_num_samples - len(self.sampled)))
+
+ if num_tokens is not None:
+ num_tokens = num_tokens[self.sampled]
+
+ self.num_tokens = num_tokens
+ self.dataset = None
+
+ def count_tokens(self, cache_dir=None):
+
+ dataset = []
+
+ with open(self.path) as f:
+ dataset = json.load(f)
+
+ num_samples = len(dataset)
+
+ if dist.is_available():
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ else:
+ world_size = 1
+ rank = 0
+
+ num_per_rank = math.ceil(num_samples / world_size)
+
+ start = rank * num_per_rank
+ end = (rank + 1) * num_per_rank
+ dataset_shard = dataset[start:end]
+
+ desc = f'[Rank {rank}] {self.path}'
+ chunk_size = min(1024, max(1, len(dataset_shard) // self.tokenizer_workers))
+ with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor:
+ tokenized = list(
+ tqdm(
+ executor.map(self.tokenize_fn, dataset_shard,
+ chunksize=chunk_size),
+ desc=desc,
+ total=len(dataset_shard)))
+
+ _num_tokens = [data['num_tokens'] for data in tokenized]
+ _num_tokens = np.array(_num_tokens)
+
+ if dist.is_available():
+ num_tokens = [None] * world_size
+ dist.all_gather_object(num_tokens, _num_tokens)
+ num_tokens = np.concatenate(num_tokens, axis=0)
+ else:
+ num_tokens = _num_tokens
+
+ if rank == 0 and cache_dir:
+ save_path = os.path.join(cache_dir, 'num_tokens.npy')
+ np.save(save_path, num_tokens)
+
+ return num_tokens
+
+ def __len__(self):
+ return len(self.sampled)
+
+ def __getitem__(self, item):
+ """Returns a dict containing packed data in the given item.
+
+ Args:
+ item: An index to retrieve packed data.
+
+ Returns:
+ A dict including packed input_ids, labels, and cumulative_len.
+ """
+ if self.dataset is None:
+ with open(self.path) as f:
+ self.dataset = json.load(f)
+
+ raw_data = self.dataset[self.sampled[item]]
+
+ if self.tokenize_fn:
+ tokenized_data = self.tokenize_fn(raw_data)
+ return tokenized_data
+ else:
+ return raw_data
diff --git a/xtuner/_lite/datasets/jsonl.py b/xtuner/_lite/datasets/jsonl.py
new file mode 100644
index 000000000..3bfc2c4bb
--- /dev/null
+++ b/xtuner/_lite/datasets/jsonl.py
@@ -0,0 +1,205 @@
+import hashlib
+import inspect
+import json
+import math
+import os
+import random
+from concurrent.futures import ProcessPoolExecutor
+from mmengine import mkdir_or_exist
+import numpy as np
+import torch
+from torch import distributed as dist
+from tqdm import tqdm
+from xtuner._lite import get_logger
+
+logger = get_logger()
+
+
+def calculate_jsonl_sha256(path):
+ with open(path, 'rb') as f:
+ file_hash = hashlib.sha256()
+ file_hash.update(f.read())
+ return file_hash.hexdigest()
+
+
+def calculate_tokenize_fn_sha256(tokenize_fn):
+ """Calculate SHA-256 hash for an instance method's source code."""
+ # Get the source code of the method
+ fn_source = inspect.getsource(tokenize_fn.__call__)
+ return hashlib.sha256(fn_source.encode('utf-8')).hexdigest()
+
+
+class JsonlDataset(torch.utils.data.Dataset):
+
+ def __init__(self,
+ path,
+ sample_ratio=1.0,
+ tokenize_fn=None,
+ cache_dir=None,
+ max_length=None,):
+ super().__init__()
+
+ self.tokenize_fn = tokenize_fn
+ self.path = path
+ self.tokenizer_workers = int(os.environ.get('XTUNER_TOKENIZE_WORKERS', 8))
+
+ if cache_dir:
+ if os.path.exists(cache_dir):
+ assert os.path.isdir(cache_dir)
+ else:
+ mkdir_or_exist(cache_dir)
+
+ file_hash = calculate_jsonl_sha256(path)
+ file_cache_dir = os.path.join(cache_dir, file_hash)
+
+ if file_hash not in os.listdir(cache_dir):
+ mkdir_or_exist(file_cache_dir)
+
+ if 'offsets.npy' in os.listdir(file_cache_dir):
+ _cached_file = os.path.join(file_cache_dir, 'offsets.npy')
+ offsets = np.load(_cached_file)
+ else:
+ offsets = self.count_offsets(file_cache_dir)
+
+ if self.tokenize_fn:
+ tok_hash = calculate_tokenize_fn_sha256(tokenize_fn)
+ tok_cache_dir = os.path.join(file_cache_dir, tok_hash)
+ if tok_hash not in os.listdir(file_cache_dir):
+ mkdir_or_exist(tok_cache_dir)
+
+ if 'num_tokens.npy' in os.listdir(tok_cache_dir):
+ _cached_file = os.path.join(tok_cache_dir,
+ 'num_tokens.npy')
+ num_tokens = np.load(_cached_file)
+ else:
+ num_tokens = self.count_tokens(offsets, tok_cache_dir)
+ else:
+ num_tokens = None
+
+ offsets = offsets
+ num_tokens = num_tokens
+
+ else:
+ offsets = self.count_offsets()
+ num_tokens = None
+ if max_length is not None:
+ assert self.tokenize_fn
+ num_tokens = self.count_tokens(offsets)
+
+ _sampled = [i for i in range(len(offsets))]
+
+ if max_length is not None:
+ assert isinstance(max_length, int)
+ _filtered = [x for i, x in enumerate(_sampled) if num_tokens[i] < max_length]
+
+ if len(_filtered) < len(_sampled):
+ missed_num = len(_sampled) - len(_filtered)
+ logger.warning(f"{path} has {missed_num} prompt length>{max_length}, discard.")
+
+ _sampled = _filtered
+
+ _target_num_samples = int(len(_sampled) * sample_ratio)
+ self.sampled = _sampled * int(sample_ratio)
+ self.sampled.extend(random.sample(_sampled, _target_num_samples - len(self.sampled)))
+
+ if num_tokens is not None:
+ num_tokens = num_tokens[self.sampled]
+
+ self.num_tokens = num_tokens
+ self.offsets = offsets[self.sampled]
+
+
+ def count_offsets(self, cache_dir=None):
+
+ offsets = [0]
+ with open(self.path) as f:
+
+ lines = f.readlines()
+ for line in lines[:-1]:
+ offsets.append(offsets[-1]+len(line.encode()))
+
+ offsets = np.array(offsets)
+
+ if dist.get_rank() == 0 and cache_dir:
+ save_path = os.path.join(cache_dir, 'offsets.npy')
+ np.save(save_path, offsets)
+
+ return offsets
+
+ def _tokenize_by_offset(self, offset):
+
+ with open(self.path, 'r') as f:
+ f.seek(offset)
+ data = json.loads(f.readline())
+ return self.tokenize_fn(data)
+
+ def count_tokens(self, offsets, cache_dir=None):
+
+ num_samples = len(offsets)
+
+ if dist.is_available():
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ else:
+ world_size = 1
+ rank = 0
+
+ num_per_rank = math.ceil(num_samples / world_size)
+
+ start = rank * num_per_rank
+ end = (rank + 1) * num_per_rank
+ offsets_shard = offsets[start:end]
+
+
+ desc = f'[Rank {rank}] {self.path}'
+ chunk_size = min(1024, max(1, len(offsets_shard) // self.tokenizer_workers))
+
+ with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor:
+ tokenized = list(
+ tqdm(
+ executor.map(
+ self._tokenize_by_offset,
+ offsets_shard,
+ chunksize=chunk_size),
+ desc=desc,
+ total=len(offsets_shard)))
+
+ _num_tokens = [data['num_tokens'] for data in tokenized]
+ _num_tokens = np.array(_num_tokens)
+
+ if dist.is_available():
+ num_tokens = [None] * world_size
+ dist.all_gather_object(num_tokens, _num_tokens)
+ num_tokens = np.concatenate(num_tokens, axis=0)
+ else:
+ num_tokens = _num_tokens
+
+ if rank == 0 and cache_dir:
+ save_path = os.path.join(cache_dir, 'num_tokens.npy')
+ np.save(save_path, num_tokens)
+
+ return num_tokens
+
+ def __len__(self):
+ return len(self.offsets)
+
+ def __getitem__(self, item):
+ """Returns a dict containing packed data in the given item.
+
+ Args:
+ item: An index to retrieve packed data.
+
+ Returns:
+ A dict including packed input_ids, labels, and cumulative_len.
+ """
+ with open(self.path, 'r') as f:
+ f.seek(self.offsets[item])
+ line = f.readline()
+
+ raw_data = json.loads(line)
+
+ if self.tokenize_fn:
+ tokenized_data = self.tokenize_fn(raw_data)
+ return tokenized_data
+ else:
+ return raw_data
diff --git a/xtuner/_lite/datasets/pack.py b/xtuner/_lite/datasets/pack.py
new file mode 100644
index 000000000..2cbfce863
--- /dev/null
+++ b/xtuner/_lite/datasets/pack.py
@@ -0,0 +1,267 @@
+import random
+
+import numpy as np
+import torch
+from datasets import Dataset, concatenate_datasets
+from torch.utils.data import ConcatDataset
+import bisect
+import itertools
+
+
+class SoftPackDataset(torch.utils.data.Dataset):
+
+ def __init__(self, datasets, target=2048, blend=False, sort=False):
+
+ if blend:
+ num_tokens = [
+ np.concatenate([dset.num_tokens for dset in datasets])
+ ]
+ datasets = [ConcatDataset(datasets)]
+ else:
+ num_tokens = [dset.num_tokens for dset in datasets]
+ self.datasets = datasets
+ self.target = target
+
+ pack_infos = []
+ for i, dataset in enumerate(self.datasets):
+ _infos = self.get_pack_infos(dataset, i, num_tokens[i])
+ pack_infos.append(_infos)
+ self.pack_infos = concatenate_datasets(pack_infos)
+
+ @property
+ def longest(self):
+ return self.pack_infos['longest']
+
+ def get_pack_infos(self, dataset, dataset_id, num_tokens):
+ # _ori_lens = dataset['num_tokens']
+ inds = [i for i in range(len(dataset))]
+ random.shuffle(inds)
+
+ item_buffer = []
+ length_buffer = []
+ longest = 0
+
+ pack_infos = []
+ for shfl_i in inds:
+ if num_tokens[shfl_i] + sum(length_buffer) <= self.target:
+ item_buffer.append(shfl_i)
+ length_buffer.append(num_tokens[shfl_i])
+ longest = max(longest, num_tokens[shfl_i])
+ else:
+ if len(item_buffer) > 0:
+ info = {
+ 'dataset_id': dataset_id,
+ 'indices': item_buffer,
+ 'longest': int(longest)
+ }
+ pack_infos.append(info)
+
+ item_buffer = [shfl_i]
+ length_buffer = [num_tokens[shfl_i]]
+ longest = num_tokens[shfl_i]
+
+ if len(item_buffer) > 0:
+ info = {
+ 'dataset_id': dataset_id,
+ 'indices': item_buffer,
+ 'longest': int(longest)
+ }
+
+ pack_infos.append(info)
+
+ pack_infos = Dataset.from_list(pack_infos)
+
+ return pack_infos
+
+ def __len__(self):
+ return len(self.pack_infos)
+
+ def __getitem__(self, item):
+ indices = self.pack_infos[item]['indices']
+ dataset_id = self.pack_infos[item]['dataset_id']
+ return [self.datasets[dataset_id][i] for i in indices]
+
+
+class HardPackDataset(torch.utils.data.Dataset):
+
+
+ def __init__(self, datasets, target=2048, blend=True, sort=False):
+
+ if blend:
+ num_tokens = [
+ np.concatenate([dset.num_tokens for dset in datasets])
+ ]
+ datasets = [ConcatDataset(datasets)]
+ else:
+ num_tokens = [dset.num_tokens for dset in datasets]
+ self.datasets = datasets
+ self.target = target
+
+ pack_infos = []
+ for i, dataset in enumerate(self.datasets):
+ _info = self.get_pack_info(dataset, i, num_tokens[i])
+ pack_infos.append(_info)
+
+ _ranges_left = []
+ _ranges_right = []
+ _num_packed_samples = []
+ _indices = []
+ _max_length_per_pack = []
+ _dataset_id = []
+ for info in pack_infos:
+ _ranges_left.extend(info['ranges_left'])
+ _ranges_right.extend(info['ranges_right'])
+ _num_packed_samples.append(info['num_packed_samples'])
+ _indices.extend(info['indices'])
+ _max_length_per_pack.extend(info['max_length_per_pack'])
+ _dataset_id.extend(info['dataset_id'])
+
+ self.pack_infos = {
+ 'ranges_left': _ranges_left,
+ 'ranges_right': _ranges_right,
+ 'num_packed_samples': _num_packed_samples,
+ 'indices': _indices,
+ 'max_length_per_pack':_max_length_per_pack,
+ 'dataset_id': _dataset_id
+ }
+
+
+ @classmethod
+ def _cal_max_length(cls, begin, end, shfl_item_rngs_left,
+ shfl_item_rngs_right):
+ left = bisect.bisect(shfl_item_rngs_right, begin)
+ right = bisect.bisect(shfl_item_rngs_left, end)
+ max_length = 0
+ for i in range(left, right):
+ item_begin = shfl_item_rngs_left[i]
+ item_end = shfl_item_rngs_right[i]
+ inner_l = max(begin, item_begin) - item_begin
+ inner_r = min(end, item_end) - item_begin
+ trunc_size = inner_r - inner_l
+ max_length = max(max_length, trunc_size)
+ return max_length
+
+ def get_pack_info(self, dataset, dataset_id, num_tokens):
+
+ # The number of data items after packing
+ num_packed_samples = int(num_tokens.sum() / self.target)
+
+ # Shuffle the order of the original dataset
+ # The packing will proceed according to the order after shuffle.
+ # Assume the following conditions hold:
+ # (1) shfl_inds = [3, 1, 2, 0]
+ # (2) self._ori_lens[3] + self._ori_lens[1] = max_length
+ # (3) self._ori_lens[2] + self._ori_lens[0] = max_length
+ # Ultimately, dataset[3] and dataset[1] will be combined into a new
+ # data, and dataset[2] and dataset[0] will be combined into a new data.
+ inds = [i for i in range(len(dataset))]
+ # if seed is not None:
+ # random.seed(seed)
+ random.shuffle(inds)
+ shfl_inds = inds
+
+ # shuffled cumulative lengths
+ shfl_lens = [num_tokens[i] for i in shfl_inds]
+ shfl_acc_lens = list(itertools.accumulate(shfl_lens))
+
+ shfl_item_rngs_left = [0] + shfl_acc_lens[:-1]
+ shfl_item_rngs_right = shfl_acc_lens
+
+ max_length_per_pack = []
+ belong_dataset_ids = []
+ for i in range(num_packed_samples):
+ begin = i * self.target
+ end = (i + 1) * self.target
+ max_length_per_pack.append(
+ self._cal_max_length(begin, end, shfl_item_rngs_left,
+ shfl_item_rngs_right))
+ belong_dataset_ids.append(dataset_id)
+
+ pack_infos = {
+ 'ranges_left': shfl_item_rngs_left,
+ 'ranges_right': shfl_item_rngs_right,
+ 'num_packed_samples': num_packed_samples,
+ 'indices': shfl_inds,
+ 'dataset_id': belong_dataset_ids,
+ 'max_length_per_pack': max_length_per_pack
+ }
+
+ # pack_infos = Dataset.from_list(pack_infos)
+
+ return pack_infos
+
+ def _pack_ids_and_labels_in_range(self, begin: int, end: int):
+ """Packs ids and labels in a given range using bisection method.
+
+ Args:
+ begin: Index indicating the beginning of the range.
+ end: Index indicating the end of the range.
+
+ Returns:
+ A tuple containing packed ids, labels, and cumulative lengths.
+ """
+
+ # Use binary search to find dataset positions that fall within begin
+ # and end range
+ left = bisect.bisect(self.pack_infos['ranges_left'], begin)
+ right = bisect.bisect(self.pack_infos['ranges_right'], end)
+
+ trunc_input_ids = []
+ trunc_labels = []
+ trunc_sizes = []
+
+ for i in range(left, right):
+
+ # Determine the real range we will cut in current original item
+ item_begin = self.pack_infos['ranges_left'][i]
+ item_end = self.pack_infos['ranges_right'][i]
+
+ # Calculate exact positions within current dataset item
+ inner_l = max(begin, item_begin) - item_begin
+ inner_r = min(end, item_end) - item_begin
+
+ # Get original data and labels
+ ori_idx = self.pack_infos['indices'][i]
+ ori_dataset_id = self.pack_infos['dataset_id'][i]
+ ori_input_ids = self.datasets[ori_dataset_id][ori_idx]['input_ids']
+ ori_labels = self.datasets[ori_dataset_id][ori_idx]['labels']
+
+ # Add original data and labels from calculated positions
+ # to trunc_ids and trunc_labels
+ trunc_input_ids.extend(ori_input_ids[inner_l:inner_r])
+ trunc_labels.extend(ori_labels[inner_l:inner_r])
+ trunc_sizes.append(inner_r - inner_l)
+
+ # return populated lists of truncated ids, labels and their cumulative
+ # lengths
+ return trunc_input_ids, trunc_labels, trunc_sizes
+
+ def __len__(self):
+ return len(self.pack_infos['indices'])
+
+ def __getitem__(self, item):
+ """Returns a dict containing packed data in the given item.
+
+ Args:
+ item: An index to retrieve packed data.
+
+ Returns:
+ A dict including packed input_ids, labels, and cumulative_len.
+ """
+ # The cumulative length from the start position of this data
+ begin = item * self.target
+ # The cumulative length from the end position of this data
+ end = (item + 1) * self.target
+
+ # Extract data within the range from the shuffled original dataset.
+ _res = self._pack_ids_and_labels_in_range(begin, end)
+ packed_input_ids, packed_labels, num_tokens = _res
+ assert self.target == len(packed_input_ids) == len(packed_labels)
+
+ packed = {
+ 'input_ids': packed_input_ids,
+ 'labels': packed_labels,
+ 'num_tokens': num_tokens,
+ }
+
+ return packed
\ No newline at end of file
diff --git a/xtuner/_lite/datasets/streaming.py b/xtuner/_lite/datasets/streaming.py
new file mode 100644
index 000000000..bb567e780
--- /dev/null
+++ b/xtuner/_lite/datasets/streaming.py
@@ -0,0 +1,159 @@
+import json
+
+import numpy as np
+from torch.utils.data import IterableDataset
+
+
+class Streaming:
+
+ def __init__(self, file, max_epoch=1):
+ self.file = file
+ self.offset = 0
+ self.epoch = 1
+ self.max_epoch = max_epoch
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+
+ with open(self.file) as f:
+ f.seek(self.offset)
+ line = f.readline()
+
+ if not line and self.epoch < self.max_epoch:
+ self.offset = 0
+ self.epoch += 1
+ return next(self)
+
+ elif not line and self.epoch == self.max_epoch:
+ raise StopIteration
+
+ self.offset = f.tell()
+ return line
+
+
+# import torch
+
+# class MultiStreamingDataset(torch.utils.data.IterableDataset):
+
+# def __init__(self, streamings, weights, max_length, tokenize_fn, seed, dp_rank, dp_world_size, crossover = False):
+
+# assert len(streamings) == len(weights)
+# self.streamings = streamings
+# self.activated = [True for _ in self.streamings]
+# for sid, stream in enumerate(self.streamings):
+# stream.offset = 0
+# try:
+# for _ in range(self.dp_rank):
+# next(stream)
+# except StopIteration:
+# self.activated[sid] = False
+
+# self.random_state = np.random.RandomState(seed + dp_rank)
+# self.weights = weights
+
+# self.max_length = max_length
+# self.tokenize_fn = tokenize_fn
+# self.dp_rank = dp_rank
+# self.dp_world_size = dp_world_size
+# self.crossover = crossover
+
+# def reactivate(self):
+# self.activated = [True for _ in self.streamings]
+# for stream in self.streamings:
+# stream.offset = 0
+# for _ in range(self.dp_rank):
+# next(stream)
+
+# @property
+# def probabilities(self):
+# if sum(self.activated) == 0:
+# self.reactivate()
+
+# probs = (np.array(self.weights) * self.activated) / sum(self.weights[self.activated])
+# return probs
+
+# @property
+# def num_streamings(self):
+# assert len(self.iterators) == len(self.weights)
+# return len(self.weights)
+
+# def per_rank_next(self, streaming_id):
+
+# sid = streaming_id
+# streaming = self.streamings[sid]
+
+# try:
+# data = next(streaming)
+# except StopIteration:
+# self.activated[sid] = False
+# sid = self.random_state.choice(
+# self.num_streamings, p=self.probabilities)
+# return self.per_rank_next(sid)
+
+# try:
+# for _ in range(self.dp_world_size):
+# next(streaming)
+# except StopIteration:
+# self.activated[sid] = False
+
+# return data, sid
+
+# def __iter__(self):
+# worker_info = torch.utils.data.get_worker_info()
+
+# if worker_info and worker_info.num_workers > 1:
+# raise NotImplementedError
+
+# input_ids = []
+# labels = []
+# num_tokens = []
+# while True:
+# sid = self.random_state.choice(
+# self.num_streamings, p=self.probabilities)
+
+# while len(input_ids) < self.max_length:
+# if self.crossover:
+# sid = self.random_state.choice(
+# self.num_streamings, p=self.probabilities)
+
+# line, sid = self.per_rank_next(sid)
+
+# tokenized = self.tokenize_fn(json.loads(line))
+
+# input_ids.extend(tokenized['input_ids'])
+# labels.extend(tokenized['labels'])
+# num_tokens.extend(tokenized['num_tokens'])
+
+# remain_tokens = max(sum(num_tokens) - self.max_length, 0)
+# num_tokens[-1] = num_tokens[-1] - remain_tokens
+
+# packed_ids = input_ids[:self.max_length]
+# packed_labels = labels[:self.max_length]
+# packed_tokens = num_tokens
+
+# if remain_tokens:
+# input_ids = input_ids[self.max_length:]
+# labels = labels[self.max_length:]
+# num_tokens = [remain_tokens]
+
+# yield {'input_ids': packed_ids, 'labels': packed_labels, 'num_tokens': packed_tokens}
+
+if __name__ == '__main__':
+ import json
+ streaming = Streaming(
+ '/mnt/hwfile/xtuner/huanghaian/data/databricks-dolly-15k/databricks-dolly-15k.jsonl'
+ )
+
+ data = next(streaming)
+ print(json.loads(data))
+
+ data = next(streaming)
+ print(json.loads(data))
+
+ data = next(streaming)
+ print(json.loads(data))
+
+ data = next(streaming)
+ print(json.loads(data))
diff --git a/xtuner/_lite/datasets/utils/__init__.py b/xtuner/_lite/datasets/utils/__init__.py
new file mode 100644
index 000000000..b16dc419d
--- /dev/null
+++ b/xtuner/_lite/datasets/utils/__init__.py
@@ -0,0 +1,6 @@
+from .convert import OPENAI_CONVERT_MAP
+from .load import DATASET_CLS_MAP, load_datasets
+from .utils import apply_exif_orientation, move_data_to_device
+
+__all__ = ['OPENAI_CONVERT_MAP', 'DATASET_CLS_MAP', 'load_datasets',
+ 'apply_exif_orientation', 'move_data_to_device']
diff --git a/xtuner/_lite/datasets/utils/convert.py b/xtuner/_lite/datasets/utils/convert.py
new file mode 100644
index 000000000..b78a12d51
--- /dev/null
+++ b/xtuner/_lite/datasets/utils/convert.py
@@ -0,0 +1,234 @@
+import re
+
+from xtuner._lite.chat import ChatMessages
+
+
+class XTunerFormat2Openai():
+
+ @classmethod
+ def source_format(cls):
+ data = {
+ 'conversation': [{
+ 'system': 'SYSTEM',
+ 'input': 'INPUT',
+ 'output': 'OUTPUT'
+ }, {
+ 'input': 'INPUT',
+ 'output': 'OUTPUT'
+ }]
+ }
+ return data
+
+ @classmethod
+ def target_format(cls):
+ data = {
+ 'messages': [
+ {
+ 'role': 'system',
+ 'content': 'SYSTEM'
+ },
+ {
+ 'role': 'user',
+ 'content': 'INPUT'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'OUTPUT'
+ },
+ {
+ 'role': 'user',
+ 'content': 'INPUT'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'OUTPUT'
+ },
+ ]
+ }
+ return data
+
+ @staticmethod
+ def convert(data):
+ ROLE_MAPPING = {
+ 'system': 'system',
+ 'input': 'user',
+ 'output': 'assistant'
+ }
+ messages = []
+ for single_turn_conversation in data['conversation']:
+ for role, content in single_turn_conversation.items():
+ messages.append({
+ 'role': ROLE_MAPPING[role],
+ 'content': content
+ })
+ return ChatMessages.from_dict({'messages': messages})
+
+
+class Alpaca2Openai():
+
+ @classmethod
+ def source_format(cls):
+ data = {
+ 'instruction': 'INSTRUCTION',
+ 'input': 'INPUT',
+ 'output': 'OUTPUT',
+ }
+ return data
+
+ @classmethod
+ def target_format(cls):
+ data = {
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': 'INSTRUCTION\nINPUT'
+ },
+ {
+ 'role': 'assistant',
+ 'content': 'OUTPUT'
+ },
+ ]
+ }
+ return data
+
+ @staticmethod
+ def convert(data):
+ if data.get('output') == '':
+ return ChatMessages.from_dict({'messages': []})
+ else:
+ return ChatMessages.from_dict({
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': f"{data['instruction']}\n{data['input']}"
+ },
+ {
+ 'role': 'assistant',
+ 'content': f"{data['output']}"
+ },
+ ]
+ })
+
+
+def llava_to_openai(data):
+
+ image_token = ''
+ conversations = data['conversations']
+ messages = []
+
+ if 'image' in data:
+ image_urls = data['image']
+ if isinstance(image_urls, str):
+ image_urls = [image_urls]
+ else:
+ image_urls = None
+
+ while conversations and conversations[0]['from'] == 'gpt':
+ # Skip the first one if it is from gpt
+ conversations = conversations[1:]
+
+ image_id = 0
+ for convs in conversations:
+ if convs['from'] == 'human':
+ pattern = f'({image_token})'
+ chunks = re.split(pattern, convs['value'])
+
+ text_content = []
+ img_content = []
+
+ for chunk in chunks:
+ if chunk == image_token:
+ url = image_urls[image_id]
+ if not isinstance(url, str):
+ raise TypeError(data)
+ # assert , image_url
+ item = dict(type='image_url', image_url=url)
+ img_content.append(item)
+ image_id += 1
+ elif len(chunk.strip()):
+ item = dict(type='text', text=chunk.strip())
+ text_content.append(item)
+
+ msg = {'role': 'user', 'content': img_content + text_content}
+ messages.append(msg)
+
+ elif convs['from'] == 'gpt':
+ msg = {'role': 'assistant', 'content': convs['value']}
+ messages.append(msg)
+ else:
+ raise NotImplementedError
+
+ return ChatMessages.from_dict({'messages': messages})
+
+
+def llava_to_openai_interleave(data):
+
+ image_token = ''
+ conversations = data['conversations']
+ messages = []
+
+ if 'image' in data:
+ image_urls = data['image']
+ if isinstance(image_urls, str):
+ image_urls = [image_urls]
+ else:
+ image_urls = None
+
+ while conversations and conversations[0]['from'] == 'gpt':
+ # Skip the first one if it is from gpt
+ conversations = conversations[1:]
+
+ image_id = 0
+ for convs in conversations:
+ if convs['from'] == 'human':
+ pattern = f'({image_token})'
+ chunks = re.split(pattern, convs['value'])
+
+ content = []
+
+ for chunk in chunks:
+ if chunk == image_token:
+ url = image_urls[image_id]
+ if not isinstance(url, str):
+ raise TypeError(data)
+ # assert , image_url
+ item = dict(type='image_url', image_url=url)
+ content.append(item)
+ image_id += 1
+ elif len(chunk.strip()):
+ item = dict(type='text', text=chunk.strip())
+ content.append(item)
+
+ msg = {'role': 'user', 'content': content}
+ messages.append(msg)
+
+ elif convs['from'] == 'gpt':
+ msg = {'role': 'assistant', 'content': convs['value']}
+ messages.append(msg)
+ else:
+ raise NotImplementedError
+
+ return ChatMessages.from_dict({'messages': messages})
+
+
+def official_openai(data):
+ if 'messages' in data:
+ return ChatMessages.from_dict(data)
+ elif 'message_data' in data:
+ return ChatMessages.from_dict({'messages': data['message_data']})
+ elif 'dialogs' in data:
+ return ChatMessages.from_dict({'messages': data['dialogs']})
+ else:
+ return ChatMessages.from_dict({'messages': data})
+
+OPENAI_CONVERT_MAP = {
+ 'llava':
+ llava_to_openai,
+ 'llava_interleave':
+ llava_to_openai_interleave,
+ 'alpaca':
+ Alpaca2Openai.convert,
+ 'xtuner':
+ XTunerFormat2Openai.convert,
+ 'openai': official_openai,
+}
diff --git a/xtuner/_lite/datasets/utils/load.py b/xtuner/_lite/datasets/utils/load.py
new file mode 100644
index 000000000..7a39f3c75
--- /dev/null
+++ b/xtuner/_lite/datasets/utils/load.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import math
+import os
+import random
+import re
+
+from torch import distributed as dist
+from tqdm import tqdm
+
+from xtuner._lite import get_logger
+from ..json import JsonDataset
+from ..jsonl import JsonlDataset
+
+logger = get_logger()
+
+DATASET_CLS_MAP = {'.jsonl': JsonlDataset, '.json': JsonDataset}
+
+
+def load_hf_dataset(path,
+ split='train',
+ sample_ratio=1.0,
+ cache_dir=None,
+ map_fn=None):
+ from datasets import load_dataset
+ dataset = load_dataset(path)[split]
+
+ if map_fn:
+ dataset = dataset.map(map_fn, num_proc=8)
+
+ if sample_ratio != 1:
+ ori_samples = len(dataset)
+ target_samples = int(sample_ratio * ori_samples)
+ indices = random.choices([i for i in range(ori_samples)],
+ k=target_samples)
+ dataset = dataset.select(indices)
+
+ dataset = dataset.to_list()
+
+ # if init_fn:
+ # dataset = init_fn(dataset)
+
+ # if cache_dir and isinstance(dataset, CacheDataset):
+ # dataset.cache(cache_dir)
+
+ return dataset
+
+
+def load_from_cache(cache_dir, init_fn):
+
+ if dist.is_available():
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ else:
+ world_size = 1
+ rank = 0
+
+ sub_cache_dirs = []
+ for _path in tqdm(os.listdir(cache_dir)):
+ path = os.path.join(cache_dir, _path)
+ if os.path.isdir(path):
+ sub_cache_dirs.append(path)
+
+ num_dsets = len(sub_cache_dirs)
+ avg_num = math.ceil(num_dsets / world_size)
+ start = rank * avg_num
+ end = min((rank + 1) * avg_num, num_dsets)
+ desc = f'[Rank {rank}] Loading Cached Dataset'
+
+ rank_datasets = []
+ for ind in tqdm(range(start, end), desc=desc):
+ dset = init_fn(sub_cache_dirs[ind])
+ rank_datasets.append(dset)
+
+ if dist.is_available() and world_size > 1:
+ dist.barrier()
+ buffers = [None] * world_size
+ dist.all_gather_object(buffers, rank_datasets)
+ world_datasets = []
+ for dsets_per_rank in buffers:
+ world_datasets.extend(dsets_per_rank)
+
+ assert len(world_datasets) == num_dsets
+ else:
+ world_datasets = rank_datasets
+
+ return world_datasets
+
+
+def load_local_datasets(paths,
+ file_types,
+ file_pattern=None,
+ cache_dir=None,
+ sample_ratios=1.0,
+ map_fns=None,
+ max_length=None):
+
+ if isinstance(paths, str):
+ paths = [paths]
+
+ if isinstance(sample_ratios, (tuple, list)):
+
+ if len(sample_ratios) == 1:
+ sample_ratios = list(sample_ratios) * len(paths)
+
+ if len(sample_ratios) != len(paths):
+ raise RuntimeError(f'There are {len(paths)} paths, but only '
+ f'{len(sample_ratios)} sample ratios were set.')
+
+ if map_fns is None:
+ map_fns = [None] * len(paths)
+
+ if isinstance(map_fns, (tuple, list)):
+
+ if len(map_fns) == 1:
+ map_fns = list(map_fns) * len(paths)
+
+ if len(map_fns) != len(paths):
+ raise RuntimeError(f'There are {len(paths)} paths, but only'
+ f'{len(map_fns)} map fns were set.')
+
+ files = []
+ file_sample_ratios = []
+ file_map_fns = []
+
+ for pid, path in enumerate(paths):
+ if os.path.isdir(path):
+ dir_files = []
+ for root, dirs, _files in os.walk(path, followlinks=True):
+ dirs.sort()
+ for relative_path in sorted(_files):
+ suffix = os.path.splitext(relative_path)[-1]
+ absolute_path = os.path.join(root, relative_path)
+ if file_pattern is not None:
+ if bool(re.match(file_pattern, absolute_path)):
+ dir_files.append(absolute_path)
+ elif suffix in file_types:
+ dir_files.append(absolute_path)
+
+ _num_dir_files = len(dir_files)
+ if _num_dir_files == 0:
+ raise RuntimeError(
+ f'There are no files with the suffix {file_types}'
+ f'in `{path}`.')
+
+ logger.info(f'Found {len(dir_files)} files in {path}')
+ files.extend(dir_files)
+ file_sample_ratios.extend([sample_ratios[pid]] * _num_dir_files)
+ file_map_fns.extend([map_fns[pid]] * _num_dir_files)
+
+ elif os.path.isfile(path):
+ files.append(path)
+ file_sample_ratios.append(sample_ratios[pid])
+ file_map_fns.append(map_fns[pid])
+
+ else:
+ raise RuntimeError(f'`{path}` not found.')
+
+ num_files = len(files)
+
+ datasets = []
+ for i in range(num_files):
+ _path = files[i]
+ _ratio = file_sample_ratios[i]
+ _map_fn = file_map_fns[i]
+ _suffix = os.path.splitext(_path)[-1]
+
+ dataset_cls = DATASET_CLS_MAP[_suffix]
+ _dataset = dataset_cls(_path, _ratio, _map_fn, cache_dir, max_length)
+ datasets.append(_dataset)
+
+ return datasets
+
+
+def load_datasets(paths,
+ sources='local',
+ sample_ratios=1.0,
+ file_types=DATASET_CLS_MAP.keys(),
+ file_pattern=None,
+ cache_dir=None,
+ map_fns=None,
+ max_length=None):
+
+ if isinstance(paths, str):
+ paths = [paths]
+
+ num_paths = len(paths)
+
+ if isinstance(sample_ratios, (float, int)):
+ sample_ratios = [sample_ratios] * num_paths
+
+ if isinstance(sample_ratios, (tuple, list)):
+
+ if len(sample_ratios) == 1:
+ sample_ratios = list(sample_ratios) * num_paths
+
+ if len(sample_ratios) != num_paths:
+ raise RuntimeError(f'There are {num_paths} paths, but only '
+ f'{len(sample_ratios)} sample ratios were set.')
+
+ if isinstance(sources, str):
+ sources = [sources]
+
+ if isinstance(sources, (tuple, list)):
+
+ if len(sources) == 1:
+ sources = list(sources) * num_paths
+
+ if len(sources) != num_paths:
+ raise RuntimeError(f'There are {num_paths} paths, but only '
+ f'{len(sources)} sources were set.')
+
+ if not isinstance(map_fns, (tuple, list)):
+ map_fns = [map_fns] * num_paths
+
+ if isinstance(map_fns, (tuple, list)):
+
+ if len(map_fns) == 1:
+ map_fns = list(map_fns) * num_paths
+
+ if len(map_fns) != num_paths:
+ raise RuntimeError(f'There are {num_paths} paths, but only'
+ f'{len(map_fns)} map fns were set.')
+
+ local_inds = [i for i, src in enumerate(sources) if src == 'local']
+ local_paths = [paths[ind] for ind in local_inds]
+ local_map_fns = [map_fns[ind] for ind in local_inds]
+ local_sample_ratios = [sample_ratios[ind] for ind in local_inds]
+
+ hf_inds = [i for i, src in enumerate(sources) if src == 'huggingface']
+ hf_paths = [paths[ind] for ind in hf_inds]
+ hf_map_fns = [map_fns[ind] for ind in hf_inds]
+ hf_sample_ratios = [sample_ratios[ind] for ind in hf_inds]
+
+ datasets = []
+ if len(local_inds):
+ local_datasets = load_local_datasets(local_paths, file_types,
+ file_pattern, cache_dir,
+ local_sample_ratios,
+ local_map_fns, max_length)
+ datasets.extend(local_datasets)
+
+ if len(hf_inds):
+ cached_infos = {}
+ for i in range(len(hf_inds)):
+ if cache_dir:
+ digits = len(str(abs(len(hf_inds))))
+ cache_id = (f'cache-hf-{i+1:0{digits}}-of-'
+ f'{len(hf_inds):0{digits}}')
+ sub_cache_dir = os.path.join(cache_dir, cache_id)
+ else:
+ sub_cache_dir = None
+ dset = load_hf_dataset(
+ hf_paths[i],
+ sample_ratio=hf_sample_ratios[i],
+ map_fn=hf_map_fns[i],
+ cache_dir=sub_cache_dir,
+ max_length=max_length)
+ datasets.append(dset)
+ breakpoint()
+ if cache_dir:
+
+ infos = {
+ 'path': hf_paths[i],
+ 'num_samples': dset.num_samples,
+ 'num_tokens': dset.total_tokens
+ }
+ cached_infos[cache_id] = infos
+
+ if cache_dir:
+ _path = os.path.join(cache_dir, 'hf_infos.json')
+ with open(_path, 'w') as f:
+ json.dump(cached_infos, f)
+
+ return datasets
+
+
+def load_ms_dataset():
+ pass
diff --git a/xtuner/_lite/datasets/utils/utils.py b/xtuner/_lite/datasets/utils/utils.py
new file mode 100644
index 000000000..19f572ac0
--- /dev/null
+++ b/xtuner/_lite/datasets/utils/utils.py
@@ -0,0 +1,66 @@
+from PIL import Image
+import torch
+from collections.abc import Mapping
+
+_EXIF_ORIENT = 274 # exif 'Orientation' tag
+
+
+def apply_exif_orientation(image):
+ """
+ Applies the exif orientation correctly.
+
+ This code exists per the bug:
+ https://github.com/python-pillow/Pillow/issues/3973
+ with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
+ various methods, especially `tobytes`
+
+ Function based on:
+ https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
+ https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
+
+ Args:
+ image (PIL.Image): a PIL image
+
+ Returns:
+ (PIL.Image): the PIL image with exif orientation applied, if applicable
+ """
+ if not hasattr(image, "getexif"):
+ return image
+
+ try:
+ exif = image.getexif()
+ except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
+ exif = None
+
+ if exif is None:
+ return image
+
+ orientation = exif.get(_EXIF_ORIENT)
+
+ method = {
+ 2: Image.FLIP_LEFT_RIGHT,
+ 3: Image.ROTATE_180,
+ 4: Image.FLIP_TOP_BOTTOM,
+ 5: Image.TRANSPOSE,
+ 6: Image.ROTATE_270,
+ 7: Image.TRANSVERSE,
+ 8: Image.ROTATE_90,
+ }.get(orientation)
+
+ if method is not None:
+ return image.transpose(method)
+ return image
+
+
+def move_data_to_device(data, device='cuda'):
+ """
+ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+ """
+ if isinstance(data, Mapping):
+ return type(data)({k: move_data_to_device(v) for k, v in data.items()})
+ elif isinstance(data, (tuple, list)):
+ return type(data)(move_data_to_device(v) for v in data)
+ elif isinstance(data, torch.Tensor):
+ kwargs = {"device": device}
+ return data.to(non_blocking=True, **kwargs)
+ return data
diff --git a/xtuner/_lite/device.py b/xtuner/_lite/device.py
new file mode 100644
index 000000000..ab6890c5f
--- /dev/null
+++ b/xtuner/_lite/device.py
@@ -0,0 +1,32 @@
+import torch
+
+
+def get_device():
+ device = None
+ if torch.cuda.is_available():
+ device = 'cuda'
+ else:
+ try:
+ import torch_npu # noqa: F401
+ device = 'npu'
+ except ImportError:
+ pass
+
+ if device is None:
+ raise NotImplementedError(
+ 'Supports only CUDA or NPU. If your device is CUDA or NPU, '
+ 'please make sure that your environmental settings are '
+ 'configured correctly.')
+
+ return device
+
+
+def get_torch_device_module():
+
+ device = get_device()
+ if device == 'cuda':
+ return torch.cuda
+ elif device == 'npu':
+ return torch.npu
+ else:
+ raise NotImplementedError
diff --git a/xtuner/_lite/modelings/__init__.py b/xtuner/_lite/modelings/__init__.py
new file mode 100644
index 000000000..de930bff1
--- /dev/null
+++ b/xtuner/_lite/modelings/__init__.py
@@ -0,0 +1,10 @@
+from .internlm2 import InternLM2Config, InternLM2ForCausalLM
+from .llava.modeling_llava import LlavaForConditionalGeneration
+from .llava.configuration_llava import EnhancedLlavaConfig
+from .llava.processing_llava import LlavaProcessor
+
+def register_remote_code():
+ from transformers import AutoConfig, AutoModelForCausalLM
+ AutoConfig.register('internlm2', InternLM2Config, exist_ok=True)
+ AutoModelForCausalLM.register(
+ InternLM2Config, InternLM2ForCausalLM, exist_ok=True)
diff --git a/xtuner/_lite/modelings/internlm2/__init__.py b/xtuner/_lite/modelings/internlm2/__init__.py
new file mode 100644
index 000000000..e43d72d4a
--- /dev/null
+++ b/xtuner/_lite/modelings/internlm2/__init__.py
@@ -0,0 +1,2 @@
+from .configuration_internlm2 import InternLM2Config
+from .modeling_internlm2 import InternLM2ForCausalLM
diff --git a/xtuner/_lite/modelings/internlm2/configuration_internlm2.py b/xtuner/_lite/modelings/internlm2/configuration_internlm2.py
new file mode 100644
index 000000000..8b8107947
--- /dev/null
+++ b/xtuner/_lite/modelings/internlm2/configuration_internlm2.py
@@ -0,0 +1,175 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" InternLM2 model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+
+
+# Modified from transformers.model.llama.configuration_llama.LlamaConfig
+class InternLM2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`InternLM2Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
+ to understand more about it. This value is necessary to ensure exact reproducibility
+ of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ """
+ _auto_class = 'AutoConfig'
+ model_type = 'internlm2'
+ keys_to_ignore_at_inference = ['past_key_values']
+
+ def __init__( # pylint: disable=W0102
+ self,
+ vocab_size=103168,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act='silu',
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ bias=True,
+ rope_theta=10000,
+ rope_scaling=None,
+ attn_implementation=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.bias = bias
+
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.attn_implementation = attn_implementation
+ if self.attn_implementation is None:
+ self.attn_implementation = 'eager'
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling,
+ dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
+ f'got {self.rope_scaling}')
+ rope_scaling_type = self.rope_scaling.get('type', None)
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
+ if rope_scaling_type is None or rope_scaling_type not in [
+ 'linear', 'dynamic'
+ ]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if (rope_scaling_factor is None
+ or not isinstance(rope_scaling_factor,
+ (float, int)) or rope_scaling_factor < 1.0):
+ raise ValueError(
+ f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
+ f'of type {type(rope_scaling_factor)}')
diff --git a/xtuner/_lite/modelings/internlm2/modeling_internlm2.py b/xtuner/_lite/modelings/internlm2/modeling_internlm2.py
new file mode 100644
index 000000000..69ddc6196
--- /dev/null
+++ b/xtuner/_lite/modelings/internlm2/modeling_internlm2.py
@@ -0,0 +1,1899 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch InternLM2.5 model."""
+import math
+import queue
+import threading
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10, logging,
+ replace_return_docstrings)
+
+try:
+ from transformers.generation.streamers import BaseStreamer
+except Exception:
+ BaseStreamer = None
+
+from .configuration_internlm2 import InternLM2Config
+
+try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import (index_first_axis, pad_input,
+ unpad_input)
+except:
+ pass
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = 'InternLM2Config'
+
+
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+class InternLM2RMSNorm(nn.Module):
+ """InternLM2RMSNorm is equivalent to T5LayerNorm."""
+
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance +
+ self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm)
+
+
+class InternLM2RotaryEmbedding(nn.Module):
+ """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains."""
+
+ def __init__(self,
+ dim,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0):
+ super().__init__()
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base
+ **(torch.arange(0, self.dim, 2,
+ dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
+ position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(
+ device_type, str) and device_type != 'mps' else 'cpu'
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float()
+ @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
+ Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * ((self.scaling_factor * seq_len /
+ self.max_position_embeddings) -
+ (self.scaling_factor - 1))**(
+ self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base
+ **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(
+ x.device) / self.dim))
+ self.register_buffer(
+ 'inv_freq', inv_freq,
+ persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., :x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
+ """Applies Rotary Position Embedding to the query and key tensors.
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class InternLM2MLP(nn.Module):
+ """MLP for InternLM2 model."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.w1 = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False)
+ self.w3 = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False)
+ self.w2 = nn.Linear(
+ self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :,
+ None, :, :].expand(batch,
+ num_key_value_heads,
+ n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
+ head_dim)
+
+
+class InternLM2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self,
+ config: InternLM2Config,
+ layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
+ 'when creating this class.')
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
+ f' and `num_heads`: {self.num_heads}).')
+
+ self.wqkv = nn.Linear(
+ self.hidden_size,
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
+ bias=config.bias,
+ )
+ self.wo = nn.Linear(
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
+
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = InternLM2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling['type']
+ scaling_factor = self.config.rope_scaling['factor']
+ if scaling_type == 'linear':
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == 'dynamic':
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False, # pylint: disable=unused-argument
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ # split qkv_states by tp size
+ key_value_slicing = (self.num_key_value_heads *
+ self.head_dim) // self.config.pretraining_tp
+ qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
+ qkv_states = torch.cat(
+ [
+ F.linear(hidden_states, qkv_slice)
+ for qkv_slice in qkv_slices
+ ],
+ dim=-1 # pylint: disable=E1102
+ )
+ else:
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states,
+ 'b q h gs d -> b q (h gs) d').transpose(1, 2)
+ key_states = qkv_states[..., -2, :].transpose(1, 2)
+ value_states = qkv_states[..., -1, :].transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(
+ 2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
+ f' {attn_output.size()}')
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(
+ self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.wo.weight.split(
+ self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum([
+ F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
+ for i in range(self.config.pretraining_tp)
+ ])
+ else:
+ attn_output = self.wo(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class InternLM2FlashAttention2(InternLM2Attention):
+ """
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
+ # that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
+ # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1)
+ # produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` '
+ 'make sure to use `sdpa` in the mean time, and open an issue at '
+ 'https://github.com/huggingface/transformers')
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # dropout_rate = self.attention_dropout if self.training else 0.0
+ dropout_rate = 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (InternLM2RMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.wqkv.weight.dtype
+
+ logger.warning_once(
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
+ f' {target_dtype}.')
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate)
+
+ attn_output = attn_output.reshape(bsz, q_len,
+ self.hidden_size).contiguous()
+ attn_output = self.wo(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value # pylint: disable=E0606
+
+ def _flash_attention_forward(self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
+ # For details, please see the comment in InternLM2FlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask,
+ query_length)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
+ query_length) # pylint: disable=E0606
+ else:
+ attn_output = flash_attn_func( # pylint: disable=E0606
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal)
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
+ query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
+ attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis( # pylint: disable=E0606
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ value_layer = index_first_axis( # pylint: disable=E0606
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis( # pylint: disable=E0606
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
+ head_dim), indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606
+ query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2
+class InternLM2SdpaAttention(InternLM2Attention):
+ """
+ InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
+ to adapt to SDPA API.
+ """
+
+ # Adapted from InternLM2Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
+ # once this is implemented.
+ logger.warning_once(
+ 'InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` '
+ 'does not support `output_attentions=True`. Falling back to the manual attention implementation, '
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. '
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
+ # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == 'cuda' and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
+ # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
+ # options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = bool(causal_mask is None and q_len > 1)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.wo(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+INTERNLM2_ATTENTION_CLASSES = {
+ 'eager': InternLM2Attention,
+ 'flash_attention_2': InternLM2FlashAttention2,
+ 'sdpa': InternLM2SdpaAttention,
+}
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2
+class InternLM2DecoderLayer(nn.Module):
+ """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model."""
+
+ def __init__(self, config: InternLM2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+
+ self.attention = INTERNLM2_ATTENTION_CLASSES[
+ config.attn_implementation](
+ config=config, layer_idx=layer_idx)
+
+ self.feed_forward = InternLM2MLP(config)
+ self.attention_norm = InternLM2RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps)
+ self.ffn_norm = InternLM2RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
+ torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+ residual = hidden_states
+
+ hidden_states = self.attention_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.attention(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.ffn_norm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states, )
+
+ if output_attentions:
+ outputs += (self_attn_weights, )
+
+ if use_cache:
+ outputs += (present_key_value, )
+
+ return outputs
+
+
+InternLM2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+ Parameters:
+ config ([`InternLM2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
+@add_start_docstrings(
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2PreTrainedModel(PreTrainedModel):
+ """
+ InternLM2 pretraiend model's base class.
+ """
+
+ config_class = InternLM2Config
+ base_model_prefix = 'model'
+ supports_gradient_checkpointing = True
+ _no_split_modules = ['InternLM2DecoderLayer']
+ _skip_keys_device_placement = ['past_key_values']
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+InternLM2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2
+@add_start_docstrings(
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2Model(InternLM2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
+ Args:
+ config: InternLM2Config
+ """
+
+ _auto_class = 'AutoModel'
+
+ def __init__(self, config: InternLM2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.config = config
+
+ self.tok_embeddings = nn.Embedding(config.vocab_size,
+ config.hidden_size,
+ self.padding_idx)
+
+ self.layers = nn.ModuleList([
+ InternLM2DecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ])
+ self.norm = InternLM2RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.tok_embeddings(input_ids)
+
+ return_legacy_cache = False
+ if use_cache and not isinstance(
+ past_key_values,
+ Cache): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length(
+ ) if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device)
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
+ cache_position, past_key_values,
+ output_attentions)
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states, )
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[
+ 2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1], )
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states, )
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(
+ v for v in
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length
+ # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at
+ # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is
+ # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`.
+ # See more context in https://github.com/huggingface/transformers/pull/29114
+
+ if self.config.attn_implementation == 'flash_attention_2':
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length(
+ ) if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config.attn_implementation == 'sdpa' and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_length()
+ else:
+ target_length = (
+ attention_mask.shape[-1] if isinstance(
+ attention_mask, torch.Tensor) else past_seen_tokens +
+ sequence_length + 1)
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
+ if attention_mask.max() != 0:
+ raise ValueError(
+ 'Custom 4D attention mask should be passed in inverted form with max==0`'
+ )
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full((sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(
+ input_tensor.shape[0], 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone(
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :
+ mask_length] + attention_mask[:,
+ None,
+ None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :
+ mask_length] = causal_mask[:, :, :, :
+ mask_length].masked_fill(
+ padding_mask,
+ min_dtype)
+ if (self.config.attn_implementation == 'sdpa'
+ and attention_mask is not None
+ and attention_mask.device.type == 'cuda'
+ and not output_attentions):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(
+ causal_mask, min_dtype) # pylint: disable=E1120
+
+ return causal_mask
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM
+class InternLM2ForCausalLM(InternLM2PreTrainedModel):
+ """Causal language model (CLM) for InternLM2."""
+
+ _auto_class = 'AutoModelForCausalLM'
+ _tied_weights_keys = ['output.weight']
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = InternLM2Model(config)
+ self.vocab_size = config.vocab_size
+ self.output = nn.Linear(
+ config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ def get_output_embeddings(self):
+ return self.output
+
+ def set_output_embeddings(self, new_embeddings):
+ self.output = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ Returns:
+ Example:
+ ```python
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
+ >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ output_slices = self.output.weight.split(
+ self.vocab_size // self.config.pretraining_tp, dim=0)
+ logits = [
+ F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable
+ for i in range(self.config.pretraining_tp)
+ ]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.output(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits, ) + outputs[1:]
+ return (loss, ) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ past_length = 0
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ past_length = cache_position[
+ 0] if cache_position is not None else past_key_values.get_seq_length(
+ )
+ max_cache_length = (
+ torch.tensor(
+ past_key_values.get_max_length(),
+ device=input_ids.device)
+ if past_key_values.get_max_length() is not None else None)
+ cache_length = past_length if max_cache_length is None else torch.min(
+ max_cache_length, past_length)
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
+ if attention_mask is not None and attention_mask.shape[
+ 1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
+ past_length):]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (max_cache_length is not None and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length):
+ attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130
+
+ position_ids = kwargs.get('position_ids', None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {'inputs_embeds': inputs_embeds}
+ else:
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+ # recompiles graphs as the stride of the inputs is a guard.
+ # Ref: https://github.com/huggingface/transformers/pull/29114
+ # TODO: use `next_tokens` directly instead.
+ model_inputs = {'input_ids': input_ids.contiguous()}
+
+ input_length = position_ids.shape[
+ -1] if position_ids is not None else input_ids.shape[-1]
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_length,
+ past_length + input_length,
+ device=input_ids.device)
+ elif use_cache:
+ cache_position = cache_position[-input_length:]
+
+ model_inputs.update({
+ 'position_ids': position_ids,
+ 'cache_position': cache_position,
+ 'past_key_values': past_key_values,
+ 'use_cache': use_cache,
+ 'attention_mask': attention_mask,
+ })
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(
+ past_state.index_select(0, beam_idx.to(past_state.device))
+ for past_state in layer_past), )
+ return reordered_past
+
+ def build_inputs(self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ meta_instruction=''):
+ if history is None:
+ history = []
+ if tokenizer.add_bos_token:
+ prompt = ''
+ else:
+ prompt = tokenizer.bos_token
+ if meta_instruction:
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
+ for record in history:
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
+ return tokenizer([prompt], return_tensors='pt')
+
+ @torch.no_grad()
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ streamer: Optional[BaseStreamer] = None,
+ max_new_tokens: int = 1024,
+ do_sample: bool = True,
+ temperature: float = 0.8,
+ top_p: float = 0.8,
+ meta_instruction:
+ str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
+ '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory '
+ '(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
+ '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such '
+ 'as English and 中文.',
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
+ inputs = {
+ k: v.to(self.device)
+ for k, v in inputs.items() if torch.is_tensor(v)
+ }
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
+ eos_token_id = [
+ tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]
+ ]
+ outputs = self.generate(
+ **inputs,
+ streamer=streamer,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
+ response = response.split('<|im_end|>')[0]
+ history = history + [(query, response)]
+ return response, history
+
+ @torch.no_grad()
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_new_tokens: int = 1024,
+ do_sample: bool = True,
+ temperature: float = 0.8,
+ top_p: float = 0.8,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ """
+ Return a generator in format: (response, history)
+ Eg.
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
+ """
+ if BaseStreamer is None:
+ raise ModuleNotFoundError(
+ 'The version of `transformers` is too low. Please make sure '
+ 'that you have installed `transformers>=4.28.0`.')
+
+ response_queue = queue.Queue(maxsize=20)
+
+ class ChatStreamer(BaseStreamer):
+ """
+ Streamer used in generate to print words one by one.
+ """
+
+ def __init__(self, tokenizer) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.queue = response_queue
+ self.query = query
+ self.history = history
+ self.response = ''
+ self.cache = []
+ self.received_inputs = False
+ self.queue.put(
+ (self.response, history + [(self.query, self.response)]))
+
+ def put(self, value):
+ if len(value.shape) > 1 and value.shape[0] > 1:
+ raise ValueError('ChatStreamer only supports batch size 1')
+ elif len(value.shape) > 1:
+ value = value[0]
+
+ if not self.received_inputs:
+ # The first received value is input_ids, ignore here
+ self.received_inputs = True
+ return
+
+ self.cache.extend(value.tolist())
+ token = self.tokenizer.decode(
+ self.cache, skip_special_tokens=True)
+ if token.strip() != '<|im_end|>':
+ self.response = self.response + token
+ history = self.history + [(self.query, self.response)]
+ self.queue.put((self.response, history))
+ self.cache = []
+ else:
+ self.end()
+
+ def end(self):
+ self.queue.put(None)
+
+ def stream_producer():
+ return self.chat(
+ tokenizer=tokenizer,
+ query=query,
+ streamer=ChatStreamer(tokenizer=tokenizer),
+ history=history,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ **kwargs,
+ )
+
+ def consumer():
+ producer = threading.Thread(target=stream_producer)
+ producer.start()
+ while True:
+ res = response_queue.get()
+ if res is None:
+ return
+ yield res
+
+ return consumer()
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
+@add_start_docstrings(
+ """
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
+ """Sequence Classification Head for InternLM2 Model."""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = InternLM2Model(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError(
+ 'Cannot handle batch sizes > 1 if no padding token is defined.'
+ )
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(
+ input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device),
+ sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = 'regression'
+ elif self.num_labels > 1 and (labels.dtype
+ in (torch.long, torch.int)):
+ self.config.problem_type = 'single_label_classification'
+ else:
+ self.config.problem_type = 'multi_label_classification'
+
+ if self.config.problem_type == 'regression':
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == 'single_label_classification':
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == 'multi_label_classification':
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits, ) + transformer_outputs[1:]
+ return ((loss, ) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2
+@add_start_docstrings(
+ """
+The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like
+SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel):
+ """Question Answering model for InternLM2."""
+
+ base_model_prefix = 'transformer'
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = InternLM2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.transformer.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(
+ start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss, ) +
+ output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2
+@add_start_docstrings(
+ """
+ The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
+ """,
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
+ """Token classification model for InternLM2."""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = InternLM2Model(config)
+ if getattr(config, 'classifier_dropout', None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, 'hidden_dropout', None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits, ) + outputs[2:]
+ return ((loss, ) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/xtuner/_lite/modelings/internvl2/__init__.py b/xtuner/_lite/modelings/internvl2/__init__.py
new file mode 100644
index 000000000..8652be2d9
--- /dev/null
+++ b/xtuner/_lite/modelings/internvl2/__init__.py
@@ -0,0 +1,3 @@
+from .modeling_intern_vit import InternVisionModel
+
+__all__ = ['InternVisionModel']
diff --git a/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py b/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py
new file mode 100644
index 000000000..32f469c4b
--- /dev/null
+++ b/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py
@@ -0,0 +1,119 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+import os
+from typing import Union
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class InternVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of color channels in the input images (e.g., 3 for RGB).
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ qkv_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the queries and values in the self-attention layers.
+ hidden_size (`int`, *optional*, defaults to 3200):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_attention_heads (`int`, *optional*, defaults to 25):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 12800):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ qk_normalization (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the queries and keys in the self-attention layers.
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of hidden layers in the Transformer encoder.
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
+ Whether to use flash attention mechanism.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Dropout rate for stochastic depth.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 0.1):
+ A factor for layer scale.
+ """
+
+ model_type = 'intern_vit_6b'
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=14,
+ image_size=224,
+ qkv_bias=False,
+ hidden_size=3200,
+ num_attention_heads=25,
+ intermediate_size=12800,
+ qk_normalization=True,
+ num_hidden_layers=48,
+ use_flash_attn=True,
+ hidden_act='gelu',
+ norm_type='rms_norm',
+ layer_norm_eps=1e-6,
+ dropout=0.0,
+ drop_path_rate=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=0.1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.drop_path_rate = drop_path_rate
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.norm_type = norm_type
+ self.qkv_bias = qkv_bias
+ self.qk_normalization = qk_normalization
+ self.use_flash_attn = use_flash_attn
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if 'vision_config' in config_dict:
+ config_dict = config_dict['vision_config']
+
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
\ No newline at end of file
diff --git a/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py b/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py
new file mode 100644
index 000000000..a8d36d9e3
--- /dev/null
+++ b/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py
@@ -0,0 +1,432 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from timm.models.layers import DropPath
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (BaseModelOutput,
+ BaseModelOutputWithPooling)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from .configuration_intern_vit import InternVisionConfig
+
+try:
+ from flash_attn.bert_padding import pad_input, unpad_input
+ from flash_attn.flash_attn_interface import \
+ flash_attn_varlen_qkvpacked_func
+ has_flash_attn = True
+except:
+ print('FlashAttention2 is not installed.')
+ has_flash_attn = False
+
+logger = logging.get_logger(__name__)
+
+
+class FlashAttention(nn.Module):
+ """Implement the scaled dot product attention with softmax.
+ Arguments
+ ---------
+ softmax_scale: The temperature to use for the softmax attention.
+ (default: 1/sqrt(d_keys) where d_keys is computed at
+ runtime)
+ attention_dropout: The dropout rate to apply to the attention
+ (default: 0.0)
+ """
+
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
+ super().__init__()
+ self.softmax_scale = softmax_scale
+ self.dropout_p = attention_dropout
+
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
+ max_s=None, need_weights=False):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
+ if unpadded: (nnz, 3, h, d)
+ key_padding_mask: a bool tensor of shape (B, S)
+ """
+ assert not need_weights
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
+ assert qkv.is_cuda
+
+ if cu_seqlens is None:
+ batch_size = qkv.shape[0]
+ seqlen = qkv.shape[1]
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ max_s = seqlen
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
+ device=qkv.device)
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
+ output_unpad = flash_attn_varlen_qkvpacked_func(
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
+ indices, batch_size, seqlen),
+ 'b s (h d) -> b s h d', h=nheads)
+ else:
+ assert max_s is not None
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+
+ return output, None
+
+
+class InternRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+try:
+ from apex.normalization import FusedRMSNorm
+
+ InternRMSNorm = FusedRMSNorm # noqa
+
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
+except ImportError:
+ # using the normal InternRMSNorm
+ pass
+except Exception:
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
+ pass
+
+
+NORM2FN = {
+ 'rms_norm': InternRMSNorm,
+ 'layer_norm': nn.LayerNorm,
+}
+
+
+class InternVisionEmbeddings(nn.Module):
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.randn(1, 1, self.embed_dim),
+ )
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
+
+ def _get_pos_embed(self, pos_embed, H, W):
+ target_dtype = pos_embed.dtype
+ pos_embed = pos_embed.float().reshape(
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
+ return pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
+ batch_size, _, height, width = patch_embeds.shape
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ position_embedding = torch.cat([
+ self.position_embedding[:, :1, :],
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
+ ], dim=1)
+ embeddings = embeddings + position_embedding.to(target_dtype)
+ return embeddings
+
+
+class InternAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
+ if config.use_flash_attn and not has_flash_attn:
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
+ f' {self.num_heads}).'
+ )
+
+ self.scale = self.head_dim ** -0.5
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
+ self.attn_drop = nn.Dropout(config.attention_dropout)
+ self.proj_drop = nn.Dropout(config.dropout)
+
+ self.qk_normalization = config.qk_normalization
+
+ if self.qk_normalization:
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ if self.use_flash_attn:
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
+ )
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
+ return x
+
+
+class InternMLP(nn.Module):
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.act = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class InternVisionEncoderLayer(nn.Module):
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.norm_type = config.norm_type
+
+ self.attn = InternAttention(config)
+ self.mlp = InternMLP(config)
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
+
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ """
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
+
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
+
+ return hidden_states
+
+
+class InternVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`InternEncoderLayer`].
+
+ Args:
+ config (`InternConfig`):
+ The corresponding vision configuration for the `InternEncoder`.
+ """
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ # stochastic depth decay rule
+ # TODO: error
+ # dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
+ dpr = np.linspace(0.0, float(config.drop_path_rate), int(config.num_hidden_layers))
+ self.layers = nn.ModuleList([
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = True
+
+ def forward(
+ self,
+ inputs_embeds,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Embedded representation of the inputs. Should be float, not int tokens.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ hidden_states = inputs_embeds
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ encoder_layer,
+ hidden_states)
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ )
+ hidden_states = layer_outputs
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states
+ )
+
+
+class InternVisionModel(PreTrainedModel):
+ main_input_name = 'pixel_values'
+ _supports_flash_attn_2 = True
+ config_class = InternVisionConfig
+ _no_split_modules = ['InternVisionEncoderLayer']
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = InternVisionEmbeddings(config)
+ self.encoder = InternVisionEncoder(config)
+
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
+ pos_emb = self.embeddings.position_embedding
+ _, num_positions, embed_dim = pos_emb.shape
+ cls_emb = pos_emb[:, :1, :]
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
+ self.embeddings.image_size = new_size
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_embeds: Optional[torch.FloatTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None and pixel_embeds is None:
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
+
+ if pixel_embeds is not None:
+ hidden_states = pixel_embeds
+ else:
+ if len(pixel_values.shape) == 4:
+ hidden_states = self.embeddings(pixel_values)
+ else:
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = encoder_outputs.last_hidden_state
+ pooled_output = last_hidden_state[:, 0, :]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/xtuner/_lite/modelings/llava/__init__.py b/xtuner/_lite/modelings/llava/__init__.py
new file mode 100644
index 000000000..036324005
--- /dev/null
+++ b/xtuner/_lite/modelings/llava/__init__.py
@@ -0,0 +1,3 @@
+from .configuration_llava import EnhancedLlavaConfig
+from .modeling_llava import LlavaForConditionalGeneration
+from .processing_llava import LlavaProcessor
diff --git a/xtuner/_lite/modelings/llava/configuration_internlm2.py b/xtuner/_lite/modelings/llava/configuration_internlm2.py
new file mode 100644
index 000000000..8b8107947
--- /dev/null
+++ b/xtuner/_lite/modelings/llava/configuration_internlm2.py
@@ -0,0 +1,175 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" InternLM2 model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+
+
+# Modified from transformers.model.llama.configuration_llama.LlamaConfig
+class InternLM2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`InternLM2Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
+ to understand more about it. This value is necessary to ensure exact reproducibility
+ of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ """
+ _auto_class = 'AutoConfig'
+ model_type = 'internlm2'
+ keys_to_ignore_at_inference = ['past_key_values']
+
+ def __init__( # pylint: disable=W0102
+ self,
+ vocab_size=103168,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act='silu',
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ bias=True,
+ rope_theta=10000,
+ rope_scaling=None,
+ attn_implementation=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.bias = bias
+
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.attn_implementation = attn_implementation
+ if self.attn_implementation is None:
+ self.attn_implementation = 'eager'
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling,
+ dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
+ f'got {self.rope_scaling}')
+ rope_scaling_type = self.rope_scaling.get('type', None)
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
+ if rope_scaling_type is None or rope_scaling_type not in [
+ 'linear', 'dynamic'
+ ]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if (rope_scaling_factor is None
+ or not isinstance(rope_scaling_factor,
+ (float, int)) or rope_scaling_factor < 1.0):
+ raise ValueError(
+ f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
+ f'of type {type(rope_scaling_factor)}')
diff --git a/xtuner/_lite/modelings/llava/configuration_llava.py b/xtuner/_lite/modelings/llava/configuration_llava.py
new file mode 100644
index 000000000..f5ec7bbfa
--- /dev/null
+++ b/xtuner/_lite/modelings/llava/configuration_llava.py
@@ -0,0 +1,163 @@
+# coding=utf-8
+# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Llava model configuration"""
+import os
+from typing import Union
+from transformers.configuration_utils import PretrainedConfig, custom_object_save
+from transformers.utils import logging
+from transformers import CONFIG_MAPPING, AutoModelForCausalLM, AutoConfig
+
+logger = logging.get_logger(__name__)
+
+class EnhancedLlavaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Llava-9B.
+
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ ignore_index (`int`, *optional*, defaults to -100):
+ The ignore index for the loss function.
+ image_token_index (`int`, *optional*, defaults to 32000):
+ The image token index to encode the image prompt.
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function used by the multimodal projector.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`.
+ vision_feature_layer (`int`, *optional*, defaults to -2):
+ The index of the layer to select the vision feature.
+
+ Example:
+
+ ```python
+ >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
+
+ >>> # Initializing a CLIP-vision config
+ >>> vision_config = CLIPVisionConfig()
+
+ >>> # Initializing a Llama config
+ >>> text_config = LlamaConfig()
+
+ >>> # Initializing a Llava llava-1.5-7b style configuration
+ >>> configuration = LlavaConfig(vision_config, text_config)
+
+ >>> # Initializing a model from the llava-1.5-7b style configuration
+ >>> model = LlavaForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ _auto_class = 'AutoConfig'
+ model_type = "enhanced_llava"
+ is_composition = False
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ ignore_index=-100,
+ image_token_index=32000,
+ projector_hidden_act="gelu",
+ vision_feature_select_strategy="default",
+ vision_feature_layer=-2,
+ **kwargs,
+ ):
+ self.ignore_index = ignore_index
+ self.image_token_index = image_token_index
+ self.projector_hidden_act = projector_hidden_act
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(
+ "vision_feature_select_strategy should be one of 'default', 'full'."
+ f"Got: {vision_feature_select_strategy}"
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = (
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
+ )
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["clip_vision_model"](
+ intermediate_size=4096,
+ hidden_size=1024,
+ patch_size=14,
+ image_size=336,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ vocab_size=32000,
+ projection_dim=768,
+ )
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
+
+ if text_config["model_type"] == 'internlm2':
+ from .configuration_internlm2 import InternLM2Config
+ from .modeling_internlm2 import InternLM2ForCausalLM
+ AutoConfig.register('internlm2', InternLM2Config)
+ AutoModelForCausalLM.register(
+ InternLM2Config, InternLM2ForCausalLM)
+ text_config['auto_map']['AutoConfig'] = 'configuration_internlm2.InternLM2Config'
+ text_config['auto_map']['AutoModel'] = 'modeling_internlm2.InternLM2ForCausalLM'
+ text_config['auto_map']['AutoModelForCausalLM'] = 'modeling_internlm2.InternLM2ForCausalLM'
+ text_config = InternLM2Config(**text_config)
+ else:
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["llama"]()
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~PretrainedConfig.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ super().save_pretrained(save_directory, push_to_hub, **kwargs)
+
+ if self.text_config._auto_class is not None:
+ custom_object_save(self.text_config, save_directory, config=self.text_config)
+
+AutoConfig.register('enhanced_llava', EnhancedLlavaConfig, exist_ok=True)
\ No newline at end of file
diff --git a/xtuner/_lite/modelings/llava/modeling_internlm2.py b/xtuner/_lite/modelings/llava/modeling_internlm2.py
new file mode 100644
index 000000000..69ddc6196
--- /dev/null
+++ b/xtuner/_lite/modelings/llava/modeling_internlm2.py
@@ -0,0 +1,1899 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch InternLM2.5 model."""
+import math
+import queue
+import threading
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10, logging,
+ replace_return_docstrings)
+
+try:
+ from transformers.generation.streamers import BaseStreamer
+except Exception:
+ BaseStreamer = None
+
+from .configuration_internlm2 import InternLM2Config
+
+try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import (index_first_axis, pad_input,
+ unpad_input)
+except:
+ pass
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = 'InternLM2Config'
+
+
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+class InternLM2RMSNorm(nn.Module):
+ """InternLM2RMSNorm is equivalent to T5LayerNorm."""
+
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance +
+ self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm)
+
+
+class InternLM2RotaryEmbedding(nn.Module):
+ """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains."""
+
+ def __init__(self,
+ dim,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0):
+ super().__init__()
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base
+ **(torch.arange(0, self.dim, 2,
+ dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
+ position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(
+ device_type, str) and device_type != 'mps' else 'cpu'
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float()
+ @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
+ Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * ((self.scaling_factor * seq_len /
+ self.max_position_embeddings) -
+ (self.scaling_factor - 1))**(
+ self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base
+ **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(
+ x.device) / self.dim))
+ self.register_buffer(
+ 'inv_freq', inv_freq,
+ persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., :x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
+ """Applies Rotary Position Embedding to the query and key tensors.
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class InternLM2MLP(nn.Module):
+ """MLP for InternLM2 model."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.w1 = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False)
+ self.w3 = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False)
+ self.w2 = nn.Linear(
+ self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :,
+ None, :, :].expand(batch,
+ num_key_value_heads,
+ n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
+ head_dim)
+
+
+class InternLM2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self,
+ config: InternLM2Config,
+ layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
+ 'when creating this class.')
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
+ f' and `num_heads`: {self.num_heads}).')
+
+ self.wqkv = nn.Linear(
+ self.hidden_size,
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
+ bias=config.bias,
+ )
+ self.wo = nn.Linear(
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
+
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = InternLM2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling['type']
+ scaling_factor = self.config.rope_scaling['factor']
+ if scaling_type == 'linear':
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == 'dynamic':
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False, # pylint: disable=unused-argument
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ # split qkv_states by tp size
+ key_value_slicing = (self.num_key_value_heads *
+ self.head_dim) // self.config.pretraining_tp
+ qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
+ qkv_states = torch.cat(
+ [
+ F.linear(hidden_states, qkv_slice)
+ for qkv_slice in qkv_slices
+ ],
+ dim=-1 # pylint: disable=E1102
+ )
+ else:
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states,
+ 'b q h gs d -> b q (h gs) d').transpose(1, 2)
+ key_states = qkv_states[..., -2, :].transpose(1, 2)
+ value_states = qkv_states[..., -1, :].transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(
+ 2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
+ f' {attn_output.size()}')
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(
+ self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.wo.weight.split(
+ self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum([
+ F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
+ for i in range(self.config.pretraining_tp)
+ ])
+ else:
+ attn_output = self.wo(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class InternLM2FlashAttention2(InternLM2Attention):
+ """
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
+ # that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
+ # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1)
+ # produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` '
+ 'make sure to use `sdpa` in the mean time, and open an issue at '
+ 'https://github.com/huggingface/transformers')
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # dropout_rate = self.attention_dropout if self.training else 0.0
+ dropout_rate = 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (InternLM2RMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.wqkv.weight.dtype
+
+ logger.warning_once(
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
+ f' {target_dtype}.')
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate)
+
+ attn_output = attn_output.reshape(bsz, q_len,
+ self.hidden_size).contiguous()
+ attn_output = self.wo(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value # pylint: disable=E0606
+
+ def _flash_attention_forward(self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
+ # For details, please see the comment in InternLM2FlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask,
+ query_length)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
+ query_length) # pylint: disable=E0606
+ else:
+ attn_output = flash_attn_func( # pylint: disable=E0606
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal)
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
+ query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
+ attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis( # pylint: disable=E0606
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ value_layer = index_first_axis( # pylint: disable=E0606
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis( # pylint: disable=E0606
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
+ head_dim), indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606
+ query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2
+class InternLM2SdpaAttention(InternLM2Attention):
+ """
+ InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
+ to adapt to SDPA API.
+ """
+
+ # Adapted from InternLM2Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
+ # once this is implemented.
+ logger.warning_once(
+ 'InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` '
+ 'does not support `output_attentions=True`. Falling back to the manual attention implementation, '
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. '
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
+ # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == 'cuda' and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
+ # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
+ # options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = bool(causal_mask is None and q_len > 1)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.wo(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+INTERNLM2_ATTENTION_CLASSES = {
+ 'eager': InternLM2Attention,
+ 'flash_attention_2': InternLM2FlashAttention2,
+ 'sdpa': InternLM2SdpaAttention,
+}
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2
+class InternLM2DecoderLayer(nn.Module):
+ """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model."""
+
+ def __init__(self, config: InternLM2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+
+ self.attention = INTERNLM2_ATTENTION_CLASSES[
+ config.attn_implementation](
+ config=config, layer_idx=layer_idx)
+
+ self.feed_forward = InternLM2MLP(config)
+ self.attention_norm = InternLM2RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps)
+ self.ffn_norm = InternLM2RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
+ torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+ residual = hidden_states
+
+ hidden_states = self.attention_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.attention(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.ffn_norm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states, )
+
+ if output_attentions:
+ outputs += (self_attn_weights, )
+
+ if use_cache:
+ outputs += (present_key_value, )
+
+ return outputs
+
+
+InternLM2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+ Parameters:
+ config ([`InternLM2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
+@add_start_docstrings(
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2PreTrainedModel(PreTrainedModel):
+ """
+ InternLM2 pretraiend model's base class.
+ """
+
+ config_class = InternLM2Config
+ base_model_prefix = 'model'
+ supports_gradient_checkpointing = True
+ _no_split_modules = ['InternLM2DecoderLayer']
+ _skip_keys_device_placement = ['past_key_values']
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+InternLM2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2
+@add_start_docstrings(
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2Model(InternLM2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
+ Args:
+ config: InternLM2Config
+ """
+
+ _auto_class = 'AutoModel'
+
+ def __init__(self, config: InternLM2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.config = config
+
+ self.tok_embeddings = nn.Embedding(config.vocab_size,
+ config.hidden_size,
+ self.padding_idx)
+
+ self.layers = nn.ModuleList([
+ InternLM2DecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ])
+ self.norm = InternLM2RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.tok_embeddings(input_ids)
+
+ return_legacy_cache = False
+ if use_cache and not isinstance(
+ past_key_values,
+ Cache): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length(
+ ) if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device)
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
+ cache_position, past_key_values,
+ output_attentions)
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states, )
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[
+ 2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1], )
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states, )
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(
+ v for v in
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length
+ # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at
+ # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is
+ # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`.
+ # See more context in https://github.com/huggingface/transformers/pull/29114
+
+ if self.config.attn_implementation == 'flash_attention_2':
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length(
+ ) if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config.attn_implementation == 'sdpa' and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_length()
+ else:
+ target_length = (
+ attention_mask.shape[-1] if isinstance(
+ attention_mask, torch.Tensor) else past_seen_tokens +
+ sequence_length + 1)
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
+ if attention_mask.max() != 0:
+ raise ValueError(
+ 'Custom 4D attention mask should be passed in inverted form with max==0`'
+ )
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full((sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(
+ input_tensor.shape[0], 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone(
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :
+ mask_length] + attention_mask[:,
+ None,
+ None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :
+ mask_length] = causal_mask[:, :, :, :
+ mask_length].masked_fill(
+ padding_mask,
+ min_dtype)
+ if (self.config.attn_implementation == 'sdpa'
+ and attention_mask is not None
+ and attention_mask.device.type == 'cuda'
+ and not output_attentions):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(
+ causal_mask, min_dtype) # pylint: disable=E1120
+
+ return causal_mask
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM
+class InternLM2ForCausalLM(InternLM2PreTrainedModel):
+ """Causal language model (CLM) for InternLM2."""
+
+ _auto_class = 'AutoModelForCausalLM'
+ _tied_weights_keys = ['output.weight']
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = InternLM2Model(config)
+ self.vocab_size = config.vocab_size
+ self.output = nn.Linear(
+ config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ def get_output_embeddings(self):
+ return self.output
+
+ def set_output_embeddings(self, new_embeddings):
+ self.output = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ Returns:
+ Example:
+ ```python
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
+ >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ output_slices = self.output.weight.split(
+ self.vocab_size // self.config.pretraining_tp, dim=0)
+ logits = [
+ F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable
+ for i in range(self.config.pretraining_tp)
+ ]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.output(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits, ) + outputs[1:]
+ return (loss, ) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ past_length = 0
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ past_length = cache_position[
+ 0] if cache_position is not None else past_key_values.get_seq_length(
+ )
+ max_cache_length = (
+ torch.tensor(
+ past_key_values.get_max_length(),
+ device=input_ids.device)
+ if past_key_values.get_max_length() is not None else None)
+ cache_length = past_length if max_cache_length is None else torch.min(
+ max_cache_length, past_length)
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
+ if attention_mask is not None and attention_mask.shape[
+ 1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
+ past_length):]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (max_cache_length is not None and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length):
+ attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130
+
+ position_ids = kwargs.get('position_ids', None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {'inputs_embeds': inputs_embeds}
+ else:
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+ # recompiles graphs as the stride of the inputs is a guard.
+ # Ref: https://github.com/huggingface/transformers/pull/29114
+ # TODO: use `next_tokens` directly instead.
+ model_inputs = {'input_ids': input_ids.contiguous()}
+
+ input_length = position_ids.shape[
+ -1] if position_ids is not None else input_ids.shape[-1]
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_length,
+ past_length + input_length,
+ device=input_ids.device)
+ elif use_cache:
+ cache_position = cache_position[-input_length:]
+
+ model_inputs.update({
+ 'position_ids': position_ids,
+ 'cache_position': cache_position,
+ 'past_key_values': past_key_values,
+ 'use_cache': use_cache,
+ 'attention_mask': attention_mask,
+ })
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(
+ past_state.index_select(0, beam_idx.to(past_state.device))
+ for past_state in layer_past), )
+ return reordered_past
+
+ def build_inputs(self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ meta_instruction=''):
+ if history is None:
+ history = []
+ if tokenizer.add_bos_token:
+ prompt = ''
+ else:
+ prompt = tokenizer.bos_token
+ if meta_instruction:
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
+ for record in history:
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
+ return tokenizer([prompt], return_tensors='pt')
+
+ @torch.no_grad()
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ streamer: Optional[BaseStreamer] = None,
+ max_new_tokens: int = 1024,
+ do_sample: bool = True,
+ temperature: float = 0.8,
+ top_p: float = 0.8,
+ meta_instruction:
+ str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
+ '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory '
+ '(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
+ '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such '
+ 'as English and 中文.',
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
+ inputs = {
+ k: v.to(self.device)
+ for k, v in inputs.items() if torch.is_tensor(v)
+ }
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
+ eos_token_id = [
+ tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]
+ ]
+ outputs = self.generate(
+ **inputs,
+ streamer=streamer,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
+ response = response.split('<|im_end|>')[0]
+ history = history + [(query, response)]
+ return response, history
+
+ @torch.no_grad()
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_new_tokens: int = 1024,
+ do_sample: bool = True,
+ temperature: float = 0.8,
+ top_p: float = 0.8,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ """
+ Return a generator in format: (response, history)
+ Eg.
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
+ """
+ if BaseStreamer is None:
+ raise ModuleNotFoundError(
+ 'The version of `transformers` is too low. Please make sure '
+ 'that you have installed `transformers>=4.28.0`.')
+
+ response_queue = queue.Queue(maxsize=20)
+
+ class ChatStreamer(BaseStreamer):
+ """
+ Streamer used in generate to print words one by one.
+ """
+
+ def __init__(self, tokenizer) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.queue = response_queue
+ self.query = query
+ self.history = history
+ self.response = ''
+ self.cache = []
+ self.received_inputs = False
+ self.queue.put(
+ (self.response, history + [(self.query, self.response)]))
+
+ def put(self, value):
+ if len(value.shape) > 1 and value.shape[0] > 1:
+ raise ValueError('ChatStreamer only supports batch size 1')
+ elif len(value.shape) > 1:
+ value = value[0]
+
+ if not self.received_inputs:
+ # The first received value is input_ids, ignore here
+ self.received_inputs = True
+ return
+
+ self.cache.extend(value.tolist())
+ token = self.tokenizer.decode(
+ self.cache, skip_special_tokens=True)
+ if token.strip() != '<|im_end|>':
+ self.response = self.response + token
+ history = self.history + [(self.query, self.response)]
+ self.queue.put((self.response, history))
+ self.cache = []
+ else:
+ self.end()
+
+ def end(self):
+ self.queue.put(None)
+
+ def stream_producer():
+ return self.chat(
+ tokenizer=tokenizer,
+ query=query,
+ streamer=ChatStreamer(tokenizer=tokenizer),
+ history=history,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ **kwargs,
+ )
+
+ def consumer():
+ producer = threading.Thread(target=stream_producer)
+ producer.start()
+ while True:
+ res = response_queue.get()
+ if res is None:
+ return
+ yield res
+
+ return consumer()
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
+@add_start_docstrings(
+ """
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
+ """Sequence Classification Head for InternLM2 Model."""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = InternLM2Model(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError(
+ 'Cannot handle batch sizes > 1 if no padding token is defined.'
+ )
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(
+ input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device),
+ sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = 'regression'
+ elif self.num_labels > 1 and (labels.dtype
+ in (torch.long, torch.int)):
+ self.config.problem_type = 'single_label_classification'
+ else:
+ self.config.problem_type = 'multi_label_classification'
+
+ if self.config.problem_type == 'regression':
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == 'single_label_classification':
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == 'multi_label_classification':
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits, ) + transformer_outputs[1:]
+ return ((loss, ) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2
+@add_start_docstrings(
+ """
+The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like
+SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel):
+ """Question Answering model for InternLM2."""
+
+ base_model_prefix = 'transformer'
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = InternLM2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.transformer.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache,
+ List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(
+ start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss, ) +
+ output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2
+@add_start_docstrings(
+ """
+ The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
+ """,
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
+ """Token classification model for InternLM2."""
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = InternLM2Model(config)
+ if getattr(config, 'classifier_dropout', None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, 'hidden_dropout', None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits, ) + outputs[2:]
+ return ((loss, ) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/xtuner/_lite/modelings/llava/modeling_llava.py b/xtuner/_lite/modelings/llava/modeling_llava.py
new file mode 100644
index 000000000..b987db7b5
--- /dev/null
+++ b/xtuner/_lite/modelings/llava/modeling_llava.py
@@ -0,0 +1,573 @@
+# coding=utf-8
+# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Llava model."""
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers import PreTrainedModel
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache
+from transformers.modeling_outputs import ModelOutput
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers import AutoModel, AutoModelForCausalLM
+from .configuration_llava import EnhancedLlavaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlavaConfig"
+
+
+
+@dataclass
+# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
+class LlavaCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for Llava causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class LlavaMultiModalProjector(nn.Module):
+ def __init__(self, config: EnhancedLlavaConfig):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+LLAVA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAVA_START_DOCSTRING,
+)
+class LlavaPreTrainedModel(PreTrainedModel):
+ config_class = EnhancedLlavaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlavaVisionAttention"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ # important: this ported version of Llava isn't meant for training from scratch - only
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
+ std = (
+ self.config.initializer_range
+ if hasattr(self.config, "initializer_range")
+ else self.config.text_config.initializer_range
+ )
+
+ if hasattr(module, "class_embedding"):
+ module.class_embedding.data.normal_(mean=0.0, std=std)
+
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @property
+ def _supports_sdpa(self):
+ """
+ Retrieve language_model's attribute to check whether the model supports
+ SDPA or not.
+ """
+ return self.language_model._supports_sdpa
+
+
+LLAVA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
+ [`CLIPImageProcessor`] for processing images).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ vision_feature_layer (`int`, *optional*, defaults to -2):
+ The index of the layer to select the vision feature.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ """The LLAVA model which consists of a vision backbone and a language model.""",
+ LLAVA_START_DOCSTRING,
+)
+class LlavaForConditionalGeneration(LlavaPreTrainedModel):
+
+ _auto_class = 'AutoModel'
+
+ def __init__(self, config: EnhancedLlavaConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = AutoModelForCausalLM.from_config(
+ config.text_config,
+ attn_implementation=config._attn_implementation)
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def tie_weights(self):
+ return self.language_model.tie_weights()
+
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ # update vocab size
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
+ self.vocab_size = model_embeds.num_embeddings
+ return model_embeds
+
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
+ num_images, num_image_patches, embed_dim = image_features.shape
+ batch_size, sequence_length = input_ids.shape
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
+ # 1. Create a mask to know where special image tokens are
+ special_image_token_mask = input_ids == self.config.image_token_index
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
+ # Compute the maximum embed dimension
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged image-text sequence.
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ final_attention_mask = torch.zeros(
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
+ )
+ if labels is not None:
+ final_labels = torch.full(
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
+ )
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = inputs_embeds.device
+ batch_indices, non_image_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_image_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+ attention_mask = attention_mask.to(target_device)
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
+ if labels is not None:
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
+
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
+ image_to_overwrite = torch.full(
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
+ )
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
+
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
+ raise ValueError(
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
+ )
+
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ final_attention_mask |= image_to_overwrite
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
+
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
+
+ final_embedding[batch_indices, indices_to_mask] = 0
+
+ if labels is None:
+ final_labels = None
+
+ return final_embedding, final_attention_mask, final_labels, position_ids
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[int] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # ------------- start add this ----------------
+ if pixel_values is None and self.training:
+ # all of the input is text
+ # If not handled properly, deadlock can occur.
+ # print('===================all of the input is text==============')
+ image_size = self.config.vision_config.image_size
+ pixel_values = torch.zeros(input_ids.shape[0], 3, image_size, image_size,
+ dtype=torch.float32,
+ device=input_ids.device)
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ selected_image_feature = selected_image_feature
+ else:
+ raise ValueError(
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
+ )
+ image_features = self.multi_modal_projector(selected_image_feature)
+ inputs_embeds = inputs_embeds.to(image_features.dtype)
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
+ image_features[0:0], inputs_embeds, input_ids, attention_mask, labels
+ )
+ # ------------- end add this ----------------
+ # 2. Merge text and images
+ elif pixel_values is not None and input_ids.shape[1] != 1:
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ selected_image_feature = selected_image_feature
+ else:
+ raise ValueError(
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
+ )
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+ inputs_embeds = inputs_embeds.to(image_features.dtype)
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds, input_ids, attention_mask, labels
+ )
+
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
+ # generation with cache
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
+ # that are set to 0
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
+
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
+
+ # Get the target length
+ target_length = input_ids.shape[1]
+ past_length = first_layer_past_key_value.shape[-1]
+
+ extended_attention_mask = torch.ones(
+ (attention_mask.shape[0], past_length),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+
+ # Filter out only the tokens that can be un-attended, this can happen
+ # if one uses Llava + Fused modules where the cache on the
+ # first iteration is already big enough, or if one passes custom cache
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
+ new_batch_index = batch_index[valid_indices]
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
+
+ # Zero-out the places where we don't need to attend
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
+
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs[0]
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ shift_attention_mask = attention_mask[..., 1:]
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return LlavaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
+ ):
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+ elif self.config.image_token_index in input_ids:
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
+ # older attention values, as their corresponding values are not part of the input.
+ if cache_length < past_length and attention_mask is not None:
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ }
+ )
+ return model_inputs
+
+ def _reorder_cache(self, *args, **kwargs):
+ return self.language_model._reorder_cache(*args, **kwargs)
+
+AutoModel.register(EnhancedLlavaConfig, LlavaForConditionalGeneration, exist_ok=True)
+AutoModelForCausalLM.register(EnhancedLlavaConfig, LlavaForConditionalGeneration, exist_ok=True)
\ No newline at end of file
diff --git a/xtuner/_lite/modelings/llava/processing_llava.py b/xtuner/_lite/modelings/llava/processing_llava.py
new file mode 100644
index 000000000..230975575
--- /dev/null
+++ b/xtuner/_lite/modelings/llava/processing_llava.py
@@ -0,0 +1,137 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Llava.
+"""
+
+from typing import List, Optional, Union
+
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from transformers.utils import TensorType
+
+
+class LlavaProcessor(ProcessorMixin):
+ r"""
+ Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
+
+ [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`CLIPImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ images: ImageInput = None,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length=None,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ if images is not None:
+ image_inputs = self.image_processor(images, return_tensors=return_tensors)
+ else:
+ image_inputs = {}
+ text_inputs = self.tokenizer(
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
+ )
+
+ return BatchFeature(data={**text_inputs, **image_inputs})
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
\ No newline at end of file
diff --git a/xtuner/_lite/parallel/__init__.py b/xtuner/_lite/parallel/__init__.py
new file mode 100644
index 000000000..e631d05f9
--- /dev/null
+++ b/xtuner/_lite/parallel/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .comm import all_to_all, all_to_all_list, barrier
+from .sampler import LengthGroupedSampler, ParallelSampler, VLMLengthGroupedSampler
+from .sequence import * # noqa: F401, F403
+from .setup import (get_dp_mesh, get_fsdp_mesh, get_sp_mesh, get_tp_mesh,
+ get_world_mesh, get_same_data_mesh, setup_parallel,
+ get_ep_mesh, get_experts_fsdp_mesh)
+from .utils import MetaStateful
+
+__all__ = [
+ 'ParallelSampler',
+ 'LengthGroupedSampler',
+ 'VLMLengthGroupedSampler',
+ 'all_to_all',
+ 'all_to_all_list',
+ 'get_dp_mesh',
+ 'get_same_data_mesh',
+ 'get_fsdp_mesh',
+ 'get_sp_mesh',
+ 'get_tp_mesh',
+ 'get_world_mesh',
+ 'setup_parallel',
+ 'MetaStateful',
+ 'get_ep_mesh',
+ 'get_experts_fsdp_mesh',
+ 'barrier'
+]
diff --git a/xtuner/_lite/parallel/comm.py b/xtuner/_lite/parallel/comm.py
new file mode 100644
index 000000000..47daf4fb6
--- /dev/null
+++ b/xtuner/_lite/parallel/comm.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from torch.distributed.distributed_c10d import (_get_pg_default_device,
+ _object_to_tensor,
+ _tensor_to_object)
+
+
+# Modified from https://github.com/microsoft/DeepSpeed/blob/ffd0a0e3ef24bfd00c2e5f35019d2674cc01ec14/deepspeed/sequence/layer.py#L15 # noqa: E501
+def _all_to_all(
+ input: Tensor,
+ world_size: int,
+ group: dist.ProcessGroup,
+ scatter_dim: int,
+ gather_dim: int,
+):
+ input_list = [
+ t.contiguous()
+ for t in torch.tensor_split(input, world_size, scatter_dim)
+ ]
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
+ dist.all_to_all(output_list, input_list, group=group)
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+
+class _AllToAll(torch.autograd.Function):
+ """All-to-all communication.
+
+ Args:
+ input: Input tensor
+ sp_group: Sequence parallel process group
+ scatter_dim: Scatter dimension
+ gather_dim: Gather dimension
+ """
+
+ @staticmethod
+ def forward(ctx: Any, input: Tensor, sp_group: dist.ProcessGroup,
+ scatter_dim: int, gather_dim: int):
+ ctx.sp_group = sp_group
+ ctx.scatter_dim = scatter_dim
+ ctx.gather_dim = gather_dim
+ ctx.world_size = dist.get_world_size(sp_group)
+ output = _all_to_all(input, ctx.world_size, sp_group, scatter_dim,
+ gather_dim)
+ return output
+
+ @staticmethod
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple:
+ grad_output = _all_to_all(
+ grad_output,
+ ctx.world_size,
+ ctx.sp_group,
+ ctx.gather_dim,
+ ctx.scatter_dim,
+ )
+ return (
+ grad_output,
+ None,
+ None,
+ None,
+ )
+
+
+def all_to_all(
+ input: Tensor,
+ sp_group: dist.ProcessGroup,
+ scatter_dim: int = 2,
+ gather_dim: int = 1,
+):
+ """Convenience function to apply the all-to-all operation with scatter and
+ gather dimensions.
+
+ Notes:
+ We have wrapped the `torch.distributed.all_to_all` function to
+ enable automatic differentiation of the all-to-all operation.
+
+ Args:
+ input: The input tensor for which all-to-all communication is performed
+ sp_group: The sequence parallel process group.
+ scatter_dim: The dimension along which the input tensor is scattered
+ (default: 2).
+ gather_dim: The dimension along which the output tensor is gathered
+ (default: 1).
+
+ Returns:
+ The output tensor after the all-to-all communication.
+ """
+ return _AllToAll.apply(input, sp_group, scatter_dim, gather_dim)
+
+
+def all_to_all_list(object_list, group=None):
+ current_device = _get_pg_default_device(group)
+ rank = dist.get_rank(group)
+ world_size = dist.get_world_size(group)
+ tensor_list, size_list = zip(
+ *
+ [_object_to_tensor(obj, current_device, group) for obj in object_list])
+ tensor_list = list(tensor_list)
+ size_list = torch.cat(size_list)
+ buffer = [None] * world_size
+
+ dist.all_gather_object(buffer, size_list, group=group)
+ size_this_rank = []
+ for size_list in buffer:
+ size_this_rank.append(size_list[rank])
+
+ target_tensor_list = [
+ torch.empty(size.item(), dtype=torch.uint8, device=current_device)
+ for size in size_this_rank
+ ]
+ dist.all_to_all(target_tensor_list, tensor_list, group=group)
+
+ for i in range(len(target_tensor_list)):
+ obj_view = target_tensor_list[i].type(torch.uint8)
+ target_tensor_list[i] = _tensor_to_object(obj_view, size_this_rank[i],
+ group)
+
+ return target_tensor_list
+
+
+def barrier():
+ if not dist.is_available():
+ return
+
+ rank = dist.get_rank()
+ if rank == 0:
+ objects = [1]
+ else:
+ objects = [None]
+
+ dist.broadcast_object_list(objects, src=0)
+ return
diff --git a/xtuner/_lite/parallel/device.py b/xtuner/_lite/parallel/device.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/xtuner/_lite/parallel/fsdp/__init__.py b/xtuner/_lite/parallel/fsdp/__init__.py
new file mode 100644
index 000000000..4f5d2b80f
--- /dev/null
+++ b/xtuner/_lite/parallel/fsdp/__init__.py
@@ -0,0 +1,11 @@
+from .checkpointing import RECOMPUTE_MODULES, checkpoint_check_fn, checkpoint
+from .lazy import LoadWoInit, dp_lazy_init, dp_sp_lazy_init, lazy_init_megatron
+from .wrap import (all_required_grad_wrap_policy, layer_and_emb_wrap_policy,
+ layer_auto_wrap_policy, token_embedding_wrap_policy)
+from .clip_grad import clip_grad_norm_
+__all__ = [
+ 'RECOMPUTE_MODULES', 'checkpoint_check_fn', 'LoadWoInit', 'dp_lazy_init',
+ 'all_required_grad_wrap_policy', 'layer_auto_wrap_policy',
+ 'token_embedding_wrap_policy', 'lazy_init_megatron', 'dp_sp_lazy_init',
+ 'layer_and_emb_wrap_policy', 'checkpoint'
+]
diff --git a/xtuner/_lite/parallel/fsdp/checkpointing.py b/xtuner/_lite/parallel/fsdp/checkpointing.py
new file mode 100644
index 000000000..1fe998987
--- /dev/null
+++ b/xtuner/_lite/parallel/fsdp/checkpointing.py
@@ -0,0 +1,92 @@
+import random
+from typing import Any, Tuple
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import (
+ _checkpoint_without_reentrant_generator,
+ _DEFAULT_DETERMINISM_MODE,
+)
+from contextlib import nullcontext, contextmanager
+from torch.distributed._composable.contract import contract
+
+
+RECOMPUTE_MODULES = ('InternLM2DecoderLayer', 'CLIPEncoderLayer')
+
+
+def checkpoint_check_fn(submodule, target=RECOMPUTE_MODULES, selective=1.0):
+ ret = False
+ if type(submodule).__name__ in target:
+ if random.uniform(0, 1) < selective:
+ ret = True
+ return ret
+
+
+@contextmanager
+def _no_hook(module: nn.Module):
+ r"""
+ Disable hooks installed by checkpoint to avoid unintentional recursion
+ during backward recomputation.
+ """
+ orig_enable_hook = checkpoint.state(module).enable_hook
+ checkpoint.state(module).enable_hook = False
+ try:
+ yield
+ finally:
+ checkpoint.state(module).enable_hook = orig_enable_hook
+
+
+# Support **kwargs
+@contract()
+def checkpoint(module: nn.Module) -> nn.Module:
+ torch._C._log_api_usage_once("torch.distributed.checkpoint")
+
+ def forward_pre_hook(module: nn.Module, *args) -> None:
+ if checkpoint.state(module).enable_hook:
+ def context_fns():
+ return nullcontext(), _no_hook(module)
+
+ checkpoint.state(
+ module
+ )._ac_generator = _checkpoint_without_reentrant_generator(
+ module, True, context_fns, _DEFAULT_DETERMINISM_MODE, False, *args[0], **args[1]
+ )
+ next(checkpoint.state(module)._ac_generator)
+
+ def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
+ if checkpoint.state(module).enable_hook:
+ try:
+ next(checkpoint.state(module)._ac_generator)
+ except StopIteration:
+ pass
+ else:
+ raise RuntimeError(
+ "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
+ )
+
+ # Ensure that we no longer hold on to the generator. always_call=True helps ensure we
+ # clear this even in the case of exception in fwd pass.
+ checkpoint.state(module)._ac_generator = None
+
+ checkpoint.state(module).enable_hook = True
+ module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
+ module.register_forward_hook(forward_hook, prepend=True, always_call=True)
+ return module
+
+
+if __name__ == '__main__':
+
+ class MyModel(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.l1 = nn.Linear(10, 10)
+ self.l2 = nn.Linear(10, 10)
+
+ def forward(self, x, b, a=4, c=4):
+ print(b, a, c)
+ return self.l2(self.l1(x))
+
+ # from torch.distributed._composable.checkpoint_activation import checkpoint
+ model = MyModel()
+ checkpoint(model) # apply activation checkpointing only to l1
+ model(torch.zeros(2, 10), 2, a=5, c=6).sum().backward()
diff --git a/xtuner/_lite/parallel/fsdp/clip_grad.py b/xtuner/_lite/parallel/fsdp/clip_grad.py
new file mode 100644
index 000000000..a6cecbc95
--- /dev/null
+++ b/xtuner/_lite/parallel/fsdp/clip_grad.py
@@ -0,0 +1,88 @@
+from torch.nn.utils.clip_grad import _no_grad
+import torch
+from typing import List, Optional, Tuple, Union, Dict
+from torch import Tensor
+from torch import distributed as dist
+from torch.utils._foreach_utils import (
+ _device_has_foreach_support,
+ _group_tensors_by_device_and_dtype,
+ _has_foreach_support,
+)
+
+
+@_no_grad
+def clip_grad_norm_(
+ parameters,
+ fsdp_mesh,
+ max_norm: float,
+ norm_type: float = 2.0,
+ error_if_nonfinite: bool = False,
+ foreach= None,
+) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return torch.tensor(0.0)
+ first_device = grads[0].device
+
+ grouped_grads: Dict[
+ Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
+ ] = _group_tensors_by_device_and_dtype(
+ [grads]
+ ) # type: ignore[assignment]
+
+ norms: List[Tensor] = []
+ for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
+ foreach and _device_has_foreach_support(device)
+ ):
+ # for grouped_device_grads in group_tensors_by_device_mesh(device_grads).values():
+ norms.extend(torch._foreach_norm(device_grads, norm_type))
+ elif foreach:
+ raise RuntimeError(
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
+ )
+ else:
+ norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
+
+ local_sharded_norm = torch.linalg.vector_norm(
+ torch.stack([norm.to_local().to(first_device) for norm in norms]), norm_type, dtype=torch.float32
+ )
+
+ if norm_type == 2:
+ total_norm = local_sharded_norm**norm_type
+ dist.all_reduce(total_norm, group=fsdp_mesh.get_group(mesh_dim=0))
+ total_norm = total_norm ** (1 / norm_type)
+ else:
+ raise NotImplementedError
+
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
+ raise RuntimeError(
+ f"The total norm of order {norm_type} for gradients from "
+ "`parameters` is non-finite, so it cannot be clipped. To disable "
+ "this error and scale the gradients by the non-finite norm anyway, "
+ "set `error_if_nonfinite=False`"
+ )
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+ # when the gradients do not reside in CPU memory.
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+ for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
+ foreach and _device_has_foreach_support(device)
+ ):
+ torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
+ elif foreach:
+ raise RuntimeError(
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
+ )
+ else:
+ clip_coef_clamped_device = clip_coef_clamped.to(device)
+ for g in device_grads:
+ g.mul_(clip_coef_clamped_device.to(g.dtype))
+
+ return total_norm
diff --git a/xtuner/_lite/parallel/fsdp/lazy.py b/xtuner/_lite/parallel/fsdp/lazy.py
new file mode 100644
index 000000000..149f141fb
--- /dev/null
+++ b/xtuner/_lite/parallel/fsdp/lazy.py
@@ -0,0 +1,153 @@
+import torch
+import torch.distributed as dist
+from torch.distributed._tensor import DTensor, distribute_tensor
+
+from xtuner._lite import get_logger, get_torch_device_module
+
+logger = get_logger()
+
+DEVICE_MODULE = get_torch_device_module()
+
+
+@torch.no_grad
+def dp_lazy_init(module, module_map, dp_mesh):
+ device = DEVICE_MODULE.current_device()
+ module.to_empty(device=DEVICE_MODULE.current_device(), recurse=False)
+
+ if dp_mesh.get_local_rank() == 0:
+ master_module = module_map[module]
+ master_params = {
+ name: param
+ for name, param in master_module.named_parameters(recurse=False)
+ }
+ master_buffers = {
+ name: buffer
+ for name, buffer in master_module.named_buffers(recurse=False)
+ }
+
+ for name, param in module.named_parameters(recurse=False):
+
+ p_copy = master_params[name].to(device).to(param.dtype)
+ # if param.requires_grad:
+ # p_copy = p_copy.to(device).to(param.dtype)
+ # else:
+ # p_copy = p_copy.to(device)
+ param.data.copy_(p_copy)
+
+ for name, buffer in module.named_buffers(recurse=False):
+
+ b_copy = master_buffers[name].to(device).to(buffer.dtype)
+ # b_copy = b_copy.to(device)
+ buffer.data.copy_(b_copy)
+
+
+@torch.no_grad
+def dp_sp_lazy_init(module, module_map, dp_mesh, sp_mesh):
+ device = DEVICE_MODULE.current_device()
+ module.to_empty(device=DEVICE_MODULE.current_device(), recurse=False)
+
+ if dp_mesh.get_local_rank() == 0 and sp_mesh.get_local_rank() == 0:
+ master_module = module_map[module]
+ master_params = {
+ name: param
+ for name, param in master_module.named_parameters(recurse=False)
+ }
+ master_buffers = {
+ name: buffer
+ for name, buffer in master_module.named_buffers(recurse=False)
+ }
+
+ for name, param in module.named_parameters(recurse=False):
+ p_copy = master_params[name].to(device).to(param.dtype)
+ param.data.copy_(p_copy)
+
+ for name, buffer in module.named_buffers(recurse=False):
+ b_copy = master_buffers[name].to(device).to(buffer.dtype)
+ buffer.data.copy_(b_copy)
+
+
+@torch.no_grad
+def lazy_init_megatron(module, rank0_map, dp_mesh, tp_mesh=None, pp_mesh=None):
+ device = DEVICE_MODULE.current_device()
+
+ if dp_mesh.get_rank() == 0:
+ rank0_module = rank0_map[module]
+ rank0_params = {
+ name: param
+ for name, param in rank0_module.named_parameters(recurse=False)
+ }
+ rank0_buffers = {
+ name: buffer
+ for name, buffer in rank0_module.named_buffers(recurse=False)
+ }
+ else:
+ rank0_params = None
+ rank0_buffers = None
+
+ param_shapes = {
+ name: param.full_tensor().shape
+ if isinstance(param, DTensor) else param.shape
+ for name, param in module.named_parameters(recurse=False)
+ }
+
+ module.to_empty(device=DEVICE_MODULE.current_device(), recurse=False)
+
+ for name, param in module.named_parameters(recurse=False):
+ dtype = param.dtype
+ if dp_mesh.get_rank() == 0:
+ rank0_param = rank0_params[name].to(device).to(dtype)
+ else:
+ full_shape = param_shapes[name]
+ rank0_param = torch.zeros(full_shape, dtype=dtype, device=device)
+
+ dist.broadcast(rank0_param, src=0)
+
+ if isinstance(param, DTensor):
+ mesh = param.device_mesh
+ assert mesh == tp_mesh
+ placements = param.placements
+ rank0_param = distribute_tensor(rank0_param, mesh, placements)
+
+ param.data.copy_(rank0_param)
+ dist.barrier()
+
+ # TP does not shard buffers
+ for name, buffer in module.named_buffers(recurse=False):
+ if dp_mesh.get_rank() == 0:
+ rank0_buffer = rank0_buffers[name].to(device).to(buffer.dtype)
+ else:
+ rank0_buffer = torch.empty_like(buffer).to(device)
+
+ dist.broadcast(rank0_buffer, src=0)
+ buffer.data.copy_(rank0_buffer)
+
+
+class LoadWoInit:
+ """Context manager that disable parameter initialization."""
+
+ def __init__(self):
+ self.constant_ = torch.nn.init.constant_
+ self.zeros_ = torch.nn.init.zeros_
+ self.ones_ = torch.nn.init.ones_
+ self.uniform_ = torch.nn.init.uniform_
+ self.normal_ = torch.nn.init.normal_
+ self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_
+ self.kaiming_normal_ = torch.nn.init.kaiming_normal_
+
+ def __enter__(self, *args, **kwargs):
+ torch.nn.init.constant_ = lambda *args, **kwargs: None
+ torch.nn.init.zeros_ = lambda *args, **kwargs: None
+ torch.nn.init.ones_ = lambda *args, **kwargs: None
+ torch.nn.init.uniform_ = lambda *args, **kwargs: None
+ torch.nn.init.normal_ = lambda *args, **kwargs: None
+ torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None
+ torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
+
+ def __exit__(self, *args, **kwargs):
+ torch.nn.init.constant_ = self.constant_
+ torch.nn.init.zeros_ = self.zeros_
+ torch.nn.init.ones_ = self.ones_
+ torch.nn.init.uniform_ = self.uniform_
+ torch.nn.init.normal_ = self.normal_
+ torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_
+ torch.nn.init.kaiming_normal_ = self.kaiming_normal_
diff --git a/xtuner/_lite/parallel/fsdp/precision.py b/xtuner/_lite/parallel/fsdp/precision.py
new file mode 100644
index 000000000..b8d1e7249
--- /dev/null
+++ b/xtuner/_lite/parallel/fsdp/precision.py
@@ -0,0 +1,23 @@
+import torch
+from torch import nn
+
+
+def set_require_grad_param_to_fp32(model: nn.Module):
+
+ def traverse(module: nn.Module):
+
+ for m_name, child in module.named_children():
+
+ all_require_grad = True
+ for p_name, param in child.named_parameters():
+
+ if not param.requires_grad:
+ all_require_grad = False
+ break
+
+ if all_require_grad:
+ child.to(torch.float32)
+
+ traverse(child)
+
+ traverse(model)
diff --git a/xtuner/_lite/parallel/fsdp/wrap.py b/xtuner/_lite/parallel/fsdp/wrap.py
new file mode 100644
index 000000000..f1d97c071
--- /dev/null
+++ b/xtuner/_lite/parallel/fsdp/wrap.py
@@ -0,0 +1,80 @@
+from torch import nn
+
+from xtuner._lite import get_logger
+
+logger = get_logger()
+_LAYERS = [
+ 'InternLM2DecoderLayer', 'CLIPVisionModel', 'LlavaMultiModalProjector'
+]
+
+
+def layer_auto_wrap_policy(
+ module,
+ recurse: bool,
+ nonwrapped_numel: int,
+ layer_cls=_LAYERS,
+) -> bool:
+ if recurse:
+ # always recurse
+ return True
+ else:
+ # if not recursing, decide whether we should wrap for
+ # the leaf node or reminder
+ return module.__class__.__name__ in layer_cls
+
+
+def layer_and_emb_wrap_policy(
+ module,
+ recurse: bool,
+ nonwrapped_numel: int,
+ vocab_size,
+ layer_cls=_LAYERS,
+) -> bool:
+ if recurse:
+ # always recurse
+ return True
+ else:
+ # if not recursing, decide whether we should wrap for
+ # the leaf node or reminder
+ if module.__class__.__name__ in layer_cls or isinstance(
+ module, nn.Embedding):
+ return True
+ elif isinstance(module, nn.Linear):
+ return module.weight.size(0) == vocab_size
+ else:
+ return False
+
+
+def token_embedding_wrap_policy(
+ module,
+ recurse: bool,
+ nonwrapped_numel: int,
+ vocab_size: int,
+) -> bool:
+ if recurse:
+ # always recurse
+ return True
+
+ if isinstance(module, (nn.Embedding, nn.Linear)):
+ if module.weight.size(0) == vocab_size:
+ return True
+
+ return False
+
+
+def all_required_grad_wrap_policy(
+ module,
+ recurse: bool,
+ nonwrapped_numel: int,
+) -> bool:
+ if recurse:
+ # always recurse
+ return True
+
+ requires_grads = [p.requires_grad for p in module.parameters()]
+
+ if len(requires_grads) and all(requires_grads):
+ logger.debug(module)
+ return True
+
+ return False
diff --git a/xtuner/_lite/parallel/loss.py b/xtuner/_lite/parallel/loss.py
new file mode 100644
index 000000000..44fd286f8
--- /dev/null
+++ b/xtuner/_lite/parallel/loss.py
@@ -0,0 +1,18 @@
+from torch import distributed as dist
+from torch.distributed.nn.functional import all_reduce
+
+
+def dist_softmax(logits, mesh, dim=-1, temperature=1):
+
+ logits = logits / temperature
+
+ max_values = logits.max(dim, keepdim=True)
+ all_reduce(max_values, dist.ReduceOp.MAX, mesh.group())
+
+ exps = (logits - max_values).exp()
+ sums = exps.sum(dim, keepdim=True)
+ all_reduce(sums, dist.ReduceOp.SUM, mesh.group())
+
+ probs = exps / sums
+
+ return probs
diff --git a/xtuner/_lite/parallel/megatron/__init__.py b/xtuner/_lite/parallel/megatron/__init__.py
new file mode 100644
index 000000000..87e746247
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/__init__.py
@@ -0,0 +1,52 @@
+from .internlm2 import (megatron_internlm2, megatron_internlm2_casual,
+ megatron_internlm2_reward)
+
+from .qwen2 import (megatron_qwen2_casual, megatron_qwen2, megatron_qwen2_reward)
+from .internvl2 import megatron_internvl2_casual
+from .minicpmv import megatron_minicpmv_casual
+from .llama import megatron_llama, megatron_llama_casual
+from .janus import megatron_janus_casual
+
+MEGATRON_MAP = {
+ 'InternLM2ForCausalLM': megatron_internlm2_casual,
+ 'InternLM2ForRewardModel': megatron_internlm2_reward,
+ 'InternLM2Model': megatron_internlm2,
+ 'Qwen2ForCausalLM': megatron_qwen2_casual,
+ 'Qwen2Model': megatron_qwen2,
+ 'Qwen2ForRewardModel': megatron_qwen2_reward,
+ 'InternVLChatModel': megatron_internvl2_casual,
+ 'MiniCPMV': megatron_minicpmv_casual,
+ 'MultiModalityCausalLM': megatron_janus_casual,
+ 'LlamaModel': megatron_llama,
+ 'LlamaForCausalLM': megatron_llama_casual
+}
+
+
+def megatron_parallelize(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True,
+ **kwargs):
+
+ cls_name = model.__class__.__name__
+ if cls_name not in MEGATRON_MAP:
+ raise NotImplementedError
+
+ parallel_fn = MEGATRON_MAP[cls_name]
+
+ model = parallel_fn(
+ model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=tp_mesh,
+ pp_mesh=pp_mesh,
+ mp_policy=mp_policy,
+ recompute_ratio=recompute_ratio,
+ reshard_after_forward=reshard_after_forward,
+ **kwargs)
+
+ return model
diff --git a/xtuner/_lite/parallel/megatron/internlm2.py b/xtuner/_lite/parallel/megatron/internlm2.py
new file mode 100644
index 000000000..0da9437ff
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/internlm2.py
@@ -0,0 +1,303 @@
+from functools import partial
+from packaging import version
+
+import torch
+from torch import nn
+from torch.distributed._tensor import Replicate, distribute_tensor
+from torch.distributed.tensor.parallel import (ColwiseParallel,
+ PrepareModuleInput,
+ PrepareModuleOutput,
+ RowwiseParallel,
+ parallelize_module)
+
+from xtuner._lite import get_logger
+from ..fsdp.lazy import lazy_init_megatron
+from .utils import map_rank0_modules
+
+logger = get_logger()
+
+
+def _tp_internlm2(model, tp_mesh):
+
+ layer_tp_plan = {
+ # by default ColwiseParallel input layouts is replicated
+ # and RowwiseParallel output layouts is replicated
+ 'attention.wqkv':
+ ColwiseParallel(),
+ 'attention.wo':
+ RowwiseParallel(),
+ 'attention_norm':
+ PrepareModuleInput(
+ input_layouts=(Replicate(), ),
+ desired_input_layouts=(Replicate(), ),
+ use_local_output=True
+ ),
+ 'feed_forward.w1':
+ ColwiseParallel(),
+ 'feed_forward.w2':
+ RowwiseParallel(),
+ 'feed_forward.w3':
+ ColwiseParallel(),
+ 'ffn_norm':
+ PrepareModuleInput(
+ input_layouts=(Replicate(), ),
+ desired_input_layouts=(Replicate(), ),
+ use_local_output=True
+ )
+ }
+
+ tp_size = tp_mesh.size()
+ for layer in model.layers:
+ attention = layer.attention
+ num_key_value_heads = attention.num_key_value_heads
+ num_heads = attention.num_heads
+ hidden_size = attention.hidden_size
+
+ attention.num_heads = num_heads // tp_size
+ attention.num_key_value_heads = num_key_value_heads // tp_size
+ attention.hidden_size = hidden_size // tp_size
+
+ attn_norm = layer.attention_norm
+ attn_norm.register_parameter(
+ 'weight',
+ nn.Parameter(
+ distribute_tensor(attn_norm.weight, tp_mesh, [Replicate()])))
+
+ ffn_norm = layer.ffn_norm
+ ffn_norm.register_parameter(
+ 'weight',
+ nn.Parameter(
+ distribute_tensor(ffn_norm.weight, tp_mesh, [Replicate()])))
+
+ parallelize_module(
+ module=layer,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_tp_plan,
+ )
+ norm = model.norm
+ dist_norm_w = nn.Parameter(
+ distribute_tensor(norm.weight, tp_mesh, [Replicate()]))
+ norm.register_parameter('weight', dist_norm_w)
+
+ # emb = model.tok_embeddings
+ # dist_emb_w = nn.Parameter(
+ # distribute_tensor(emb.weight, tp_mesh, [Replicate()]))
+ # emb.register_parameter('weight', dist_emb_w)
+
+ model = parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'model.tok_embeddings':
+ RowwiseParallel(input_layouts=Replicate(), ),
+ 'model.norm':PrepareModuleInput(
+ input_layouts=(Replicate(),),
+ desired_input_layouts=(Replicate(),),
+ use_local_output=True
+ ),
+ })
+
+
+def megatron_internlm2(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+ if tp_mesh and tp_mesh.size() > 1:
+ _tp_internlm2(model, tp_mesh)
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+ from torch.distributed._composable import checkpoint
+ from torch.distributed._composable.fsdp import fully_shard
+ num_layers = len(model.layers)
+ num_recompute_layers = int(num_layers * recompute_ratio)
+
+ for i, block in enumerate(model.layers):
+
+ block.apply(param_init_fn)
+
+ # # As an optimization, do not reshard after forward for the last
+ # # transformer block since FSDP would prefetch it immediately
+ # if i < num_layers - 1:
+ # _reshard = reshard_after_forward
+ # else:
+ # _reshard = False
+
+ fully_shard(
+ block,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if i < num_recompute_layers:
+ checkpoint(block)
+
+ if version.parse(torch.__version__) >= version.parse("2.5.0"):
+ for layer_cur, layer_next in zip(model.layers[:-1], model.layers[1:]):
+ layer_cur.set_modules_to_forward_prefetch([layer_next])
+
+ model.tok_embeddings.apply(param_init_fn)
+ model.norm.apply(param_init_fn)
+
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+
+
+def megatron_internlm2_casual(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+ megatron_internlm2(
+ model.model,
+ rank0_model.model if dp_mesh.get_rank() == 0 else None,
+ dp_mesh,
+ tp_mesh=tp_mesh,
+ pp_mesh=pp_mesh,
+ mp_policy=mp_policy,
+ recompute_ratio=recompute_ratio,
+ reshard_after_forward=reshard_after_forward)
+
+ if tp_mesh and tp_mesh.size() > 1:
+ model = parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'output': ColwiseParallel(output_layouts=Replicate(), ),
+ })
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+ model.output.apply(param_init_fn)
+
+ from torch.distributed._composable.fsdp import fully_shard
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+
+
+def megatron_internlm2_reward(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+ megatron_internlm2(
+ model.model,
+ rank0_model.model if dp_mesh.get_rank() == 0 else None,
+ dp_mesh,
+ tp_mesh=tp_mesh,
+ pp_mesh=pp_mesh,
+ mp_policy=mp_policy,
+ recompute_ratio=recompute_ratio,
+ reshard_after_forward=reshard_after_forward)
+
+ if tp_mesh and tp_mesh.size() > 1:
+
+ head_0 = model.v_head[0]
+ dist_head_0 = nn.Parameter(
+ distribute_tensor(head_0.weight, tp_mesh, [Replicate()]))
+ head_0.register_parameter('weight', dist_head_0)
+
+ head_norm = model.v_head[1]
+ dist_head_norm = nn.Parameter(
+ distribute_tensor(head_norm.weight, tp_mesh, [Replicate()]))
+ dist_head_bias = nn.Parameter(
+ distribute_tensor(head_norm.bias, tp_mesh, [Replicate()]))
+ head_norm.register_parameter('weight', dist_head_norm)
+ head_norm.register_parameter('bias', dist_head_bias)
+
+ head_1 = model.v_head[3]
+ dist_head_1 = nn.Parameter(
+ distribute_tensor(head_1.weight, tp_mesh, [Replicate()]))
+ head_1.register_parameter('weight', dist_head_1)
+
+
+ parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'v_head.0': PrepareModuleInput(
+ input_layouts=(Replicate(),),
+ desired_input_layouts=(Replicate(),),
+ ),
+ 'v_head.0': PrepareModuleOutput(
+ output_layouts=(Replicate(),),
+ desired_output_layouts=(Replicate(),),
+ use_local_output=True
+ ),
+ 'v_head.1': PrepareModuleInput(
+ input_layouts=(Replicate(),),
+ desired_input_layouts=(Replicate(),),
+ ),
+ 'v_head.1': PrepareModuleOutput(
+ output_layouts=(Replicate(),),
+ desired_output_layouts=(Replicate(),),
+ use_local_output=True
+ ),
+ 'v_head.3': PrepareModuleInput(
+ input_layouts=(Replicate(),),
+ desired_input_layouts=(Replicate(),),
+ ),
+ 'v_head.3': PrepareModuleOutput(
+ output_layouts=(Replicate(),),
+ desired_output_layouts=(Replicate(),),
+ use_local_output=True
+ ),
+
+ })
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+ model.v_head.apply(param_init_fn)
+
+ from torch.distributed._composable.fsdp import fully_shard
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
diff --git a/xtuner/_lite/parallel/megatron/internvl2.py b/xtuner/_lite/parallel/megatron/internvl2.py
new file mode 100644
index 000000000..0796a033b
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/internvl2.py
@@ -0,0 +1,122 @@
+from functools import partial
+from torch.distributed._composable.fsdp import fully_shard
+from xtuner._lite import get_logger
+from ..fsdp.lazy import lazy_init_megatron
+from .utils import map_rank0_modules
+from ..fsdp import checkpoint
+
+logger = get_logger()
+
+
+def megatron_internvl2_casual(meta_model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True,
+ small_model=True): # if 70b model, set to False
+ if tp_mesh.size() > 1:
+ raise NotImplementedError
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(meta_model, rank0_model)
+ else:
+ rank0_map = None
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+ num_layers = len(meta_model.language_model.model.layers)
+ num_recompute_layers = int(num_layers * recompute_ratio)
+ for i, block in enumerate(meta_model.language_model.model.layers):
+ block.apply(param_init_fn)
+
+ fully_shard(
+ block,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if i < num_recompute_layers:
+ checkpoint(block)
+
+ has_forward_prefetch = hasattr(meta_model.language_model.model.layers[0], 'set_modules_to_forward_prefetch')
+ if has_forward_prefetch:
+ for layer_cur, layer_next in zip(meta_model.language_model.model.layers[:-1],
+ meta_model.language_model.model.layers[1:]):
+ layer_cur.set_modules_to_forward_prefetch([layer_next])
+
+ if small_model:
+ meta_model.vision_model.apply(param_init_fn)
+ fully_shard(
+ meta_model.vision_model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ for i, layers in enumerate(meta_model.vision_model.encoder.layers):
+ checkpoint(layers)
+
+ if has_forward_prefetch:
+ meta_model.vision_model.set_modules_to_forward_prefetch([meta_model.language_model.model.layers[0]])
+ else:
+ # visual
+ for i, block in enumerate(meta_model.vision_model.encoder.layers):
+ block.apply(param_init_fn)
+
+ fully_shard(
+ block,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+ checkpoint(block)
+
+ if has_forward_prefetch:
+ for layer_cur, layer_next in zip(meta_model.vision_model.encoder.layers[:-1],
+ meta_model.vision_model.encoder.layers[1:]):
+ layer_cur.set_modules_to_forward_prefetch([layer_next])
+
+ meta_model.vision_model.encoder.layers[-1].set_modules_to_forward_prefetch([meta_model.language_model.model.layers[0]])
+
+ meta_model.vision_model.embeddings.apply(param_init_fn)
+ fully_shard(
+ meta_model.vision_model.embeddings,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+ if has_forward_prefetch:
+ meta_model.vision_model.embeddings.set_modules_to_forward_prefetch([meta_model.vision_model.encoder.layers[0]])
+
+ meta_model.mlp1.apply(param_init_fn)
+ try:
+ meta_model.language_model.model.tok_embeddings.apply(param_init_fn)
+ except AttributeError:
+ meta_model.language_model.model.embed_tokens.apply(param_init_fn)
+
+ meta_model.language_model.model.norm.apply(param_init_fn)
+ try:
+ meta_model.language_model.output.apply(param_init_fn)
+ except AttributeError:
+ meta_model.language_model.lm_head.apply(param_init_fn)
+
+ model = fully_shard(
+ meta_model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward) # False is zero2, True is zero3
+
+ if has_forward_prefetch and small_model:
+ model.set_modules_to_forward_prefetch([model.vision_model])
+ elif has_forward_prefetch:
+ model.set_modules_to_forward_prefetch([model.vision_model.embeddings])
+ return model
diff --git a/xtuner/_lite/parallel/megatron/janus.py b/xtuner/_lite/parallel/megatron/janus.py
new file mode 100644
index 000000000..e8eec0477
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/janus.py
@@ -0,0 +1,125 @@
+from functools import partial
+from torch.distributed._composable.fsdp import fully_shard
+from xtuner._lite import get_logger
+from ..fsdp.lazy import lazy_init_megatron
+from .utils import map_rank0_modules
+from ..fsdp import checkpoint
+
+logger = get_logger()
+
+
+def megatron_janus_casual(meta_model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True,
+ freeze_style='mode1'):
+ if tp_mesh.size() > 1:
+ raise NotImplementedError
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(meta_model, rank0_model)
+ else:
+ rank0_map = None
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+ num_layers = len(meta_model.language_model.model.layers)
+ num_recompute_layers = int(num_layers * recompute_ratio)
+ for i, block in enumerate(meta_model.language_model.model.layers):
+ block.apply(param_init_fn)
+
+ fully_shard(
+ block,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if i < num_recompute_layers:
+ checkpoint(block)
+
+ if freeze_style == 'mode1':
+ meta_model.language_model.lm_head.apply(param_init_fn)
+ fully_shard(
+ meta_model.language_model.lm_head,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+ meta_model.gen_head.apply(param_init_fn)
+ fully_shard(
+ meta_model.gen_head,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+ meta_model.aligner.apply(param_init_fn)
+ fully_shard(
+ meta_model.aligner,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+ meta_model.gen_aligner.apply(param_init_fn)
+ fully_shard(
+ meta_model.gen_aligner,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ meta_model.vision_model.apply(param_init_fn)
+ meta_model.gen_vision_model.apply(param_init_fn)
+ meta_model.gen_embed.apply(param_init_fn)
+ meta_model.language_model.model.embed_tokens.apply(param_init_fn)
+ meta_model.language_model.model.norm.apply(param_init_fn)
+
+ model = fully_shard(
+ meta_model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward) # False is zero2, True is zero3
+
+ # TODO: Bug
+ # model.set_reshard_after_backward(False)
+
+ elif freeze_style == 'mode2':
+ meta_model.gen_vision_model.apply(param_init_fn)
+ fully_shard(
+ meta_model.gen_vision_model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+ meta_model.gen_vision_model.set_reshard_after_backward(False)
+
+ meta_model.vision_model.apply(param_init_fn)
+ fully_shard(
+ meta_model.vision_model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+ meta_model.vision_model.set_reshard_after_backward(False)
+
+ meta_model.gen_head.apply(param_init_fn)
+ meta_model.aligner.apply(param_init_fn)
+ meta_model.gen_aligner.apply(param_init_fn)
+ meta_model.gen_embed.apply(param_init_fn)
+ meta_model.language_model.model.embed_tokens.apply(param_init_fn)
+ meta_model.language_model.model.norm.apply(param_init_fn)
+ meta_model.language_model.lm_head.apply(param_init_fn)
+ model = fully_shard(
+ meta_model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+
+ return model
diff --git a/xtuner/_lite/parallel/megatron/llama.py b/xtuner/_lite/parallel/megatron/llama.py
new file mode 100644
index 000000000..e100e88f0
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/llama.py
@@ -0,0 +1,216 @@
+from functools import partial
+from packaging import version
+
+import torch
+from torch import nn
+from torch.distributed._tensor import Replicate, distribute_tensor
+from torch.distributed.tensor.parallel import (ColwiseParallel,
+ PrepareModuleInput,
+ RowwiseParallel,
+ parallelize_module)
+
+from xtuner._lite import get_logger
+from ..fsdp.lazy import lazy_init_megatron
+from .utils import map_rank0_modules
+
+logger = get_logger()
+
+
+def _tp_llama(model, tp_mesh):
+
+ layer_tp_plan = {
+
+ # by default ColwiseParallel input layouts is replicated
+ # and RowwiseParallel output layouts is replicated
+ 'self_attn.q_proj':
+ ColwiseParallel(),
+ 'self_attn.k_proj':
+ ColwiseParallel(),
+ 'self_attn.v_proj':
+ ColwiseParallel(),
+ 'self_attn.o_proj':
+ RowwiseParallel(),
+ 'input_layernorm':
+ PrepareModuleInput(
+ input_layouts=(Replicate(), ),
+ desired_input_layouts=(Replicate(), ),
+ use_local_output=True
+ ),
+ 'mlp.up_proj':
+ ColwiseParallel(),
+ 'mlp.down_proj':
+ RowwiseParallel(),
+ 'mlp.gate_proj':
+ ColwiseParallel(),
+ 'post_attention_layernorm':
+ PrepareModuleInput(
+ input_layouts=(Replicate(), ),
+ desired_input_layouts=(Replicate(), ),
+ use_local_output=True
+ )
+ }
+
+ tp_size = tp_mesh.size()
+ for layer in model.layers:
+ attention = layer.self_attn
+ num_key_value_heads = attention.num_key_value_heads
+ num_heads = attention.num_heads
+ hidden_size = attention.hidden_size
+
+ attention.num_heads = num_heads // tp_size
+ attention.num_key_value_heads = num_key_value_heads // tp_size
+ attention.hidden_size = hidden_size // tp_size
+
+ attn_norm = layer.input_layernorm
+ attn_norm.register_parameter(
+ 'weight',
+ nn.Parameter(
+ distribute_tensor(attn_norm.weight, tp_mesh, [Replicate()])))
+
+ ffn_norm = layer.post_attention_layernorm
+ ffn_norm.register_parameter(
+ 'weight',
+ nn.Parameter(
+ distribute_tensor(ffn_norm.weight, tp_mesh, [Replicate()])))
+
+ parallelize_module(
+ module=layer,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_tp_plan,
+ )
+
+ norm = model.norm
+ dist_norm_w = nn.Parameter(
+ distribute_tensor(norm.weight, tp_mesh, [Replicate()]))
+ norm.register_parameter('weight', dist_norm_w)
+
+ model = parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'model.embed_tokens':
+ RowwiseParallel(input_layouts=Replicate(), ),
+ 'model.norm':PrepareModuleInput(
+ input_layouts=(Replicate(),),
+ desired_input_layouts=(Replicate(),),
+ use_local_output=True
+ ),
+ })
+
+
+def megatron_llama(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+
+ if tp_mesh and tp_mesh.size() > 1:
+ _tp_llama(model, tp_mesh)
+
+ from torch.distributed._composable import checkpoint
+ from torch.distributed._composable.fsdp import fully_shard
+ num_layers = len(model.layers)
+ num_recompute_layers = int(num_layers * recompute_ratio)
+
+ for i, block in enumerate(model.layers):
+
+ block.apply(param_init_fn)
+
+ # # As an optimization, do not reshard after forward for the last
+ # # transformer block since FSDP would prefetch it immediately
+ # if i < num_layers - 1:
+ # _reshard = reshard_after_forward
+ # else:
+ # _reshard = False
+
+ fully_shard(
+ block,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if i < num_recompute_layers:
+ checkpoint(block)
+
+ if version.parse(torch.__version__) >= version.parse("2.5.0"):
+ for layer_cur, layer_next in zip(model.layers[:-1], model.layers[1:]):
+ layer_cur.set_modules_to_forward_prefetch([layer_next])
+
+
+ model.embed_tokens.apply(param_init_fn)
+ model.norm.apply(param_init_fn)
+ if hasattr(model, 'rotary_emb'):
+ model.rotary_emb.apply(param_init_fn)
+
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+
+
+def megatron_llama_casual(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+ megatron_llama(
+ model.model,
+ rank0_model.model if dp_mesh.get_rank() == 0 else None,
+ dp_mesh,
+ tp_mesh=tp_mesh,
+ pp_mesh=pp_mesh,
+ mp_policy=mp_policy,
+ recompute_ratio=recompute_ratio,
+ reshard_after_forward=reshard_after_forward)
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+
+ if tp_mesh and tp_mesh.size() > 1:
+ model = parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'lm_head': ColwiseParallel(output_layouts=Replicate(), ),
+ })
+
+ model.lm_head.apply(param_init_fn)
+
+ from torch.distributed._composable.fsdp import fully_shard
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
diff --git a/xtuner/_lite/parallel/megatron/minicpmv.py b/xtuner/_lite/parallel/megatron/minicpmv.py
new file mode 100644
index 000000000..2e667cca0
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/minicpmv.py
@@ -0,0 +1,79 @@
+from functools import partial
+from torch.distributed._composable.fsdp import fully_shard
+from xtuner._lite import get_logger
+from ..fsdp.lazy import lazy_init_megatron
+from .utils import map_rank0_modules
+from ..fsdp import checkpoint
+
+logger = get_logger()
+
+
+def megatron_minicpmv_casual(meta_model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+ if tp_mesh.size() > 1:
+ raise NotImplementedError
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(meta_model, rank0_model)
+ else:
+ rank0_map = None
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+ # visual
+ meta_model.vpm.apply(param_init_fn)
+ fully_shard(
+ meta_model.vpm,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+ for i, layers in enumerate(meta_model.vpm.encoder.layers):
+ checkpoint(layers)
+
+ # resampler
+ meta_model.resampler.apply(param_init_fn)
+ fully_shard(
+ meta_model.resampler,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ # llm
+ num_layers = len(meta_model.llm.model.layers)
+ num_recompute_layers = int(num_layers * recompute_ratio)
+ for i, block in enumerate(meta_model.llm.model.layers):
+ block.apply(param_init_fn)
+
+ fully_shard(
+ block,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if i < num_recompute_layers:
+ checkpoint(block)
+
+ meta_model.llm.model.embed_tokens.apply(param_init_fn)
+ meta_model.llm.model.norm.apply(param_init_fn)
+ meta_model.llm.lm_head.apply(param_init_fn)
+
+ model = fully_shard(
+ meta_model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward) # False is zero2, True is zero3
+ return model
diff --git a/xtuner/_lite/parallel/megatron/qwen2.py b/xtuner/_lite/parallel/megatron/qwen2.py
new file mode 100644
index 000000000..90c65309b
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/qwen2.py
@@ -0,0 +1,270 @@
+from functools import partial
+from packaging import version
+
+import torch
+from torch import nn
+from torch.distributed._tensor import Replicate, distribute_tensor
+from torch.distributed.tensor.parallel import (ColwiseParallel,
+ PrepareModuleInput,
+ RowwiseParallel,
+ parallelize_module)
+
+from xtuner._lite import get_logger
+from ..fsdp.lazy import lazy_init_megatron
+from .utils import map_rank0_modules
+
+logger = get_logger()
+
+
+def _tp_qwen2(model, tp_mesh):
+
+ layer_tp_plan = {
+
+ # by default ColwiseParallel input layouts is replicated
+ # and RowwiseParallel output layouts is replicated
+ 'self_attn.q_proj':
+ ColwiseParallel(),
+ 'self_attn.k_proj':
+ ColwiseParallel(),
+ 'self_attn.v_proj':
+ ColwiseParallel(),
+ 'self_attn.o_proj':
+ RowwiseParallel(),
+ 'input_layernorm':
+ PrepareModuleInput(
+ input_layouts=(Replicate(), ),
+ desired_input_layouts=(Replicate(), ),
+ use_local_output=True
+ ),
+ 'mlp.up_proj':
+ ColwiseParallel(),
+ 'mlp.down_proj':
+ RowwiseParallel(),
+ 'mlp.gate_proj':
+ ColwiseParallel(),
+ 'post_attention_layernorm':
+ PrepareModuleInput(
+ input_layouts=(Replicate(), ),
+ desired_input_layouts=(Replicate(), ),
+ use_local_output=True
+ )
+ }
+
+ tp_size = tp_mesh.size()
+ for layer in model.layers:
+ attention = layer.self_attn
+ num_key_value_heads = attention.num_key_value_heads
+ num_heads = attention.num_heads
+ hidden_size = attention.hidden_size
+
+ attention.num_heads = num_heads // tp_size
+ attention.num_key_value_heads = num_key_value_heads // tp_size
+ attention.hidden_size = hidden_size // tp_size
+
+ attn_norm = layer.input_layernorm
+ attn_norm.register_parameter(
+ 'weight',
+ nn.Parameter(
+ distribute_tensor(attn_norm.weight, tp_mesh, [Replicate()])))
+
+ ffn_norm = layer.post_attention_layernorm
+ ffn_norm.register_parameter(
+ 'weight',
+ nn.Parameter(
+ distribute_tensor(ffn_norm.weight, tp_mesh, [Replicate()])))
+
+ parallelize_module(
+ module=layer,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_tp_plan,
+ )
+
+ norm = model.norm
+ dist_norm_w = nn.Parameter(
+ distribute_tensor(norm.weight, tp_mesh, [Replicate()]))
+ norm.register_parameter('weight', dist_norm_w)
+
+ # emb = model.embed_tokens
+ # dist_emb_w = nn.Parameter(
+ # distribute_tensor(emb.weight, tp_mesh, [Replicate()]))
+ # emb.register_parameter('weight', dist_emb_w)
+
+ # model.norm.apply(param_init_fn)
+ # model.embed_tokens.apply(param_init_fn)
+
+ model = parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'model.embed_tokens':
+ RowwiseParallel(input_layouts=Replicate(), ),
+ 'model.norm':PrepareModuleInput(
+ input_layouts=(Replicate(),),
+ desired_input_layouts=(Replicate(),),
+ use_local_output=True
+ ),
+ })
+
+
+def megatron_qwen2(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+
+ if tp_mesh and tp_mesh.size() > 1:
+ _tp_qwen2(model, tp_mesh)
+
+ from torch.distributed._composable import checkpoint
+ from torch.distributed._composable.fsdp import fully_shard
+ num_layers = len(model.layers)
+ num_recompute_layers = int(num_layers * recompute_ratio)
+
+ for i, block in enumerate(model.layers):
+
+ block.apply(param_init_fn)
+
+ # # As an optimization, do not reshard after forward for the last
+ # # transformer block since FSDP would prefetch it immediately
+ # if i < num_layers - 1:
+ # _reshard = reshard_after_forward
+ # else:
+ # _reshard = False
+
+ fully_shard(
+ block,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if i < num_recompute_layers:
+ checkpoint(block)
+
+ if version.parse(torch.__version__) >= version.parse("2.5.0"):
+ for layer_cur, layer_next in zip(model.layers[:-1], model.layers[1:]):
+ layer_cur.set_modules_to_forward_prefetch([layer_next])
+
+ model.embed_tokens.apply(param_init_fn)
+ model.norm.apply(param_init_fn)
+
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+
+
+def megatron_qwen2_casual(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+ megatron_qwen2(
+ model.model,
+ rank0_model.model if dp_mesh.get_rank() == 0 else None,
+ dp_mesh,
+ tp_mesh=tp_mesh,
+ pp_mesh=pp_mesh,
+ mp_policy=mp_policy,
+ recompute_ratio=recompute_ratio,
+ reshard_after_forward=reshard_after_forward)
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+
+ if tp_mesh and tp_mesh.size() > 1:
+ model = parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'lm_head': ColwiseParallel(output_layouts=Replicate(), ),
+ })
+
+ model.lm_head.apply(param_init_fn)
+
+ from torch.distributed._composable.fsdp import fully_shard
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
+
+
+def megatron_qwen2_reward(model,
+ rank0_model,
+ dp_mesh,
+ tp_mesh=None,
+ pp_mesh=None,
+ mp_policy=None,
+ recompute_ratio=1.0,
+ reshard_after_forward=True):
+ megatron_qwen2(
+ model.model,
+ rank0_model.model if dp_mesh.get_rank() == 0 else None,
+ dp_mesh,
+ tp_mesh=tp_mesh,
+ pp_mesh=pp_mesh,
+ mp_policy=mp_policy,
+ recompute_ratio=recompute_ratio,
+ reshard_after_forward=reshard_after_forward)
+
+ if dp_mesh.get_rank() == 0:
+ rank0_map = map_rank0_modules(model, rank0_model)
+ else:
+ rank0_map = None
+
+ if tp_mesh and tp_mesh.size() > 1:
+ parallelize_module(
+ module=model,
+ device_mesh=tp_mesh,
+ parallelize_plan={
+ 'score.0': ColwiseParallel(),
+ 'score.2': RowwiseParallel(),
+ })
+
+ param_init_fn = partial(
+ lazy_init_megatron,
+ rank0_map=rank0_map,
+ dp_mesh=dp_mesh,
+ tp_mesh=tp_mesh,
+ )
+
+ model.v_head.apply(param_init_fn)
+
+ from torch.distributed._composable.fsdp import fully_shard
+ fully_shard(
+ model,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward)
diff --git a/xtuner/_lite/parallel/megatron/utils.py b/xtuner/_lite/parallel/megatron/utils.py
new file mode 100644
index 000000000..5e4ed2dd2
--- /dev/null
+++ b/xtuner/_lite/parallel/megatron/utils.py
@@ -0,0 +1,7 @@
+def map_rank0_modules(model, rank0_model):
+ rank0_modules = {name: mod for name, mod in rank0_model.named_modules()}
+ rank0_map = {
+ mod: rank0_modules[name]
+ for name, mod in model.named_modules()
+ }
+ return rank0_map
diff --git a/xtuner/_lite/parallel/sampler.py b/xtuner/_lite/parallel/sampler.py
new file mode 100644
index 000000000..91b286f86
--- /dev/null
+++ b/xtuner/_lite/parallel/sampler.py
@@ -0,0 +1,398 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import random
+from typing import Iterator, Optional, Sized
+
+import torch
+from mmengine.dist import sync_random_seed
+from torch.distributed.device_mesh import DeviceMesh
+from torch.utils.data import ConcatDataset as TorchConcatDataset
+from torch.utils.data import Sampler
+
+
+class ParallelSampler(Sampler):
+ """The default data sampler for both distributed and non-distributed
+ environment.
+
+ It has several differences from the PyTorch ``DistributedSampler`` as
+ below:
+
+ 1. This sampler supports non-distributed environment.
+
+ 2. The round up behaviors are a little different.
+
+ - If ``round_up=True``, this sampler will add extra samples to make the
+ number of samples is evenly divisible by the world size. And
+ this behavior is the same as the ``DistributedSampler`` with
+ ``drop_last=False``.
+ - If ``round_up=False``, this sampler won't remove or add any samples
+ while the ``DistributedSampler`` with ``drop_last=True`` will remove
+ tail samples.
+
+ Args:
+ dataset (Sized): The dataset.
+ shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
+ seed (int, optional): Random seed used to shuffle the sampler if
+ :attr:`shuffle=True`. This number should be identical across all
+ processes in the distributed group. Defaults to None.
+ round_up (bool): Whether to add extra samples to make the number of
+ samples evenly divisible by the world size. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ dataset: Sized,
+ dp_mesh: DeviceMesh,
+ global_batch_size: int,
+ shuffle: bool = True,
+ seed: Optional[int] = None,
+ round_up: bool = True,
+ ) -> None:
+ rank = dp_mesh.get_local_rank()
+ world_size = dp_mesh.size()
+
+ assert global_batch_size % world_size == 0
+ self.global_batch_size = global_batch_size
+ self.rank = rank
+ self.world_size = world_size
+
+ self.dataset = dataset
+ self.shuffle = shuffle
+ if seed is None:
+ seed = sync_random_seed()
+ self.seed = seed
+ self.epoch = 0
+ self.step = 0
+ self.round_up = round_up
+
+ if self.round_up:
+ self.num_samples = math.ceil(
+ len(self.dataset) /
+ global_batch_size) * global_batch_size // world_size
+ self.total_size = self.num_samples * self.world_size
+ else:
+ self.num_samples = math.ceil(
+ (len(self.dataset) - rank) / world_size)
+ self.total_size = len(self.dataset)
+
+ def __iter__(self) -> Iterator[int]:
+ """Iterate the indices."""
+ # deterministically shuffle based on epoch and seed
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ if self.round_up:
+ indices = (
+ indices *
+ int(self.total_size / len(indices) + 1))[:self.total_size]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.world_size]
+
+ return iter(indices[self.step:])
+
+ def __len__(self) -> int:
+ """The number of samples in this rank."""
+ return self.num_samples - self.step
+
+ def set_epoch(self, epoch: int, step=0) -> None:
+ """Sets the epoch for this sampler.
+
+ When :attr:`shuffle=True`, this ensures all replicas use a different
+ random ordering for each epoch. Otherwise, the next iteration of this
+ sampler will yield the same ordering.
+
+ Args:
+ epoch (int): Epoch number.
+ """
+ self.epoch = epoch
+ self.step = step
+
+
+def get_length_grouped_indices(max_lengths,
+ group_batch_size,
+ dp_size,
+ seed=None):
+ if seed is not None:
+ torch.manual_seed(seed)
+ random.seed(seed)
+
+ assert all(leng != 0
+ for leng in max_lengths), 'Should not have zero length.'
+ indices = torch.randperm(len(max_lengths))
+ megabatches = [
+ indices[i:i + group_batch_size].tolist()
+ for i in range(0, len(max_lengths), group_batch_size)
+ ]
+ output = []
+ for megabatch in megabatches:
+ megabatch = sorted(
+ megabatch, key=lambda i: max_lengths[i], reverse=True)
+ grouped_megabatch = [
+ megabatch[i:i + dp_size] for i in range(0, len(megabatch), dp_size)
+ ]
+ random.shuffle(grouped_megabatch)
+ for group in grouped_megabatch:
+ output.extend(group)
+
+ return output
+
+
+class LengthGroupedSampler(Sampler):
+
+ def __init__(self,
+ dataset: Sized,
+ dp_mesh: DeviceMesh,
+ global_batch_size: int,
+ length_attr: str = 'longest',
+ mega_batch_mult: Optional[int] = None,
+ seed: Optional[int] = None,
+ round_up: bool = True) -> None:
+ rank = dp_mesh.get_local_rank()
+ world_size = dp_mesh.size()
+ self.rank = rank
+ self.world_size = world_size
+ assert global_batch_size % world_size == 0
+
+ self.dataset = dataset
+ if seed is None:
+ seed = sync_random_seed()
+ self.seed = seed
+ self.epoch = 0
+ self.step = 0
+ self.round_up = round_up
+
+ if self.round_up:
+ self.num_samples = math.ceil(
+ len(self.dataset) /
+ global_batch_size) * global_batch_size // world_size
+ self.total_size = self.num_samples * self.world_size
+ else:
+ self.num_samples = math.ceil(
+ (len(self.dataset) - rank) / world_size)
+ self.total_size = len(self.dataset)
+
+ if mega_batch_mult is None:
+ # Default for mega_batch_mult: 50 or the number to get 4
+ # megabatches, whichever is smaller.
+ mega_batch_mult = min(
+ len(self.dataset) // (global_batch_size * 4), 50)
+ # Just in case, for tiny datasets
+ if mega_batch_mult == 0:
+ mega_batch_mult = 1
+ self.group_batch_size = mega_batch_mult * global_batch_size
+
+ if isinstance(self.dataset, TorchConcatDataset):
+ max_lengths = []
+ for sub_dataset in self.dataset.datasets:
+ if hasattr(sub_dataset, length_attr):
+ max_lengths.extend(getattr(sub_dataset, length_attr))
+ else:
+ raise ValueError
+ self.max_lengths = max_lengths
+ else:
+ if hasattr(self.dataset, length_attr):
+ self.max_lengths = getattr(self.dataset, length_attr)
+ assert isinstance(self.max_lengths, (list, tuple))
+
+ self.global_batch_size = global_batch_size
+
+ def __iter__(self) -> Iterator[int]:
+ """Iterate the indices."""
+ generator = torch.Generator()
+ generator.manual_seed(self.seed + self.epoch)
+ seed = self.seed + self.epoch
+ indices = get_length_grouped_indices(
+ max_lengths=self.max_lengths,
+ group_batch_size=self.group_batch_size,
+ dp_size=self.world_size,
+ seed=seed)
+ assert len(set(indices)) == len(indices)
+ # add extra samples to make it evenly divisible
+ if self.round_up:
+ indices = (
+ indices *
+ int(self.total_size / len(indices) + 1))[:self.total_size]
+ # subsample
+ assert len(indices) == self.total_size
+ indices = indices[self.rank:self.total_size:self.world_size]
+ assert len(indices) == self.num_samples
+ return iter(indices[self.step:])
+
+ def __len__(self) -> int:
+ """The number of samples in this rank."""
+ return self.num_samples - self.step
+
+ def set_epoch(self, epoch: int, step=0) -> None:
+ """Sets the epoch for this sampler.
+
+ When :attr:`shuffle=True`, this ensures all replicas use a different
+ random ordering for each epoch. Otherwise, the next iteration of this
+ sampler will yield the same ordering.
+
+ Args:
+ epoch (int): Epoch number.
+ """
+ self.epoch = epoch
+ self.step = step
+
+
+def vlm_get_length_grouped_indices(max_lengths, group_batch_size, generator=None, **kwargs):
+
+ def process(lengths, group_batch_size, generator=None):
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatches = [
+ indices[i:i + group_batch_size].tolist()
+ for i in range(0, len(lengths), group_batch_size)
+ ]
+ megabatches = [
+ sorted(megabatch, key=lambda i: lengths[i], reverse=True)
+ for megabatch in megabatches
+ ]
+ return megabatches
+
+ lengths = max_lengths
+ assert all(leng != 0 for leng in lengths), 'Should not have zero length.'
+ if all(leng > 0 for leng in lengths) or all(leng < 0 for leng in lengths):
+ # all samples are in the same modality
+ megabatches = process(lengths, group_batch_size, generator=generator)
+ else:
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths)
+ if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l)
+ for i, l in enumerate(lengths)
+ if l < 0])
+ mm_megabatches = []
+ for mm_megabatch in process(
+ mm_lengths, group_batch_size, generator=generator):
+ mm_megabatches.append([mm_indices[i] for i in mm_megabatch])
+ lang_megabatches = []
+ for lang_megabatch in process(
+ lang_lengths, group_batch_size, generator=generator):
+ lang_megabatches.append([lang_indices[i] for i in lang_megabatch])
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ last_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+
+ megabatch_indices = torch.randperm(
+ len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(last_batch) > 0:
+ megabatches.append(
+ sorted(
+ last_batch, key=lambda i: abs(lengths[i]), reverse=True))
+
+ # The rest is to get the biggest batch first.
+ # Since each megabatch is sorted by descending length,
+ # the longest element is the first
+ megabatch_maximums = [
+ abs(lengths[megabatch[0]]) for megabatch in megabatches
+ ]
+ max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
+ # Switch to put the longest element in first position
+ megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][
+ 0], megabatches[0][0]
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+class VLMLengthGroupedSampler(Sampler):
+
+ def __init__(self,
+ dataset: Sized,
+ dp_mesh: DeviceMesh,
+ global_batch_size: int,
+ mega_batch_mult: Optional[int] = None,
+ seed: Optional[int] = None,
+ round_up: bool = True,
+ length_property='length') -> None:
+ rank = dp_mesh.get_local_rank()
+ world_size = dp_mesh.size()
+ self.rank = rank
+ self.world_size = world_size
+ assert global_batch_size % world_size == 0
+
+ self.dataset = dataset
+ if seed is None:
+ seed = sync_random_seed()
+ self.seed = seed
+ self.epoch = 0
+ self.step = 0
+ self.round_up = round_up
+
+ if self.round_up:
+ self.num_samples = math.ceil(
+ len(self.dataset) /
+ global_batch_size) * global_batch_size // world_size
+ self.total_size = self.num_samples * self.world_size
+ else:
+ self.num_samples = math.ceil(
+ (len(self.dataset) - rank) / world_size)
+ self.total_size = len(self.dataset)
+
+ if mega_batch_mult is None:
+ # Default for mega_batch_mult: 50 or the number to get 4
+ # megabatches, whichever is smaller.
+ mega_batch_mult = min(
+ len(self.dataset) // (global_batch_size * 4), 50)
+ # Just in case, for tiny datasets
+ if mega_batch_mult == 0:
+ mega_batch_mult = 1
+ self.group_batch_size = mega_batch_mult * global_batch_size
+
+ if isinstance(self.dataset, TorchConcatDataset):
+ max_lengths = []
+ for sub_dataset in self.dataset.datasets:
+ max_lengths.extend(getattr(sub_dataset, length_property))
+ self.max_lengths = max_lengths
+ else:
+ self.max_lengths = getattr(self.dataset, length_property)
+ assert isinstance(self.max_lengths, (list, tuple))
+
+ self.global_batch_size = global_batch_size
+
+ def __iter__(self) -> Iterator[int]:
+ """Iterate the indices."""
+ generator = torch.Generator()
+ generator.manual_seed(self.seed + self.epoch)
+ indices = vlm_get_length_grouped_indices(
+ max_lengths=self.max_lengths,
+ group_batch_size=self.group_batch_size,
+ dp_size=self.world_size,
+ generator=generator)
+ assert len(set(indices)) == len(indices)
+ # add extra samples to make it evenly divisible
+ if self.round_up:
+ indices = (
+ indices *
+ int(self.total_size / len(indices) + 1))[:self.total_size]
+ # subsample
+ assert len(indices) == self.total_size
+ indices = indices[self.rank:self.total_size:self.world_size]
+ assert len(indices) == self.num_samples
+ return iter(indices[self.step:])
+
+ def __len__(self) -> int:
+ """The number of samples in this rank."""
+ return self.num_samples - self.step
+
+ def set_epoch(self, epoch: int, step=0) -> None:
+ """Sets the epoch for this sampler.
+
+ When :attr:`shuffle=True`, this ensures all replicas use a different
+ random ordering for each epoch. Otherwise, the next iteration of this
+ sampler will yield the same ordering.
+
+ Args:
+ epoch (int): Epoch number.
+ """
+ self.epoch = epoch
+ self.step = step
\ No newline at end of file
diff --git a/xtuner/_lite/parallel/sequence/__init__.py b/xtuner/_lite/parallel/sequence/__init__.py
new file mode 100644
index 000000000..2dbdcaacf
--- /dev/null
+++ b/xtuner/_lite/parallel/sequence/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.dist import init_dist
+
+from .attention import (post_process_for_sequence_parallel_attn,
+ pre_process_for_sequence_parallel_attn,
+ sequence_parallel_wrapper)
+from .data_collate import (pad_cumulative_len_for_sequence_parallel,
+ pad_for_sequence_parallel)
+from .ops import (gather_for_sequence_parallel, gather_forward_split_backward,
+ split_for_sequence_parallel, split_forward_gather_backward)
+from .reduce_loss import reduce_sequence_parallel_loss
+
+__all__ = [
+ 'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn',
+ 'post_process_for_sequence_parallel_attn', 'split_for_sequence_parallel',
+ 'init_dist', 'gather_for_sequence_parallel',
+ 'split_forward_gather_backward', 'gather_forward_split_backward',
+ 'pad_cumulative_len_for_sequence_parallel', 'pad_for_sequence_parallel',
+ 'reduce_sequence_parallel_loss'
+]
diff --git a/xtuner/_lite/parallel/sequence/attention.py b/xtuner/_lite/parallel/sequence/attention.py
new file mode 100644
index 000000000..5866a9f33
--- /dev/null
+++ b/xtuner/_lite/parallel/sequence/attention.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.distributed as dist
+
+from ..comm import all_to_all
+from ..setup import get_sp_mesh
+
+
+def pre_process_for_sequence_parallel_attn(query_states,
+ key_states,
+ value_states,
+ sp_mesh,
+ scatter_dim=2,
+ gather_dim=1):
+ sp_size = sp_mesh.size()
+ n_head = query_states.shape[2]
+ assert n_head % sp_size == 0, \
+ ('The number of attention heads should be divisible by '
+ f'sequence_parallel_world_size. But got n_head = {n_head} and '
+ f'sequence_parallel_world_size = {sp_size}.')
+
+ # (b, s // sp_world_size, nd, dim) -> (b, s, nd // sp_world_size, dim)
+ sp_group = sp_mesh.get_group()
+ query_states = all_to_all(
+ query_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim)
+ key_states = all_to_all(
+ key_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim)
+ value_states = all_to_all(
+ value_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim)
+
+ return query_states, key_states, value_states
+
+
+def post_process_for_sequence_parallel_attn(attn_output,
+ sp_mesh,
+ scatter_dim=1,
+ gather_dim=2):
+ # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim)
+ sp_group = sp_mesh.get_group()
+ output = all_to_all(
+ attn_output, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim)
+ return output
+
+
+def sequence_parallel_wrapper(local_attn):
+
+ def sequence_parallel_attn(query_states, key_states, value_states, *args,
+ **kwargs):
+ training = kwargs.pop('training', True)
+ sp_mesh = kwargs.pop('sp_mesh', None)
+
+ if sp_mesh:
+ sp_size = sp_mesh.size()
+ else:
+ sp_size = get_sp_mesh().size()
+
+ enable_sequence_parallel = sp_size > 1
+ if enable_sequence_parallel:
+ query_states, key_states, value_states = \
+ pre_process_for_sequence_parallel_attn(
+ query_states, key_states, value_states, sp_mesh)
+
+ out = local_attn(query_states, key_states, value_states, *args,
+ **kwargs)
+
+ if enable_sequence_parallel:
+ out = post_process_for_sequence_parallel_attn(out, sp_mesh).contiguous()
+
+ return out
+
+ return sequence_parallel_attn
diff --git a/xtuner/_lite/parallel/sequence/data_collate.py b/xtuner/_lite/parallel/sequence/data_collate.py
new file mode 100644
index 000000000..32c8f161b
--- /dev/null
+++ b/xtuner/_lite/parallel/sequence/data_collate.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..setup import get_sp_mesh
+
+
+def pad_for_sequence_parallel(tensor, padding_value, sp_mesh, dim=-1):
+
+ sp_size = sp_mesh.size()
+ length = tensor.shape[dim]
+ if length % sp_size == 0:
+ return tensor
+
+ pad_num = sp_size - (length % sp_size)
+ pad_shape = (*tensor.shape[:dim], pad_num,
+ *tensor.shape[dim + 1:]) if dim != -1 else (
+ *tensor.shape[:dim], pad_num)
+ pad = torch.full(
+ pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device)
+ tensor = torch.cat([tensor, pad], dim=dim)
+ return tensor
+
+
+# This function only meets the following two conditions:
+# 1. use_varlen_attn = True
+# 2. pack_to_max_length = True and the lengths of each sequence are different
+def pad_cumulative_len_for_sequence_parallel(cumulative_len):
+ assert len(cumulative_len) == 1
+ seqlen = cumulative_len[0][-1]
+ sp_size = get_sp_mesh().size()
+ if seqlen % sp_size == 0:
+ return cumulative_len, None
+
+ bs = len(cumulative_len)
+ pad_len = sp_size - (seqlen % sp_size)
+ seqlen_new = seqlen + pad_len
+ attention_mask = torch.zeros(
+ bs, seqlen_new, dtype=torch.bool, device=cumulative_len[0].device)
+ attention_mask[:, :seqlen] = True
+
+ for i, cu_len in enumerate(cumulative_len):
+ pad = torch.tensor([seqlen_new],
+ device=cu_len.device,
+ dtype=cu_len.dtype)
+ cumulative_len[i] = torch.cat([cu_len, pad], dim=0)
+
+ return cumulative_len, attention_mask
diff --git a/xtuner/_lite/parallel/sequence/ops.py b/xtuner/_lite/parallel/sequence/ops.py
new file mode 100644
index 000000000..fb0ba0d86
--- /dev/null
+++ b/xtuner/_lite/parallel/sequence/ops.py
@@ -0,0 +1,186 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+
+
+def split_for_sequence_parallel(input, dim: int, sp_mesh):
+ """Splits the input tensor along a given dimension for sequence parallel.
+
+ Args:
+ input: The input tensor to be split.
+ dim: The dimension along which the tensor should be split.
+ sp_group: The sequence parallel process group.
+
+ Returns:
+ The split tensor corresponding to the current rank's chunk.
+ """
+ sp_group = sp_mesh.get_group()
+ sp_size = sp_mesh.size()
+ if sp_size == 1:
+ return input
+
+ rank = dist.get_rank(sp_group)
+ dim_size = input.size(dim)
+ assert dim_size % sp_size == 0, (
+ f'The dimension to split ({dim_size}) is not a multiple of '
+ f'sp size ({sp_size}), cannot split tensor evenly')
+
+ tensor_list = torch.split(input, dim_size // sp_size, dim=dim)
+ output = tensor_list[rank].contiguous()
+
+ return output
+
+
+def gather_for_sequence_parallel(input, dim: int, sp_group: dist.ProcessGroup):
+ """Gathers the input tensor along a given dimension for sequence parallel.
+
+ Args:
+ input: The input tensor to be gathered.
+ dim: The dimension along which the tensor should be gathered.
+ sp_group: The sequence parallel process group.
+
+ Returns:
+ The gathered tensor concatenated along the specified dimension.
+ """
+ input = input.contiguous()
+ world_size = dist.get_world_size(sp_group)
+ dist.get_rank(sp_group)
+
+ if world_size == 1:
+ return input
+
+ tensor_list = [torch.empty_like(input) for _ in range(world_size)]
+ assert input.device.type == 'cuda'
+ dist.all_gather(tensor_list, input, group=sp_group)
+
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+
+ return output
+
+
+class _GatherForwardSplitBackward(torch.autograd.Function):
+ """Gather the input during forward.
+
+ Scale and split the grad and keep only the corresponding chuck to the rank
+ during backward.
+ """
+
+ @staticmethod
+ def forward(ctx, input, dim, sp_group, grad_scale):
+ ctx.dim = dim
+ ctx.sp_group = sp_group
+ ctx.grad_scale = grad_scale
+ return gather_for_sequence_parallel(input, dim, sp_group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.grad_scale == 'up':
+ grad_output = grad_output * dist.get_world_size(ctx.sp_group)
+ elif ctx.grad_scale == 'down':
+ grad_output = grad_output / dist.get_world_size(ctx.sp_group)
+
+ return (split_for_sequence_parallel(grad_output, ctx.dim,
+ ctx.sp_group), None, None, None)
+
+
+class _SplitForwardGatherBackward(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank during
+ forward.
+
+ Scale and gather the grad during backward.
+ """
+
+ @staticmethod
+ def forward(ctx, input, dim, sp_group, grad_scale):
+ ctx.dim = dim
+ ctx.sp_group = sp_group
+ ctx.grad_scale = grad_scale
+ return split_for_sequence_parallel(input, dim, sp_group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.grad_scale == 'up':
+ grad_output = grad_output * dist.get_world_size(ctx.sp_group)
+ elif ctx.grad_scale == 'down':
+ grad_output = grad_output / dist.get_world_size(ctx.sp_group)
+ return (gather_for_sequence_parallel(grad_output, ctx.dim,
+ ctx.sp_group), None, None, None)
+
+
+def split_forward_gather_backward(input, dim, sp_group, grad_scale=None):
+ """Split tensors according to the sp rank during forward propagation and
+ gather the grad from the whole sp group during backward propagation.
+
+ 1. When do we need this? input.requires_grad = True
+
+ 2. Why we need grad scale?
+
+ We have to scale down the grads as `gather_forward_split_backward` scales
+ up the grads.
+ """
+ return _SplitForwardGatherBackward.apply(input, dim, sp_group, grad_scale)
+
+
+def gather_forward_split_backward(input, dim, sp_group, grad_scale=None):
+ """Gather tensors from the whole sp group during forward propagation and
+ split the grad according to the sp rank during backward propagation.
+
+ 1. When do we need this?
+
+ When sp is greater than 1, we need to slice the input `x` along
+ sequence length dimension before it is passed into the model and get
+ `sub_seq_x`. We then pass `sub_seq_x` into model and get output
+ `sub_seq_out`. If the loss calculation process needs to use the complete
+ output, we have to gather the `sub_seq_out` in all sp ranks during forward
+ propagation and split the grad during backward propagation.
+
+ 2. Why we need grad scale?
+ Here is a simple case.
+
+ -------- SP 1 -----------
+ Suppose here is a toy model with only one linear module
+ (in_features = 2, out_features = 1) and the input x has shape(2, 2).
+ Y = [[y1], = [[w11x11 + w21x12], = [[x11, x12], dot [[w11],
+ [y2]] [w11x21 + w21x22]] [x21, x22]] [w21]]
+ z = mean(Y) = (y1 + y2) / 2
+ Here is the partial derivative of z with respect to w11:
+ ∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 + ∂z / ∂y2 * ∂y2 / ∂w11
+ = 1/2 * x11 + 1/2 * x21 = (x11 + x21) / 2
+
+ -------- SP 2 -----------
+ When sequence parallel world size is set to 2, we will split the input x
+ and scatter them to the two rank in the same sequence parallel group.
+ ```Step 1
+ Y_rank0 = [[y1]] = [[w11x11 + w21x12]] = [[x11, x12]] dot [[w11, w21]]^T
+ Y_rank1 = [[y2]] = [[w11x21 + w21x22]] = [[x21, x22]] dot [[w11, w21]]^T
+ ```
+
+ Then, we have to gather them:
+ ```Step 2
+ Y_rank0 = [[y1],
+ detach([y2])]
+ Y_rank1 = [detach([y1]),
+ [y2]]
+ ```
+ Note that y2 in Y_rank0 does not have grad, neither does y1 in Y_rank1.
+
+ Similarly, we calculate the loss in each rank:
+ ```Step 3
+ z_rank0 = mean(Y_rank0) = (y1 + detach(y2)) / 2
+ z_rank1 = mean(Y_rank1) = (detach(y1) + y2) / 2
+ ```
+ So the partial derivative of loss_rank0 with respect to w11:
+ ```∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 = x11 / 2```
+ The same for rank1:
+ ```∂z / ∂w11 = ∂z / ∂y2 * ∂y2 / ∂w11 = x21 / 2```
+
+ Finally, we need to all_reduce them:
+ ```Step 4
+ In both rank:
+ ∂z / ∂w11 = (x11 / 2 + x21 / 2) / 2 = (x11 + x21) / 4
+ ```
+
+ In SP2, the gradient of each param is only half of that in SP1.
+ So we should scale up the grad during the backward process in Step 2.
+ """ # noqa: E501
+ return _GatherForwardSplitBackward.apply(input, dim, sp_group, grad_scale)
diff --git a/xtuner/_lite/parallel/sequence/reduce_loss.py b/xtuner/_lite/parallel/sequence/reduce_loss.py
new file mode 100644
index 000000000..11e204806
--- /dev/null
+++ b/xtuner/_lite/parallel/sequence/reduce_loss.py
@@ -0,0 +1,33 @@
+import torch
+import torch.distributed as dist
+
+from ..setup import get_sp_mesh
+
+
+class _ReduceLoss(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, mean_loss, loss_scale, process_group):
+ ctx.mode = process_group
+ if loss_scale == 0:
+ # convert nan to 0 just for logging
+ mean_loss = torch.nan_to_num(mean_loss)
+ loss_sum = mean_loss * loss_scale
+ dist.all_reduce(loss_sum, group=process_group)
+ dist.all_reduce(loss_scale, group=process_group)
+ loss = loss_sum / loss_scale
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None, None
+
+
+def reduce_sequence_parallel_loss(mean_loss, loss_scale, sp_mesh=None):
+ if sp_mesh.size() == 1:
+ return mean_loss
+ if sp_mesh is None:
+ # avoid bc breaking
+ sp_mesh = get_sp_mesh()
+ sp_group = sp_mesh.get_group()
+ return _ReduceLoss.apply(mean_loss, loss_scale, sp_group)
diff --git a/xtuner/_lite/parallel/setup.py b/xtuner/_lite/parallel/setup.py
new file mode 100644
index 000000000..7a02f5d03
--- /dev/null
+++ b/xtuner/_lite/parallel/setup.py
@@ -0,0 +1,110 @@
+import torch.distributed as dist
+from mmengine.dist import infer_launcher, init_dist
+from torch.distributed.device_mesh import init_device_mesh
+
+from xtuner._lite import get_device
+
+_SP_MESH = None
+_DP_MESH = None
+_SAME_DATA_MESH = None
+_TP_MESH = None
+_FSDP_MESH = None
+_WORLD_MESH = None
+
+_EP_MESH = None
+_EXPERTS_FSDP_MESH = None
+
+
+def setup_parallel(sp_size=1, tp_size=1, ep_size=1):
+
+ if not dist.is_initialized():
+ dist_launcher = infer_launcher()
+ init_dist(dist_launcher)
+
+ device = get_device()
+
+ world_size = dist.get_world_size()
+ assert world_size % sp_size == 0
+ assert world_size % sp_size % tp_size == 0
+ assert tp_size <= 8
+
+ dp_size = world_size // sp_size // tp_size
+ data_mesh = init_device_mesh(
+ device, (dp_size, sp_size, tp_size), mesh_dim_names=('dp', 'sp', 'tp'))
+
+ same_data_mesh = init_device_mesh(
+ device, (dp_size, sp_size * tp_size), mesh_dim_names=('dp', 'same_data'))
+
+ model_mesh = init_device_mesh(
+ device, (dp_size * sp_size, tp_size), mesh_dim_names=('fsdp', 'tp'))
+
+ world_mesh = init_device_mesh(
+ device, (world_size, ), mesh_dim_names=('world', ))
+
+ global _DP_MESH, _DP_GROUP, _DP_WORLD_SIZE
+ _DP_MESH = data_mesh['dp']
+ _DP_GROUP = data_mesh['dp'].get_group()
+ _DP_WORLD_SIZE = data_mesh['dp'].size()
+
+ global _SP_MESH, _SP_GROUP, _SP_WORLD_SIZE
+ _SP_MESH = data_mesh['sp']
+ _SP_GROUP = data_mesh['sp'].get_group()
+ _SP_WORLD_SIZE = data_mesh['sp'].size()
+
+ global _TP_MESH, _TP_GROUP, _TP_WORLD_SIZE
+ _TP_MESH = model_mesh['tp']
+ _TP_GROUP = model_mesh['tp'].get_group()
+ _TP_WORLD_SIZE = model_mesh['tp'].size()
+
+ global _WORLD_MESH, _FSDP_MESH
+ _WORLD_MESH = world_mesh['world']
+ _FSDP_MESH = model_mesh['fsdp']
+
+ global _SAME_DATA_MESH
+ _SAME_DATA_MESH = same_data_mesh['same_data']
+
+ assert world_size % ep_size == 0
+ fsdp_size = world_size // ep_size
+
+ # faster in multi nodes
+ device_mesh = init_device_mesh(
+ device, (fsdp_size, ep_size), mesh_dim_names=('fsdp', 'ep'))
+ # slower in multi nodes
+ # device_mesh = init_device_mesh('cuda', (ep_size, fsdp_size),
+ # mesh_dim_names=('ep', 'fsdp'))
+
+ global _EP_MESH
+ global _EXPERTS_FSDP_MESH
+ _EP_MESH = device_mesh['ep']
+ _EXPERTS_FSDP_MESH = device_mesh['fsdp']
+
+
+def get_ep_mesh():
+ return _EP_MESH
+
+
+def get_experts_fsdp_mesh():
+ return _EXPERTS_FSDP_MESH
+
+
+def get_world_mesh():
+ return _WORLD_MESH
+
+
+def get_dp_mesh():
+ return _DP_MESH
+
+
+def get_fsdp_mesh():
+ return _FSDP_MESH
+
+
+def get_sp_mesh():
+ return _SP_MESH
+
+
+def get_tp_mesh():
+ return _TP_MESH
+
+def get_same_data_mesh():
+ return _SAME_DATA_MESH
diff --git a/xtuner/_lite/parallel/utils.py b/xtuner/_lite/parallel/utils.py
new file mode 100644
index 000000000..4c9033e4f
--- /dev/null
+++ b/xtuner/_lite/parallel/utils.py
@@ -0,0 +1,16 @@
+from torch.distributed.checkpoint.stateful import Stateful
+
+
+class MetaStateful(Stateful):
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.kwargs = kwargs
+
+ def state_dict(self):
+ return self.kwargs
+
+ def load_state_dict(self, state_dict) -> None:
+ self.kwargs = state_dict
+
+ def __getitem__(self, key):
+ return self.kwargs[key]
diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py
index e81ec7a3a..25444530e 100644
--- a/xtuner/model/modules/dispatch/__init__.py
+++ b/xtuner/model/modules/dispatch/__init__.py
@@ -97,6 +97,12 @@
DeepseekV2FlashAttention2=LazyObject(
'xtuner.model.modules.dispatch.deepseek_v2',
'deepseek_varlen_attn_forward'),
+ InternLM3FlashCrossAttention2=LazyObject(
+ 'xtuner.model.modules.dispatch.internlm3',
+ 'internlm3_cross_attn_varlen_forward'),
+ InternLM3FlashSelfAttention2=LazyObject(
+ 'xtuner.model.modules.dispatch.internlm3',
+ 'internlm3_self_attn_varlen_forward')
)
VARLEN_ATTN_LEGACY_DISPATCH_MAPPING = dict(
@@ -181,24 +187,23 @@ def dispatch_varlen_attn_forward(model):
return
from mmengine import print_log
- print_log = log_once(print_log)
+ # print_log = log_once(print_log)
- varlen_attn_forward = None
+
for module in model.modules():
name = type(module).__name__
if (IS_LOW_VERSION_TRANSFORMERS
and name in VARLEN_ATTN_LEGACY_DISPATCH_MAPPING):
- if varlen_attn_forward is None:
- varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name]
- varlen_attn_forward = varlen_attn_forward.build()
+
+ varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name]
+ varlen_attn_forward = varlen_attn_forward.build()
print_log(
f'Dispatch legacy {name} varlen forward. '
f'{NO_ATTN_WEIGHTS_MSG}', 'current')
module.forward = types.MethodType(varlen_attn_forward, module)
elif name in VARLEN_ATTN_DISPATCH_MAPPING:
- if varlen_attn_forward is None:
- varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name]
- varlen_attn_forward = varlen_attn_forward.build()
+ varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name]
+ varlen_attn_forward = varlen_attn_forward.build()
print_log(f'Dispatch {name} varlen forward. {NO_ATTN_WEIGHTS_MSG}',
'current')
module.forward = types.MethodType(varlen_attn_forward, module)
diff --git a/xtuner/model/modules/dispatch/internlm3.py b/xtuner/model/modules/dispatch/internlm3.py
new file mode 100644
index 000000000..f9f964a58
--- /dev/null
+++ b/xtuner/model/modules/dispatch/internlm3.py
@@ -0,0 +1,367 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Tuple, Union
+
+import torch
+from einops import rearrange
+from mmengine import MessageHub
+from transformers.cache_utils import StaticCache
+from transformers.modeling_outputs import SequenceClassifierOutputWithPast
+from torch import distributed as dist
+from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
+ post_process_for_sequence_parallel_attn,
+ pre_process_for_sequence_parallel_attn)
+from .attention import SUPPORT_FLASH2, flash_attn_wo_mask, varlen_flash_attn
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ if k is None:
+ return q_embed, None
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """This is the equivalent of torch.repeat_interleave(x, dim=1,
+ repeats=n_rep).
+
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
+ (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :,
+ None, :, :].expand(batch,
+ num_key_value_heads,
+ n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
+ head_dim)
+
+
+def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
+ to (batch, seqlen, num_attention_heads, head_dim)"""
+ batch, slen, num_key_value_heads, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, :,
+ None, :].expand(batch, slen,
+ num_key_value_heads, n_rep,
+ head_dim)
+ return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
+ head_dim)
+
+
+def internlm3_self_attn_varlen_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ # Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ '`static` cache implementation is not compatible with '
+ '`attn_implementation==flash_attention_2` make sure to use `sdpa` '
+ 'in the mean time, and open an issue at '
+ 'https://github.com/huggingface/transformers')
+
+ bsz, q_len, _ = hidden_states.size()
+
+ message_hub = MessageHub.get_instance('varlen_attn_args')
+ if dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+ cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
+ max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
+ use_varlen_atten = (cumulative_len is not None)
+
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
+ cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # dropout_rate = self.attention_dropout if self.training else 0.0
+ dropout_rate = 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (InternLM3RMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.wqkv.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # repeat kv for sequence parallel
+ key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
+ value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+
+ if use_varlen_atten:
+ attn_output = varlen_flash_attn(
+ query_states,
+ key_states,
+ value_states,
+ cumulative_len,
+ max_seqlen,
+ causal=True,
+ dropout_p=dropout_rate,
+ training=self.training)
+ elif attention_mask is None:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=True,
+ training=self.training)
+ else:
+
+ enable_sequence_parallel = (
+ dist.is_initialized() and get_sequence_parallel_world_size() > 1
+ and self.training)
+ if enable_sequence_parallel:
+ query_states, key_states, value_states = \
+ pre_process_for_sequence_parallel_attn(
+ query_states, key_states, value_states)
+ # self.num_heads is used in self._upad_input method
+ # num_heads has been changed because of sequence parallel
+ ori_num_head = self.num_heads
+ self.num_heads = query_states.shape[-2]
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_states.shape[1],
+ dropout=dropout_rate)
+
+ if enable_sequence_parallel:
+ attn_output = post_process_for_sequence_parallel_attn(attn_output)
+ self.num_heads = ori_num_head
+
+
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.wo(attn_output)
+
+ # Due to the implementation of the PyTorch version of flash attention,
+ # even when the output_attentions flag is set to True, it is not possible
+ # to return the attn_weights.
+ return attn_output, None, past_key_value
+
+
+
+
+def internlm3_cross_attn_varlen_forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ message_hub = MessageHub.get_instance('varlen_attn_args')
+ if dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+ cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
+ max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
+ use_varlen_atten = (cumulative_len is not None)
+
+ if self.config.pretraining_tp > 1:
+ # split qkv_states by tp size
+ key_value_slicing = self.hidden_size // self.config.pretraining_tp
+ q_slices = self.wq.weight.split(key_value_slicing, dim=0)
+ query_states = torch.cat(
+ [F.linear(hidden_states, q_slice) for q_slice in q_slices], dim=-1 # pylint: disable=E1102
+ )
+ else:
+ query_states = self.wq(hidden_states)
+
+ query_states = rearrange(query_states, "b q (h d) -> b q h d", d=self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ # Only query_states are rotated in cross-attention
+ query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # dropout_rate = self.attention_dropout if self.training else 0.0
+ dropout_rate = 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (InternLM3RMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.wqkv.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ dropout_rate = 0.0
+ if use_varlen_atten:
+ attn_output = varlen_flash_attn(
+ query_states,
+ key_states,
+ value_states,
+ cumulative_len,
+ max_seqlen,
+ causal=True,
+ dropout_p=dropout_rate,
+ training=self.training)
+ elif attention_mask is None:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=True,
+ training=self.training)
+ else:
+
+ enable_sequence_parallel = (
+ dist.is_initialized() and get_sequence_parallel_world_size() > 1
+ and self.training)
+ if enable_sequence_parallel:
+ query_states, key_states, value_states = \
+ pre_process_for_sequence_parallel_attn(
+ query_states, key_states, value_states)
+ # self.num_heads is used in self._upad_input method
+ # num_heads has been changed because of sequence parallel
+ ori_num_head = self.num_heads
+ self.num_heads = query_states.shape[-2]
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_states.shape[1],
+ dropout=dropout_rate)
+
+ if enable_sequence_parallel:
+ attn_output = post_process_for_sequence_parallel_attn(attn_output)
+ self.num_heads = ori_num_head
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.wo(attn_output)
+
+ return attn_output, None
\ No newline at end of file
diff --git a/xtuner/utils/handle_moe_load_and_save.py b/xtuner/utils/handle_moe_load_and_save.py
index 88a3936a8..18764a82d 100644
--- a/xtuner/utils/handle_moe_load_and_save.py
+++ b/xtuner/utils/handle_moe_load_and_save.py
@@ -3,7 +3,6 @@
import re
from collections import OrderedDict
-import deepspeed
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -145,6 +144,7 @@ def load(module: nn.Module, state_dict, unloaded_shard_files, prefix=''):
if len(params_to_gather) > 0:
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
+ import deepspeed
with deepspeed.zero.GatheredParameters(
params_to_gather, modifier_rank=0):
if dist.get_rank() == 0: