Skip to content

Commit aace707

Browse files
author
amaurya
committed
Update DataStates as decoupled checkpoint engine
2 parents 7d9a2f2 + e2cf199 commit aace707

File tree

13 files changed

+198
-4
lines changed

13 files changed

+198
-4
lines changed

deepspeed/datastates/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# DataStates-LLM checkpointing engine.
2+
3+
This feature is not enabled by default. To enable, set the following options in ds_config.json and download the [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).
4+
5+
```
6+
{
7+
... other deepspeed config options,
8+
"datastates_ckpt": {
9+
"host_cache_size": 16
10+
}
11+
}
12+
```

deepspeed/datastates/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team

deepspeed/datastates/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
9+
import copy
10+
11+
DATASTATES_CHECKPOINTING = "datastates_ckpt"
12+
DATASTATES_CHECKPOINTING_ENABLED = False
13+
14+
15+
class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):
16+
17+
def __init__(self, param_dict):
18+
super(DeepSpeedDataStatesConfig, self).__init__()
19+
20+
self.enabled = param_dict.get(DATASTATES_CHECKPOINTING, DATASTATES_CHECKPOINTING_ENABLED) is not False
21+
self.config = copy.deepcopy(param_dict.get(DATASTATES_CHECKPOINTING, None))

deepspeed/runtime/checkpoint_engine/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,16 @@ class CheckpointEngine(object):
3535
pass
3636

