Skip to content

Commit

Permalink
fix: set event loop for client
Browse files Browse the repository at this point in the history
  • Loading branch information
muhlba91 committed Dec 12, 2023
1 parent b5e3fae commit b7e94a3
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 60 deletions.
13 changes: 5 additions & 8 deletions custom_components/hella_onyx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
CONF_ACCESS_TOKEN,
CONF_SCAN_INTERVAL,
CONF_FORCE_UPDATE,
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
Expand All @@ -23,7 +22,9 @@
ONYX_API,
ONYX_COORDINATOR,
ONYX_TIMEZONE,
ONYX_THREAD,
)
from .event_thread import EventThread

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,19 +87,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
request_refresh_debouncer=Debouncer(hass, _LOGGER, cooldown=0, immediate=True),
)

def updated_device(device):
onyx_api.updated_device(device)
coordinator.async_set_updated_data(device)

hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, onyx_api.stop)
onyx_api.set_event_callback(updated_device)
onyx_api.start(force_update)
thread = EventThread(onyx_api, coordinator, force_update)

hass.data[DOMAIN][entry.entry_id] = {
ONYX_API: onyx_api,
ONYX_TIMEZONE: onyx_timezone,
ONYX_COORDINATOR: coordinator,
ONYX_THREAD: thread,
}
thread.start()

for platform in PLATFORMS:
hass.async_create_task(
Expand Down
30 changes: 4 additions & 26 deletions custom_components/hella_onyx/api_connector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
"""API connector for the ONYX integration."""
import logging

import asyncio
from typing import Any

from homeassistant.helpers.aiohttp_client import async_get_clientsession
from onyx_client.client import create
from onyx_client.data.device_command import DeviceCommand
from onyx_client.enum.action import Action

from custom_components.hella_onyx.const import MAX_BACKOFF_TIME

_LOGGER = logging.getLogger(__name__)


Expand All @@ -24,17 +19,14 @@ def __init__(self, hass, fingerprint, token):
self.token = token
self.devices = {}
self.groups = {}
self._loop = None
self.__client = None

def _client(self):
if self.__client is None:
self._loop = asyncio.new_event_loop()
self.__client = create(
fingerprint=self.fingerprint,
access_token=self.token,
client_session=async_get_clientsession(self.hass),
event_loop=self._loop,
)
return self.__client

Expand Down Expand Up @@ -84,24 +76,10 @@ async def send_device_command_properties(self, uuid: str, properties: dict):
if not success:
raise CommandException("ONYX_ACTION_ERROR", uuid)

def start(self, include_details):
"""Start the event loop."""
_LOGGER.info("Starting ONYX")
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
self._client().start(include_details, MAX_BACKOFF_TIME)

def set_event_callback(self, callback):
"""Set the event callback."""
self._client().set_event_callback(callback)

def stop(self, **kwargs: Any):
"""Stop the event loop."""
_LOGGER.info("Shutting down ONYX")
self._client().stop()
self._loop.stop()
self.__client = None
self._loop = None
async def events(self, force_update: bool = False):
"""Listen for events."""
async for device in self._client().events(force_update):
yield device


class CommandException(Exception):
Expand Down
1 change: 1 addition & 0 deletions custom_components/hella_onyx/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ONYX_API = "onyx_api"
ONYX_TIMEZONE = "onyx_timezone"
ONYX_COORDINATOR = "onyx_coordinator"
ONYX_THREAD = "onyx_thread"

CONF_FINGERPRINT = "fingerprint"

Expand Down
56 changes: 56 additions & 0 deletions custom_components/hella_onyx/event_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""ONYX API event thread."""
import asyncio
import logging
import threading
from random import uniform

from homeassistant.helpers.update_coordinator import DataUpdateCoordinator

from .api_connector import APIConnector
from .const import MAX_BACKOFF_TIME

_LOGGER = logging.getLogger(__name__)


class EventThread(threading.Thread):
"""The event thread for asynchronous updates."""

def __init__(
self,
api: APIConnector,
coordinator: DataUpdateCoordinator,
force_update: bool = False,
backoff=True,
):
threading.Thread.__init__(self, name="HellaOnyx")
self._api = api
self._coordinator = coordinator
self._force_update = force_update
self._backoff = backoff

async def _update(self):
"""Listen for updates."""
while True:
backoff = int(uniform(0, MAX_BACKOFF_TIME) * 60)
try:
async for device in self._api.events(self._force_update):
self._api.updated_device(device)
self._coordinator.async_set_updated_data(device)
except Exception as ex:
_LOGGER.error(
"connection reset: %s, restarting with backoff of %s seconds (%s)",
ex,
backoff,
self._backoff,
)
if self._backoff:
await asyncio.sleep(backoff)
else:
break

def run(self):
"""Start the thread."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(self._update())
loop.run_forever()
46 changes: 20 additions & 26 deletions tests/test_api_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
CommandException,
UnknownStateException,
)
from custom_components.hella_onyx.const import MAX_BACKOFF_TIME


class TestAPIConnector:
Expand Down Expand Up @@ -131,27 +130,21 @@ async def test__client(self, api):
assert client is not None
assert isinstance(client, OnyxClient)

