Skip to content

Commit acc4dc3

Browse files
author
amaurya
committed
Update datastates using decoupled checkpointing APIs (fix pre-commit)
1 parent 4eb3772 commit acc4dc3

File tree

10 files changed

+172
-5
lines changed

10 files changed

+172
-5
lines changed

deepspeed/datastates/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# DataStates-LLM checkpointing engine.
2+
3+
This feature is not enabled by default. To enable, set the following options in ds_config.json and download the [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).
4+
5+
```
6+
{
7+
... other deepspeed config options,
8+
"datastates_ckpt": {
9+
"host_cache_size": 16
10+
}
11+
}
12+
```

deepspeed/datastates/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team

deepspeed/datastates/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
9+
import copy
10+
11+
DATASTATES_CHECKPOINTING = "datastates_ckpt"
12+
DATASTATES_CHECKPOINTING_ENABLED = False
13+
14+
15+
class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):
16+
17+
def __init__(self, param_dict):
18+
super(DeepSpeedDataStatesConfig, self).__init__()
19+
20+
self.enabled = param_dict.get(DATASTATES_CHECKPOINTING, DATASTATES_CHECKPOINTING_ENABLED) is not False
21+
self.config = copy.deepcopy(param_dict.get(DATASTATES_CHECKPOINTING, None))

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ def get_commit_info(self):
5858

5959
def cleanup(self):
6060
pass
61+
62+
def preserves_storage_sharing(self):
63+
return True
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
9+
CheckpointEngine, CheckpointCommitInfo
10+
from datastates import CheckpointEngine as DataStatesEngine
11+
12+
ENGINE_NAME = "DataStatesCheckpointEngine"
13+
14+
15+
class DataStatesCheckpointEngine(CheckpointEngine):
16+
17+
def __init__(self, deepspeed_config, rank):
18+
super().__init__(deepspeed_config)
19+
self.commit_info = None
20+
self.ckpt_engine = DataStatesEngine(deepspeed_config, rank)
21+
22+
def __del__(self):
23+
self.cleanup()
24+
25+
def create(self, info: CheckpointCommitInfo):
26+
self.commit_info = info
27+
return None
28+
29+
def save(self, state_dict, path: str):
30+
return self.ckpt_engine.save(state_dict, path)
31+
32+
def load(self, path: str, map_location=None):
33+
return self.ckpt_engine.load(path, map_location)
34+
35+
def commit(self, info: CheckpointCommitInfo):
36+
assert info == self.commit_info
37+
self.ckpt_engine.wait()
38+
return self.ckpt_engine.commit(info.tag)
39+
40+
def cleanup(self):
41+
self.commit(self.commit_info)
42+
self.ckpt_engine.wait(True)
43+
del self.ckpt_engine
44+
45+
def is_decoupled(self):
46+
return True
47+
48+
def preserves_storage_sharing(self):
49+
return False

deepspeed/runtime/checkpoint_engine/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from deepspeed.runtime.model_checkpointing.constants import *
77
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
88
from deepspeed.utils import logger
9-
9+
from deepspeed import comm as dist
1010
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
1111
from .fast_checkpoint_engine import FastCheckpointEngine
1212
from .torch_checkpoint_engine import TorchCheckpointEngine
@@ -35,4 +35,14 @@ def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers,
3535
else:
3636
return NebulaCheckpointEngine(config_params=config_params.nebula_config)
3737

38+
if config_params.datastates_config.enabled:
39+
try:
40+
from deepspeed.runtime.checkpoint_engine.datastates_checkpoint_engine import DataStatesCheckpointEngine
41+
return DataStatesCheckpointEngine(deepspeed_config=config_params, rank=dist.get_rank())
42+
except ImportError as err:
43+
logger.error(
44+
f"No datastates engine found! Install from https://github.com/DataStates/datastates-llm. Will fall back to torch.save. Details: {err}"
45+
)
46+
return TorchCheckpointEngine(config_params)
47+
3848
return TorchCheckpointEngine(config_params)

deepspeed/runtime/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from ..profiling.config import DeepSpeedFlopsProfilerConfig
5353
from ..autotuning.config import DeepSpeedAutotuningConfig
5454
from ..nebula.config import DeepSpeedNebulaConfig
55+
from ..datastates.config import DeepSpeedDataStatesConfig
5556

5657
from ..compression.config import get_compression_config, get_quantize_enabled
5758
from ..compression.constants import *
@@ -859,6 +860,7 @@ def _initialize_params(self, param_dict):
859860
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
860861

861862
self.nebula_config = DeepSpeedNebulaConfig(param_dict)
863+
self.datastates_config = DeepSpeedDataStatesConfig(param_dict)
862864
self.checkpoint_config = get_checkpoint_config(param_dict)
863865

864866
self.weight_quantization_config = WeightQuantConfig(

deepspeed/runtime/engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3594,7 +3594,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
35943594
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
35953595
if self.random_ltd_enabled():
35963596
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
3597-
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
3597+
saveable_state_dict = expert_state_dict
3598+
if self.checkpoint_engine.preserves_storage_sharing():
3599+
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
35983600
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
35993601
moe_layer_id += 1
36003602

@@ -3616,7 +3618,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36163618
}
36173619
# TODO: why use BufferedWriter not the path
36183620
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
3619-
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
3621+
saveable_state_dict = optimizer_state
3622+
if self.checkpoint_engine.preserves_storage_sharing():
3623+
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
36203624
self.checkpoint_engine.save(saveable_state_dict, file_path)
36213625

