Skip to content

Commit

Permalink
[syft/network] move logic of comparing routes into their __eq__
Browse files Browse the repository at this point in the history
  • Loading branch information
khoaguin committed Apr 9, 2024
1 parent 4f68061 commit 0bc9757
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 86 deletions.
76 changes: 2 additions & 74 deletions packages/syft/src/syft/service/network/node_peer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# stdlib
from collections.abc import Callable

# third party
from result import Err
Expand Down Expand Up @@ -67,10 +66,8 @@ def existed_route(
if route:
if not isinstance(route, HTTPNodeRoute | PythonNodeRoute | VeilidNodeRoute):
raise ValueError(f"Unsupported route type: {type(route)}")

same_route: Callable = _route_type_to_same_route_check(route)
for i, r in enumerate(self.node_routes):
if same_route(route, r):
if route == r:
return (True, i)

elif route_id:
Expand Down Expand Up @@ -256,79 +253,10 @@ def delete_route(

if route:
try:
same_route: Callable = _route_type_to_same_route_check(route)
self.node_routes = [
r for r in self.node_routes if not same_route(r, route)
]
self.node_routes = [r for r in self.node_routes if r != route]
except Exception as e:
return SyftError(
message=f"Error deleting route with id {route.id}. Exception: {e}"
)

return None


def _route_type_to_same_route_check(
route: NodeRouteType,
) -> Callable[[NodeRouteType, NodeRouteType], bool]:
"""
Takes a route as input and returns a function that can be
used to compare if the two routes are the same.
Args:
route (NodeRouteType): The route for which to get a comparison function.
Returns:
Callable[[NodeRouteType, NodeRouteType], bool]: A function that takes two routes as input and returns a boolean
indicating whether the routes are the same.
"""
route_type_to_comparison_method: dict[
type[NodeRouteType], Callable[[NodeRouteType, NodeRouteType], bool]
] = {
HTTPNodeRoute: _same_http_route,
PythonNodeRoute: _same_python_route,
VeilidNodeRoute: _same_veilid_route,
}
return route_type_to_comparison_method[type(route)]


def _same_http_route(route: HTTPNodeRoute, other: HTTPNodeRoute) -> bool:
"""
Check if two HTTPNodeRoute are the same based on protocol, host_or_ip (url) and port
"""
if type(route) != type(other):
return False
return (
(route.host_or_ip == other.host_or_ip)
and (route.port == other.port)
and (route.protocol == other.protocol)
)


def _same_python_route(route: PythonNodeRoute, other: PythonNodeRoute) -> bool:
"""
Check if two PythonNodeRoute are the same based on the metatdata of their worker settings (name, id...)
"""
if type(route) != type(other):
return False
return (
(route.worker_settings.id == other.worker_settings.id)
and (route.worker_settings.name == other.worker_settings.name)
and (route.worker_settings.node_type == other.worker_settings.node_type)
and (
route.worker_settings.node_side_type == other.worker_settings.node_side_type
)
and (route.worker_settings.signing_key == other.worker_settings.signing_key)
)


def _same_veilid_route(route: VeilidNodeRoute, other: VeilidNodeRoute) -> bool:
"""
Check if two VeilidNodeRoute are the same based on their veilid keys and proxy_target_uid
"""
if type(route) != type(other):
return False
return (
route.vld_key == other.vld_key
and route.proxy_target_uid == other.proxy_target_uid
)
35 changes: 26 additions & 9 deletions packages/syft/src/syft/service/network/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ class HTTPNodeRoute(SyftObject, NodeRoute):
priority: int = 1

def __eq__(self, other: Any) -> bool:
if isinstance(other, HTTPNodeRoute):
return hash(self) == hash(other)
return self == other
if not isinstance(other, HTTPNodeRoute):
return False
return (
(self.host_or_ip == other.host_or_ip)
and (self.port == other.port)
and (self.protocol == other.protocol)
and (self.proxy_target_uid == other.proxy_target_uid)
)

def __hash__(self) -> int:
return hash(self.host_or_ip) + hash(self.port) + hash(self.protocol)
Expand All @@ -119,9 +124,12 @@ class VeilidNodeRoute(SyftObject, NodeRoute):
priority: int = 1

def __eq__(self, other: Any) -> bool:
if isinstance(other, VeilidNodeRoute):
return hash(self) == hash(other)
return self == other
if not isinstance(other, VeilidNodeRoute):
return False
return (
self.vld_key == other.vld_key
and self.proxy_target_uid == other.proxy_target_uid
)

def __hash__(self) -> int:
return hash(self.vld_key)
Expand Down Expand Up @@ -159,9 +167,18 @@ def with_node(cls, node: AbstractNode) -> Self:
return cls(id=worker_settings.id, worker_settings=worker_settings)

def __eq__(self, other: Any) -> bool:
if isinstance(other, PythonNodeRoute):
return hash(self) == hash(other)
return self == other
if not isinstance(other, PythonNodeRoute):
return False
return (
(self.worker_settings.id == other.worker_settings.id)
and (self.worker_settings.name == other.worker_settings.name)
and (self.worker_settings.node_type == other.worker_settings.node_type)
and (
self.worker_settings.node_side_type
== other.worker_settings.node_side_type
)
and (self.worker_settings.signing_key == other.worker_settings.signing_key)
)

def __hash__(self) -> int:
return hash(self.worker_settings.id)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/network/gateway_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,14 @@ def test_delete_route_on_peer(

# gateway delete the routes for the domain
res = gateway_client.api.services.network.delete_route_on_peer(
peer_verify_key=domain_peer.verify_key, route_id=new_route.id
peer=domain_peer, route_id=new_route.id
)
assert isinstance(res, SyftSuccess)
gateway_peer = domain_client.peers[0]
assert len(gateway_peer.node_routes) == 2

res = gateway_client.api.services.network.delete_route_on_peer(
peer_verify_key=domain_peer.verify_key, route=new_route2
peer=domain_peer, route=new_route2
)
assert isinstance(res, SyftSuccess)
gateway_peer = domain_client.peers[0]
Expand All @@ -582,7 +582,7 @@ def test_delete_route_on_peer(
# gateway deletes the last the route to it for the domain
last_route: NodeRouteType = gateway_peer.node_routes[0]
res = gateway_client.api.services.network.delete_route_on_peer(
peer_verify_key=domain_peer.verify_key, route=last_route
peer=domain_peer, route=last_route
)
assert isinstance(res, SyftSuccess)
assert "There is no routes left" in res.message
Expand Down

0 comments on commit 0bc9757

Please sign in to comment.