@patch("asyncio.set_event_loop")
@pytest.mark.asyncio
async def test_start(self, mock_set_event_loop, api, client):
async def test_events(self, api, client):
with patch.object(api, "_client", new=client.make):
with patch.object(api, "_loop"):
api.start(True)
assert client.is_called
assert mock_set_event_loop.called

@pytest.mark.asyncio
async def test_stop(self, api, client):
with patch.object(api, "_client", new=client.make):
with patch.object(api, "_loop"):
api.stop()
assert client.is_called
async for device in api.events():
assert device is not None
assert client.is_called
assert not client.is_force_update

@pytest.mark.asyncio
async def test_set_event_callback(self, api, client):
async def test_events_force_update(self, api, client):
with patch.object(api, "_client", new=client.make):
api.set_event_callback(None)
async for device in api.events(True):
assert device is not None
assert client.is_called
assert client.is_force_update


class MockClient:
Expand Down Expand Up @@ -201,22 +194,23 @@ async def device(self, uuid: str):
list(Action),
)

async def events(self, force_update: bool):
self.called = True
self.force_update = force_update
yield Shutter(
"id",
"other",
DeviceType.RAFFSTORE_90,
DeviceMode(DeviceType.RAFFSTORE_90),
list(Action),
)

async def send_command(self, uuid: str, command: DeviceCommand):
self.called = True
return command.action == Action.STOP or (
command.properties is not None and "fail" not in command.properties
)

def start(self, include_details, backoff_time):
self.called = True
assert backoff_time == MAX_BACKOFF_TIME

def stop(self):
self.called = True

def set_event_callback(self, callback):
self.called = True

async def date_information(self):
self.called = True
return self.date
Expand Down
139 changes: 139 additions & 0 deletions tests/test_event_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Test for the EventThread."""

from unittest.mock import AsyncMock, patch

import pytest
from onyx_client.data.device_mode import DeviceMode
from onyx_client.data.numeric_value import NumericValue
from onyx_client.device.shutter import Shutter
from onyx_client.enum.action import Action
from onyx_client.enum.device_type import DeviceType

from custom_components.hella_onyx import (
EventThread,
)
from custom_components.hella_onyx.api_connector import (
UnknownStateException,
)
from custom_components.hella_onyx.const import MAX_BACKOFF_TIME


class TestEventThread:
@pytest.fixture
def api(self):
yield MockAPI()

@pytest.fixture
def coordinator(self):
yield AsyncMock()

@pytest.fixture
def thread(self, api, coordinator):
yield EventThread(api, coordinator, force_update=False, backoff=False)

@pytest.mark.asyncio
async def test_update(self, thread, api, coordinator):
api.called = False
await thread._update()
assert api.is_called
assert not api.is_force_update
assert coordinator.async_set_updated_data.called

@pytest.mark.asyncio
async def test_update_force_update(self, thread, api, coordinator):
thread._force_update = True
api.called = False
await thread._update()
assert api.is_called
assert api.is_force_update
assert coordinator.async_set_updated_data.called

@pytest.mark.asyncio
async def test_update_invalid_device(self, thread, api, coordinator):
api.called = False
api.fail_device = True
await thread._update()
assert api.is_called
assert not api.is_force_update
assert coordinator.async_set_updated_data.called

@pytest.mark.asyncio
async def test_update_none_device(self, thread, api, coordinator):
api.called = False
api.none_device = True
await thread._update()
assert api.is_called
assert not api.is_force_update
assert coordinator.async_set_updated_data.called

@pytest.mark.asyncio
async def test_update_connection_error(self, thread, api, coordinator):
api.called = False
api.fail = True
await thread._update()
assert api.is_called
assert not api.is_force_update
assert not coordinator.async_set_updated_data.called

@pytest.mark.asyncio
async def test_update_backoff(self, thread, api, coordinator):
api.called = False

async def sleep_called(backoff: int):
assert backoff > 0
assert backoff / 60 < MAX_BACKOFF_TIME
thread._backoff = False

with patch("asyncio.sleep", new=sleep_called):
thread._backoff = True
api.fail = True
assert thread._backoff
await thread._update()
assert api.is_called
assert not api.is_force_update
assert not thread._backoff
assert not coordinator.async_set_updated_data.called


class MockAPI:
def __init__(self):
self.called = False
self.force_update = False
self.fail = False
self.fail_device = False
self.none_device = False

@property
def is_called(self):
return self.called

@property
def is_force_update(self):
return self.force_update

def device(self, uuid: str):
self.called = True
if self.none_device:
return None
if self.fail_device:
raise UnknownStateException("ERROR")
numeric = NumericValue(10, 10, 10, False, None)
return Shutter(
"uuid", "name", None, None, None, None, numeric, numeric, numeric
)

def updated_device(self, device):
self.called = True

async def events(self, force_update: bool):
self.called = True
self.force_update = force_update
if self.fail:
raise NotImplementedError()
yield Shutter(
"uuid",
"other",
DeviceType.RAFFSTORE_90,
DeviceMode(DeviceType.RAFFSTORE_90),
list(Action),
)

0 comments on commit b7e94a3

Please sign in to comment.