11import copy
22import logging
33import os
4+ from contextlib import nullcontext
45from functools import reduce
56from pathlib import Path
67from shutil import rmtree
78from typing import Dict , Iterator , Optional , OrderedDict , Tuple
8- from contextlib import nullcontext
99
1010import torch
1111import torch .distributed as dist
2626from colossalai .utils .safetensors import _flatten_optim_state_dict , load_flat
2727
2828from .distributed_checkpoint_utils import (
29+ MODEL_WEIGHT_PREFIX ,
30+ RestoreDefaultStateDictBehavior ,
2931 create_model_metadata ,
32+ get_dist_files_name ,
33+ get_dist_meta_file_name ,
3034 is_pytorch_model_meta_dist_file ,
3135 load_dist_model ,
3236 save_metadata ,
33- get_dist_files_name ,
34- get_dist_meta_file_name ,
35- MODEL_WEIGHT_PREFIX ,
36- RestoreDefaultStateDictBehavior
3737)
3838from .general_checkpoint_io import GeneralCheckpointIO
3939from .index_file import CheckpointIndexFile
@@ -108,7 +108,7 @@ def _model_sharder(
108108 keep_vars : bool = False ,
109109 size_per_shard : int = 1024 ,
110110 pinned_state_dicts : Optional [Dict [str , torch .Tensor ]] = None ,
111- gather_dtensor : bool = True ,
111+ gather_dtensor : bool = True ,
112112 ) -> Iterator [Tuple [OrderedDict , int ]]:
113113 # An internel method that breaks state_dict of model into shards within limited size.
114114
@@ -118,7 +118,7 @@ def _model_sharder(
118118 for name , param in model .named_parameters ():
119119 if param is None :
120120 continue
121-
121+
122122 # Gather tensor pieces when using tensor parallel.
123123 param_ = gather_distributed_param (param , keep_vars = False )
124124 if is_padded_tensor (param_ ):
@@ -245,12 +245,12 @@ def save_sharded_model(
245245 model ._force_wait_all_gather ()
246246 if self .dp_rank != 0 and self .sp_rank != 0 :
247247 return
248-
248+
249249 model_metadata = None
250250 if not gather_dtensor :
251251 # Manage filenames of sharded weights and index file for each pipeline stage.
252252 model_metadata = create_model_metadata (model , tp_size = self .tp_size , tp_rank = self .tp_rank )
253-
253+
254254 model = model .unwrap ()
255255
256256 if os .path .isfile (checkpoint ):
@@ -280,7 +280,9 @@ def save_sharded_model(
280280 if not gather_dtensor :
281281 dist_id = self .tp_size * self .pp_rank + self .tp_rank
282282 weights_name = get_dist_files_name (weights_name = weights_name , dist_id = dist_id )
283- metadata_file = get_dist_meta_file_name (checkpoint = checkpoint , dist_id = dist_id , use_safetensors = use_safetensors )
283+ metadata_file = get_dist_meta_file_name (
284+ checkpoint = checkpoint , dist_id = dist_id , use_safetensors = use_safetensors
285+ )
284286
285287 if use_async :
286288 total_size , writers = async_save_state_dict_shards (
@@ -413,9 +415,7 @@ def load_sharded_model(
413415 )
414416 model = model .unwrap ()
415417 with RestoreDefaultStateDictBehavior (model ):
416- load_state_dict_into_model (
417- model , state_dict , missing_keys = [], strict = False , load_sub_module = True
418- )
418+ load_state_dict_into_model (model , state_dict , missing_keys = [], strict = False , load_sub_module = True )
419419 return
420420
421421 model_before_wrapping = model # backup for model before wrapping
@@ -897,7 +897,7 @@ def load_unsharded_model(
897897 load_dtensor = True
898898 break
899899
900- model_metadata = None # used for dist model
900+ model_metadata = None # used for dist model
901901 if load_dtensor :
902902 model_metadata = create_model_metadata (model , tp_size = self .tp_size , tp_rank = self .tp_rank )
903903
0 commit comments