diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0fdd126f5..9e79f8299 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer, ORPOConfig, ORPOTrainer +from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -826,6 +826,14 @@ class AxolotlORPOTrainer(ORPOTrainer): tag_names = ["axolotl", "orpo"] +class AxolotlKTOTrainer(KTOTrainer): + """ + Extend the base KTOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "kto"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1532,6 +1540,22 @@ def build_training_arguments(self, total_num_steps): if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + if self.cfg.rl == "kto": + training_args_cls = KTOConfig + + training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1 + training_args_kwargs["desirable_weight"] = ( + self.cfg.kto_desirable_weight or 1.0 + ) + training_args_kwargs["undesirable_weight"] = ( + self.cfg.kto_undesirable_weight or 1.0 + ) + + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + training_args_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, @@ -1567,7 +1591,7 @@ def build(self, total_num_steps): ] = self.cfg.precompute_ref_log_probs if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: trainer_cls = AxolotlDPOTrainer - dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 + dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] # these aren't used for the ORPO trainer @@ -1580,6 +1604,9 @@ def build(self, total_num_steps): elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] + elif self.cfg.rl == "kto": + trainer_cls = AxolotlKTOTrainer + trainer_cls_args = [self.model] else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") dpo_trainer = trainer_cls( diff --git a/src/axolotl/prompt_strategies/kto/__init__.py b/src/axolotl/prompt_strategies/kto/__init__.py new file mode 100644 index 000000000..9af6300eb --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/__init__.py @@ -0,0 +1,9 @@ +""" +module for KTO style dataset transform strategies +""" + +from functools import partial + +from ..base import load as load_base + +load = partial(load_base, module_base="axolotl.prompt_strategies.kto") diff --git a/src/axolotl/prompt_strategies/kto/chatml.py b/src/axolotl/prompt_strategies/kto/chatml.py new file mode 100644 index 000000000..46c305f83 --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/chatml.py @@ -0,0 +1,105 @@ +""" +KTO strategies for chatml +""" +# pylint: disable=duplicate-code + + +def argilla( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn + + +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/kto-mix-15k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>" + return sample + + return transform_fn + + +def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + For Intel Orca KTO + ex: argilla/distilabel-intel-orca-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn + + +def prompt_pairs( + cfg, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn + + +def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + for ultrafeedback binarized conversations + ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/kto/llama3.py b/src/axolotl/prompt_strategies/kto/llama3.py new file mode 100644 index 000000000..795d343fe --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/llama3.py @@ -0,0 +1,105 @@ +""" +KTO strategies for llama-3 chat template +""" +# pylint: disable=duplicate-code + + +def argilla( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn + + +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/kto-mix-15k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>" + return sample + + return transform_fn + + +def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + For Intel Orca KTO + ex: argilla/distilabel-intel-orca-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn + + +def prompt_pairs( + cfg, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn + + +def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + for ultrafeedback binarized conversations + ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/kto/user_defined.py b/src/axolotl/prompt_strategies/kto/user_defined.py new file mode 100644 index 000000000..7e5458bb7 --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/user_defined.py @@ -0,0 +1,39 @@ +""" +User-defined KTO strategies +""" +# pylint: disable=duplicate-code + + +def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument + ds_cfg = cfg["datasets"][dataset_idx]["type"] + if not isinstance(ds_cfg, dict): + raise ValueError( + f"User-defined dataset type must be a dictionary. Got: {ds_cfg}" + ) + field_prompt = ds_cfg.get("field_prompt", "prompt") + field_system = ds_cfg.get("field_system", "system") + field_completion = ds_cfg.get("field_completion", "completion") + field_label = ds_cfg.get("field_label", "label") + prompt_format = ds_cfg.get("prompt_format") + if not prompt_format: + prompt_format = "{" + field_prompt + "}" + completion_format = ds_cfg.get("completion_format") + if not completion_format: + chosen_format = "{" + field_completion + "}" + + def transform_fn(sample): + if ( + "{" + field_system + "}" in prompt_format + and field_system in sample + and sample[field_system] + ): + sample["prompt"] = prompt_format.format( + system=sample[field_system], prompt=sample[field_prompt] + ) + else: + sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) + sample["completion"] = chosen_format.format(chosen=sample[field_completion]) + sample["label"] = sample[field_label] + return sample + + return transform_fn diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 585fbd734..82db40e5a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -24,6 +24,7 @@ class DeprecatedParameters(BaseModel): max_packed_sequence_len: Optional[int] = None rope_scaling: Optional[Any] = None noisy_embedding_alpha: Optional[float] = None + dpo_beta: Optional[float] = None @field_validator("max_packed_sequence_len") @classmethod @@ -48,6 +49,13 @@ def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha): LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") return noisy_embedding_alpha + @field_validator("dpo_beta") + @classmethod + def validate_dpo_beta(cls, dpo_beta): + if dpo_beta is not None: + LOG.warning("dpo_beta is deprecated, use rl_beta instead") + return dpo_beta + class RemappedParameters(BaseModel): """parameters that have been remapped to other names""" @@ -126,6 +134,26 @@ class DPODataset(BaseModel): data_files: Optional[List[str]] = None +class UserDefinedKTOType(BaseModel): + """User defined typing for KTO""" + + field_system: Optional[str] = None + field_prompt: Optional[str] = None + field_completion: Optional[str] = None + field_label: Optional[bool] = None + prompt_format: Optional[str] = None + completion_format: Optional[str] = None + + +class KTODataset(BaseModel): + """KTO configuration subset""" + + path: Optional[str] = None + split: Optional[str] = None + type: Optional[Union[UserDefinedKTOType, str]] = None + data_files: Optional[List[str]] = None + + class RLType(str, Enum): """RL trainer type configuration subset""" @@ -133,6 +161,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name + kto = "kto" # pylint: disable=invalid-name class ChatTemplate(str, Enum): @@ -450,8 +479,8 @@ class Config: rl: Optional[RLType] = None - datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore - test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore + datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore + test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore shuffle_merged_datasets: Optional[bool] = True dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None @@ -585,6 +614,10 @@ class Config: orpo_alpha: Optional[float] = None + kto_desirable_weight: Optional[float] = None + kto_undesirable_weight: Optional[float] = None + rl_beta: Optional[float] = None + max_memory: Optional[ Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] ] = None @@ -884,6 +917,13 @@ def validate_neftune_noise_alpha(cls, neftune_noise_alpha): raise ValueError("neftune_noise_alpha must be > 0.0") return neftune_noise_alpha + @model_validator(mode="after") + def check(self): + if self.dpo_beta and not self.rl_beta: + self.rl_beta = self.dpo_beta + del self.dpo_beta + return self + @model_validator(mode="before") @classmethod def check_frozen(cls, data): diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index ff5ca87dd..7416ca28b 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -10,6 +10,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault @@ -55,6 +56,22 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset): dataset.save_to_disk(str(prepared_ds_path)) +def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): + sig = inspect.signature(ds_transform_fn) + if "tokenizer" in sig.parameters: + if not tokenizer: + tokenizer = load_tokenizer(cfg) + ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) + + data_set = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) + if isinstance(data_set, DatasetDict): + data_set = data_set["train"] + return data_set + + def load_prepare_dpo_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] @@ -76,6 +93,7 @@ def load_split(dataset_cfgs, _cfg): split_datasets.insert(i, ds) tokenizer = None + for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: @@ -83,21 +101,19 @@ def load_split(dataset_cfgs, _cfg): _type = "user_defined.default" if _cfg.rl == "orpo": ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) + elif _cfg.rl == "kto": + ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) else: ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) - sig = inspect.signature(ds_transform_fn) - if "tokenizer" in sig.parameters: - if not tokenizer: - tokenizer = load_tokenizer(_cfg) - ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) - - data_set = data_set.map( - ds_transform_fn, - desc="Mapping RL Dataset", + + split_datasets[i] = map_dataset( + cfg, data_set, ds_transform_fn, tokenizer + ) + elif _cfg.rl == "kto": + ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) + split_datasets[i] = map_dataset( + cfg, data_set, ds_transform_fn, tokenizer ) - if isinstance(data_set, DatasetDict): - data_set = data_set["train"] - split_datasets[i] = data_set else: # If no `type` is provided, assume the dataset is already in the expected format with # "prompt", "chosen" and "rejected" already preprocessed diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4f8388a0a..a8df4bbad 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -803,7 +803,11 @@ def load_model( if not reference_model or cfg.lora_model_dir: # if we're not loading the reference model, then we're loading the model for training # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: + if ( + cfg.adapter + and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"] + and not cfg.merge_lora + ): _, lora_config = load_lora(model, cfg, inference=False, config_only=True) else: model, lora_config = load_adapter(model, cfg, cfg.adapter) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fdf86e567..83977ef06 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -428,7 +428,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 9596b1873..ddd63d827 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -205,3 +205,66 @@ def test_orpo_lora(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + + @with_temp_dir + def test_kto_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "kto", + "rl_beta": 0.5, + "kto_desirable_weight": 1.0, + "kto_undesirable_weight": 1.0, + "remove_unused_columns": False, + "datasets": [ + # { + # "path": "argilla/kto-mix-15k", + # "type": "chatml.argilla_chat", + # "split": "train", + # }, + { + "path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto", + "type": "chatml.ultra", + "split": "train", + }, + # { + # "path": "argilla/kto-mix-15k", + # "type": "llama3.argilla_chat", + # "split": "train", + # }, + { + "path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto", + "type": "llama3.ultra", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() diff --git a/tests/test_validation.py b/tests/test_validation.py index 27824f288..35d0e265e 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1117,6 +1117,15 @@ def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): validate_config(cfg) assert len(self._caplog.records) == 0 + def test_dpo_beta_deprecation(self, minimal_cfg): + cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg + + with self._caplog.at_level(logging.WARNING): + new_cfg = validate_config(cfg) + assert new_cfg["rl_beta"] == 0.2 + assert new_cfg["dpo_beta"] is None + assert len(self._caplog.records) == 1 + class TestValidationCheckModelConfig(BaseValidation): """