diff --git a/custom_components/hella_onyx/__init__.py b/custom_components/hella_onyx/__init__.py index 31227ad..e3a0e63 100644 --- a/custom_components/hella_onyx/__init__.py +++ b/custom_components/hella_onyx/__init__.py @@ -9,7 +9,6 @@ CONF_ACCESS_TOKEN, CONF_SCAN_INTERVAL, CONF_FORCE_UPDATE, - EVENT_HOMEASSISTANT_STOP, ) from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv @@ -23,7 +22,9 @@ ONYX_API, ONYX_COORDINATOR, ONYX_TIMEZONE, + ONYX_THREAD, ) +from .event_thread import EventThread _LOGGER = logging.getLogger(__name__) @@ -86,19 +87,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): request_refresh_debouncer=Debouncer(hass, _LOGGER, cooldown=0, immediate=True), ) - def updated_device(device): - onyx_api.updated_device(device) - coordinator.async_set_updated_data(device) - - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, onyx_api.stop) - onyx_api.set_event_callback(updated_device) - onyx_api.start(force_update) + thread = EventThread(onyx_api, coordinator, force_update) hass.data[DOMAIN][entry.entry_id] = { ONYX_API: onyx_api, ONYX_TIMEZONE: onyx_timezone, ONYX_COORDINATOR: coordinator, + ONYX_THREAD: thread, } + thread.start() for platform in PLATFORMS: hass.async_create_task( diff --git a/custom_components/hella_onyx/api_connector.py b/custom_components/hella_onyx/api_connector.py index 09b6bb1..0b10c04 100644 --- a/custom_components/hella_onyx/api_connector.py +++ b/custom_components/hella_onyx/api_connector.py @@ -1,16 +1,11 @@ """API connector for the ONYX integration.""" import logging -import asyncio -from typing import Any - from homeassistant.helpers.aiohttp_client import async_get_clientsession from onyx_client.client import create from onyx_client.data.device_command import DeviceCommand from onyx_client.enum.action import Action -from custom_components.hella_onyx.const import MAX_BACKOFF_TIME - _LOGGER = logging.getLogger(__name__) @@ -24,17 +19,14 @@ def __init__(self, hass, fingerprint, token): self.token = token self.devices = {} self.groups = {} - self._loop = None self.__client = None def _client(self): if self.__client is None: - self._loop = asyncio.new_event_loop() self.__client = create( fingerprint=self.fingerprint, access_token=self.token, client_session=async_get_clientsession(self.hass), - event_loop=self._loop, ) return self.__client @@ -84,24 +76,10 @@ async def send_device_command_properties(self, uuid: str, properties: dict): if not success: raise CommandException("ONYX_ACTION_ERROR", uuid) - def start(self, include_details): - """Start the event loop.""" - _LOGGER.info("Starting ONYX") - asyncio.set_event_loop(self._loop) - self._loop.run_forever() - self._client().start(include_details, MAX_BACKOFF_TIME) - - def set_event_callback(self, callback): - """Set the event callback.""" - self._client().set_event_callback(callback) - - def stop(self, **kwargs: Any): - """Stop the event loop.""" - _LOGGER.info("Shutting down ONYX") - self._client().stop() - self._loop.stop() - self.__client = None - self._loop = None + async def events(self, force_update: bool = False): + """Listen for events.""" + async for device in self._client().events(force_update): + yield device class CommandException(Exception): diff --git a/custom_components/hella_onyx/const.py b/custom_components/hella_onyx/const.py index b2a1260..2688575 100644 --- a/custom_components/hella_onyx/const.py +++ b/custom_components/hella_onyx/const.py @@ -5,6 +5,7 @@ ONYX_API = "onyx_api" ONYX_TIMEZONE = "onyx_timezone" ONYX_COORDINATOR = "onyx_coordinator" +ONYX_THREAD = "onyx_thread" CONF_FINGERPRINT = "fingerprint" diff --git a/custom_components/hella_onyx/event_thread.py b/custom_components/hella_onyx/event_thread.py new file mode 100644 index 0000000..7aaa425 --- /dev/null +++ b/custom_components/hella_onyx/event_thread.py @@ -0,0 +1,56 @@ +"""ONYX API event thread.""" +import asyncio +import logging +import threading +from random import uniform + +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator + +from .api_connector import APIConnector +from .const import MAX_BACKOFF_TIME + +_LOGGER = logging.getLogger(__name__) + + +class EventThread(threading.Thread): + """The event thread for asynchronous updates.""" + + def __init__( + self, + api: APIConnector, + coordinator: DataUpdateCoordinator, + force_update: bool = False, + backoff=True, + ): + threading.Thread.__init__(self, name="HellaOnyx") + self._api = api + self._coordinator = coordinator + self._force_update = force_update + self._backoff = backoff + + async def _update(self): + """Listen for updates.""" + while True: + backoff = int(uniform(0, MAX_BACKOFF_TIME) * 60) + try: + async for device in self._api.events(self._force_update): + self._api.updated_device(device) + self._coordinator.async_set_updated_data(device) + except Exception as ex: + _LOGGER.error( + "connection reset: %s, restarting with backoff of %s seconds (%s)", + ex, + backoff, + self._backoff, + ) + if self._backoff: + await asyncio.sleep(backoff) + else: + break + + def run(self): + """Start the thread.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(self._update()) + loop.run_forever() diff --git a/tests/test_api_connector.py b/tests/test_api_connector.py index dcfb3e3..3559da4 100644 --- a/tests/test_api_connector.py +++ b/tests/test_api_connector.py @@ -20,7 +20,6 @@ CommandException, UnknownStateException, ) -from custom_components.hella_onyx.const import MAX_BACKOFF_TIME class TestAPIConnector: @@ -131,27 +130,21 @@ async def test__client(self, api): assert client is not None assert isinstance(client, OnyxClient) - @patch("asyncio.set_event_loop") @pytest.mark.asyncio - async def test_start(self, mock_set_event_loop, api, client): + async def test_events(self, api, client): with patch.object(api, "_client", new=client.make): - with patch.object(api, "_loop"): - api.start(True) - assert client.is_called - assert mock_set_event_loop.called - - @pytest.mark.asyncio - async def test_stop(self, api, client): - with patch.object(api, "_client", new=client.make): - with patch.object(api, "_loop"): - api.stop() - assert client.is_called + async for device in api.events(): + assert device is not None + assert client.is_called + assert not client.is_force_update @pytest.mark.asyncio - async def test_set_event_callback(self, api, client): + async def test_events_force_update(self, api, client): with patch.object(api, "_client", new=client.make): - api.set_event_callback(None) + async for device in api.events(True): + assert device is not None assert client.is_called + assert client.is_force_update class MockClient: @@ -201,22 +194,23 @@ async def device(self, uuid: str): list(Action), ) + async def events(self, force_update: bool): + self.called = True + self.force_update = force_update + yield Shutter( + "id", + "other", + DeviceType.RAFFSTORE_90, + DeviceMode(DeviceType.RAFFSTORE_90), + list(Action), + ) + async def send_command(self, uuid: str, command: DeviceCommand): self.called = True return command.action == Action.STOP or ( command.properties is not None and "fail" not in command.properties ) - def start(self, include_details, backoff_time): - self.called = True - assert backoff_time == MAX_BACKOFF_TIME - - def stop(self): - self.called = True - - def set_event_callback(self, callback): - self.called = True - async def date_information(self): self.called = True return self.date diff --git a/tests/test_event_thread.py b/tests/test_event_thread.py new file mode 100644 index 0000000..445711a --- /dev/null +++ b/tests/test_event_thread.py @@ -0,0 +1,139 @@ +"""Test for the EventThread.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from onyx_client.data.device_mode import DeviceMode +from onyx_client.data.numeric_value import NumericValue +from onyx_client.device.shutter import Shutter +from onyx_client.enum.action import Action +from onyx_client.enum.device_type import DeviceType + +from custom_components.hella_onyx import ( + EventThread, +) +from custom_components.hella_onyx.api_connector import ( + UnknownStateException, +) +from custom_components.hella_onyx.const import MAX_BACKOFF_TIME + + +class TestEventThread: + @pytest.fixture + def api(self): + yield MockAPI() + + @pytest.fixture + def coordinator(self): + yield AsyncMock() + + @pytest.fixture + def thread(self, api, coordinator): + yield EventThread(api, coordinator, force_update=False, backoff=False) + + @pytest.mark.asyncio + async def test_update(self, thread, api, coordinator): + api.called = False + await thread._update() + assert api.is_called + assert not api.is_force_update + assert coordinator.async_set_updated_data.called + + @pytest.mark.asyncio + async def test_update_force_update(self, thread, api, coordinator): + thread._force_update = True + api.called = False + await thread._update() + assert api.is_called + assert api.is_force_update + assert coordinator.async_set_updated_data.called + + @pytest.mark.asyncio + async def test_update_invalid_device(self, thread, api, coordinator): + api.called = False + api.fail_device = True + await thread._update() + assert api.is_called + assert not api.is_force_update + assert coordinator.async_set_updated_data.called + + @pytest.mark.asyncio + async def test_update_none_device(self, thread, api, coordinator): + api.called = False + api.none_device = True + await thread._update() + assert api.is_called + assert not api.is_force_update + assert coordinator.async_set_updated_data.called + + @pytest.mark.asyncio + async def test_update_connection_error(self, thread, api, coordinator): + api.called = False + api.fail = True + await thread._update() + assert api.is_called + assert not api.is_force_update + assert not coordinator.async_set_updated_data.called + + @pytest.mark.asyncio + async def test_update_backoff(self, thread, api, coordinator): + api.called = False + + async def sleep_called(backoff: int): + assert backoff > 0 + assert backoff / 60 < MAX_BACKOFF_TIME + thread._backoff = False + + with patch("asyncio.sleep", new=sleep_called): + thread._backoff = True + api.fail = True + assert thread._backoff + await thread._update() + assert api.is_called + assert not api.is_force_update + assert not thread._backoff + assert not coordinator.async_set_updated_data.called + + +class MockAPI: + def __init__(self): + self.called = False + self.force_update = False + self.fail = False + self.fail_device = False + self.none_device = False + + @property + def is_called(self): + return self.called + + @property + def is_force_update(self): + return self.force_update + + def device(self, uuid: str): + self.called = True + if self.none_device: + return None + if self.fail_device: + raise UnknownStateException("ERROR") + numeric = NumericValue(10, 10, 10, False, None) + return Shutter( + "uuid", "name", None, None, None, None, numeric, numeric, numeric + ) + + def updated_device(self, device): + self.called = True + + async def events(self, force_update: bool): + self.called = True + self.force_update = force_update + if self.fail: + raise NotImplementedError() + yield Shutter( + "uuid", + "other", + DeviceType.RAFFSTORE_90, + DeviceMode(DeviceType.RAFFSTORE_90), + list(Action), + )