Skip to content

Commit c8fab8a

Browse files
committed
Resolve conflicts during merge
1 parent b2d5196 commit c8fab8a

File tree

2 files changed

+187
-63
lines changed

2 files changed

+187
-63
lines changed

src/openfermion/utils/operator_utils.py

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
# limitations under the License.
1212
"""This module provides generic tools for classes in ops/"""
1313
from builtins import map, zip
14-
import json
14+
import marshal
1515
import os
16-
from ast import literal_eval
1716

1817
import numpy
1918
import sympy
@@ -37,6 +36,12 @@
3736
from openfermion.transforms.opconversions.term_reordering import normal_ordered
3837

3938

39+
# Maximum size allowed for data files read by load_operator(). This is a (weak) safety
40+
# measure against corrupted or insecure files.
41+
_MAX_TEXT_OPERATOR_DATA = 5 * 1024 * 1024
42+
_MAX_BINARY_OPERATOR_DATA = 1024 * 1024
43+
44+
4045
class OperatorUtilsError(Exception):
4146
pass
4247

@@ -247,27 +252,45 @@ def get_file_path(file_name, data_directory):
247252

248253

249254
def load_operator(file_name=None, data_directory=None, plain_text=False):
250-
"""Load FermionOperator or QubitOperator from file.
255+
"""Load an operator (such as a FermionOperator) from a file.
251256
252257
Args:
253-
file_name: The name of the saved file.
258+
file_name: The name of the data file to read.
254259
data_directory: Optional data directory to change from default data
255-
directory specified in config file.
260+
directory specified in config file.
256261
plain_text: Whether the input file is plain text
257262
258263
Returns:
259264
operator: The stored FermionOperator, BosonOperator,
260-
QuadOperator, or QubitOperator
265+
QuadOperator, or QubitOperator.
261266
262267
Raises:
263268
TypeError: Operator of invalid type.
269+
ValueError: If the file is larger than the maximum allowed.
270+
ValueError: If the file content is not as expected or loading fails.
271+
IOError: If the file cannot be opened.
272+
273+
Warning:
274+
Loading from binary files (plain_text=False) uses the Python 'marshal'
275+
module, which is not secure against untrusted or maliciously crafted
276+
data. Only load binary operator files from sources that you trust.
277+
Prefer using the plain_text format for data from untrusted sources.
264278
"""
279+
265280
file_path = get_file_path(file_name, data_directory)
266281

282+
operator_type = None
283+
operator_terms = None
284+
267285
if plain_text:
268286
with open(file_path, 'r') as f:
269-
data = f.read()
270-
operator_type, operator_terms = data.split(":\n")
287+
data = f.read(_MAX_TEXT_OPERATOR_DATA)
288+
try:
289+
operator_type, operator_terms = data.split(":\n", 1)
290+
except ValueError:
291+
raise ValueError(
292+
"Invalid format in plain-text data file {file_path}: " "expected 'TYPE:\\nTERMS'"
293+
)
271294

272295
if operator_type == 'FermionOperator':
273296
operator = FermionOperator(operator_terms)
@@ -278,15 +301,24 @@ def load_operator(file_name=None, data_directory=None, plain_text=False):
278301
elif operator_type == 'QuadOperator':
279302
operator = QuadOperator(operator_terms)
280303
else:
281-
raise TypeError('Operator of invalid type.')
304+
raise TypeError(
305+
f"Invalid operator type '{operator_type}' encountered "
306+
f"found in plain-text data file '{file_path}'."
307+
)
282308
else:
283-
with open(file_path, 'r') as f:
284-
data = json.load(f)
285-
operator_type, serializable_terms = data
286-
operator_terms = {
287-
literal_eval(key): complex(value[0], value[1])
288-
for key, value in serializable_terms.items()
289-
}
309+
# marshal.load() doesn't have a size parameter, so we test it ourselves.
310+
if os.path.getsize(file_path) > _MAX_BINARY_OPERATOR_DATA:
311+
raise ValueError(
312+
f"Size of {file_path} exceeds maximum allowed "
313+
f"({_MAX_BINARY_OPERATOR_DATA} bytes)."
314+
)
315+
try:
316+
with open(file_path, 'rb') as f:
317+
raw_data = marshal.load(f)
318+
except Exception as e:
319+
raise ValueError(f"Failed to load marshaled data from {file_path}: {e}")
320+
321+
operator_type, operator_terms = _validate_operator_data(raw_data)
290322

291323
if operator_type == 'FermionOperator':
292324
operator = FermionOperator()
@@ -313,17 +345,17 @@ def load_operator(file_name=None, data_directory=None, plain_text=False):
313345
def save_operator(
314346
operator, file_name=None, data_directory=None, allow_overwrite=False, plain_text=False
315347
):
316-
"""Save FermionOperator or QubitOperator to file.
348+
"""Save an operator (such as a FermionOperator) to a file.
317349
318350
Args:
319351
operator: An instance of FermionOperator, BosonOperator,
320352
or QubitOperator.
321353
file_name: The name of the saved file.
322354
data_directory: Optional data directory to change from default data
323-
directory specified in config file.
355+
directory specified in config file.
324356
allow_overwrite: Whether to allow files to be overwritten.
325357
plain_text: Whether the operator should be saved to a
326-
plain-text format for manual analysis
358+
plain-text format for manual analysis.
327359
328360
Raises:
329361
OperatorUtilsError: Not saved, file already exists.
@@ -360,6 +392,55 @@ def save_operator(
360392
f.write(operator_type + ":\n" + str(operator))
361393
else:
362394
tm = operator.terms
363-
serializable_terms = {str(key): (value.real, value.imag) for key, value in tm.items()}
364-
with open(file_path, 'w') as f:
365-
json.dump((operator_type, serializable_terms), f)
395+
with open(file_path, 'wb') as f:
396+
marshal.dump((operator_type, dict(zip(tm.keys(), map(complex, tm.values())))), f)
397+
398+
399+
def _validate_operator_data(raw_data):
400+
"""Validates the structure and types of data loaded using marshal.
401+
402+
The file is expected to contain a tuple of (type, data), where the
403+
"type" is one of the currently-supported operators, and "data" is a dict.
404+
405+
Args:
406+
raw_data: text or binary data.
407+
408+
Returns:
409+
tuple(str, dict) where the 0th element is the name of the operator
410+
type (e.g., 'FermionOperator') and the dict is the operator data.
411+
412+
Raises:
413+
TypeError: raw_data did not contain a tuple of length 2.
414+
TypeError: the first element of the tuple is not a string.
415+
TypeError: the second element of the tuple is not a dict.
416+
TypeError: the given operator type is not supported.
417+
"""
418+
419+
if not isinstance(raw_data, tuple) or len(raw_data) != 2:
420+
raise TypeError(
421+
f"Invalid marshaled structure: Expected a tuple "
422+
"of length 2, but got {type(raw_data)} instead."
423+
)
424+
425+
operator_type, operator_terms = raw_data
426+
427+
if not isinstance(operator_type, str):
428+
raise TypeError(
429+
f"Invalid type for operator_type: Expected str but "
430+
"got type {type(operator_type)} instead."
431+
)
432+
433+
allowed = {'FermionOperator', 'BosonOperator', 'QubitOperator', 'QuadOperator'}
434+
if operator_type not in allowed:
435+
raise TypeError(
436+
f"Operator type '{operator_type}' is not supported. "
437+
"The operator must be one of {allowed}."
438+
)
439+
440+
if not isinstance(operator_terms, dict):
441+
raise TypeError(
442+
f"Invalid type for operator_terms: Expected dict "
443+
"but got type {type(operator_terms)} instead."
444+
)
445+
446+
return operator_type, operator_terms

src/openfermion/utils/operator_utils_test.py

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
"""Tests for operator_utils."""
1313

1414
import os
15-
import json
1615

1716
import itertools
1817

@@ -587,53 +586,97 @@ def test_overwrite_flag_save_on_top_of_existing_operator(self):
587586

588587
self.assertEqual(fermion_operator, self.fermion_operator)
589588

590-
def test_load_bad_type(self):
591-
with self.assertRaises(TypeError):
592-
_ = load_operator('bad_type_operator')
589+
def test_load_nonexistent_file_raises_error(self):
590+
with self.assertRaises(FileNotFoundError):
591+
load_operator('non_existent_file_for_testing')
593592

594593
def test_save_bad_type(self):
595594
with self.assertRaises(TypeError):
596595
save_operator('ping', 'somewhere')
597596

598-
def test_save_and_load_complex_json(self):
599-
fermion_op = FermionOperator('1^ 2', 1 + 2j)
600-
boson_op = BosonOperator('1^ 2', 1 + 2j)
601-
qubit_op = QubitOperator('X1 Y2', 1 + 2j)
602-
quad_op = QuadOperator('q1 p2', 1 + 2j)
603-
604-
save_operator(fermion_op, self.file_name)
605-
loaded_op = load_operator(self.file_name)
606-
self.assertEqual(fermion_op, loaded_op)
607-
608-
save_operator(boson_op, self.file_name, allow_overwrite=True)
609-
loaded_op = load_operator(self.file_name)
610-
self.assertEqual(boson_op, loaded_op)
611-
612-
save_operator(qubit_op, self.file_name, allow_overwrite=True)
613-
loaded_op = load_operator(self.file_name)
614-
self.assertEqual(qubit_op, loaded_op)
615-
616-
save_operator(quad_op, self.file_name, allow_overwrite=True)
617-
loaded_op = load_operator(self.file_name)
618-
self.assertEqual(quad_op, loaded_op)
619597

620-
def test_saved_json_content(self):
621-
import json
622-
623-
qubit_op = QubitOperator('X1 Y2', 1 + 2j)
624-
save_operator(qubit_op, self.file_name)
625-
626-
file_path = get_file_path(self.file_name, None)
627-
with open(file_path, 'r') as f:
628-
data = json.load(f)
629-
630-
self.assertEqual(len(data), 2)
631-
self.assertEqual(data[0], 'QubitOperator')
598+
class LoadOperatorTest(unittest.TestCase):
599+
def setUp(self):
600+
self.file_name = "test_load_operator_file.data"
601+
self.file_path = os.path.join(DATA_DIRECTORY, self.file_name)
632602

633-
# The key is stringified tuple
634-
# The value is a list [real, imag]
635-
expected_terms = {"((1, 'X'), (2, 'Y'))": [1.0, 2.0]}
636-
self.assertEqual(data[1], expected_terms)
603+
def tearDown(self):
604+
if os.path.isfile(self.file_path):
605+
os.remove(self.file_path)
606+
607+
def test_load_plain_text_invalid_format_raises_value_error(self):
608+
with open(self.file_path, 'w') as f:
609+
f.write("some text without the required separator")
610+
with self.assertRaises(ValueError) as cm:
611+
load_operator(self.file_name, plain_text=True)
612+
self.assertIn("expected 'TYPE:\\nTERMS'", str(cm.exception))
613+
614+
def test_load_binary_too_large_raises_value_error(self):
615+
with open(self.file_path, 'wb') as f:
616+
f.write(b'data')
617+
618+
original_getsize = os.path.getsize
619+
620+
def mock_getsize(path):
621+
from openfermion.utils.operator_utils import _MAX_BINARY_OPERATOR_DATA
622+
if path == self.file_path:
623+
return _MAX_BINARY_OPERATOR_DATA + 1
624+
return original_getsize(path)
625+
626+
os.path.getsize = mock_getsize
627+
try:
628+
with self.assertRaises(ValueError) as cm:
629+
load_operator(self.file_name, plain_text=False)
630+
self.assertIn("exceeds maximum allowed", str(cm.exception))
631+
finally:
632+
os.path.getsize = original_getsize
633+
634+
def test_load_binary_corrupted_data_raises_value_error(self):
635+
with open(self.file_path, 'wb') as f:
636+
f.write(b"corrupted marshal data")
637+
with self.assertRaises(ValueError) as cm:
638+
load_operator(self.file_name, plain_text=False)
639+
self.assertIn("Failed to load marshaled data", str(cm.exception))
640+
641+
def test_load_binary_invalid_structure_not_tuple_raises_type_error(self):
642+
import marshal
643+
with open(self.file_path, 'wb') as f:
644+
marshal.dump("this is not a tuple", f)
645+
with self.assertRaises(TypeError) as cm:
646+
load_operator(self.file_name, plain_text=False)
647+
self.assertIn("Expected a tuple", str(cm.exception))
648+
649+
def test_load_binary_invalid_structure_wrong_len_tuple_raises_type_error(self):
650+
import marshal
651+
with open(self.file_path, 'wb') as f:
652+
marshal.dump(("one", "two", "three"), f)
653+
with self.assertRaises(TypeError) as cm:
654+
load_operator(self.file_name, plain_text=False)
655+
self.assertIn("Expected a tuple of length 2", str(cm.exception))
656+
657+
def test_load_binary_invalid_operator_type_not_string_raises_type_error(self):
658+
import marshal
659+
with open(self.file_path, 'wb') as f:
660+
marshal.dump((123, {}), f)
661+
with self.assertRaises(TypeError) as cm:
662+
load_operator(self.file_name, plain_text=False)
663+
self.assertIn("Expected str but got type", str(cm.exception))
664+
665+
def test_load_binary_invalid_terms_type_not_dict_raises_type_error(self):
666+
import marshal
667+
with open(self.file_path, 'wb') as f:
668+
marshal.dump(("FermionOperator", ["a", "list"]), f)
669+
with self.assertRaises(TypeError) as cm:
670+
load_operator(self.file_name, plain_text=False)
671+
self.assertIn("Expected dict but got type", str(cm.exception))
672+
673+
def test_load_binary_unsupported_operator_type_raises_type_error(self):
674+
import marshal
675+
with open(self.file_path, 'wb') as f:
676+
marshal.dump(("UnsupportedOperator", {}), f)
677+
with self.assertRaises(TypeError) as cm:
678+
load_operator(self.file_name, plain_text=False)
679+
self.assertIn("is not supported", str(cm.exception))
637680

638681

639682
class GetFileDirTest(unittest.TestCase):

0 commit comments

Comments
 (0)