36223626
# Load flow uses below saved file for model parameters, RNG and more
@@ -3656,7 +3660,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36563660
}
36573661
state.update(client_state)
36583662
logger.info(f'Saving model checkpoint: {save_path}')
3659-
saveable_state_dict = clone_tensors_for_torch_save(state)
3663+
savable_state_dict = state
3664+
if self.checkpoint_engine.preserves_storage_sharing():
3665+
saveable_state_dict = clone_tensors_for_torch_save(state)
36603666
self.checkpoint_engine.save(saveable_state_dict, save_path)
36613667

36623668
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):

deepspeed/runtime/pipe/module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
621621
layer_list = self.forward_funcs[start:end]
622622

623623
checkpoint_engine.makedirs(save_dir, exist_ok=True)
624+
should_clone = checkpoint_engine.preserves_storage_sharing()
624625
for idx, layer in enumerate(layer_list):
625626
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
626627
if not hasattr(layer, 'state_dict'):
@@ -630,7 +631,9 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
630631
if exclude_frozen_params:
631632
for n in self._get_frozen_parameter_names(layer):
632633
del orig_state_dict[n]
633-
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634+
final_state_dict = orig_state_dict
635+
if should_clone:
636+
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634637
checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path)
635638

636639
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
---
2+
title: "DataStates-LLM Checkpointing Engine"
3+
tags: asynchronous checkpointing for minimizing I/O overheads.
4+
---
5+
This tutorial will show how to use [DataStates-LLM](https://github.com/DataStates/datastates-llm) for asynchronous checkpointing. DataStates-LLM introduces a lazy asynchronous checkpointing mechanism tailored for LLMs, aiming to minimize I/O overhead and enhance training efficiency. This tutorial provides a guide on integrating DataStates-LLM with the DeepSpeed framework.
6+
7+
## Overview of DataStates-LLM
8+
9+
DataStates-LLM is designed to address the challenges of frequent checkpointing in LLM training by introducing a lazy asynchronous multi-level approach. It leverages the immutability of model parameters and optimizer states during forward and backward passes to perform non-blocking data transfers, thereby reducing interference with the training process. This method has demonstrated up to 48x faster checkpointing and 2.2x faster end-to-end training times compared to traditional approaches as outlined in [DataStates-LLM: Lazy Asynchronous Checkpointing for Large Language Models](https://arxiv.org/abs/2406.10707).
10+
11+
## Prerequisites
12+
13+
Before integrating DataStates-LLM with DeepSpeed, ensure the following:
14+
15+
- **DeepSpeed Installation**: DeepSpeed should be installed in your environment. If not, refer to the [DeepSpeed Getting Started Guide](https://github.com/microsoft/DeepSpeed/blob/master/docs/_tutorials/getting-started.md) for installation instructions.
16+
17+
- **DataStates-LLM Repository**: Access the DataStates-LLM source code from its [GitHub repository](https://github.com/DataStates/datastates-llm) and follow the installation instructions provided therein.
18+
19+
## Configuring DeepSpeed for DataStates-LLM
20+
21+
To enable DataStates-LLM's asynchronous checkpointing within DeepSpeed, please modify the `deepspeed_config.json` file to include specific configurations under the `datastates_ckpt` section. Below is an example configuration:
22+
23+
```json
24+
{
25+
// ... other DeepSpeed configuration options
26+
"datastates_ckpt": {
27+
"host_cache_size": 16
28+
}
29+
}
30+
```
31+
32+
### Configuration Parameters
33+
34+
- **`host_cache_size`**: Specifies the amount of pinned host memory (in gigabytes) reserved for asynchronous checkpoint flushing. Adjust this value based on your system's memory capacity and the size of your model checkpoints.
35+
36+
## Implementing DataStates-LLM in Your Training Script
37+
38+
After enabling datastates checkpointing the `deepspeed_config.json`, the frequency of checkpointing can be configured by specifying the number of iterations after which the checkpoints should be captured using command-line parameter ` --save-interval`.
39+
40+
## Limitations and Ongoing Work
41+
42+
1. DataStates-LLM currently only supports the CUDA runtime on Nvidia-based GPUs.
43+
44+
45+
2. DataStates-LLM has only been tested with ZeRO stage-1 without offloading to any other tiers.
46+
47+
48+
3. While the checkpoint layout of datastates matches Huggingface's [safetensor](https://huggingface.co/docs/safetensors/) format, due to pickled objects required by DeepSpeed during restart, it is not fully compatible with safetensor library yet.
49+
50+
4. DataStates-LLM does not yet support universal or elastic checkpointing.
51+
52+
53+
## Questions and Support
54+
55+
Please use the [DataStates-LLM Github repository](https://github.com/DataStates/datastates-llm) for any questions, issues, or feature requests.

0 commit comments

Comments
 (0)