From 6a9cfec2227935393bcfc0fbe324ef6232c520ec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 21:22:16 -0400 Subject: [PATCH] add support for simpo via cpo trainer (#1772) * add support for simpo via cpo trainer * add cpo_alpha / sft_weight from the paper * make sure to use the right builder for simpo --- src/axolotl/core/trainer_builder.py | 47 +++++++++++++++++-- .../config/models/input/v0_4_1/__init__.py | 3 ++ src/axolotl/utils/trainer.py | 2 +- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 616b1d4eb..9a12c5a06 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,16 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer +from trl import ( + CPOConfig, + CPOTrainer, + DPOConfig, + DPOTrainer, + KTOConfig, + KTOTrainer, + ORPOConfig, + ORPOTrainer, +) from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -265,6 +274,18 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): """ +@dataclass +class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): + """ + CPO config for CPO training + """ + + simpo_gamma: Optional[float] = field( + default=None, + metadata={"help": "simpo gamma parameter"}, + ) + + class AxolotlTrainer(Trainer): """ Extend the base Trainer for axolotl helpers @@ -985,6 +1006,14 @@ class AxolotlKTOTrainer(KTOTrainer): tag_names = ["axolotl", "kto"] +class AxolotlCPOTrainer(CPOTrainer): + """ + Extend the base CPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "cpo"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1707,6 +1736,8 @@ def build_training_arguments(self, total_num_steps): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" + if self.cfg.rl_beta: + training_args_kwargs["beta"] = self.cfg.rl_beta if self.cfg.orpo_alpha: # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha @@ -1715,9 +1746,16 @@ def build_training_arguments(self, total_num_steps): training_args_cls = AxolotlDPOConfig if self.cfg.rpo_alpha is not None: training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha + + if self.cfg.rl == "simpo": + training_args_cls = AxolotlCPOConfig + training_args_kwargs["loss_type"] = "simpo" + training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma + if self.cfg.cpo_alpha is not None: + training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha + if self.cfg.rl == "orpo": training_args_cls = AxolotlORPOConfig - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len @@ -1725,7 +1763,6 @@ def build_training_arguments(self, total_num_steps): if self.cfg.rl == "kto": training_args_cls = AxolotlKTOConfig - training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1 training_args_kwargs["desirable_weight"] = ( self.cfg.kto_desirable_weight or 1.0 ) @@ -1771,7 +1808,6 @@ def build(self, total_num_steps): ] = self.cfg.precompute_ref_log_probs if self.cfg.rl in ["dpo", "ipo"]: trainer_cls = AxolotlDPOTrainer - dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] # these aren't used for the ORPO trainer @@ -1785,6 +1821,9 @@ def build(self, total_num_steps): elif self.cfg.rl in ["kto"]: trainer_cls = AxolotlKTOTrainer trainer_cls_args = [self.model] + elif self.cfg.rl in ["simpo"]: + trainer_cls = AxolotlCPOTrainer + trainer_cls_args = [self.model] else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") dpo_trainer = trainer_cls( diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 7f30283af..7397c7c73 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -172,6 +172,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name kto = "kto" # pylint: disable=invalid-name + simpo = "simpo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): @@ -644,6 +645,8 @@ class Config: orpo_alpha: Optional[float] = None rpo_alpha: Optional[float] = None + simpo_gamma: Optional[float] = None + cpo_alpha: Optional[float] = None kto_desirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 65c2d424e..c5a71e689 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -425,7 +425,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "orpo", "kto"]: + if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2]