3737
```
38+
39+
40+
### Asynchronous Lazy Checkpointing using DataStates-LLM
41+
42+
DataStates-LLM is an asynchronous checkpointing approach optimized for LLM pre-training and can be obtained at https://github.com/DataStates/datastates-llm. A detailed tutorial is available [here](../../../docs/_tutorials/datastates-async-checkpointing.md). To enable datastates-llm checkpointing, specify the `host_cache_size` (in gigabytes) which reserves pinned host memory for asynchronous checkpoint flushing using the following lines in config.json supplied during the launch:
43+
```
44+
{
45+
... other deepspeed config options,
46+
"datastates_ckpt": {
47+
"host_cache_size": 16
48+
}
49+
}
50+
```

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ def get_commit_info(self):
5858

5959
def cleanup(self):
6060
pass
61+
62+
def preserves_storage_sharing(self):
63+
return True
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Apache-2.0 License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
7+
8+
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
9+
CheckpointEngine, CheckpointCommitInfo
10+
from datastates import CheckpointEngine as DataStatesEngine
11+
12+
ENGINE_NAME = "DataStatesCheckpointEngine"
13+
14+
15+
class DataStatesCheckpointEngine(CheckpointEngine):
16+
17+
def __init__(self, deepspeed_config, rank):
18+
super().__init__(deepspeed_config)
19+
self.commit_info = None
20+
self.ckpt_engine = DataStatesEngine(deepspeed_config, rank)
21+
22+
def __del__(self):
23+
self.cleanup()
24+
25+
def create(self, info: CheckpointCommitInfo):
26+
self.commit_info = info
27+
return None
28+
29+
def save(self, state_dict, path: str):
30+
return self.ckpt_engine.save(state_dict, path)
31+
32+
def load(self, path: str, map_location=None):
33+
return self.ckpt_engine.load(path, map_location)
34+
35+
def commit(self, info: CheckpointCommitInfo):
36+
assert info == self.commit_info
37+
self.ckpt_engine.wait()
38+
return self.ckpt_engine.commit(info.tag)
39+
40+
def cleanup(self):
41+
self.commit(self.commit_info)
42+
self.ckpt_engine.wait(True)
43+
del self.ckpt_engine
44+
45+
def is_decoupled(self):
46+
return True
47+
48+
def preserves_storage_sharing(self):
49+
return False

deepspeed/runtime/checkpoint_engine/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepspeed.runtime.model_checkpointing.constants import *
77
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
88
from deepspeed.utils import logger
9+
from deepspeed import comm as dist
910

1011
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
1112
from .fast_checkpoint_engine import FastCheckpointEngine
@@ -35,4 +36,14 @@ def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers,
3536
else:
3637
return NebulaCheckpointEngine(config_params=config_params.nebula_config)
3738

39+
if config_params.datastates_config.enabled:
40+
try:
41+
from deepspeed.runtime.checkpoint_engine.datastates_checkpoint_engine import DataStatesCheckpointEngine
42+
return DataStatesCheckpointEngine(deepspeed_config=config_params, rank=dist.get_rank())
43+
except ImportError as err:
44+
logger.error(
45+
f"No datastates engine found! Install from https://github.com/DataStates/datastates-llm. Will fall back to torch.save. Details: {err}"
46+
)
47+
return TorchCheckpointEngine(config_params)
48+
3849
return TorchCheckpointEngine(config_params)

deepspeed/runtime/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from ..profiling.config import DeepSpeedFlopsProfilerConfig
5353
from ..autotuning.config import DeepSpeedAutotuningConfig
5454
from ..nebula.config import DeepSpeedNebulaConfig
55+
from ..datastates.config import DeepSpeedDataStatesConfig
5556

5657
from ..compression.config import get_compression_config, get_quantize_enabled
5758
from ..compression.constants import *
@@ -859,6 +860,7 @@ def _initialize_params(self, param_dict):
859860
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
860861

861862
self.nebula_config = DeepSpeedNebulaConfig(param_dict)
863+
self.datastates_config = DeepSpeedDataStatesConfig(param_dict)
862864
self.checkpoint_config = get_checkpoint_config(param_dict)
863865

864866
self.weight_quantization_config = WeightQuantConfig(

deepspeed/runtime/engine.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,6 +2409,7 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
24092409
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
24102410
master_params = amp.master_params(self.optimizer)
24112411
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
2412+
24122413
self.optimizer.step()
24132414

24142415
if hasattr(self.optimizer, '_global_grad_norm'):
@@ -3594,7 +3595,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
35943595
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
35953596
if self.random_ltd_enabled():
35963597
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
3597-
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
3598+
saveable_state_dict = expert_state_dict
3599+
if self.checkpoint_engine.preserves_storage_sharing():
3600+
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
35983601
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
35993602
moe_layer_id += 1
36003603

@@ -3616,7 +3619,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36163619
}
36173620
# TODO: why use BufferedWriter not the path
36183621
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
3619-
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
3622+
saveable_state_dict = optimizer_state
3623+
if self.checkpoint_engine.preserves_storage_sharing():
3624+
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
36203625
self.checkpoint_engine.save(saveable_state_dict, file_path)
36213626

36223627
# Load flow uses below saved file for model parameters, RNG and more
@@ -3656,7 +3661,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36563661
}
36573662
state.update(client_state)
36583663
logger.info(f'Saving model checkpoint: {save_path}')
3659-
saveable_state_dict = clone_tensors_for_torch_save(state)
3664+
saveable_state_dict = state
3665+
if self.checkpoint_engine.preserves_storage_sharing():
3666+
saveable_state_dict = clone_tensors_for_torch_save(state)
36603667
self.checkpoint_engine.save(saveable_state_dict, save_path)
36613668

36623669
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):

deepspeed/runtime/pipe/module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
621621
layer_list = self.forward_funcs[start:end]
622622

623623
checkpoint_engine.makedirs(save_dir, exist_ok=True)
624+
should_clone = checkpoint_engine.preserves_storage_sharing()
624625
for idx, layer in enumerate(layer_list):
625626
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
626627
if not hasattr(layer, 'state_dict'):
@@ -630,7 +631,9 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
630631
if exclude_frozen_params:
631632
for n in self._get_frozen_parameter_names(layer):
632633
del orig_state_dict[n]
633-
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634+
final_state_dict = orig_state_dict
635+
if should_clone:
636+
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634637
checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path)
635638

636639
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):

0 commit comments

Comments
 (0)