From 47bfb1eb2cbcd7d69403f2a587a806060d92d07b Mon Sep 17 00:00:00 2001 From: teo Date: Thu, 8 Feb 2024 16:33:39 +0200 Subject: [PATCH] remove nested_requests from UserCode --- .../src/syft/protocol/protocol_version.json | 2 +- .../syft/src/syft/service/code/user_code.py | 34 ++++++-- .../syft/service/code/user_code_service.py | 30 +++---- .../syft/service/request/request_service.py | 87 ++++++++++--------- 4 files changed, 88 insertions(+), 65 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index ee7a0cde1bb..2de7c8259cc 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1072,7 +1072,7 @@ "UserCode": { "3": { "version": 3, - "hash": "90fcae0f556f375ba1e91d2e345f57241660695c6e2b84c8e311df89d09e6c66", + "hash": "4b7909bd79f4bd44da79a76fceb781bba98f6c788fb38b7ffa0fe1ffd13d8b51", "action": "add" } }, diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index a449845115d..038e7b238e4 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -342,7 +342,7 @@ class UserCode(SyftObject): enclave_metadata: Optional[EnclaveMetadata] = None submit_time: Optional[DateTime] uses_domain = False # tracks if the code calls domain.something, variable is set during parsing - nested_requests: Dict[str, str] = {} + # nested_requests: Dict[str, str] = {} nested_codes: Optional[Dict[str, Tuple[LinkedObject, Dict]]] = {} worker_pool_name: Optional[str] @@ -1028,9 +1028,24 @@ def new_check_code(context: TransformContext) -> TransformContext: return context +# def locate_launch_jobs(context: TransformContext) -> TransformContext: +# # stdlib +# nested_requests = {} +# tree = ast.parse(context.output["raw_code"]) + +# # look for domain arg +# if "domain" in [arg.arg for arg in tree.body[0].args.args]: +# v = LaunchJobVisitor() +# v.visit(tree) +# nested_calls = v.nested_calls +# for call in nested_calls: +# nested_requests[call] = "latest" + +# context.output["nested_requests"] = nested_requests +# return context + def locate_launch_jobs(context: TransformContext) -> TransformContext: - # stdlib - nested_requests = {} + nested_codes = {} tree = ast.parse(context.output["raw_code"]) # look for domain arg @@ -1038,10 +1053,17 @@ def locate_launch_jobs(context: TransformContext) -> TransformContext: v = LaunchJobVisitor() v.visit(tree) nested_calls = v.nested_calls + user_code_service = context.node.get_service("usercodeService") for call in nested_calls: - nested_requests[call] = "latest" - - context.output["nested_requests"] = nested_requests + user_code = user_code_service.get_by_service_name(context, call) + if isinstance(user_code, SyftError): + raise Exception(user_code.message) + # TODO: Not great + print(user_code) + user_code_link = LinkedObject.from_obj(user_code[0], node_uid=context.node.id) + + nested_codes[call] = (user_code_link, user_code[0].nested_codes) + context.output["nested_codes"] = nested_codes return context diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index da409b1dac0..f45800fabf8 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -105,21 +105,21 @@ def get_by_service_name( return SyftError(message=str(result.err())) return result.ok() - def solve_nested_requests(self, context: AuthedServiceContext, code: UserCode): - nested_requests = code.nested_requests - nested_codes = {} - for service_func_name, version in nested_requests.items(): - codes = self.get_by_service_name( - context=context, service_func_name=service_func_name - ) - if isinstance(codes, SyftError): - return codes - if version == "latest": - nested_codes[service_func_name] = codes[-1] - else: - nested_codes[service_func_name] = codes[int(version)] - - return nested_codes + # def solve_nested_requests(self, context: AuthedServiceContext, code: UserCode): + # nested_requests = code.nested_requests + # nested_codes = {} + # for service_func_name, version in nested_requests.items(): + # codes = self.get_by_service_name( + # context=context, service_func_name=service_func_name + # ) + # if isinstance(codes, SyftError): + # return codes + # if version == "latest": + # nested_codes[service_func_name] = codes[-1] + # else: + # nested_codes[service_func_name] = codes[int(version)] + + # return nested_codes def _request_code_execution( self, diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index b9bfe8761e6..987b6f8e782 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -104,48 +104,48 @@ def submit( print("Failed to submit Request", e) raise e - def expand_node(self, context: AuthedServiceContext, code_obj: UserCode): - user_code_service = context.node.get_service("usercodeservice") - nested_requests = user_code_service.solve_nested_requests(context, code_obj) - - new_nested_requests = {} - for func_name, code in nested_requests.items(): - nested_dict = self.expand_node(context, code) - if isinstance(nested_dict, SyftError): - return nested_dict - code.nested_codes = nested_dict - res = user_code_service.stash.update(context.credentials, code) - if isinstance(res, Err): - return res - linked_obj = LinkedObject.from_obj(code, node_uid=context.node.id) - new_nested_requests[func_name] = (linked_obj, nested_dict) - - return new_nested_requests - - def resolve_nested_requests(self, context, request): - # TODO: change this if we have more UserCode Changes - if len(request.changes) != 1: - return request - - change = request.changes[0] - if isinstance(change, UserCodeStatusChange): - if change.nested_solved: - return request - code_obj = change.linked_obj.resolve_with_context(context=context).ok() - # recursively check what other UserCodes to approve - nested_requests: Dict[str : Tuple[LinkedObject, Dict]] = self.expand_node( - context, code_obj - ) - if isinstance(nested_requests, Err): - return SyftError(message=nested_requests.value) - change.nested_solved = True - code_obj.nested_codes = nested_requests - change.linked_obj.update_with_context(context=context, obj=code_obj) - - request.changes = [change] - new_request = self.save(context=context, request=request) - return new_request - return request + # def expand_node(self, context: AuthedServiceContext, code_obj: UserCode): + # user_code_service = context.node.get_service("usercodeservice") + # nested_requests = user_code_service.solve_nested_requests(context, code_obj) + + # new_nested_requests = {} + # for func_name, code in nested_requests.items(): + # nested_dict = self.expand_node(context, code) + # if isinstance(nested_dict, SyftError): + # return nested_dict + # code.nested_codes = nested_dict + # res = user_code_service.stash.update(context.credentials, code) + # if isinstance(res, Err): + # return res + # linked_obj = LinkedObject.from_obj(code, node_uid=context.node.id) + # new_nested_requests[func_name] = (linked_obj, nested_dict) + + # return new_nested_requests + + # def resolve_nested_requests(self, context, request): + # # TODO: change this if we have more UserCode Changes + # if len(request.changes) != 1: + # return request + + # change = request.changes[0] + # if isinstance(change, UserCodeStatusChange): + # if change.nested_solved: + # return request + # code_obj = change.linked_obj.resolve_with_context(context=context).ok() + # # recursively check what other UserCodes to approve + # nested_requests: Dict[str : Tuple[LinkedObject, Dict]] = self.expand_node( + # context, code_obj + # ) + # if isinstance(nested_requests, Err): + # return SyftError(message=nested_requests.value) + # change.nested_solved = True + # code_obj.nested_codes = nested_requests + # change.linked_obj.update_with_context(context=context, obj=code_obj) + + # request.changes = [change] + # new_request = self.save(context=context, request=request) + # return new_request + # return request @service_method(path="request.get_all", name="get_all") def get_all(self, context: AuthedServiceContext) -> Union[List[Request], SyftError]: @@ -153,7 +153,8 @@ def get_all(self, context: AuthedServiceContext) -> Union[List[Request], SyftErr if result.is_err(): return SyftError(message=str(result.err())) requests = result.ok() - return [self.resolve_nested_requests(context, request) for request in requests] + # return [self.resolve_nested_requests(context, request) for request in requests] + return requests @service_method(path="request.get_all_info", name="get_all_info") def get_all_info(