Skip to content

Commit

Permalink
add deserializers input
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Jan 23, 2025
1 parent 8442912 commit c59d19e
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 28 deletions.
26 changes: 25 additions & 1 deletion docs/gallery/autogen/how_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,31 @@ def add(x, y):
# directory (by default, `~/.aiida` directory).
#
# If you want to pass AiiDA Data node as input, and the node does not have a `value` attribute,
# then one must provide a deserializer for it. One can set the deserializer in the configuration file.
# then one must provide a deserializer for it.
#

from aiida import orm # noqa: E402


def make_supercell(structure, n=2):
return structure * [n, n, n]


structure = orm.StructureData(cell=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])
structure.append_atom(position=(0.0, 0.0, 0.0), symbols="Li")

inputs = prepare_pythonjob_inputs(
make_supercell,
function_inputs={"structure": structure},
deserializers={
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms"
},
)
result, node = run_get_node(PythonJob, inputs=inputs)
print("result: ", result["result"])

######################################################################
# One can also set the deserializer in the configuration file.
#
#
# .. code-block:: json
Expand Down
18 changes: 16 additions & 2 deletions src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiida.engine import CalcJob, CalcJobProcessSpec
from aiida.orm import (
Data,
Dict,
FolderData,
List,
RemoteData,
Expand Down Expand Up @@ -91,6 +92,13 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
serializer=to_aiida_type,
help="Additional filenames to retrieve from the remote work directory",
)
spec.input(
"deserializers",
valid_type=Dict,
default=None,
required=False,
help="The deserializers to convert the input AiiDA data nodes to raw Python data.",
)
spec.outputs.dynamic = True
# set default options (optional)
spec.inputs["metadata"]["options"]["parser_name"].default = "pythonjob.pythonjob"
Expand Down Expand Up @@ -189,7 +197,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
import cloudpickle as pickle

from aiida_pythonjob.calculations.utils import generate_script_py
from aiida_pythonjob.data.deserializer import general_deserializer
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data

dirpath = pathlib.Path(folder._abspath)

Expand Down Expand Up @@ -279,7 +287,13 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:

# Create a pickle file for the user input values
input_values = {}
input_values = general_deserializer(inputs)
if "deserializers" in self.inputs and self.inputs.deserializers:
deserializers = self.inputs.deserializers.get_dict()
# replace "__dot__" with "." in the keys
deserializers = {k.replace("__dot__", "."): v for k, v in deserializers.items()}
else:
deserializers = None
input_values = deserialize_to_raw_python_data(inputs, deserializers=deserializers)

filename = "inputs.pickle"
with folder.open(filename, "wb") as handle:
Expand Down
39 changes: 19 additions & 20 deletions src/aiida_pythonjob/data/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any

from aiida import common, orm
Expand All @@ -10,6 +12,14 @@
}


def generate_aiida_node_deserializer(data: orm.Node) -> dict:
if isinstance(data, orm.Data):
return data.backend_entity.attributes
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
# if the data is an AttributeDict, use it directly
return {k: generate_aiida_node_deserializer(v) for k, v in data.items()}


def list_data_to_list(data):
return data.get_list()

Expand All @@ -34,38 +44,27 @@ def get_deserializer() -> dict:
eps_deserializers = get_deserializer()


def deserialize_to_raw_python_data(datas: dict) -> dict:
"""Deserialize the datas to a dictionary of raw Python data.
Args:
datas (dict): The datas to be deserialized.
Returns:
dict: The deserialized datas.
"""
new_datas = {}
# save all kwargs to inputs port
for key, data in datas.items():
new_datas[key] = general_deserializer(data)
return new_datas


def general_deserializer(data: Any) -> orm.Node:
def deserialize_to_raw_python_data(data: orm.Node, deserializers: dict | None = None) -> Any:
"""Deserialize the AiiDA data node to an raw Python data."""
import importlib

all_deserializers = eps_deserializers.copy()

if deserializers is not None:
all_deserializers.update(deserializers)

