Skip to content

Commit

Permalink
set the number of dataset processes on the DPO Config rather than the…
Browse files Browse the repository at this point in the history
… trainer (#1762)
  • Loading branch information
winglian authored Jul 17, 2024
1 parent 8731b95 commit c86c32a
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,7 @@ def build_training_arguments(self, total_num_steps):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
Expand Down Expand Up @@ -1754,8 +1755,6 @@ def build(self, total_num_steps):
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
if self.cfg.rl == "dpo":
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
Expand Down

0 comments on commit c86c32a

Please sign in to comment.