diff --git a/configs/accelerate_dsz3_config.yaml b/configs/accelerate_dsz3_config.yaml new file mode 100644 index 000000000..7f2cf9600 --- /dev/null +++ b/configs/accelerate_dsz3_config.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +gpu_ids: 4,5,6,7 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/dpov2_train.py b/examples/dpov2_train.py new file mode 100644 index 000000000..ae4ca22d6 --- /dev/null +++ b/examples/dpov2_train.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. +import logging +import os +import sys +import copy + +from transformers import ( + HfArgumentParser +) + +from lmflow.datasets import Dataset +from lmflow.models.auto_model import AutoModel +from lmflow.pipeline.auto_pipeline import AutoPipeline +from lmflow.args import ( + ModelArguments, + DatasetArguments, + AutoArguments, +) +from lmflow.utils.common import remove_dataclass_attr_prefix, create_copied_dataclass + + +logger = logging.getLogger(__name__) + + +ReferenceModelArguments = create_copied_dataclass( + original_dataclass=ModelArguments, + field_prefix="reference_", + class_prefix="Reference" +) + + +def main(): + # Parses arguments + pipeline_name = "dpov2_aligner" + PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + + parser = HfArgumentParser(( + ModelArguments, + ReferenceModelArguments, + DatasetArguments, + PipelineArguments + )) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, ref_model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, ref_model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() + + ref_model_args_dict = remove_dataclass_attr_prefix(ref_model_args, "reference_") + ref_model_args = ModelArguments(**ref_model_args_dict) + + train_dataset = Dataset(data_args) + eval_data_args = copy.deepcopy(data_args) + eval_data_args.dataset_path = pipeline_args.eval_dataset_path + eval_dataset = Dataset(eval_data_args) + model = AutoModel.get_model(model_args) + ref_model = AutoModel.get_model(ref_model_args) + aligner = AutoPipeline.get_pipeline( + pipeline_name=pipeline_name, + model_args=model_args, + data_args=data_args, + pipeline_args=pipeline_args, + ref_model_args=ref_model_args, + ) + + res = aligner.align( + model=model, + ref_model=ref_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/run_dpo_align.sh b/scripts/run_dpo_align.sh index d9c54f4fc..7d2ee00be 100644 --- a/scripts/run_dpo_align.sh +++ b/scripts/run_dpo_align.sh @@ -20,7 +20,7 @@ while [[ $# -ge 1 ]]; do dataset_path="$2" shift ;; - -o|--output_lora_path) + -o|--output_dir) output_dir="$2" shift ;; diff --git a/scripts/run_dpov2_align.sh b/scripts/run_dpov2_align.sh new file mode 100644 index 000000000..9a2b0faae --- /dev/null +++ b/scripts/run_dpov2_align.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# Parses arguments +run_name=dpov2_align +model_name_or_path=meta-llama/Meta-Llama-3-8B-Instruct +reference_model_name_or_path=meta-llama/Meta-Llama-3-8B-Instruct +dataset_path=data/iterative-prompt/train +eval_dataset_path=data/iterative-prompt/eval +output_dir=output_models/${run_name} +deepspeed_args="--master_port=11000 --include localhost:4,5,6,7" + +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -r|--run_name) + run_name="$2" + shift + ;; + --model_name_or_path) + model_name_or_path="$2" + shift + ;; + --reference_model_name_or_path) + reference_model_name_or_path="$2" + shift + ;; + --dataset_path) + dataset_path="$2" + shift + ;; + --eval_dataset_path) + eval_dataset_path="$2" + shift + ;; + -o|--output_dir) + output_dir="$2" + shift + ;; + --deepspeed_args) + deepspeed_args="$2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +project_dir=$(cd "$(dirname $0)"/..; pwd) +log_dir=${project_dir}/log/${run_name} +mkdir -p ${output_dir} ${log_dir} + +accelerate launch --config_file configs/accelerate_dsz3_config.yaml \ + examples/dpov2_train.py \ + --model_name_or_path ${model_name_or_path} \ + --reference_model_name_or_path ${reference_model_name_or_path} \ + --do_train True \ + --dataset_path ${dataset_path} \ + --eval_dataset_path ${eval_dataset_path} \ + --bf16 True \ + --learning_rate 5e-7 \ + --lr_scheduler_type cosine \ + --warmup_steps 100 \ + --optim paged_adamw_32bit \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing True \ + --margin_scale 1.0 \ + --max_prompt_length 1000 \ + --num_train_epochs 2 \ + --logging_steps 2 \ + --save_strategy epoch \ + --save_steps 5000 \ + --evaluation_strategy steps \ + --eval_steps 100 \ + --loss_type sigmoid \ + --output_dir ${output_dir} \ + --run_name ${run_name} \ + --sampling_paired_method max_min \ + --report_to wandb \ + --mask_prompt True \ + --length_penalty 0 \ + | tee ${log_dir}/train.log \ + 2> ${log_dir}/train.err \ No newline at end of file diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 6a1aa438e..786f27a8a 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -1289,65 +1289,22 @@ class DPOAlignerArguments: @dataclass -class DPOv2AlignerArguments(TrainingArguments): +class DPOv2AlignerArguments(FinetunerArguments): """ The arguments for the DPOv2 training script. """ - - # data parameters, i.e., the KL penalty in the paper - beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) - - # training parameters - eval_dir: Optional[str] = field( - default="/export/home/hanze/project/vllm-gen/uf_split0_offline_reward.json", # "/export/home/data/gemma_it_2b_3w_k8_with_pairrm_rewards.json", - metadata={"help": "the location of the evalset name or path"}, - ) - learning_rate: Optional[float] = field(default=5e-7, metadata={"help": "optimizer learning rate"}) - lr_scheduler_type: Optional[str] = field( - default="constant_with_warmup", metadata={"help": "the lr scheduler type"} - ) - warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) - weight_decay: Optional[float] = field(default=0.01, metadata={"help": "the weight decay"}) - - per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) - per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) - gradient_accumulation_steps: Optional[int] = field( - default=16, metadata={"help": "the number of gradient accumulation steps"} - ) - gradient_checkpointing: Optional[bool] = field( - default=True, metadata={"help": "whether to use gradient checkpointing"} - ) - - - lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) - lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) - lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) - + # pair sampling args margin_scale: Optional[float] = field(default=1.0, metadata={"help": "the margin scale"}) - - max_prompt_length: Optional[int] = field(default=1000, metadata={"help": "the maximum prompt length"}) - max_length: Optional[int] = field(default=2048, metadata={"help": "the maximum sequence length"}) - num_train_epochs: Optional[int] = field(default=2, metadata={"help": "max number of training epochs"}) - logging_steps: Optional[int] = field(default=2, metadata={"help": "the logging frequency"}) - save_strategy: Optional[str] = field(default="epoch", metadata={"help": "the saving strategy"}) - save_steps: Optional[int] = field(default=50000, metadata={"help": "the saving frequency"}) - eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) - run_name: Optional[str] = field(default="dpo_soft", metadata={"help": "the run name"}) - loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type"}) - output_dir: Optional[str] = field(default="./dpo_soft", metadata={"help": "the output directory"}) - log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) - - # instrumentation sampling_paired_method: Optional[str] = field(default="max_random", metadata={"help": "the choose type"}) - - mask_prompt: Optional[bool] = field(default=False, metadata={"help": "mask prompt"}) length_penalty: Optional[float] = field(default=0, metadata={"help": "the length penalty"}) + # data collator args + max_length: Optional[int] = field(default=2048, metadata={"help": "the maximum sequence length, prompt + output"}) + max_prompt_length: Optional[int] = field(default=1000, metadata={"help": "the maximum prompt length"}) + mask_prompt: Optional[bool] = field(default=False, metadata={"help": "mask prompt"}) + # dpov2 aligner args + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type"}) - # need to add - evaluation_strategy: Optional[str] = field( - default="steps", - metadata={"help": "the evaluation strategy"} - ) @dataclass class IterativeAlignerArguments(InferencerArguments): @@ -1366,6 +1323,7 @@ class IterativeAlignerArguments(InferencerArguments): "raft_aligner": RaftAlignerArguments, "dpo_aligner": DPOAlignerArguments, "rm_tuner": RewardModelTunerArguments, + "dpov2_aligner": DPOv2AlignerArguments, } diff --git a/src/lmflow/pipeline/auto_pipeline.py b/src/lmflow/pipeline/auto_pipeline.py index a5e815636..82212802f 100644 --- a/src/lmflow/pipeline/auto_pipeline.py +++ b/src/lmflow/pipeline/auto_pipeline.py @@ -19,6 +19,7 @@ def is_package_version_at_least(package_name, min_version): from lmflow.pipeline.inferencer import Inferencer from lmflow.pipeline.vllm_inferencer import VLLMInferencer from lmflow.pipeline.dpo_aligner import DPOAligner +from lmflow.pipeline.dpov2_aligner import DPOv2Aligner from lmflow.pipeline.rm_tuner import RewardModelTuner from lmflow.pipeline.rm_inferencer import RewardModelInferencer PIPELINE_MAPPING = { @@ -28,6 +29,7 @@ def is_package_version_at_least(package_name, min_version): "vllm_inferencer": VLLMInferencer, "rm_inferencer": RewardModelInferencer, "dpo_aligner": DPOAligner, + "dpov2_aligner": DPOv2Aligner, "rm_tuner": RewardModelTuner, } diff --git a/src/lmflow/pipeline/dpov2_aligner.py b/src/lmflow/pipeline/dpov2_aligner.py index 151c1794e..81770da16 100644 --- a/src/lmflow/pipeline/dpov2_aligner.py +++ b/src/lmflow/pipeline/dpov2_aligner.py @@ -25,9 +25,9 @@ class DPOv2Aligner(BaseAligner): def __init__( self, model_args: ModelArguments, - ref_model_args: ModelArguments, data_args: DatasetArguments, aligner_args: DPOv2AlignerArguments, + ref_model_args: ModelArguments, ): self.model_args = model_args self.ref_model_args = ref_model_args @@ -91,6 +91,7 @@ def align( max_length=self.aligner_args.max_length, mask_prompt=self.aligner_args.mask_prompt, len_penalty=self.aligner_args.length_penalty, + # preprocessing_num_workers=self.data_args.preprocessing_num_workers, # will trigger TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroup' object ) # step 3. train diff --git a/src/lmflow/pipeline/utils/dpov2_trainer.py b/src/lmflow/pipeline/utils/dpov2_trainer.py index 032c6e92e..735daf635 100644 --- a/src/lmflow/pipeline/utils/dpov2_trainer.py +++ b/src/lmflow/pipeline/utils/dpov2_trainer.py @@ -54,6 +54,7 @@ def __init__( compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, mask_prompt: Optional[bool] = False, len_penalty: float = 0, + preprocessing_num_workers: int = 1, ): if data_collator is None: @@ -93,6 +94,7 @@ def __init__( disable_dropout=disable_dropout, generate_during_eval=generate_during_eval, compute_metrics=compute_metrics, + dataset_num_proc=preprocessing_num_workers, ) self.use_dpo_data_collator = True self.len_penalty = len_penalty