Skip to content

SFTTrainer fails with DFT loss during evaluation #4096

@solenetarride

Description

@solenetarride

Reproduction

Hi!

I tried to fine-tune a model using the new DFT loss. Training fails during the evaluation step because num_items_in_batch is None at this point.

Here is a minimal script to reproduce the error:

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train").train_test_split()

config = SFTConfig(
    eval_steps=1,
    eval_strategy="steps",
    loss_type="dft"
    )

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    args=config,
)
trainer.train()

Output:

Traceback (most recent call last):
  File "/lustre/fswork/projects/rech/yfq/ubz97wr/train_dft.py", line 18, in <module>
    trainer.train()
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 2754, in _inner_training_loop
    self._maybe_log_save_evaluate(
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 3227, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 3176, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 4469, in evaluate
    output = eval_loop(
             ^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 4665, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 4881, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 1103, in compute_loss
    (loss, outputs) = super().compute_loss(
                      ^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/transformers/trainer.py", line 4113, in compute_loss
    loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fsn1/projects/rech/yfq/ubz97wr/envs/test_trl/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 493, in dft_loss
    loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for /: 'Tensor' and 'NoneType'

System Info

  • Platform: Linux-5.14.0-427.76.1.el9_4.x86_64-x86_64-with-glibc2.34
  • Python version: 3.12.11
  • TRL version: 0.23.0
  • PyTorch version: 2.8.0
  • accelerator(s): NVIDIA H100 80GB HBM3
  • Transformers version: 4.56.1
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • Datasets version: 4.1.0
  • HF Hub version: 0.35.0
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: not installed
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

Labels

🏋 SFTRelated to SFT🐛 bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions