Skip to content

Commit

Permalink
Add Mixtral MoE and Qwen-vl (#105)
Browse files Browse the repository at this point in the history
Co-authored-by: 同润 <[email protected]>
  • Loading branch information
jerryli1981 and 同润 committed Dec 27, 2023
1 parent 8fa82d0 commit 84e0355
Show file tree
Hide file tree
Showing 48 changed files with 11,783 additions and 117 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ English | [简体中文](./README_zh-CN.md)
Pai-Megatron-Patch (https://github.com/alibaba/Pai-Megatron-Patch) is a deep learning training toolkit built for developers to train and predict large language models (LLMs) by using MegatronLM framework easily. With the continuous development of LLMs, the model structure and scale are rapidly evolving. Although these models can be conveniently manufactured using Transformers or DeepSpeed training framework, the training efficiency is comparably low. This phenomenon becomes even severer when the model scale exceeds 10 billion. The primary objective of Pai-Megatron-Patch is to effectively utilize the computational power of GPUs for LLM. This tool allows convenient training of commonly used LLM with all the accelerating techniques provided by Megatron-LM.

What's New:
- **Support fine-tuning mixtral-8x7b moe model by using Megatron-LM.** [🔥🔥 2023.12.27]
- **Support fine-tuning qwen-vl multimodel by using Megatron-LM.** [🔥🔥 2023.12.15]
- **Support fine-tuning LLava multimodel by using Megatron-LM.** [🔥🔥 2023.12.01]
- **Support fine-tuning deepseek model by using Megatron-LM.** [🔥🔥 2023.11.24]
- **Support fine-tuning qwen-72B model by using Megatron-LM.** [🔥🔥 2023.11.23]
Expand Down
2 changes: 2 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Pai-Megatron-Patch工具是阿里云机器学习平台PAI算法团队研发,
- [阿里云PAI获得FewCLUE基于大模型的小样本学习双料冠军](https://developer.aliyun.com/article/788081?spm=a2c6h.12873639.article-detail.17.11c5383cHpFZks&tlog=yuekan_8)

新功能:
- **支持用MegatronLM框架训练mixtral-8x7b MoE稀疏模型** [🔥🔥 2023.12.27]
- **支持用MegatronLM框架微调多模态大模型qwen-vl.** [🔥🔥 2023.12.15]
- **支持用MegatronLM框架微调多模态大模型LLava.** [🔥🔥 2023.12.01]
- **支持用MegatronLM框架训练deepseek系列模型.** [🔥🔥 2023.11.24]
- **支持用MegatronLM框架微调qwen-72B模型.** [🔥🔥 2023.11.23]
Expand Down
14 changes: 4 additions & 10 deletions examples/mistral/evaluate_huggingface_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron.core.enums import ModelType
from megatron import get_args
from megatron import print_rank_0
from megatron import is_last_rank
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.p2p_communication import send_forward
from megatron.initialize import initialize_megatron
Expand All @@ -28,23 +26,19 @@
from megatron.utils import unwrap_model
from megatron.arguments import core_transformer_config_from_args

from megatron_patch.data.evaluate_dataset import build_evaluation_dataset
from megatron_patch.data import build_evaluation_dataset
from megatron_patch.finetune_utils import build_data_loader
from megatron_patch.tokenizer import build_tokenizer
from megatron_patch.tokenizer import get_tokenizer
from megatron_patch.training import get_model
from megatron_patch.arguments import get_tasks_args
from megatron_patch.model.mistral.modeling_mistral import MistralForCausalLM

from transformers import AutoModelForCausalLM

def get_model_provider():
"""Based on evaluation metric set the parallel-output flag and
return the model provider."""
def model_provider(pre_process=True, post_process=True):
args = get_args()
tokenizer = build_tokenizer(args)
model = MistralForCausalLM.from_pretrained(args.load,
trust_remote_code=False)
model = AutoModelForCausalLM.from_pretrained(args.load, device_map="auto")
return model

return model_provider
Expand All @@ -56,7 +50,7 @@ def forward_step(batch, model):
# Get the batch.
input_ids = batch['input_ids'].long().cuda()
labels = batch['labels'].long().cuda()
labels[labels == -1] = -100
labels[labels == 0] = -100
attention_mask = input_ids.ne(tokenizer.pad_token_id)

# Tell the model what our actual batch size will be
Expand Down
12 changes: 4 additions & 8 deletions examples/mistral/evaluate_megatron_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# limitations under the License.

import torch
from megatron_patch.data import \
build_pretrain_dataset_from_original, build_pretrain_dataset_from_idxmap

from megatron.core.enums import ModelType
from megatron import get_args
from megatron import print_rank_0
Expand All @@ -27,6 +24,7 @@
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.arguments import core_transformer_config_from_args

from megatron_patch.data import build_evaluation_dataset
from megatron_patch.checkpointing import load_checkpoint
from megatron_patch.finetune_utils import build_data_loader
from megatron_patch.model.mistral.gpt_model import GPTModel
Expand Down Expand Up @@ -67,8 +65,8 @@ def get_batch(batch):

tokens = tokens[:, :-1].contiguous()
labels = labels[:, 1:].contiguous()

attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
attention_mask = tokens.ne(tokenizer.pad_token_id)
_, loss_mask, position_ids = get_ltor_masks_and_position_ids(
labels,
tokenizer.pad_token_id,
args.reset_position_ids,
Expand Down Expand Up @@ -145,9 +143,7 @@ def main():
exit()

# Data stuff.
#dataset = build_evaluation_dataset(args.dataset)
dataset, _, _ = \
build_pretrain_dataset_from_original(args.dataset)
dataset = build_evaluation_dataset(args.dataset)
dataloader = build_data_loader(dataset,
args.micro_batch_size,
args.num_workers,
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/run_evaluate_huggingface_mistral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ fi

megatron_options=" \
--transformer-type huggingface \
--data-path ${DATASET_PATH}
--valid-data-path ${DATASET_PATH}
--micro-batch-size ${BATCH_SIZE} \
--num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
Expand Down
10 changes: 5 additions & 5 deletions examples/mistral/run_evaluate_megatron_mistral.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash
#sh run_evaluate_megatron_mistral.sh dsw /workspace/Pai-Megatron-Patch 7B 1 80 80 0 bf16 2 1 sel true false true false /mnt/llama2-datasets/alpaca_data.json /mnt/mistral-ckpts/Mistral-7B-v0.1-to-mg-tp2-pp1/
#sh run_evaluate_megatron_mistral.sh dsw ../.. 7B 1 81 81 0 bf16 2 1 sel true false true false /mnt/llama2-datasets/alpaca_data.json /mnt/mistral-ckpts/Mistral-7B-v0.1-to-mg-tp2-pp1/
set -e
ENV=$1
MEGATRON_PATCH_PATH=$2
MEGATRON_PATH=${MEGATRON_PATCH_PATH}/Megatron-LM-main
export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH
export CUDA_DEVICE_MAX_CONNECTIONS=1
if [ $ENV = dsw ]; then
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=0
MASTER_ADDR=localhost
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=8
GPUS_PER_NODE=1

elif [ $ENV = dlc ]; then

Expand Down Expand Up @@ -125,7 +125,7 @@ fi


megatron_options=" \
--train-data-path ${DATASET_PATH}
--valid-data-path ${DATASET_PATH}
--micro-batch-size ${BATCH_SIZE} \
--num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
Expand All @@ -145,7 +145,7 @@ megatron_options=" \
--max-padding-length ${PAD_LEN} \
--extra-vocab-size ${EXTRA_VOCAB_SIZE} \
--patch-tokenizer-type MistralTokenizer \
--dataset LLama-Pretrain-Raw \
--dataset Mistral-SFT \
--sliding-window ${SLW} \
--swiglu \
--normalization RMSNorm \
Expand Down
136 changes: 136 additions & 0 deletions examples/mixtral/evaluate_huggingface_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron.core.enums import ModelType
from megatron import get_args
from megatron import print_rank_0
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.p2p_communication import send_forward
from megatron.initialize import initialize_megatron
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
from megatron.arguments import core_transformer_config_from_args

from megatron_patch.data import build_evaluation_dataset
from megatron_patch.finetune_utils import build_data_loader
from megatron_patch.tokenizer import get_tokenizer
from megatron_patch.training import get_model
from megatron_patch.arguments import get_tasks_args
from transformers import AutoModelForCausalLM

def get_model_provider():
"""Based on evaluation metric set the parallel-output flag and
return the model provider."""
def model_provider(pre_process=True, post_process=True):
args = get_args()
"""
from accelerate import load_checkpoint_and_dispatch
from accelerate import init_empty_weights
with init_empty_weights():
config = MixtralConfig()
model = MixtralForCausalLM(config=config)
model = load_checkpoint_and_dispatch(model, checkpoint=args.load, device_map=device_map, dtype=torch.bfloat16)
"""
model = AutoModelForCausalLM.from_pretrained(args.load, torch_dtype=torch.bfloat16, device_map="auto")
return model

return model_provider


def forward_step(batch, model):
"""Forward step."""
tokenizer = get_tokenizer()
# Get the batch.
input_ids = batch['input_ids'].long().cuda()
labels = batch['labels'].long().cuda()
labels[labels == 0] = -100
attention_mask = input_ids.ne(tokenizer.pad_token_id)

# Tell the model what our actual batch size will be
args = get_args()

# Forward pass through the model.
unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module))
output = unwrapped_model(input_ids=input_ids,
labels=labels,
attention_mask=attention_mask)
config = core_transformer_config_from_args(args)
send_forward(output, config)
if parallel_state.is_pipeline_last_stage():
print_rank_0(output.loss)
return output.loss

return None


def evaluate(data_loader, model):
"""Evaluation."""
args = get_args()

# Turn on evaluation mode which disables dropout.
model.eval()

total_output = 0.0
with torch.no_grad():
# For all the batches in the dataset.
for iteration, batch in enumerate(data_loader):
if iteration % args.log_interval == 0:
print_rank_0('> working on iteration: {}'.format(iteration))
# Forward evaluation.
output = forward_step(batch, model)

# Reduce across processes.
if parallel_state.is_pipeline_last_stage():
torch.distributed.all_reduce(
output, group=parallel_state.get_data_parallel_group())

total_output += output

return total_output

def main():
"""Main program."""
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print('Interleaved pipeline schedule '
'is not yet supported for text generation.')
exit()

# Set up model and load checkpoint.
model = get_model(get_model_provider(),
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=False)

assert len(model) == 1, 'Above condition should have caught this'
model = model[0]

# Data stuff.
dataset = build_evaluation_dataset(args.dataset)
dataloader = build_data_loader(dataset,
args.micro_batch_size,
args.num_workers,
drop_last=False)

# Run evaluation.
evaluate(dataloader, model)
print_rank_0('done :-)')


if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
main()
Loading

0 comments on commit 84e0355

Please sign in to comment.