Skip to content

Commit

Permalink
Merge pull request #455 from mj-will/fix-torch-loading
Browse files Browse the repository at this point in the history
MAINT: specify `weights_only=True` in `load.torch`
  • Loading branch information
mj-will authored Feb 3, 2025
2 parents 49ce6d8 + 04b3c80 commit b1e88f3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/nessai/flowmodel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def load_weights(self, weights_file):
# TODO: these two methods are basically the same
if not self.initialised:
self.initialise()
self.model.load_state_dict(torch.load(weights_file))
self.model.load_state_dict(torch.load(weights_file, weights_only=True))
self.model.eval()
self.weights_file = weights_file

Expand Down
2 changes: 1 addition & 1 deletion src/nessai/flowmodel/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def load_all_weights(self) -> None:
for wf in self.weights_files:
new_flow = configure_model(self.flow_config)
new_flow.device = self.device
new_flow.load_state_dict(torch.load(wf))
new_flow.load_state_dict(torch.load(wf, weights_only=True))
self.models.append(new_flow)
self.models.eval()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_flowmodel/test_flowmodel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def test_load_weights(model, initialised):
model.initialise.assert_not_called()
else:
model.initialise.assert_called_once()
mock_load.assert_called_once_with(weights_file)
mock_load.assert_called_once_with(weights_file, weights_only=True)
model.model.load_state_dict.assert_called_once_with(d)
model.model.eval.assert_called_once()
assert model.weights_file == weights_file
Expand Down

0 comments on commit b1e88f3

Please sign in to comment.