Skip to content

Commit

Permalink
Add custom de-serializers and serializers (#15)
Browse files Browse the repository at this point in the history
This PR allows the users to set custom de-serializers and serializers either as input or in the pythonjob.json configuration file.
  • Loading branch information
superstar54 authored Jan 23, 2025
1 parent 16c854c commit 083f592
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 84 deletions.
45 changes: 43 additions & 2 deletions docs/gallery/autogen/how_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def add(x, y):


######################################################################
# Define your data serializer
# Define your data serializer and deserializer
# --------------
#
# PythonJob search data serializer from the `aiida.data` entry point by the
Expand All @@ -376,13 +376,54 @@ def add(x, y):
#
# {
# "serializers": {
# "ase.atoms.Atoms": "abc.ase.atoms.Atoms"
# "ase.atoms.Atoms": "abc.ase.atoms.AtomsData" # use the full path to the serializer
# }
# }
#
# Save the configuration file as `pythonjob.json` in the aiida configuration
# directory (by default, `~/.aiida` directory).
#
# If you want to pass AiiDA Data node as input, and the node does not have a `value` attribute,
# then one must provide a deserializer for it.
#

from aiida import orm # noqa: E402


def make_supercell(structure, n=2):
return structure * [n, n, n]


structure = orm.StructureData(cell=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])
structure.append_atom(position=(0.0, 0.0, 0.0), symbols="Li")

inputs = prepare_pythonjob_inputs(
make_supercell,
function_inputs={"structure": structure},
deserializers={
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms"
},
)
result, node = run_get_node(PythonJob, inputs=inputs)
print("result: ", result["result"])

######################################################################
# One can also set the deserializer in the configuration file.
#
#
# .. code-block:: json
#
# {
# "serializers": {
# "ase.atoms.Atoms": "abc.ase.atoms.Atoms"
# },
# "deserializers": {
# "aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_pymatgen" # noqa
# }
# }
#
# The `orm.List`, `orm.Dict`and `orm.StructureData` data types already have built-in deserializers.
#

######################################################################
# What's Next
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str"
"pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List"
"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict"
"pythonjob.builtins.list"="aiida.orm.nodes.data.list:List"
"pythonjob.builtins.dict"="aiida.orm.nodes.data.dict:Dict"


[project.entry-points."aiida.calculations"]
Expand Down
37 changes: 25 additions & 12 deletions src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import typing as t

from aiida.common.datastructures import CalcInfo, CodeInfo
from aiida.common.extendeddicts import AttributeDict
from aiida.common.folders import Folder
from aiida.engine import CalcJob, CalcJobProcessSpec
from aiida.orm import (
Data,
Dict,
FolderData,
List,
RemoteData,
Expand Down Expand Up @@ -92,6 +92,22 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
serializer=to_aiida_type,
help="Additional filenames to retrieve from the remote work directory",
)
spec.input(
"deserializers",
valid_type=Dict,
default=None,
required=False,
serializer=to_aiida_type,
help="The deserializers to convert the input AiiDA data nodes to raw Python data.",
)
spec.input(
"serializers",
valid_type=Dict,
default=None,
required=False,
serializer=to_aiida_type,
help="The serializers to convert the raw Python data to AiiDA data nodes.",
)
spec.outputs.dynamic = True
# set default options (optional)
spec.inputs["metadata"]["options"]["parser_name"].default = "pythonjob.pythonjob"
Expand Down Expand Up @@ -190,6 +206,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
import cloudpickle as pickle

from aiida_pythonjob.calculations.utils import generate_script_py
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data

dirpath = pathlib.Path(folder._abspath)

Expand Down Expand Up @@ -279,17 +296,13 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:

# Create a pickle file for the user input values
input_values = {}
for key, value in inputs.items():
if isinstance(value, Data) and hasattr(value, "value"):
input_values[key] = value.value
elif isinstance(value, (AttributeDict, dict)):
# Convert an AttributeDict/dict with .value items
input_values[key] = {k: v.value for k, v in value.items()}
else:
raise ValueError(
f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or "
"AttributeDict/dict-of-Data are allowed."
)
if "deserializers" in self.inputs and self.inputs.deserializers:
deserializers = self.inputs.deserializers.get_dict()
# replace "__dot__" with "." in the keys
deserializers = {k.replace("__dot__", "."): v for k, v in deserializers.items()}
else:
deserializers = None
input_values = deserialize_to_raw_python_data(inputs, deserializers=deserializers)

