Skip to content

Commit

Permalink
[Feature] add dpo v2 example
Browse files Browse the repository at this point in the history
  • Loading branch information
wheresmyhair committed Jun 27, 2024
1 parent 35347d4 commit 1715212
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 54 deletions.
23 changes: 23 additions & 0 deletions configs/accelerate_dsz3_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
gpu_ids: 4,5,6,7
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
78 changes: 78 additions & 0 deletions examples/dpov2_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
import os
import sys
import copy

from transformers import (
HfArgumentParser
)

from lmflow.datasets import Dataset
from lmflow.models.auto_model import AutoModel
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.args import (
ModelArguments,
DatasetArguments,
AutoArguments,
)
from lmflow.utils.common import remove_dataclass_attr_prefix, create_copied_dataclass


logger = logging.getLogger(__name__)


ReferenceModelArguments = create_copied_dataclass(
original_dataclass=ModelArguments,
field_prefix="reference_",
class_prefix="Reference"
)


def main():
# Parses arguments
pipeline_name = "dpov2_aligner"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
ModelArguments,
ReferenceModelArguments,
DatasetArguments,
PipelineArguments
))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, ref_model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, ref_model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()

ref_model_args_dict = remove_dataclass_attr_prefix(ref_model_args, "reference_")
ref_model_args = ModelArguments(**ref_model_args_dict)

train_dataset = Dataset(data_args)
eval_data_args = copy.deepcopy(data_args)
eval_data_args.dataset_path = pipeline_args.eval_dataset_path
eval_dataset = Dataset(eval_data_args)
model = AutoModel.get_model(model_args)
ref_model = AutoModel.get_model(ref_model_args)
aligner = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
ref_model_args=ref_model_args,
)

res = aligner.align(
model=model,
ref_model=ref_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion scripts/run_dpo_align.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ while [[ $# -ge 1 ]]; do
dataset_path="$2"
shift
;;
-o|--output_lora_path)
-o|--output_dir)
output_dir="$2"
shift
;;
Expand Down
86 changes: 86 additions & 0 deletions scripts/run_dpov2_align.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/bin/bash

# Parses arguments
run_name=dpov2_align
model_name_or_path=meta-llama/Meta-Llama-3-8B-Instruct
reference_model_name_or_path=meta-llama/Meta-Llama-3-8B-Instruct
dataset_path=data/iterative-prompt/train
eval_dataset_path=data/iterative-prompt/eval
output_dir=output_models/${run_name}
deepspeed_args="--master_port=11000 --include localhost:4,5,6,7"

while [[ $# -ge 1 ]]; do
key="$1"
case ${key} in
-r|--run_name)
run_name="$2"
shift
;;
--model_name_or_path)
model_name_or_path="$2"
shift
;;
--reference_model_name_or_path)
reference_model_name_or_path="$2"
shift
;;
--dataset_path)
dataset_path="$2"
shift
;;
--eval_dataset_path)
eval_dataset_path="$2"
shift
;;
-o|--output_dir)
output_dir="$2"
shift
;;
--deepspeed_args)
deepspeed_args="$2"
shift
;;
*)
echo "error: unknown option \"${key}\"" 1>&2
exit 1
esac
shift
done

project_dir=$(cd "$(dirname $0)"/..; pwd)
log_dir=${project_dir}/log/${run_name}
mkdir -p ${output_dir} ${log_dir}

accelerate launch --config_file configs/accelerate_dsz3_config.yaml \
examples/dpov2_train.py \
--model_name_or_path ${model_name_or_path} \
--reference_model_name_or_path ${reference_model_name_or_path} \
--do_train True \
--dataset_path ${dataset_path} \
--eval_dataset_path ${eval_dataset_path} \
--bf16 True \
--learning_rate 5e-7 \
--lr_scheduler_type cosine \
--warmup_steps 100 \
--optim paged_adamw_32bit \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing True \
--margin_scale 1.0 \
--max_prompt_length 1000 \
--num_train_epochs 2 \
--logging_steps 2 \
--save_strategy epoch \
--save_steps 5000 \
--evaluation_strategy steps \
--eval_steps 100 \
--loss_type sigmoid \
--output_dir ${output_dir} \
--run_name ${run_name} \
--sampling_paired_method max_min \
--report_to wandb \
--mask_prompt True \
--length_penalty 0 \
| tee ${log_dir}/train.log \
2> ${log_dir}/train.err
62 changes: 10 additions & 52 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,65 +1289,22 @@ class DPOAlignerArguments:


@dataclass
class DPOv2AlignerArguments(TrainingArguments):
class DPOv2AlignerArguments(FinetunerArguments):
"""
The arguments for the DPOv2 training script.
"""

# data parameters, i.e., the KL penalty in the paper
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

