-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat ckpt #18
Merged
Feat ckpt #18
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
272e903
fix vocab size for debugmodel and real data
samsja 5b01f13
add torchdata
samsja b01eb90
add ckpt v0
samsja 98e9565
fix real dataset loading error
samsja 986aa20
ckpt save in the right step folder
samsja 7ccce1b
change total tokens diloco
samsja 09d59f0
add tests
samsja 10b6215
refactor easier step
samsja d7119ae
use fsspec rsync for bbacking up the ckpt to remote
samsja 7622bbe
add async saving to remote
samsja aab5a5b
remove unused file
samsja 6b63fe3
fix rebase
samsja 00eaec7
fix ckpt
samsja fcbd102
add diloco ckpt
samsja 3c052dd
save into dioco scpeific folder
samsja 47c5396
firemove process group
samsja e080641
firemove process group
samsja 1c098b2
add diloco rank
samsja 2c54c2c
fix ckpt issue
samsja b9e6eef
remove ckpt tests
samsja File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
from dataclasses import dataclass | ||
import multiprocessing | ||
import os | ||
import time | ||
from typing import Any | ||
from fsspec.generic import rsync as rsync_fsspec | ||
import torch | ||
from torch import nn | ||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import LambdaLR | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
import torch.distributed.checkpoint as dcp | ||
from torch.distributed.checkpoint.state_dict import ( | ||
set_optimizer_state_dict, | ||
set_model_state_dict, | ||
get_model_state_dict, | ||
get_optimizer_state_dict, | ||
StateDictOptions, | ||
) | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from zeroband.utils.logging import get_logger | ||
import warnings | ||
import logging | ||
|
||
from zeroband.utils.world_info import get_world_info | ||
|
||
## code inspired by torchtitan https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py | ||
|
||
|
||
@dataclass | ||
class TrainingProgress(Stateful): | ||
total_tokens: int | ||
outer_step: int | ||
step: int | ||
|
||
def state_dict(self) -> dict[str, Any]: | ||
return {"total_tokens": self.total_tokens, "outer_step": self.outer_step, "step": self.step} | ||
|
||
def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
self.total_tokens = state_dict["total_tokens"] | ||
self.outer_step = state_dict["outer_step"] | ||
self.step = state_dict["step"] | ||
|
||
|
||
class ModelWrapper(Stateful): | ||
def __init__(self, model: nn.Module) -> None: | ||
self.model = model | ||
|
||
def state_dict(self) -> dict[str, Any]: | ||
return get_model_state_dict(self.model) | ||
|
||
def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
set_model_state_dict(model=self.model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) | ||
|
||
|
||
class OptimizerWrapper(Stateful): | ||
def __init__( | ||
self, | ||
model: nn.Module, | ||
optim: torch.optim.Optimizer, | ||
) -> None: | ||
self.model = model | ||
self.optim = optim | ||
|
||
def state_dict(self) -> dict[str, Any]: | ||
return get_optimizer_state_dict( | ||
model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True) | ||
) | ||
|
||
def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
set_optimizer_state_dict( | ||
model=self.model, optimizers=self.optim, optim_state_dict=state_dict, options=StateDictOptions(strict=False) | ||
) | ||
|
||
|
||
class CkptManager: | ||
"""Its name CkptManager because I (sami) always misstyped chekcpoint. | ||
|
||
Checkpoint are saved in a folder with the following structure: | ||
ckpt_path/ | ||
step_0/ | ||
_0_0.pt | ||
_1_0.pt | ||
... | ||
step_1/ | ||
... | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: nn.Module, | ||
optimizer: Optimizer, | ||
scheduler: LambdaLR, | ||
dataloader: StatefulDataLoader, | ||
training_progress: TrainingProgress, | ||
diloco_offloaded_param_list: list[nn.Parameter] | None, | ||
diloco_offloaded_optimizer: Optimizer | None, | ||
): | ||
self.model = ModelWrapper(model) | ||
self.optimizer = OptimizerWrapper(model, optimizer) | ||
self.scheduler = scheduler | ||
self.dataloader = dataloader | ||
self.training_progress = training_progress | ||
|
||
# states can only be stateful object, hence we need to wrap Model and Optimizer | ||
self.states: dict[str, Stateful] = { | ||
"model": self.model, | ||
"optimizer": self.optimizer, | ||
"scheduler": self.scheduler, | ||
# "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader | ||
"training_progress": self.training_progress, | ||
} | ||
|
||
assert (diloco_offloaded_param_list is None) == ( | ||
diloco_offloaded_optimizer is None | ||
), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values" | ||
|
||
self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed | ||
# which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho | ||
self.diloco_offloaded_param_list = diloco_offloaded_param_list | ||
|
||
if diloco_offloaded_optimizer is not None: | ||
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. | ||
# main reason is that we actually don't a cpu model but just a list of cpu parameters. | ||
self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer | ||
|
||
self._logger = get_logger() | ||
|
||
self.async_save_process: list[multiprocessing.Process] = [] | ||
|
||
def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: | ||
""" | ||
Each rank will save the right shard of the model and optimizer. | ||
|
||
Saving is done inplace | ||
""" | ||
|
||
time_start = time.perf_counter() | ||
world_info = get_world_info() | ||
|
||
ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}") | ||
if self.diloco_offloaded_optimizer: | ||
# here we save model and offloaded optimizer on each diloco rank even tho they are the same | ||
# this is done for two reasons: | ||
# * if the nodes don't share a filesystem nor a remote path, they still save all of the data | ||
# * its easier to implement and avoid race condition on the shared data. | ||
ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}") | ||
|
||
catch_warning = self._logger.getEffectiveLevel() <= logging.INFO | ||
|
||
with warnings.catch_warnings(): | ||
# pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 | ||
# we can ignore it if we are not logging in DEBUG mode | ||
if catch_warning: | ||
warnings.simplefilter("ignore") | ||
|
||
dcp.save(self.states, checkpoint_id=ckpt_path) | ||
|
||
## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk | ||
with open(os.path.join(ckpt_path, f"__{world_info.local_rank}_0.pt"), "wb") as f: | ||
torch.save({"data_loader": self.dataloader.state_dict()}, f) | ||
|
||
self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") | ||
|
||
if remote_ckpt_path is not None: | ||
self._async_save_remote(ckpt_path, remote_ckpt_path) | ||
|
||
def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str): | ||
"""asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install | ||
specific libraries (e.g. s3fs) | ||
""" | ||
|
||
def rsync(): | ||
time_start = time.perf_counter() | ||
self._logger.info(f"start pushing {ckpt_path} to {remote_ckpt_path} asynchronously") | ||
rsync_fsspec(ckpt_path, destination=remote_ckpt_path) | ||
self._logger.info( | ||
f"finish pushing {ckpt_path} to {remote_ckpt_path} in {time.perf_counter() - time_start} seconds" | ||
) | ||
|
||
processes = multiprocessing.Process(target=rsync, daemon=True) | ||
processes.start() | ||
|
||
self.async_save_process.append(processes) | ||
|
||
def wait_async_save_process(self): | ||
""" | ||
wait for all async save process to finish | ||
""" | ||
for process in self.async_save_process: | ||
process.join() | ||
|
||
def _del__(self): | ||
self.wait_async_save_process() | ||
|
||
def load(self, resume_ckpt_path: str) -> None: | ||
""" | ||
loading should be done after fsdp wrap and optimizer init. | ||
Each rank will load the right shard of the model and optimizer. | ||
All rank will load the global states (scheduler, step, total_tokens, dataloader). | ||
|
||
`resume_ckpt_path` should point to a specific step and not to the base ckpt folder. Example: `ckpt_path/step_100` | ||
|
||
Loading is done inplace | ||
""" | ||
time_start = time.perf_counter() | ||
|
||
world_info = get_world_info() | ||
if self.diloco_offloaded_param_list is not None: | ||
resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") | ||
|
||
self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path) | ||
|
||
# since we don't load the param list from the state dict as its the same as the model one we just copy | ||
if self.diloco_offloaded_param_list is not None: | ||
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()): | ||
param_offloaded.data.copy_(param_model.data) | ||
|
||
## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk | ||
with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: | ||
rank_state_dict = torch.load(f) | ||
|
||
self.dataloader.load_state_dict(rank_state_dict["data_loader"]) | ||
|
||
self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this prevents us from resuming with a different number of local ranks right? For now since we are just running on 8xH100 nodes it is fine, just good to keep in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes indeed, if we need to resume from a different diloco we can easily implement it