Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ module = [
'aiida.cmdline.params.*',
'aiida.common.*',
'aiida.repository.*',
'aiida.schedulers.*',
'aiida.tools.graph.*',
'aiida.tools.query.*'
]
Expand Down
52 changes: 27 additions & 25 deletions src/aiida/schedulers/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import json
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Literal

from typing_extensions import Self

from aiida.common import AIIDA_LOGGER, CodeRunMode
from aiida.common.extendeddicts import AttributeDict, DefaultFieldsAttributeDict
Expand Down Expand Up @@ -78,16 +80,16 @@ class JobResource(DefaultFieldsAttributeDict, metaclass=abc.ABCMeta):

@classmethod
@abc.abstractmethod
def validate_resources(cls, **kwargs):
def validate_resources(cls, **kwargs: Any) -> dict[Any, Any] | None:
"""Validate the resources against the job resource class of this scheduler.

:param kwargs: dictionary of values to define the job resources
:raises ValueError: if the resources are invalid or incomplete
:return: optional tuple of parsed resource settings
:return: optional dict of parsed resource settings
"""

@classmethod
def get_valid_keys(cls):
def get_valid_keys(cls) -> list[str]:
"""Return a list of valid keys to be passed to the constructor."""
return list(cls._default_fields)

Expand Down Expand Up @@ -123,7 +125,7 @@ class NodeNumberJobResource(JobResource):
num_cores_per_mpiproc: int

@classmethod
def validate_resources(cls, **kwargs):
def validate_resources(cls, **kwargs: Any) -> AttributeDict:
"""Validate the resources against the job resource class of this scheduler.

:param kwargs: dictionary of values to define the job resources
Expand All @@ -132,7 +134,7 @@ def validate_resources(cls, **kwargs):
"""
resources = AttributeDict()

def is_greater_equal_one(parameter):
def is_greater_equal_one(parameter: str) -> None:
value = getattr(resources, parameter, None)
if value is not None and value < 1:
raise ValueError(f'`{parameter}` must be greater than or equal to one.')
Expand Down Expand Up @@ -176,7 +178,7 @@ def is_greater_equal_one(parameter):

return resources

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
"""Initialize the job resources from the passed arguments.

:raises ValueError: if the resources are invalid or incomplete
Expand All @@ -185,16 +187,16 @@ def __init__(self, **kwargs):
super().__init__(resources)

@classmethod
def get_valid_keys(cls):
def get_valid_keys(cls) -> list[str]:
"""Return a list of valid keys to be passed to the constructor."""
return super().get_valid_keys() + ['tot_num_mpiprocs']

@classmethod
def accepts_default_mpiprocs_per_machine(cls):
def accepts_default_mpiprocs_per_machine(cls) -> Literal[True]:
"""Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise."""
return True

def get_tot_num_mpiprocs(self):
def get_tot_num_mpiprocs(self) -> int:
"""Return the total number of cpus of this job resource."""
return self.num_machines * self.num_mpiprocs_per_machine

Expand All @@ -212,7 +214,7 @@ class ParEnvJobResource(JobResource):
tot_num_mpiprocs: int

@classmethod
def validate_resources(cls, **kwargs):
def validate_resources(cls, **kwargs: Any) -> AttributeDict:
"""Validate the resources against the job resource class of this scheduler.

:param kwargs: dictionary of values to define the job resources
Expand Down Expand Up @@ -242,7 +244,7 @@ def validate_resources(cls, **kwargs):

