Skip to content

Commit

Permalink
Add test for parser (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 authored Dec 3, 2024
1 parent e059d36 commit 8009878
Show file tree
Hide file tree
Showing 10 changed files with 333 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _build_process_label(self) -> str:
if "process_label" in self.inputs:
return self.inputs.process_label.value
else:
data = self.get_function_data()
data = self.inputs.function_data.get_dict()
return f"PythonJob<{data['name']}>"

def on_create(self) -> None:
Expand Down
10 changes: 1 addition & 9 deletions src/aiida_pythonjob/data/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ def get_serializer_from_entry_points() -> dict:
eps.setdefault(key, [])
eps[key].append(ep)

# print("Time to load entry points: ", time.time() - ts)
# check if there are duplicates
for key, value in eps.items():
if len(value) > 1:
if key in serializers:
[ep for ep in value if ep.name == serializers[key]]
eps[key] = [ep for ep in value if ep.name == serializers[key]]
if not eps[key]:
raise ValueError(f"Entry point {serializers[key]} not found for {key}")
Expand Down Expand Up @@ -105,13 +103,7 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
new_node.store()
return new_node
except Exception:
# try to serialize the value as a PickledData
try:
new_node = PickledData(data)
new_node.store()
return new_node
except Exception as e:
raise ValueError(f"Error in serializing {ep_key}: {e}")
raise ValueError(f"Error in storing data {ep_key}")
else:
# try to serialize the data as a PickledData
try:
Expand Down
2 changes: 2 additions & 0 deletions src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def prepare_pythonjob_inputs(
new_upload_files[new_key] = orm.SinglefileData(file=source)
elif os.path.isdir(source):
new_upload_files[new_key] = orm.FolderData(tree=source)
else:
raise ValueError(f"Invalid upload file path: {source}")
elif isinstance(source, (orm.SinglefileData, orm.FolderData)):
new_upload_files[new_key] = source
else:
Expand Down
68 changes: 37 additions & 31 deletions src/aiida_pythonjob/parsers/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,58 @@ def parse(self, **kwargs):
self.output_list = function_outputs
# first we remove nested outputs, e.g., "add_multiply.add"
top_level_output_list = [output for output in self.output_list if "." not in output["name"]]
exit_code = 0
try:
with self.retrieved.base.repository.open("results.pickle", "rb") as handle:
results = pickle.load(handle)
if isinstance(results, tuple):
if len(top_level_output_list) != len(results):
self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH
return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH
for i in range(len(top_level_output_list)):
top_level_output_list[i]["value"] = self.serialize_output(results[i], top_level_output_list[i])
elif isinstance(results, dict) and len(top_level_output_list) > 1:
elif isinstance(results, dict):
# pop the exit code if it exists
exit_code = results.pop("exit_code", 0)
for output in top_level_output_list:
if output.get("required", False):
if exit_code:
if isinstance(exit_code, dict):
exit_code = ExitCode(exit_code["status"], exit_code["message"])
elif isinstance(exit_code, int):
exit_code = ExitCode(exit_code)
return exit_code
if len(top_level_output_list) == 1:
# if output name in results, use it
if top_level_output_list[0]["name"] in results:
top_level_output_list[0]["value"] = self.serialize_output(
results.pop(top_level_output_list[0]["name"]),
top_level_output_list[0],
)
# if there are any remaining results, raise an warning
if len(results) > 0:
self.logger.warning(
f"Found extra results that are not included in the output: {results.keys()}"
)
# otherwise, we assume the results is the output
else:
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
elif len(top_level_output_list) > 1:
for output in top_level_output_list:
if output["name"] not in results:
self.exit_codes.ERROR_MISSING_OUTPUT
output["value"] = self.serialize_output(results.pop(output["name"]), output)
# if there are any remaining results, raise an warning
if results:
self.logger.warning(
f"Found extra results that are not included in the output: {results.keys()}"
)
elif isinstance(results, dict) and len(top_level_output_list) == 1:
exit_code = results.pop("exit_code", 0)
# if output name in results, use it
if top_level_output_list[0]["name"] in results:
top_level_output_list[0]["value"] = self.serialize_output(
results[top_level_output_list[0]["name"]],
top_level_output_list[0],
)
# otherwise, we assume the results is the output
else:
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
if output.get("required", True):
return self.exit_codes.ERROR_MISSING_OUTPUT
else:
output["value"] = self.serialize_output(results.pop(output["name"]), output)
# if there are any remaining results, raise an warning
if len(results) > 0:
self.logger.warning(
f"Found extra results that are not included in the output: {results.keys()}"
)

elif len(top_level_output_list) == 1:
# otherwise, we assume the results is the output
# otherwise it returns a single value, we assume the results is the output
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
else:
raise ValueError("The number of results does not match the number of outputs.")
return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH
for output in top_level_output_list:
self.out(output["name"], output["value"])
if exit_code:
if isinstance(exit_code, dict):
exit_code = ExitCode(exit_code["status"], exit_code["message"])
elif isinstance(exit_code, int):
exit_code = ExitCode(exit_code)
return exit_code
except OSError:
return self.exit_codes.ERROR_READING_OUTPUT_FILE
except ValueError as exception:
Expand Down
1 change: 0 additions & 1 deletion tests/input.txt

This file was deleted.

17 changes: 17 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import aiida
from aiida_pythonjob.data import general_serializer
from aiida_pythonjob.utils import get_required_imports


Expand Down Expand Up @@ -36,6 +37,15 @@ def test_python_job():
assert isinstance(new_inputs["c"], PickledData)


def test_dict_list():
from aiida_pythonjob.data.data_with_value import Dict, List

data = List([1, 2, 3])
assert data.value == [1, 2, 3]
data = Dict({"a": 1, "b": 2})
assert data.value == {"a": 1, "b": 2}


def test_atoms_data():
from aiida_pythonjob.data.atoms import AtomsData
from ase.build import bulk
Expand All @@ -44,3 +54,10 @@ def test_atoms_data():

atoms_data = AtomsData(atoms)
assert atoms_data.value == atoms


def test_only_data_with_value():
try:
general_serializer(aiida.orm.List([1]))
except ValueError as e:
assert str(e) == "Only AiiDA data Node with a value attribute is allowed."
61 changes: 61 additions & 0 deletions tests/test_entry_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import sys
from importlib.metadata import EntryPoint
from unittest.mock import patch

import pytest


# Helper function to mock EntryPoint creation
def create_entry_point(name, value, group):
return EntryPoint(name=name, value=value, group=group)


def create_mock_entry_points(entry_point_list):
if sys.version_info >= (3, 10):
# Mock the EntryPoints object for Python 3.10+
# Conditional import for EntryPoints
from importlib.metadata import EntryPoints

return EntryPoints(entry_point_list)
else:
# Return a dictionary for older Python versions
return {"aiida.data": entry_point_list}


@patch("aiida_pythonjob.data.serializer.load_config")
@patch("aiida_pythonjob.data.serializer.entry_points")
def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config):
# Mock the configuration
mock_load_config.return_value = {
"serializers": {
"excludes": ["excluded_entry"],
}
}
# Mock entry points
mock_ep_1 = create_entry_point("xyz.abc.Abc", "xyz.abc:AbcData", "aiida.data")
mock_ep_2 = create_entry_point("xyz.abc.Bcd", "xyz.abc:BcdData", "aiida.data")
mock_ep_3 = create_entry_point("xyz.abc.Cde", "xyz.abc:CdeData", "aiida.data")
mock_ep_4 = create_entry_point("another_xyz.abc.Cde", "another_xyz.abc:CdeData", "aiida.data")

mock_entry_points.return_value = create_mock_entry_points([mock_ep_1, mock_ep_2, mock_ep_3, mock_ep_4])

# Import the function and run
from aiida_pythonjob.data.serializer import get_serializer_from_entry_points

with pytest.raises(ValueError, match="Duplicate entry points for abc.Cde"):
get_serializer_from_entry_points()
# Mock the configuration
mock_load_config.return_value = {
"serializers": {
"excludes": ["excluded_entry"],
"abc.Cde": "another_xyz.abc.Cde",
}
}
result = get_serializer_from_entry_points()
# Assert results
expected = {
"abc.Abc": [mock_ep_1],
"abc.Bcd": [mock_ep_2],
"abc.Cde": [mock_ep_4],
}
assert result == expected
112 changes: 112 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pathlib
import tempfile

import cloudpickle as pickle
from aiida import orm
from aiida.cmdline.utils.common import get_workchain_report
from aiida.common.links import LinkType
from aiida_pythonjob.parsers import PythonJobParser


def create_retrieved_folder(result: dict, output_filename="results.pickle"):
# Create a retrieved ``FolderData`` node with results
with tempfile.TemporaryDirectory() as tmpdir:
dirpath = pathlib.Path(tmpdir)
with open((dirpath / output_filename), "wb") as handle:
pickle.dump(result, handle)
folder_data = orm.FolderData(tree=dirpath.absolute())
return folder_data


def create_process_node(result: dict, function_data: dict, output_filename: str = "results.pickle"):
node = orm.CalcJobNode()
node.set_process_type("aiida.calculations:pythonjob.pythonjob")
function_data = orm.Dict(function_data)
retrieved = create_retrieved_folder(result, output_filename=output_filename)
node.base.links.add_incoming(function_data, link_type=LinkType.INPUT_CALC, link_label="function_data")
retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label="retrieved")
function_data.store()
node.store()
retrieved.store()
return node


