diff --git a/pytorch_memlab/mem_reporter.py b/pytorch_memlab/mem_reporter.py index 536d220..851e68b 100644 --- a/pytorch_memlab/mem_reporter.py +++ b/pytorch_memlab/mem_reporter.py @@ -20,14 +20,18 @@ class MemReporter(): Parameters: - model: an extra nn.Module can be passed to infer the name of Tensors + - pre_collect: do a garbage collection before getting remaining + Tensors, this gives cleaner outputs. + Caution: This is an intrusive change to your original code. """ - def __init__(self, model: Optional[torch.nn.Module] = None): + def __init__(self, model: Optional[torch.nn.Module] = None, pre_collect: bool = False): self.tensor_name = {} self.device_mapping = defaultdict(list) self.device_tensor_stat = {} # to numbering the unknown tensors self.name_idx = 0 + self.pre_collect = pre_collect tensor_names = defaultdict(list) if model is not None: @@ -51,6 +55,16 @@ def _get_tensor_name(self, tensor: torch.Tensor) -> str: self.name_idx += 1 return name + def add_optimizer(self, optimizer: torch.optim.Optimizer): + optimizer_name = optimizer.__class__.__name__ + for param, states in optimizer.state.items(): + param_name = self.tensor_name[id(param)] + for name, tensor in states.items(): + self.tensor_name[id(tensor)] = f'{optimizer_name}.{param_name}.{name}' + # self.tensor_name[id()] + # print(states) + + def collect_tensor(self): """Collect all tensor objects tracked by python @@ -61,6 +75,9 @@ def collect_tensor(self): I don't know why. """ #FIXME: make the grad tensor collected by gc + # Do a pre-garbage collect to eliminate python garbage objects + if self.pre_collect: + gc.collect() objects = gc.get_objects() tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)] for t in tensors: diff --git a/test/test_mem_reporter.py b/test/test_mem_reporter.py index b7727b6..3ef9428 100644 --- a/test/test_mem_reporter.py +++ b/test/test_mem_reporter.py @@ -1,4 +1,5 @@ import torch +import torch.optim from pytorch_memlab import MemReporter import pytest @@ -57,6 +58,23 @@ def test_reporter_tie_weight(): reporter = MemReporter(container) reporter.report() +def test_reporter_with_optimizer(): + linear = torch.nn.Linear(1024, 1024) + inp = torch.Tensor(512, 1024) + optimizer = torch.optim.Adam(linear.parameters()) + # reporter = MemReporter(linear) + + out = linear(inp*(inp+3)*(inp+2)).mean() + reporter = MemReporter(linear) + reporter.report() + out.backward() + # reporter.report() + optimizer.step() + + reporter.add_optimizer(optimizer) + reporter.report() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(concentrate_mode, reason='concentrate') def test_reporter_LSTM():