Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ckpt update #145

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/skip_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def skip_data(config: Config):
if total_steps >= config.optim.total_steps:
break

CkptManager.save_data_v2(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank)
CkptManager.save_data(os.path.join(config.ckpt.data_path, "data"), train_dataloader, world_info.local_rank)

logger.info("skipped data up to step: %d", config.optim.total_steps)

Expand Down
82 changes: 28 additions & 54 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
import threading
import time
from typing import Any, Literal
from typing import Any
import uuid
import fsspec
from fsspec.generic import rsync as rsync_fsspec
Expand Down Expand Up @@ -164,7 +164,6 @@ class CkptConfig(BaseConfig):

live_recovery_rank_src: int | None = None

data_version: Literal["v1", "v2"] = "v2"
data_path: str | None = None

token_count: int | None = None
Expand All @@ -180,12 +179,6 @@ def validate_path_and_interval(self):

@model_validator(mode="after")
def validate_remote_data_path(self):
if self.remote_data_path is not None and self.data_version == "v1":
raise ValueError("remote_data_path is only available with data_version v2")

if self.remote_data_load and self.data_version == "v1":
raise ValueError("remote_data_load is only available with data_version v2")

if self.remote_data_load and self.data_path is not None:
raise ValueError("remote_data_load and data_path are mutually exclusive")

Expand Down Expand Up @@ -342,39 +335,31 @@ def _save(self, ckpt_path: str):

dcp.save(self.states, checkpoint_id=ckpt_path)

## we have two formats to to save the dataloader:
## 1. v1: save the dataloader in the same file as the outer optimizer
## 2. v2: save the dataloader in a data folder inside the ckpt path

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f:
state = {"data_loader": self.dataloader.state_dict()} if self.config.data_version == "v1" else {}
if self.diloco_offloaded_optimizer:
if self.diloco_offloaded_optimizer:
with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f:
state = {}
state["optimizer"] = OuterOptimizerWrapper(self.diloco_offloaded_optimizer).state_dict()

torch.save(state, f)
torch.save(state, f)

if self.config.data_version == "v2":
data_path = os.path.join(ckpt_path, "data")
self.save_data_v2(data_path, self.dataloader, self.world_info.local_rank)
data_path = os.path.join(ckpt_path, "data")
self.save_data(data_path, self.dataloader, self.world_info.local_rank)

non_error_barrier()
non_error_barrier()

if self.config.remote_data_path is not None:
remote_data_path = os.path.join(
self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.step}"
)
latest_remote_data_path = os.path.join(
self.config.remote_data_path, f"data_{self.data_rank}", "latest"
)
if self.config.remote_data_path is not None:
remote_data_path = os.path.join(
self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.step}"
)
latest_remote_data_path = os.path.join(self.config.remote_data_path, f"data_{self.data_rank}", "latest")

self._async_save_remote(data_path, remote_data_path, blocking=False)
self._async_save_remote(data_path, latest_remote_data_path, blocking=False)
self._async_save_remote(data_path, remote_data_path, blocking=False)
self._async_save_remote(data_path, latest_remote_data_path, blocking=False)

gc.collect()

@staticmethod
def save_data_v2(data_path: str, dataloader, local_rank: int):
def save_data(data_path: str, dataloader, local_rank: int):
os.makedirs(data_path, exist_ok=True)
with open(os.path.join(data_path, f"_{local_rank}.pt"), "wb") as f:
state = {"data_loader": dataloader.state_dict()}
Expand Down Expand Up @@ -422,33 +407,14 @@ def _del__(self):

@torch.no_grad()
def _load_data(self, resume_ckpt_path: str):
## we have two formats to to save the dataloader:
## 1. v1: save the dataloader in the same file as the outer optimizer
## 2. v2: save the dataloader in a data folder inside the ckpt path
self._logger.debug(f"loading data from {resume_ckpt_path}")
world_info = get_world_info()

if self.config.data_version == "v2":
data_path = os.path.join(resume_ckpt_path, "data")

if os.path.exists(os.path.join(data_path, f"_{world_info.local_rank}.pt")):
with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f:
state = torch.load(f)
self.dataloader.load_state_dict(state["data_loader"])
return
else:
self._logger.debug(f"Data version is v2 but data folder {data_path} does not exist. trying v1 loading")

with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)
data_path = os.path.join(resume_ckpt_path, "data")

try:
self.dataloader.load_state_dict(rank_state_dict["data_loader"])
except KeyError as e:
self._logger.warning(
"Data_loader state_dict is not found. You probably are loading a v2 ckpt with v1 dataloader. Aborting"
)
raise e
with open(os.path.join(data_path, f"_{world_info.local_rank}.pt"), "rb") as f:
state = torch.load(f)
self.dataloader.load_state_dict(state["data_loader"])

@torch.no_grad()
def load(
Expand All @@ -471,6 +437,14 @@ def load(

world_info = get_world_info()

files = os.listdir(resume_ckpt_path)

if len(files) == 1 and files[0].startswith("diloco_"):
self._logger.warning(
f"Loading diloco ckpt from {files[0]}. This is deprecated and will be removed in the future"
)
resume_ckpt_path = os.path.join(resume_ckpt_path, files[0])

dcp.load(self.states, checkpoint_id=resume_ckpt_path)

if self.config.token_count is not None:
Expand Down