Skip to content

Commit

Permalink
[feature] iterative dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
wheresmyhair committed Jul 19, 2024
1 parent 11816e1 commit 3d99d8a
Show file tree
Hide file tree
Showing 12 changed files with 697 additions and 26 deletions.
21 changes: 21 additions & 0 deletions configs/accelerate_dsz0_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 16
zero3_init_flag: false
zero_stage: 0
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
gpu_ids:
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 12580
81 changes: 81 additions & 0 deletions configs/iterative_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# general
## model
model_name_or_path: "/home/yizhenjia/.cache/huggingface/hub/models--tensoropera--Fox-1-1.6B-Instruct-v0.1/snapshots/208e4eb147473dce763b889d5974aba38c1f03d3" # initial model
reference_model_name_or_path: "/home/yizhenjia/.cache/huggingface/hub/models--tensoropera--Fox-1-1.6B-Instruct-v0.1/snapshots/208e4eb147473dce763b889d5974aba38c1f03d3"
reward_model_name_or_path: /home/yizhenjia/models/sfairXC-FsfairX-LLaMA3-RM-v0.1
reward_arch_type: text_regression
trust_remote_code: True

## data
dataset_path_list:
- "data/iterative-prompt-3it-100/iter1"
- "data/iterative-prompt-3it-100/iter2"
- "data/iterative-prompt-3it-100/iter3"
conversation_template: chatml
preprocessing_num_workers: 16

## pipeline
output_dir: ./output_models/iterative_dpo_pipelinetest
run_name: iterative_dpo
random_seed: 42
use_accelerator: True
enable_distributed_inference: True
distributed_inference_num_instances: 8


# inference phase
## general
apply_chat_template: True
num_output_sequences: 3
use_beam_search: False
temperature: 1.0
top_p: 0.9
max_new_tokens: 4096
enable_decode_inference_result: True

## vllm
use_vllm: True
vllm_gpu_memory_utilization: 0.95
vllm_tensor_parallel_size: 1
vllm_inference_batch_size: 16


# reward model scoring phase
reward_model_inference_block_size: 2048
overwrite_cache: True
reward_model_inference_batch_size: 8 # the actual batch size for rm forward will be reward_model_inference_batch_size * num_output_sequences


# dpo phase
## model
do_train: True

## data
sampling_paired_method: max_min
margin_scale: 1.0
length_penalty: 0
max_prompt_length: 1000
mask_prompt: True

## pipeline
### training
accelerate_config_file: configs/accelerate_dsz0_config.yaml
bf16: True
num_train_epochs: 2
learning_rate: 2.0e-5
warmup_steps: 100
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
gradient_accumulation_steps: 16
gradient_checkpointing: True
loss_type: sigmoid
lr_scheduler_type: cosine
optim: paged_adamw_32bit

### logging
logging_steps: 2
save_strategy: steps
save_steps: 100
evaluation_strategy: steps
eval_steps: 100
report_to: wandb
83 changes: 83 additions & 0 deletions examples/iterative_dpo_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
import os
import sys
import copy

from transformers import (
HfArgumentParser
)

from lmflow.datasets import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.args import (
ModelArguments,
DatasetArguments,
AutoArguments,
)
from lmflow.utils.common import remove_dataclass_attr_prefix, create_copied_dataclass


logger = logging.getLogger(__name__)


ReferenceModelArguments = create_copied_dataclass(
original_dataclass=ModelArguments,
field_prefix="reference_",
class_prefix="Reference"
)

RewardModelArguments = create_copied_dataclass(
original_dataclass=ModelArguments,
field_prefix="reward_",
class_prefix="Reward"
)


def main():
pipeline_name = "iterative_dpo_aligner"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
ModelArguments,
ReferenceModelArguments,
RewardModelArguments,
DatasetArguments,
PipelineArguments
))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, ref_model_args, reward_model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, ref_model_args, reward_model_args, data_args, pipeline_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))
else:
model_args, ref_model_args, reward_model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()

