Skip to content

Commit

Permalink
refactor: Add get_memory_usage_mb function for memory monitoring
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 18, 2024
1 parent 30d2ad8 commit 80ffc3b
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import inspect
import itertools
import resource
import subprocess

import numpy as np
Expand Down Expand Up @@ -228,6 +229,12 @@ def print_tensor_devices(scope="local"):
print(f"{var_name}: {var_value.device}")


def get_memory_usage_mb():
# Returns the peak memory usage in MB
usage = resource.getrusage(resource.RUSAGE_SELF)
return usage.ru_maxrss / 1024 # Convert from KB to MB


# Reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
Expand Down

0 comments on commit 80ffc3b

Please sign in to comment.