From 0bc97577463526134695a52f3d2f72519cdc6f43 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Tue, 9 Apr 2024 17:21:40 +0700 Subject: [PATCH] [syft/network] move logic of comparing routes into their `__eq__` --- .../src/syft/service/network/node_peer.py | 76 +------------------ .../syft/src/syft/service/network/routes.py | 35 ++++++--- tests/integration/network/gateway_test.py | 6 +- 3 files changed, 31 insertions(+), 86 deletions(-) diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 15e30e4bca7..d78da3024c9 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -1,5 +1,4 @@ # stdlib -from collections.abc import Callable # third party from result import Err @@ -67,10 +66,8 @@ def existed_route( if route: if not isinstance(route, HTTPNodeRoute | PythonNodeRoute | VeilidNodeRoute): raise ValueError(f"Unsupported route type: {type(route)}") - - same_route: Callable = _route_type_to_same_route_check(route) for i, r in enumerate(self.node_routes): - if same_route(route, r): + if route == r: return (True, i) elif route_id: @@ -256,79 +253,10 @@ def delete_route( if route: try: - same_route: Callable = _route_type_to_same_route_check(route) - self.node_routes = [ - r for r in self.node_routes if not same_route(r, route) - ] + self.node_routes = [r for r in self.node_routes if r != route] except Exception as e: return SyftError( message=f"Error deleting route with id {route.id}. Exception: {e}" ) return None - - -def _route_type_to_same_route_check( - route: NodeRouteType, -) -> Callable[[NodeRouteType, NodeRouteType], bool]: - """ - Takes a route as input and returns a function that can be - used to compare if the two routes are the same. - - Args: - route (NodeRouteType): The route for which to get a comparison function. - - Returns: - Callable[[NodeRouteType, NodeRouteType], bool]: A function that takes two routes as input and returns a boolean - indicating whether the routes are the same. - """ - route_type_to_comparison_method: dict[ - type[NodeRouteType], Callable[[NodeRouteType, NodeRouteType], bool] - ] = { - HTTPNodeRoute: _same_http_route, - PythonNodeRoute: _same_python_route, - VeilidNodeRoute: _same_veilid_route, - } - return route_type_to_comparison_method[type(route)] - - -def _same_http_route(route: HTTPNodeRoute, other: HTTPNodeRoute) -> bool: - """ - Check if two HTTPNodeRoute are the same based on protocol, host_or_ip (url) and port - """ - if type(route) != type(other): - return False - return ( - (route.host_or_ip == other.host_or_ip) - and (route.port == other.port) - and (route.protocol == other.protocol) - ) - - -def _same_python_route(route: PythonNodeRoute, other: PythonNodeRoute) -> bool: - """ - Check if two PythonNodeRoute are the same based on the metatdata of their worker settings (name, id...) - """ - if type(route) != type(other): - return False - return ( - (route.worker_settings.id == other.worker_settings.id) - and (route.worker_settings.name == other.worker_settings.name) - and (route.worker_settings.node_type == other.worker_settings.node_type) - and ( - route.worker_settings.node_side_type == other.worker_settings.node_side_type - ) - and (route.worker_settings.signing_key == other.worker_settings.signing_key) - ) - - -def _same_veilid_route(route: VeilidNodeRoute, other: VeilidNodeRoute) -> bool: - """ - Check if two VeilidNodeRoute are the same based on their veilid keys and proxy_target_uid - """ - if type(route) != type(other): - return False - return ( - route.vld_key == other.vld_key - and route.proxy_target_uid == other.proxy_target_uid - ) diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index 635b0ac6b73..fc02a75de34 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -98,9 +98,14 @@ class HTTPNodeRoute(SyftObject, NodeRoute): priority: int = 1 def __eq__(self, other: Any) -> bool: - if isinstance(other, HTTPNodeRoute): - return hash(self) == hash(other) - return self == other + if not isinstance(other, HTTPNodeRoute): + return False + return ( + (self.host_or_ip == other.host_or_ip) + and (self.port == other.port) + and (self.protocol == other.protocol) + and (self.proxy_target_uid == other.proxy_target_uid) + ) def __hash__(self) -> int: return hash(self.host_or_ip) + hash(self.port) + hash(self.protocol) @@ -119,9 +124,12 @@ class VeilidNodeRoute(SyftObject, NodeRoute): priority: int = 1 def __eq__(self, other: Any) -> bool: - if isinstance(other, VeilidNodeRoute): - return hash(self) == hash(other) - return self == other + if not isinstance(other, VeilidNodeRoute): + return False + return ( + self.vld_key == other.vld_key + and self.proxy_target_uid == other.proxy_target_uid + ) def __hash__(self) -> int: return hash(self.vld_key) @@ -159,9 +167,18 @@ def with_node(cls, node: AbstractNode) -> Self: return cls(id=worker_settings.id, worker_settings=worker_settings) def __eq__(self, other: Any) -> bool: - if isinstance(other, PythonNodeRoute): - return hash(self) == hash(other) - return self == other + if not isinstance(other, PythonNodeRoute): + return False + return ( + (self.worker_settings.id == other.worker_settings.id) + and (self.worker_settings.name == other.worker_settings.name) + and (self.worker_settings.node_type == other.worker_settings.node_type) + and ( + self.worker_settings.node_side_type + == other.worker_settings.node_side_type + ) + and (self.worker_settings.signing_key == other.worker_settings.signing_key) + ) def __hash__(self) -> int: return hash(self.worker_settings.id) diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index 7d9da0db9f5..44fe5477a45 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -566,14 +566,14 @@ def test_delete_route_on_peer( # gateway delete the routes for the domain res = gateway_client.api.services.network.delete_route_on_peer( - peer_verify_key=domain_peer.verify_key, route_id=new_route.id + peer=domain_peer, route_id=new_route.id ) assert isinstance(res, SyftSuccess) gateway_peer = domain_client.peers[0] assert len(gateway_peer.node_routes) == 2 res = gateway_client.api.services.network.delete_route_on_peer( - peer_verify_key=domain_peer.verify_key, route=new_route2 + peer=domain_peer, route=new_route2 ) assert isinstance(res, SyftSuccess) gateway_peer = domain_client.peers[0] @@ -582,7 +582,7 @@ def test_delete_route_on_peer( # gateway deletes the last the route to it for the domain last_route: NodeRouteType = gateway_peer.node_routes[0] res = gateway_client.api.services.network.delete_route_on_peer( - peer_verify_key=domain_peer.verify_key, route=last_route + peer=domain_peer, route=last_route ) assert isinstance(res, SyftSuccess) assert "There is no routes left" in res.message