Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 102 additions & 17 deletions src/openfermion/utils/operator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides generic tools for classes in ops/"""
from builtins import map, zip

import marshal
import os

Expand All @@ -22,19 +22,24 @@
from openfermion.ops.operators import (
BosonOperator,
FermionOperator,
IsingOperator,
MajoranaOperator,
QuadOperator,
QubitOperator,
IsingOperator,
)
from openfermion.ops.representations import (
PolynomialTensor,
DiagonalCoulombHamiltonian,
InteractionOperator,
InteractionRDM,
PolynomialTensor,
)
from openfermion.transforms.opconversions.term_reordering import normal_ordered

# Maximum size allowed for data files read by load_operator(). This is a (weak) safety
# measure against corrupted or insecure files.
_MAX_TEXT_OPERATOR_DATA = 5 * 1024 * 1024
_MAX_BINARY_OPERATOR_DATA = 1024 * 1024


class OperatorUtilsError(Exception):
pass
Expand Down Expand Up @@ -246,27 +251,45 @@


def load_operator(file_name=None, data_directory=None, plain_text=False):
"""Load FermionOperator or QubitOperator from file.
"""Load an operator (such as a FermionOperator) from a file.

Args:
file_name: The name of the saved file.
file_name: The name of the data file to read.
data_directory: Optional data directory to change from default data
directory specified in config file.
directory specified in config file.
plain_text: Whether the input file is plain text

Returns:
operator: The stored FermionOperator, BosonOperator,
QuadOperator, or QubitOperator
QuadOperator, or QubitOperator.

Raises:
TypeError: Operator of invalid type.
ValueError: If the file is larger than the maximum allowed.
ValueError: If the file content is not as expected or loading fails.
IOError: If the file cannot be opened.

Warning:
Loading from binary files (plain_text=False) uses the Python 'marshal'
module, which is not secure against untrusted or maliciously crafted
data. Only load binary operator files from sources that you trust.
Prefer using the plain_text format for data from untrusted sources.
"""

file_path = get_file_path(file_name, data_directory)

operator_type = None
operator_terms = None

if plain_text:
with open(file_path, 'r') as f:
data = f.read()
operator_type, operator_terms = data.split(":\n")
data = f.read(_MAX_TEXT_OPERATOR_DATA)
try:
operator_type, operator_terms = data.split(":\n", 1)
except ValueError:
raise ValueError(
"Invalid format in plain-text data file {file_path}: " "expected 'TYPE:\\nTERMS'"
)

if operator_type == 'FermionOperator':
operator = FermionOperator(operator_terms)
Expand All @@ -277,12 +300,24 @@
elif operator_type == 'QuadOperator':
operator = QuadOperator(operator_terms)
else:
raise TypeError('Operator of invalid type.')
raise TypeError(
f"Invalid operator type '{operator_type}' encountered "
f"found in plain-text data file '{file_path}'."
)
else:
with open(file_path, 'rb') as f:
data = marshal.load(f)
operator_type = data[0]
operator_terms = data[1]
# marshal.load() doesn't have a size parameter, so we test it ourselves.
if os.path.getsize(file_path) > _MAX_BINARY_OPERATOR_DATA:
raise ValueError(
f"Size of {file_path} exceeds maximum allowed "
f"({_MAX_BINARY_OPERATOR_DATA} bytes)."
)
try:
with open(file_path, 'rb') as f:
raw_data = marshal.load(f)
except Exception as e:
raise ValueError(f"Failed to load marshaled data from {file_path}: {e}")

operator_type, operator_terms = _validate_operator_data(raw_data)

if operator_type == 'FermionOperator':
operator = FermionOperator()
Expand All @@ -309,17 +344,17 @@
def save_operator(
operator, file_name=None, data_directory=None, allow_overwrite=False, plain_text=False
):
"""Save FermionOperator or QubitOperator to file.
"""Save an operator (such as a FermionOperator) to a file.

Args:
operator: An instance of FermionOperator, BosonOperator,
or QubitOperator.
file_name: The name of the saved file.
data_directory: Optional data directory to change from default data
directory specified in config file.
directory specified in config file.
allow_overwrite: Whether to allow files to be overwritten.
plain_text: Whether the operator should be saved to a
plain-text format for manual analysis
plain-text format for manual analysis.

Raises:
OperatorUtilsError: Not saved, file already exists.
Expand Down Expand Up @@ -358,3 +393,53 @@
tm = operator.terms
with open(file_path, 'wb') as f:
marshal.dump((operator_type, dict(zip(tm.keys(), map(complex, tm.values())))), f)


def _validate_operator_data(raw_data):
"""Validates the structure and types of data loaded using marshal.

The file is expected to contain a tuple of (type, data), where the
"type" is one of the currently-supported operators, and "data" is a dict.

Args:
raw_data: text or binary data.

Returns:
tuple(str, dict) where the 0th element is the name of the operator
type (e.g., 'FermionOperator') and the dict is the operator data.

Raises:
TypeError: raw_data did not contain a tuple of length 2.
TypeError: the first element of the tuple is not a string.
TypeError: the second element of the tuple is not a dict.
TypeError: the given operator type is not supported.
"""

if not isinstance(raw_data, tuple) or len(raw_data) != 2:
raise TypeError(
f"Invalid marshaled structure: Expected a tuple "
"of length 2, but got {type(raw_data)} instead."
)

operator_type, operator_terms = raw_data

if not isinstance(operator_type, str):
raise TypeError(
f"Invalid type for operator_type: Expected str but "
"got type {type(operator_type)} instead."
)