ref_model_args_dict = remove_dataclass_attr_prefix(ref_model_args, "reference_")
ref_model_args = ModelArguments(**ref_model_args_dict)
reward_model_args_dict = remove_dataclass_attr_prefix(reward_model_args, "reward_")
reward_model_args = ModelArguments(**reward_model_args_dict)

dataset_list = []
for dataset in pipeline_args.dataset_path_list:
iter_data_args = copy.deepcopy(data_args)
iter_data_args.dataset_path = dataset
dataset_list.append(Dataset(iter_data_args))

aligner = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
ref_model_args=ref_model_args,
reward_model_args=reward_model_args,
)

aligner.align(dataset_list=dataset_list)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ pydantic
gradio
accelerate>=0.27.2
einops>=0.6.1
vllm>=0.4.1
vllm>=0.4.3
ray>=2.22.0
1 change: 1 addition & 0 deletions scripts/run_iterative_dpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python examples/iterative_dpo_train.py configs/iterative_dpo.yaml
46 changes: 46 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,10 @@ class InferencerArguments:
default=1,
metadata={"help": "batch size for inference"},
)
vllm_inference_batch_size: int = field(
default=1,
metadata={"help": "The batch size for VLLM inference."}
)
temperature: float = field(
default=0.0,
metadata={"help": "Temperature during inference."},
Expand Down Expand Up @@ -1072,6 +1076,18 @@ class InferencerArguments:
default=False,
metadata={"help": "Whether to decode the inference results."},
)
tensor_parallel_size: Optional[int] = field(
default=1,
metadata={"help": "The tp size for distributed (multi-instance) inference."}
)
enable_distributed_inference: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use multi-instance VLLM inference."}
)
distributed_inference_num_instances: Optional[int] = field(
default=1,
metadata={"help": "The number of instances for multi-instance VLLM inference."}
)

# vllm inference args
use_vllm: bool = field(
Expand Down Expand Up @@ -1351,6 +1367,12 @@ class DPOv2AlignerArguments(FinetunerArguments):
"""
The arguments for the DPOv2 training script.
"""
# general args
random_seed: Optional[int] = field(default=42, metadata={"help": "the random seed"})
accelerate_config_file: Optional[str] = field(
default=None,
metadata={"help": "file path for accelerate config file, only used in memory safe dpov2 align."}
)
# pair sampling args
margin_scale: Optional[float] = field(default=1.0, metadata={"help": "the margin scale"})
sampling_paired_method: Optional[str] = field(default="max_random", metadata={"help": "the choose type"})
Expand All @@ -1370,6 +1392,29 @@ class IterativeAlignerArguments(InferencerArguments):
Arguments for iterative aligners.
"""
pass


@dataclass
class IterativeDPOAlignerArguments(IterativeAlignerArguments, DPOv2AlignerArguments):
"""
Arguments for iterative DPO aligners.
"""
output_dir: Optional[str] = field(
default="./runs",
metadata={"help": "Output path for the inferenced results"},
)
dataset_path_list: List[str] = field(
default_factory=list,
metadata={"help": "The list of dataset paths for iterative aligners."}
)
reward_model_inference_batch_size: int = field(
default=1,
metadata={"help": "The batch size for reward model inference."}
)
reward_model_inference_block_size: int = field(
default=2048,
metadata={"help": "The block size for reward model inference."}
)


PIPELINE_ARGUMENT_MAPPING = {
Expand All @@ -1382,6 +1427,7 @@ class IterativeAlignerArguments(InferencerArguments):
"dpo_aligner": DPOAlignerArguments,
"rm_tuner": RewardModelTunerArguments,
"dpov2_aligner": DPOv2AlignerArguments,
"iterative_dpo_aligner": IterativeDPOAlignerArguments,
}


Expand Down
Loading

0 comments on commit 3d99d8a

Please sign in to comment.