Skip to content

Commit

Permalink
Llama3 dpo (#1610)
Browse files Browse the repository at this point in the history
* add dpo llama3

* fix dpo bos and eos

* bos token gets added automatically by the tokenizer

* explicit <|end_of_text|> not needed, as eot_id is sufficient

---------

Co-authored-by: Nero10578 <[email protected]>
  • Loading branch information
winglian and Nero10578 authored May 11, 2024
1 parent 50421c8 commit 2147cf6
Showing 1 changed file with 133 additions and 0 deletions.
133 changes: 133 additions & 0 deletions src/axolotl/prompt_strategies/dpo/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
DPO strategies for llama-3 chat template
"""


def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
return sample

return transform_fn


def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""

def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample

return transform_fn


def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample

return transform_fn


def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample

return transform_fn


def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample

return transform_fn


def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample

return transform_fn

1 comment on commit 2147cf6

@Nero10578
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see so the bos token gets added automatically by the tokenizer and the eos is not needed.

Please sign in to comment.