-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy path05_dpo.py
140 lines (114 loc) · 4.21 KB
/
05_dpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from datasets import load_dataset
import time
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from peft import (
get_peft_model,
LoraConfig,
TaskType,
prepare_model_for_kbit_training,
)
from trl import DPOConfig, DPOTrainer
import sys, os
if len(sys.argv) != 5:
print("Usage: python dpo.py <base_model_name> <model_name> <dataset.json> <results_dir>")
exit()
base_model_name = sys.argv[1]
model_name = sys.argv[2]
# use $MODEL/generated/preferences.jsonl for the dataset
dataset_file = sys.argv[3]
output_dir = sys.argv[4]
# load the training dataset
dataset = load_dataset("json", data_files={'train': dataset_file})
dataset = dataset['train'].shuffle(seed=42)
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16,
# )
device_map = "auto"
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
# quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
)
base_model.config.use_cache = False
# when running with ref model whis error comes:
# ValueError: You passed both a ref_model and a peft_config.
# For training PEFT adapters with DPO there is no need to pass a reference model. Please pass `ref_model=None` in case you want to train PEFT adapters.
# ref_model = AutoModelForCausalLM.from_pretrained(
# base_model_name,
# quantization_config=bnb_config,
# device_map=device_map,
# trust_remote_code=True,
# )
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.chat_template = """<|begin_of_text|>
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
{{- message['content'] + '<|eot_id|>' }}
{%- endfor %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
"""
tokenizer.eos_token = "<|eot_id|>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
def get_prompt(example):
prompt_sample = [
{"role": "system", "content": "Respond to the following user query in a comprehensive and detailed way. But first write down your internal thoughts. This must include your draft response and its evaluation. After this, write your final response after \"<R>\"."},
{"role": "user", "content": example['prompt']}
]
prompt_for_model = tokenizer.apply_chat_template(prompt_sample, tokenize=False)
example['prompt'] = prompt_for_model
example['chosen'] = example['chosen'] + tokenizer.eos_token
example['rejected'] = example['rejected'] + tokenizer.eos_token
# print(example)
return example
# filter out examples that do not have the required keys or are empty
dataset = dataset.filter(lambda x: 'prompt' in x and 'chosen' in x and 'rejected' in x and x['prompt'] and x['chosen'] and x['rejected'])
dataset = dataset.map(get_prompt)
# from https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb
lora_dropout=0.05
lora_alpha=16
lora_r=16
learning_rate=5e-4
batch_size = 1
def create_peft_config(model):
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
lora_dropout=lora_dropout,
lora_alpha=lora_alpha,
r=lora_r,
bias="none",
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"]
)
# model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
model, lora_config = create_peft_config(base_model)
training_args = DPOConfig(
output_dir=output_dir,
logging_steps=1,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
warmup_steps=50,
max_length=2048,
)
trainer = DPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
peft_config=lora_config,
train_dataset=dataset
)
# Time the training
start_time = time.time()
trainer.train()
end_time = time.time()
print(f"Training time: {end_time - start_time} seconds")
output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)