Skip to content

Commit

Permalink
Add KTO support (#1640)
Browse files Browse the repository at this point in the history
* add kto support

* test cleanup

* fix outdated comment

* fix llama3 ultra

* chore: lint

* update to use rl_beta instead of dpo_beta

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
benredmond and winglian authored May 20, 2024
1 parent ba45531 commit 22ae21a
Show file tree
Hide file tree
Showing 11 changed files with 435 additions and 18 deletions.
31 changes: 29 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length

from axolotl.loraplus import create_loraplus_optimizer
Expand Down Expand Up @@ -826,6 +826,14 @@ class AxolotlORPOTrainer(ORPOTrainer):
tag_names = ["axolotl", "orpo"]


class AxolotlKTOTrainer(KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""

tag_names = ["axolotl", "kto"]


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1532,6 +1540,22 @@ def build_training_arguments(self, total_num_steps):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len

if self.cfg.rl == "kto":
training_args_cls = KTOConfig

training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)

training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len

training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
Expand Down Expand Up @@ -1567,7 +1591,7 @@ def build(self, total_num_steps):
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]

# these aren't used for the ORPO trainer
Expand All @@ -1580,6 +1604,9 @@ def build(self, total_num_steps):
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl == "kto":
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/prompt_strategies/kto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
module for KTO style dataset transform strategies
"""

from functools import partial

from ..base import load as load_base

load = partial(load_base, module_base="axolotl.prompt_strategies.kto")
105 changes: 105 additions & 0 deletions src/axolotl/prompt_strategies/kto/chatml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
KTO strategies for chatml
"""
# pylint: disable=duplicate-code


def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample

return transform_fn


def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/kto-mix-15k conversations
"""

def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample

return transform_fn


def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca KTO
ex: argilla/distilabel-intel-orca-kto
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample

return transform_fn


def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample

return transform_fn


def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample

return transform_fn
105 changes: 105 additions & 0 deletions src/axolotl/prompt_strategies/kto/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
KTO strategies for llama-3 chat template
"""
# pylint: disable=duplicate-code


def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample

return transform_fn


def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/kto-mix-15k conversations
"""

def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample

return transform_fn


def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca KTO
ex: argilla/distilabel-intel-orca-kto
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample

return transform_fn


def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample

return transform_fn


def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
"""

def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample

return transform_fn
39 changes: 39 additions & 0 deletions src/axolotl/prompt_strategies/kto/user_defined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
User-defined KTO strategies
"""
# pylint: disable=duplicate-code


def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_completion = ds_cfg.get("field_completion", "completion")
field_label = ds_cfg.get("field_label", "label")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
completion_format = ds_cfg.get("completion_format")
if not completion_format:
chosen_format = "{" + field_completion + "}"

def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["completion"] = chosen_format.format(chosen=sample[field_completion])
sample["label"] = sample[field_label]
return sample

return transform_fn
Loading

0 comments on commit 22ae21a

Please sign in to comment.