Skip to content

Commit 104281c

Browse files
committed
Add deserializer
1 parent 16c854c commit 104281c

File tree

7 files changed

+112
-46
lines changed

7 files changed

+112
-46
lines changed

docs/gallery/autogen/how_to.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def add(x, y):
349349

350350

351351
######################################################################
352-
# Define your data serializer
352+
# Define your data serializer and deserializer
353353
# --------------
354354
#
355355
# PythonJob search data serializer from the `aiida.data` entry point by the
@@ -382,7 +382,24 @@ def add(x, y):
382382
#
383383
# Save the configuration file as `pythonjob.json` in the aiida configuration
384384
# directory (by default, `~/.aiida` directory).
385-
385+
#
386+
# If you want to pass AiiDA Data node as input, and the node does not have a `value` attribute,
387+
# then one must provide a deserializer for it. One can set the deserializer in the configuration file.
388+
#
389+
#
390+
# .. code-block:: json
391+
#
392+
# {
393+
# "serializers": {
394+
# "ase.atoms.Atoms": "abc.ase.atoms.Atoms"
395+
# },
396+
# "deserializers": {
397+
# "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms"
398+
# }
399+
# }
400+
#
401+
# The `orm.List`, `orm.Dict` and `orm.StructureData` data types already have built-in deserializers.
402+
#
386403

387404
######################################################################
388405
# What's Next

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
5252
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
5353
"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str"
5454
"pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
55-
"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List"
56-
"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict"
55+
"pythonjob.builtins.list"="aiida.orm.nodes.data.list:List"
56+
"pythonjob.builtins.dict"="aiida.orm.nodes.data.dict:Dict"
5757

5858

5959
[project.entry-points."aiida.calculations"]

src/aiida_pythonjob/calculations/pythonjob.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import typing as t
77

