Skip to content

Commit

Permalink
add support for simpo via cpo trainer (#1772)
Browse files Browse the repository at this point in the history
* add support for simpo via cpo trainer

* add cpo_alpha / sft_weight from the paper

* make sure to use the right builder for simpo
  • Loading branch information
winglian committed Jul 24, 2024
1 parent fe250ad commit 6a9cfec
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
47 changes: 43 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl import (
CPOConfig,
CPOTrainer,
DPOConfig,
DPOTrainer,
KTOConfig,
KTOTrainer,
ORPOConfig,
ORPOTrainer,
)
from trl.trainer.utils import pad_to_length

from axolotl.loraplus import create_loraplus_optimizer
Expand Down Expand Up @@ -265,6 +274,18 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""


@dataclass
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
"""
CPO config for CPO training
"""

simpo_gamma: Optional[float] = field(
default=None,
metadata={"help": "simpo gamma parameter"},
)


class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
Expand Down Expand Up @@ -985,6 +1006,14 @@ class AxolotlKTOTrainer(KTOTrainer):
tag_names = ["axolotl", "kto"]


class AxolotlCPOTrainer(CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""

tag_names = ["axolotl", "cpo"]


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1707,6 +1736,8 @@ def build_training_arguments(self, total_num_steps):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"

if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
Expand All @@ -1715,17 +1746,23 @@ def build_training_arguments(self, total_num_steps):
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha

if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha

if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len

if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig

training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
Expand Down Expand Up @@ -1771,7 +1808,6 @@ def build(self, total_num_steps):
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]

# these aren't used for the ORPO trainer
Expand All @@ -1785,6 +1821,9 @@ def build(self, total_num_steps):
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name


class ChatTemplate(str, Enum):
Expand Down Expand Up @@ -644,6 +645,8 @@ class Config:

orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None
simpo_gamma: Optional[float] = None
cpo_alpha: Optional[float] = None

kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def prepare_optim_env(cfg):


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
Expand Down

0 comments on commit 6a9cfec

Please sign in to comment.