Skip to content

Commit

Permalink
Fix pytorch tensor leak for reporter (#18)
Browse files Browse the repository at this point in the history
* Fix pytorch tensor leak for reporter

* Fix LGTM code-style

* Remove redundant line
  • Loading branch information
Stonesjtu authored Sep 15, 2020
1 parent 65a2439 commit ca60672
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
2 changes: 0 additions & 2 deletions pytorch_memlab/line_profiler/extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""IPython & notebook extension interface"""
from tempfile import mkstemp

from IPython.core.magic import (
Magics,
magics_class,
Expand Down
1 change: 0 additions & 1 deletion pytorch_memlab/line_profiler/line_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from .line_records import LineRecords
from ..utils import readable_size

# Seaborn's `muted` color cycle
DEFAULT_COLUMNS = ['active_bytes.all.peak', 'reserved_bytes.all.peak']
Expand Down
16 changes: 9 additions & 7 deletions pytorch_memlab/mem_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MemReporter():
"""
def __init__(self, model=None):
self.tensor_name = defaultdict(list)
self.tensor_name = {}
self.device_mapping = defaultdict(list)
self.device_tensor_stat = {}
# to numbering the unknown tensors
Expand All @@ -31,18 +31,20 @@ def __init__(self, model=None):
assert isinstance(model, torch.nn.Module)
# for model with tying weight, multiple parameters may share
# the same underlying tensor
tensor_names = defaultdict(list)
for name, param in model.named_parameters():
self.tensor_name[param].append(name)
for param, name in self.tensor_name.items():
self.tensor_name[param] = '+'.join(name)
tensor_names[param].append(name)
for param, name in tensor_names.items():
self.tensor_name[id(param)] = '+'.join(name)

def _get_tensor_name(self, tensor):
if tensor in self.tensor_name:
name = self.tensor_name[tensor]
tensor_id = id(tensor)
if tensor_id in self.tensor_name:
name = self.tensor_name[tensor_id]
# use numbering if no name can be inferred
else:
name = type(tensor).__name__ + str(self.name_idx)
self.tensor_name[tensor] = name
self.tensor_name[tensor_id] = name
self.name_idx += 1
return name

Expand Down

0 comments on commit ca60672

Please sign in to comment.