Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ my_text.set_text("Some awesome text")

## Availability Management

> [!WARNING]
> This features is not supported if using an [existing MQTT client](#using-an-existing-mqtt-client).

If `manual_availability` is set to `True`:

* `set_availability` has to be called to indicate if an entity is _available_ or _unavailable_
Expand Down Expand Up @@ -644,8 +647,12 @@ from paho.mqtt.client import Client

# Creating the MQTT client
client = Client()
# Doing other stuff with the client, like connecting to the broker
# Doing other stuff with the client
# ...
# Make sure the client is connected to the broker
client.connect(host="localhost")
# Also make sure the network communication is started
client.loop_start()

# Providing the client to the Settings object
# In this case, no other MQTT settings are needed
Expand Down
37 changes: 13 additions & 24 deletions ha_mqtt_discoverable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,16 @@ def __init__(self, settings: Settings[EntityType], on_connect: Callable | None =
self.availability_topic = f"{self._settings.mqtt.state_prefix}/{self._entity_topic}/availability"
logger.debug(f"availability_topic: {self.availability_topic}")

# Create the MQTT client, registering the user `on_connect` callback
self._setup_client(on_connect)
# If there is a callback function defined, the user must manually connect
# to the MQTT client
if not (on_connect or self._settings.mqtt.client is not None):
self._connect_client()
# If the user has passed in a MQTT client, use it
if self._settings.mqtt.client:
self.mqtt_client = self._settings.mqtt.client
else:
# Create the MQTT client, registering the user `on_connect` callback
self._setup_client(on_connect)
# If there is a callback function defined, the user must manually connect
# to the MQTT client
if not on_connect:
self._connect_client()

def __str__(self) -> str:
"""
Expand All @@ -227,14 +231,6 @@ def __str__(self) -> str:
def _setup_client(self, on_connect: Callable | None = None) -> None:
"""Create an MQTT client and setup some basic properties on it"""

# If the user has passed in an MQTT client, use it
if self._settings.mqtt.client:
self.mqtt_client = self._settings.mqtt.client
if on_connect:
logger.debug("Registering custom callback function")
self.mqtt_client.on_connect = on_connect
return

mqtt_settings = self._settings.mqtt
logger.debug(f"Creating mqtt client ({mqtt_settings.client_name}) for {mqtt_settings.host}:{mqtt_settings.port}")
self.mqtt_client = mqtt.Client(callback_api_version=CallbackAPIVersion.VERSION2, client_id=mqtt_settings.client_name)
Expand Down Expand Up @@ -436,16 +432,9 @@ def on_client_connected(client: mqtt.Client, *_):

if self._settings.mqtt.client:
# externally created MQTT client is used
if self.mqtt_client.is_connected():
# MQTT client is already connected, therefor explicitly
# subscribe to the command topic
on_client_connected(self.mqtt_client)
else:
# externally created MQTT client is not connected yet
# the 'on_connect' callback named 'on_client_connected'
# will subscribe to the command topic
# when the externally created MQTT client connects
pass
# which needs to be connected already
# therefor explicitly subscribe to the command topic
on_client_connected(self.mqtt_client)
else:
# Manually connect the MQTT client
self._connect_client()
Expand Down
85 changes: 57 additions & 28 deletions tests/test_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,44 @@
# limitations under the License.
#
import logging
import random
import string
import time
from collections.abc import Callable
from threading import Event
from typing import Any, TypeVar

import paho.mqtt.client as mqtt
import pytest
from paho.mqtt import publish
from paho.mqtt.client import MQTTMessage
from paho.mqtt.enums import CallbackAPIVersion

from ha_mqtt_discoverable import EntityInfo, Settings, Subscriber

T = TypeVar("T") # Used in the callback function


@pytest.fixture
def subscriber() -> Subscriber[EntityInfo]:
mqtt_settings = Settings.MQTT(host="localhost")
sensor_info = EntityInfo(name="test", component="button")
settings = Settings(mqtt=mqtt_settings, entity=sensor_info)
# Define an empty `command_callback`
return Subscriber(settings, lambda _, __, ___: None)
def make_subscriber():
def _make_subscriber(
callback: Callable[[mqtt.Client, T, mqtt.MQTTMessage], Any] = lambda _, __, ___: None, mqtt_client: mqtt.Client = None
):
mqtt_settings = Settings.MQTT(client=mqtt_client) if mqtt_client else Settings.MQTT(host="localhost")
sensor_info = EntityInfo(name="".join(random.choices(string.ascii_lowercase + string.digits, k=10)), component="button")
settings = Settings(mqtt=mqtt_settings, entity=sensor_info)
return Subscriber(settings, callback)

return _make_subscriber


def test_required_config():
mqtt_settings = Settings.MQTT(host="localhost")
sensor_info = EntityInfo(name="test", component="button")
settings = Settings(mqtt=mqtt_settings, entity=sensor_info)
# Define empty callback
sensor = Subscriber(settings, lambda _, __, ___: None)
assert sensor is not None
@pytest.fixture
def subscriber(make_subscriber) -> Subscriber[EntityInfo]:
return make_subscriber()


def test_required_config(subscriber: Subscriber):
assert subscriber is not None


def test_generate_config(subscriber: Subscriber):
Expand All @@ -50,27 +62,44 @@ def test_generate_config(subscriber: Subscriber):
assert config["command_topic"] == subscriber._command_topic


def test_command_callback():
mqtt_settings = Settings.MQTT(host="localhost")
sensor_info = EntityInfo(name="test", component="switch")
settings = Settings(mqtt=mqtt_settings, entity=sensor_info)

# Flag that waits for the command to be received
message_received = Event()

def create_callback(event: Event, expected_payload: str) -> Callable:
# Callback to receive the command message
def custom_callback(_, user_data, message: MQTTMessage):
payload = message.payload.decode()
logging.info(f"Received {payload}")
assert payload == "on"
assert payload == expected_payload
assert user_data is None
message_received.set()
event.set()

return custom_callback


switch = Subscriber(settings, custom_callback)
# Wait some seconds for the subscription to take effect
time.sleep(2)
def create_external_mqtt_client() -> mqtt.Client:
mqtt_client = mqtt.Client(callback_api_version=CallbackAPIVersion.VERSION2)
mqtt_client.connect(host="localhost")
mqtt_client.loop_start()
return mqtt_client


@pytest.mark.parametrize("mqtt_client", [None, create_external_mqtt_client()])
def test_command_callbacks(make_subscriber, mqtt_client):
# Flag that waits for the command to be received
event1 = Event()
event2 = Event()
expected_payload1 = "on"
expected_payload2 = "off"
custom_callback1 = create_callback(event1, expected_payload1)
custom_callback2 = create_callback(event2, expected_payload2)
subscriber1 = make_subscriber(custom_callback1, mqtt_client)
subscriber2 = make_subscriber(custom_callback2, mqtt_client)
# Wait for the subscription to take effect
time.sleep(1)

assert subscriber1._command_topic != subscriber2._command_topic

# Send a command to the command topic
publish.single(switch._command_topic, "on", hostname="localhost")
publish.single(subscriber1._command_topic, expected_payload1)
publish.single(subscriber2._command_topic, expected_payload2)

assert message_received.wait(2)
assert event1.wait(1)
assert event2.wait(1)
Loading