Skip to content

Commit

Permalink
ORPO Trainer replacement (#1551)
Browse files Browse the repository at this point in the history
* WIP use trl ORPOTrainer

* fixes to make orpo work with trl

* fix the chat template laoding

* make sure to handle the special tokens and add_generation for assistant turn too
  • Loading branch information
winglian authored Apr 19, 2024
1 parent 0e8f340 commit 7d1d22f
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 26 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ s3fs
gcsfs
# adlfs

trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
trl==0.8.5
zstandard==0.22.0
fastcore
2 changes: 1 addition & 1 deletion src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH

if parsed_cfg.rl and parsed_cfg.rl != "orpo":
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
register_chatml_template()

if cfg.rl and cfg.rl != "orpo":
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
Expand Down
47 changes: 36 additions & 11 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length

from axolotl.loraplus import create_loraplus_optimizer
Expand Down Expand Up @@ -810,6 +810,14 @@ def tokenize_row(
return res


class AxolotlORPOTrainer(ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""

tag_names = ["axolotl", "orpo"]


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1404,7 +1412,7 @@ def build_collator(
)


class HFDPOTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
"""
Expand Down Expand Up @@ -1497,7 +1505,15 @@ def build_training_arguments(self, total_num_steps):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"

training_args = TrainingArguments(
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

training_args_cls = TrainingArguments
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig

training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
Expand Down Expand Up @@ -1530,17 +1546,26 @@ def build(self, total_num_steps):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer = AxolotlDPOTrainer(
self.model,
self.model_ref,
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]

# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/orpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

from ..base import load as load_base

load = partial(load_base, module="axolotl.prompt_strategies.orpo")
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
84 changes: 84 additions & 0 deletions src/axolotl/prompt_strategies/orpo/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,57 @@ def get_rejected_conversation_thread(self, prompt) -> MessageList:
)
return MessageList(messages=messages)

def get_prompt(self, prompt) -> MessageList:
"""Map the data to extract everything up to the last turn"""
total_msg_len = len(prompt["chosen"])
total_msg_turns, remainder = divmod(total_msg_len, 2)
assert remainder == 0, "invalid number of turns"

messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
for i in range(total_msg_turns):
if "prompt" in prompt:
messages.append(
Message(role="user", content=prompt["prompt"], label=False)
)
else:
messages.append(
Message(
role="user",
content=prompt["chosen"][i * 2]["content"],
label=False,
)
)
if i < total_msg_turns - 1:
messages.append(
Message(
role="assistant",
content=prompt["chosen"][i * 2 + 1]["content"],
label=False,
)
)

return MessageList(messages=messages)

def get_chosen(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["chosen"][-1]["content"], label=True
)
)
return res

def get_rejected(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["rejected"][-1]["content"], label=True
)
)
return res


class ORPOTokenizingStrategy(PromptTokenizingStrategy):
"""
Expand Down Expand Up @@ -186,3 +237,36 @@ def build_prompt(
chat_template=self.chat_template,
tokenize=False,
), True


def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()

chat_template_str = chat_templates(cfg.chat_template)

def transform_fn(sample, tokenizer=None):
res = {}

res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]

return res

return transform_fn
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
Data processing modules
"""
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,
Expand Down
24 changes: 20 additions & 4 deletions src/axolotl/utils/data/dpo.py → src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
"""data handling specific to DPO"""

import inspect
import logging
from functools import partial
from pathlib import Path
from typing import Any, List

import yaml
from datasets import concatenate_datasets, load_dataset, load_from_disk
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer

LOG = logging.getLogger("axolotl")

Expand Down Expand Up @@ -72,16 +75,29 @@ def load_split(dataset_cfgs, _cfg):
)
split_datasets.insert(i, ds)

tokenizer = None
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
if _cfg.rl == "orpo":
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
tokenizer = load_tokenizer(_cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)

data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
split_datasets[i] = data_set
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths

Expand Down Expand Up @@ -340,8 +340,8 @@ def prepare_optim_env(cfg):


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
Expand Down Expand Up @@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)


class TestHFDPOTrainerBuilder:
class TestHFRLTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""

def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
Expand Down

0 comments on commit 7d1d22f

Please sign in to comment.