Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality for websocket listening #8

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions rhasspyclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@
class RhasspyClient:
"""Client object for remote Rhasspy server."""

def __init__(self, api_url: str, session: aiohttp.ClientSession):
self.api_url = api_url
if not self.api_url.endswith("/"):
self.api_url += "/"
def __init__(
self, host: str, port: str, session: aiohttp.ClientSession, secure: bool = False
):
if secure:
self.api_url = "https://{}:{}/api/".format(host, port)
self.events_url = "wss://{}:{}/api/events/".format(host, port)
else:
self.api_url = "http://{}:{}/api/".format(host, port)
self.events_url = "ws://{}:{}/api/events/".format(host, port)

# Construct URLs for end-points
self.sentences_url = urljoin(self.api_url, "sentences")
Expand All @@ -44,6 +49,10 @@ def __init__(self, api_url: str, session: aiohttp.ClientSession):
self.lookup_url = urljoin(self.api_url, "lookup")
self.version_url = urljoin(self.api_url, "version")

self.intent_listen_url = urljoin(self.events_url, "intent")
self.wake_listen_url = urljoin(self.events_url, "wake")
self.speech_listen_url = urljoin(self.events_url, "text")

self.session = session
assert self.session is not None, "ClientSession is required"

Expand Down Expand Up @@ -296,3 +305,40 @@ async def stream_to_text(self, raw_stream: aiohttp.StreamReader) -> str:
) as response:
response.raise_for_status()
return await response.text()

# -------------------------------------------------------------------------
async def listen_for_intent(self, handler, **handlerargs) -> None:
"""Given a handler function at startup handles the intents as they arrive"""
async with self.session.ws_connect(self.intent_listen_url) as ws:
async for msg in ws:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
break

if handlerargs:
handler(msg, handlerargs)
else:
handler(msg)

async def listen_for_speech(self, handler, **handlerargs) -> None:
"""Given a handler function at startup handles the transcribed text as it arrives"""
async with self.session.ws_connect(self.speech_listen_url) as ws:
async for msg in ws:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
break

if handlerargs:
handler(msg, handlerargs)
else:
handler(msg)

async def listen_for_wake(self, handler, **handlerargs) -> None:
"""Given a handler function at startup, handles when a wakeword is heard"""
async with self.session.ws_connect(self.wake_listen_url) as ws:
async for msg in ws:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
break

if handlerargs:
handler(msg, handlerargs)
else:
handler(msg)
11 changes: 7 additions & 4 deletions rhasspyclient/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ async def main():
"--debug", action="store_true", help="Print DEBUG messages to console"
)
parser.add_argument(
"--api-url",
default="http://localhost:12101/api",
help="URL of Rhasspy server (with /api)",
"--api-host",
default="localhost",
help="Host where the Rhasspy API can be found",
)
parser.add_argument(
"--api-port", default="12101", help="Port where Rhasspy is exposed"
)

sub_parsers = parser.add_subparsers()
Expand Down Expand Up @@ -105,7 +108,7 @@ async def main():

# Begin client session
async with aiohttp.ClientSession() as session:
client = RhasspyClient(args.api_url, session)
client = RhasspyClient(args.api_host, args.api_port, session)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The secure parameter you added to RhasspyClient.__init__ isn't used in the main function. Shouldn't you add an argument to the argument parser to use https/wss? I also think tls (and the corresponding --tls argument) is more descriptive than secure.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd agree with that. It'll be a little while before I can get back to this but I can fix when I have time


# Call sub-commmand
await args.func(args, client)
Expand Down