Skip to content

Commit

Permalink
Fix meta file processing in storage and improve schedule job retrieval (
Browse files Browse the repository at this point in the history
#2193)

* Fix meta file processing in storage and improve schedule job retrieval

* changed update_unfinished_jobs to use one get_jobs_by_status
  • Loading branch information
yanchengnv authored Dec 7, 2023
1 parent 5e28e31 commit ca6da08
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 58 deletions.
86 changes: 64 additions & 22 deletions nvflare/apis/impl/job_def_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import Job, JobDataKey, JobMetaKey, job_from_meta
Expand All @@ -30,47 +30,77 @@
from nvflare.fuel.utils import fobs
from nvflare.fuel.utils.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes

_OBJ_TAG_SCHEDULED = "scheduled"


class JobInfo:
def __init__(self, meta: dict, job_id: str, uri: str):
self.meta = meta
self.job_id = job_id
self.uri = uri


class _JobFilter(ABC):
@abstractmethod
def filter_job(self, meta: dict) -> bool:
def filter_job(self, info: JobInfo) -> bool:
pass


class _StatusFilter(_JobFilter):
def __init__(self, status_to_check):
self.result = []
if not isinstance(status_to_check, list):
# turning to list
status_to_check = [status_to_check]
self.status_to_check = status_to_check

def filter_job(self, meta: dict):
if meta[JobMetaKey.STATUS] == self.status_to_check:
self.result.append(job_from_meta(meta))
def filter_job(self, info: JobInfo):
status = info.meta.get(JobMetaKey.STATUS.value)
if status in self.status_to_check:
self.result.append(job_from_meta(info.meta))
return True


class _AllJobsFilter(_JobFilter):
def __init__(self):
self.result = []

def filter_job(self, meta: dict):
self.result.append(job_from_meta(meta))
def filter_job(self, info: JobInfo):
self.result.append(job_from_meta(info.meta))
return True


class _ReviewerFilter(_JobFilter):
def __init__(self, reviewer_name, fl_ctx: FLContext):
def __init__(self, reviewer_name):
"""Not used yet, for use in future implementations."""
self.result = []
self.reviewer_name = reviewer_name

def filter_job(self, meta: dict):
approvals = meta.get(JobMetaKey.APPROVALS)
def filter_job(self, info: JobInfo):
approvals = info.meta.get(JobMetaKey.APPROVALS.value)
if not approvals or self.reviewer_name not in approvals:
self.result.append(job_from_meta(meta))
self.result.append(job_from_meta(info.meta))
return True


# TODO:: use try block around storage calls
class _ScheduleJobFilter(_JobFilter):

"""
This filter is optimized for selecting jobs to schedule since it is used so frequently (every 1 sec).
"""

def __init__(self, store):
self.store = store
self.result = []

def filter_job(self, info: JobInfo):
status = info.meta.get(JobMetaKey.STATUS.value)
if status == RunStatus.SUBMITTED.value:
self.result.append(job_from_meta(info.meta))
elif status:
# skip this job in all future calls (so the meta file of this job won't be read)
self.store.tag_object(uri=info.uri, tag=_OBJ_TAG_SCHEDULED)
return True


class SimpleJobDefManager(JobDefManagerSpec):
Expand Down Expand Up @@ -239,28 +269,40 @@ def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]:
self._scan(job_filter, fl_ctx)
return job_filter.result

def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext):
def get_jobs_to_schedule(self, fl_ctx: FLContext) -> List[Job]:
job_filter = _ScheduleJobFilter(self._get_job_store(fl_ctx))
self._scan(job_filter, fl_ctx, skip_tag=_OBJ_TAG_SCHEDULED)
return job_filter.result

def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext, skip_tag=None):
store = self._get_job_store(fl_ctx)
jid_paths = store.list_objects(self.uri_root)
if not jid_paths:
obj_uris = store.list_objects(self.uri_root, without_tag=skip_tag)
self.log_debug(fl_ctx, f"objects to scan: {len(obj_uris)}")
if not obj_uris:
return

