From 8009878c2368e7f50bc7ff125597992134b97c06 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Tue, 3 Dec 2024 10:41:25 +0100 Subject: [PATCH] Add test for parser (#5) --- src/aiida_pythonjob/calculations/pythonjob.py | 2 +- src/aiida_pythonjob/data/serializer.py | 10 +- src/aiida_pythonjob/launch.py | 2 + src/aiida_pythonjob/parsers/pythonjob.py | 68 +++++---- tests/input.txt | 1 - tests/test_data.py | 17 +++ tests/test_entry_points.py | 61 ++++++++ tests/test_parser.py | 112 +++++++++++++++ tests/test_parsers.py | 0 tests/test_pythonjob.py | 132 ++++++++++++++---- 10 files changed, 333 insertions(+), 72 deletions(-) delete mode 100644 tests/input.txt create mode 100644 tests/test_entry_points.py create mode 100644 tests/test_parser.py delete mode 100644 tests/test_parsers.py diff --git a/src/aiida_pythonjob/calculations/pythonjob.py b/src/aiida_pythonjob/calculations/pythonjob.py index 8fd12b5..5299efd 100644 --- a/src/aiida_pythonjob/calculations/pythonjob.py +++ b/src/aiida_pythonjob/calculations/pythonjob.py @@ -125,7 +125,7 @@ def _build_process_label(self) -> str: if "process_label" in self.inputs: return self.inputs.process_label.value else: - data = self.get_function_data() + data = self.inputs.function_data.get_dict() return f"PythonJob<{data['name']}>" def on_create(self) -> None: diff --git a/src/aiida_pythonjob/data/serializer.py b/src/aiida_pythonjob/data/serializer.py index cb56cf3..beb8e7f 100644 --- a/src/aiida_pythonjob/data/serializer.py +++ b/src/aiida_pythonjob/data/serializer.py @@ -33,12 +33,10 @@ def get_serializer_from_entry_points() -> dict: eps.setdefault(key, []) eps[key].append(ep) - # print("Time to load entry points: ", time.time() - ts) # check if there are duplicates for key, value in eps.items(): if len(value) > 1: if key in serializers: - [ep for ep in value if ep.name == serializers[key]] eps[key] = [ep for ep in value if ep.name == serializers[key]] if not eps[key]: raise ValueError(f"Entry point {serializers[key]} not found for {key}") @@ -105,13 +103,7 @@ def general_serializer(data: Any, check_value=True) -> orm.Node: new_node.store() return new_node except Exception: - # try to serialize the value as a PickledData - try: - new_node = PickledData(data) - new_node.store() - return new_node - except Exception as e: - raise ValueError(f"Error in serializing {ep_key}: {e}") + raise ValueError(f"Error in storing data {ep_key}") else: # try to serialize the data as a PickledData try: diff --git a/src/aiida_pythonjob/launch.py b/src/aiida_pythonjob/launch.py index a08b568..fee94a7 100644 --- a/src/aiida_pythonjob/launch.py +++ b/src/aiida_pythonjob/launch.py @@ -44,6 +44,8 @@ def prepare_pythonjob_inputs( new_upload_files[new_key] = orm.SinglefileData(file=source) elif os.path.isdir(source): new_upload_files[new_key] = orm.FolderData(tree=source) + else: + raise ValueError(f"Invalid upload file path: {source}") elif isinstance(source, (orm.SinglefileData, orm.FolderData)): new_upload_files[new_key] = source else: diff --git a/src/aiida_pythonjob/parsers/pythonjob.py b/src/aiida_pythonjob/parsers/pythonjob.py index 2fb659f..7e5fe04 100644 --- a/src/aiida_pythonjob/parsers/pythonjob.py +++ b/src/aiida_pythonjob/parsers/pythonjob.py @@ -28,52 +28,58 @@ def parse(self, **kwargs): self.output_list = function_outputs # first we remove nested outputs, e.g., "add_multiply.add" top_level_output_list = [output for output in self.output_list if "." not in output["name"]] - exit_code = 0 try: with self.retrieved.base.repository.open("results.pickle", "rb") as handle: results = pickle.load(handle) if isinstance(results, tuple): if len(top_level_output_list) != len(results): - self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH + return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH for i in range(len(top_level_output_list)): top_level_output_list[i]["value"] = self.serialize_output(results[i], top_level_output_list[i]) - elif isinstance(results, dict) and len(top_level_output_list) > 1: + elif isinstance(results, dict): # pop the exit code if it exists exit_code = results.pop("exit_code", 0) - for output in top_level_output_list: - if output.get("required", False): + if exit_code: + if isinstance(exit_code, dict): + exit_code = ExitCode(exit_code["status"], exit_code["message"]) + elif isinstance(exit_code, int): + exit_code = ExitCode(exit_code) + return exit_code + if len(top_level_output_list) == 1: + # if output name in results, use it + if top_level_output_list[0]["name"] in results: + top_level_output_list[0]["value"] = self.serialize_output( + results.pop(top_level_output_list[0]["name"]), + top_level_output_list[0], + ) + # if there are any remaining results, raise an warning + if len(results) > 0: + self.logger.warning( + f"Found extra results that are not included in the output: {results.keys()}" + ) + # otherwise, we assume the results is the output + else: + top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) + elif len(top_level_output_list) > 1: + for output in top_level_output_list: if output["name"] not in results: - self.exit_codes.ERROR_MISSING_OUTPUT - output["value"] = self.serialize_output(results.pop(output["name"]), output) - # if there are any remaining results, raise an warning - if results: - self.logger.warning( - f"Found extra results that are not included in the output: {results.keys()}" - ) - elif isinstance(results, dict) and len(top_level_output_list) == 1: - exit_code = results.pop("exit_code", 0) - # if output name in results, use it - if top_level_output_list[0]["name"] in results: - top_level_output_list[0]["value"] = self.serialize_output( - results[top_level_output_list[0]["name"]], - top_level_output_list[0], - ) - # otherwise, we assume the results is the output - else: - top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) + if output.get("required", True): + return self.exit_codes.ERROR_MISSING_OUTPUT + else: + output["value"] = self.serialize_output(results.pop(output["name"]), output) + # if there are any remaining results, raise an warning + if len(results) > 0: + self.logger.warning( + f"Found extra results that are not included in the output: {results.keys()}" + ) + elif len(top_level_output_list) == 1: - # otherwise, we assume the results is the output + # otherwise it returns a single value, we assume the results is the output top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) else: - raise ValueError("The number of results does not match the number of outputs.") + return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH for output in top_level_output_list: self.out(output["name"], output["value"]) - if exit_code: - if isinstance(exit_code, dict): - exit_code = ExitCode(exit_code["status"], exit_code["message"]) - elif isinstance(exit_code, int): - exit_code = ExitCode(exit_code) - return exit_code except OSError: return self.exit_codes.ERROR_READING_OUTPUT_FILE except ValueError as exception: diff --git a/tests/input.txt b/tests/input.txt deleted file mode 100644 index d8263ee..0000000 --- a/tests/input.txt +++ /dev/null @@ -1 +0,0 @@ -2 \ No newline at end of file diff --git a/tests/test_data.py b/tests/test_data.py index 82c8f9c..7cedeae 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,5 @@ import aiida +from aiida_pythonjob.data import general_serializer from aiida_pythonjob.utils import get_required_imports @@ -36,6 +37,15 @@ def test_python_job(): assert isinstance(new_inputs["c"], PickledData) +def test_dict_list(): + from aiida_pythonjob.data.data_with_value import Dict, List + + data = List([1, 2, 3]) + assert data.value == [1, 2, 3] + data = Dict({"a": 1, "b": 2}) + assert data.value == {"a": 1, "b": 2} + + def test_atoms_data(): from aiida_pythonjob.data.atoms import AtomsData from ase.build import bulk @@ -44,3 +54,10 @@ def test_atoms_data(): atoms_data = AtomsData(atoms) assert atoms_data.value == atoms + + +def test_only_data_with_value(): + try: + general_serializer(aiida.orm.List([1])) + except ValueError as e: + assert str(e) == "Only AiiDA data Node with a value attribute is allowed." diff --git a/tests/test_entry_points.py b/tests/test_entry_points.py new file mode 100644 index 0000000..af451f8 --- /dev/null +++ b/tests/test_entry_points.py @@ -0,0 +1,61 @@ +import sys +from importlib.metadata import EntryPoint +from unittest.mock import patch + +import pytest + + +# Helper function to mock EntryPoint creation +def create_entry_point(name, value, group): + return EntryPoint(name=name, value=value, group=group) + + +def create_mock_entry_points(entry_point_list): + if sys.version_info >= (3, 10): + # Mock the EntryPoints object for Python 3.10+ + # Conditional import for EntryPoints + from importlib.metadata import EntryPoints + + return EntryPoints(entry_point_list) + else: + # Return a dictionary for older Python versions + return {"aiida.data": entry_point_list} + + +@patch("aiida_pythonjob.data.serializer.load_config") +@patch("aiida_pythonjob.data.serializer.entry_points") +def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config): + # Mock the configuration + mock_load_config.return_value = { + "serializers": { + "excludes": ["excluded_entry"], + } + } + # Mock entry points + mock_ep_1 = create_entry_point("xyz.abc.Abc", "xyz.abc:AbcData", "aiida.data") + mock_ep_2 = create_entry_point("xyz.abc.Bcd", "xyz.abc:BcdData", "aiida.data") + mock_ep_3 = create_entry_point("xyz.abc.Cde", "xyz.abc:CdeData", "aiida.data") + mock_ep_4 = create_entry_point("another_xyz.abc.Cde", "another_xyz.abc:CdeData", "aiida.data") + + mock_entry_points.return_value = create_mock_entry_points([mock_ep_1, mock_ep_2, mock_ep_3, mock_ep_4]) + + # Import the function and run + from aiida_pythonjob.data.serializer import get_serializer_from_entry_points + + with pytest.raises(ValueError, match="Duplicate entry points for abc.Cde"): + get_serializer_from_entry_points() + # Mock the configuration + mock_load_config.return_value = { + "serializers": { + "excludes": ["excluded_entry"], + "abc.Cde": "another_xyz.abc.Cde", + } + } + result = get_serializer_from_entry_points() + # Assert results + expected = { + "abc.Abc": [mock_ep_1], + "abc.Bcd": [mock_ep_2], + "abc.Cde": [mock_ep_4], + } + assert result == expected diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..1bef763 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,112 @@ +import pathlib +import tempfile + +import cloudpickle as pickle +from aiida import orm +from aiida.cmdline.utils.common import get_workchain_report +from aiida.common.links import LinkType +from aiida_pythonjob.parsers import PythonJobParser + + +def create_retrieved_folder(result: dict, output_filename="results.pickle"): + # Create a retrieved ``FolderData`` node with results + with tempfile.TemporaryDirectory() as tmpdir: + dirpath = pathlib.Path(tmpdir) + with open((dirpath / output_filename), "wb") as handle: + pickle.dump(result, handle) + folder_data = orm.FolderData(tree=dirpath.absolute()) + return folder_data + + +def create_process_node(result: dict, function_data: dict, output_filename: str = "results.pickle"): + node = orm.CalcJobNode() + node.set_process_type("aiida.calculations:pythonjob.pythonjob") + function_data = orm.Dict(function_data) + retrieved = create_retrieved_folder(result, output_filename=output_filename) + node.base.links.add_incoming(function_data, link_type=LinkType.INPUT_CALC, link_label="function_data") + retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label="retrieved") + function_data.store() + node.store() + retrieved.store() + return node + + +def create_parser(result, function_data, output_filename="results.pickle"): + node = create_process_node(result, function_data, output_filename=output_filename) + parser = PythonJobParser(node=node) + return parser + + +def test_tuple_result(fixture_localhost): + result = (1, 2, 3) + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is None + assert len(parser.outputs) == 3 + + +def test_tuple_result_mismatch(fixture_localhost): + result = (1, 2) + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code == parser.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH + + +def test_dict_result(fixture_localhost): + result = {"a": 1, "b": 2, "c": 3} + function_data = {"outputs": [{"name": "a"}, {"name": "b"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is None + assert len(parser.outputs) == 2 + report = get_workchain_report(parser.node, levelname="WARNING") + assert "Found extra results that are not included in the output: dict_keys(['c'])" in report + + +def test_dict_result_missing(fixture_localhost): + result = {"a": 1, "b": 2} + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code == parser.exit_codes.ERROR_MISSING_OUTPUT + + +def test_dict_result_as_one_output(fixture_localhost): + result = {"a": 1, "b": 2, "c": 3} + function_data = {"outputs": [{"name": "result"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is None + assert len(parser.outputs) == 1 + assert parser.outputs["result"] == result + + +def test_dict_result_only_show_one_output(fixture_localhost): + result = {"a": 1, "b": 2} + function_data = {"outputs": [{"name": "a"}]} + parser = create_parser(result, function_data) + parser.parse() + assert len(parser.outputs) == 1 + assert parser.outputs["a"] == 1 + report = get_workchain_report(parser.node, levelname="WARNING") + assert "Found extra results that are not included in the output: dict_keys(['b'])" in report + + +def test_exit_code(fixture_localhost): + result = {"exit_code": {"status": 1, "message": "error"}} + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is not None + assert exit_code.status == 1 + assert exit_code.message == "error" + + +def test_no_output_file(fixture_localhost): + result = {"a": 1, "b": 2, "c": 3} + function_data = {"outputs": [{"name": "result"}]} + parser = create_parser(result, function_data, output_filename="not_results.pickle") + exit_code = parser.parse() + assert exit_code == parser.exit_codes.ERROR_READING_OUTPUT_FILE diff --git a/tests/test_parsers.py b/tests/test_parsers.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_pythonjob.py b/tests/test_pythonjob.py index c1f4cb7..130c535 100644 --- a/tests/test_pythonjob.py +++ b/tests/test_pythonjob.py @@ -1,8 +1,28 @@ +import os +import pathlib +import tempfile + import pytest +from aiida import orm from aiida.engine import run_get_node from aiida_pythonjob import PythonJob, prepare_pythonjob_inputs +def test_validate_inputs(): + def add(x, y): + return x + y + + with pytest.raises(ValueError, match="Either function or function_data must be provided"): + prepare_pythonjob_inputs( + function_inputs={"x": 1, "y": 2}, + ) + with pytest.raises(ValueError, match="Only one of function or function_data should be provided"): + prepare_pythonjob_inputs( + function=add, + function_data={"module": "math", "name": "sqrt", "is_pickle": False}, + ) + + def test_function_default_outputs(fixture_localhost): """Test decorator.""" @@ -14,7 +34,6 @@ def add(x, y): function_inputs={"x": 1, "y": 2}, ) result, node = run_get_node(PythonJob, **inputs) - print("result: ", result) assert result["result"].value == 3 assert node.process_label == "PythonJob" @@ -34,10 +53,12 @@ def add(x, y): {"name": "diff"}, ], ) + inputs.pop("process_label") result, node = run_get_node(PythonJob, **inputs) assert result["sum"].value == 3 assert result["diff"].value == -1 + assert node.process_label == "PythonJob" @pytest.mark.skip("Can not inspect the built-in function.") @@ -110,7 +131,7 @@ def myfunc(x, y): assert result["add_multiply"]["multiply"].value == 2 -def test_parent_folder(fixture_localhost): +def test_parent_folder_remote(fixture_localhost): """Test function with parent folder.""" def add(x, y): @@ -142,43 +163,70 @@ def multiply(x, y): assert result2["product"].value == 5 -def test_upload_files(fixture_localhost): - """Test function with upload files.""" +def test_parent_folder_local(fixture_localhost): + """Test function with parent folder.""" - # create a temporary file "input.txt" in the current directory - with open("input.txt", "w") as f: - f.write("2") + with tempfile.TemporaryDirectory() as tmpdir: + dirpath = pathlib.Path(tmpdir) + with open((dirpath / "result.txt"), "w") as f: + f.write("3") + + parent_folder = orm.FolderData(tree=dirpath.absolute()) + + def multiply(x, y): + with open("parent_folder/result.txt", "r") as f: + z = int(f.read()) + return x * y + z - # create a temporary folder "inputs_folder" in the current directory - # and add a file "another_input.txt" in the folder - import os + inputs2 = prepare_pythonjob_inputs( + multiply, + function_inputs={"x": 1, "y": 2}, + function_outputs=[{"name": "product"}], + parent_folder=parent_folder, + ) + result2, node2 = run_get_node(PythonJob, inputs=inputs2) - os.makedirs("inputs_folder", exist_ok=True) - with open("inputs_folder/another_input.txt", "w") as f: - f.write("3") + assert result2["product"].value == 5 + + +def test_upload_files(fixture_localhost): + """Test function with upload files.""" def add(): with open("input.txt", "r") as f: a = int(f.read()) - with open("inputs_folder/another_input.txt", "r") as f: + with open("another_input.txt", "r") as f: b = int(f.read()) - return a + b - - # ------------------------- Submit the calculation ------------------- - # we need use full path to the file - input_file = os.path.abspath("input.txt") - input_folder = os.path.abspath("inputs_folder") - inputs = prepare_pythonjob_inputs( - add, - upload_files={ - "input.txt": input_file, - "inputs_folder": input_folder, - }, - ) - result, node = run_get_node(PythonJob, inputs=inputs) + with open("inputs_folder/another_input.txt", "r") as f: + c = int(f.read()) + return a + b + c - # wait=True) - assert result["result"].value == 5 + # create a temporary file "input.txt" in the current directory + with tempfile.TemporaryDirectory() as tmpdir: + dirpath = pathlib.Path(tmpdir) + with open((dirpath / "input.txt"), "w") as f: + f.write("2") + with open((dirpath / "another_input.txt"), "w") as f: + f.write("3") + # create a temporary folder "inputs_folder" + os.makedirs((dirpath / "inputs_folder"), exist_ok=True) + with open((dirpath / "inputs_folder/another_input.txt"), "w") as f: + f.write("4") + # we need use full path to the file + input_file = str(dirpath / "input.txt") + input_folder = str(dirpath / "inputs_folder") + single_file_data = orm.SinglefileData(file=(dirpath / "another_input.txt")) + # ------------------------- Submit the calculation ------------------- + inputs = prepare_pythonjob_inputs( + add, + upload_files={ + "input.txt": input_file, + "another_input.txt": single_file_data, + "inputs_folder": input_folder, + }, + ) + result, node = run_get_node(PythonJob, inputs=inputs) + assert result["result"].value == 9 def test_retrieve_files(fixture_localhost): @@ -205,6 +253,30 @@ def add(x, y): assert "result.txt" in result["retrieved"].list_object_names() +def test_copy_files(fixture_localhost): + """Test function with copy files.""" + + def add(x, y): + z = x + y + with open("result.txt", "w") as f: + f.write(str(z)) + + def multiply(x_folder_name, y): + with open(f"{x_folder_name}/result.txt", "r") as f: + x = int(f.read()) + return x * y + + inputs = prepare_pythonjob_inputs(add, function_inputs={"x": 1, "y": 2}) + result, node = run_get_node(PythonJob, inputs=inputs) + inputs = prepare_pythonjob_inputs( + multiply, + function_inputs={"x_folder_name": "x_folder_name", "y": 2}, + copy_files={"x_folder_name": result["remote_folder"]}, + ) + result, node = run_get_node(PythonJob, inputs=inputs) + assert result["result"].value == 6 + + def test_exit_code(fixture_localhost): """Test function with exit code.""" from numpy import array