From 7fea5822f058ea8be48fbcf266e122582551f539 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 08:56:15 -0400 Subject: [PATCH 1/8] add support for SPPO --- docs/config.qmd | 2 +- src/axolotl/core/trainer_builder.py | 70 ++++++++++++++++++- .../config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/models.py | 6 +- src/axolotl/utils/trainer.py | 2 +- 5 files changed, 76 insertions(+), 5 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 570a173f9..800293535 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -138,7 +138,7 @@ test_datasets: data_files: - /workspace/data/eval.jsonl -# use RL training: 'dpo', 'ipo', 'kto_pair' +# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo' rl: # Saves the desired chat template to the tokenizer_config.json for easier inferencing diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 742a88633..cc53fb79b 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, Union import torch import transformers @@ -817,6 +817,70 @@ def tokenize_row( res[key] = res[key][1:] return res + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + if self.loss_type not in ["sigmoid", "hinge", "ipo", "kto_pair"]: + return super().dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative DPO loss. + if self.loss_type == "sppo": + # Calculate a and b + a = self.beta * ( # pylint: disable=invalid-name + policy_chosen_logps - reference_chosen_logps + ) + b = self.beta * ( # pylint: disable=invalid-name + policy_rejected_logps - reference_rejected_logps + ) + + # Compute the SPPO loss + losses = (a - 0.5) ** 2 + (b + 0.5) ** 2 + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']" + ) + + chosen_rewards = ( + self.beta + * ( + policy_chosen_logps.to(self.accelerator.device) + - reference_chosen_logps.to(self.accelerator.device) + ).detach() + ) + rejected_rewards = ( + self.beta + * ( + policy_rejected_logps.to(self.accelerator.device) + - reference_rejected_logps.to(self.accelerator.device) + ).detach() + ) + + return losses, chosen_rewards, rejected_rewards + class AxolotlORPOTrainer(ORPOTrainer): """ @@ -1552,6 +1616,8 @@ def build(self, total_num_steps): dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing elif self.cfg.rl == "kto_pair": dpo_trainer_kwargs["loss_type"] = "kto_pair" + elif self.cfg.rl == "sppo": + dpo_trainer_kwargs["loss_type"] = "sppo" if self.eval_dataset: dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: @@ -1560,7 +1626,7 @@ def build(self, total_num_steps): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: + if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: trainer_cls = AxolotlDPOTrainer dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 419deee58..53d60e76c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -133,6 +133,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name + sppo = "sppo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8537b7e75..f0ae55a73 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -789,7 +789,11 @@ def load_model( if not reference_model or cfg.lora_model_dir: # if we're not loading the reference model, then we're loading the model for training # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: + if ( + cfg.adapter + and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"] + and not cfg.merge_lora + ): _, lora_config = load_lora(model, cfg, inference=False, config_only=True) else: model, lora_config = load_adapter(model, cfg, cfg.adapter) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2e3728cc8..fe1f6e0bd 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -438,7 +438,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] From df645906eb55a60be9c6b1cbe9cf42caa3f80299 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 09:31:43 -0400 Subject: [PATCH 2/8] invert check --- src/axolotl/core/trainer_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cc53fb79b..b1554f0c0 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -837,7 +837,7 @@ def dpo_loss( The losses tensor contains the DPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ - if self.loss_type not in ["sigmoid", "hinge", "ipo", "kto_pair"]: + if self.loss_type in ["sigmoid", "hinge", "ipo", "kto_pair"]: return super().dpo_loss( policy_chosen_logps, policy_rejected_logps, From b301068098f1ae476e254562f68e5102396b4ad5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 11:01:20 -0400 Subject: [PATCH 3/8] remove override --- src/axolotl/core/trainer_builder.py | 66 +---------------------------- 1 file changed, 1 insertion(+), 65 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b1554f0c0..55eecf839 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Type, Union import torch import transformers @@ -817,70 +817,6 @@ def tokenize_row( res[key] = res[key][1:] return res - def dpo_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - reference_chosen_logps: torch.FloatTensor, - reference_rejected_logps: torch.FloatTensor, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute the DPO loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) - reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). - The losses tensor contains the DPO loss for each example in the batch. - The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. - """ - if self.loss_type in ["sigmoid", "hinge", "ipo", "kto_pair"]: - return super().dpo_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - ) - - # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. - # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and - # calculates a conservative DPO loss. - if self.loss_type == "sppo": - # Calculate a and b - a = self.beta * ( # pylint: disable=invalid-name - policy_chosen_logps - reference_chosen_logps - ) - b = self.beta * ( # pylint: disable=invalid-name - policy_rejected_logps - reference_rejected_logps - ) - - # Compute the SPPO loss - losses = (a - 0.5) ** 2 + (b + 0.5) ** 2 - else: - raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']" - ) - - chosen_rewards = ( - self.beta - * ( - policy_chosen_logps.to(self.accelerator.device) - - reference_chosen_logps.to(self.accelerator.device) - ).detach() - ) - rejected_rewards = ( - self.beta - * ( - policy_rejected_logps.to(self.accelerator.device) - - reference_rejected_logps.to(self.accelerator.device) - ).detach() - ) - - return losses, chosen_rewards, rejected_rewards - class AxolotlORPOTrainer(ORPOTrainer): """ From 60fecac367ea2b13721fa6ad567fc2ffa8446d38 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 12:12:53 -0400 Subject: [PATCH 4/8] bump trl --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 26525be15..19eb13d19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl==0.8.5 +trl @ git+https://github.com/huggingface/trl.git@adf17a5a269a0bc59162597f81e3d489a8c144e5 zstandard==0.22.0 fastcore From f58fcd09ec84a706947ad77d5f49ee72c4c80460 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 13:44:26 -0400 Subject: [PATCH 5/8] use DPOConfig --- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 19eb13d19..4ec2aec89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@adf17a5a269a0bc59162597f81e3d489a8c144e5 +trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 55eecf839..576c303b7 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer, ORPOConfig, ORPOTrainer +from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -1526,6 +1526,9 @@ def build_training_arguments(self, total_num_steps): if self.cfg.rl == "orpo": training_args_cls = ORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: + training_args_cls = DPOConfig + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, From 0554105baab6fdda13d225f456ec5e5b9e3561b5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 23:02:03 -0400 Subject: [PATCH 6/8] add mistral instruct strategy and fix dpo_loss input --- src/axolotl/prompt_strategies/dpo/mistral.py | 30 +++++++++++++++++++ .../config/models/input/v0_4_1/__init__.py | 1 + 2 files changed, 31 insertions(+) create mode 100644 src/axolotl/prompt_strategies/dpo/mistral.py diff --git a/src/axolotl/prompt_strategies/dpo/mistral.py b/src/axolotl/prompt_strategies/dpo/mistral.py new file mode 100644 index 000000000..49e948fcd --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/mistral.py @@ -0,0 +1,30 @@ +""" +DPO strategies for mistral instruct +""" + + +def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + sample["prompt"] = f"[INST]{sample['prompt']}[/INST]" + sample["chosen"] = f"{sample['chosen']}" + sample["rejected"] = f"{sample['rejected']}" + return sample + + return transform_fn + + +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/dpo-mix-7k conversations + """ + + def transform_fn(sample): + sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]" + sample["chosen"] = f"{sample['chosen'][1]['content']}" + sample["rejected"] = f"{sample['rejected'][1]['content']}" + return sample + + return transform_fn diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 53d60e76c..7c00500b9 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -575,6 +575,7 @@ class Config: neftune_noise_alpha: Optional[float] = None orpo_alpha: Optional[float] = None + dpo_beta: Optional[float] = None max_memory: Optional[ Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] From 027f7d54f0b4195eaf6f0eb0e126e4158038798b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 3 May 2024 08:41:59 -0400 Subject: [PATCH 7/8] update for sppo --- docs/config.qmd | 2 +- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 8 ++++---- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 +- src/axolotl/utils/models.py | 2 +- src/axolotl/utils/trainer.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 800293535..7cc4a712f 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -138,7 +138,7 @@ test_datasets: data_files: - /workspace/data/eval.jsonl -# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo' +# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard' rl: # Saves the desired chat template to the tokenizer_config.json for easier inferencing diff --git a/requirements.txt b/requirements.txt index 4ec2aec89..39c1623c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b +trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 576c303b7..0974f6f61 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1526,7 +1526,7 @@ def build_training_arguments(self, total_num_steps): if self.cfg.rl == "orpo": training_args_cls = ORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: + elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]: training_args_cls = DPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes @@ -1555,8 +1555,8 @@ def build(self, total_num_steps): dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing elif self.cfg.rl == "kto_pair": dpo_trainer_kwargs["loss_type"] = "kto_pair" - elif self.cfg.rl == "sppo": - dpo_trainer_kwargs["loss_type"] = "sppo" + elif self.cfg.rl == "sppo_hard": + dpo_trainer_kwargs["loss_type"] = "sppo_hard" if self.eval_dataset: dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: @@ -1565,7 +1565,7 @@ def build(self, total_num_steps): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: + if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]: trainer_cls = AxolotlDPOTrainer dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 7c00500b9..bb4476563 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -133,7 +133,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name - sppo = "sppo" # pylint: disable=invalid-name + sppo = "sppo_hard" # pylint: disable=invalid-name class ChatTemplate(str, Enum): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index f0ae55a73..fc8a67acf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -791,7 +791,7 @@ def load_model( # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config if ( cfg.adapter - and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"] + and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"] and not cfg.merge_lora ): _, lora_config = load_lora(model, cfg, inference=False, config_only=True) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fe1f6e0bd..1a0e55010 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -438,7 +438,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]: + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] From 6a9ac4ad276e5b0e2d7c2c47299c58b7e308b78d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 6 May 2024 16:58:58 -0400 Subject: [PATCH 8/8] consistency w sppo -> sppo_hard --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index bb4476563..78a36232c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -133,7 +133,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name - sppo = "sppo_hard" # pylint: disable=invalid-name + sppo_hard = "sppo_hard" # pylint: disable=invalid-name class ChatTemplate(str, Enum):