Skip to content

Model is not completely moved to a different device when using .cuda() or .cpu() #599

@giacomoguiduzzi

Description

@giacomoguiduzzi

Hi Wenjie,

I was debugging an issue with StemGNN allocating too much VRAM on the backward() step, causing an OOM Exception. To handle this I wanted to finish the training on CPU, as it only happens with some model configurations. When doing so, I couldn't run the training as something was left on the CUDA device. Debugging the training function I discovered that the optimizer is not moved as it is not part of the modules.

1. System Info

  • PyPOTS version 0.8.1
  • CUDA 12.4
  • Intel Xeon Gold 6338
  • NVIDIA A40

2. Information

  • The official example scripts
  • My own created scripts

I solved the issue with the following:

imputer.device = torch.device("cpu")
# move model to CPU
imputer.model = imputer.model.cpu()
_move_optimizer_state_to_device(imputer.optimizer.torch_optimizer, 'cpu')
# move data to CPU
train_set = train_set.to(imputer.device)

imputer.fit(
    train_set={"X": train_set},
)

The _move_optimizer_state_to_device is as follows:

def _move_optimizer_state_to_device(optimizer, device: Union[torch.device, int, str]):
    """Recursively moves optimizer state tensors to the specified device."""
    for param_group in optimizer.state_dict()["param_groups"]:
        for param_id in param_group["params"]:
            if param_id in optimizer.state.keys():
                param_state = optimizer.state[param_id]
            else:
                param_state = optimizer.state
            # Handle momentum_buffer outside the param loop, as a special case
            if "momentum_buffer" in list(
                {key for param in param_state.values() for key in param.keys()}
            ):
                for param_state_dict in param_state.values():
                    for param_name, param in param_state_dict.items():
                        if param_name == "momentum_buffer":
                            if isinstance(param, tuple):
                                param_state_dict[param_name] = tuple(
                                    (
                                        item.to(device)
                                        if isinstance(item, torch.Tensor)
                                        else item
                                    )
                                    for item in param
                                )
                            elif isinstance(param, torch.Tensor):
                                param_state_dict[param_name] = param.to(device)

            # Iterate over the other keys in param_state
            for param_state_dict in param_state.values():
                for param_name, param in param_state_dict.items():
                    if param_name == "step":
                        continue

                    if param_name != "momentum_buffer" and isinstance(
                        param, torch.Tensor
                    ):
                        param_state_dict[param_name] = param.to(device)

I wrote the function looking at the Adam wrapper I found in PyPOTS and its state_dict, so I'm not sure it works in every case. I am using PyPOTS version 0.8.1.

3. Reproduction

  1. Instantiate a StemGNN model on cuda:0;
  2. Move it to CPU;
  3. Run fit().

4. Expected behavior

The model correctly moves all of its parts on CPU and runs the .fit() function correctly, without raising Exceptions.


I wanted to know what you think about it and ask you if I could test this on the latest PyPOTS version to create a PR, if you like this approach of course. In case you'd appreciate it, I wanted to ask you where to include this function. I noticed that .cuda() or .cpu() are functions from the PyTorch nn.Module class, so I think it would be cool to overload them with a PyPOTS version that calls PyTorch's one and _move_optimizer_state_to_device. Also, I tested this only on StemGNN for now so it's better to test it with other optimizers or models if we want to include it.

Looking forward to your kind response.

Best Regards,
Giacomo Guiduzzi

Metadata

Metadata

Labels

bugSomething isn't workingcompletedThe issue has been completed

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions