Skip to content

Commit

Permalink
fix plots
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Mar 31, 2024
1 parent 68aaa49 commit 5907216
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 6 deletions.
5 changes: 2 additions & 3 deletions src/llmtuner/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,10 @@ def get_batch_loss_metrics(
reference_chosen_logps,
reference_rejected_logps,
)
batch_loss = losses.mean()
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)

reward_accuracies = (chosen_rewards > rejected_rewards).float()

Expand All @@ -160,4 +159,4 @@ def get_batch_loss_metrics(
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()

return batch_loss, metrics
return losses.mean(), metrics
2 changes: 1 addition & 1 deletion src/llmtuner/train/dpo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])

# Evaluation
if training_args.do_eval:
Expand Down
2 changes: 1 addition & 1 deletion src/llmtuner/train/orpo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_orpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies", "sft_loss"])

# Evaluation
if training_args.do_eval:
Expand Down
2 changes: 1 addition & 1 deletion src/llmtuner/train/rm/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run_rm(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])

# Evaluation
if training_args.do_eval:
Expand Down

0 comments on commit 5907216

Please sign in to comment.