Skip to content

Commit

Permalink
Merge pull request #8970 from OpenMined/aziz/cache
Browse files Browse the repository at this point in the history
cache Request.code and DataProtocol
  • Loading branch information
abyesilyurt authored Jun 25, 2024
2 parents a2bd3e2 + 3ac23c0 commit ce9a902
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 17 deletions.
24 changes: 24 additions & 0 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

# third party
from argon2 import PasswordHasher
from cachetools import TTLCache
from cachetools import cached
from pydantic import field_validator
import requests
from requests import Response
Expand Down Expand Up @@ -200,6 +202,8 @@ def session(self) -> Session:
return self.session_cache

def _make_get(self, path: str, params: dict | None = None) -> bytes:
if params is None:
return self._make_get_no_params(path)
url = self.url.with_path(path)
response = self.session.get(
str(url),
Expand All @@ -218,6 +222,26 @@ def _make_get(self, path: str, params: dict | None = None) -> bytes:

return response.content

@cached(cache=TTLCache(maxsize=128, ttl=300))
def _make_get_no_params(self, path: str) -> bytes:
print(path)
url = self.url.with_path(path)
response = self.session.get(
str(url),
headers=self.headers,
verify=verify_tls(),
proxies={},
)
if response.status_code != 200:
raise requests.ConnectionError(
f"Failed to fetch {url}. Response returned with code {response.status_code}"
)

# upgrade to tls if available
self.url = upgrade_tls(self.url, response)

return response.content

def _make_post(
self,
path: str,
Expand Down
11 changes: 10 additions & 1 deletion packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Iterable
from collections.abc import MutableMapping
from collections.abc import MutableSequence
from functools import cache
import hashlib
import json
from operator import itemgetter
Expand Down Expand Up @@ -529,12 +530,20 @@ def reset_dev_protocol(self) -> None:


def get_data_protocol(raise_exception: bool = False) -> DataProtocol:
return DataProtocol(
return _get_data_protocol(
filename=data_protocol_file_name(),
raise_exception=raise_exception,
)


@cache
def _get_data_protocol(filename: str, raise_exception: bool = False) -> DataProtocol:
return DataProtocol(
filename=filename,
raise_exception=raise_exception,
)


def stage_protocol_changes() -> Result[SyftSuccess, SyftError]:
data_protocol = get_data_protocol(raise_exception=True)
return data_protocol.stage_protocol_changes()
Expand Down
13 changes: 10 additions & 3 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ class UserCode(SyncableSyftObject):
origin_node_side_type: NodeSideType
l0_deny_reason: str | None = None

_has_output_read_permissions_cache: bool | None = None

__table_coll_widths__ = [
"min-content",
"auto",
Expand Down Expand Up @@ -439,9 +441,14 @@ def _compute_status_l0(
if isinstance(api, SyftError):
return api
node_identity = NodeIdentity.from_api(api)
is_approved = api.output.has_output_read_permissions(
self.id, self.user_verify_key
)

if self._has_output_read_permissions_cache is None:
is_approved = api.output.has_output_read_permissions(
self.id, self.user_verify_key
)
self._has_output_read_permissions_cache = is_approved
else:
is_approved = self._has_output_read_permissions_cache
else:
# Serverside
node_identity = NodeIdentity.from_node(context.node)
Expand Down
2 changes: 2 additions & 0 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,8 @@ class UserCodeStatusChange(Change):

@property
def code(self) -> UserCode:
if self.linked_user_code._resolve_cache:
return self.linked_user_code._resolve_cache
return self.linked_user_code.resolve

def get_user_code(self, context: AuthedServiceContext) -> UserCode:
Expand Down
10 changes: 6 additions & 4 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,11 @@ class ObjectDiffBatch(SyftObject):
root_diff: ObjectDiff
sync_direction: SyncDirection | None

def resolve(self) -> "ResolveWidget":
def resolve(self, build_state: bool = True) -> "ResolveWidget":
# relative
from .resolve_widget import ResolveWidget

return ResolveWidget(self)
return ResolveWidget(self, build_state=build_state)

def walk_graph(
self,
Expand Down Expand Up @@ -1142,14 +1142,16 @@ class NodeDiff(SyftObject):

include_ignored: bool = False

def resolve(self) -> "PaginatedResolveWidget | SyftSuccess":
def resolve(
self, build_state: bool = True
) -> "PaginatedResolveWidget | SyftSuccess":
if len(self.batches) == 0:
return SyftSuccess(message="No batches to resolve")

# relative
from .resolve_widget import PaginatedResolveWidget

return PaginatedResolveWidget(batches=self.batches)
return PaginatedResolveWidget(batches=self.batches, build_state=build_state)

def __getitem__(self, idx: Any) -> ObjectDiffBatch:
return self.batches[idx]
Expand Down
35 changes: 28 additions & 7 deletions packages/syft/src/syft/service/sync/resolve_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,19 @@ def __init__(
direction: SyncDirection,
with_box: bool = True,
show_share_warning: bool = False,
build_state: bool = True,
):
self.low_properties = diff.repr_attr_dict("low")
self.high_properties = diff.repr_attr_dict("high")
self.statuses = diff.repr_attr_diffstatus_dict()
build_state = build_state

if build_state:
self.low_properties = diff.repr_attr_dict("low")
self.high_properties = diff.repr_attr_dict("high")
self.statuses = diff.repr_attr_diffstatus_dict()
else:
self.low_properties = {}
self.high_properties = {}
self.statuses = {}

self.direction = direction
self.diff: ObjectDiff = diff
self.with_box = with_box
Expand Down Expand Up @@ -203,9 +212,10 @@ def __init__(
self,
diff: ObjectDiff,
direction: SyncDirection,
build_state: bool = True,
):
self.direction = direction

self.build_state = build_state
self.share_private_data = False
self.diff: ObjectDiff = diff
self.sync: bool = False
Expand Down Expand Up @@ -275,6 +285,7 @@ def build(self) -> widgets.VBox:
self.direction,
with_box=False,
show_share_warning=self.show_share_button,
build_state=self.build_state,
).widget

accordion, share_private_checkbox, sync_checkbox = self.build_accordion(
Expand Down Expand Up @@ -411,8 +422,12 @@ def _on_share_private_data_change(self, change: Any) -> None:

class ResolveWidget:
def __init__(
self, obj_diff_batch: ObjectDiffBatch, on_sync_callback: Callable | None = None
self,
obj_diff_batch: ObjectDiffBatch,
on_sync_callback: Callable | None = None,
build_state: bool = True,
):
self.build_state = build_state
self.obj_diff_batch: ObjectDiffBatch = obj_diff_batch
self.id2widget: dict[
UID, CollapsableObjectDiffWidget | MainObjectDiffWidget
Expand Down Expand Up @@ -483,6 +498,7 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]:
CollapsableObjectDiffWidget(
diff,
direction=self.obj_diff_batch.sync_direction,
build_state=self.build_state,
)
for diff in dependents
]
Expand All @@ -498,7 +514,9 @@ def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]:
]
widgets = [
CollapsableObjectDiffWidget(
diff, direction=self.obj_diff_batch.sync_direction
diff,
direction=self.obj_diff_batch.sync_direction,
build_state=self.build_state,
)
for diff in other_roots
]
Expand All @@ -509,6 +527,7 @@ def main_object_diff_widget(self) -> MainObjectDiffWidget:
obj_diff_widget = MainObjectDiffWidget(
self.obj_diff_batch.root_diff,
direction=self.obj_diff_batch.sync_direction,
build_state=self.build_state,
)
return obj_diff_widget

Expand Down Expand Up @@ -712,12 +731,14 @@ class PaginatedResolveWidget:
paginated by a PaginationControl widget.
"""

def __init__(self, batches: list[ObjectDiffBatch]):
def __init__(self, batches: list[ObjectDiffBatch], build_state: bool = True):
self.build_state = build_state
self.batches = batches
self.resolve_widgets: list[ResolveWidget] = [
ResolveWidget(
batch,
on_sync_callback=partial(self.on_click_sync, i),
build_state=build_state,
)
for i, batch in enumerate(self.batches)
]
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/store/linked_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class LinkedObject(SyftObject):
object_type: type[SyftObject]
object_uid: UID

_resolve_cache: SyftObject | None = None

__exclude_sync_diff_attrs__ = ["node_uid"]

def __str__(self) -> str:
Expand All @@ -46,7 +48,9 @@ def resolve(self) -> SyftObject:
if api is None:
raise ValueError(f"api is None. You must login to {self.node_uid}")

return api.services.notifications.resolve_object(self)
resolve: SyftObject = api.services.notifications.resolve_object(self)
self._resolve_cache = resolve
return resolve

def resolve_with_context(self, context: NodeServiceContext) -> Any:
if context.node is None:
Expand Down
8 changes: 7 additions & 1 deletion packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Mapping
from collections.abc import Sequence
from collections.abc import Set
from functools import cache
from hashlib import sha256
import inspect
from inspect import Signature
Expand Down Expand Up @@ -229,6 +230,11 @@ def get_transform(
)


@cache
def cached_get_type_hints(cls: type) -> dict[str, Any]:
return typing.get_type_hints(cls)


class SyftMigrationRegistry:
__migration_version_registry__: dict[str, dict[int, str]] = {}
__migration_transform_registry__: dict[str, dict[str, Callable]] = {}
Expand Down Expand Up @@ -578,7 +584,7 @@ def _syft_set_validate_private_attrs_(self, **kwargs: Any) -> None:
return
# Validate and set private attributes
# https://github.com/pydantic/pydantic/issues/2105
annotations = typing.get_type_hints(self.__class__)
annotations = cached_get_type_hints(self.__class__)
for attr, decl in self.__private_attributes__.items():
value = kwargs.get(attr, decl.get_default())
var_annotation = annotations.get(attr)
Expand Down

0 comments on commit ce9a902

Please sign in to comment.