diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 4415d8a74afa..4a3ad92ba21f 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -3135,11 +3135,30 @@ def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False): shard_files = list(set(index["weight_map"].values())) loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu") - ret = {} - for shard_file in tqdm(shard_files): - state_dict = loader(os.path.join(folder, shard_file)) - ret.update(state_dict) + try: + from fastsafetensors import fastsafe_open + + path = [os.path.join(folder, shard_file) for shard_file in shard_files] + device = "gpu" if paddle.device.cuda.device_count() else "cpu" + not_use_gds = False + # Check load time of files + for _ in tqdm(range(1)): + with fastsafe_open( + filenames=path, + nogds=not_use_gds, + device=device, + max_copy_block_size=256 * 1024 * 1024, + framework="paddle", + ) as f: + for key in f.get_keys(): + # Must clone, because cuda memory will be destroyed after `with` end. + ret[key] = f.get_tensor(key).clone().detach() + except: + + for shard_file in tqdm(shard_files): + state_dict = loader(os.path.join(folder, shard_file)) + ret.update(state_dict) if not return_numpy: for key in list(ret.keys()):