Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers PaliGemma evaluate and compute_loss fail with tensors/device errors #35990

Open
4 tasks
BlGene opened this issue Jan 31, 2025 · 7 comments
Open
4 tasks

Comments

@BlGene
Copy link

BlGene commented Jan 31, 2025

System Info

My versions are:

Python Version: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 16:05:46) [GCC 13.3.0]
Torch Version: 2.5.1+cu124
CUDA Available: True
CUDA Device Count: 2
GPU Name: NVIDIA GeForce RTX 3090
Transformers Version: 4.48.1
Tokenizers Version: 0.21.0
Accelerate Version: 1.3.0

Who can help?

@ArthurZucker , @amyeroberts, @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm loading a PaliGemma2 model google/paligemma2-3b-pt-224 and trying to fine-tune using Trainer/Seq2SeqTrainer. If I add evaluation, this fails. After doing some digging, I found that this only happens if the model is in evaluate mode.

batch = [valid_dataset[i] for i in range(8)]
inputs = collate_fn(batch)
#generate_ids = model.generate(**inputs, max_length=286+30)
trainer.model.train()
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("works")
trainer.model.train(False)
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("fails.")

I've worked around it by mokey-patching compute_loss_context_manager as follows:

orig_context_manager = trainer.compute_loss_context_manager
class TempTrainContext(object):
    def __init__(self, trainer):
        self.trainer = trainer
        self.orig_context_manager = trainer.compute_loss_context_manager
    def __enter__(self):
        self.orig_context_inst = self.orig_context_manager()
        self.orig_context_inst.__enter__()
        self.training_enter = self.trainer.model.training
        self.trainer.model.train()
    def __exit__(self, type, value, traceback):
        self.trainer.model.train(self.training_enter)
        self.orig_context_inst.__exit__(type, value, traceback)
    def __call__(self):
        return self

trainer.compute_loss_context_manager = TempTrainContext(trainer)

(Bonus question: Is this safe to do, or will I train on the test set?)
Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 8
      6 print("works")
      7 trainer.model.train(False)
----> 8 trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
      9 print("fails.")
     12 orig_context_manager = trainer.compute_loss_context_manager

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3729         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3730     inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
   3732 # Save past state if it exists
   3733 # TODO: this needs to be fixed and made cleaner later.
   3734 if self.args.past_index >= 0:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
    525     labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
    527 causal_mask = self._update_causal_mask(
    528     attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
    529 )
--> 530 outputs = self.language_model(
    531     attention_mask=causal_mask,
    532     position_ids=position_ids,
    533     past_key_values=past_key_values,
    534     inputs_embeds=inputs_embeds,
    535     use_cache=use_cache,
    536     output_attentions=output_attentions,
    537     output_hidden_states=output_hidden_states,
    538     return_dict=return_dict,
    539     cache_position=cache_position,
    540     num_logits_to_keep=num_logits_to_keep,
    541 )
    543 logits = outputs.logits
    544 loss = None

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
    840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
    843     input_ids=input_ids,
    844     attention_mask=attention_mask,
    845     position_ids=position_ids,
    846     past_key_values=past_key_values,
    847     inputs_embeds=inputs_embeds,
    848     use_cache=use_cache,
    849     output_attentions=output_attentions,
    850     output_hidden_states=output_hidden_states,
    851     return_dict=return_dict,
    852     cache_position=cache_position,
    853 )
    855 hidden_states = outputs[0]
    856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
    617     layer_outputs = self._gradient_checkpointing_func(
    618         decoder_layer.__call__,
    619         hidden_states,
   (...)
    626         cache_position,
    627     )
    628 else:
--> 629     layer_outputs = decoder_layer(
    630         hidden_states,
    631         position_embeddings=position_embeddings,
    632         attention_mask=causal_mask,
    633         position_ids=position_ids,
    634         past_key_value=past_key_values,
    635         output_attentions=output_attentions,
    636         use_cache=use_cache,
    637         cache_position=cache_position,
    638         **flash_attn_kwargs,
    639     )
    641 hidden_states = layer_outputs[0]
    643 if output_attentions:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    296 hidden_states = self.input_layernorm(hidden_states)
    298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
    300     hidden_states=hidden_states,
    301     position_embeddings=position_embeddings,
    302     attention_mask=attention_mask,
    303     position_ids=position_ids,
    304     past_key_value=past_key_value,
    305     output_attentions=output_attentions,
    306     use_cache=use_cache,
    307     cache_position=cache_position,
    308 )
    309 hidden_states = self.post_attention_layernorm(hidden_states)
    310 hidden_states = residual + hidden_states

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
    221 if past_key_value is not None:
    222     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    223     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 224     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    226 attention_interface: Callable = eager_attention_forward
    227 if self.config._attn_implementation != "eager":

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1714 else:
   1715     update_fn = self._static_update
