diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 3350ec4..4cae8d4 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -468,8 +468,7 @@ def train(config: Config): monitor.set_stage("outer_loop") # todo we could skip this is we don't have live recovery enabled - # disable because of potential memory leak - # ckpt_manager.cache_inner_optimizer() + ckpt_manager.cache_inner_optimizer() time_start_inner = time.perf_counter() diloco.step(model=model, flag=training_progress.outer_step, num_effective_peers=num_effective_peers)