allowed = {'FermionOperator', 'BosonOperator', 'QubitOperator', 'QuadOperator'}
if operator_type not in allowed:
raise TypeError(
f"Operator type '{operator_type}' is not supported. "
"The operator must be one of {allowed}."
)

if not isinstance(operator_terms, dict):
raise TypeError(
f"Invalid type for operator_terms: Expected dict "
"but got type {type(operator_terms)} instead."
)

return operator_type, operator_terms
111 changes: 94 additions & 17 deletions src/openfermion/utils/operator_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,38 @@
# limitations under the License.
"""Tests for operator_utils."""

import os

import itertools

import marshal
import os
import unittest

import numpy

import sympy

from scipy.sparse import csc_matrix

from openfermion.config import DATA_DIRECTORY
from openfermion.hamiltonians import fermi_hubbard
from openfermion.ops.operators import (
BosonOperator,
FermionOperator,
IsingOperator,
MajoranaOperator,
BosonOperator,
QubitOperator,
QuadOperator,
IsingOperator,
QubitOperator,
)
from openfermion.ops.representations import InteractionOperator
from openfermion.transforms.opconversions import jordan_wigner, bravyi_kitaev
from openfermion.transforms.repconversions import get_interaction_operator
from openfermion.testing.testing_utils import random_interaction_operator
from openfermion.transforms.opconversions import bravyi_kitaev, jordan_wigner
from openfermion.transforms.repconversions import get_interaction_operator
from openfermion.utils.operator_utils import (
OperatorUtilsError,
count_qubits,
get_file_path,
hermitian_conjugated,
is_identity,
save_operator,
OperatorUtilsError,
is_hermitian,
is_identity,
load_operator,
get_file_path,
save_operator,
)


Expand Down Expand Up @@ -586,15 +583,95 @@ def test_overwrite_flag_save_on_top_of_existing_operator(self):

self.assertEqual(fermion_operator, self.fermion_operator)

def test_load_bad_type(self):
with self.assertRaises(TypeError):
_ = load_operator('bad_type_operator')
def test_load_nonexistent_file_raises_error(self):
with self.assertRaises(FileNotFoundError):
load_operator('non_existent_file_for_testing')

def test_save_bad_type(self):
with self.assertRaises(TypeError):
save_operator('ping', 'somewhere')


class LoadOperatorTest(unittest.TestCase):
def setUp(self):
self.file_name = "test_load_operator_file.data"
self.file_path = os.path.join(DATA_DIRECTORY, self.file_name)

def tearDown(self):
if os.path.isfile(self.file_path):
os.remove(self.file_path)

def test_load_plain_text_invalid_format_raises_value_error(self):
with open(self.file_path, 'w') as f:
f.write("some text without the required separator")
with self.assertRaises(ValueError) as cm:
load_operator(self.file_name, plain_text=True)
self.assertIn("expected 'TYPE:\\nTERMS'", str(cm.exception))

def test_load_binary_too_large_raises_value_error(self):
with open(self.file_path, 'wb') as f:
f.write(b'data')

original_getsize = os.path.getsize

def mock_getsize(path):
from openfermion.utils.operator_utils import _MAX_BINARY_OPERATOR_DATA

if path == self.file_path:
return _MAX_BINARY_OPERATOR_DATA + 1
return original_getsize(path) # pragma: no cover

os.path.getsize = mock_getsize
try:
with self.assertRaises(ValueError) as cm:
load_operator(self.file_name, plain_text=False)
self.assertIn("exceeds maximum allowed", str(cm.exception))
finally:
os.path.getsize = original_getsize

def test_load_binary_corrupted_data_raises_value_error(self):
with open(self.file_path, 'wb') as f:
f.write(b"corrupted marshal data")
with self.assertRaises(ValueError) as cm:
load_operator(self.file_name, plain_text=False)
self.assertIn("Failed to load marshaled data", str(cm.exception))

def test_load_binary_invalid_structure_not_tuple_raises_type_error(self):
with open(self.file_path, 'wb') as f:
marshal.dump("this is not a tuple", f)
with self.assertRaises(TypeError) as cm:
load_operator(self.file_name, plain_text=False)
self.assertIn("Expected a tuple", str(cm.exception))

def test_load_binary_invalid_structure_wrong_len_tuple_raises_type_error(self):
with open(self.file_path, 'wb') as f:
marshal.dump(("one", "two", "three"), f)
with self.assertRaises(TypeError) as cm:
load_operator(self.file_name, plain_text=False)
self.assertIn("Expected a tuple of length 2", str(cm.exception))

def test_load_binary_invalid_operator_type_not_string_raises_type_error(self):
with open(self.file_path, 'wb') as f:
marshal.dump((123, {}), f)
with self.assertRaises(TypeError) as cm:
load_operator(self.file_name, plain_text=False)
self.assertIn("Expected str but got type", str(cm.exception))

def test_load_binary_invalid_terms_type_not_dict_raises_type_error(self):
with open(self.file_path, 'wb') as f:
marshal.dump(("FermionOperator", ["a", "list"]), f)
with self.assertRaises(TypeError) as cm:
load_operator(self.file_name, plain_text=False)
self.assertIn("Expected dict but got type", str(cm.exception))

def test_load_binary_unsupported_operator_type_raises_type_error(self):
with open(self.file_path, 'wb') as f:
marshal.dump(("UnsupportedOperator", {}), f)
with self.assertRaises(TypeError) as cm:
load_operator(self.file_name, plain_text=False)
self.assertIn("is not supported", str(cm.exception))


class GetFileDirTest(unittest.TestCase):
def setUp(self):
self.filename = 'foo'
Expand Down
Loading