Skip to content

Commit

Permalink
[Feature] Support ChatGLM3-6B (#222)
Browse files Browse the repository at this point in the history
* add chatglm cfgs

* add chatglm3 template

* fix bos_token_id bug

* add encode_special_tokens

* update readme
  • Loading branch information
LZHgrla authored Nov 14, 2023
1 parent 51ae023 commit bda70b1
Show file tree
Hide file tree
Showing 37 changed files with 2,922 additions and 25 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ English | [简体中文](README_zh-CN.md)

## 🎉 News

- **\[2023/10\]** Support [ChatGLM3-6B-Base](https://huggingface.co/THUDM/chatglm3-6b-base) model!
- **\[2023/10\]** Support [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b) model!
- **\[2023/10\]** Support [MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench) dataset, and the fine-tuned LLMs can be applied by [Lagent](https://github.com/InternLM/lagent)!
- **\[2023/10\]** Optimize the data processing to accommodate `system` context. More information can be found on [Docs](docs/en/user_guides/dataset_format.md)!
- **\[2023/09\]** Support [InternLM-20B](https://huggingface.co/internlm) models!
Expand Down Expand Up @@ -79,7 +79,7 @@ XTuner is a toolkit for efficiently fine-tuning LLM, developed by the [MMRazor](
<li><a href="https://huggingface.co/meta-llama">Llama</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama2</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm3-6b-base">ChatGLM3</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm3-6b">ChatGLM3</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B">Qwen</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-7B">Baichuan</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan2-7B-Base">Baichuan2</a></li>
Expand Down
4 changes: 2 additions & 2 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

## 🎉 更新

- **\[2023/10\]** 支持 [ChatGLM3-6B-Base](https://huggingface.co/THUDM/chatglm3-6b-base) 模型!
- **\[2023/10\]** 支持 [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b) 模型!
- **\[2023/10\]** 支持 [MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench) 数据集,并且微调所得大语言模型可应用至 [Lagent](https://github.com/InternLM/lagent) 框架!
- **\[2023/10\]** 优化数据处理逻辑以兼容 `system` 字段,相关细节请查阅[文档](docs/zh_cn/user_guides/dataset_format.md)
- **\[2023/09\]** 支持 [InternLM-20B](https://huggingface.co/internlm) 系列模型!
Expand Down Expand Up @@ -79,7 +79,7 @@ XTuner 是一个轻量级微调大语言模型的工具库,由 [MMRazor](https
<li><a href="https://huggingface.co/meta-llama">Llama</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama2</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm3-6b-base">ChatGLM3</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm3-6b">ChatGLM3</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B">Qwen</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-7B">Baichuan</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan2-7B-Base">Baichuan2</a></li>
Expand Down
12 changes: 9 additions & 3 deletions xtuner/apis/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def build_qlora_model(model_name_or_path,

if return_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True)
model_name_or_path,
trust_remote_code=True,
encode_special_tokens=True)
return model.llm, tokenizer
else:
return model.llm
Expand All @@ -65,7 +67,9 @@ def build_lora_model(model_name_or_path,

if return_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True)
model_name_or_path,
trust_remote_code=True,
encode_special_tokens=True)
return model.llm, tokenizer
else:
return model.llm
Expand All @@ -77,7 +81,9 @@ def build_model(model_name_or_path, return_tokenizer=True):

if return_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True)
model_name_or_path,
trust_remote_code=True,
encode_special_tokens=True)
return model, tokenizer
else:
return model
183 changes: 183 additions & 0 deletions xtuner/configs/chatglm/chatglm3_6b/chatglm3_6b_qlora_alpaca_e3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from bitsandbytes.optim import PagedAdamW32bit
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'THUDM/chatglm3-6b'

# Data
alpaca_en_path = 'tatsu-lab/alpaca'
prompt_template = PROMPT_TEMPLATE.chatglm3
max_length = 2048
pack_to_max_length = True

# Scheduler & Optimizer
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
max_epochs = 3
optim_type = PagedAdamW32bit
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = [
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
]

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
encode_special_tokens=True,
padding_side='left')

model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path=alpaca_en_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=alpaca_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=alpaca_en,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=lr * 0.1,
by_epoch=True,
T_max=max_epochs,
convert_to_iter_based=True)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template)
]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
Loading

0 comments on commit bda70b1

Please sign in to comment.