We are ByteDance Seed team.
You can get to know us better through the following channelsπ
ByteCheckpoint is a unified, efficient and production-grade checkpointing system for large foundation model development.
ByteCheckpoint is the open-source implementation of our research paper: ByteCheckpoint: A Unified Checkpointing System for Large Foundation Model Development.
ByteCheckpoint is easy to use and efficient with:
β Framework-Agnostic API: Provides a unified checkpointing entrypoint, i.e., bytecheckpoint.save and bytecheckpoint.load, to support various parallelism configurations across different frameworks.
β Load-time Checkpoint Resharding: Enables seamless checkpoint reloading with arbitrary new parallelism configurations, eliminating the need for manual resharding scripts.
β Optimized I/O Performance: Integrates advanced techniques such as asynchronous and parallel I/O, D2H tensor copying with pinned memory, load-balanced checkpointing, decomposed tensor representation.
β Comprehensive Toolset: Provides utilities for checkpoint merging/conversion/modification and metadata/tensor file inspection. Enables flexible checkpoint transfer and management.
[2025/04] We officially released ByteCheckpoint! π₯
[2024/12] ByteCheckpoint is accepted to NSDI 2025.
Install ByteCheckpoint from source.
git clone https://github.com/ByteDance-Seed/ByteCheckpoint.git
cd ByteCheckpoint
pip install -e .
Install ByteCheckpoint from PyPI.
pip install bytecheckpoint
We introduce how to use Bytecheckpoint to save, load, and merge checkpoint.
In ByteCheckpoint, a checkpoint consists of three parts (folders):
model: It contains model checkpoint, including one.metadatacheckpoint metadata file and multiple.distcptensor data files.optimizer: It contains optimizer checkpoint, including one.metadatacheckpoint metadata file and multiple.distcptensor data files.extra_state: It contains user-saved pickable objects, e.g., the dataloader state dictionary and RNG states.
Get model, optimizer, and extra states (RNG states, learning rate scheduler) from training code.
checkpoint_state = {
"model": model,
"optimizer": optimizer,
"extra_state": {'torch_rng_state': torch.get_rng_state()}
}Save them with ByteCheckpoint save API.
import bytecheckpoint as bcp
bcp.save(ckpt_path, checkpoint_state, framework="fsdp")Load them with ByteCheckpoint load API.
The model and optimizer will be loaded in an in-place manner.
The extra state will be loaded in checkpoint_state["extra_state"].
bcp.load(ckpt_path, checkpoint_state, framework="fsdp")
torch.set_rng_state(checkpoint_state["extra_state"]['torch_rng_state'])A simple single-machine FSDP training demo with ByteCheckpoint is on demo/fsdp_save_reshard.py
Start training and save checkpoint at each step:
# Train on 8 GPUs
torchrun --master_addr=localhost --master_port=6000 --nproc_per_node=8 --nnodes=1 demo/fsdp_save_reshard.py --mode normalLoad checkpoint and resume training:
# Load on 4 GPUs
torchrun --master_addr=localhost --master_port=6000 --nproc_per_node=4 --nnodes=1 demo/fsdp_save_reshard.py --mode resumeFor multi-machine training, we recommend operating checkpoint in a shared file system that supports POSIX semantics, such as NFS.
To merge model checkpoint, you can use scripts/merge_bcp.py
Merge saved checkpoint in the demo training code with safetensors format:
python3 scripts/merge_bcp.py --framework fsdp \
--ckpt_path tmp_checkpoint_dir_fsdp/global_step_0 \
--output_path merged_ckpt_fsdp \
--safetensors_format \
--model_only
- Enable
fast_savingandfast_loadingto use asynchronous and parallel I/O techniques. - Enable
save_decomposed_model_optimizerandload_decomposed_model_optimizerfor FSDP (use_orig_params=Trueis required) to obtain model/optimizer state dict without additional communication and GPU-CPU synchronization. - Pass the
rolekeyword (e.g., actor, critic) to support checkpointing in multi-role training scenarios, such as PPO training. - Enable
strictinloadAPI to check whether the fqns in a given state_dict are strictly the same as those recorded in the .metadata file.
- Enable
BYTECHECKPOINT_ENABLE_TREE_TOPOto improve the stability of large-scale planning for model/optimizer planning. - Enable
BYTECHECKPOINT_ENABLE_PINNED_MEM_D2Hto use the pinned CPU memory pool to accelerate D2H tensor copying. - Adjust
BYTECHECKPOINT_STORE_WORKER_COUNTandBYTECHECKPOINT_LOAD_WORKER_COUNTto tune the I/O performance.
Please refer to config.py for more details.
Community contributions are welcome. Please checkout Contribution Guidance.
We use ruff to enforce strict code formatting when reviewing PRs. To reformat your code locally, make sure you have installed the latest version of ruff.
pip install ruff
Then you can format code with:
bash format_code.sh
Run local tests with:
bash test.sh
This project is licensed under Apache License 2.0. See the LICENSE file for details.
If you find this project helpful, please give us a star β and cite our paper:
@article{wan2024bytecheckpoint,
title={ByteCheckpoint: A Unified Checkpointing System for Large Foundation Model Development},
author={Borui, Wan and Mingji, Han and Yiyao, Sheng and Yanghua, Peng and Haibin, Lin and Mofan, Zhang and Zhichao, Lai and Menghan, Yu and Junda, Zhang and Zuquan, Song and Xin, Liu and Chuan, Wu},
journal={arXiv preprint arXiv:2407.20143},
year={2024}
}ByteCheckpoint is inspired by the design of PyTorch Distributed Checkpoint (DCP).
π± About ByteDance Seed Team
Founded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society.