# training parameters
eval_dir: Optional[str] = field(
default="/export/home/hanze/project/vllm-gen/uf_split0_offline_reward.json", # "/export/home/data/gemma_it_2b_3w_k8_with_pairrm_rewards.json",
metadata={"help": "the location of the evalset name or path"},
)
learning_rate: Optional[float] = field(default=5e-7, metadata={"help": "optimizer learning rate"})
lr_scheduler_type: Optional[str] = field(
default="constant_with_warmup", metadata={"help": "the lr scheduler type"}
)
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.01, metadata={"help": "the weight decay"})

per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "whether to use gradient checkpointing"}
)


lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})

# pair sampling args
margin_scale: Optional[float] = field(default=1.0, metadata={"help": "the margin scale"})

max_prompt_length: Optional[int] = field(default=1000, metadata={"help": "the maximum prompt length"})
max_length: Optional[int] = field(default=2048, metadata={"help": "the maximum sequence length"})
num_train_epochs: Optional[int] = field(default=2, metadata={"help": "max number of training epochs"})
logging_steps: Optional[int] = field(default=2, metadata={"help": "the logging frequency"})
save_strategy: Optional[str] = field(default="epoch", metadata={"help": "the saving strategy"})
save_steps: Optional[int] = field(default=50000, metadata={"help": "the saving frequency"})
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
run_name: Optional[str] = field(default="dpo_soft", metadata={"help": "the run name"})
loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type"})
output_dir: Optional[str] = field(default="./dpo_soft", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})

# instrumentation
sampling_paired_method: Optional[str] = field(default="max_random", metadata={"help": "the choose type"})

mask_prompt: Optional[bool] = field(default=False, metadata={"help": "mask prompt"})
length_penalty: Optional[float] = field(default=0, metadata={"help": "the length penalty"})
# data collator args
max_length: Optional[int] = field(default=2048, metadata={"help": "the maximum sequence length, prompt + output"})
max_prompt_length: Optional[int] = field(default=1000, metadata={"help": "the maximum prompt length"})
mask_prompt: Optional[bool] = field(default=False, metadata={"help": "mask prompt"})
# dpov2 aligner args
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type"})

# need to add
evaluation_strategy: Optional[str] = field(
default="steps",
metadata={"help": "the evaluation strategy"}
)

@dataclass
class IterativeAlignerArguments(InferencerArguments):
Expand All @@ -1366,6 +1323,7 @@ class IterativeAlignerArguments(InferencerArguments):
"raft_aligner": RaftAlignerArguments,
"dpo_aligner": DPOAlignerArguments,
"rm_tuner": RewardModelTunerArguments,
"dpov2_aligner": DPOv2AlignerArguments,
}


Expand Down
2 changes: 2 additions & 0 deletions src/lmflow/pipeline/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def is_package_version_at_least(package_name, min_version):
from lmflow.pipeline.inferencer import Inferencer
from lmflow.pipeline.vllm_inferencer import VLLMInferencer
from lmflow.pipeline.dpo_aligner import DPOAligner
from lmflow.pipeline.dpov2_aligner import DPOv2Aligner
from lmflow.pipeline.rm_tuner import RewardModelTuner
from lmflow.pipeline.rm_inferencer import RewardModelInferencer
PIPELINE_MAPPING = {
Expand All @@ -28,6 +29,7 @@ def is_package_version_at_least(package_name, min_version):
"vllm_inferencer": VLLMInferencer,
"rm_inferencer": RewardModelInferencer,
"dpo_aligner": DPOAligner,
"dpov2_aligner": DPOv2Aligner,
"rm_tuner": RewardModelTuner,
}

Expand Down
3 changes: 2 additions & 1 deletion src/lmflow/pipeline/dpov2_aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class DPOv2Aligner(BaseAligner):
def __init__(
self,
model_args: ModelArguments,
ref_model_args: ModelArguments,
data_args: DatasetArguments,
aligner_args: DPOv2AlignerArguments,
ref_model_args: ModelArguments,
):
self.model_args = model_args
self.ref_model_args = ref_model_args
Expand Down Expand Up @@ -91,6 +91,7 @@ def align(
max_length=self.aligner_args.max_length,
mask_prompt=self.aligner_args.mask_prompt,
len_penalty=self.aligner_args.length_penalty,
# preprocessing_num_workers=self.data_args.preprocessing_num_workers, # will trigger TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroup' object
)

# step 3. train
Expand Down
2 changes: 2 additions & 0 deletions src/lmflow/pipeline/utils/dpov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
mask_prompt: Optional[bool] = False,
len_penalty: float = 0,
preprocessing_num_workers: int = 1,
):

if data_collator is None:
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
disable_dropout=disable_dropout,
generate_during_eval=generate_during_eval,
compute_metrics=compute_metrics,
dataset_num_proc=preprocessing_num_workers,
)
self.use_dpo_data_collator = True
self.len_penalty = len_penalty
Expand Down

0 comments on commit 1715212

Please sign in to comment.