Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3135,11 +3135,30 @@ def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):

shard_files = list(set(index["weight_map"].values()))
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu")

ret = {}
for shard_file in tqdm(shard_files):
state_dict = loader(os.path.join(folder, shard_file))
ret.update(state_dict)
try:
from fastsafetensors import fastsafe_open

path = [os.path.join(folder, shard_file) for shard_file in shard_files]
device = "gpu" if paddle.device.cuda.device_count() else "cpu"
not_use_gds = False
# Check load time of files
for _ in tqdm(range(1)):
with fastsafe_open(
filenames=path,
nogds=not_use_gds,
device=device,
max_copy_block_size=256 * 1024 * 1024,
framework="paddle",
) as f:
for key in f.get_keys():
# Must clone, because cuda memory will be destroyed after `with` end.
ret[key] = f.get_tensor(key).clone().detach()
except:

for shard_file in tqdm(shard_files):
state_dict = loader(os.path.join(folder, shard_file))
ret.update(state_dict)

if not return_numpy:
for key in list(ret.keys()):
Expand Down
Loading