generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Labels
🐛 bugSomething isn't workingSomething isn't working
Description
Reproduction
import torch
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from trl import SFTConfig, SFTTrainer
import numpy as np
model_id = "google/gemma-3-4b-pt"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="cpu",
dtype=torch.bfloat16,
token="<token>",
)
processor = AutoProcessor.from_pretrained(
model_id, token="<token>"
)
trainer = SFTTrainer(
model=model,
args=SFTConfig(use_cpu=True, max_length=None, bf16=True, fp16=False),
train_dataset=[
{
"images": [
Image.fromarray(
np.random.randint(0, 255, size=(64, 64), dtype=np.uint8)
)
],
"prompt": processor.boi_token,
"completion": "something",
}
],
processing_class=processor,
)
trainer.train()
I believe the issue stems from trl.trainer.sft_trainer.DataCollatorForVisionLanguageModeling._collate_prompt_completion
only concatenating the "input_ids"
and "attention_mask"
from processed_prompts
and processed_completions
. A Gemma3Model
also takes along "token_type_ids"
which will only be the output from processed_prompts
.
outputs:
Traceback (most recent call last):
File "/home/karel/Projects/trl-bug/main.py", line 35, in <module>
trainer.train()
~~~~~~~~~~~~~^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 2328, in train
return inner_training_loop(
args=args,
...<2 lines>...
ignore_keys_for_eval=ignore_keys_for_eval,
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 2672, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/trl/trainer/sft_trainer.py", line 1189, in training_step
return super().training_step(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 4009, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/trl/trainer/sft_trainer.py", line 1103, in compute_loss
(loss, outputs) = super().compute_loss(
~~~~~~~~~~~~~~~~~~~~^
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/trainer.py", line 4099, in compute_loss
outputs = model(**inputs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 818, in forward
return model_forward(*args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 806, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1077, in forward
outputs = self.model(
input_ids=input_ids,
...<12 lines>...
**lm_kwargs,
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/utils/generic.py", line 940, in wrapper
output = func(self, *args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 933, in forward
"full_attention": create_causal_mask(**mask_kwargs),
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 822, in create_causal_mask
causal_mask = mask_interface(
batch_size=batch_size,
...<7 lines>...
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 392, in sdpa_mask_recent_torch
causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
return _flat_vmap(
func,
...<6 lines>...
**kwargs,
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
return _flat_vmap(
func,
...<6 lines>...
**kwargs,
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
return _flat_vmap(
func,
...<6 lines>...
**kwargs,
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
return _flat_vmap(
func,
...<6 lines>...
**kwargs,
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 54, in and_mask
result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/masking_utils.py", line 68, in or_mask
result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 743, in inner_mask
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 141, in __torch_function__
return mod_index(args[0], index_args)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/autograd/function.py", line 586, in apply
return custom_function_call(cls, *args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
return super().__call__(autograd_function, *args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
return wrapper()
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
return self.dispatch(
~~~~~~~~~~~~~^
dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
return dispatch_functorch(self, args, kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
return interpreter.process(op, args, kwargs)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
return kernel(self, *args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
return custom_function_call_vmap_generate_rule(
interpreter, autograd_function, *operands
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
outputs = custom_function_call(vmapped_function, *unwrapped_operands)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
return super().__call__(autograd_function, *args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
return wrapper()
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
return self.dispatch(
~~~~~~~~~~~~~^
dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
return dispatch_functorch(self, args, kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
return interpreter.process(op, args, kwargs)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
return kernel(self, *args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
return custom_function_call_vmap_generate_rule(
interpreter, autograd_function, *operands
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
outputs = custom_function_call(vmapped_function, *unwrapped_operands)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
return super().__call__(autograd_function, *args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
return wrapper()
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
return self.dispatch(
~~~~~~~~~~~~~^
dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
return dispatch_functorch(self, args, kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
return interpreter.process(op, args, kwargs)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
return kernel(self, *args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
return custom_function_call_vmap_generate_rule(
interpreter, autograd_function, *operands
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
outputs = custom_function_call(vmapped_function, *unwrapped_operands)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
return super().__call__(autograd_function, *args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 524, in __call__
return wrapper()
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 520, in wrapper
return self.dispatch(
~~~~~~~~~~~~~^
dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 383, in dispatch
return dispatch_functorch(self, args, kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 312, in dispatch_functorch
return interpreter.process(op, args, kwargs)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/pyfunctorch.py", line 139, in process
return kernel(self, *args, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 300, in custom_function_call_vmap
return custom_function_call_vmap_generate_rule(
interpreter, autograd_function, *operands
)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 384, in custom_function_call_vmap_generate_rule
outputs = custom_function_call(vmapped_function, *unwrapped_operands)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 50, in __call__
return autograd_function.apply(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/autograd/function.py", line 576, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
outputs, out_dims = restore_vmap(
~~~~~~~~~~~~~
autograd_function.forward, in_dims, batch_size, randomness
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(*operands)
~^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
outputs, out_dims = restore_vmap(
~~~~~~~~~~~~~
autograd_function.forward, in_dims, batch_size, randomness
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(*operands)
~^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
outputs, out_dims = restore_vmap(
~~~~~~~~~~~~~
autograd_function.forward, in_dims, batch_size, randomness
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(*operands)
~^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/autograd_function.py", line 400, in forward
outputs, out_dims = restore_vmap(
~~~~~~~~~~~~~
autograd_function.forward, in_dims, batch_size, randomness
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(*operands)
~^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 512, in inner
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 100, in forward
return torch.ops.aten.index(x, indices)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/karel/Projects/trl-bug/.venv/lib/python3.13/site-packages/torch/_ops.py", line 1243, in __call__
return self._op(*args, **kwargs)
~~~~~~~~^^^^^^^^^^^^^^^^^
IndexError: index 260 is out of bounds for dimension 1 with size 260
System Info
- Platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
- Python version: 3.13.7
- TRL version: 0.23.0
- PyTorch version: 2.8.0+cu129
- accelerator(s): NVIDIA RTX 1000 Ada Generation Laptop GPU
- Transformers version: 4.56.2
- Accelerate version: 1.10.1
- Accelerate config: not found
- Datasets version: 4.1.1
- HF Hub version: 0.35.3
- bitsandbytes version: not installed
- DeepSpeed version: not installed
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: not installed
- vLLM version: not installed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete
Metadata
Metadata
Assignees
Labels
🐛 bugSomething isn't workingSomething isn't working