Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When training from saved checkpoint: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #395

Open
KadriMufti opened this issue Apr 1, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@KadriMufti
Copy link

Version

Version: 2.4.1.post1
Summary: This tool provides the state-of-the-art models for aspect term extraction (ATE), aspect polarity classification (APC), and text classification (TC).
Home-page: https://github.com/yangheng95/PyABSA
Author: Yang, Heng
Author-email: [email protected]
License: MIT
Location: /usr/local/lib/python3.8/dist-packages
Requires: metric-visualizer, boostaug, networkx, seqeval, torch, sentencepiece, protobuf, update-checker, pytorch-warmup, transformers, tqdm, findfile, pandas, typing-extensions, gitpython, termcolor, spacy, autocuda
Required-by: boostaug
Name: torch
Version: 2.2.1+cu121
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3
Location: /usr/local/lib/python3.8/dist-packages
Requires: networkx, sympy, nvidia-cuda-nvrtc-cu12, triton, nvidia-cusolver-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, typing-extensions, fsspec, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, jinja2, nvidia-curand-cu12, filelock, nvidia-cuda-runtime-cu12, nvidia-cusparse-cu12
Required-by: trl, torchvision, torchaudio, timm, pytorch-warmup, pyabsa, peft, OpenNMT-py, flash-attn, deepspeed, accelerate
Name: transformers
Version: 4.40.0.dev0
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: [email protected]
License: Apache 2.0 License
Location: /usr/local/lib/python3.8/dist-packages
Requires: regex, requests, packaging, pyyaml, filelock, tokenizers, numpy, safetensors, tqdm, huggingface-hub
Required-by: trl, pyabsa, peft

Describe the bug
When the training of a model is done, the checkpoint is saved. When I want to continue training from that saved checkpoint, I get the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[54], line 1
----> 1 trainer = ATEPC.ATEPCTrainer(
      2     config=config,
      3     dataset=my_dataset,
      4     from_checkpoint=checkpoint,  # if you want to resume training from our pretrained checkpoints, you can pass the checkpoint name here
      5     auto_device=DeviceTypeOption.AUTO,  # use cuda if available
      6     checkpoint_save_mode=2, # ModelSaveOption.SAVE_MODEL_STATE_DICT,  # save state dict only instead of the whole model
      7     load_aug=False,  # there are some augmentation dataset for integrated datasets, you use them by setting load_aug=True to improve performance
      8     path_to_save=f'/app/aspect/code4_pyabsa/NEW_ATEPC_MULTILINGUAL_CHECKPOINT_{model}_4/'
      9 )

File /usr/local/lib/python3.8/dist-packages/pyabsa/tasks/AspectTermExtraction/trainer/atepc_trainer.py:69, in ATEPCTrainer.__init__(self, config, dataset, from_checkpoint, checkpoint_save_mode, auto_device, path_to_save, load_aug)
     64 self.config.task_code = TaskCodeOption.Aspect_Term_Extraction_and_Classification
     65 self.config.task_name = TaskNameOption().get(
     66     TaskCodeOption.Aspect_Term_Extraction_and_Classification
     67 )
---> 69 self._run()

File /usr/local/lib/python3.8/dist-packages/pyabsa/framework/trainer_class/trainer_template.py:240, in Trainer._run(self)
    238 self.config.seed = s
    239 if self.config.checkpoint_save_mode:
--> 240     model_path.append(self.training_instructor(self.config).run())
    241 else:
    242     # always return the last trained model if you don't save trained model
    243     model = self.inference_model_class(
    244         checkpoint=self.training_instructor(self.config).run()
    245     )

File /usr/local/lib/python3.8/dist-packages/pyabsa/tasks/AspectTermExtraction/instructor/atepc_instructor.py:795, in ATEPCTrainingInstructor.run(self)
    794 def run(self):
--> 795     return self._train(criterion=None)

File /usr/local/lib/python3.8/dist-packages/pyabsa/framework/instructor_class/instructor_template.py:368, in BaseTrainingInstructor._train(self, criterion)
    365     return self._k_fold_train_and_evaluate(criterion)
    366 # Train and evaluate the model if there is only one validation dataloader
    367 else:
--> 368     return self._train_and_evaluate(criterion)