def create_parser(result, function_data, output_filename="results.pickle"):
node = create_process_node(result, function_data, output_filename=output_filename)
parser = PythonJobParser(node=node)
return parser


def test_tuple_result(fixture_localhost):
result = (1, 2, 3)
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is None
assert len(parser.outputs) == 3


def test_tuple_result_mismatch(fixture_localhost):
result = (1, 2)
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code == parser.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH


def test_dict_result(fixture_localhost):
result = {"a": 1, "b": 2, "c": 3}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is None
assert len(parser.outputs) == 2
report = get_workchain_report(parser.node, levelname="WARNING")
assert "Found extra results that are not included in the output: dict_keys(['c'])" in report


def test_dict_result_missing(fixture_localhost):
result = {"a": 1, "b": 2}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code == parser.exit_codes.ERROR_MISSING_OUTPUT


def test_dict_result_as_one_output(fixture_localhost):
result = {"a": 1, "b": 2, "c": 3}
function_data = {"outputs": [{"name": "result"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is None
assert len(parser.outputs) == 1
assert parser.outputs["result"] == result


def test_dict_result_only_show_one_output(fixture_localhost):
result = {"a": 1, "b": 2}
function_data = {"outputs": [{"name": "a"}]}
parser = create_parser(result, function_data)
parser.parse()
assert len(parser.outputs) == 1
assert parser.outputs["a"] == 1
report = get_workchain_report(parser.node, levelname="WARNING")
assert "Found extra results that are not included in the output: dict_keys(['b'])" in report


def test_exit_code(fixture_localhost):
result = {"exit_code": {"status": 1, "message": "error"}}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is not None
assert exit_code.status == 1
assert exit_code.message == "error"


def test_no_output_file(fixture_localhost):
result = {"a": 1, "b": 2, "c": 3}
function_data = {"outputs": [{"name": "result"}]}
parser = create_parser(result, function_data, output_filename="not_results.pickle")
exit_code = parser.parse()
assert exit_code == parser.exit_codes.ERROR_READING_OUTPUT_FILE
Empty file removed tests/test_parsers.py
Empty file.
Loading

0 comments on commit 8009878

Please sign in to comment.