From 815eda9afdd791cd414a8d602129c7de95524f14 Mon Sep 17 00:00:00 2001 From: Steffen Pankratz Date: Thu, 4 Dec 2025 15:13:03 +0100 Subject: [PATCH] Fix command topic subscriptions for externally created MQTT clients (fixes #468) Signed-off-by: Steffen Pankratz --- README.md | 9 +++- ha_mqtt_discoverable/__init__.py | 37 +++++--------- tests/test_subscriber.py | 85 +++++++++++++++++++++----------- 3 files changed, 78 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index f3f63ba..e81e5e2 100644 --- a/README.md +++ b/README.md @@ -606,6 +606,9 @@ my_text.set_text("Some awesome text") ## Availability Management +> [!WARNING] +> This features is not supported if using an [existing MQTT client](#using-an-existing-mqtt-client). + If `manual_availability` is set to `True`: * `set_availability` has to be called to indicate if an entity is _available_ or _unavailable_ @@ -644,8 +647,12 @@ from paho.mqtt.client import Client # Creating the MQTT client client = Client() -# Doing other stuff with the client, like connecting to the broker +# Doing other stuff with the client # ... +# Make sure the client is connected to the broker +client.connect(host="localhost") +# Also make sure the network communication is started +client.loop_start() # Providing the client to the Settings object # In this case, no other MQTT settings are needed diff --git a/ha_mqtt_discoverable/__init__.py b/ha_mqtt_discoverable/__init__.py index 6db9eba..5ecd395 100644 --- a/ha_mqtt_discoverable/__init__.py +++ b/ha_mqtt_discoverable/__init__.py @@ -204,12 +204,16 @@ def __init__(self, settings: Settings[EntityType], on_connect: Callable | None = self.availability_topic = f"{self._settings.mqtt.state_prefix}/{self._entity_topic}/availability" logger.debug(f"availability_topic: {self.availability_topic}") - # Create the MQTT client, registering the user `on_connect` callback - self._setup_client(on_connect) - # If there is a callback function defined, the user must manually connect - # to the MQTT client - if not (on_connect or self._settings.mqtt.client is not None): - self._connect_client() + # If the user has passed in a MQTT client, use it + if self._settings.mqtt.client: + self.mqtt_client = self._settings.mqtt.client + else: + # Create the MQTT client, registering the user `on_connect` callback + self._setup_client(on_connect) + # If there is a callback function defined, the user must manually connect + # to the MQTT client + if not on_connect: + self._connect_client() def __str__(self) -> str: """ @@ -227,14 +231,6 @@ def __str__(self) -> str: def _setup_client(self, on_connect: Callable | None = None) -> None: """Create an MQTT client and setup some basic properties on it""" - # If the user has passed in an MQTT client, use it - if self._settings.mqtt.client: - self.mqtt_client = self._settings.mqtt.client - if on_connect: - logger.debug("Registering custom callback function") - self.mqtt_client.on_connect = on_connect - return - mqtt_settings = self._settings.mqtt logger.debug(f"Creating mqtt client ({mqtt_settings.client_name}) for {mqtt_settings.host}:{mqtt_settings.port}") self.mqtt_client = mqtt.Client(callback_api_version=CallbackAPIVersion.VERSION2, client_id=mqtt_settings.client_name) @@ -436,16 +432,9 @@ def on_client_connected(client: mqtt.Client, *_): if self._settings.mqtt.client: # externally created MQTT client is used - if self.mqtt_client.is_connected(): - # MQTT client is already connected, therefor explicitly - # subscribe to the command topic - on_client_connected(self.mqtt_client) - else: - # externally created MQTT client is not connected yet - # the 'on_connect' callback named 'on_client_connected' - # will subscribe to the command topic - # when the externally created MQTT client connects - pass + # which needs to be connected already + # therefor explicitly subscribe to the command topic + on_client_connected(self.mqtt_client) else: # Manually connect the MQTT client self._connect_client() diff --git a/tests/test_subscriber.py b/tests/test_subscriber.py index c396ff8..9cc3268 100644 --- a/tests/test_subscriber.py +++ b/tests/test_subscriber.py @@ -14,32 +14,44 @@ # limitations under the License. # import logging +import random +import string import time +from collections.abc import Callable from threading import Event +from typing import Any, TypeVar +import paho.mqtt.client as mqtt import pytest from paho.mqtt import publish from paho.mqtt.client import MQTTMessage +from paho.mqtt.enums import CallbackAPIVersion from ha_mqtt_discoverable import EntityInfo, Settings, Subscriber +T = TypeVar("T") # Used in the callback function + @pytest.fixture -def subscriber() -> Subscriber[EntityInfo]: - mqtt_settings = Settings.MQTT(host="localhost") - sensor_info = EntityInfo(name="test", component="button") - settings = Settings(mqtt=mqtt_settings, entity=sensor_info) - # Define an empty `command_callback` - return Subscriber(settings, lambda _, __, ___: None) +def make_subscriber(): + def _make_subscriber( + callback: Callable[[mqtt.Client, T, mqtt.MQTTMessage], Any] = lambda _, __, ___: None, mqtt_client: mqtt.Client = None + ): + mqtt_settings = Settings.MQTT(client=mqtt_client) if mqtt_client else Settings.MQTT(host="localhost") + sensor_info = EntityInfo(name="".join(random.choices(string.ascii_lowercase + string.digits, k=10)), component="button") + settings = Settings(mqtt=mqtt_settings, entity=sensor_info) + return Subscriber(settings, callback) + + return _make_subscriber -def test_required_config(): - mqtt_settings = Settings.MQTT(host="localhost") - sensor_info = EntityInfo(name="test", component="button") - settings = Settings(mqtt=mqtt_settings, entity=sensor_info) - # Define empty callback - sensor = Subscriber(settings, lambda _, __, ___: None) - assert sensor is not None +@pytest.fixture +def subscriber(make_subscriber) -> Subscriber[EntityInfo]: + return make_subscriber() + + +def test_required_config(subscriber: Subscriber): + assert subscriber is not None def test_generate_config(subscriber: Subscriber): @@ -50,27 +62,44 @@ def test_generate_config(subscriber: Subscriber): assert config["command_topic"] == subscriber._command_topic -def test_command_callback(): - mqtt_settings = Settings.MQTT(host="localhost") - sensor_info = EntityInfo(name="test", component="switch") - settings = Settings(mqtt=mqtt_settings, entity=sensor_info) - - # Flag that waits for the command to be received - message_received = Event() - +def create_callback(event: Event, expected_payload: str) -> Callable: # Callback to receive the command message def custom_callback(_, user_data, message: MQTTMessage): payload = message.payload.decode() logging.info(f"Received {payload}") - assert payload == "on" + assert payload == expected_payload assert user_data is None - message_received.set() + event.set() + + return custom_callback + - switch = Subscriber(settings, custom_callback) - # Wait some seconds for the subscription to take effect - time.sleep(2) +def create_external_mqtt_client() -> mqtt.Client: + mqtt_client = mqtt.Client(callback_api_version=CallbackAPIVersion.VERSION2) + mqtt_client.connect(host="localhost") + mqtt_client.loop_start() + return mqtt_client + + +@pytest.mark.parametrize("mqtt_client", [None, create_external_mqtt_client()]) +def test_command_callbacks(make_subscriber, mqtt_client): + # Flag that waits for the command to be received + event1 = Event() + event2 = Event() + expected_payload1 = "on" + expected_payload2 = "off" + custom_callback1 = create_callback(event1, expected_payload1) + custom_callback2 = create_callback(event2, expected_payload2) + subscriber1 = make_subscriber(custom_callback1, mqtt_client) + subscriber2 = make_subscriber(custom_callback2, mqtt_client) + # Wait for the subscription to take effect + time.sleep(1) + + assert subscriber1._command_topic != subscriber2._command_topic # Send a command to the command topic - publish.single(switch._command_topic, "on", hostname="localhost") + publish.single(subscriber1._command_topic, expected_payload1) + publish.single(subscriber2._command_topic, expected_payload2) - assert message_received.wait(2) + assert event1.wait(1) + assert event2.wait(1)