Skip to content

Commit

Permalink
Merge pull request #8836 from OpenMined/aziz/ignore_twin_jobs
Browse files Browse the repository at this point in the history
sync: ignore jobs created by custom endpoints
  • Loading branch information
abyesilyurt authored May 21, 2024
2 parents 20dd090 + 04aa85a commit 9b21714
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 12 deletions.
22 changes: 15 additions & 7 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from ..service.job.job_service import JobService
from ..service.job.job_stash import Job
from ..service.job.job_stash import JobStash
from ..service.job.job_stash import JobType
from ..service.log.log_service import LogService
from ..service.metadata.metadata_service import MetadataService
from ..service.metadata.node_metadata import NodeMetadataV3
Expand Down Expand Up @@ -1288,10 +1289,10 @@ def add_api_endpoint_execution_to_queue(

action = Action.from_api_endpoint_execution()
return self.add_queueitem_to_queue(
queue_item,
credentials,
action,
None,
queue_item=queue_item,
credentials=credentials,
action=action,
job_type=JobType.TWINAPIJOB,
)

def get_worker_pool_ref_by_name(
Expand Down Expand Up @@ -1360,16 +1361,22 @@ def add_action_to_queue(
)

return self.add_queueitem_to_queue(
queue_item, credentials, action, parent_job_id, user_id
queue_item=queue_item,
credentials=credentials,
action=action,
parent_job_id=parent_job_id,
user_id=user_id,
)

def add_queueitem_to_queue(
self,
*,
queue_item: QueueItem,
credentials: SyftVerifyKey,
action: Action | None = None,
parent_job_id: UID | None = None,
user_id: UID | None = None,
job_type: JobType = JobType.JOB,
) -> Job | SyftError:
log_id = UID()
role = self.get_role_for_credentials(credentials=credentials)
Expand Down Expand Up @@ -1403,6 +1410,7 @@ def add_queueitem_to_queue(
parent_job_id=parent_job_id,
action=action,
requested_by=user_id,
job_type=job_type,
)

# 🟡 TODO 36: Needs distributed lock
Expand Down Expand Up @@ -1505,8 +1513,8 @@ def add_api_call_to_queue(
worker_pool=worker_pool_ref,
)
return self.add_queueitem_to_queue(
queue_item,
api_call.credentials,
queue_item=queue_item,
credentials=api_call.credentials,
action=None,
parent_job_id=parent_job_id,
)
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
},
"5": {
"version": 5,
"hash": "82ee08442b09797ed7a3710c31de633bb308b1d2215f51b58a3e01a4c201055d",
"hash": "95a2367bce2e4deb5f8c807561779876c1ec010dbf4d4f68abb526e4eca4487e",
"action": "add"
}
},
Expand Down
14 changes: 12 additions & 2 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...store.document_store import UIDPartitionKey
from ...types.datetime import DateTime
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_5
from ...types.syft_object import SYFT_OBJECT_VERSION_6
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
from ...types.uid import UID
Expand Down Expand Up @@ -73,10 +73,19 @@ def center_content(text: Any) -> str:
return center_div


@serializable()
class JobType(str, Enum):
JOB = "job"
TWINAPIJOB = "twinapijob"

def __str__(self) -> str:
return self.value


@serializable()
class Job(SyncableSyftObject):
__canonical_name__ = "JobItem"
__version__ = SYFT_OBJECT_VERSION_5
__version__ = SYFT_OBJECT_VERSION_6

id: UID
node_uid: UID
Expand All @@ -94,6 +103,7 @@ class Job(SyncableSyftObject):
updated_at: DateTime | None = None
user_code_id: UID | None = None
requested_by: UID | None = None
job_type: JobType = JobType.JOB

__attr_searchable__ = ["parent_job_id", "job_worker_id", "status", "user_code_id"]
__repr_attrs__ = [
Expand Down
8 changes: 7 additions & 1 deletion packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..code.user_code import UserCode
from ..code.user_code import UserCodeStatusCollection
from ..job.job_stash import Job
from ..job.job_stash import JobType
from ..log.log import SyftLog
from ..output.output_service import ExecutionOutput
from ..request.request import Request
Expand Down Expand Up @@ -1288,7 +1289,12 @@ def hierarchies(
# TODO: Figure out nested user codes, do we even need that?

root_ids.append(diff.object_id) # type: ignore
elif isinstance(diff_obj, Job) and diff_obj.parent_job_id is None: # type: ignore
elif (
isinstance(diff_obj, Job) # type: ignore
and diff_obj.parent_job_id is None
# ignore Job objects created by TwinAPIEndpoint
and diff_obj.job_type != JobType.TWINAPIJOB
):
root_ids.append(diff.object_id) # type: ignore

for root_uid in root_ids:
Expand Down
2 changes: 2 additions & 0 deletions packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@
SYFT_OBJECT_VERSION_3 = 3
SYFT_OBJECT_VERSION_4 = 4
SYFT_OBJECT_VERSION_5 = 5
SYFT_OBJECT_VERSION_6 = 6

supported_object_versions = [
SYFT_OBJECT_VERSION_1,
SYFT_OBJECT_VERSION_2,
SYFT_OBJECT_VERSION_3,
SYFT_OBJECT_VERSION_4,
SYFT_OBJECT_VERSION_5,
SYFT_OBJECT_VERSION_6,
]

HIGHEST_SYFT_OBJECT_VERSION = max(supported_object_versions)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/local/twin_api_sync_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compute(query):

# verify that ds cannot access private job
assert client_low_ds.api.services.job.get(private_job_id) is None
assert low_client.api.services.job.get(private_job_id) is not None
assert low_client.api.services.job.get(private_job_id) is None

# we only sync the mock function, we never sync the private function to the low side
mock_res = low_client.api.services.testapi.query.mock()
Expand Down

0 comments on commit 9b21714

Please sign in to comment.