Skip to content

Commit

Permalink
FIX: max_length and max_prompt_length was not being sent to ORPOTrain…
Browse files Browse the repository at this point in the history
…er (#1584)

* FIX: TRL trainer preprocessing step was running in one process

* FIX: max_length and max_prompt_length was not being sent to ORPOTrainer

* FIX: Change ORPO max prompt length to 1/4 of max length, otherwise we get strange behaviour

* FIX: Removed change from a different PR

* FIX: Black fix

* explicitly set max prompt len for orpo config

---------

Co-authored-by: Ali Mosavian <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
3 people authored May 14, 2024
1 parent 1634ac8 commit 1e1921b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,9 @@ def build_training_arguments(self, total_num_steps):
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len

training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size,
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ class Config:

sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
max_prompt_len: int = Field(
default=512, metadata={"help": "maximum prompt length for RL training"}
)
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
Expand Down

0 comments on commit 1e1921b

Please sign in to comment.