-> 1717 return update_fn(
   1718     cache_position,
   1719     layer_idx,
   1720     key_states,
   1721     value_states,
   1722     k_out,
   1723     v_out,
   1724     k_out.shape[2],
   1725 )

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
   1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694     k_out[:, :, cache_position] = key_states
   1695     v_out[:, :, cache_position] = value_states
   1697     self.key_cache[layer_idx] = k_out

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!"

Error of Evaluator (bottom half of file): https://gist.github.com/BlGene/607c7bee450e03835aa2bf0d2fd2959a

Expected behavior

Training runs with evaluation enabled.

@BlGene BlGene added the bug label Jan 31, 2025
@zucchini-nlp
Copy link
Member

Hey @BlGene !

That should have been solved by #35164, which you can get by installing from main with !pip install --upgrade git+https://github.com/huggingface/transformers.git

@BlGene
Copy link
Author

BlGene commented Jan 31, 2025

@zucchini-nlp

No this did not fix the problem, I've upgraded transformers, also had to upgrade accelerate and then got the following error:

Python Version: 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0]
Torch Version: 2.5.1+cu124
CUDA Available: True
CUDA Device Count: 2
GPU Name: NVIDIA GeForce RTX 3090
Transformers Version: 4.49.0.dev0
Tokenizers Version: 0.21.0
Accelerate Version: 1.4.0.dev0

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[28], line 18
      1 # orig_context_manager = trainer.compute_loss_context_manager
      2 # class TempTrainContext(object):
      3 #     def __init__(self, trainer):
   (...)
     15 #         return self
     16 # trainer.compute_loss_context_manager = TempTrainContext(trainer)
---> 18 trainer.train()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:2184, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2182         hf_hub_utils.enable_progress_bars()
   2183 else:
-> 2184     return inner_training_loop(
   2185         args=args,
   2186         resume_from_checkpoint=resume_from_checkpoint,
   2187         trial=trial,
   2188         ignore_keys_for_eval=ignore_keys_for_eval,
   2189     )

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:2554, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2552     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2553     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2554     self._maybe_log_save_evaluate(
   2555         tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
   2556     )
   2557 else:
   2558     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3027, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
   3025 metrics = None
   3026 if self.control.should_evaluate:
-> 3027     metrics = self._evaluate(trial, ignore_keys_for_eval)
   3028     is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
   3030     if self.args.save_strategy == SaveStrategy.BEST:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:2981, in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2980 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2981     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2982     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2984     # Run delayed LR scheduler now that metrics are populated

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer_seq2seq.py:197, in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
    195 self.gather_function = self.accelerator.gather
    196 self._gen_kwargs = gen_kwargs
--> 197 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4001, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3998 start_time = time.time()
   4000 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 4001 output = eval_loop(
   4002     eval_dataloader,
   4003     description="Evaluation",
   4004     # No point gathering the predictions if there are no metrics, otherwise we defer to
   4005     # self.args.prediction_loss_only
   4006     prediction_loss_only=True if self.compute_metrics is None else None,
   4007     ignore_keys=ignore_keys,
   4008     metric_key_prefix=metric_key_prefix,
   4009 )
   4011 total_batch_size = self.args.eval_batch_size * self.args.world_size
   4012 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4217, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   4215     labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
   4216 if logits is not None:
-> 4217     logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
   4218     if self.preprocess_logits_for_metrics is not None:
   4219         logits = self.preprocess_logits_for_metrics(logits, labels)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/accelerator.py:2640, in Accelerator.pad_across_processes(self, tensor, dim, pad_index, pad_first)
   2607 def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
   2608     """
   2609     Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
   2610     they can safely be gathered.
   (...)
   2638     ```
   2639     """
-> 2640     return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:408, in chained_operation.<locals>.wrapper(*args, **kwargs)
    405 @wraps(function)
    406 def wrapper(*args, **kwargs):
    407     try:
--> 408         return function(*args, **kwargs)
    409     except DistributedOperationException as e:
    410         operation = f"{function.__module__}.{function.__name__}"

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:678, in pad_across_processes(tensor, dim, pad_index, pad_first)
    675     new_tensor[indices] = tensor
    676     return new_tensor
--> 678 return recursively_apply(
    679     _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
    680 )

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:107, in recursively_apply(func, data, test_type, error_on_other_type, *args, **kwargs)
     85 """
     86 Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
     87 
   (...)
    104     The same data structure as `data` with `func` applied to every object of type `main_type`.
    105 """
    106 if isinstance(data, (tuple, list)):