88
from aiida.common.datastructures import CalcInfo, CodeInfo
9-
from aiida.common.extendeddicts import AttributeDict
109
from aiida.common.folders import Folder
1110
from aiida.engine import CalcJob, CalcJobProcessSpec
1211
from aiida.orm import (
@@ -190,6 +189,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
190189
import cloudpickle as pickle
191190

192191
from aiida_pythonjob.calculations.utils import generate_script_py
192+
from aiida_pythonjob.data.deserializer import general_deserializer
193193

194194
dirpath = pathlib.Path(folder._abspath)
195195

@@ -279,17 +279,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
279279

280280
# Create a pickle file for the user input values
281281
input_values = {}
282-
for key, value in inputs.items():
283-
if isinstance(value, Data) and hasattr(value, "value"):
284-
input_values[key] = value.value
285-
elif isinstance(value, (AttributeDict, dict)):
286-
# Convert an AttributeDict/dict with .value items
287-
input_values[key] = {k: v.value for k, v in value.items()}
288-
else:
289-
raise ValueError(
290-
f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or "
291-
"AttributeDict/dict-of-Data are allowed."
292-
)
282+
input_values = general_deserializer(inputs)
293283

294284
filename = "inputs.pickle"
295285
with folder.open(filename, "wb") as handle:

src/aiida_pythonjob/data/data_with_value.py

-13
This file was deleted.
+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any
2+
3+
from aiida import common, orm
4+
5+
from aiida_pythonjob.config import load_config
6+
7+
builtin_deserializers = {
8+
"aiida.orm.nodes.data.list.List": "aiida_pythonjob.data.deserializer.list_data_to_list",
9+
"aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict",
10+
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms",
11+
}
12+
13+
14+
def list_data_to_list(data):
15+
return data.get_list()
16+
17+
18+
def dict_data_to_dict(data):
19+
return data.get_dict()
20+
21+
22+
def structure_data_to_atoms(structure):
23+
return structure.get_ase()
24+
25+
26+
def get_deserializer() -> dict:
27+
"""Retrieve the serializer from the entry points."""
28+
configs = load_config()
29+
custom_deserializers = configs.get("deserializers", {})
30+
deserializers = builtin_deserializers.copy()
31+
deserializers.update(custom_deserializers)
32+
return deserializers
33+
34+
35+
eps_deserializers = get_deserializer()
36+
37+
38+
def deserialize_to_raw_python_data(datas: dict) -> dict:
39+
"""Deserialize the datas to a dictionary of raw Python data.
40+
41+
Args:
42+
datas (dict): The datas to be deserialized.
43+
44+
Returns:
45+
dict: The deserialized datas.
46+
"""
47+
new_datas = {}
48+
# save all kwargs to inputs port
49+
for key, data in datas.items():
50+
new_datas[key] = general_deserializer(data)
51+
return new_datas
52+
53+
54+
def general_deserializer(data: Any) -> orm.Node:
55+
"""Deserialize the AiiDA data node to an raw Python data."""
56+
import importlib
57+
58+
if isinstance(data, orm.Data):
59+
if hasattr(data, "value"):
60+
return getattr(data, "value")
61+
data_type = type(data)
62+
ep_key = f"{data_type.__module__}.{data_type.__name__}"
63+
if ep_key in eps_deserializers:
64+
module_name, deserializer_name = eps_deserializers[ep_key].rsplit(".", 1)
65+
module = importlib.import_module(module_name)
66+
deserializer = getattr(module, deserializer_name)
67+
return deserializer(data)
68+
else:
69+
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
70+
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
71+
# if the data is an AttributeDict, use it directly
72+
return {k: general_deserializer(v) for k, v in data.items()}

src/aiida_pythonjob/data/serializer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from aiida_pythonjob.config import load_config
88

9+
from .deserializer import eps_deserializers
910
from .pickled_data import PickledData
1011

1112

@@ -46,7 +47,7 @@ def get_serializer_from_entry_points() -> dict:
4647
return eps
4748

4849

49-
eps = get_serializer_from_entry_points()
50+
eps_serializers = get_serializer_from_entry_points()
5051

5152

5253
def serialize_to_aiida_nodes(inputs: dict) -> dict:
@@ -76,7 +77,10 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
7677
"""Serialize the data to an AiiDA data node."""
7778
if isinstance(data, orm.Data):
7879
if check_value and not hasattr(data, "value"):
79-
raise ValueError("Only AiiDA data Node with a value attribute is allowed.")
80+
data_type = type(data)
81+
ep_key = f"{data_type.__module__}.{data_type.__name__}"
82+
if ep_key not in eps_deserializers:
83+
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
8084
return data
8185
elif isinstance(data, common.extendeddicts.AttributeDict):
8286
# if the data is an AttributeDict, use it directly
@@ -92,9 +96,9 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
9296
data_type = type(data)
9397
ep_key = f"{data_type.__module__}.{data_type.__name__}"
9498
# search for the key in the entry points
95-
if ep_key in eps:
99+
if ep_key in eps_serializers:
96100
try:
97-
new_node = eps[ep_key][0].load()(data)
101+
new_node = eps_serializers[ep_key][0].load()(data)
98102
except Exception as e:
99103
raise ValueError(f"Error in serializing {ep_key}: {e}")
100104
finally:

tests/test_data.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import aiida
2+
import pytest
23

34

45
def test_typing():
@@ -36,15 +37,6 @@ def test_python_job():
3637
assert isinstance(new_inputs["c"], PickledData)
3738

3839

39-
def test_dict_list():
40-
from aiida_pythonjob.data.data_with_value import Dict, List
41-
42-
data = List([1, 2, 3])
43-
assert data.value == [1, 2, 3]
44-
data = Dict({"a": 1, "b": 2})
45-
assert data.value == {"a": 1, "b": 2}
46-
47-
4840
def test_atoms_data():
4941
from aiida_pythonjob.data.atoms import AtomsData
5042
from ase.build import bulk
@@ -58,7 +50,11 @@ def test_atoms_data():
5850
def test_only_data_with_value():
5951
from aiida_pythonjob.data import general_serializer
6052

61-
try:
62-
general_serializer(aiida.orm.List([1]))
63-
except ValueError as e:
64-
assert str(e) == "Only AiiDA data Node with a value attribute is allowed."
53+
# do not raise error because the built-in serializer can handle it
54+
general_serializer(aiida.orm.List([1]))
55+
# Test case: aiida.orm.ArrayData should raise a ValueError
56+
with pytest.raises(
57+
ValueError,
58+
match="AiiDA data: aiida.orm.nodes.data.array.array.ArrayData, does not have a value attribute or deserializer.", # noqa
59+
):
60+
general_serializer(aiida.orm.ArrayData())

0 commit comments

Comments
 (0)