Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tensor name alias for optimizer #62

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion pytorch_memlab/mem_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ class MemReporter():
Parameters:
- model: an extra nn.Module can be passed to infer the name
of Tensors
- pre_collect: do a garbage collection before getting remaining
Tensors, this gives cleaner outputs.
Caution: This is an intrusive change to your original code.

"""
def __init__(self, model: Optional[torch.nn.Module] = None):
def __init__(self, model: Optional[torch.nn.Module] = None, pre_collect: bool = False):
self.tensor_name = {}
self.device_mapping = defaultdict(list)
self.device_tensor_stat = {}
# to numbering the unknown tensors
self.name_idx = 0
self.pre_collect = pre_collect

tensor_names = defaultdict(list)
if model is not None:
Expand All @@ -51,6 +55,16 @@ def _get_tensor_name(self, tensor: torch.Tensor) -> str:
self.name_idx += 1
return name

def add_optimizer(self, optimizer: torch.optim.Optimizer):
optimizer_name = optimizer.__class__.__name__
for param, states in optimizer.state.items():
param_name = self.tensor_name[id(param)]
for name, tensor in states.items():
self.tensor_name[id(tensor)] = f'{optimizer_name}.{param_name}.{name}'
# self.tensor_name[id()]
# print(states)


def collect_tensor(self):
"""Collect all tensor objects tracked by python

Expand All @@ -61,6 +75,9 @@ def collect_tensor(self):
I don't know why.
"""
#FIXME: make the grad tensor collected by gc
# Do a pre-garbage collect to eliminate python garbage objects
if self.pre_collect:
gc.collect()
objects = gc.get_objects()
tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)]
for t in tensors:
Expand Down
18 changes: 18 additions & 0 deletions test/test_mem_reporter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.optim
from pytorch_memlab import MemReporter

import pytest
Expand Down Expand Up @@ -57,6 +58,23 @@ def test_reporter_tie_weight():
reporter = MemReporter(container)
reporter.report()

def test_reporter_with_optimizer():
linear = torch.nn.Linear(1024, 1024)
inp = torch.Tensor(512, 1024)
optimizer = torch.optim.Adam(linear.parameters())
# reporter = MemReporter(linear)

out = linear(inp*(inp+3)*(inp+2)).mean()
reporter = MemReporter(linear)
reporter.report()
out.backward()
# reporter.report()
optimizer.step()

reporter.add_optimizer(optimizer)
reporter.report()


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.skipif(concentrate_mode, reason='concentrate')
def test_reporter_LSTM():
Expand Down