Skip to content

Commit

Permalink
flag to enable/disable memory monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
Sami jaghouar authored and samsja committed Oct 5, 2024
1 parent 753fb78 commit 503cdd1
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TrainConfig(BaseConfig):
ac_ckpt: bool | int = False
log_model_hash: bool = False

memory_monitor: bool = False
memory_profiler: MemoryProfilerConfig | None = None


Expand Down Expand Up @@ -205,7 +206,8 @@ def train(config: Config):
logger_cls = WandbMonitor if config.metric_logger_type == "wandb" else DummyMonitor
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False)

gpu_mem_monitor = GPUMemoryMonitor()
if config.train.memory_monitor:
gpu_mem_monitor = GPUMemoryMonitor()
if config.train.memory_profiler is not None:
memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir)

Expand Down Expand Up @@ -271,9 +273,9 @@ def train(config: Config):
"Perplexity": torch.exp(loss_batch).item(),
"total_tokens": training_progress.total_tokens,
}

peak_gpu_stats = gpu_mem_monitor.get_peak_stats()
metrics.update(peak_gpu_stats)
if config.train.memory_monitor:
peak_gpu_stats = gpu_mem_monitor.get_peak_stats()
metrics.update(peak_gpu_stats)

log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}"

Expand Down Expand Up @@ -327,7 +329,8 @@ def train(config: Config):
mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size
logger.info(f"effective mfu: {mfu}")

logger.info(f"outer step peak gpu stats: {gpu_mem_monitor.format_peak_states()}")
if config.train.memory_monitor:
logger.info(f"outer step peak gpu stats: {gpu_mem_monitor.format_peak_states()}")

if training_progress.step >= config.optim.total_steps:
# we only allow to break outisde of the inner loop.
Expand Down

0 comments on commit 503cdd1

Please sign in to comment.