Skip to content

Commit

Permalink
remove deprecated v1 data ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Nov 6, 2024
1 parent 57d0b64 commit b306f16
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 55 deletions.
2 changes: 1 addition & 1 deletion scripts/skip_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def skip_data(config: Config):
if total_steps >= config.optim.total_steps:
break

CkptManager.save_data_v2(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank)
CkptManager.save_data(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank)

logger.info("skipped data up to step: %d", config.optim.total_steps)

Expand Down
74 changes: 20 additions & 54 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
import threading
import time
from typing import Any, Literal
from typing import Any
import uuid
import fsspec
from fsspec.generic import rsync as rsync_fsspec
Expand Down Expand Up @@ -164,7 +164,6 @@ class CkptConfig(BaseConfig):

live_recovery_rank_src: int | None = None

data_version: Literal["v1", "v2"] = "v2"
data_path: str | None = None

token_count: int | None = None
Expand All @@ -180,12 +179,6 @@ def validate_path_and_interval(self):

@model_validator(mode="after")
def validate_remote_data_path(self):
if self.remote_data_path is not None and self.data_version == "v1":
raise ValueError("remote_data_path is only available with data_version v2")

if self.remote_data_load and self.data_version == "v1":
raise ValueError("remote_data_load is only available with data_version v2")

if self.remote_data_load and self.data_path is not None:
raise ValueError("remote_data_load and data_path are mutually exclusive")

Expand Down Expand Up @@ -342,39 +335,31 @@ def _save(self, ckpt_path: str):

dcp.save(self.states, checkpoint_id=ckpt_path)

## we have two formats to to save the dataloader:
## 1. v1: save the dataloader in the same file as the outer optimizer
## 2. v2: save the dataloader in a data folder inside the ckpt path

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f:
state = {"data_loader": self.dataloader.state_dict()} if self.config.data_version == "v1" else {}
if self.diloco_offloaded_optimizer:
if self.diloco_offloaded_optimizer:
with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f:
state = {}
state["optimizer"] = OuterOptimizerWrapper(self.diloco_offloaded_optimizer).state_dict()

torch.save(state, f)
torch.save(state, f)

if self.config.data_version == "v2":
data_path = os.path.join(ckpt_path, "data")
self.save_data_v2(data_path, self.dataloader, self.world_info.local_rank)
data_path = os.path.join(ckpt_path, "data")
self.save_data(data_path, self.dataloader, self.world_info.local_rank)

non_error_barrier()
non_error_barrier()

if self.config.remote_data_path is not None:
remote_data_path = os.path.join(
self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.step}"
)
latest_remote_data_path = os.path.join(
self.config.remote_data_path, f"data_{self.data_rank}", "latest"
)
if self.config.remote_data_path is not None:
remote_data_path = os.path.join(
self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.step}"
)
latest_remote_data_path = os.path.join(self.config.remote_data_path, f"data_{self.data_rank}", "latest")

self._async_save_remote(data_path, remote_data_path, blocking=False)
self._async_save_remote(data_path, latest_remote_data_path, blocking=False)
self._async_save_remote(data_path, remote_data_path, blocking=False)
self._async_save_remote(data_path, latest_remote_data_path, blocking=False)

gc.collect()

@staticmethod
def save_data_v2(data_path: str, dataloader, local_rank: int):
def save_data(data_path: str, dataloader, local_rank: int):
os.makedirs(data_path, exist_ok=True)
with open(os.path.join(data_path, f"_{local_rank}.pt"), "wb") as f:
state = {"data_loader": dataloader.state_dict()}
Expand Down Expand Up @@ -422,33 +407,14 @@ def _del__(self):

@torch.no_grad()
def _load_data(self, resume_ckpt_path: str):
## we have two formats to to save the dataloader:
## 1. v1: save the dataloader in the same file as the outer optimizer
## 2. v2: save the dataloader in a data folder inside the ckpt path
self._logger.debug(f"loading data from {resume_ckpt_path}")
world_info = get_world_info()

if self.config.data_version == "v2":
data_path = os.path.join(resume_ckpt_path, "data")
data_path = os.path.join(resume_ckpt_path, "data")

if os.path.exists(os.path.join(data_path, f"_{world_info.local_rank}.pt")):
with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f:
state = torch.load(f)
self.dataloader.load_state_dict(state["data_loader"])
return
else:
self._logger.debug(f"Data version is v2 but data folder {data_path} does not exist. trying v1 loading")

with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)

try:
self.dataloader.load_state_dict(rank_state_dict["data_loader"])
except KeyError as e:
self._logger.warning(
"Data_loader state_dict is not found. You probably are loading a v2 ckpt with v1 dataloader. Aborting"
)
raise e
with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f:
state = torch.load(f)
self.dataloader.load_state_dict(state["data_loader"])

@torch.no_grad()
def load(
Expand Down

0 comments on commit b306f16

Please sign in to comment.