diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 742a88633..2f38b12dc 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1526,6 +1526,9 @@ def build_training_arguments(self, total_num_steps): if self.cfg.rl == "orpo": training_args_cls = ORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + training_args_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0a2442d50..f1c12b2ba 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -517,6 +517,9 @@ class Config: sequence_len: int = Field(default=512) min_sample_len: Optional[int] = None + max_prompt_len: int = Field( + default=512, metadata={"help": "maximum prompt length for RL training"} + ) sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None