Skip to content

Commit

Permalink
Merge pull request #8778 from OpenMined/aziz/job_kill
Browse files Browse the repository at this point in the history
feat: update kill/restart behaviour for nested jobs 

Fixes OpenMined/Heartbeat#1159
Fixes OpenMined/Heartbeat#1160
  • Loading branch information
abyesilyurt authored May 14, 2024
2 parents bf873e3 + 2c3815a commit 221986c
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 220 deletions.
53 changes: 30 additions & 23 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,27 @@ def add_api_endpoint_execution_to_queue(
None,
)

def get_worker_pool_ref_by_name(
self, credentials: SyftVerifyKey, worker_pool_name: str | None = None
) -> LinkedObject | SyftError:
# If worker pool id is not set, then use default worker pool
# Else, get the worker pool for given uid
if worker_pool_name is None:
worker_pool = self.get_default_worker_pool()
else:
result = self.pool_stash.get_by_name(credentials, worker_pool_name)
if result.is_err():
return SyftError(message=f"{result.err()}")
worker_pool = result.ok()

# Create a Worker pool reference object
worker_pool_ref = LinkedObject.from_obj(
worker_pool,
service_type=SyftWorkerPoolService,
node_uid=self.id,
)
return worker_pool_ref

def add_action_to_queue(
self,
action: Action,
Expand All @@ -1312,23 +1333,11 @@ def add_action_to_queue(
user_code = result.ok()
worker_pool_name = user_code.worker_pool_name

# If worker pool id is not set, then use default worker pool
# Else, get the worker pool for given uid
if worker_pool_name is None:
worker_pool = self.get_default_worker_pool()
else:
result = self.pool_stash.get_by_name(credentials, worker_pool_name)
if result.is_err():
return SyftError(message=f"{result.err()}")
worker_pool = result.ok()

# Create a Worker pool reference object
worker_pool_ref = LinkedObject.from_obj(
worker_pool,
service_type=SyftWorkerPoolService,
node_uid=self.id,
worker_pool_ref = self.get_worker_pool_ref_by_name(
credentials, worker_pool_name
)

if isinstance(worker_pool_ref, SyftError):
return worker_pool_ref
queue_item = ActionQueueItem(
id=task_uid,
node_uid=self.id,
Expand Down Expand Up @@ -1473,12 +1482,10 @@ def add_api_call_to_queue(

else:
worker_settings = WorkerSettings.from_node(node=self)
default_worker_pool = self.get_default_worker_pool()
worker_pool = LinkedObject.from_obj(
default_worker_pool,
service_type=SyftWorkerPoolService,
node_uid=self.id,
)
worker_pool_ref = self.get_worker_pool_ref_by_name(credentials=credentials)
if isinstance(worker_pool_ref, SyftError):
return worker_pool_ref

queue_item = QueueItem(
id=UID(),
node_uid=self.id,
Expand All @@ -1490,7 +1497,7 @@ def add_api_call_to_queue(
method=method_str,
args=unsigned_call.args,
kwargs=unsigned_call.kwargs,
worker_pool=worker_pool,
worker_pool=worker_pool_ref,
)
return self.add_queueitem_to_queue(
queue_item,
Expand Down
103 changes: 84 additions & 19 deletions packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# stdlib
from collections.abc import Callable
import inspect
import time
from typing import Any
from typing import cast

Expand Down Expand Up @@ -28,6 +31,18 @@
from .job_stash import JobStatus


def wait_until(
predicate: Callable[[], bool], timeout: int = 10
) -> SyftSuccess | SyftError:
start = time.time()
code_string = inspect.getsource(predicate).strip()
while time.time() - start < timeout:
if predicate():
return SyftSuccess(message=f"Predicate {code_string} is True")
time.sleep(1)
return SyftError(message=f"Timeout reached for predicate {code_string}")


@instrument
@serializable()
class JobService(AbstractService):
Expand Down Expand Up @@ -112,16 +127,31 @@ def get_by_result_id(
def restart(
self, context: AuthedServiceContext, uid: UID
) -> SyftSuccess | SyftError:
res = self.stash.get_by_uid(context.credentials, uid=uid)
if res.is_err():
return SyftError(message=res.err())
job_or_err = self.stash.get_by_uid(context.credentials, uid=uid)
if job_or_err.is_err():
return SyftError(message=job_or_err.err())
if job_or_err.ok() is None:
return SyftError(message="Job not found")

job = job_or_err.ok()
if job.parent_job_id is not None:
return SyftError(
message="Not possible to restart subjobs. Please restart the parent job."
)
if job.status == JobStatus.PROCESSING:
return SyftError(
message="Jobs in progress cannot be restarted. "
"Please wait for completion or cancel the job via .cancel() to proceed."
)

job = res.ok()
job.status = JobStatus.CREATED
self.update(context=context, job=job)

task_uid = UID()
worker_settings = WorkerSettings.from_node(context.node)
worker_pool_ref = context.node.get_worker_pool_ref_by_name(context.credentials)
if isinstance(worker_pool_ref, SyftError):
return worker_pool_ref

queue_item = ActionQueueItem(
id=task_uid,
Expand All @@ -132,15 +162,16 @@ def restart(
worker_settings=worker_settings,
args=[],
kwargs={"action": job.action},
worker_pool=worker_pool_ref,
)

context.node.queue_stash.set_placeholder(context.credentials, queue_item)
context.node.job_stash.set(context.credentials, job)

log_service = context.node.get_service("logservice")
result = log_service.restart(context, job.log_id)
if result.is_err():
return SyftError(message=str(result.err()))
if isinstance(result, SyftError):
return result

return SyftSuccess(message="Great Success!")

Expand All @@ -158,28 +189,62 @@ def update(
res = res.ok()
return SyftSuccess(message="Great Success!")

def _kill(self, context: AuthedServiceContext, job: Job) -> SyftSuccess | SyftError:
# set job and subjobs status to TERMINATING
# so that MonitorThread can kill them
job.status = JobStatus.TERMINATING
res = self.stash.update(context.credentials, obj=job)
results = [res]

# attempt to kill all subjobs
subjobs_or_err = self.stash.get_by_parent_id(context.credentials, uid=job.id)
if subjobs_or_err.is_ok() and subjobs_or_err.ok() is not None:
subjobs = subjobs_or_err.ok()
for subjob in subjobs:
subjob.status = JobStatus.TERMINATING
res = self.stash.update(context.credentials, obj=subjob)
results.append(res)

errors = [res.err() for res in results if res.is_err()]
if errors:
return SyftError(message=f"Failed to kill job: {errors}")

# wait for job and subjobs to be killed by MonitorThread
wait_until(lambda: job.fetched_status == JobStatus.INTERRUPTED)
wait_until(
lambda: all(
subjob.fetched_status == JobStatus.INTERRUPTED for subjob in job.subjobs
)
)

return SyftSuccess(message="Job killed successfully!")

@service_method(
path="job.kill",
name="kill",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def kill(self, context: AuthedServiceContext, id: UID) -> SyftSuccess | SyftError:
res = self.stash.get_by_uid(context.credentials, uid=id)
if res.is_err():
return SyftError(message=res.err())
job_or_err = self.stash.get_by_uid(context.credentials, uid=id)
if job_or_err.is_err():
return SyftError(message=job_or_err.err())
if job_or_err.ok() is None:
return SyftError(message="Job not found")

job = res.ok()
if job.job_pid is not None and job.status == JobStatus.PROCESSING:
job.status = JobStatus.INTERRUPTED
res = self.stash.update(context.credentials, obj=job)
if res.is_err():
return SyftError(message=res.err())
return SyftSuccess(message="Job killed successfully!")
else:
job = job_or_err.ok()
if job.parent_job_id is not None:
return SyftError(
message="Job is not running or isn't running in multiprocessing mode."
"Killing threads is currently not supported"
message="Not possible to cancel subjobs. To stop execution, please cancel the parent job."
)
if job.status != JobStatus.PROCESSING:
return SyftError(message="Job is not running")
if job.job_pid is None:
return SyftError(
message="Job termination disabled in dev mode. "
"Set 'dev_mode=False' or 'thread_workers=False' to enable."
)

return self._kill(context, job)

@service_method(
path="job.get_subjobs",
Expand Down
70 changes: 29 additions & 41 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class JobStatus(str, Enum):
PROCESSING = "processing"
ERRORED = "errored"
COMPLETED = "completed"
TERMINATING = "terminating"
INTERRUPTED = "interrupted"


Expand Down Expand Up @@ -254,47 +255,26 @@ def apply_info(self, info: "JobInfo") -> None:
self.result = info.result

def restart(self, kill: bool = False) -> None:
if kill:
self.kill()
self.fetch()
if not self.has_parent:
# this is currently the limitation, we will need to implement
# killing toplevel jobs later
print("Can only kill nested jobs")
elif kill or (
self.status != JobStatus.PROCESSING and self.status != JobStatus.CREATED
):
api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.syft_client_verify_key,
)
if api is None:
raise ValueError(
f"Can't access Syft API. You must login to {self.syft_node_location}"
)
call = SyftAPICall(
node_uid=self.node_uid,
path="job.restart",
args=[],
kwargs={"uid": self.id},
blocking=True,
)

api.make_call(call)
else:
print(
"Job is running or scheduled, if you want to kill it use job.kill() first"
api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.syft_client_verify_key,
)
if api is None:
raise ValueError(
f"Can't access Syft API. You must login to {self.syft_node_location}"
)
return None
call = SyftAPICall(
node_uid=self.node_uid,
path="job.restart",
args=[],
kwargs={"uid": self.id},
blocking=True,
)
res = api.make_call(call)
self.fetch()
return res

def kill(self) -> SyftError | SyftSuccess:
if self.status != JobStatus.PROCESSING:
return SyftError(message="Job is not running")
if self.job_pid is None:
return SyftError(
message="Job termination disabled in dev mode. "
"Set 'dev_mode=False' or 'thread_workers=False' to enable."
)
api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.syft_client_verify_key,
Expand All @@ -310,8 +290,9 @@ def kill(self) -> SyftError | SyftSuccess:
kwargs={"id": self.id},
blocking=True,
)
api.make_call(call)
return SyftSuccess(message="Job is killed successfully!")
res = api.make_call(call)
self.fetch()
return res

def fetch(self) -> None:
api = APIRegistry.api_for(
Expand All @@ -329,7 +310,9 @@ def fetch(self) -> None:
kwargs={"uid": self.id},
blocking=True,
)
job: Job = api.make_call(call)
job: Job | None = api.make_call(call)
if job is None:
return
self.resolved = job.resolved
if job.resolved:
self.result = job.result
Expand Down Expand Up @@ -532,6 +515,11 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str:
"""
return as_markdown_code(md)

@property
def fetched_status(self) -> JobStatus:
self.fetch()
return self.status

@property
def requesting_user(self) -> UserView | SyftError:
api = APIRegistry.api_for(
Expand Down
Loading

0 comments on commit 221986c

Please sign in to comment.