diff --git a/src/openfermion/utils/operator_utils.py b/src/openfermion/utils/operator_utils.py index 82d2e345e..6798f3407 100644 --- a/src/openfermion/utils/operator_utils.py +++ b/src/openfermion/utils/operator_utils.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """This module provides generic tools for classes in ops/""" -from builtins import map, zip + import marshal import os @@ -22,19 +22,24 @@ from openfermion.ops.operators import ( BosonOperator, FermionOperator, + IsingOperator, MajoranaOperator, QuadOperator, QubitOperator, - IsingOperator, ) from openfermion.ops.representations import ( - PolynomialTensor, DiagonalCoulombHamiltonian, InteractionOperator, InteractionRDM, + PolynomialTensor, ) from openfermion.transforms.opconversions.term_reordering import normal_ordered +# Maximum size allowed for data files read by load_operator(). This is a (weak) safety +# measure against corrupted or insecure files. +_MAX_TEXT_OPERATOR_DATA = 5 * 1024 * 1024 +_MAX_BINARY_OPERATOR_DATA = 1024 * 1024 + class OperatorUtilsError(Exception): pass @@ -246,27 +251,45 @@ def get_file_path(file_name, data_directory): def load_operator(file_name=None, data_directory=None, plain_text=False): - """Load FermionOperator or QubitOperator from file. + """Load an operator (such as a FermionOperator) from a file. Args: - file_name: The name of the saved file. + file_name: The name of the data file to read. data_directory: Optional data directory to change from default data - directory specified in config file. + directory specified in config file. plain_text: Whether the input file is plain text Returns: operator: The stored FermionOperator, BosonOperator, - QuadOperator, or QubitOperator + QuadOperator, or QubitOperator. Raises: TypeError: Operator of invalid type. + ValueError: If the file is larger than the maximum allowed. + ValueError: If the file content is not as expected or loading fails. + IOError: If the file cannot be opened. + + Warning: + Loading from binary files (plain_text=False) uses the Python 'marshal' + module, which is not secure against untrusted or maliciously crafted + data. Only load binary operator files from sources that you trust. + Prefer using the plain_text format for data from untrusted sources. """ + file_path = get_file_path(file_name, data_directory) + operator_type = None + operator_terms = None + if plain_text: with open(file_path, 'r') as f: - data = f.read() - operator_type, operator_terms = data.split(":\n") + data = f.read(_MAX_TEXT_OPERATOR_DATA) + try: + operator_type, operator_terms = data.split(":\n", 1) + except ValueError: + raise ValueError( + "Invalid format in plain-text data file {file_path}: " "expected 'TYPE:\\nTERMS'" + ) if operator_type == 'FermionOperator': operator = FermionOperator(operator_terms) @@ -277,12 +300,24 @@ def load_operator(file_name=None, data_directory=None, plain_text=False): elif operator_type == 'QuadOperator': operator = QuadOperator(operator_terms) else: - raise TypeError('Operator of invalid type.') + raise TypeError( + f"Invalid operator type '{operator_type}' encountered " + f"found in plain-text data file '{file_path}'." + ) else: - with open(file_path, 'rb') as f: - data = marshal.load(f) - operator_type = data[0] - operator_terms = data[1] + # marshal.load() doesn't have a size parameter, so we test it ourselves. + if os.path.getsize(file_path) > _MAX_BINARY_OPERATOR_DATA: + raise ValueError( + f"Size of {file_path} exceeds maximum allowed " + f"({_MAX_BINARY_OPERATOR_DATA} bytes)." + ) + try: + with open(file_path, 'rb') as f: + raw_data = marshal.load(f) + except Exception as e: + raise ValueError(f"Failed to load marshaled data from {file_path}: {e}") + + operator_type, operator_terms = _validate_operator_data(raw_data) if operator_type == 'FermionOperator': operator = FermionOperator() @@ -309,17 +344,17 @@ def load_operator(file_name=None, data_directory=None, plain_text=False): def save_operator( operator, file_name=None, data_directory=None, allow_overwrite=False, plain_text=False ): - """Save FermionOperator or QubitOperator to file. + """Save an operator (such as a FermionOperator) to a file. Args: operator: An instance of FermionOperator, BosonOperator, or QubitOperator. file_name: The name of the saved file. data_directory: Optional data directory to change from default data - directory specified in config file. + directory specified in config file. allow_overwrite: Whether to allow files to be overwritten. plain_text: Whether the operator should be saved to a - plain-text format for manual analysis + plain-text format for manual analysis. Raises: OperatorUtilsError: Not saved, file already exists. @@ -358,3 +393,53 @@ def save_operator( tm = operator.terms with open(file_path, 'wb') as f: marshal.dump((operator_type, dict(zip(tm.keys(), map(complex, tm.values())))), f) + + +def _validate_operator_data(raw_data): + """Validates the structure and types of data loaded using marshal. + + The file is expected to contain a tuple of (type, data), where the + "type" is one of the currently-supported operators, and "data" is a dict. + + Args: + raw_data: text or binary data. + + Returns: + tuple(str, dict) where the 0th element is the name of the operator + type (e.g., 'FermionOperator') and the dict is the operator data. + + Raises: + TypeError: raw_data did not contain a tuple of length 2. + TypeError: the first element of the tuple is not a string. + TypeError: the second element of the tuple is not a dict. + TypeError: the given operator type is not supported. + """ + + if not isinstance(raw_data, tuple) or len(raw_data) != 2: + raise TypeError( + f"Invalid marshaled structure: Expected a tuple " + "of length 2, but got {type(raw_data)} instead." + ) + + operator_type, operator_terms = raw_data + + if not isinstance(operator_type, str): + raise TypeError( + f"Invalid type for operator_type: Expected str but " + "got type {type(operator_type)} instead." + ) + + allowed = {'FermionOperator', 'BosonOperator', 'QubitOperator', 'QuadOperator'} + if operator_type not in allowed: + raise TypeError( + f"Operator type '{operator_type}' is not supported. " + "The operator must be one of {allowed}." + ) + + if not isinstance(operator_terms, dict): + raise TypeError( + f"Invalid type for operator_terms: Expected dict " + "but got type {type(operator_terms)} instead." + ) + + return operator_type, operator_terms diff --git a/src/openfermion/utils/operator_utils_test.py b/src/openfermion/utils/operator_utils_test.py index 95b7d32cb..6d436565d 100644 --- a/src/openfermion/utils/operator_utils_test.py +++ b/src/openfermion/utils/operator_utils_test.py @@ -11,41 +11,38 @@ # limitations under the License. """Tests for operator_utils.""" -import os - import itertools - +import marshal +import os import unittest import numpy - import sympy - from scipy.sparse import csc_matrix from openfermion.config import DATA_DIRECTORY from openfermion.hamiltonians import fermi_hubbard from openfermion.ops.operators import ( + BosonOperator, FermionOperator, + IsingOperator, MajoranaOperator, - BosonOperator, - QubitOperator, QuadOperator, - IsingOperator, + QubitOperator, ) from openfermion.ops.representations import InteractionOperator -from openfermion.transforms.opconversions import jordan_wigner, bravyi_kitaev -from openfermion.transforms.repconversions import get_interaction_operator from openfermion.testing.testing_utils import random_interaction_operator +from openfermion.transforms.opconversions import bravyi_kitaev, jordan_wigner +from openfermion.transforms.repconversions import get_interaction_operator from openfermion.utils.operator_utils import ( + OperatorUtilsError, count_qubits, + get_file_path, hermitian_conjugated, - is_identity, - save_operator, - OperatorUtilsError, is_hermitian, + is_identity, load_operator, - get_file_path, + save_operator, ) @@ -586,15 +583,95 @@ def test_overwrite_flag_save_on_top_of_existing_operator(self): self.assertEqual(fermion_operator, self.fermion_operator) - def test_load_bad_type(self): - with self.assertRaises(TypeError): - _ = load_operator('bad_type_operator') + def test_load_nonexistent_file_raises_error(self): + with self.assertRaises(FileNotFoundError): + load_operator('non_existent_file_for_testing') def test_save_bad_type(self): with self.assertRaises(TypeError): save_operator('ping', 'somewhere') +class LoadOperatorTest(unittest.TestCase): + def setUp(self): + self.file_name = "test_load_operator_file.data" + self.file_path = os.path.join(DATA_DIRECTORY, self.file_name) + + def tearDown(self): + if os.path.isfile(self.file_path): + os.remove(self.file_path) + + def test_load_plain_text_invalid_format_raises_value_error(self): + with open(self.file_path, 'w') as f: + f.write("some text without the required separator") + with self.assertRaises(ValueError) as cm: + load_operator(self.file_name, plain_text=True) + self.assertIn("expected 'TYPE:\\nTERMS'", str(cm.exception)) + + def test_load_binary_too_large_raises_value_error(self): + with open(self.file_path, 'wb') as f: + f.write(b'data') + + original_getsize = os.path.getsize + + def mock_getsize(path): + from openfermion.utils.operator_utils import _MAX_BINARY_OPERATOR_DATA + + if path == self.file_path: + return _MAX_BINARY_OPERATOR_DATA + 1 + return original_getsize(path) # pragma: no cover + + os.path.getsize = mock_getsize + try: + with self.assertRaises(ValueError) as cm: + load_operator(self.file_name, plain_text=False) + self.assertIn("exceeds maximum allowed", str(cm.exception)) + finally: + os.path.getsize = original_getsize + + def test_load_binary_corrupted_data_raises_value_error(self): + with open(self.file_path, 'wb') as f: + f.write(b"corrupted marshal data") + with self.assertRaises(ValueError) as cm: + load_operator(self.file_name, plain_text=False) + self.assertIn("Failed to load marshaled data", str(cm.exception)) + + def test_load_binary_invalid_structure_not_tuple_raises_type_error(self): + with open(self.file_path, 'wb') as f: + marshal.dump("this is not a tuple", f) + with self.assertRaises(TypeError) as cm: + load_operator(self.file_name, plain_text=False) + self.assertIn("Expected a tuple", str(cm.exception)) + + def test_load_binary_invalid_structure_wrong_len_tuple_raises_type_error(self): + with open(self.file_path, 'wb') as f: + marshal.dump(("one", "two", "three"), f) + with self.assertRaises(TypeError) as cm: + load_operator(self.file_name, plain_text=False) + self.assertIn("Expected a tuple of length 2", str(cm.exception)) + + def test_load_binary_invalid_operator_type_not_string_raises_type_error(self): + with open(self.file_path, 'wb') as f: + marshal.dump((123, {}), f) + with self.assertRaises(TypeError) as cm: + load_operator(self.file_name, plain_text=False) + self.assertIn("Expected str but got type", str(cm.exception)) + + def test_load_binary_invalid_terms_type_not_dict_raises_type_error(self): + with open(self.file_path, 'wb') as f: + marshal.dump(("FermionOperator", ["a", "list"]), f) + with self.assertRaises(TypeError) as cm: + load_operator(self.file_name, plain_text=False) + self.assertIn("Expected dict but got type", str(cm.exception)) + + def test_load_binary_unsupported_operator_type_raises_type_error(self): + with open(self.file_path, 'wb') as f: + marshal.dump(("UnsupportedOperator", {}), f) + with self.assertRaises(TypeError) as cm: + load_operator(self.file_name, plain_text=False) + self.assertIn("is not supported", str(cm.exception)) + + class GetFileDirTest(unittest.TestCase): def setUp(self): self.filename = 'foo'