Skip to content

Commit

Permalink
Merge branch 'aziz/cache' of github.com:OpenMined/PySyft into aziz/cache
Browse files Browse the repository at this point in the history
  • Loading branch information
abyesilyurt committed Jun 25, 2024
2 parents 258afeb + 17d62db commit 9e71b68
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
10 changes: 6 additions & 4 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,11 @@ class ObjectDiffBatch(SyftObject):
root_diff: ObjectDiff
sync_direction: SyncDirection | None

def resolve(self) -> "ResolveWidget":
def resolve(self, build_state: bool = True) -> "ResolveWidget":
# relative
from .resolve_widget import ResolveWidget

return ResolveWidget(self)
return ResolveWidget(self, build_state=build_state)

def walk_graph(
self,
Expand Down Expand Up @@ -1142,14 +1142,16 @@ class NodeDiff(SyftObject):

include_ignored: bool = False

def resolve(self) -> "PaginatedResolveWidget | SyftSuccess":
def resolve(
self, build_state: bool = True
) -> "PaginatedResolveWidget | SyftSuccess":
if len(self.batches) == 0:
return SyftSuccess(message="No batches to resolve")

# relative
from .resolve_widget import PaginatedResolveWidget

return PaginatedResolveWidget(batches=self.batches)
return PaginatedResolveWidget(batches=self.batches, build_state=build_state)

def __getitem__(self, idx: Any) -> ObjectDiffBatch:
return self.batches[idx]
Expand Down
35 changes: 28 additions & 7 deletions packages/syft/src/syft/service/sync/resolve_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,19 @@ def __init__(
direction: SyncDirection,
with_box: bool = True,
show_share_warning: bool = False,
build_state: bool = True,
):
self.low_properties = diff.repr_attr_dict("low")
self.high_properties = diff.repr_attr_dict("high")
self.statuses = diff.repr_attr_diffstatus_dict()
build_state = build_state

if build_state:
self.low_properties = diff.repr_attr_dict("low")
self.high_properties = diff.repr_attr_dict("high")
self.statuses = diff.repr_attr_diffstatus_dict()
else:
self.low_properties = {}
self.high_properties = {}
self.statuses = {}

self.direction = direction
self.diff: ObjectDiff = diff
self.with_box = with_box
Expand Down Expand Up @@ -203,9 +212,10 @@ def __init__(
self,
diff: ObjectDiff,
direction: SyncDirection,
build_state: bool = True,
):
self.direction = direction

self.build_state = build_state
self.share_private_data = False
self.diff: ObjectDiff = diff
self.sync: bool = False
Expand Down Expand Up @@ -275,6 +285,7 @@ def build(self) -> widgets.VBox:
self.direction,
with_box=False,
show_share_warning=self.show_share_button,
build_state=self.build_state,
).widget

accordion, share_private_checkbox, sync_checkbox = self.build_accordion(
Expand Down Expand Up @@ -411,8 +422,12 @@ def _on_share_private_data_change(self, change: Any) -> None:

class ResolveWidget:
def __init__(
self, obj_diff_batch: ObjectDiffBatch, on_sync_callback: Callable | None = None
self,
obj_diff_batch: ObjectDiffBatch,
on_sync_callback: Callable | None = None,
build_state: bool = True,
):
self.build_state = build_state
self.obj_diff_batch: ObjectDiffBatch = obj_diff_batch
self.id2widget: dict[
UID, CollapsableObjectDiffWidget | MainObjectDiffWidget
Expand Down Expand Up @@ -483,6 +498,7 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]:
CollapsableObjectDiffWidget(
diff,
direction=self.obj_diff_batch.sync_direction,
build_state=self.build_state,
)
for diff in dependents
]
Expand All @@ -498,7 +514,9 @@ def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]:
]
widgets = [
CollapsableObjectDiffWidget(
diff, direction=self.obj_diff_batch.sync_direction
diff,
direction=self.obj_diff_batch.sync_direction,
build_state=self.build_state,
)
for diff in other_roots
]
Expand All @@ -509,6 +527,7 @@ def main_object_diff_widget(self) -> MainObjectDiffWidget:
obj_diff_widget = MainObjectDiffWidget(
self.obj_diff_batch.root_diff,
direction=self.obj_diff_batch.sync_direction,
build_state=self.build_state,
)
return obj_diff_widget

Expand Down Expand Up @@ -712,12 +731,14 @@ class PaginatedResolveWidget:
paginated by a PaginationControl widget.
"""

def __init__(self, batches: list[ObjectDiffBatch]):
def __init__(self, batches: list[ObjectDiffBatch], build_state: bool = True):
self.build_state = build_state
self.batches = batches
self.resolve_widgets: list[ResolveWidget] = [
ResolveWidget(
batch,
on_sync_callback=partial(self.on_click_sync, i),
build_state=build_state,
)
for i, batch in enumerate(self.batches)
]
Expand Down

0 comments on commit 9e71b68

Please sign in to comment.