if isinstance(data, orm.Data):
if hasattr(data, "value"):
return getattr(data, "value")
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key in eps_deserializers:
module_name, deserializer_name = eps_deserializers[ep_key].rsplit(".", 1)
if ep_key in all_deserializers:
module_name, deserializer_name = all_deserializers[ep_key].rsplit(".", 1)
module = importlib.import_module(module_name)
deserializer = getattr(module, deserializer_name)
return deserializer(data)
else:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
# if the data is an AttributeDict, use it directly
return {k: general_deserializer(v) for k, v in data.items()}
return {k: deserialize_to_raw_python_data(v, deserializers=deserializers) for k, v in data.items()}
14 changes: 10 additions & 4 deletions src/aiida_pythonjob/data/serializer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import sys
from importlib.metadata import entry_points
from typing import Any
Expand Down Expand Up @@ -50,7 +52,7 @@ def get_serializer_from_entry_points() -> dict:
eps_serializers = get_serializer_from_entry_points()


def serialize_to_aiida_nodes(inputs: dict) -> dict:
def serialize_to_aiida_nodes(inputs: dict, deserializers: dict | None = None) -> dict:
"""Serialize the inputs to a dictionary of AiiDA data nodes.
Args:
Expand All @@ -62,7 +64,7 @@ def serialize_to_aiida_nodes(inputs: dict) -> dict:
new_inputs = {}
# save all kwargs to inputs port
for key, data in inputs.items():
new_inputs[key] = general_serializer(data)
new_inputs[key] = general_serializer(data, deserializers=deserializers)
return new_inputs


Expand All @@ -73,13 +75,17 @@ def clean_dict_key(data):
return data


def general_serializer(data: Any, check_value=True) -> orm.Node:
def general_serializer(data: Any, check_value=True, deserializers: dict | None = None) -> orm.Node:
"""Serialize the data to an AiiDA data node."""
all_deserializers = eps_deserializers.copy()
if deserializers is not None:
all_deserializers.update(deserializers)

if isinstance(data, orm.Data):
if check_value and not hasattr(data, "value"):
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key not in eps_deserializers:
if ep_key not in all_deserializers:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
return data
elif isinstance(data, common.extendeddicts.AttributeDict):
Expand Down
7 changes: 6 additions & 1 deletion src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def prepare_pythonjob_inputs(
upload_files: Dict[str, str] = {},
process_label: Optional[str] = None,
function_data: dict | None = None,
deserializers: dict | None = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the inputs for PythonJob"""
Expand Down Expand Up @@ -55,14 +56,18 @@ def prepare_pythonjob_inputs(
code = get_or_create_code(computer=computer, **command_info)
# serialize the kwargs into AiiDA Data
function_inputs = function_inputs or {}
function_inputs = serialize_to_aiida_nodes(function_inputs)
function_inputs = serialize_to_aiida_nodes(function_inputs, deserializers=deserializers)
function_data["outputs"] = function_outputs or [{"name": "result"}]
# replace "." with "__dot__" in the keys of a dictionary
if deserializers:
deserializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in deserializers.items()})
inputs = {
"function_data": function_data,
"code": code,
"function_inputs": function_inputs,
"upload_files": new_upload_files,
"metadata": metadata or {},
"deserializers": deserializers,
**kwargs,
}
if process_label:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,18 @@ def test_only_data_with_value():
match="AiiDA data: aiida.orm.nodes.data.array.array.ArrayData, does not have a value attribute or deserializer.", # noqa
):
general_serializer(aiida.orm.ArrayData())


def test_deserializer():
import numpy as np
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data

data = aiida.orm.ArrayData()
data.set_array("data", np.array([1, 2, 3]))
data = deserialize_to_raw_python_data(
data,
deserializers={
"aiida.orm.nodes.data.array.array.ArrayData": "aiida_pythonjob.data.deserializer.generate_aiida_node_deserializer" # noqa
},
)
assert data == {"array|data": [3]}

0 comments on commit c59d19e

Please sign in to comment.