Skip to content

Commit

Permalink
Merge pull request #8349 from OpenMined/hotfix/hb_897
Browse files Browse the repository at this point in the history
ADD ActionDataLink/ Update action object
  • Loading branch information
koenvanderveen authored Dec 20, 2023
2 parents a8fca78 + 6d3a25e commit 8e6d948
Show file tree
Hide file tree
Showing 8 changed files with 1,002 additions and 187 deletions.
954 changes: 778 additions & 176 deletions notebooks/helm/helm-syft.ipynb

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,11 +1137,26 @@ def add_queueitem_to_queue(
self, queue_item, credentials, action=None, parent_job_id=None
):
log_id = UID()
role = self.get_role_for_credentials(credentials=credentials)
context = AuthedServiceContext(node=self, credentials=credentials, role=role)

result_obj = ActionObject.empty()
if action is not None:
result_obj = ActionObject.obj_not_ready(id=action.result_id)
result_obj.id = action.result_id
result_obj.syft_resolved = False
result_obj.syft_node_location = self.id
result_obj.syft_client_verify_key = credentials

action_service = self.get_service("actionservice")

if not action_service.store.exists(uid=action.result_id):
result = action_service.set_result_to_store(
result_action_object=result_obj,
context=context,
)
if result.is_err():
return result.err()

