Skip to content

Commit

Permalink
move filters to compare_client method
Browse files Browse the repository at this point in the history
  • Loading branch information
eelcovdw committed May 22, 2024
1 parent 5573f1b commit fafa4b9
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 79 deletions.
28 changes: 25 additions & 3 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@


def compare_states(
from_state: SyncState, to_state: SyncState, include_ignored: bool = False
from_state: SyncState,
to_state: SyncState,
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: str | type | None = None,
) -> NodeDiff:
# NodeDiff
if (
Expand All @@ -42,11 +47,28 @@ def compare_states(
high_state=high_state,
direction=direction,
include_ignored=include_ignored,
include_same=include_same,
filter_by_email=filter_by_email,
filter_by_type=filter_by_type,
)


def compare_clients(low_client: SyftClient, high_client: SyftClient) -> NodeDiff:
return compare_states(low_client.get_sync_state(), high_client.get_sync_state())
def compare_clients(
from_client: SyftClient,
to_client: SyftClient,
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: type | None = None,
) -> NodeDiff:
return compare_states(
from_client.get_sync_state(),
to_client.get_sync_state(),
include_ignored=include_ignored,
include_same=include_same,
filter_by_email=filter_by_email,
filter_by_type=filter_by_type,
)


def get_user_input_for_resolve() -> SyncDecision:
Expand Down
147 changes: 71 additions & 76 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,10 +1054,10 @@ def from_batch(self, batch: ObjectDiffBatch) -> Any:
if isinstance(user, UserView):
return user.email
return None
elif self == FilterProperty.BATCH_TYPE:
return batch.root_diff.obj_type
elif self == FilterProperty.TYPE:
return batch.root_diff.obj_type.__name__.lower()
elif self == FilterProperty.STATUS:
return batch.status
return batch.status.lower()
elif self == FilterProperty.IGNORED:
return batch.is_ignored
else:
Expand All @@ -1069,7 +1069,7 @@ class NodeDiffFilter:
"""
Filter to apply to a NodeDiff object to determine if it should be included in a batch.
Tests for `property op value` , where
Checks for `property op value` , where
property: FilterProperty - property to filter on
value: Any - value to compare against
op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq`
Expand All @@ -1082,28 +1082,21 @@ class NodeDiffFilter:
op: Callable[[Any, Any], bool] = operator.eq

def __call__(self, batch: ObjectDiffBatch) -> bool:
filter_value = self.filter_value
if isinstance(filter_value, str):
filter_value = filter_value.lower()

try:
p = self.filter_property.from_batch(batch)
if self.op == operator.contains:
# Contains check has reversed arg order: check if p in self.filter_value
return p in self.filter_value
return p in filter_value
else:
return self.op(p, self.filter_value)
return self.op(p, filter_value)
except Exception as e:
# By default, exclude the batch if there is an error
logger.debug(f"Error filtering batch {batch} with {self}: {e}")
return True

def __hash__(self) -> int:
return hash(self.filter_property) + hash(self.filter_value) + hash(self.op)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, NodeDiffFilter):
return False
return (
self.filter_property == other.filter_property
and self.filter_value == other.filter_value
and self.op == other.op
)


class NodeDiff(SyftObject):
Expand Down Expand Up @@ -1160,6 +1153,8 @@ def from_sync_state(
direction: SyncDirection,
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: type | None = None,
_include_node_status: bool = False,
) -> "NodeDiff":
obj_uid_to_diff = {}
Expand Down Expand Up @@ -1212,31 +1207,31 @@ def from_sync_state(
previously_ignored_batches = low_state.ignored_batches
NodeDiff.apply_previous_ignore_state(all_batches, previously_ignored_batches)

filters = []
if not include_ignored:
filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne))
if not include_same:
filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne))

batches = all_batches
for f in filters:
batches = [b for b in batches if f(b)]

return cls(
res = cls(
low_node_uid=low_state.node_uid,
high_node_uid=high_state.node_uid,
user_verify_key_low=low_state.syft_client_verify_key,
user_verify_key_high=high_state.syft_client_verify_key,
obj_uid_to_diff=obj_uid_to_diff,
obj_dependencies=obj_dependencies,
batches=batches,
batches=all_batches,
all_batches=all_batches,
low_state=low_state,
high_state=high_state,
direction=direction,
filters=filters,
filters=[],
)

res._filter(
user_email=filter_by_email,
obj_type=filter_by_type,
include_ignored=include_ignored,
include_same=include_same,
inplace=True,
)

return res

@staticmethod
def apply_previous_ignore_state(
batches: list[ObjectDiffBatch], previously_ignored_batches: dict[UID, int]
Expand Down Expand Up @@ -1414,65 +1409,65 @@ def hierarchies(
def is_same(self) -> bool:
return all(object_diff.status == "SAME" for object_diff in self.diffs)

def _apply_filters(self, filters: list[NodeDiffFilter]) -> Self:
def _apply_filters(
self, filters: list[NodeDiffFilter], inplace: bool = True
) -> Self:
"""
Apply filters to the NodeDiff object and return a new NodeDiff object
"""
batches = self.all_batches
for filter in filters:
batches = [b for b in batches if filter(b)]
return NodeDiff(
low_node_uid=self.low_node_uid,
high_node_uid=self.high_node_uid,
user_verify_key_low=self.user_verify_key_low,
user_verify_key_high=self.user_verify_key_high,
obj_uid_to_diff=self.obj_uid_to_diff,
obj_dependencies=self.obj_dependencies,
batches=batches,
all_batches=self.all_batches,
low_state=self.low_state,
high_state=self.high_state,
direction=self.direction,
filters=filters,
)

def reset_filters(
if inplace:
self.filters = filters
self.batches = batches
return self
else:
return NodeDiff(
low_node_uid=self.low_node_uid,
high_node_uid=self.high_node_uid,
user_verify_key_low=self.user_verify_key_low,
user_verify_key_high=self.user_verify_key_high,
obj_uid_to_diff=self.obj_uid_to_diff,
obj_dependencies=self.obj_dependencies,
batches=batches,
all_batches=self.all_batches,
low_state=self.low_state,
high_state=self.high_state,
direction=self.direction,
filters=filters,
)

def _filter(
self,
user_email: str | None = None,
obj_type: str | type | None = None,
include_ignored: bool = False,
include_same: bool = False,
inplace: bool = True,
) -> Self:
filters = []
if not include_ignored:
filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne))
if not include_same:
filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne))
return self._apply_filters(filters)

def filter(
self,
user: str | None = None,
obj_type: type | None = None,
) -> Self:
current_filters = self.filters
new_filters = []
if user is not None:
new_filters.append(NodeDiffFilter(FilterProperty.USER, user))
if user_email is not None:
new_filters.append(
NodeDiffFilter(FilterProperty.USER, user_email, operator.eq)
)
if obj_type is not None:
new_filters.append(NodeDiffFilter(FilterProperty.TYPE, obj_type))

if len(new_filters) == 0:
return self

new_filter_properties = {f.filter_property for f in new_filters}
# Only add filters that are not in the new filters
# - remove duplicate filters
# - overwrite filters with the same property but different value
# (example: cannot filter on 2 different users)
for current_filter in current_filters:
if current_filter.filter_property not in new_filter_properties:
new_filters.append(current_filter)
if isinstance(obj_type, type):
obj_type = obj_type.__name__
new_filters.append(
NodeDiffFilter(FilterProperty.TYPE, obj_type, operator.eq)
)
if not include_ignored:
new_filters.append(
NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)
)
if not include_same:
new_filters.append(
NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)
)

return self._apply_filters(new_filters)
return self._apply_filters(new_filters, inplace=inplace)


class SyncInstruction(SyftObject):
Expand Down

0 comments on commit fafa4b9

Please sign in to comment.