From a217dfb409d942b3d2f532071430a76ef3c368a6 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Mon, 30 Sep 2024 20:13:48 -0700 Subject: [PATCH] clean gc after ckpt to avoid oom (#28) --- src/zeroband/checkpoint.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 9bc0f003..2455a380 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import gc import multiprocessing import os import time @@ -162,6 +163,8 @@ def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") + gc.collect() # because we are badass engineer + if remote_ckpt_path is not None: self._async_save_remote(ckpt_path, remote_ckpt_path)