filename = "inputs.pickle"
with folder.open(filename, "wb") as handle:
Expand Down
13 changes: 0 additions & 13 deletions src/aiida_pythonjob/data/data_with_value.py

This file was deleted.

73 changes: 73 additions & 0 deletions src/aiida_pythonjob/data/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from typing import Any

from aiida import common, orm

from aiida_pythonjob.config import load_config
from aiida_pythonjob.utils import import_from_path

builtin_deserializers = {
"aiida.orm.nodes.data.list.List": "aiida_pythonjob.data.deserializer.list_data_to_list",
"aiida.orm.nodes.data.dict.Dict": "aiida_pythonjob.data.deserializer.dict_data_to_dict",
"aiida.orm.nodes.data.structure.StructureData": "aiida_pythonjob.data.deserializer.structure_data_to_atoms",
}


def generate_aiida_node_deserializer(data: orm.Node) -> dict:
if isinstance(data, orm.Data):
return data.backend_entity.attributes
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
# if the data is an AttributeDict, use it directly
return {k: generate_aiida_node_deserializer(v) for k, v in data.items()}


def list_data_to_list(data):
return data.get_list()


def dict_data_to_dict(data):
return data.get_dict()


def structure_data_to_atoms(structure):
return structure.get_ase()


def structure_data_to_pymatgen(structure):
return structure.get_pymatgen()


def get_deserializer() -> dict:
"""Retrieve the serializer from the entry points."""
configs = load_config()
custom_deserializers = configs.get("deserializers", {})
deserializers = builtin_deserializers.copy()
deserializers.update(custom_deserializers)
return deserializers


all_deserializers = get_deserializer()


def deserialize_to_raw_python_data(data: orm.Node, deserializers: dict | None = None) -> Any:
"""Deserialize the AiiDA data node to an raw Python data."""

updated_deserializers = all_deserializers.copy()

if deserializers is not None:
updated_deserializers.update(deserializers)

if isinstance(data, orm.Data):
if hasattr(data, "value"):
return getattr(data, "value")
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key in updated_deserializers:
deserializer = import_from_path(updated_deserializers[ep_key])
return deserializer(data)
else:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
elif isinstance(data, (common.extendeddicts.AttributeDict, dict)):
# if the data is an AttributeDict, use it directly
return {k: deserialize_to_raw_python_data(v, deserializers=deserializers) for k, v in data.items()}
72 changes: 49 additions & 23 deletions src/aiida_pythonjob/data/serializer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from __future__ import annotations

import sys
from importlib.metadata import entry_points
from typing import Any

from aiida import common, orm

from aiida_pythonjob.config import load_config
from aiida_pythonjob.utils import import_from_path

from .deserializer import all_deserializers
from .pickled_data import PickledData


def get_serializer_from_entry_points() -> dict:
"""Retrieve the serializer from the entry points."""
# import time
def atoms_to_structure_data(structure):
return orm.StructureData(ase=structure)

# ts = time.time()
configs = load_config()
serializers = configs.get("serializers", {})
excludes = serializers.get("excludes", [])

def get_serializers_from_entry_points() -> dict:
# Retrieve the entry points for 'aiida.data' and store them in a dictionary
eps = entry_points()
if sys.version_info >= (3, 10):
Expand All @@ -28,28 +29,39 @@ def get_serializer_from_entry_points() -> dict:
# split the entry point name by first ".", and check the last part
key = ep.name.split(".", 1)[-1]
# skip key without "." because it is not a module name for a data type
if "." not in key or key in excludes:
if "." not in key:
continue
eps.setdefault(key, [])
eps[key].append(ep)
# get the path of the entry point value and replace ":" with "."
eps[key].append(ep.value.replace(":", "."))
return eps


def get_serializers() -> dict:
"""Retrieve the serializer from the entry points."""
# import time

# ts = time.time()
all_serializers = {}
configs = load_config()
custom_serializers = configs.get("serializers", {})
eps = get_serializers_from_entry_points()
# check if there are duplicates
for key, value in eps.items():
if len(value) > 1:
if key in serializers:
eps[key] = [ep for ep in value if ep.name == serializers[key]]
if not eps[key]:
raise ValueError(f"Entry point {serializers[key]} not found for {key}")
else:
msg = f"Duplicate entry points for {key}: {[ep.name for ep in value]}"
if key not in custom_serializers:
msg = f"Duplicate entry points for {key}: {value}. You can specify the one to use in the configuration file." # noqa
raise ValueError(msg)
return eps
all_serializers[key] = value[0]
all_serializers.update(custom_serializers)
# print("Time to get serializer", time.time() - ts)
return all_serializers


eps = get_serializer_from_entry_points()
all_serializers = get_serializers()


def serialize_to_aiida_nodes(inputs: dict) -> dict:
def serialize_to_aiida_nodes(inputs: dict, serializers: dict | None = None, deserializers: dict | None = None) -> dict:
"""Serialize the inputs to a dictionary of AiiDA data nodes.
Args:
Expand All @@ -61,7 +73,7 @@ def serialize_to_aiida_nodes(inputs: dict) -> dict:
new_inputs = {}
# save all kwargs to inputs port
for key, data in inputs.items():
new_inputs[key] = general_serializer(data)
new_inputs[key] = general_serializer(data, serializers=serializers, deserializers=deserializers)
return new_inputs


Expand All @@ -72,11 +84,24 @@ def clean_dict_key(data):
return data


def general_serializer(data: Any, check_value=True) -> orm.Node:
def general_serializer(
data: Any, serializers: dict | None = None, deserializers: dict | None = None, check_value=True
) -> orm.Node:
"""Serialize the data to an AiiDA data node."""
updated_deserializers = all_deserializers.copy()
if deserializers is not None:
updated_deserializers.update(deserializers)

updated_serializers = all_serializers.copy()
if serializers is not None:
updated_serializers.update(serializers)

if isinstance(data, orm.Data):
if check_value and not hasattr(data, "value"):
raise ValueError("Only AiiDA data Node with a value attribute is allowed.")
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key not in updated_deserializers:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
return data
elif isinstance(data, common.extendeddicts.AttributeDict):
# if the data is an AttributeDict, use it directly
Expand All @@ -92,9 +117,10 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
# search for the key in the entry points
if ep_key in eps:
if ep_key in updated_serializers:
try:
new_node = eps[ep_key][0].load()(data)
serializer = import_from_path(updated_serializers[ep_key])
new_node = serializer(data)
except Exception as e:
raise ValueError(f"Error in serializing {ep_key}: {e}")
finally:
Expand Down
11 changes: 10 additions & 1 deletion src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def prepare_pythonjob_inputs(
upload_files: Dict[str, str] = {},
process_label: Optional[str] = None,
function_data: dict | None = None,
deserializers: dict | None = None,
serializers: dict | None = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the inputs for PythonJob"""
Expand Down Expand Up @@ -55,14 +57,21 @@ def prepare_pythonjob_inputs(
code = get_or_create_code(computer=computer, **command_info)
# serialize the kwargs into AiiDA Data
function_inputs = function_inputs or {}
function_inputs = serialize_to_aiida_nodes(function_inputs)
function_inputs = serialize_to_aiida_nodes(function_inputs, serializers=serializers, deserializers=deserializers)
function_data["outputs"] = function_outputs or [{"name": "result"}]
# replace "." with "__dot__" in the keys of a dictionary
if deserializers:
deserializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in deserializers.items()})
if serializers:
serializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in serializers.items()})
inputs = {
"function_data": function_data,
"code": code,
"function_inputs": function_inputs,
"upload_files": new_upload_files,
"metadata": metadata or {},
"deserializers": deserializers,
"serializers": serializers,
**kwargs,
}
if process_label:
Expand Down
Loading

0 comments on commit 083f592

Please sign in to comment.