diff --git a/src/zeroband/train.py b/src/zeroband/train.py
index 38cf1aaa..a48f0d83 100644
--- a/src/zeroband/train.py
+++ b/src/zeroband/train.py
@@ -21,7 +21,7 @@
 from zeroband.diloco import Diloco, DilocoConfig
 from zeroband.comms import ElasticDeviceMesh
 
-from zeroband.utils import PerfCounter, get_module_signature, get_sharding_strategy
+from zeroband.utils import GPUMemoryMonitor, PerfCounter, get_module_signature, get_sharding_strategy
 from zeroband.utils.monitor import WandbMonitor, DummyMonitor
 from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
 from zeroband.models.llama import get_model
@@ -192,6 +192,8 @@ def train(config: Config):
         logger_cls = WandbMonitor if config.metric_logger_type == "wandb" else DummyMonitor
         metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False)
 
+    gpu_mem_monitor = GPUMemoryMonitor()
+
     train_dataloader_iterator = iter(train_dataloader)
 
     num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1
@@ -254,6 +256,10 @@ def train(config: Config):
                 "Perplexity": torch.exp(loss_batch).item(),
                 "total_tokens": training_progress.total_tokens,
             }
+
+            peak_gpu_stats = gpu_mem_monitor.get_peak_stats()
+            metrics.update(peak_gpu_stats)
+
             log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}"
 
             tokens_per_second = perf_counter.get_tokens_per_second()
@@ -303,6 +309,8 @@ def train(config: Config):
             mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size
             logger.info(f"effective mfu: {mfu}")
 
+        logger.info(f"outer step peak gpu stats: {gpu_mem_monitor.format_peak_states()}")
+
         if training_progress.step >= config.optim.total_steps:
             # we only allow to break outisde of the inner loop.
             # This avoid ending the training in the middle of a the inner loop
diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py
index 8c9c8b4e..04d08ad2 100644
--- a/src/zeroband/utils/__init__.py
+++ b/src/zeroband/utils/__init__.py
@@ -1,8 +1,11 @@
 import hashlib
 import time
+from typing import Any
 import torch
 from torch.distributed.fsdp import ShardingStrategy
 
+from zeroband.utils.logging import get_logger
+
 
 __all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"]
 
@@ -121,3 +124,50 @@ def get_module_signature(module: torch.nn.Module, compress: bool = True) -> str:
         return hashlib.md5(str(state_dict_sig).encode("utf-8")).hexdigest()
     else:
         return "\n".join(f"{name}: {sig}" for name, sig in state_dict_sig.items())
+
+
+class GPUMemoryMonitor:
+    # inspired from https://github.com/pytorch/torchtitan/blob/eef8bb2b1b6f0875ab0581079e1511d51654910e/torchtitan/metrics.py#L32
+    def __init__(self, device: str = "cuda"):
+        self.device = torch.device(device)  # device object
+        self.device_capacity = torch.cuda.get_device_properties(self.device).total_memory
+        self.device_capacity_gib = self._to_gib(self.device_capacity)
+        torch.cuda.reset_peak_memory_stats()
+        torch.cuda.empty_cache()
+
+        self._logger = get_logger()
+
+    def _to_gib(self, memory_in_bytes):
+        # NOTE: GiB (gibibyte) is 1024, vs GB is 1000
+        _gib_in_bytes = 1024 * 1024 * 1024
+        memory_in_gib = memory_in_bytes / _gib_in_bytes
+        return memory_in_gib
+
+    def _to_pct(self, memory):
+        return 100 * memory / self.device_capacity
+
+    def get_peak_stats(self) -> dict[str, Any]:
+        cuda_info = torch.cuda.memory_stats(self.device)
+
+        max_active = cuda_info["active_bytes.all.peak"]
+        max_active_gib = self._to_gib(max_active)
+        max_active_pct = self._to_pct(max_active)
+
+        max_reserved = cuda_info["reserved_bytes.all.peak"]
+        max_reserved_gib = self._to_gib(max_reserved)
+        max_reserved_pct = self._to_pct(max_reserved)
+
+        return {
+            "gpu_max_active_gib": max_active_gib,
+            "gpu_max_active_pct": max_active_pct,
+            "gpu_max_reserved_gib": max_reserved_gib,
+            "gpu_max_reserved_pct": max_reserved_pct,
+        }
+
+    def reset_peak_stats(self):
+        torch.cuda.reset_peak_memory_stats()
+
+    def format_peak_states(self, peak_stats: dict[str, Any] | None = None) -> str:
+        if peak_stats is None:
+            peak_stats = self.get_peak_stats()
+        return f"Active {peak_stats['gpu_max_active_gib']:.2f} GiB ({peak_stats['gpu_max_active_pct']:.2f}%), Reserved {peak_stats['gpu_max_reserved_gib']:.2f} GiB ({peak_stats['gpu_max_reserved_pct']:.2f}%)"