diff --git a/aiocomfoconnect/discovery.py b/aiocomfoconnect/discovery.py index c16541b..e9f2a38 100644 --- a/aiocomfoconnect/discovery.py +++ b/aiocomfoconnect/discovery.py @@ -6,6 +6,8 @@ import logging from typing import Any, List, Union +import netifaces + from .bridge import Bridge from .protobuf import zehnder_pb2 @@ -15,7 +17,7 @@ class BridgeDiscoveryProtocol(asyncio.DatagramProtocol): """UDP Protocol for the ComfoConnect LAN C bridge discovery.""" - def __init__(self, target: str = None, timeout: int = 5): + def __init__(self, target: str | None = None, timeout: int = 5): loop = asyncio.get_running_loop() self._bridges: List[Bridge] = [] @@ -33,8 +35,17 @@ def connection_made(self, transport: asyncio.transports.DatagramTransport): _LOGGER.debug("Sending discovery request to %s:%d", self._target, Bridge.PORT) self.transport.sendto(b"\x0a\x00", (self._target, Bridge.PORT)) else: - _LOGGER.debug("Sending discovery request to broadcast:%d", Bridge.PORT) - self.transport.sendto(b"\x0a\x00", ("", Bridge.PORT)) + # Determine broadcast address programmatically + try: + gws = netifaces.gateways() + default_iface = gws['default'][netifaces.AF_INET][1] + addrs = netifaces.ifaddresses(default_iface) + broadcast_addr = addrs[netifaces.AF_INET][0].get('broadcast', '255.255.255.255') + except Exception as e: + _LOGGER.warning("Could not determine broadcast address, using 255.255.255.255: %s", e) + broadcast_addr = '255.255.255.255' + _LOGGER.debug("Sending discovery request to broadcast:%d (%s)", Bridge.PORT, broadcast_addr) + self.transport.sendto(b"\x0a\x00", (broadcast_addr, Bridge.PORT)) def datagram_received(self, data: Union[bytes, str], addr: tuple[str | Any, int]): """Called when some datagram is received.""" @@ -43,12 +54,15 @@ def datagram_received(self, data: Union[bytes, str], addr: tuple[str | Any, int] return _LOGGER.debug("Data received from %s: %s", addr, data) - - # Decode the response - parser = zehnder_pb2.DiscoveryOperation() # pylint: disable=no-member - parser.ParseFromString(data) - - self._bridges.append(Bridge(host=parser.searchGatewayResponse.ipaddress, uuid=parser.searchGatewayResponse.uuid.hex())) + try: + # Decode the response + parser = zehnder_pb2.DiscoveryOperation() # pylint: disable=no-member + parser.ParseFromString(data) + + self._bridges.append(Bridge(host=parser.searchGatewayResponse.ipaddress, uuid=parser.searchGatewayResponse.uuid.hex())) + except Exception as exc: + _LOGGER.error("Failed to parse discovery response from %s: %s", addr, exc) + return # When we have passed a target, we only want to listen for that one if self._target: @@ -66,8 +80,30 @@ def get_bridges(self): return self._future -async def discover_bridges(host: str = None, timeout: int = 1, loop=None) -> List[Bridge]: - """Discover a bridge by IP.""" +async def discover_bridges(host: str | None = None, timeout: int = 1, loop=None) -> List[Bridge]: + """ + Discover ComfoConnect bridges on the local network or at a specified host. + + This asynchronous function sends a UDP broadcast (or unicast if a host is specified) + to discover available ComfoConnect bridges. It returns a list of discovered Bridge + instances. + + Args: + host (str | None): The IP address of a specific bridge to discover. If None, + a broadcast is sent to discover all available bridges. Defaults to None. + timeout (int): The time in seconds to wait for responses. Defaults to 1. + loop (asyncio.AbstractEventLoop, optional): The event loop to use. If None, + the default event loop is used. + + Returns: + List[Bridge]: A list of discovered Bridge objects. + + Raises: + Any exceptions raised by the underlying asyncio transport or protocol. + + Example: + bridges = await discover_bridges(timeout=2) + """ if loop is None: loop = asyncio.get_event_loop() diff --git a/pyproject.toml b/pyproject.toml index d741247..c442d15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ homepage = "https://github.com/michaelarnauts/aiocomfoconnect" python = "^3.10" aiohttp = "^3.8.0" protobuf = "^6.31" +netifaces = "^0.11.0" [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.73.0"