Skip to content

Commit 2fa6f58

Browse files
committed
Resolve conflicts during merge
1 parent b2d5196 commit 2fa6f58

File tree

2 files changed

+193
-25
lines changed

2 files changed

+193
-25
lines changed

src/openfermion/utils/operator_utils.py

Lines changed: 106 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
# limitations under the License.
1212
"""This module provides generic tools for classes in ops/"""
1313
from builtins import map, zip
14-
import json
14+
import marshal
1515
import os
16-
from ast import literal_eval
1716

1817
import numpy
1918
import sympy
@@ -37,6 +36,12 @@
3736
from openfermion.transforms.opconversions.term_reordering import normal_ordered
3837

3938

39+
# Maximum size allowed for data files read by load_operator(). This is a (weak) safety
40+
# measure against corrupted or insecure files.
41+
_MAX_TEXT_OPERATOR_DATA = 5 * 1024 * 1024
42+
_MAX_BINARY_OPERATOR_DATA = 1024 * 1024
43+
44+
4045
class OperatorUtilsError(Exception):
4146
pass
4247

@@ -247,27 +252,45 @@ def get_file_path(file_name, data_directory):
247252

248253

249254
def load_operator(file_name=None, data_directory=None, plain_text=False):
250-
"""Load FermionOperator or QubitOperator from file.
255+
"""Load an operator (such as a FermionOperator) from a file.
251256
252257
Args:
253-
file_name: The name of the saved file.
258+
file_name: The name of the data file to read.
254259
data_directory: Optional data directory to change from default data
255-
directory specified in config file.
260+
directory specified in config file.
256261
plain_text: Whether the input file is plain text
257262
258263
Returns:
259264
operator: The stored FermionOperator, BosonOperator,
260-
QuadOperator, or QubitOperator
265+
QuadOperator, or QubitOperator.
261266
262267
Raises:
263268
TypeError: Operator of invalid type.
269+
ValueError: If the file is larger than the maximum allowed.
270+
ValueError: If the file content is not as expected or loading fails.
271+
IOError: If the file cannot be opened.
272+
273+
Warning:
274+
Loading from binary files (plain_text=False) uses the Python 'marshal'
275+
module, which is not secure against untrusted or maliciously crafted
276+
data. Only load binary operator files from sources that you trust.
277+
Prefer using the plain_text format for data from untrusted sources.
264278
"""
279+
265280
file_path = get_file_path(file_name, data_directory)
266281

282+
operator_type = None
283+
operator_terms = None
284+
267285
if plain_text:
268286
with open(file_path, 'r') as f:
269-
data = f.read()
270-
operator_type, operator_terms = data.split(":\n")
287+
data = f.read(_MAX_TEXT_OPERATOR_DATA)
288+
try:
289+
operator_type, operator_terms = data.split(":\n", 1)
290+
except ValueError:
291+
raise ValueError(
292+
"Invalid format in plain-text data file {file_path}: " "expected 'TYPE:\\nTERMS'"
293+
)
271294

272295
if operator_type == 'FermionOperator':
273296
operator = FermionOperator(operator_terms)
@@ -278,15 +301,27 @@ def load_operator(file_name=None, data_directory=None, plain_text=False):
278301
elif operator_type == 'QuadOperator':
279302
operator = QuadOperator(operator_terms)
280303
else:
281-
raise TypeError('Operator of invalid type.')
304+
raise TypeError(
305+
f"Invalid operator type '{operator_type}' encountered "
306+
f"found in plain-text data file '{file_path}'."
307+
)
282308
else:
283-
with open(file_path, 'r') as f:
284-
data = json.load(f)
285-
operator_type, serializable_terms = data
286-
operator_terms = {
287-
literal_eval(key): complex(value[0], value[1])
288-
for key, value in serializable_terms.items()
289-
}
309+
# marshal.load() doesn't have a size parameter, so we test it ourselves.
310+
if os.path.getsize(file_path) > _MAX_BINARY_OPERATOR_DATA:
311+
raise ValueError(
312+
f"Size of {file_path} exceeds maximum allowed "
313+
f"({_MAX_BINARY_OPERATOR_DATA} bytes)."
314+
)
315+
try:
316+
with open(file_path, 'rb') as f:
317+
raw_data = marshal.load(f)
318+
except Exception as e:
319+
raise ValueError(f"Failed to load marshaled data from {file_path}: {e}")
320+
321+
operator_type, operator_terms = _validate_operator_data(raw_data)
322+
323+
with open(file_path, 'rb') as f:
324+
operator_type, operator_terms = marshal.load(f)
290325

291326
if operator_type == 'FermionOperator':
292327
operator = FermionOperator()
@@ -313,17 +348,17 @@ def load_operator(file_name=None, data_directory=None, plain_text=False):
313348
def save_operator(
314349
operator, file_name=None, data_directory=None, allow_overwrite=False, plain_text=False
315350
):
316-
"""Save FermionOperator or QubitOperator to file.
351+
"""Save an operator (such as a FermionOperator) to a file.
317352
318353
Args:
319354
operator: An instance of FermionOperator, BosonOperator,
320355
or QubitOperator.
321356
file_name: The name of the saved file.
322357
data_directory: Optional data directory to change from default data
323-
directory specified in config file.
358+
directory specified in config file.
324359
allow_overwrite: Whether to allow files to be overwritten.
325360
plain_text: Whether the operator should be saved to a
326-
plain-text format for manual analysis
361+
plain-text format for manual analysis.
327362
328363
Raises:
329364
OperatorUtilsError: Not saved, file already exists.
@@ -360,6 +395,55 @@ def save_operator(
360395
f.write(operator_type + ":\n" + str(operator))
361396
else:
362397
tm = operator.terms
363-
serializable_terms = {str(key): (value.real, value.imag) for key, value in tm.items()}
364-
with open(file_path, 'w') as f:
365-
json.dump((operator_type, serializable_terms), f)
398+
with open(file_path, 'wb') as f:
399+
marshal.dump((operator_type, dict(zip(tm.keys(), map(complex, tm.values())))), f)
400+
401+
402+
def _validate_operator_data(raw_data):
403+
"""Validates the structure and types of data loaded using marshal.
404+
405+
The file is expected to contain a tuple of (type, data), where the
406+
"type" is one of the currently-supported operators, and "data" is a dict.
407+
408+
Args:
409+
raw_data: text or binary data.
410+
411+
Returns:
412+
tuple(str, dict) where the 0th element is the name of the operator
413+
type (e.g., 'FermionOperator') and the dict is the operator data.
414+
415+
Raises:
416+
TypeError: raw_data did not contain a tuple of length 2.
417+
TypeError: the first element of the tuple is not a string.
418+
TypeError: the second element of the tuple is not a dict.
419+
TypeError: the given operator type is not supported.
420+
"""
421+
422+
if not isinstance(raw_data, tuple) or len(raw_data) != 2:
423+
raise TypeError(
424+
f"Invalid marshaled structure: Expected a tuple "
425+
"of length 2, but got {type(raw_data)} instead."
426+
)
427+
428+
operator_type, operator_terms = raw_data
429+
430+
if not isinstance(operator_type, str):
431+
raise TypeError(
432+
f"Invalid type for operator_type: Expected str but "
433+
"got type {type(operator_type)} instead."
434+
)
435+
436+
allowed = {'FermionOperator', 'BosonOperator', 'QubitOperator', 'QuadOperator'}
437+
if operator_type not in allowed:
438+
raise TypeError(
439+
f"Operator type '{operator_type}' is not supported. "
440+
"The operator must be one of {allowed}."
441+
)
442+
443+
if not isinstance(operator_terms, dict):
444+
raise TypeError(
445+
f"Invalid type for operator_terms: Expected dict "
446+
"but got type {type(operator_terms)} instead."
447+
)
448+
449+
return operator_type, operator_terms

src/openfermion/utils/operator_utils_test.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,9 @@ def test_overwrite_flag_save_on_top_of_existing_operator(self):
587587

588588
self.assertEqual(fermion_operator, self.fermion_operator)
589589

590-
def test_load_bad_type(self):
591-
with self.assertRaises(TypeError):
592-
_ = load_operator('bad_type_operator')
590+
def test_load_nonexistent_file_raises_error(self):
591+
with self.assertRaises(FileNotFoundError):
592+
load_operator('non_existent_file_for_testing')
593593

594594
def test_save_bad_type(self):
595595
with self.assertRaises(TypeError):
@@ -636,6 +636,90 @@ def test_saved_json_content(self):
636636
self.assertEqual(data[1], expected_terms)
637637

638638

639+
class LoadOperatorTest(unittest.TestCase):
640+
def setUp(self):
641+
self.file_name = "test_load_operator_file.data"
642+
self.file_path = os.path.join(DATA_DIRECTORY, self.file_name)
643+
644+
def tearDown(self):
645+
if os.path.isfile(self.file_path):
646+
os.remove(self.file_path)
647+
648+
def test_load_plain_text_invalid_format_raises_value_error(self):
649+
with open(self.file_path, 'w') as f:
650+
f.write("some text without the required separator")
651+
with self.assertRaises(ValueError) as cm:
652+
load_operator(self.file_name, plain_text=True)
653+
self.assertIn("expected 'TYPE:\\nTERMS'", str(cm.exception))
654+
655+
def test_load_binary_too_large_raises_value_error(self):
656+
with open(self.file_path, 'wb') as f:
657+
f.write(b'data')
658+
659+
original_getsize = os.path.getsize
660+
661+
def mock_getsize(path):
662+
from openfermion.utils.operator_utils import _MAX_BINARY_OPERATOR_DATA
663+
if path == self.file_path:
664+
return _MAX_BINARY_OPERATOR_DATA + 1
665+
return original_getsize(path)
666+
667+
os.path.getsize = mock_getsize
668+
try:
669+
with self.assertRaises(ValueError) as cm:
670+
load_operator(self.file_name, plain_text=False)
671+
self.assertIn("exceeds maximum allowed", str(cm.exception))
672+
finally:
673+
os.path.getsize = original_getsize
674+
675+
def test_load_binary_corrupted_data_raises_value_error(self):
676+
with open(self.file_path, 'wb') as f:
677+
f.write(b"corrupted marshal data")
678+
with self.assertRaises(ValueError) as cm:
679+
load_operator(self.file_name, plain_text=False)
680+
self.assertIn("Failed to load marshaled data", str(cm.exception))
681+
682+
def test_load_binary_invalid_structure_not_tuple_raises_type_error(self):
683+
import marshal
684+
with open(self.file_path, 'wb') as f:
685+
marshal.dump("this is not a tuple", f)
686+
with self.assertRaises(TypeError) as cm:
687+
load_operator(self.file_name, plain_text=False)
688+
self.assertIn("Expected a tuple", str(cm.exception))
689+
690+
def test_load_binary_invalid_structure_wrong_len_tuple_raises_type_error(self):
691+
import marshal
692+
with open(self.file_path, 'wb') as f:
693+
marshal.dump(("one", "two", "three"), f)
694+
with self.assertRaises(TypeError) as cm:
695+
load_operator(self.file_name, plain_text=False)
696+
self.assertIn("Expected a tuple of length 2", str(cm.exception))
697+
698+
def test_load_binary_invalid_operator_type_not_string_raises_type_error(self):
699+
import marshal
700+
with open(self.file_path, 'wb') as f:
701+
marshal.dump((123, {}), f)
702+
with self.assertRaises(TypeError) as cm:
703+
load_operator(self.file_name, plain_text=False)
704+
self.assertIn("Expected str but got type", str(cm.exception))
705+
706+
def test_load_binary_invalid_terms_type_not_dict_raises_type_error(self):
707+
import marshal
708+
with open(self.file_path, 'wb') as f:
709+
marshal.dump(("FermionOperator", ["a", "list"]), f)
710+
with self.assertRaises(TypeError) as cm:
711+
load_operator(self.file_name, plain_text=False)
712+
self.assertIn("Expected dict but got type", str(cm.exception))
713+
714+
def test_load_binary_unsupported_operator_type_raises_type_error(self):
715+
import marshal
716+
with open(self.file_path, 'wb') as f:
717+
marshal.dump(("UnsupportedOperator", {}), f)
718+
with self.assertRaises(TypeError) as cm:
719+
load_operator(self.file_name, plain_text=False)
720+
self.assertIn("is not supported", str(cm.exception))
721+
722+
639723
class GetFileDirTest(unittest.TestCase):
640724
def setUp(self):
641725
self.filename = 'foo'

0 commit comments

Comments
 (0)