return resources

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
"""Initialize the job resources from the passed arguments (the valid keys can be
obtained with the function self.get_valid_keys()).

Expand All @@ -252,11 +254,11 @@ def __init__(self, **kwargs):
super().__init__(resources)

@classmethod
def accepts_default_mpiprocs_per_machine(cls):
def accepts_default_mpiprocs_per_machine(cls) -> Literal[False]:
"""Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise."""
return False

def get_tot_num_mpiprocs(self):
def get_tot_num_mpiprocs(self) -> int:
"""Return the total number of cpus of this job resource."""
return self.tot_num_mpiprocs

Expand Down Expand Up @@ -569,20 +571,20 @@ class JobInfo(DefaultFieldsAttributeDict):
}

@staticmethod
def _serialize_job_state(job_state):
def _serialize_job_state(job_state: JobState) -> str:
"""Return the serialized value of the JobState instance."""
if not isinstance(job_state, JobState):
raise TypeError(f'invalid type for value {job_state}, should be an instance of `JobState`')

return job_state.value

@staticmethod
def _deserialize_job_state(job_state):
def _deserialize_job_state(job_state: str) -> JobState:
"""Return an instance of JobState from the job_state string."""
return JobState(job_state)

@staticmethod
def _serialize_date(value):
def _serialize_date(value: datetime | None) -> dict[str, str | None] | None:
"""Serialise a data value
:param value: The value to serialise
:return: The serialised value
Expand All @@ -600,7 +602,7 @@ def _serialize_date(value):
return {'date': value.astimezone(timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%f'), 'timezone': 'UTC'}

@staticmethod
def _deserialize_date(value):
def _deserialize_date(value: dict[str, str] | None) -> datetime | None:
"""Deserialise a date
:param value: The date vlue
:return: The deserialised date
Expand All @@ -610,7 +612,7 @@ def _deserialize_date(value):

if value['timezone'] is None:
# naive date
return datetime.strptime(value['date'], '%Y-%m-%dT%H:%M:%S.%f')
return datetime.strptime(value['date'], '%Y-%m-%dT%H:%M:%S.%f') # type: ignore[unreachable]
if value['timezone'] == 'UTC':
return make_aware(datetime.strptime(value['date'], '%Y-%m-%dT%H:%M:%S.%f'), timezone.utc)

Expand All @@ -620,7 +622,7 @@ def _deserialize_date(value):
)

@classmethod
def serialize_field(cls, value, field_type):
def serialize_field(cls, value: Any, field_type: str | None) -> Any:
"""Serialise a particular field value

:param value: The value to serialise
Expand All @@ -635,7 +637,7 @@ def serialize_field(cls, value, field_type):
return serializer_method(value)

@classmethod
def deserialize_field(cls, value, field_type):
def deserialize_field(cls, value: Any, field_type: str | None) -> Any:
"""Deserialise the value of a particular field with a type
:param value: The value
:param field_type: The field type
Expand All @@ -648,22 +650,22 @@ def deserialize_field(cls, value, field_type):

return deserializer_method(value)

def serialize(self):
def serialize(self) -> str:
"""Serialize the current data (as obtained by ``self.get_dict()``) into a JSON string.

:return: A string with serialised representation of the current data.
"""
return json.dumps(self.get_dict())

def get_dict(self):
def get_dict(self) -> dict[str, Any]:
"""Serialise the current data into a dictionary that is JSON-serializable.

:return: A dictionary
"""
return {k: self.serialize_field(v, self._special_serializers.get(k, None)) for k, v in self.items()}

@classmethod
def load_from_dict(cls, data):
def load_from_dict(cls, data: dict[str, Any]) -> Self:
"""Create a new instance loading the values from serialised data in dictionary form

:param data: The dictionary with the data to load from
Expand All @@ -674,7 +676,7 @@ def load_from_dict(cls, data):
return instance

@classmethod
def load_from_serialized(cls, data):
def load_from_serialized(cls, data: str) -> Self:
"""Create a new instance loading the values from JSON-serialised data as a string

:param data: The string with the JSON-serialised data to load from
Expand Down
11 changes: 9 additions & 2 deletions src/aiida/schedulers/plugins/pbsbaseclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
###########################################################################
"""Base classes for PBSPro and PBS/Torque plugins."""

from __future__ import annotations

import logging

from aiida.common.escaping import escape_for_bash
Expand Down Expand Up @@ -118,8 +120,13 @@ class PbsBaseClass(BashCliScheduler):
_map_status = _MAP_STATUS_PBS_COMMON

def _get_resource_lines(
self, num_machines, num_mpiprocs_per_machine, num_cores_per_machine, max_memory_kb, max_wallclock_seconds
):
self,
num_machines: int,
num_mpiprocs_per_machine: int | None,
num_cores_per_machine: int | None,
max_memory_kb: int | None,
max_wallclock_seconds: int | None,
) -> list[str]:
"""Return a set a list of lines (possibly empty) with the header
lines relative to:

Expand Down
11 changes: 9 additions & 2 deletions src/aiida/schedulers/plugins/pbspro.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
This has been tested on PBSPro v. 12.
"""

from __future__ import annotations

import logging

from .pbsbaseclasses import PbsBaseClass
Expand Down Expand Up @@ -48,8 +50,13 @@ class PbsproScheduler(PbsBaseClass):
# _map_status = _map_status_pbs_common

def _get_resource_lines(
self, num_machines, num_mpiprocs_per_machine, num_cores_per_machine, max_memory_kb, max_wallclock_seconds
):
self,
num_machines: int,
num_mpiprocs_per_machine: int | None,
num_cores_per_machine: int | None,
max_memory_kb: int | None,
max_wallclock_seconds: int | None,
) -> list[str]:
"""Return the lines for machines, memory and wallclock relative
to pbspro.
"""
Expand Down
11 changes: 9 additions & 2 deletions src/aiida/schedulers/plugins/torque.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
This has been tested on Torque v.2.4.16 (from Ubuntu).
"""

from __future__ import annotations

import logging

from .pbsbaseclasses import PbsBaseClass
Expand Down Expand Up @@ -42,8 +44,13 @@ class TorqueScheduler(PbsBaseClass):
# _map_status = _map_status_pbs_common

def _get_resource_lines(
self, num_machines, num_mpiprocs_per_machine, num_cores_per_machine, max_memory_kb, max_wallclock_seconds
):
self,
num_machines: int,
num_mpiprocs_per_machine: int | None,
num_cores_per_machine: int | None,
max_memory_kb: int | None,
max_wallclock_seconds: int | None,
) -> list[str]:
"""Return the lines for machines, memory and wallclock relative
to pbspro.
"""
Expand Down
24 changes: 12 additions & 12 deletions src/aiida/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ class Scheduler(metaclass=abc.ABCMeta):
# The class to be used for the job resource.
_job_resource_class: t.Type[JobResource] | None = None

def __str__(self):
def __init__(self) -> None:
assert self._job_resource_class is not None and issubclass(self._job_resource_class, JobResource)
self._transport: Transport | None = None

def __str__(self) -> str:
return self.__class__.__name__

@classmethod
def preprocess_resources(cls, resources: dict[str, t.Any], default_mpiprocs_per_machine: int | None = None):
def preprocess_resources(cls, resources: dict[str, t.Any], default_mpiprocs_per_machine: int | None = None) -> None:
"""Pre process the resources.

Add the `num_mpiprocs_per_machine` key to the `resources` if it is not already defined and it cannot be deduced
Expand All @@ -73,7 +77,7 @@ class of this scheduler does not accept the `num_mpiprocs_per_machine` keyword.
resources['num_mpiprocs_per_machine'] = default_mpiprocs_per_machine

@classmethod
def validate_resources(cls, **resources):
def validate_resources(cls, **resources: t.Any) -> None:
"""Validate the resources against the job resource class of this scheduler.

:param resources: keyword arguments to define the job resources
Expand All @@ -82,12 +86,8 @@ def validate_resources(cls, **resources):
assert cls._job_resource_class is not None and issubclass(cls._job_resource_class, JobResource)
cls._job_resource_class.validate_resources(**resources)

def __init__(self):
assert self._job_resource_class is not None and issubclass(self._job_resource_class, JobResource)
self._transport = None

@classmethod
def get_short_doc(cls):
def get_short_doc(cls) -> str:
"""Return the first non-empty line of the class docstring, if available."""
# Remove empty lines
docstring = cls.__doc__
Expand All @@ -107,7 +107,7 @@ def get_feature(self, feature_name: str) -> bool:
raise NotImplementedError(f'Feature {feature_name} not implemented for this scheduler')

@property
def logger(self):
def logger(self) -> log.AiidaLoggerType:
"""Return the internal logger."""
try:
return self._logger
Expand All @@ -120,7 +120,7 @@ def job_resource_class(cls) -> t.Type[JobResource]: # noqa: N805
return cls._job_resource_class

@classmethod
def create_job_resource(cls, **kwargs):
def create_job_resource(cls, **kwargs: t.Any) -> JobResource:
"""Create a suitable job resource from the kwargs specified."""
assert cls._job_resource_class is not None and issubclass(cls._job_resource_class, JobResource)
return cls._job_resource_class(**kwargs)
Expand Down Expand Up @@ -374,14 +374,14 @@ def get_detailed_job_info(self, job_id: str) -> dict[str, str | int]:
return detailed_job_info

@property
def transport(self):
def transport(self) -> Transport:
"""Return the transport set for this scheduler."""
if self._transport is None:
raise SchedulerError('Use the set_transport function to set the transport for the scheduler first.')

return self._transport

def set_transport(self, transport: Transport):
def set_transport(self, transport: Transport) -> None:
"""Set the transport to be used to query the machine or to submit scripts.

This class assumes that the transport is open and active.
Expand Down
Loading