From c86c32a627957a08d145dee48a1a7ed2f20e4437 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jul 2024 15:38:37 -0400 Subject: [PATCH] set the number of dataset processes on the DPO Config rather than the trainer (#1762) --- src/axolotl/core/trainer_builder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b0eea55b1..3952cd593 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1687,6 +1687,7 @@ def build_training_arguments(self, total_num_steps): # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_cls = AxolotlDPOConfig if self.cfg.rpo_alpha is not None: training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha @@ -1754,8 +1755,6 @@ def build(self, total_num_steps): dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["generate_during_eval"] = True - if self.cfg.rl == "dpo": - dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model]