Skip to content

Commit

Permalink
Pytorch Load / Save Plugin (#1114)
Browse files Browse the repository at this point in the history
* Pytorch Load / Save Plugin

This plugin checks for the use of `torch.load` and `torch.save`.
Using `torch.load` with untrusted data can lead to arbitrary code
execution, and improper use of `torch.save` might expose sensitive
data or lead to data corruption.

Signed-off-by: Luke Hinds <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add missing save check

Signed-off-by: Luke Hinds <[email protected]>

* Review fixes from 8b92a02

Signed-off-by: Luke Hinds <[email protected]>

* Fix tox issues

Signed-off-by: Luke Hinds <[email protected]>

* Review fixes

Signed-off-by: Luke Hinds <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_functional.py

* Update bandit/plugins/pytorch_load_save.py

Co-authored-by: Eric Brown <[email protected]>

* Update bandit/plugins/pytorch_load_save.py

Co-authored-by: Eric Brown <[email protected]>

* Update doc/source/plugins/b704_pytorch_load_save.rst

Co-authored-by: Eric Brown <[email protected]>

* Update bandit/plugins/pytorch_load_save.py

Co-authored-by: Eric Brown <[email protected]>

---------

Signed-off-by: Luke Hinds <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Brown <[email protected]>
  • Loading branch information
3 people authored Sep 23, 2024
1 parent 4ac55df commit 36fd650
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 0 deletions.
72 changes: 72 additions & 0 deletions bandit/plugins/pytorch_load_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2024 Stacklok, Inc.
#
# SPDX-License-Identifier: Apache-2.0
r"""
==========================================
B614: Test for unsafe PyTorch load or save
==========================================
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution, and
improper use of `torch.save` might expose sensitive data or lead to data
corruption. A safe alternative is to use `torch.load` with the `safetensors`
library from hugingface, which provides a safe deserialization mechanism.
:Example:
.. code-block:: none
>> Issue: Use of unsafe PyTorch load or save
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
8 another_model.load_state_dict(torch.load('model_weights.pth',
map_location='cpu'))
9
10 print("Model loaded successfully!")
.. seealso::
- https://cwe.mitre.org/data/definitions/94.html
- https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
- https://github.com/huggingface/safetensors
.. versionadded:: 1.7.10
"""
import bandit
from bandit.core import issue
from bandit.core import test_properties as test


@test.checks("Call")
@test.test_id("B614")
def pytorch_load_save(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead
to data corruption.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
if not imported and isinstance(qualname, str):
return

qualname_list = qualname.split(".")
func = qualname_list[-1]
if all(
[
"torch" in qualname_list,
func in ["load", "save"],
not context.check_call_arg_value("map_location", "cpu"),
]
):
return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
text="Use of unsafe PyTorch load or save",
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b704_pytorch_load_save.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----------------------
B614: pytorch_load_save
-----------------------

.. automodule:: bandit.plugins.pytorch_load_save
21 changes: 21 additions & 0 deletions examples/pytorch_load_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Save the model
torch.save(loaded_model.state_dict(), 'model_weights.pth')

# Another example using torch.load with more parameters
another_model = models.resnet18()
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))

# Save the model
torch.save(another_model.state_dict(), 'model_weights.pth')

3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save

# bandit/plugins/trojansource.py
trojansource = bandit.plugins.trojansource:trojansource

Expand Down
8 changes: 8 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,14 @@ def test_tarfile_unsafe_members(self):
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4},
}
self.check_example("pytorch_load_save.py", expect)

def test_trojansource(self):
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 1},
Expand Down

0 comments on commit 36fd650

Please sign in to comment.