File /usr/local/lib/python3.8/dist-packages/pyabsa/tasks/AspectTermExtraction/instructor/atepc_instructor.py:334, in ATEPCTrainingInstructor._train_and_evaluate(self, criterion)
    322         loss_ate, loss_apc = self.model(
    323             input_ids_spc,
    324             token_type_ids=segment_ids,
   (...)
    331             lcf_cdw_vec=lcf_cdw_vec,
    332         )
    333 else:
--> 334     loss_ate, loss_apc = self.model(
    335         input_ids_spc,
    336         token_type_ids=segment_ids,
    337         attention_mask=input_mask,
    338         labels=label_ids,
    339         polarity=polarity,
    340         valid_ids=valid_ids,
    341         attention_mask_label=l_mask,
    342         lcf_cdm_vec=lcf_cdm_vec,
    343         lcf_cdw_vec=lcf_cdw_vec,
    344     )
    345 # for multi-gpu, average loss by gpu instance number
    346 if self.config.auto_device == DeviceTypeOption.ALL_CUDA:

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /usr/local/lib/python3.8/dist-packages/pyabsa/tasks/AspectTermExtraction/models/__lcf__/fast_lcf_atepc.py:75, in FAST_LCF_ATEPC.forward(self, input_ids_spc, token_type_ids, attention_mask, labels, polarity, valid_ids, attention_mask_label, lcf_cdm_vec, lcf_cdw_vec)
     73     input_ids = self.get_ids_for_local_context_extractor(input_ids_spc)
     74     labels = self.get_batch_token_labels_bert_base_indices(labels)
---> 75     global_context_out = self.bert4global(
     76         input_ids=input_ids, attention_mask=attention_mask
     77     )["last_hidden_state"]
     78 else:
     79     global_context_out = self.bert4global(
     80         input_ids=input_ids_spc, attention_mask=attention_mask
     81     )["last_hidden_state"]

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /transformers/src/transformers/models/deberta_v2/modeling_deberta_v2.py:1058, in DebertaV2Model.forward(self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, output_attentions, output_hidden_states, return_dict)
   1055 if token_type_ids is None:
   1056     token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
-> 1058 embedding_output = self.embeddings(
   1059     input_ids=input_ids,
   1060     token_type_ids=token_type_ids,
   1061     position_ids=position_ids,
   1062     mask=attention_mask,
   1063     inputs_embeds=inputs_embeds,
   1064 )
   1066 encoder_outputs = self.encoder(
   1067     embedding_output,
   1068     attention_mask,
   (...)
   1071     return_dict=return_dict,
   1072 )
   1073 encoded_layers = encoder_outputs[1]

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /transformers/src/transformers/models/deberta_v2/modeling_deberta_v2.py:896, in DebertaV2Embeddings.forward(self, input_ids, token_type_ids, position_ids, mask, inputs_embeds)
    893         mask = mask.unsqueeze(2)
    894     mask = mask.to(embeddings.dtype)
--> 896     embeddings = embeddings * mask
    898 embeddings = self.dropout(embeddings)
    899 return embeddings

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Code To Reproduce

config.model = ATEPC.ATEPCModelList.FAST_LCF_ATEPC
config.evaluate_begin = 0
config.max_seq_len = 512
config.num_epoch = 50
config.batch_size = 16
config.patience = 10
config.log_step = -1
config.seed = [1]
config.show_metric = True
config.gradient_accumulation_steps = 4
config.verbose = False  # If verbose == True, PyABSA will output the model strcture and several processed data examples

config.pretrained_bert = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" 
config.notice = (
    f"This is a finetuned aspect term extraction model, based on {config.pretrained_bert}, using combined Arabic and English data from various sources."  # for memos usage
)

model = config.pretrained_bert.split('/')[-1]
base_path =  f'/app/aspect/code4_pyabsa/NEW_ATEPC_MULTILINGUAL_CHECKPOINT_{model}_3/'
# checkpoint = base_path + get_latest_checkpoint(base_path)
trainer = ATEPC.ATEPCTrainer(
    config=config,
    dataset=my_dataset,
    auto_device=DeviceTypeOption.AUTO, 
    checkpoint_save_mode=2,
    load_aug=False, 
    path_to_save=base_path
)

Expected behavior

I expect the training should continue from the saved checkpoint without problems. How do I solve this issue when training from a saved checkpoint?

Screenshots

image

@KadriMufti KadriMufti added the bug Something isn't working label Apr 1, 2024
@yangheng95
Copy link
Owner

Can you set auto_device to be a specific device, e.g., auto_device='cuda:0'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants