Skip to content

Commit e2cf199

Browse files
author
amaurya
committed
Update DataStates checkpoint engine with decoupled functionality
1 parent 5cf9911 commit e2cf199

File tree

9 files changed

+70
-43
lines changed

9 files changed

+70
-43
lines changed

deepspeed/datastates/README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11
# DataStates-LLM checkpointing engine.
22

3-
This feature is not enabled by default. To enable, set the following options in ds_config.json and download [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).
3+
This feature is not enabled by default. To enable, set the following options in ds_config.json and download the [DataStates-LLM checkpointing library](https://github.com/DataStates/datastates-llm/). A detailed tutorial is available [here](../../docs/_tutorials/datastates-async-checkpointing.md).
4+
5+
```
6+
{
7+
... other deepspeed config options,
8+
"datastates_ckpt": {
9+
"host_cache_size": 16
10+
}
11+
}
12+
```

deepspeed/datastates/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
# DeepSpeed Team
77

88
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
9+
import copy
10+
11+
DATASTATES_CHECKPOINTING = "datastates_ckpt"
12+
DATASTATES_CHECKPOINTING_ENABLED = False
913

1014

1115
class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):
1216

1317
def __init__(self, param_dict):
1418
super(DeepSpeedDataStatesConfig, self).__init__()
1519

16-
self.enabled = None
17-
self.config = {}
18-
19-
if "datastates_ckpt" in param_dict.keys():
20-
self.enabled = True
21-
self.config = param_dict["datastates_ckpt"]
20+
self.enabled = param_dict.get(DATASTATES_CHECKPOINTING, DATASTATES_CHECKPOINTING_ENABLED) is not False
21+
self.config = copy.deepcopy(param_dict.get(DATASTATES_CHECKPOINTING, None))

deepspeed/runtime/checkpoint_engine/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,12 @@ class CheckpointEngine(object):
3939

4040
### Asynchronous Lazy Checkpointing using DataStates-LLM
4141

42-
DataStates-LLM is an asynchronous checkpointing approach optimized for LLM pre-training and can be obtained at https://github.com/DataStates/datastates-llm. A detailed tutorial is available [here](../../../docs/_tutorials/datastates-async-checkpointing.md). To enable datastates-llm checkpointing, specify the `host_cache_size` (in gigabytes) which reserves pinned host memory for asynchronous checkpoint flushing, and `parser_threads` to parse multiple checkpoint file requests in parallel using the following lines in config.json supplied during the launch:
42+
DataStates-LLM is an asynchronous checkpointing approach optimized for LLM pre-training and can be obtained at https://github.com/DataStates/datastates-llm. A detailed tutorial is available [here](../../../docs/_tutorials/datastates-async-checkpointing.md). To enable datastates-llm checkpointing, specify the `host_cache_size` (in gigabytes) which reserves pinned host memory for asynchronous checkpoint flushing using the following lines in config.json supplied during the launch:
4343
```
4444
{
4545
... other deepspeed config options,
4646
"datastates_ckpt": {
47-
"host_cache_size": 16,
48-
"parser_threads": 8
47+
"host_cache_size": 16
4948
}
5049
}
5150
```

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,5 @@ def get_commit_info(self):
5959
def cleanup(self):
6060
pass
6161

62-
def wait(self):
63-
# To wait in asynchronous checkpoint engines (e.g. DataStates-LLM) for the previous snapshot to finish
64-
pass
62+
def preserves_storage_sharing(self):
63+
return True

deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,25 @@
55

66
# DeepSpeed Team
77

8-
from deepspeed.utils import log_dist
98
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
10-
CheckpointEngine
11-
from datastates.llm import Checkpointing
9+
CheckpointEngine, CheckpointCommitInfo
10+
from datastates import CheckpointEngine as DataStatesEngine
11+
12+
ENGINE_NAME = "DataStatesCheckpointEngine"
1213

1314

1415
class DataStatesCheckpointEngine(CheckpointEngine):
1516

1617
def __init__(self, deepspeed_config, rank):
1718
super().__init__(deepspeed_config)
18-
self.ckpt_engine = Checkpointing(deepspeed_config, rank)
19+
self.commit_info = None
20+
self.ckpt_engine = DataStatesEngine(deepspeed_config, rank)
21+
22+
def __del__(self):
23+
self.cleanup()
1924

20-
def create(self, tag):
21-
log_dist(f"[DataStates] Checkpoint {tag} is about to be saved!", ranks=[0])
25+
def create(self, info: CheckpointCommitInfo):
26+
self.commit_info = info
2227
return None
2328

2429
def save(self, state_dict, path: str):
@@ -27,8 +32,18 @@ def save(self, state_dict, path: str):
2732
def load(self, path: str, map_location=None):
2833
return self.ckpt_engine.load(path, map_location)
2934

30-
def commit(self, tag):
31-
return self.ckpt_engine.commit(tag)
35+
def commit(self, info: CheckpointCommitInfo):
36+
assert info == self.commit_info
37+
self.ckpt_engine.wait()
38+
return self.ckpt_engine.commit(info.tag)
39+
40+
def cleanup(self):
41+
self.commit(self.commit_info)
42+
self.ckpt_engine.wait(True)
43+
del self.ckpt_engine
44+
45+
def is_decoupled(self):
46+
return True
3247

