Skip to content

Commit

Permalink
Use HassKey in stt (#126335)
Browse files Browse the repository at this point in the history
  • Loading branch information
epenet authored Sep 21, 2024
1 parent 91c1e75 commit 0299fa1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
30 changes: 10 additions & 20 deletions homeassistant/components/stt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .const import (
DATA_PROVIDERS,
DOMAIN,
DOMAIN_DATA,
AudioBitRates,
AudioChannels,
AudioCodecs,
Expand Down Expand Up @@ -72,11 +73,9 @@
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]

default_entity_id: str | None = None

for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
if entity.platform and entity.platform.platform_name == "cloud":
return entity.entity_id

Expand All @@ -91,9 +90,7 @@ def async_get_speech_to_text_entity(
hass: HomeAssistant, entity_id: str
) -> SpeechToTextEntity | None:
"""Return stt entity."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]

return component.get_entity(entity_id)
return hass.data[DOMAIN_DATA].get_entity(entity_id)


@callback
Expand All @@ -111,13 +108,11 @@ def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by stt engines."""
languages = set()

component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
for language_tag in entity.supported_languages:
languages.add(language_tag)

for engine in legacy_providers.values():
for engine in hass.data[DATA_PROVIDERS].values():
for language_tag in engine.supported_languages:
languages.add(language_tag)

Expand All @@ -128,7 +123,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
websocket_api.async_register_command(hass, websocket_list_engines)

component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
component = hass.data[DOMAIN_DATA] = EntityComponent[SpeechToTextEntity](
_LOGGER, DOMAIN, hass
)

Expand All @@ -150,14 +145,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
return await hass.data[DOMAIN_DATA].async_setup_entry(entry)


async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
return await hass.data[DOMAIN_DATA].async_unload_entry(entry)


class SpeechToTextEntity(RestoreEntity):
Expand Down Expand Up @@ -426,15 +419,12 @@ def websocket_list_engines(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List speech-to-text engines and, optionally, if they support a given language."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]

country = msg.get("country")
language = msg.get("language")
providers = []
provider_info: dict[str, Any]

for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
provider_info = {
"engine_id": entity.entity_id,
"supported_languages": entity.supported_languages,
Expand All @@ -445,7 +435,7 @@ def websocket_list_engines(
)
providers.append(provider_info)

for engine_id, provider in legacy_providers.items():
for engine_id, provider in hass.data[DATA_PROVIDERS].items():
provider_info = {
"engine_id": engine_id,
"name": provider.name,
Expand Down
14 changes: 13 additions & 1 deletion homeassistant/components/stt/const.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
"""STT constante."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

from homeassistant.util.hass_dict import HassKey

if TYPE_CHECKING:
from homeassistant.helpers.entity_component import EntityComponent

from . import SpeechToTextEntity
from .legacy import Provider

DOMAIN = "stt"
DATA_PROVIDERS = f"{DOMAIN}_providers"
DOMAIN_DATA: HassKey[EntityComponent[SpeechToTextEntity]] = HassKey(DOMAIN)
DATA_PROVIDERS: HassKey[dict[str, Provider]] = HassKey(f"{DOMAIN}_providers")


class AudioCodecs(str, Enum):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/stt/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@
@callback
def async_default_provider(hass: HomeAssistant) -> str | None:
"""Return the domain of the default provider."""
return next(iter(hass.data[DATA_PROVIDERS]), None)
providers = hass.data[DATA_PROVIDERS]
return next(iter(providers), None)


@callback
def async_get_provider(
hass: HomeAssistant, domain: str | None = None
) -> Provider | None:
"""Return provider."""
providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
providers = hass.data[DATA_PROVIDERS]
if domain:
return providers.get(domain)

Expand Down

0 comments on commit 0299fa1

Please sign in to comment.