Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions examples/data_preprocess/mnist_multiturn_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the MNIST multi-turn SFT dataset to parquet format
"""

import argparse
import os
import numpy as np
import datasets
import pyarrow as pa
import pandas as pd
import pyarrow.parquet as pq
from PIL import Image
import io
from verl.utils.hdfs_io import copy, makedirs

if __name__ == "__main__":

Check failure on line 29 in examples/data_preprocess/mnist_multiturn_sft.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E501)

examples/data_preprocess/mnist_multiturn_sft.py:29:121: E501 Line too long (121 > 120)
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default=None)
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
parser.add_argument(
"--local_save_dir", default="~/data/mnist_multiturn_sft", help="The save directory for the preprocessed dataset."
)

args = parser.parse_args()
local_dataset_path = args.local_dataset_path
hdfs_dir = args.hdfs_dir
local_save_dir = args.local_save_dir

data_source = "vermouth1992/mnist_multiturn_sft"

if local_dataset_path is not None:
dataset = datasets.load_dataset(
local_dataset_path,
)
else:
dataset = datasets.load_dataset(
data_source,
)

train_dataset = dataset["train"]
test_dataset = dataset["test"]

# def image_to_bytes(image: Image.Image) -> bytes:
# img_byte_arr = io.BytesIO()
# image.save(img_byte_arr)
# return img_byte_arr.getvalue()

# def process_row(row):
# messages = row['messages']
# messages_new = []
# for message in messages:
# for idx, content in enumerate(message['content']):
# if content['type'] == 'image':
# content = {"image": int(content['image']), "type": "image"}
# if content['type'] == 'text':
# content = {"text": content['text'], "type": "text"}
# message['content'][idx] = content
# messages_new.append(message)
# row['messages'] = messages_new
# return row

print(train_dataset[0])
train_dataset.to_parquet(os.path.join(args.local_save_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(args.local_save_dir, "test.parquet"))

# if hdfs_dir is not None:
# makedirs(hdfs_dir)
# copy(src=local_save_dir, dst=hdfs_dir)


# df_loaded = pd.read_parquet(os.path.join(local_save_dir, "test.parquet"))
# messages = df_loaded['messages'][0]
# # print("is numpy.ndarray?:", isinstance(messages, np.ndarray))
# print(messages)
102 changes: 102 additions & 0 deletions tests/special_e2e/sft/run_sft_engine_mnist.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env bash
set -xeuo pipefail

ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}

NUM_GPUS=${NUM_GPUS:-1}

TRAIN_FILES=~/data/mnist_multiturn_sft/train.parquet
VAL_FILES=~/data/mnist_multiturn_sft/test.parquet

backend=${BACKEND:-fsdp}

project_name=verl_vlm_sft_test

RESUME_MODE=disable

ckpts_home=${ckpts_home:-~/verl/test/mnist-sft-${backend}}

MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-VL-3B-Instruct}
MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
#huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}"

SP_SIZE=${SP_SIZE:-1}
FSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}}
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp"}

TP_SIZE=${TP_SIZE:-1}
PP_SIZE=${PP_SIZE:-1}
VPP_SIZE=${VPP_SIZE:-null}
CP_SIZE=${CP_SIZE:-1}

FSDP_ENGINE_CONFIG="\
engine=${backend} \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
optim.weight_decay=0.1 \
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.min_lr_ratio=0.1 \
optim.warmup_style=cosine \
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
engine.strategy=${FSDP_STRATEGY} \
engine.fsdp_size=${FSDP_SIZE}"


MEGATRON_ENGINE_CONFIG="\
engine=${backend} \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
optim.weight_decay=0.1 \
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.lr_warmup_init=0 \
optim.lr_decay_style=cosine \
optim.min_lr=1e-6 \
engine.tensor_model_parallel_size=${TP_SIZE} \
engine.pipeline_model_parallel_size=${PP_SIZE} \
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
engine.context_parallel_size=${CP_SIZE}"

if [ "$backend" = "fsdp" ]; then
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
echo "Using fsdp engine"
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}
else
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
echo "Using megatron engine"
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}
fi

mkdir -p "${ckpts_home}"

torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
data.train_files="${TRAIN_FILES}" \
data.val_files="${VAL_FILES}" \
data.train_batch_size=256 \
data.max_prompt_length=1024 \
data.max_response_length=128 \
data.pad_mode=left_right \
data.truncation=error \
data.use_dynamic_bsz=True \
data.max_token_len_per_gpu=8192 \
data.messages_key=messages \
model.path=$MODEL_PATH \
${ENGINE_CONFIG} \
trainer.test_freq=after_each_epoch \
trainer.save_freq=-1 \
trainer.logger=['console','file'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.total_epochs=2 \
trainer.total_training_steps=2 \
trainer.default_local_dir="${ckpts_home}" \
trainer.resume_mode=${RESUME_MODE} \

# trainer.total_training_steps=${TOTAL_TRAIN_STEP} \
# trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \
# trainer.max_ckpt_to_keep=1 \

rm -rf "${ckpts_home:?}/*"
29 changes: 22 additions & 7 deletions verl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint import CheckpointHandler
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset, collate_fn
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
from verl.utils.distributed import destroy_global_process_group
from verl.utils.flops_counter import FlopsCounter
Expand Down Expand Up @@ -144,8 +144,10 @@ def _init_engine(self):
def _build_dataset(self):
config = self.config
tokenizer = self.model_config.tokenizer
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
processor = self.model_config.processor
print(f"processor: {processor}")
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer, processor)

self.train_dataset, self.val_dataset = train_dataset, val_dataset

Expand Down Expand Up @@ -176,6 +178,7 @@ def _build_dataloader(self):
pin_memory=True,
drop_last=True,
pin_memory_device=device_name,
collate_fn=collate_fn,
)

self.val_sampler = DistributedSampler(
Expand All @@ -189,6 +192,7 @@ def _build_dataloader(self):
pin_memory=True,
drop_last=True,
pin_memory_device=device_name,
collate_fn=collate_fn,
)

def fit(self):
Expand Down Expand Up @@ -232,6 +236,7 @@ def fit(self):
"micro_batch_size_per_gpu": self.config.data.micro_batch_size_per_gpu,
"temperature": 1.0,
"global_batch_size": self.global_batch_size,
# "use_remove_padding": True, # TODO(caiyunke.astra): check this branch
}

train_time = 0
Expand All @@ -249,8 +254,11 @@ def fit(self):
):
global_step += 1

tensor_data = {k: v for k, v in data.items() if isinstance(v, torch.Tensor)}
non_tensor_data = {k: v for k, v in data.items() if not isinstance(v, torch.Tensor)}
non_tensor_data.update(meta_info)
# construct tensordict
data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info)
data = tu.get_tensordict(tensor_dict=tensor_data, non_tensor_dict=non_tensor_data)

with self.engine.train_mode():
with Timer(name="update_policy", logger=None) as timer:
Expand Down Expand Up @@ -308,8 +316,11 @@ def fit(self):
val_losses = []
for val_data in self.val_dataloader:
with self.engine.eval_mode():
tensor_data = {k: v for k, v in val_data.items() if isinstance(v, torch.Tensor)}
non_tensor_data = {k: v for k, v in val_data.items() if not isinstance(v, torch.Tensor)}
non_tensor_data.update(meta_info)
# construct tensordict
val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info)
val_data = tu.get_tensordict(tensor_dict=tensor_data, non_tensor_dict=non_tensor_data)
output = self.engine.infer_batch(data=val_data, loss_function=self.loss_fn)
if self.engine.is_mp_src_rank_with_outputs():
val_losses.extend(output["metrics"]["loss"])
Expand Down Expand Up @@ -351,7 +362,7 @@ def main(config):
run_sft(config)


def create_sft_dataset(data_paths, data_config, tokenizer):
def create_sft_dataset(data_paths, data_config, tokenizer, processor):
"""Create a dataset."""
# build dataset
# First check if a custom dataset class is specified
Expand All @@ -364,9 +375,13 @@ def create_sft_dataset(data_paths, data_config, tokenizer):
dataset_cls = MultiTurnSFTDataset

# Create datasets based on the selected class
dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config)
dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, processor=processor)
return dataset


if __name__ == "__main__":
# import debugpy
# debugpy.listen(("localhost", 5678))
# print("Waiting for debugger attach...")
# debugpy.wait_for_client()
main()
Loading
Loading