diff --git a/docs/gallery/autogen/how_to.py b/docs/gallery/autogen/how_to.py index abd1539..67885eb 100644 --- a/docs/gallery/autogen/how_to.py +++ b/docs/gallery/autogen/how_to.py @@ -349,7 +349,7 @@ def add(x, y): ###################################################################### -# Define your data serializer +# Define your data serializer and deserializer # -------------- # # PythonJob search data serializer from the `aiida.data` entry point by the @@ -376,13 +376,54 @@ def add(x, y): # # { # "serializers": { -# "ase.atoms.Atoms": "abc.ase.atoms.Atoms" +# "ase.atoms.Atoms": "abc.ase.atoms.AtomsData" # use the full path to the serializer # } # } # # Save the configuration file as `pythonjob.json` in the aiida configuration # directory (by default, `~/.aiida` directory). +# +# If you want to pass AiiDA Data node as input, and the node does not have a `value` attribute, +# then one must provide a deserializer for it. +# + +from aiida import orm # noqa: E402 + + +def make_supercell(structure, n=2): + return structure * [n, n, n] + + +structure = orm.StructureData(cell=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]) +structure.append_atom(position=(0.0, 0.0, 0.0), symbols="Li") + +inputs = prepare_pythonjob_inputs( + make_supercell, + function_inputs={"structure": structure}, + deserializers={ + "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms" + }, +) +result, node = run_get_node(PythonJob, inputs=inputs) +print("result: ", result["result"]) +###################################################################### +# One can also set the deserializer in the configuration file. +# +# +# .. code-block:: json +# +# { +# "serializers": { +# "ase.atoms.Atoms": "abc.ase.atoms.Atoms" +# }, +# "deserializers": { +# "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_pymatgen" # noqa +# } +# } +# +# The `orm.List`, `orm.Dict`and `orm.StructureData` data types already have built-in deserializers. +# ###################################################################### # What's Next diff --git a/pyproject.toml b/pyproject.toml index 06c49c9..c2c4693 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,8 +52,8 @@ Source = "https://github.com/aiidateam/aiida-pythonjob" "pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float" "pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str" "pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool" -"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List" -"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict" +"pythonjob.builtins.list"="aiida.orm.nodes.data.list:List" +"pythonjob.builtins.dict"="aiida.orm.nodes.data.dict:Dict" [project.entry-points."aiida.calculations"] diff --git a/src/aiida_pythonjob/calculations/pythonjob.py b/src/aiida_pythonjob/calculations/pythonjob.py index 8b50bb3..1be7069 100644 --- a/src/aiida_pythonjob/calculations/pythonjob.py +++ b/src/aiida_pythonjob/calculations/pythonjob.py @@ -6,11 +6,11 @@ import typing as t from aiida.common.datastructures import CalcInfo, CodeInfo -from aiida.common.extendeddicts import AttributeDict from aiida.common.folders import Folder from aiida.engine import CalcJob, CalcJobProcessSpec from aiida.orm import ( Data, + Dict, FolderData, List, RemoteData, @@ -92,6 +92,22 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] serializer=to_aiida_type, help="Additional filenames to retrieve from the remote work directory", ) + spec.input( + "deserializers", + valid_type=Dict, + default=None, + required=False, + serializer=to_aiida_type, + help="The deserializers to convert the input AiiDA data nodes to raw Python data.", + ) + spec.input( + "serializers", + valid_type=Dict, + default=None, + required=False, + serializer=to_aiida_type, + help="The serializers to convert the raw Python data to AiiDA data nodes.", + ) spec.outputs.dynamic = True # set default options (optional) spec.inputs["metadata"]["options"]["parser_name"].default = "pythonjob.pythonjob" @@ -190,6 +206,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: import cloudpickle as pickle from aiida_pythonjob.calculations.utils import generate_script_py + from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data dirpath = pathlib.Path(folder._abspath) @@ -279,17 +296,13 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: # Create a pickle file for the user input values input_values = {} - for key, value in inputs.items(): - if isinstance(value, Data) and hasattr(value, "value"): - input_values[key] = value.value - elif isinstance(value, (AttributeDict, dict)): - # Convert an AttributeDict/dict with .value items - input_values[key] = {k: v.value for k, v in value.items()} - else: - raise ValueError( - f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or " - "AttributeDict/dict-of-Data are allowed." - ) + if "deserializers" in self.inputs and self.inputs.deserializers: + deserializers = self.inputs.deserializers.get_dict() + # replace "__dot__" with "." in the keys + deserializers = {k.replace("__dot__", "."): v for k, v in deserializers.items()} + else: + deserializers = None + input_values = deserialize_to_raw_python_data(inputs, deserializers=deserializers) filename = "inputs.pickle" with folder.open(filename, "wb") as handle: diff --git a/src/aiida_pythonjob/data/data_with_value.py b/src/aiida_pythonjob/data/data_with_value.py deleted file mode 100644 index 469b810..0000000 --- a/src/aiida_pythonjob/data/data_with_value.py +++ /dev/null @@ -1,13 +0,0 @@ -from aiida import orm - - -class Dict(orm.Dict): - @property - def value(self): - return self.get_dict() - - -class List(orm.List): - @property - def value(self): - return self.get_list() diff --git a/src/aiida_pythonjob/data/deserializer.py b/src/aiida_pythonjob/data/deserializer.py new file mode 100644 index 0000000..ee93bce --- /dev/null +++ b/src/aiida_pythonjob/data/deserializer.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any + +from aiida import common, orm + +from aiida_pythonjob.config import load_config +from aiida_pythonjob.utils import import_from_path + +builtin_deserializers = { + "aiida.orm.nodes.data.list.List": "aiida_pythonjob.data.deserializer.list_data_to_list", + "aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict", + "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms", +} + + +def generate_aiida_node_deserializer(data: orm.Node) -> dict: + if isinstance(data, orm.Data): + return data.backend_entity.attributes + elif isinstance(data, (common.extendeddicts.AttributeDict, dict)): + # if the data is an AttributeDict, use it directly + return {k: generate_aiida_node_deserializer(v) for k, v in data.items()} + + +def list_data_to_list(data): + return data.get_list() + + +def dict_data_to_dict(data): + return data.get_dict() + + +def structure_data_to_atoms(structure): + return structure.get_ase() + + +def structure_data_to_pymatgen(structure): + return structure.get_pymatgen() + + +def get_deserializer() -> dict: + """Retrieve the serializer from the entry points.""" + configs = load_config() + custom_deserializers = configs.get("deserializers", {}) + deserializers = builtin_deserializers.copy() + deserializers.update(custom_deserializers) + return deserializers + + +all_deserializers = get_deserializer() + + +def deserialize_to_raw_python_data(data: orm.Node, deserializers: dict | None = None) -> Any: + """Deserialize the AiiDA data node to an raw Python data.""" + + updated_deserializers = all_deserializers.copy() + + if deserializers is not None: + updated_deserializers.update(deserializers) + + if isinstance(data, orm.Data): + if hasattr(data, "value"): + return getattr(data, "value") + data_type = type(data) + ep_key = f"{data_type.__module__}.{data_type.__name__}" + if ep_key in updated_deserializers: + deserializer = import_from_path(updated_deserializers[ep_key]) + return deserializer(data) + else: + raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.") + elif isinstance(data, (common.extendeddicts.AttributeDict, dict)): + # if the data is an AttributeDict, use it directly + return {k: deserialize_to_raw_python_data(v, deserializers=deserializers) for k, v in data.items()} diff --git a/src/aiida_pythonjob/data/serializer.py b/src/aiida_pythonjob/data/serializer.py index beb8e7f..fb7a6be 100644 --- a/src/aiida_pythonjob/data/serializer.py +++ b/src/aiida_pythonjob/data/serializer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from importlib.metadata import entry_points from typing import Any @@ -5,18 +7,17 @@ from aiida import common, orm from aiida_pythonjob.config import load_config +from aiida_pythonjob.utils import import_from_path +from .deserializer import all_deserializers from .pickled_data import PickledData -def get_serializer_from_entry_points() -> dict: - """Retrieve the serializer from the entry points.""" - # import time +def atoms_to_structure_data(structure): + return orm.StructureData(ase=structure) - # ts = time.time() - configs = load_config() - serializers = configs.get("serializers", {}) - excludes = serializers.get("excludes", []) + +def get_serializers_from_entry_points() -> dict: # Retrieve the entry points for 'aiida.data' and store them in a dictionary eps = entry_points() if sys.version_info >= (3, 10): @@ -28,28 +29,39 @@ def get_serializer_from_entry_points() -> dict: # split the entry point name by first ".", and check the last part key = ep.name.split(".", 1)[-1] # skip key without "." because it is not a module name for a data type - if "." not in key or key in excludes: + if "." not in key: continue eps.setdefault(key, []) - eps[key].append(ep) + # get the path of the entry point value and replace ":" with "." + eps[key].append(ep.value.replace(":", ".")) + return eps + +def get_serializers() -> dict: + """Retrieve the serializer from the entry points.""" + # import time + + # ts = time.time() + all_serializers = {} + configs = load_config() + custom_serializers = configs.get("serializers", {}) + eps = get_serializers_from_entry_points() # check if there are duplicates for key, value in eps.items(): if len(value) > 1: - if key in serializers: - eps[key] = [ep for ep in value if ep.name == serializers[key]] - if not eps[key]: - raise ValueError(f"Entry point {serializers[key]} not found for {key}") - else: - msg = f"Duplicate entry points for {key}: {[ep.name for ep in value]}" + if key not in custom_serializers: + msg = f"Duplicate entry points for {key}: {value}. You can specify the one to use in the configuration file." # noqa raise ValueError(msg) - return eps + all_serializers[key] = value[0] + all_serializers.update(custom_serializers) + # print("Time to get serializer", time.time() - ts) + return all_serializers -eps = get_serializer_from_entry_points() +all_serializers = get_serializers() -def serialize_to_aiida_nodes(inputs: dict) -> dict: +def serialize_to_aiida_nodes(inputs: dict, serializers: dict | None = None, deserializers: dict | None = None) -> dict: """Serialize the inputs to a dictionary of AiiDA data nodes. Args: @@ -61,7 +73,7 @@ def serialize_to_aiida_nodes(inputs: dict) -> dict: new_inputs = {} # save all kwargs to inputs port for key, data in inputs.items(): - new_inputs[key] = general_serializer(data) + new_inputs[key] = general_serializer(data, serializers=serializers, deserializers=deserializers) return new_inputs @@ -72,11 +84,24 @@ def clean_dict_key(data): return data -def general_serializer(data: Any, check_value=True) -> orm.Node: +def general_serializer( + data: Any, serializers: dict | None = None, deserializers: dict | None = None, check_value=True +) -> orm.Node: """Serialize the data to an AiiDA data node.""" + updated_deserializers = all_deserializers.copy() + if deserializers is not None: + updated_deserializers.update(deserializers) + + updated_serializers = all_serializers.copy() + if serializers is not None: + updated_serializers.update(serializers) + if isinstance(data, orm.Data): if check_value and not hasattr(data, "value"): - raise ValueError("Only AiiDA data Node with a value attribute is allowed.") + data_type = type(data) + ep_key = f"{data_type.__module__}.{data_type.__name__}" + if ep_key not in updated_deserializers: + raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.") return data elif isinstance(data, common.extendeddicts.AttributeDict): # if the data is an AttributeDict, use it directly @@ -92,9 +117,10 @@ def general_serializer(data: Any, check_value=True) -> orm.Node: data_type = type(data) ep_key = f"{data_type.__module__}.{data_type.__name__}" # search for the key in the entry points - if ep_key in eps: + if ep_key in updated_serializers: try: - new_node = eps[ep_key][0].load()(data) + serializer = import_from_path(updated_serializers[ep_key]) + new_node = serializer(data) except Exception as e: raise ValueError(f"Error in serializing {ep_key}: {e}") finally: diff --git a/src/aiida_pythonjob/launch.py b/src/aiida_pythonjob/launch.py index 0224117..41b486b 100644 --- a/src/aiida_pythonjob/launch.py +++ b/src/aiida_pythonjob/launch.py @@ -20,6 +20,8 @@ def prepare_pythonjob_inputs( upload_files: Dict[str, str] = {}, process_label: Optional[str] = None, function_data: dict | None = None, + deserializers: dict | None = None, + serializers: dict | None = None, **kwargs: Any, ) -> Dict[str, Any]: """Prepare the inputs for PythonJob""" @@ -55,14 +57,21 @@ def prepare_pythonjob_inputs( code = get_or_create_code(computer=computer, **command_info) # serialize the kwargs into AiiDA Data function_inputs = function_inputs or {} - function_inputs = serialize_to_aiida_nodes(function_inputs) + function_inputs = serialize_to_aiida_nodes(function_inputs, serializers=serializers, deserializers=deserializers) function_data["outputs"] = function_outputs or [{"name": "result"}] + # replace "." with "__dot__" in the keys of a dictionary + if deserializers: + deserializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in deserializers.items()}) + if serializers: + serializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in serializers.items()}) inputs = { "function_data": function_data, "code": code, "function_inputs": function_inputs, "upload_files": new_upload_files, "metadata": metadata or {}, + "deserializers": deserializers, + "serializers": serializers, **kwargs, } if process_label: diff --git a/src/aiida_pythonjob/parsers/pythonjob.py b/src/aiida_pythonjob/parsers/pythonjob.py index 6222dfe..5810342 100644 --- a/src/aiida_pythonjob/parsers/pythonjob.py +++ b/src/aiida_pythonjob/parsers/pythonjob.py @@ -28,6 +28,14 @@ def parse(self, **kwargs): function_outputs = [{"name": "result"}] self.output_list = function_outputs + # load custom serializers + if "serializers" in self.node.inputs and self.node.inputs.serializers: + serializers = self.node.inputs.serializers.get_dict() + # replace "__dot__" with "." in the keys + self.serializers = {k.replace("__dot__", "."): v for k, v in serializers.items()} + else: + self.serializers = None + # If nested outputs like "add_multiply.add", keep only top-level top_level_output_list = [output for output in self.output_list if "." not in output["name"]] @@ -144,10 +152,10 @@ def serialize_output(self, result, output): if full_name_output and full_name_output.get("identifier", "Any").upper() == "NAMESPACE": serialized_result[key] = self.serialize_output(value, full_name_output) else: - serialized_result[key] = general_serializer(value) + serialized_result[key] = general_serializer(value, serializers=self.serializers) return serialized_result else: self.logger.error(f"Expected a dict for namespace '{name}', got {type(result)}.") return self.exit_codes.ERROR_INVALID_OUTPUT else: - return general_serializer(result) + return general_serializer(result, serializers=self.serializers) diff --git a/src/aiida_pythonjob/utils.py b/src/aiida_pythonjob/utils.py index 5cdaffd..df32ae6 100644 --- a/src/aiida_pythonjob/utils.py +++ b/src/aiida_pythonjob/utils.py @@ -5,6 +5,17 @@ from aiida.orm import Computer, InstalledCode, User, load_code, load_computer +def import_from_path(path: str) -> Any: + import importlib + + module_name, object_name = path.rsplit(".", 1) + module = importlib.import_module(module_name) + try: + return getattr(module, object_name) + except AttributeError: + raise AttributeError(f"{object_name} not found in module {module_name}.") + + def get_required_imports(func: Callable) -> Dict[str, set]: """Retrieve type hints and the corresponding modules.""" type_hints = get_type_hints(func) diff --git a/tests/test_data.py b/tests/test_data.py index 9b62e71..465784c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,5 @@ import aiida +import pytest def test_typing(): @@ -36,15 +37,6 @@ def test_python_job(): assert isinstance(new_inputs["c"], PickledData) -def test_dict_list(): - from aiida_pythonjob.data.data_with_value import Dict, List - - data = List([1, 2, 3]) - assert data.value == [1, 2, 3] - data = Dict({"a": 1, "b": 2}) - assert data.value == {"a": 1, "b": 2} - - def test_atoms_data(): from aiida_pythonjob.data.atoms import AtomsData from ase.build import bulk @@ -58,7 +50,26 @@ def test_atoms_data(): def test_only_data_with_value(): from aiida_pythonjob.data import general_serializer - try: - general_serializer(aiida.orm.List([1])) - except ValueError as e: - assert str(e) == "Only AiiDA data Node with a value attribute is allowed." + # do not raise error because the built-in serializer can handle it + general_serializer(aiida.orm.List([1])) + # Test case: aiida.orm.ArrayData should raise a ValueError + with pytest.raises( + ValueError, + match="AiiDA data: aiida.orm.nodes.data.array.array.ArrayData, does not have a value attribute or deserializer.", # noqa + ): + general_serializer(aiida.orm.ArrayData()) + + +def test_deserializer(): + import numpy as np + from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data + + data = aiida.orm.ArrayData() + data.set_array("data", np.array([1, 2, 3])) + data = deserialize_to_raw_python_data( + data, + deserializers={ + "aiida.orm.nodes.data.array.array.ArrayData": "aiida_pythonjob.data.deserializer.generate_aiida_node_deserializer" # noqa + }, + ) + assert data == {"array|data": [3]} diff --git a/tests/test_entry_points.py b/tests/test_entry_points.py index af451f8..a2a4578 100644 --- a/tests/test_entry_points.py +++ b/tests/test_entry_points.py @@ -24,13 +24,9 @@ def create_mock_entry_points(entry_point_list): @patch("aiida_pythonjob.data.serializer.load_config") @patch("aiida_pythonjob.data.serializer.entry_points") -def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config): +def test_get_serializers(mock_entry_points, mock_load_config): # Mock the configuration - mock_load_config.return_value = { - "serializers": { - "excludes": ["excluded_entry"], - } - } + mock_load_config.return_value = {"serializers": {}} # Mock entry points mock_ep_1 = create_entry_point("xyz.abc.Abc", "xyz.abc:AbcData", "aiida.data") mock_ep_2 = create_entry_point("xyz.abc.Bcd", "xyz.abc:BcdData", "aiida.data") @@ -40,22 +36,18 @@ def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config): mock_entry_points.return_value = create_mock_entry_points([mock_ep_1, mock_ep_2, mock_ep_3, mock_ep_4]) # Import the function and run - from aiida_pythonjob.data.serializer import get_serializer_from_entry_points + from aiida_pythonjob.data.serializer import get_serializers with pytest.raises(ValueError, match="Duplicate entry points for abc.Cde"): - get_serializer_from_entry_points() + get_serializers() # Mock the configuration mock_load_config.return_value = { "serializers": { - "excludes": ["excluded_entry"], - "abc.Cde": "another_xyz.abc.Cde", + "abc.Cde": "another_xyz.abc.CdeData", } } - result = get_serializer_from_entry_points() + result = get_serializers() # Assert results - expected = { - "abc.Abc": [mock_ep_1], - "abc.Bcd": [mock_ep_2], - "abc.Cde": [mock_ep_4], - } + expected = {"abc.Abc": "xyz.abc.AbcData", "abc.Bcd": "xyz.abc.BcdData", "abc.Cde": "another_xyz.abc.CdeData"} + print("result", result) assert result == expected