job = Job(
id=queue_item.job_id,
Expand All @@ -1159,8 +1174,7 @@ def add_queueitem_to_queue(
self.job_stash.set(credentials, job)

log_service = self.get_service("logservice")
role = self.get_role_for_credentials(credentials=credentials)
context = AuthedServiceContext(node=self, credentials=credentials, role=role)

result = log_service.add(context, log_id)
if isinstance(result, SyftError):
return result
Expand Down
14 changes: 14 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,20 @@
"action": "add"
}
},
"ActionDataLink": {
"1": {
"version": 1,
"hash": "10bf94e99637695f1ba283f0b10e70743a4ebcb9ee75aefb1a05e6d6e1d21a71",
"action": "add"
}
},
"ObjectNotReady": {
"1": {
"version": 1,
"hash": "88207988639b11eaca686b6e079616d9caecc3dbc2a8112258e0f39ee5c3e113",
"action": "add"
}
},
"ContainerImage": {
"1": {
"version": 1,
Expand Down
17 changes: 17 additions & 0 deletions packages/syft/src/syft/service/action/action_data_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ...serde.serializable import serializable
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
from ...types.uid import UID

NoneType = type(None)

Expand All @@ -32,6 +33,22 @@ def __str__(self) -> str:
return f"{type(self).__name__} <{self.syft_internal_type}>"


@serializable()
class ObjectNotReady(SyftObject):
__canonical_name__ = "ObjectNotReady"
__version__ = SYFT_OBJECT_VERSION_1

obj_id: UID


@serializable()
class ActionDataLink(SyftObject):
__canonical_name__ = "ActionDataLink"
__version__ = SYFT_OBJECT_VERSION_1

action_object_id: UID


@serializable()
class ActionFileData(SyftObject):
__canonical_name__ = "ActionFileData"
Expand Down
50 changes: 50 additions & 0 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
from io import BytesIO
from pathlib import Path
import time
import traceback
import types
from typing import Any
Expand Down Expand Up @@ -49,7 +50,9 @@
from ..response import SyftException
from ..service import from_api_or_context
from .action_data_empty import ActionDataEmpty
from .action_data_empty import ActionDataLink
from .action_data_empty import ActionFileData
from .action_data_empty import ObjectNotReady
from .action_permissions import ActionPermission
from .action_types import action_type_for_object
from .action_types import action_type_for_type
Expand Down Expand Up @@ -551,6 +554,8 @@ def debox_args_and_kwargs(args: Any, kwargs: Any) -> Tuple[Any, Any]:
"_repr_debug_",
"as_empty",
"get",
"is_link",
"wait",
"_save_to_blob_storage",
"_save_to_blob_storage_",
"syft_action_data",
Expand Down Expand Up @@ -1207,6 +1212,47 @@ def remove_trace_hook(cls):
def as_empty_data(self) -> ActionDataEmpty:
return ActionDataEmpty(syft_internal_type=self.syft_internal_type)

def wait(self):
# relative
from ...client.api import APIRegistry

api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.syft_client_verify_key,
)
if isinstance(self.id, LineageID):
obj_id = self.id.id
else:
obj_id = self.id

while not api.services.action.is_resolved(obj_id):
time.sleep(1)
return self

@staticmethod
def link(
result_id: UID,
pointer_id: Optional[UID] = None,
) -> ActionObject:
link = ActionDataLink(action_object_id=pointer_id)
res = ActionObject.from_obj(
id=result_id,
syft_action_data=link,
)
return res

@staticmethod
def obj_not_ready(
id: UID,
) -> ActionObject:
inner_obj = ObjectNotReady(obj_id=id)

res = ActionObject.from_obj(
id=id,
syft_action_data=inner_obj,
)
return res

@staticmethod
def empty(
syft_internal_type: Type[Any] = NoneType,
Expand Down Expand Up @@ -1619,6 +1665,10 @@ def __getattribute__(self, name: str) -> Any:
res = self._syft_wrap_attribute_for_methods(name)
return res

@property
def is_link(self) -> bool:
return isinstance(self.syft_action_data, ActionDataLink)

def __setattr__(self, name: str, value: Any) -> Any:
defined_on_self = name in self.__dict__ or name in self.__private_attributes__

Expand Down
111 changes: 105 additions & 6 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,97 @@ def set(
return Ok(action_object)
return result.err()

@service_method(
path="action.is_resolved", name="is_resolved", roles=GUEST_ROLE_LEVEL
)
def is_resolved(
self,
context: AuthedServiceContext,
uid: UID,
) -> Result[Ok[bool], Err[str]]:
"""Get an object from the action store"""
# relative
from .action_data_empty import ActionDataLink

result = self._get(context, uid)
if result.is_ok():
obj = result.ok()
if isinstance(obj.syft_action_data, ActionDataLink):
result = self.resolve_links(
context, obj.syft_action_data.action_object_id.id
)

# Checking in case any error occurred
if result.is_err():
return result

return Ok(result.syft_resolved)

# If it's a leaf but not resolved yet, return false
elif not obj.syft_resolved:
return Ok(False)

# If it's not an action data link or non resolved (empty). It's resolved
return Ok(True)

# If it's not in the store or permission error, return the error
return result

@service_method(
path="action.resolve_links", name="resolve_links", roles=GUEST_ROLE_LEVEL
)
def resolve_links(
self,
context: AuthedServiceContext,
uid: UID,
twin_mode: TwinMode = TwinMode.PRIVATE,
) -> Result[Ok[ActionObject], Err[str]]:
"""Get an object from the action store"""
# relative
from .action_data_empty import ActionDataLink

result = self.store.get(uid=uid, credentials=context.credentials)
# If user has permission to get the object / object exists
if result.is_ok():
obj = result.ok()

# If it's not a leaf
if isinstance(obj.syft_action_data, ActionDataLink):
nested_result = self.resolve_links(
context, obj.syft_action_data.action_object_id.id, twin_mode
)
return nested_result

# If it's a leaf
return result

return result

@service_method(path="action.get", name="get", roles=GUEST_ROLE_LEVEL)
def get(
self,
context: AuthedServiceContext,
uid: UID,
twin_mode: TwinMode = TwinMode.PRIVATE,
resolve_nested: bool = True,
) -> Result[Ok[ActionObject], Err[str]]:
"""Get an object from the action store"""
return self._get(context, uid, twin_mode)
return self._get(context, uid, twin_mode, resolve_nested=resolve_nested)

def _get(
self,
context: AuthedServiceContext,
uid: UID,
twin_mode: TwinMode = TwinMode.PRIVATE,
has_permission=False,
resolve_nested: bool = True,
) -> Result[ActionObject, str]:
"""Get an object from the action store"""
# stdlib

# relative
from .action_data_empty import ActionDataLink

result = self.store.get(
uid=uid, credentials=context.credentials, has_permission=has_permission
)
Expand All @@ -133,6 +206,20 @@ def _get(
context.node.id,
context.credentials,
)
# Resolve graph links
if (
not isinstance(obj, TwinObject)
and resolve_nested
and isinstance(obj.syft_action_data, ActionDataLink)
):
if not self.is_resolved(
context, obj.syft_action_data.action_object_id.id
).ok():
return SyftError(message="This object is not resolved yet.")
result = self.resolve_links(
context, obj.syft_action_data.action_object_id.id, twin_mode
)
return result
if isinstance(obj, TwinObject):
if twin_mode == TwinMode.PRIVATE:
obj = obj.private
Expand Down Expand Up @@ -234,7 +321,12 @@ def _user_code_execute(
real_kwargs, twin_mode=TwinMode.NONE
)
exec_result = execute_byte_code(code_item, filtered_kwargs, context)
result_action_object = wrap_result(result_id, exec_result.result)
if isinstance(exec_result.result, ActionObject):
result_action_object = ActionObject.link(
result_id=result_id, pointer_id=exec_result.result.id
)
else:
result_action_object = wrap_result(result_id, exec_result.result)
else:
# twins
private_kwargs = filter_twin_kwargs(
Expand Down Expand Up @@ -272,12 +364,19 @@ def _user_code_execute(
return Err(f"_user_code_execute failed. {e}")
return Ok(result_action_object)

def set_result_to_store(self, result_action_object, context, output_policy):
def set_result_to_store(self, result_action_object, context, output_policy=None):
result_id = result_action_object.id
# result_blob_id = result_action_object.syft_blob_storage_entry_id
output_readers = (
output_policy.output_readers if not context.has_execute_permissions else []
)

if output_policy is not None:
output_readers = (
output_policy.output_readers
if not context.has_execute_permissions
else []
)
else:
output_readers = []

read_permission = ActionPermission.READ

result_action_object._set_obj_location_(
Expand Down
19 changes: 19 additions & 0 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...types.uid import UID
from ...util.markdown import as_markdown_code
from ...util.telemetry import instrument
from ..action.action_data_empty import ActionDataLink
from ..action.action_object import Action
from ..action.action_permissions import ActionObjectPermission
from ..response import SyftError
Expand Down Expand Up @@ -373,11 +374,29 @@ def wait(self):
# stdlib
from time import sleep

api = APIRegistry.api_for(
node_uid=self.node_uid,
user_verify_key=self.syft_client_verify_key,
)

# todo: timeout
if self.resolved:
return self.resolve

print_warning = True
while True:
self.fetch()
if print_warning:
result_obj = api.services.action.get(
self.result.id, resolve_nested=False
)
if isinstance(result_obj.syft_action_data, ActionDataLink):
print(
"You're trying to wait on a job that has a link as a result."
"This means that the job may be ready but the linked result may not."
"Use job.result.wait() instead to wait for the linked result."
)
print_warning = False
sleep(2)
if self.resolved:
break
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/tests/syft/syft_functions/syft_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def process_all(domain, x):
batch_job = domain.launch_job(process_batch, batch=elem)
job_results += [batch_job.result]

result = domain.launch_job(aggregate_job, job_results=job_results)
return result.wait().get()
job = domain.launch_job(aggregate_job, job_results=job_results)
return job.result

process_all.code = dedent(process_all.code)

Expand All @@ -93,4 +93,4 @@ def process_all(domain, x):

sub_results = [j.wait().get() for j in job.subjobs]
assert set(sub_results) == {2, 3, 5}
assert job.wait().get() == 5
assert job.result.wait().get() == 5

0 comments on commit 8e6d948

Please sign in to comment.