33-
def wait(self):
34-
return self.ckpt_engine.wait()
48+
def preserves_storage_sharing(self):
49+
return False

deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
CheckpointEngine, CheckpointCommitInfo
1212
from deepspeed.utils import logger, log_dist
1313
from deepspeed.nebula.constants import *
14-
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
1514

1615

1716
def _get_tag_from_path(path):

deepspeed/runtime/checkpoint_engine/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepspeed.runtime.model_checkpointing.constants import *
77
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
88
from deepspeed.utils import logger
9+
from deepspeed import comm as dist
910

1011
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
1112
from .fast_checkpoint_engine import FastCheckpointEngine
@@ -35,4 +36,14 @@ def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers,
3536
else:
3637
return NebulaCheckpointEngine(config_params=config_params.nebula_config)
3738

39+
if config_params.datastates_config.enabled:
40+
try:
41+
from deepspeed.runtime.checkpoint_engine.datastates_checkpoint_engine import DataStatesCheckpointEngine
42+
return DataStatesCheckpointEngine(deepspeed_config=config_params, rank=dist.get_rank())
43+
except ImportError as err:
44+
logger.error(
45+
f"No datastates engine found! Install from https://github.com/DataStates/datastates-llm. Will fall back to torch.save. Details: {err}"
46+
)
47+
return TorchCheckpointEngine(config_params)
48+
3849
return TorchCheckpointEngine(config_params)

deepspeed/runtime/engine.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,16 +1140,6 @@ def _configure_checkpointing(self):
11401140
has_moe_layers=self.has_moe_layers,
11411141
optimize_dp_state=optimize_dp_state)
11421142

1143-
if self._config is not None and self._config.datastates_config.enabled:
1144-
try:
1145-
from deepspeed.runtime.checkpoint_engine.datastates_checkpoint_engine import DataStatesCheckpointEngine
1146-
self.checkpoint_engine = DataStatesCheckpointEngine(deepspeed_config=self._config,
1147-
rank=dist.get_rank())
1148-
except ImportError as err:
1149-
raise Exception(
1150-
f"The datastates-llm checkpoint engine was not found! Will fall back to torch.save. Details: {err}"
1151-
)
1152-
11531143
dp_rank = groups._get_sequence_data_parallel_rank()
11541144
rank = self.local_rank if self.use_node_local_storage() else dp_rank
11551145

@@ -2420,11 +2410,6 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
24202410
master_params = amp.master_params(self.optimizer)
24212411
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
24222412

2423-
try:
2424-
self.checkpoint_engine.wait()
2425-
except Exception as exc:
2426-
logger.error(f"Error during optimizer wait step: {exc}")
2427-
24282413
self.optimizer.step()
24292414

24302415
if hasattr(self.optimizer, '_global_grad_norm'):
@@ -3610,7 +3595,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36103595
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
36113596
if self.random_ltd_enabled():
36123597
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
3613-
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
3598+
saveable_state_dict = expert_state_dict
3599+
if self.checkpoint_engine.preserves_storage_sharing():
3600+
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
36143601
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
36153602
moe_layer_id += 1
36163603

@@ -3632,7 +3619,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36323619
}
36333620
# TODO: why use BufferedWriter not the path
36343621
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
3635-
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
3622+
saveable_state_dict = optimizer_state
3623+
if self.checkpoint_engine.preserves_storage_sharing():
3624+
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
36363625
self.checkpoint_engine.save(saveable_state_dict, file_path)
36373626

36383627
# Load flow uses below saved file for model parameters, RNG and more
@@ -3672,7 +3661,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
36723661
}
36733662
state.update(client_state)
36743663
logger.info(f'Saving model checkpoint: {save_path}')
3675-
saveable_state_dict = clone_tensors_for_torch_save(state)
3664+
saveable_state_dict = state
3665+
if self.checkpoint_engine.preserves_storage_sharing():
3666+
saveable_state_dict = clone_tensors_for_torch_save(state)
36763667
self.checkpoint_engine.save(saveable_state_dict, save_path)
36773668

36783669
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):

deepspeed/runtime/pipe/module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .topology import PipeDataParallelTopology, PipelineParallelGrid
2121
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
2222
from deepspeed.accelerator import get_accelerator
23+
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
2324

2425

2526
class PipelineError(Exception):
@@ -620,6 +621,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
620621
layer_list = self.forward_funcs[start:end]
621622

622623
checkpoint_engine.makedirs(save_dir, exist_ok=True)
624+
should_clone = checkpoint_engine.preserves_storage_sharing()
623625
for idx, layer in enumerate(layer_list):
624626
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
625627
if not hasattr(layer, 'state_dict'):
@@ -629,7 +631,9 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
629631
if exclude_frozen_params:
630632
for n in self._get_frozen_parameter_names(layer):
631633
del orig_state_dict[n]
632-
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
634+
final_state_dict = orig_state_dict
635+
if should_clone:
636+
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
633637
checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path)
634638

635639
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):

0 commit comments

Comments
 (0)