diff --git a/scripts/skip_data.py b/scripts/skip_data.py index 4301ffc5..4b32b55e 100644 --- a/scripts/skip_data.py +++ b/scripts/skip_data.py @@ -71,7 +71,7 @@ def skip_data(config: Config): if total_steps >= config.optim.total_steps: break - CkptManager.save_data_v2(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank) + CkptManager.save_data(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank) logger.info("skipped data up to step: %d", config.optim.total_steps) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index 64d27513..3a40afbe 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -6,7 +6,7 @@ import shutil import threading import time -from typing import Any, Literal +from typing import Any import uuid import fsspec from fsspec.generic import rsync as rsync_fsspec @@ -164,7 +164,6 @@ class CkptConfig(BaseConfig): live_recovery_rank_src: int | None = None - data_version: Literal["v1", "v2"] = "v2" data_path: str | None = None token_count: int | None = None @@ -180,12 +179,6 @@ def validate_path_and_interval(self): @model_validator(mode="after") def validate_remote_data_path(self): - if self.remote_data_path is not None and self.data_version == "v1": - raise ValueError("remote_data_path is only available with data_version v2") - - if self.remote_data_load and self.data_version == "v1": - raise ValueError("remote_data_load is only available with data_version v2") - if self.remote_data_load and self.data_path is not None: raise ValueError("remote_data_load and data_path are mutually exclusive") @@ -342,39 +335,31 @@ def _save(self, ckpt_path: str): dcp.save(self.states, checkpoint_id=ckpt_path) - ## we have two formats to to save the dataloader: - ## 1. v1: save the dataloader in the same file as the outer optimizer - ## 2. v2: save the dataloader in a data folder inside the ckpt path - - ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk - with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f: - state = {"data_loader": self.dataloader.state_dict()} if self.config.data_version == "v1" else {} - if self.diloco_offloaded_optimizer: + if self.diloco_offloaded_optimizer: + with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f: + state = {} state["optimizer"] = OuterOptimizerWrapper(self.diloco_offloaded_optimizer).state_dict() - torch.save(state, f) + torch.save(state, f) - if self.config.data_version == "v2": - data_path = os.path.join(ckpt_path, "data") - self.save_data_v2(data_path, self.dataloader, self.world_info.local_rank) + data_path = os.path.join(ckpt_path, "data") + self.save_data(data_path, self.dataloader, self.world_info.local_rank) - non_error_barrier() + non_error_barrier() - if self.config.remote_data_path is not None: - remote_data_path = os.path.join( - self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.step}" - ) - latest_remote_data_path = os.path.join( - self.config.remote_data_path, f"data_{self.data_rank}", "latest" - ) + if self.config.remote_data_path is not None: + remote_data_path = os.path.join( + self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.step}" + ) + latest_remote_data_path = os.path.join(self.config.remote_data_path, f"data_{self.data_rank}", "latest") - self._async_save_remote(data_path, remote_data_path, blocking=False) - self._async_save_remote(data_path, latest_remote_data_path, blocking=False) + self._async_save_remote(data_path, remote_data_path, blocking=False) + self._async_save_remote(data_path, latest_remote_data_path, blocking=False) gc.collect() @staticmethod - def save_data_v2(data_path: str, dataloader, local_rank: int): + def save_data(data_path: str, dataloader, local_rank: int): os.makedirs(data_path, exist_ok=True) with open(os.path.join(data_path, f"_{local_rank}.pt"), "wb") as f: state = {"data_loader": dataloader.state_dict()} @@ -422,33 +407,14 @@ def _del__(self): @torch.no_grad() def _load_data(self, resume_ckpt_path: str): - ## we have two formats to to save the dataloader: - ## 1. v1: save the dataloader in the same file as the outer optimizer - ## 2. v2: save the dataloader in a data folder inside the ckpt path self._logger.debug(f"loading data from {resume_ckpt_path}") world_info = get_world_info() - if self.config.data_version == "v2": - data_path = os.path.join(resume_ckpt_path, "data") - - if os.path.exists(os.path.join(data_path, f"_{world_info.local_rank}.pt")): - with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f: - state = torch.load(f) - self.dataloader.load_state_dict(state["data_loader"]) - return - else: - self._logger.debug(f"Data version is v2 but data folder {data_path} does not exist. trying v1 loading") - - with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: - rank_state_dict = torch.load(f) + data_path = os.path.join(resume_ckpt_path, "data") - try: - self.dataloader.load_state_dict(rank_state_dict["data_loader"]) - except KeyError as e: - self._logger.warning( - "Data_loader state_dict is not found. You probably are loading a v2 ckpt with v1 dataloader. Aborting" - ) - raise e + with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f: + state = torch.load(f) + self.dataloader.load_state_dict(state["data_loader"]) @torch.no_grad() def load( @@ -471,6 +437,14 @@ def load( world_info = get_world_info() + files = os.listdir(resume_ckpt_path) + + if len(files) == 1 and files[0].startswith("diloco_"): + self._logger.warning( + f"Loading diloco ckpt from {files[0]}. This is deprecated and will be removed in the future" + ) + resume_ckpt_path = os.path.join(resume_ckpt_path, files[0]) + dcp.load(self.states, checkpoint_id=resume_ckpt_path) if self.config.token_count is not None: