Skip to content

Commit

Permalink
Add deserializer
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Jan 23, 2025
1 parent 16c854c commit b57c36c
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 46 deletions.
21 changes: 19 additions & 2 deletions docs/gallery/autogen/how_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def add(x, y):


######################################################################
# Define your data serializer
# Define your data serializer and deserializer
# --------------
#
# PythonJob search data serializer from the `aiida.data` entry point by the
Expand Down Expand Up @@ -382,7 +382,24 @@ def add(x, y):
#
# Save the configuration file as `pythonjob.json` in the aiida configuration
# directory (by default, `~/.aiida` directory).

#
# If you want to pass AiiDA Data node as input, and the node does not has a `value` attribute
# then one mush provide a deserializer for it. One can set the deserializer in the configuration file.
#
#
# .. code-block:: json
#
# {
# "serializers": {
# "ase.atoms.Atoms": "abc.ase.atoms.Atoms"
# },
# "deserializers": {
# "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms"
# }
# }
#
# The `orm.List`, `orm.Dict` and `orm.StructureData` data types already have built-in deserializers.
#

######################################################################
# What's Next
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str"
"pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List"
"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict"
"pythonjob.builtins.list"="aiida.orm.nodes.data.list:List"
"pythonjob.builtins.dict"="aiida.orm.nodes.data.dict:Dict"


[project.entry-points."aiida.calculations"]
Expand Down
14 changes: 2 additions & 12 deletions src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import typing as t

from aiida.common.datastructures import CalcInfo, CodeInfo
from aiida.common.extendeddicts import AttributeDict
from aiida.common.folders import Folder
from aiida.engine import CalcJob, CalcJobProcessSpec
from aiida.orm import (
Expand Down Expand Up @@ -190,6 +189,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
import cloudpickle as pickle

from aiida_pythonjob.calculations.utils import generate_script_py
from aiida_pythonjob.data.deserializer import general_deserializer

dirpath = pathlib.Path(folder._abspath)

Expand Down Expand Up @@ -279,17 +279,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:

# Create a pickle file for the user input values
input_values = {}
for key, value in inputs.items():
if isinstance(value, Data) and hasattr(value, "value"):
input_values[key] = value.value
elif isinstance(value, (AttributeDict, dict)):
# Convert an AttributeDict/dict with .value items
input_values[key] = {k: v.value for k, v in value.items()}
else:
raise ValueError(
f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or "
"AttributeDict/dict-of-Data are allowed."
)
input_values = general_deserializer(inputs)

filename = "inputs.pickle"
with folder.open(filename, "wb") as handle:
Expand Down
13 changes: 0 additions & 13 deletions src/aiida_pythonjob/data/data_with_value.py

This file was deleted.

72 changes: 72 additions & 0 deletions src/aiida_pythonjob/data/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any

from aiida import common, orm

from aiida_pythonjob.config import load_config

builtin_deserializers = {
"aiida.orm.nodes.data.list.List": "aiida_pythonjob.data.deserializer.list_data_to_list",
"aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict",
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms",
}


def list_data_to_list(data):
return data.get_list()


def dict_data_to_dict(data):
return data.get_dict()


def structure_data_to_atoms(structure):
return structure.get_ase()


def get_deserializer() -> dict:
"""Retrieve the serializer from the entry points."""
configs = load_config()
custom_deserializers = configs.get("deserializers", {})
deserializers = builtin_deserializers.copy()
deserializers.update(custom_deserializers)
return deserializers


eps_deserializers = get_deserializer()


def deserialize_to_raw_python_data(datas: dict) -> dict:
"""Deserialize the datas to a dictionary of raw Python data.
Args:
datas (dict): The datas to be deserialized.
Returns:
dict: The deserialized datas.
"""
new_datas = {}
# save all kwargs to inputs port
for key, data in datas.items():
new_datas[key] = general_deserializer(data)
return new_datas


def general_deserializer(data: Any) -> orm.Node:
"""Deserialize the AiiDA data node to an raw Python data."""
import importlib

if isinstance(data, orm.Data):
if hasattr(data, "value"):
return getattr(data, "value")
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key in eps_deserializers:
module_name, deserializer_name = eps_deserializers[ep_key].rsplit(".", 1)
module = importlib.import_module(module_name)
deserializer = getattr(module, deserializer_name)
return deserializer(data)
else:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
# if the data is an AttributeDict, use it directly
return {k: general_deserializer(v) for k, v in data.items()}
12 changes: 8 additions & 4 deletions src/aiida_pythonjob/data/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aiida_pythonjob.config import load_config

from .deserializer import eps_deserializers
from .pickled_data import PickledData


Expand Down Expand Up @@ -46,7 +47,7 @@ def get_serializer_from_entry_points() -> dict:
return eps


eps = get_serializer_from_entry_points()
eps_serializers = get_serializer_from_entry_points()


def serialize_to_aiida_nodes(inputs: dict) -> dict:
Expand Down Expand Up @@ -76,7 +77,10 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
"""Serialize the data to an AiiDA data node."""
if isinstance(data, orm.Data):
if check_value and not hasattr(data, "value"):
raise ValueError("Only AiiDA data Node with a value attribute is allowed.")
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key not in eps_deserializers:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
return data
elif isinstance(data, common.extendeddicts.AttributeDict):
# if the data is an AttributeDict, use it directly
Expand All @@ -92,9 +96,9 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
# search for the key in the entry points
if ep_key in eps:
if ep_key in eps_serializers:
try:
new_node = eps[ep_key][0].load()(data)
new_node = eps_serializers[ep_key][0].load()(data)
except Exception as e:
raise ValueError(f"Error in serializing {ep_key}: {e}")
finally:
Expand Down
22 changes: 9 additions & 13 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import aiida
import pytest


def test_typing():
Expand Down Expand Up @@ -36,15 +37,6 @@ def test_python_job():
assert isinstance(new_inputs["c"], PickledData)


def test_dict_list():
from aiida_pythonjob.data.data_with_value import Dict, List

data = List([1, 2, 3])
assert data.value == [1, 2, 3]
data = Dict({"a": 1, "b": 2})
assert data.value == {"a": 1, "b": 2}


def test_atoms_data():
from aiida_pythonjob.data.atoms import AtomsData
from ase.build import bulk
Expand All @@ -58,7 +50,11 @@ def test_atoms_data():
def test_only_data_with_value():
from aiida_pythonjob.data import general_serializer

try:
general_serializer(aiida.orm.List([1]))
except ValueError as e:
assert str(e) == "Only AiiDA data Node with a value attribute is allowed."
# do not raise error because the built-in serializer can handle it
general_serializer(aiida.orm.List([1]))
# Test case: aiida.orm.ArrayData should raise a ValueError
with pytest.raises(
ValueError,
match="AiiDA data: aiida.orm.nodes.data.array.array.ArrayData, does not have a value attribute or deserializer.", # noqa
):
general_serializer(aiida.orm.ArrayData())

0 comments on commit b57c36c

Please sign in to comment.