From 03e52f07c96e09688cd6fb40341540119abfb7cb Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 25 Jun 2024 15:10:57 +0200 Subject: [PATCH] add build_state arg to resolve --- .../syft/src/syft/service/sync/diff_state.py | 10 +++--- .../src/syft/service/sync/resolve_widget.py | 35 +++++++++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 9778e98f200..d5f8eb60caf 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -566,11 +566,11 @@ class ObjectDiffBatch(SyftObject): root_diff: ObjectDiff sync_direction: SyncDirection | None - def resolve(self) -> "ResolveWidget": + def resolve(self, build_state: bool = True) -> "ResolveWidget": # relative from .resolve_widget import ResolveWidget - return ResolveWidget(self) + return ResolveWidget(self, build_state=build_state) def walk_graph( self, @@ -1142,14 +1142,16 @@ class NodeDiff(SyftObject): include_ignored: bool = False - def resolve(self) -> "PaginatedResolveWidget | SyftSuccess": + def resolve( + self, build_state: bool = True + ) -> "PaginatedResolveWidget | SyftSuccess": if len(self.batches) == 0: return SyftSuccess(message="No batches to resolve") # relative from .resolve_widget import PaginatedResolveWidget - return PaginatedResolveWidget(batches=self.batches) + return PaginatedResolveWidget(batches=self.batches, build_state=build_state) def __getitem__(self, idx: Any) -> ObjectDiffBatch: return self.batches[idx] diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 496fb7a65eb..4a868634df3 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -105,10 +105,19 @@ def __init__( direction: SyncDirection, with_box: bool = True, show_share_warning: bool = False, + build_state: bool = True, ): - self.low_properties = diff.repr_attr_dict("low") - self.high_properties = diff.repr_attr_dict("high") - self.statuses = diff.repr_attr_diffstatus_dict() + build_state = build_state + + if build_state: + self.low_properties = diff.repr_attr_dict("low") + self.high_properties = diff.repr_attr_dict("high") + self.statuses = diff.repr_attr_diffstatus_dict() + else: + self.low_properties = {} + self.high_properties = {} + self.statuses = {} + self.direction = direction self.diff: ObjectDiff = diff self.with_box = with_box @@ -203,9 +212,10 @@ def __init__( self, diff: ObjectDiff, direction: SyncDirection, + build_state: bool = True, ): self.direction = direction - + self.build_state = build_state self.share_private_data = False self.diff: ObjectDiff = diff self.sync: bool = False @@ -275,6 +285,7 @@ def build(self) -> widgets.VBox: self.direction, with_box=False, show_share_warning=self.show_share_button, + build_state=self.build_state, ).widget accordion, share_private_checkbox, sync_checkbox = self.build_accordion( @@ -411,8 +422,12 @@ def _on_share_private_data_change(self, change: Any) -> None: class ResolveWidget: def __init__( - self, obj_diff_batch: ObjectDiffBatch, on_sync_callback: Callable | None = None + self, + obj_diff_batch: ObjectDiffBatch, + on_sync_callback: Callable | None = None, + build_state: bool = True, ): + self.build_state = build_state self.obj_diff_batch: ObjectDiffBatch = obj_diff_batch self.id2widget: dict[ UID, CollapsableObjectDiffWidget | MainObjectDiffWidget @@ -483,6 +498,7 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: CollapsableObjectDiffWidget( diff, direction=self.obj_diff_batch.sync_direction, + build_state=self.build_state, ) for diff in dependents ] @@ -498,7 +514,9 @@ def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: ] widgets = [ CollapsableObjectDiffWidget( - diff, direction=self.obj_diff_batch.sync_direction + diff, + direction=self.obj_diff_batch.sync_direction, + build_state=self.build_state, ) for diff in other_roots ] @@ -509,6 +527,7 @@ def main_object_diff_widget(self) -> MainObjectDiffWidget: obj_diff_widget = MainObjectDiffWidget( self.obj_diff_batch.root_diff, direction=self.obj_diff_batch.sync_direction, + build_state=self.build_state, ) return obj_diff_widget @@ -712,12 +731,14 @@ class PaginatedResolveWidget: paginated by a PaginationControl widget. """ - def __init__(self, batches: list[ObjectDiffBatch]): + def __init__(self, batches: list[ObjectDiffBatch], build_state: bool = True): + self.build_state = build_state self.batches = batches self.resolve_widgets: list[ResolveWidget] = [ ResolveWidget( batch, on_sync_callback=partial(self.on_click_sync, i), + build_state=build_state, ) for i, batch in enumerate(self.batches) ]