--> 107     return honor_type(
    108         data,
    109         (
    110             recursively_apply(
    111                 func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
    112             )
    113             for o in data
    114         ),
    115     )
    116 elif isinstance(data, Mapping):
    117     return type(data)(
    118         {
    119             k: recursively_apply(
   (...)
    123         }
    124     )

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:81, in honor_type(obj, generator)
     79     return type(obj)(*list(generator))
     80 else:
---> 81     return type(obj)(generator)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:110, in <genexpr>(.0)
     85 """
     86 Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
     87 
   (...)
    104     The same data structure as `data` with `func` applied to every object of type `main_type`.
    105 """
    106 if isinstance(data, (tuple, list)):
    107     return honor_type(
    108         data,
    109         (
--> 110             recursively_apply(
    111                 func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
    112             )
    113             for o in data
    114         ),
    115     )
    116 elif isinstance(data, Mapping):
    117     return type(data)(
    118         {
    119             k: recursively_apply(
   (...)
    123         }
    124     )

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:128, in recursively_apply(func, data, test_type, error_on_other_type, *args, **kwargs)
    126     return func(data, *args, **kwargs)
    127 elif error_on_other_type:
--> 128     raise TypeError(
    129         f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of "
    130         f"objects that are valid for `{test_type.__name__}` should be passed."
    131     )
    132 return data

TypeError: Unsupported types (<class 'transformers.cache_utils.HybridCache'>) passed to `_pad_across_processes`. Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.

@zucchini-nlp
Copy link
Member

Now it's a new error 😆 and related to accelerate, couldn't reproduce it with provided script. For accelerate ping @SunMarc

@BlGene
Copy link
Author

BlGene commented Jan 31, 2025

@zucchini-nlp

Just to clarify, the trainer.train() failed for me and produced the bug above, not the trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416) computation which now works :party:

My trainer and config are the following, can you let me know if it works for you?
(Also I'd be interested if train loss and eval loss both go to 0 for the small_dataset case.)

args_jax = Seq2SeqTrainingArguments(
    #num_train_epochs=1,
    max_steps=TRAIN_STEPS,
    remove_unused_columns=False,
    per_device_train_batch_size=BATCH_SIZE_DEV,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=3e-5,  #1e-5, 2e-5,
    #weight_decay=3e-7,
    lr_scheduler_type="cosine",
    warmup_ratio=.05,
    #gradient_checkpointing=True,
    generation_max_length=SEQLEN,
    #weight_decay=1e-6,
    #adam_beta2=0.999,
    logging_steps=10,
    optim="adafactor",
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=1,
    output_dir=save_path,
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False,
    eval_strategy="steps",
    eval_steps=4,
    per_device_eval_batch_size=BATCH_SIZE_DEV,
    eval_accumulation_steps=GRAD_ACCUM
)

from torch.utils.data import Subset
valid_dataset_small = Subset(valid_dataset, range(1))
train_dataset_small = Subset(train_dataset, range(1))
    
trainer = Seq2SeqTrainer(
    model=model,
    train_dataset=train_dataset_small,
    eval_dataset=train_dataset_small,
    data_collator=collate_fn,
    args=args_jax,
    compute_metrics=compute_metrics
)
trainer.train()

@zucchini-nlp
Copy link
Member

Can you also share the compute_metrics ?

@BlGene
Copy link
Author

BlGene commented Jan 31, 2025

Sure:

import numpy as np
def compute_metrics(eval_pred):
    predictions, label_tokens = eval_pred  # Extract predictions and labels
    if isinstance(predictions, tuple):  # Some models return tuples
        predictions = predictions[0]

    # Convert to token indices if necessary (e.g., for text generation models)
    pred_tokens = np.argmax(predictions, axis=-1)  # Assuming logits, take argmax

    pred_texts = processor.tokenizer.batch_decode(pred_tokens[:,-SEQLEN-1:], skip_special_tokens=True)
    label_text = processor.tokenizer.batch_decode(label_tokens[:,-SEQLEN-1:], skip_special_tokens=True)

    print(pred_tokens[:,-SEQLEN-1:])
    print(label_tokens[:,-SEQLEN-1:])
    print(label_text)
    print(pred_texts)
    print()
    return {"accuracy": 0}

(Should be harmless though, hopefully)

@zucchini-nlp
Copy link
Member

@BlGene thanks, I can confirm that the bug is reproducible, and it fails even with models that use default DynamicCache. Seems like we removed support for cache in tuple format in v4.48, thus causing such errors also for other models

Not sure if accelerate team has a workaround for that already, let's wait for SunMarc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants