Skip to content

Commit

Permalink
PythonJob adds serializers inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Jan 23, 2025
1 parent c732449 commit 948c4d7
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 54 deletions.
2 changes: 1 addition & 1 deletion docs/gallery/autogen/how_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def add(x, y):
#
# {
# "serializers": {
# "ase.atoms.Atoms": "abc.ase.atoms.Atoms"
# "ase.atoms.Atoms": "abc.ase.atoms.AtomsData" # use the full path to the serializer
# }
# }
#
Expand Down
9 changes: 9 additions & 0 deletions src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,17 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
valid_type=Dict,
default=None,
required=False,
serializer=to_aiida_type,
help="The deserializers to convert the input AiiDA data nodes to raw Python data.",
)
spec.input(
"serializers",
valid_type=Dict,
default=None,
required=False,
serializer=to_aiida_type,
help="The serializers to convert the raw Python data to AiiDA data nodes.",
)
spec.outputs.dynamic = True
# set default options (optional)
spec.inputs["metadata"]["options"]["parser_name"].default = "pythonjob.pythonjob"
Expand Down
14 changes: 6 additions & 8 deletions src/aiida_pythonjob/data/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiida import common, orm

from aiida_pythonjob.config import load_config
from aiida_pythonjob.utils import import_from_path

builtin_deserializers = {
"aiida.orm.nodes.data.list.List": "aiida_pythonjob.data.deserializer.list_data_to_list",
Expand Down Expand Up @@ -46,27 +47,24 @@ def get_deserializer() -> dict:
return deserializers


eps_deserializers = get_deserializer()
all_deserializers = get_deserializer()


def deserialize_to_raw_python_data(data: orm.Node, deserializers: dict | None = None) -> Any:
"""Deserialize the AiiDA data node to an raw Python data."""
import importlib

all_deserializers = eps_deserializers.copy()
updated_deserializers = all_deserializers.copy()

if deserializers is not None:
all_deserializers.update(deserializers)
updated_deserializers.update(deserializers)

if isinstance(data, orm.Data):
if hasattr(data, "value"):
return getattr(data, "value")
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key in all_deserializers:
module_name, deserializer_name = all_deserializers[ep_key].rsplit(".", 1)
module = importlib.import_module(module_name)
deserializer = getattr(module, deserializer_name)
if ep_key in updated_deserializers:
deserializer = import_from_path(updated_deserializers[ep_key])
return deserializer(data)
else:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
Expand Down
68 changes: 42 additions & 26 deletions src/aiida_pythonjob/data/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
from aiida import common, orm

from aiida_pythonjob.config import load_config
from aiida_pythonjob.utils import import_from_path

from .deserializer import eps_deserializers
from .deserializer import all_deserializers
from .pickled_data import PickledData


def get_serializer_from_entry_points() -> dict:
"""Retrieve the serializer from the entry points."""
# import time
def atoms_to_structure_data(structure):
return orm.StructureData(ase=structure)

# ts = time.time()
configs = load_config()
serializers = configs.get("serializers", {})
excludes = serializers.get("excludes", [])

def get_serializers_from_entry_points() -> dict:
# Retrieve the entry points for 'aiida.data' and store them in a dictionary
eps = entry_points()
if sys.version_info >= (3, 10):
Expand All @@ -31,28 +29,39 @@ def get_serializer_from_entry_points() -> dict:
# split the entry point name by first ".", and check the last part
key = ep.name.split(".", 1)[-1]
# skip key without "." because it is not a module name for a data type
if "." not in key or key in excludes:
if "." not in key:
continue
eps.setdefault(key, [])
eps[key].append(ep)
# get the path of the entry point value and replace ":" with "."
eps[key].append(ep.value.replace(":", "."))
return eps


def get_serializers() -> dict:
"""Retrieve the serializer from the entry points."""
# import time

# ts = time.time()
all_serializers = {}
configs = load_config()
custom_serializers = configs.get("serializers", {})
eps = get_serializers_from_entry_points()
# check if there are duplicates
for key, value in eps.items():
if len(value) > 1:
if key in serializers:
eps[key] = [ep for ep in value if ep.name == serializers[key]]
if not eps[key]:
raise ValueError(f"Entry point {serializers[key]} not found for {key}")
else:
msg = f"Duplicate entry points for {key}: {[ep.name for ep in value]}"
if key not in custom_serializers:
msg = f"Duplicate entry points for {key}: {value}. You can specify the one to use in the configuration file." # noqa
raise ValueError(msg)
return eps
all_serializers[key] = value[0]
all_serializers.update(custom_serializers)
# print("Time to get serializer", time.time() - ts)
return all_serializers


eps_serializers = get_serializer_from_entry_points()
all_serializers = get_serializers()


def serialize_to_aiida_nodes(inputs: dict, deserializers: dict | None = None) -> dict:
def serialize_to_aiida_nodes(inputs: dict, serializers: dict | None = None, deserializers: dict | None = None) -> dict:
"""Serialize the inputs to a dictionary of AiiDA data nodes.
Args:
Expand All @@ -64,7 +73,7 @@ def serialize_to_aiida_nodes(inputs: dict, deserializers: dict | None = None) ->
new_inputs = {}
# save all kwargs to inputs port
for key, data in inputs.items():
new_inputs[key] = general_serializer(data, deserializers=deserializers)
new_inputs[key] = general_serializer(data, serializers=serializers, deserializers=deserializers)
return new_inputs


Expand All @@ -75,17 +84,23 @@ def clean_dict_key(data):
return data


def general_serializer(data: Any, check_value=True, deserializers: dict | None = None) -> orm.Node:
def general_serializer(
data: Any, serializers: dict | None = None, deserializers: dict | None = None, check_value=True
) -> orm.Node:
"""Serialize the data to an AiiDA data node."""
all_deserializers = eps_deserializers.copy()
updated_deserializers = all_deserializers.copy()
if deserializers is not None:
all_deserializers.update(deserializers)
updated_deserializers.update(deserializers)

updated_serializers = all_serializers.copy()
if serializers is not None:
updated_serializers.update(serializers)

if isinstance(data, orm.Data):
if check_value and not hasattr(data, "value"):
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
if ep_key not in all_deserializers:
if ep_key not in updated_deserializers:
raise ValueError(f"AiiDA data: {ep_key}, does not have a value attribute or deserializer.")
return data
elif isinstance(data, common.extendeddicts.AttributeDict):
Expand All @@ -102,9 +117,10 @@ def general_serializer(data: Any, check_value=True, deserializers: dict | None =
data_type = type(data)
ep_key = f"{data_type.__module__}.{data_type.__name__}"
# search for the key in the entry points
if ep_key in eps_serializers:
if ep_key in updated_serializers:
try:
new_node = eps_serializers[ep_key][0].load()(data)
serializer = import_from_path(updated_serializers[ep_key])
new_node = serializer(data)
except Exception as e:
raise ValueError(f"Error in serializing {ep_key}: {e}")
finally:
Expand Down
6 changes: 5 additions & 1 deletion src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def prepare_pythonjob_inputs(
process_label: Optional[str] = None,
function_data: dict | None = None,
deserializers: dict | None = None,
serializers: dict | None = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the inputs for PythonJob"""
Expand Down Expand Up @@ -56,18 +57,21 @@ def prepare_pythonjob_inputs(
code = get_or_create_code(computer=computer, **command_info)
# serialize the kwargs into AiiDA Data
function_inputs = function_inputs or {}
function_inputs = serialize_to_aiida_nodes(function_inputs, deserializers=deserializers)
function_inputs = serialize_to_aiida_nodes(function_inputs, serializers=serializers, deserializers=deserializers)
function_data["outputs"] = function_outputs or [{"name": "result"}]
# replace "." with "__dot__" in the keys of a dictionary
if deserializers:
deserializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in deserializers.items()})
if serializers:
serializers = orm.Dict({k.replace(".", "__dot__"): v for k, v in serializers.items()})
inputs = {
"function_data": function_data,
"code": code,
"function_inputs": function_inputs,
"upload_files": new_upload_files,
"metadata": metadata or {},
"deserializers": deserializers,
"serializers": serializers,
**kwargs,
}
if process_label:
Expand Down
12 changes: 10 additions & 2 deletions src/aiida_pythonjob/parsers/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def parse(self, **kwargs):
function_outputs = [{"name": "result"}]
self.output_list = function_outputs

# load custom serializers
if "serializers" in self.node.inputs and self.node.inputs.serializers:
serializers = self.node.inputs.serializers.get_dict()
# replace "__dot__" with "." in the keys
self.serializers = {k.replace("__dot__", "."): v for k, v in serializers.items()}
else:
self.serializers = None

# If nested outputs like "add_multiply.add", keep only top-level
top_level_output_list = [output for output in self.output_list if "." not in output["name"]]

Expand Down Expand Up @@ -144,10 +152,10 @@ def serialize_output(self, result, output):
if full_name_output and full_name_output.get("identifier", "Any").upper() == "NAMESPACE":
serialized_result[key] = self.serialize_output(value, full_name_output)
else:
serialized_result[key] = general_serializer(value)
serialized_result[key] = general_serializer(value, serializers=self.serializers)
return serialized_result
else:
self.logger.error(f"Expected a dict for namespace '{name}', got {type(result)}.")
return self.exit_codes.ERROR_INVALID_OUTPUT
else:
return general_serializer(result)
return general_serializer(result, serializers=self.serializers)
11 changes: 11 additions & 0 deletions src/aiida_pythonjob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from aiida.orm import Computer, InstalledCode, User, load_code, load_computer


def import_from_path(path: str) -> Any:
import importlib

module_name, object_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
try:
return getattr(module, object_name)
except AttributeError:
raise AttributeError(f"{object_name} not found in module {module_name}.")


def get_required_imports(func: Callable) -> Dict[str, set]:
"""Retrieve type hints and the corresponding modules."""
type_hints = get_type_hints(func)
Expand Down
24 changes: 8 additions & 16 deletions tests/test_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@ def create_mock_entry_points(entry_point_list):

@patch("aiida_pythonjob.data.serializer.load_config")
@patch("aiida_pythonjob.data.serializer.entry_points")
def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config):
def test_get_serializers(mock_entry_points, mock_load_config):
# Mock the configuration
mock_load_config.return_value = {
"serializers": {
"excludes": ["excluded_entry"],
}
}
mock_load_config.return_value = {"serializers": {}}
# Mock entry points
mock_ep_1 = create_entry_point("xyz.abc.Abc", "xyz.abc:AbcData", "aiida.data")
mock_ep_2 = create_entry_point("xyz.abc.Bcd", "xyz.abc:BcdData", "aiida.data")
Expand All @@ -40,22 +36,18 @@ def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config):
mock_entry_points.return_value = create_mock_entry_points([mock_ep_1, mock_ep_2, mock_ep_3, mock_ep_4])

# Import the function and run
from aiida_pythonjob.data.serializer import get_serializer_from_entry_points
from aiida_pythonjob.data.serializer import get_serializers

with pytest.raises(ValueError, match="Duplicate entry points for abc.Cde"):
get_serializer_from_entry_points()
get_serializers()
# Mock the configuration
mock_load_config.return_value = {
"serializers": {
"excludes": ["excluded_entry"],
"abc.Cde": "another_xyz.abc.Cde",
"abc.Cde": "another_xyz.abc.CdeData",
}
}
result = get_serializer_from_entry_points()
result = get_serializers()
# Assert results
expected = {
"abc.Abc": [mock_ep_1],
"abc.Bcd": [mock_ep_2],
"abc.Cde": [mock_ep_4],
}
expected = {"abc.Abc": "xyz.abc.AbcData", "abc.Bcd": "xyz.abc.BcdData", "abc.Cde": "another_xyz.abc.CdeData"}
print("result", result)
assert result == expected

0 comments on commit 948c4d7

Please sign in to comment.