Skip to content

Commit e29fb7a

Browse files
committed
update torch.load to include weights_only parameter in deepspeed utility
1 parent 902b04f commit e29fb7a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/lightning/pytorch/utilities/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(
9393
]
9494
checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
9595
optim_files = get_optim_files(checkpoint_dir)
96-
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
96+
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False)
9797
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
9898
model_file = get_model_state_file(checkpoint_dir, zero_stage)
99-
client_state = torch.load(model_file, map_location=CPU_DEVICE)
99+
client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False)
100100
client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states}
101101
# State dict keys will include reference to wrapper _LightningModuleWrapperBase in old checkpoints created in
102102
# Lightning version < 2.1. Delete the `_forward_module` prefix before saving.

0 commit comments

Comments
 (0)