diff --git a/pyproject.toml b/pyproject.toml index eedd227bee..5c9d205c87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -361,6 +361,7 @@ module = [ 'aiida.cmdline.params.*', 'aiida.common.*', 'aiida.repository.*', + 'aiida.schedulers.*', 'aiida.tools.graph.*', 'aiida.tools.query.*' ] diff --git a/src/aiida/schedulers/datastructures.py b/src/aiida/schedulers/datastructures.py index f3095e97a4..67982a100b 100644 --- a/src/aiida/schedulers/datastructures.py +++ b/src/aiida/schedulers/datastructures.py @@ -21,7 +21,9 @@ import json from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal + +from typing_extensions import Self from aiida.common import AIIDA_LOGGER, CodeRunMode from aiida.common.extendeddicts import AttributeDict, DefaultFieldsAttributeDict @@ -78,16 +80,16 @@ class JobResource(DefaultFieldsAttributeDict, metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def validate_resources(cls, **kwargs): + def validate_resources(cls, **kwargs: Any) -> dict[Any, Any] | None: """Validate the resources against the job resource class of this scheduler. :param kwargs: dictionary of values to define the job resources :raises ValueError: if the resources are invalid or incomplete - :return: optional tuple of parsed resource settings + :return: optional dict of parsed resource settings """ @classmethod - def get_valid_keys(cls): + def get_valid_keys(cls) -> list[str]: """Return a list of valid keys to be passed to the constructor.""" return list(cls._default_fields) @@ -123,7 +125,7 @@ class NodeNumberJobResource(JobResource): num_cores_per_mpiproc: int @classmethod - def validate_resources(cls, **kwargs): + def validate_resources(cls, **kwargs: Any) -> AttributeDict: """Validate the resources against the job resource class of this scheduler. :param kwargs: dictionary of values to define the job resources @@ -132,7 +134,7 @@ def validate_resources(cls, **kwargs): """ resources = AttributeDict() - def is_greater_equal_one(parameter): + def is_greater_equal_one(parameter: str) -> None: value = getattr(resources, parameter, None) if value is not None and value < 1: raise ValueError(f'`{parameter}` must be greater than or equal to one.') @@ -176,7 +178,7 @@ def is_greater_equal_one(parameter): return resources - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): """Initialize the job resources from the passed arguments. :raises ValueError: if the resources are invalid or incomplete @@ -185,16 +187,16 @@ def __init__(self, **kwargs): super().__init__(resources) @classmethod - def get_valid_keys(cls): + def get_valid_keys(cls) -> list[str]: """Return a list of valid keys to be passed to the constructor.""" return super().get_valid_keys() + ['tot_num_mpiprocs'] @classmethod - def accepts_default_mpiprocs_per_machine(cls): + def accepts_default_mpiprocs_per_machine(cls) -> Literal[True]: """Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise.""" return True - def get_tot_num_mpiprocs(self): + def get_tot_num_mpiprocs(self) -> int: """Return the total number of cpus of this job resource.""" return self.num_machines * self.num_mpiprocs_per_machine @@ -212,7 +214,7 @@ class ParEnvJobResource(JobResource): tot_num_mpiprocs: int @classmethod - def validate_resources(cls, **kwargs): + def validate_resources(cls, **kwargs: Any) -> AttributeDict: """Validate the resources against the job resource class of this scheduler. :param kwargs: dictionary of values to define the job resources @@ -242,7 +244,7 @@ def validate_resources(cls, **kwargs): return resources - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): """Initialize the job resources from the passed arguments (the valid keys can be obtained with the function self.get_valid_keys()). @@ -252,11 +254,11 @@ def __init__(self, **kwargs): super().__init__(resources) @classmethod - def accepts_default_mpiprocs_per_machine(cls): + def accepts_default_mpiprocs_per_machine(cls) -> Literal[False]: """Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise.""" return False - def get_tot_num_mpiprocs(self): + def get_tot_num_mpiprocs(self) -> int: """Return the total number of cpus of this job resource.""" return self.tot_num_mpiprocs @@ -569,7 +571,7 @@ class JobInfo(DefaultFieldsAttributeDict): } @staticmethod - def _serialize_job_state(job_state): + def _serialize_job_state(job_state: JobState) -> str: """Return the serialized value of the JobState instance.""" if not isinstance(job_state, JobState): raise TypeError(f'invalid type for value {job_state}, should be an instance of `JobState`') @@ -577,12 +579,12 @@ def _serialize_job_state(job_state): return job_state.value @staticmethod - def _deserialize_job_state(job_state): + def _deserialize_job_state(job_state: str) -> JobState: """Return an instance of JobState from the job_state string.""" return JobState(job_state) @staticmethod - def _serialize_date(value): + def _serialize_date(value: datetime | None) -> dict[str, str | None] | None: """Serialise a data value :param value: The value to serialise :return: The serialised value @@ -600,7 +602,7 @@ def _serialize_date(value): return {'date': value.astimezone(timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%f'), 'timezone': 'UTC'} @staticmethod - def _deserialize_date(value): + def _deserialize_date(value: dict[str, str] | None) -> datetime | None: """Deserialise a date :param value: The date vlue :return: The deserialised date @@ -610,7 +612,7 @@ def _deserialize_date(value): if value['timezone'] is None: # naive date - return datetime.strptime(value['date'], '%Y-%m-%dT%H:%M:%S.%f') + return datetime.strptime(value['date'], '%Y-%m-%dT%H:%M:%S.%f') # type: ignore[unreachable] if value['timezone'] == 'UTC': return make_aware(datetime.strptime(value['date'], '%Y-%m-%dT%H:%M:%S.%f'), timezone.utc) @@ -620,7 +622,7 @@ def _deserialize_date(value): ) @classmethod - def serialize_field(cls, value, field_type): + def serialize_field(cls, value: Any, field_type: str | None) -> Any: """Serialise a particular field value :param value: The value to serialise @@ -635,7 +637,7 @@ def serialize_field(cls, value, field_type): return serializer_method(value) @classmethod - def deserialize_field(cls, value, field_type): + def deserialize_field(cls, value: Any, field_type: str | None) -> Any: """Deserialise the value of a particular field with a type :param value: The value :param field_type: The field type @@ -648,14 +650,14 @@ def deserialize_field(cls, value, field_type): return deserializer_method(value) - def serialize(self): + def serialize(self) -> str: """Serialize the current data (as obtained by ``self.get_dict()``) into a JSON string. :return: A string with serialised representation of the current data. """ return json.dumps(self.get_dict()) - def get_dict(self): + def get_dict(self) -> dict[str, Any]: """Serialise the current data into a dictionary that is JSON-serializable. :return: A dictionary @@ -663,7 +665,7 @@ def get_dict(self): return {k: self.serialize_field(v, self._special_serializers.get(k, None)) for k, v in self.items()} @classmethod - def load_from_dict(cls, data): + def load_from_dict(cls, data: dict[str, Any]) -> Self: """Create a new instance loading the values from serialised data in dictionary form :param data: The dictionary with the data to load from @@ -674,7 +676,7 @@ def load_from_dict(cls, data): return instance @classmethod - def load_from_serialized(cls, data): + def load_from_serialized(cls, data: str) -> Self: """Create a new instance loading the values from JSON-serialised data as a string :param data: The string with the JSON-serialised data to load from diff --git a/src/aiida/schedulers/plugins/pbsbaseclasses.py b/src/aiida/schedulers/plugins/pbsbaseclasses.py index bcceeae19d..24921cf786 100644 --- a/src/aiida/schedulers/plugins/pbsbaseclasses.py +++ b/src/aiida/schedulers/plugins/pbsbaseclasses.py @@ -8,6 +8,8 @@ ########################################################################### """Base classes for PBSPro and PBS/Torque plugins.""" +from __future__ import annotations + import logging from aiida.common.escaping import escape_for_bash @@ -118,8 +120,13 @@ class PbsBaseClass(BashCliScheduler): _map_status = _MAP_STATUS_PBS_COMMON def _get_resource_lines( - self, num_machines, num_mpiprocs_per_machine, num_cores_per_machine, max_memory_kb, max_wallclock_seconds - ): + self, + num_machines: int, + num_mpiprocs_per_machine: int | None, + num_cores_per_machine: int | None, + max_memory_kb: int | None, + max_wallclock_seconds: int | None, + ) -> list[str]: """Return a set a list of lines (possibly empty) with the header lines relative to: diff --git a/src/aiida/schedulers/plugins/pbspro.py b/src/aiida/schedulers/plugins/pbspro.py index 221b526254..d54ee11b62 100644 --- a/src/aiida/schedulers/plugins/pbspro.py +++ b/src/aiida/schedulers/plugins/pbspro.py @@ -10,6 +10,8 @@ This has been tested on PBSPro v. 12. """ +from __future__ import annotations + import logging from .pbsbaseclasses import PbsBaseClass @@ -48,8 +50,13 @@ class PbsproScheduler(PbsBaseClass): # _map_status = _map_status_pbs_common def _get_resource_lines( - self, num_machines, num_mpiprocs_per_machine, num_cores_per_machine, max_memory_kb, max_wallclock_seconds - ): + self, + num_machines: int, + num_mpiprocs_per_machine: int | None, + num_cores_per_machine: int | None, + max_memory_kb: int | None, + max_wallclock_seconds: int | None, + ) -> list[str]: """Return the lines for machines, memory and wallclock relative to pbspro. """ diff --git a/src/aiida/schedulers/plugins/torque.py b/src/aiida/schedulers/plugins/torque.py index b2ca36d64b..0d2fe56f09 100644 --- a/src/aiida/schedulers/plugins/torque.py +++ b/src/aiida/schedulers/plugins/torque.py @@ -10,6 +10,8 @@ This has been tested on Torque v.2.4.16 (from Ubuntu). """ +from __future__ import annotations + import logging from .pbsbaseclasses import PbsBaseClass @@ -42,8 +44,13 @@ class TorqueScheduler(PbsBaseClass): # _map_status = _map_status_pbs_common def _get_resource_lines( - self, num_machines, num_mpiprocs_per_machine, num_cores_per_machine, max_memory_kb, max_wallclock_seconds - ): + self, + num_machines: int, + num_mpiprocs_per_machine: int | None, + num_cores_per_machine: int | None, + max_memory_kb: int | None, + max_wallclock_seconds: int | None, + ) -> list[str]: """Return the lines for machines, memory and wallclock relative to pbspro. """ diff --git a/src/aiida/schedulers/scheduler.py b/src/aiida/schedulers/scheduler.py index 08cc74023a..b664c98c0f 100644 --- a/src/aiida/schedulers/scheduler.py +++ b/src/aiida/schedulers/scheduler.py @@ -49,11 +49,15 @@ class Scheduler(metaclass=abc.ABCMeta): # The class to be used for the job resource. _job_resource_class: t.Type[JobResource] | None = None - def __str__(self): + def __init__(self) -> None: + assert self._job_resource_class is not None and issubclass(self._job_resource_class, JobResource) + self._transport: Transport | None = None + + def __str__(self) -> str: return self.__class__.__name__ @classmethod - def preprocess_resources(cls, resources: dict[str, t.Any], default_mpiprocs_per_machine: int | None = None): + def preprocess_resources(cls, resources: dict[str, t.Any], default_mpiprocs_per_machine: int | None = None) -> None: """Pre process the resources. Add the `num_mpiprocs_per_machine` key to the `resources` if it is not already defined and it cannot be deduced @@ -73,7 +77,7 @@ class of this scheduler does not accept the `num_mpiprocs_per_machine` keyword. resources['num_mpiprocs_per_machine'] = default_mpiprocs_per_machine @classmethod - def validate_resources(cls, **resources): + def validate_resources(cls, **resources: t.Any) -> None: """Validate the resources against the job resource class of this scheduler. :param resources: keyword arguments to define the job resources @@ -82,12 +86,8 @@ def validate_resources(cls, **resources): assert cls._job_resource_class is not None and issubclass(cls._job_resource_class, JobResource) cls._job_resource_class.validate_resources(**resources) - def __init__(self): - assert self._job_resource_class is not None and issubclass(self._job_resource_class, JobResource) - self._transport = None - @classmethod - def get_short_doc(cls): + def get_short_doc(cls) -> str: """Return the first non-empty line of the class docstring, if available.""" # Remove empty lines docstring = cls.__doc__ @@ -107,7 +107,7 @@ def get_feature(self, feature_name: str) -> bool: raise NotImplementedError(f'Feature {feature_name} not implemented for this scheduler') @property - def logger(self): + def logger(self) -> log.AiidaLoggerType: """Return the internal logger.""" try: return self._logger @@ -120,7 +120,7 @@ def job_resource_class(cls) -> t.Type[JobResource]: # noqa: N805 return cls._job_resource_class @classmethod - def create_job_resource(cls, **kwargs): + def create_job_resource(cls, **kwargs: t.Any) -> JobResource: """Create a suitable job resource from the kwargs specified.""" assert cls._job_resource_class is not None and issubclass(cls._job_resource_class, JobResource) return cls._job_resource_class(**kwargs) @@ -374,14 +374,14 @@ def get_detailed_job_info(self, job_id: str) -> dict[str, str | int]: return detailed_job_info @property - def transport(self): + def transport(self) -> Transport: """Return the transport set for this scheduler.""" if self._transport is None: raise SchedulerError('Use the set_transport function to set the transport for the scheduler first.') return self._transport - def set_transport(self, transport: Transport): + def set_transport(self, transport: Transport) -> None: """Set the transport to be used to query the machine or to submit scripts. This class assumes that the transport is open and active.