Skip to content

sft_trainer.DataCollatorForVisionLanguageModelling does not account for "non-standard" processor outputs (e.g. Gemma 3) #4189

@KarelKenens

Description

@KarelKenens

Reproduction

import torch
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from trl import SFTConfig, SFTTrainer
import numpy as np

model_id = "google/gemma-3-4b-pt"

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cpu",
    dtype=torch.bfloat16,
    token="<token>",
)
processor = AutoProcessor.from_pretrained(
    model_id, token="<token>"
)

trainer = SFTTrainer(
    model=model,
    args=SFTConfig(use_cpu=True, max_length=None, bf16=True, fp16=False),
    train_dataset=[
        {
            "images": [
                Image.fromarray(
                    np.random.randint(0, 255, size=(64, 64), dtype=np.uint8)
                )
            ],
            "prompt": processor.boi_token,
            "completion": "something",
        }
    ],
    processing_class=processor,
)
trainer.train()

I believe the issue stems from trl.trainer.sft_trainer.DataCollatorForVisionLanguageModeling._collate_prompt_completion only concatenating the "input_ids" and "attention_mask" from processed_prompts and processed_completions. A Gemma3Model also takes along "token_type_ids" which will only be the output from processed_prompts.
outputs:

Traceback (most recent call last):
  File "/home/karel/Projects/trl-bug/main.py", line 35, in <module>
    trainer.train()
    ~~~~~~~~~~~~~^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
        args=args,
    ...<2 lines>...
        ignore_keys_for_eval=ignore_keys_for_eval,
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 2672, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/trl/trainer/sft_trainer.py", line 1189, in training_step
    return super().training_step(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 4009, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/trl/trainer/sft_trainer.py", line 1103, in compute_loss
    (loss, outputs) = super().compute_loss(
                      ~~~~~~~~~~~~~~~~~~~~^
        model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 4099, in compute_loss
    outputs = model(**inputs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 818, in forward
    return model_forward(*args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 806, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1077, in forward
    outputs = self.model(
        input_ids=input_ids,
    ...<12 lines>...
        **lm_kwargs,
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/utils/generic.py", line 940, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 933, in forward
    "full_attention": create_causal_mask(**mask_kwargs),
                      ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 822, in create_causal_mask
    causal_mask = mask_interface(
        batch_size=batch_size,
    ...<7 lines>...
        config=config,  # Pass the config as well, in case someone wants to easily have their own mask_interface
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 392, in sdpa_mask_recent_torch
    causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
        func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
    return _flat_vmap(
        func,
    ...<6 lines>...
        **kwargs,
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
        func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
    return _flat_vmap(
        func,
    ...<6 lines>...
        **kwargs,
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
        func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
    return _flat_vmap(
        func,
    ...<6 lines>...
        **kwargs,
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
        func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
    return _flat_vmap(
        func,
    ...<6 lines>...
        **kwargs,
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 54, in and_mask
    result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
                      ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 68, in or_mask
    result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
                      ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 743, in inner_mask
    is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
                      ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 141, in __torch_function__
    return mod_index(args[0], index_args)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/autograd/function.py", line 586, in apply
    return custom_function_call(cls, *args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
    return super().__call__(autograd_function, *args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
    return wrapper()
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
    return self.dispatch(
           ~~~~~~~~~~~~~^
        dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
    return dispatch_functorch(self, args, kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
    return interpreter.process(op, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
    return kernel(self, *args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
    return custom_function_call_vmap_generate_rule(
        interpreter, autograd_function, *operands
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
    outputs = custom_function_call(vmapped_function, *unwrapped_operands)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
    return super().__call__(autograd_function, *args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
    return wrapper()
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
    return self.dispatch(
           ~~~~~~~~~~~~~^
        dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
    return dispatch_functorch(self, args, kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
    return interpreter.process(op, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
    return kernel(self, *args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
    return custom_function_call_vmap_generate_rule(
        interpreter, autograd_function, *operands
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
    outputs = custom_function_call(vmapped_function, *unwrapped_operands)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
    return super().__call__(autograd_function, *args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
    return wrapper()
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
    return self.dispatch(
           ~~~~~~~~~~~~~^
        dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
    return dispatch_functorch(self, args, kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
    return interpreter.process(op, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
    return kernel(self, *args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
    return custom_function_call_vmap_generate_rule(
        interpreter, autograd_function, *operands
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
    outputs = custom_function_call(vmapped_function, *unwrapped_operands)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
    return super().__call__(autograd_function, *args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
    return wrapper()
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
    return self.dispatch(
           ~~~~~~~~~~~~~^
        dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
    return dispatch_functorch(self, args, kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
    return interpreter.process(op, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
    return kernel(self, *args, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
    return custom_function_call_vmap_generate_rule(
        interpreter, autograd_function, *operands
    )
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
    outputs = custom_function_call(vmapped_function, *unwrapped_operands)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 50, in __call__
    return autograd_function.apply(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/autograd/function.py", line 576, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
    outputs, out_dims = restore_vmap(
                        ~~~~~~~~~~~~~
        autograd_function.forward, in_dims, batch_size, randomness
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*operands)
    ~^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
    outputs, out_dims = restore_vmap(
                        ~~~~~~~~~~~~~
        autograd_function.forward, in_dims, batch_size, randomness
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*operands)
    ~^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
    outputs, out_dims = restore_vmap(
                        ~~~~~~~~~~~~~
        autograd_function.forward, in_dims, batch_size, randomness
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*operands)
    ~^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
    outputs, out_dims = restore_vmap(
                        ~~~~~~~~~~~~~
        autograd_function.forward, in_dims, batch_size, randomness
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*operands)
    ~^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 100, in forward
    return torch.ops.aten.index(x, indices)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
IndexError: index 260 is out of bounds for dimension 1 with size 260

System Info

  • Platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
  • Python version: 3.13.7
  • TRL version: 0.23.0
  • PyTorch version: 2.8.0+cu129
  • accelerator(s): NVIDIA RTX 1000 Ada Generation Laptop GPU
  • Transformers version: 4.56.2
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • Datasets version: 4.1.1
  • HF Hub version: 0.35.3
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: not installed
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions