Skip to content

Commit

Permalink
add memory profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 2, 2024
1 parent 58a9aca commit 4bc0200
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.models.llama import get_model
from zeroband.utils.profiler import MemoryProfiler
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
from zeroband.checkpoint import CkptManager, TrainingProgress
Expand All @@ -47,13 +48,20 @@ class OptimConfig(BaseConfig):
batch_size: int = 512


class MemoryProfilerConfig(BaseConfig):
freq: int = 10
snapshot_dir: str


class TrainConfig(BaseConfig):
micro_bs: int
torch_compile: bool = True
sharding_strategy: str = "SHARD_GRAD_OP"

log_model_hash: bool = False

memory_profiler: MemoryProfilerConfig | None = None


class CkptConfig(BaseConfig):
path: str
Expand Down Expand Up @@ -193,6 +201,8 @@ def train(config: Config):
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False)

gpu_mem_monitor = GPUMemoryMonitor()
if config.train.memory_profiler is not None:
memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir)

train_dataloader_iterator = iter(train_dataloader)

Expand Down Expand Up @@ -280,6 +290,9 @@ def train(config: Config):

logger.info(log)

if memory_profiler is not None:
memory_profiler.step()

if config.diloco is not None:
if config.train.log_model_hash:
with FSDP.summon_full_params(model):
Expand Down
42 changes: 42 additions & 0 deletions src/zeroband/utils/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import pickle
import torch
from zeroband.utils.logging import get_logger

from zeroband.utils.world_info import get_world_info


_MAX_ENTRIES = 10000


class MemoryProfiler:
"""Pytorch Memory Profiler.
The output are pickles file that can be visualized here: https://pytorch.org/memory_viz
"""

def __init__(self, freq: int, snapshot_dir: str):
torch.cuda.memory._record_memory_history(max_entries=_MAX_ENTRIES)
self.freq = freq

self.world_info = get_world_info()
self.logger = get_logger()
self.step_num = 0

os.makedirs(snapshot_dir, exist_ok=True)
self.snapshot_dir = snapshot_dir

def step(self):
self.step_num += 1
if self.step_num % self.freq != 0:
return

dir_name = f"iteration_{self.step_num}"

curr_snapshot_dir = os.path.join(self.snapshot_dir, dir_name)
if not os.path.exists(curr_snapshot_dir):
os.makedirs(curr_snapshot_dir, exist_ok=True)

with open(f"{curr_snapshot_dir}/rank{self.world_info.rank}_memory_snapshot.pickle", "wb") as output:
pickle.dump(torch.cuda.memory._snapshot(), output)

torch.distributed.barrier()

0 comments on commit 4bc0200

Please sign in to comment.