Hi Wenjie,
I was debugging an issue with StemGNN allocating too much VRAM on the backward() step, causing an OOM Exception. To handle this I wanted to finish the training on CPU, as it only happens with some model configurations. When doing so, I couldn't run the training as something was left on the CUDA device. Debugging the training function I discovered that the optimizer is not moved as it is not part of the modules.
1. System Info
- PyPOTS version 0.8.1
- CUDA 12.4
- Intel Xeon Gold 6338
- NVIDIA A40
2. Information
I solved the issue with the following:
imputer.device = torch.device("cpu")
# move model to CPU
imputer.model = imputer.model.cpu()
_move_optimizer_state_to_device(imputer.optimizer.torch_optimizer, 'cpu')
# move data to CPU
train_set = train_set.to(imputer.device)
imputer.fit(
train_set={"X": train_set},
)
The _move_optimizer_state_to_device is as follows:
def _move_optimizer_state_to_device(optimizer, device: Union[torch.device, int, str]):
"""Recursively moves optimizer state tensors to the specified device."""
for param_group in optimizer.state_dict()["param_groups"]:
for param_id in param_group["params"]:
if param_id in optimizer.state.keys():
param_state = optimizer.state[param_id]
else:
param_state = optimizer.state
# Handle momentum_buffer outside the param loop, as a special case
if "momentum_buffer" in list(
{key for param in param_state.values() for key in param.keys()}
):
for param_state_dict in param_state.values():
for param_name, param in param_state_dict.items():
if param_name == "momentum_buffer":
if isinstance(param, tuple):
param_state_dict[param_name] = tuple(
(
item.to(device)
if isinstance(item, torch.Tensor)
else item
)
for item in param
)
elif isinstance(param, torch.Tensor):
param_state_dict[param_name] = param.to(device)
# Iterate over the other keys in param_state
for param_state_dict in param_state.values():
for param_name, param in param_state_dict.items():
if param_name == "step":
continue
if param_name != "momentum_buffer" and isinstance(
param, torch.Tensor
):
param_state_dict[param_name] = param.to(device)
I wrote the function looking at the Adam wrapper I found in PyPOTS and its state_dict, so I'm not sure it works in every case. I am using PyPOTS version 0.8.1.
3. Reproduction
- Instantiate a
StemGNN model on cuda:0;
- Move it to CPU;
- Run
fit().
4. Expected behavior
The model correctly moves all of its parts on CPU and runs the .fit() function correctly, without raising Exceptions.
I wanted to know what you think about it and ask you if I could test this on the latest PyPOTS version to create a PR, if you like this approach of course. In case you'd appreciate it, I wanted to ask you where to include this function. I noticed that .cuda() or .cpu() are functions from the PyTorch nn.Module class, so I think it would be cool to overload them with a PyPOTS version that calls PyTorch's one and _move_optimizer_state_to_device. Also, I tested this only on StemGNN for now so it's better to test it with other optimizers or models if we want to include it.
Looking forward to your kind response.
Best Regards,
Giacomo Guiduzzi
Hi Wenjie,
I was debugging an issue with StemGNN allocating too much VRAM on the
backward()step, causing an OOM Exception. To handle this I wanted to finish the training on CPU, as it only happens with some model configurations. When doing so, I couldn't run the training as something was left on the CUDA device. Debugging the training function I discovered that the optimizer is not moved as it is not part of the modules.1. System Info
2. Information
I solved the issue with the following:
The
_move_optimizer_state_to_deviceis as follows:I wrote the function looking at the Adam wrapper I found in PyPOTS and its state_dict, so I'm not sure it works in every case. I am using PyPOTS version 0.8.1.
3. Reproduction
StemGNNmodel oncuda:0;fit().4. Expected behavior
The model correctly moves all of its parts on CPU and runs the
.fit()function correctly, without raising Exceptions.I wanted to know what you think about it and ask you if I could test this on the latest PyPOTS version to create a PR, if you like this approach of course. In case you'd appreciate it, I wanted to ask you where to include this function. I noticed that
.cuda()or.cpu()are functions from the PyTorchnn.Moduleclass, so I think it would be cool to overload them with a PyPOTS version that calls PyTorch's one and_move_optimizer_state_to_device. Also, I tested this only onStemGNNfor now so it's better to test it with other optimizers or models if we want to include it.Looking forward to your kind response.
Best Regards,
Giacomo Guiduzzi