Skip to content

Commit

Permalink
add test for parser
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 3, 2024
1 parent e059d36 commit 21fad7e
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _build_process_label(self) -> str:
if "process_label" in self.inputs:
return self.inputs.process_label.value
else:
data = self.get_function_data()
data = self.inputs.function_data.get_dict()
return f"PythonJob<{data['name']}>"

def on_create(self) -> None:
Expand Down
66 changes: 35 additions & 31 deletions src/aiida_pythonjob/parsers/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,57 +28,61 @@ def parse(self, **kwargs):
self.output_list = function_outputs
# first we remove nested outputs, e.g., "add_multiply.add"
top_level_output_list = [output for output in self.output_list if "." not in output["name"]]
exit_code = 0
try:
with self.retrieved.base.repository.open("results.pickle", "rb") as handle:
results = pickle.load(handle)
if isinstance(results, tuple):
if len(top_level_output_list) != len(results):
self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH
return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH
for i in range(len(top_level_output_list)):
top_level_output_list[i]["value"] = self.serialize_output(results[i], top_level_output_list[i])
elif isinstance(results, dict) and len(top_level_output_list) > 1:
elif isinstance(results, dict):
# pop the exit code if it exists
exit_code = results.pop("exit_code", 0)
for output in top_level_output_list:
if output.get("required", False):
if exit_code:
if isinstance(exit_code, dict):
exit_code = ExitCode(exit_code["status"], exit_code["message"])
elif isinstance(exit_code, int):
exit_code = ExitCode(exit_code)
return exit_code

if len(top_level_output_list) > 1:
for output in top_level_output_list:
if output["name"] not in results:
self.exit_codes.ERROR_MISSING_OUTPUT
output["value"] = self.serialize_output(results.pop(output["name"]), output)
# if there are any remaining results, raise an warning
if results:
self.logger.warning(
f"Found extra results that are not included in the output: {results.keys()}"
)
elif isinstance(results, dict) and len(top_level_output_list) == 1:
exit_code = results.pop("exit_code", 0)
# if output name in results, use it
if top_level_output_list[0]["name"] in results:
top_level_output_list[0]["value"] = self.serialize_output(
results[top_level_output_list[0]["name"]],
top_level_output_list[0],
)
# otherwise, we assume the results is the output
else:
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
if output.get("required", True):
return self.exit_codes.ERROR_MISSING_OUTPUT
else:
output["value"] = self.serialize_output(results.pop(output["name"]), output)
# if there are any remaining results, raise an warning
if results:
self.logger.warning(
f"Found extra results that are not included in the output: {results.keys()}"
)
elif len(top_level_output_list) == 1:
# if output name in results, use it
if top_level_output_list[0]["name"] in results:
top_level_output_list[0]["value"] = self.serialize_output(
results[top_level_output_list[0]["name"]],
top_level_output_list[0],
)
# otherwise, we assume the results is the output
else:
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
elif len(top_level_output_list) == 1:
# otherwise, we assume the results is the output
# otherwise it returns a single value, we assume the results is the output
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
else:
raise ValueError("The number of results does not match the number of outputs.")
return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH
for output in top_level_output_list:
self.out(output["name"], output["value"])
if exit_code:
if isinstance(exit_code, dict):
exit_code = ExitCode(exit_code["status"], exit_code["message"])
elif isinstance(exit_code, int):
exit_code = ExitCode(exit_code)
return exit_code
except OSError:
return self.exit_codes.ERROR_READING_OUTPUT_FILE
except ValueError as exception:
self.logger.error(exception)
return self.exit_codes.ERROR_INVALID_OUTPUT
except Exception as exception:
self.logger.error(exception)
return self.exit_codes.ERROR_INVALID_OUTPUT

def find_output(self, name):
"""Find the output with the given name."""
Expand Down
80 changes: 80 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pathlib
import tempfile

import cloudpickle as pickle
from aiida import orm
from aiida.common.links import LinkType
from aiida_pythonjob.parsers import PythonJobParser


def create_retrieved_folder(result: dict):
# Create a retrieved ``FolderData`` node with results
with tempfile.TemporaryDirectory() as tmpdir:
dirpath = pathlib.Path(tmpdir)
with open((dirpath / "results.pickle"), "wb") as handle:
pickle.dump(result, handle)
folder_data = orm.FolderData(tree=dirpath.absolute())
return folder_data


def create_process_node(result: dict, function_data: dict):
node = orm.CalcJobNode()
node.set_process_type("aiida.calculations:pythonjob.pythonjob")
function_data = orm.Dict(function_data)
retrieved = create_retrieved_folder(result)
node.base.links.add_incoming(function_data, link_type=LinkType.INPUT_CALC, link_label="function_data")
retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label="retrieved")
function_data.store()
node.store()
retrieved.store()
return node


def create_parser(result, function_data):
node = create_process_node(result, function_data)
parser = PythonJobParser(node=node)
return parser


def test_tuple_result(fixture_localhost):
result = (1, 2, 3)
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is None
assert len(parser.outputs) == 3


def test_tuple_result_mismatch(fixture_localhost):
result = (1, 2)
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code == parser.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH


def test_dict_result(fixture_localhost):
result = {"a": 1, "b": 2, "c": 3}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is None
assert len(parser.outputs) == 3


def test_dict_result_missing(fixture_localhost):
result = {"a": 1, "b": 2}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code == parser.exit_codes.ERROR_MISSING_OUTPUT


def test_exit_code(fixture_localhost):
result = {"exit_code": {"status": 1, "message": "error"}}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is not None
assert exit_code.status == 1
assert exit_code.message == "error"
Empty file removed tests/test_parsers.py
Empty file.

0 comments on commit 21fad7e

Please sign in to comment.