From b57c36c2b01c290d0838cc7b70c3e0d7e05f31d6 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Thu, 23 Jan 2025 02:39:15 +0100 Subject: [PATCH] Add deserializer --- docs/gallery/autogen/how_to.py | 21 +++++- pyproject.toml | 4 +- src/aiida_pythonjob/calculations/pythonjob.py | 14 +--- src/aiida_pythonjob/data/data_with_value.py | 13 ---- src/aiida_pythonjob/data/deserializer.py | 72 +++++++++++++++++++ src/aiida_pythonjob/data/serializer.py | 12 ++-- tests/test_data.py | 22 +++--- 7 files changed, 112 insertions(+), 46 deletions(-) delete mode 100644 src/aiida_pythonjob/data/data_with_value.py create mode 100644 src/aiida_pythonjob/data/deserializer.py diff --git a/docs/gallery/autogen/how_to.py b/docs/gallery/autogen/how_to.py index abd1539..afe77e8 100644 --- a/docs/gallery/autogen/how_to.py +++ b/docs/gallery/autogen/how_to.py @@ -349,7 +349,7 @@ def add(x, y): ###################################################################### -# Define your data serializer +# Define your data serializer and deserializer # -------------- # # PythonJob search data serializer from the `aiida.data` entry point by the @@ -382,7 +382,24 @@ def add(x, y): # # Save the configuration file as `pythonjob.json` in the aiida configuration # directory (by default, `~/.aiida` directory). - +# +# If you want to pass AiiDA Data node as input, and the node does not has a `value` attribute +# then one mush provide a deserializer for it. One can set the deserializer in the configuration file. +# +# +# .. code-block:: json +# +# { +# "serializers": { +# "ase.atoms.Atoms": "abc.ase.atoms.Atoms" +# }, +# "deserializers": { +# "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms" +# } +# } +# +# The `orm.List`, `orm.Dict` and `orm.StructureData` data types already have built-in deserializers. +# ###################################################################### # What's Next diff --git a/pyproject.toml b/pyproject.toml index 06c49c9..c2c4693 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,8 +52,8 @@ Source = "https://github.com/aiidateam/aiida-pythonjob" "pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float" "pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str" "pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool" -"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List" -"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict" +"pythonjob.builtins.list"="aiida.orm.nodes.data.list:List" +"pythonjob.builtins.dict"="aiida.orm.nodes.data.dict:Dict" [project.entry-points."aiida.calculations"] diff --git a/src/aiida_pythonjob/calculations/pythonjob.py b/src/aiida_pythonjob/calculations/pythonjob.py index 8b50bb3..99897cf 100644 --- a/src/aiida_pythonjob/calculations/pythonjob.py +++ b/src/aiida_pythonjob/calculations/pythonjob.py @@ -6,7 +6,6 @@ import typing as t from aiida.common.datastructures import CalcInfo, CodeInfo -from aiida.common.extendeddicts import AttributeDict from aiida.common.folders import Folder from aiida.engine import CalcJob, CalcJobProcessSpec from aiida.orm import ( @@ -190,6 +189,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: import cloudpickle as pickle from aiida_pythonjob.calculations.utils import generate_script_py + from aiida_pythonjob.data.deserializer import general_deserializer dirpath = pathlib.Path(folder._abspath) @@ -279,17 +279,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: # Create a pickle file for the user input values input_values = {} - for key, value in inputs.items(): - if isinstance(value, Data) and hasattr(value, "value"): - input_values[key] = value.value - elif isinstance(value, (AttributeDict, dict)): - # Convert an AttributeDict/dict with .value items - input_values[key] = {k: v.value for k, v in value.items()} - else: - raise ValueError( - f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or " - "AttributeDict/dict-of-Data are allowed." - ) + input_values = general_deserializer(inputs) filename = "inputs.pickle" with folder.open(filename, "wb") as handle: diff --git a/src/aiida_pythonjob/data/data_with_value.py b/src/aiida_pythonjob/data/data_with_value.py deleted file mode 100644 index 469b810..0000000 --- a/src/aiida_pythonjob/data/data_with_value.py +++ /dev/null @@ -1,13 +0,0 @@ -from aiida import orm - - -class Dict(orm.Dict): - @property - def value(self): - return self.get_dict() - - -class List(orm.List): - @property - def value(self): - return self.get_list() diff --git a/src/aiida_pythonjob/data/deserializer.py b/src/aiida_pythonjob/data/deserializer.py new file mode 100644 index 0000000..4acfc29 --- /dev/null +++ b/src/aiida_pythonjob/data/deserializer.py @@ -0,0 +1,72 @@ +from typing import Any + +from aiida import common, orm + +from aiida_pythonjob.config import load_config + +builtin_deserializers = { + "aiida.orm.nodes.data.list.List": "aiida_pythonjob.data.deserializer.list_data_to_list", + "aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict", + "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms", +} + + +def list_data_to_list(data): + return data.get_list() + + +def dict_data_to_dict(data): + return data.get_dict() + + +def structure_data_to_atoms(structure): + return structure.get_ase() + + +def get_deserializer() -> dict: + """Retrieve the serializer from the entry points.""" + configs = load_config() + custom_deserializers = configs.get("deserializers", {}) + deserializers = builtin_deserializers.copy() + deserializers.update(custom_deserializers) + return deserializers + + +eps_deserializers = get_deserializer() + + +def deserialize_to_raw_python_data(datas: dict) -> dict: + """Deserialize the datas to a dictionary of raw Python data. + + Args: + datas (dict): The datas to be deserialized. + + Returns: + dict: The deserialized datas. + """ + new_datas = {} + # save all kwargs to inputs port + for key, data in datas.items(): + new_datas[key] = general_deserializer(data) + return new_datas + + +def general_deserializer(data: Any) -> orm.Node: + """Deserialize the AiiDA data node to an raw Python data.""" + import importlib + + if isinstance(data, orm.Data): + if hasattr(data, "value"): + return getattr(data, "value") + data_type = type(data) + ep_key = f"{data_type.__module__}.{data_type.__name__}" + if ep_key in eps_deserializers: + module_name, deserializer_name = eps_deserializers[ep_key].rsplit(".", 1) + module = importlib.import_module(module_name) + deserializer = getattr(module, deserializer_name) + return deserializer(data) + else: + raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.") + elif isinstance(data, (common.extendeddicts.AttributeDict, dict)): + # if the data is an AttributeDict, use it directly + return {k: general_deserializer(v) for k, v in data.items()} diff --git a/src/aiida_pythonjob/data/serializer.py b/src/aiida_pythonjob/data/serializer.py index beb8e7f..f5ea204 100644 --- a/src/aiida_pythonjob/data/serializer.py +++ b/src/aiida_pythonjob/data/serializer.py @@ -6,6 +6,7 @@ from aiida_pythonjob.config import load_config +from .deserializer import eps_deserializers from .pickled_data import PickledData @@ -46,7 +47,7 @@ def get_serializer_from_entry_points() -> dict: return eps -eps = get_serializer_from_entry_points() +eps_serializers = get_serializer_from_entry_points() def serialize_to_aiida_nodes(inputs: dict) -> dict: @@ -76,7 +77,10 @@ def general_serializer(data: Any, check_value=True) -> orm.Node: """Serialize the data to an AiiDA data node.""" if isinstance(data, orm.Data): if check_value and not hasattr(data, "value"): - raise ValueError("Only AiiDA data Node with a value attribute is allowed.") + data_type = type(data) + ep_key = f"{data_type.__module__}.{data_type.__name__}" + if ep_key not in eps_deserializers: + raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.") return data elif isinstance(data, common.extendeddicts.AttributeDict): # if the data is an AttributeDict, use it directly @@ -92,9 +96,9 @@ def general_serializer(data: Any, check_value=True) -> orm.Node: data_type = type(data) ep_key = f"{data_type.__module__}.{data_type.__name__}" # search for the key in the entry points - if ep_key in eps: + if ep_key in eps_serializers: try: - new_node = eps[ep_key][0].load()(data) + new_node = eps_serializers[ep_key][0].load()(data) except Exception as e: raise ValueError(f"Error in serializing {ep_key}: {e}") finally: diff --git a/tests/test_data.py b/tests/test_data.py index 9b62e71..7d4b795 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,5 @@ import aiida +import pytest def test_typing(): @@ -36,15 +37,6 @@ def test_python_job(): assert isinstance(new_inputs["c"], PickledData) -def test_dict_list(): - from aiida_pythonjob.data.data_with_value import Dict, List - - data = List([1, 2, 3]) - assert data.value == [1, 2, 3] - data = Dict({"a": 1, "b": 2}) - assert data.value == {"a": 1, "b": 2} - - def test_atoms_data(): from aiida_pythonjob.data.atoms import AtomsData from ase.build import bulk @@ -58,7 +50,11 @@ def test_atoms_data(): def test_only_data_with_value(): from aiida_pythonjob.data import general_serializer - try: - general_serializer(aiida.orm.List([1])) - except ValueError as e: - assert str(e) == "Only AiiDA data Node with a value attribute is allowed." + # do not raise error because the built-in serializer can handle it + general_serializer(aiida.orm.List([1])) + # Test case: aiida.orm.ArrayData should raise a ValueError + with pytest.raises( + ValueError, + match="AiiDA data: aiida.orm.nodes.data.array.array.ArrayData, does not have a value attribute or deserializer.", # noqa + ): + general_serializer(aiida.orm.ArrayData())