for jid_path in jid_paths:
jid = pathlib.PurePath(jid_path).name

meta = store.get_meta(self.job_uri(jid))
for uri in obj_uris:
jid = pathlib.PurePath(uri).name
job_uri = self.job_uri(jid)
meta = store.get_meta(job_uri)
if meta:
ok = job_filter.filter_job(meta)
ok = job_filter.filter_job(JobInfo(meta, jid, job_uri))
if not ok:
break

def get_jobs_by_status(self, status, fl_ctx: FLContext) -> List[Job]:
def get_jobs_by_status(self, status: Union[RunStatus, List[RunStatus]], fl_ctx: FLContext) -> List[Job]:
"""Get jobs that are in the specified status
Args:
status: a single status value or a list of status values
fl_ctx: the FL context
Returns: list of jobs that are in specified status
"""
job_filter = _StatusFilter(status)
self._scan(job_filter, fl_ctx)
return job_filter.result

def get_jobs_waiting_for_review(self, reviewer_name: str, fl_ctx: FLContext) -> List[Job]:
job_filter = _ReviewerFilter(reviewer_name, fl_ctx)
job_filter = _ReviewerFilter(reviewer_name)
self._scan(job_filter, fl_ctx)
return job_filter.result

Expand Down
21 changes: 17 additions & 4 deletions nvflare/apis/job_def_manager_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -103,7 +103,8 @@ def get_job_data(self, jid: str, fl_ctx: FLContext) -> dict:
fl_ctx (FLContext): FLContext information
Returns:
a dict to hold the job data and workspace. With the format: {JobDataKey.JOB_DATA.value: stored_data, JobDataKey.WORKSPACE_DATA: workspace_data}
a dict to hold the job data and workspace. With the format:
{JobDataKey.JOB_DATA.value: stored_data, JobDataKey.WORKSPACE_DATA: workspace_data}
"""
pass
Expand Down Expand Up @@ -145,6 +146,18 @@ def set_status(self, jid: str, status: RunStatus, fl_ctx: FLContext):
"""
pass

@abstractmethod
def get_jobs_to_schedule(self, fl_ctx: FLContext) -> List[Job]:
"""Get job candidates for scheduling.
Args:
fl_ctx: FL context
Returns: list of jobs for scheduling
"""
pass

@abstractmethod
def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]:
"""Gets all Jobs in the system.
Expand All @@ -158,11 +171,11 @@ def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]:
pass

@abstractmethod
def get_jobs_by_status(self, run_status: RunStatus, fl_ctx: FLContext) -> List[Job]:
def get_jobs_by_status(self, run_status: Union[RunStatus, List[RunStatus]], fl_ctx: FLContext) -> List[Job]:
"""Gets Jobs of a specified status.
Args:
run_status (RunStatus): status to filter for
run_status: status to filter for: a single or a list of status values
fl_ctx (FLContext): FLContext information
Returns:
Expand Down
14 changes: 13 additions & 1 deletion nvflare/apis/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ def update_data(self, uri: str, data: bytes):
pass

@abstractmethod
def list_objects(self, path: str) -> List[str]:
def list_objects(self, path: str, without_tag=None) -> List[str]:
"""Lists all objects in the specified path.
Args:
path: the path to the objects
without_tag: skip the objects with this specified tag
Returns:
list of URIs of objects
Expand Down Expand Up @@ -163,3 +164,14 @@ def delete_object(self, uri: str):
"""
pass

@abstractmethod
def tag_object(self, uri: str, tag: str, data=None):
"""Tag an object with specified tag and data.
Args:
uri: URI of the object
tag: tag to be placed on the object
data: data associated with the tag.
Returns: None
"""
pass
Loading

0 comments on commit ca6da08

Please sign in to comment.