Skip to content

Commit

Permalink
Merge pull request #21 from jhakulin/jhakulin/minor-fixes4
Browse files Browse the repository at this point in the history
Fix for breaking change, minor styling updates, improved session update
  • Loading branch information
jhakulin authored Dec 13, 2024
2 parents 8209bbb + 5e2da4d commit 2d2f5cc
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 283 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
https://github.com/jhakulin/realtime-ai/releases/download/v0.1.4/realtime_ai-0.1.4-py3-none-any.whl
https://github.com/jhakulin/realtime-ai/releases/download/v0.1.6/realtime_ai-0.1.6-py3-none-any.whl
pyaudio
numpy
websockets
Expand Down
2 changes: 1 addition & 1 deletion samples/async/sample_realtime_ai_with_keyword_and_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def on_speech_start(self):
logger.info("Local VAD: User speech started")
logger.info(f"on_speech_start: Current state: {self._state}")

if self._state == ConversationState.KEYWORD_DETECTED:
if self._state == ConversationState.KEYWORD_DETECTED or self._state == ConversationState.CONVERSATION_ACTIVE:
asyncio.run_coroutine_threadsafe(self._set_state(ConversationState.CONVERSATION_ACTIVE), self._event_loop)
asyncio.run_coroutine_threadsafe(self._cancel_silence_timer(), self._event_loop)

Expand Down
2 changes: 1 addition & 1 deletion samples/sample_realtime_ai_with_keyword_and_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def on_speech_start(self):
logger.info("Local VAD: User speech started")
logger.info(f"on_speech_start: Current state: {self._state}")

if self._state == ConversationState.KEYWORD_DETECTED:
if self._state == ConversationState.KEYWORD_DETECTED or self._state == ConversationState.CONVERSATION_ACTIVE:
self._set_state(ConversationState.CONVERSATION_ACTIVE)
self._cancel_silence_timer()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="realtime-ai",
version="0.1.5",
version="0.1.6",
description="Python SDK for real-time audio processing with OpenAI's Realtime REST API.",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
Expand Down
40 changes: 20 additions & 20 deletions src/realtime_ai/aio/audio_stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,52 @@ class AudioStreamManager:
Manages streaming audio data to the Realtime API via the Service Manager.
"""
def __init__(self, stream_options: AudioStreamOptions, service_manager: RealtimeAIServiceManager):
self.stream_options = stream_options
self.service_manager = service_manager
self.audio_queue = asyncio.Queue()
self.is_streaming = False
self.stream_task = None
self._stream_options = stream_options
self._service_manager = service_manager
self._audio_queue = asyncio.Queue()
self._is_streaming = False
self._stream_task = None

def _start_stream(self):
if not self.is_streaming:
self.is_streaming = True
self.stream_task = asyncio.create_task(self._stream_audio())
if not self._is_streaming:
self._is_streaming = True
self._stream_task = asyncio.create_task(self._stream_audio())
logger.info("Audio streaming started.")

async def stop_stream(self):
if self.is_streaming:
self.is_streaming = False
if self.stream_task:
self.stream_task.cancel()
if self._is_streaming:
self._is_streaming = False
if self._stream_task:
self._stream_task.cancel()
try:
await self.stream_task
await self._stream_task
except asyncio.CancelledError:
logger.info("Audio streaming task cancelled.")
logger.info("Audio streaming stopped.")

async def write_audio_buffer(self, audio_data: bytes):
if not self.is_streaming:
if not self._is_streaming:
self._start_stream()
logger.info("Enqueuing audio data for streaming.")
await self.audio_queue.put(audio_data)
await self._audio_queue.put(audio_data)
logger.info("Audio data enqueued for streaming.")

async def _stream_audio(self):
logger.info(f"Streaming audio task started, is_streaming: {self.is_streaming}")
while self.is_streaming:
logger.info(f"Streaming audio task started, is_streaming: {self._is_streaming}")
while self._is_streaming:
try:
audio_chunk = await self.audio_queue.get()
audio_chunk = await self._audio_queue.get()
processed_audio = self._process_audio(audio_chunk)
encoded_audio = base64.b64encode(processed_audio).decode()

# Send input_audio_buffer.append event
append_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "input_audio_buffer.append",
"audio": encoded_audio
}

await self.service_manager.send_event(append_event)
await self._service_manager.send_event(append_event)
logger.info("input_audio_buffer.append event sent.")

except asyncio.CancelledError:
Expand Down
99 changes: 44 additions & 55 deletions src/realtime_ai/aio/realtime_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,33 @@ class RealtimeAIClient:
"""
def __init__(self, options: RealtimeAIOptions, stream_options: AudioStreamOptions, event_handler: RealtimeAIEventHandler):
self._options = options
self.service_manager = RealtimeAIServiceManager(options)
self.audio_stream_manager = AudioStreamManager(stream_options, self.service_manager)
self.event_handler = event_handler
self.is_running = False
self._service_manager = RealtimeAIServiceManager(options)
self._audio_stream_manager = AudioStreamManager(stream_options, self._service_manager)
self._event_handler = event_handler
self._is_running = False
self._consume_task = None

async def start(self):
"""Starts the RealtimeAIClient."""
if not self.is_running:
self.is_running = True
if not self._is_running:
self._is_running = True
try:
await self.service_manager.connect() # Connect to the service initially
await self._service_manager.connect() # Connect to the service initially
logger.info("RealtimeAIClient: Client started.")

# Schedule the event consumption coroutine as a background task
self._consume_task = asyncio.create_task(self._consume_events())
except Exception as e:
logger.error(f"RealtimeAIClient: Error during client start: {e}")
self.is_running = False
self._is_running = False

async def stop(self):
"""Stops the RealtimeAIClient gracefully."""
if self.is_running:
self.is_running = False
if self._is_running:
self._is_running = False
try:
await self.audio_stream_manager.stop_stream()
await self.service_manager.disconnect()
await self._audio_stream_manager.stop_stream()
await self._service_manager.disconnect()
logger.info("RealtimeAIClient: Services stopped.")

if self._consume_task:
Expand All @@ -55,20 +55,20 @@ async def stop(self):
except asyncio.CancelledError:
logger.info("RealtimeAIClient: consume_events task cancelled.")

await self.service_manager.clear_event_queue()
await self._service_manager.clear_event_queue()
except Exception as e:
logger.error(f"RealtimeAIClient: Error during client stop: {e}")

async def send_audio(self, audio_data: bytes):
"""Sends audio data to the audio stream manager for processing."""
logger.info("RealtimeAIClient: Queuing audio data for streaming.")
await self.audio_stream_manager.write_audio_buffer(audio_data)
await self._audio_stream_manager.write_audio_buffer(audio_data)

async def send_text(self, text: str, role: str = "user", generate_response: bool = True):
"""Sends text input to the service manager.
"""
event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "conversation.item.create",
"item": {
"type": "message",
Expand All @@ -81,7 +81,7 @@ async def send_text(self, text: str, role: str = "user", generate_response: bool
]
}
}
await self.service_manager.send_event(event)
await self._service_manager.send_event(event)
logger.info("RealtimeAIClient: Sent text input to server.")

# Generate a response if required
Expand All @@ -90,25 +90,10 @@ async def send_text(self, text: str, role: str = "user", generate_response: bool

async def update_session(self, options: RealtimeAIOptions):
"""Updates the session configuration with the provided options."""
event = {
"event_id": self.service_manager._generate_event_id(),
"type": "session.update",
"session": {
"modalities": options.modalities,
"instructions": options.instructions,
"voice": options.voice,
"input_audio_format": options.input_audio_format,
"output_audio_format": options.output_audio_format,
"input_audio_transcription": {
"model": options.input_audio_transcription_model
},
"turn_detection": options.turn_detection,
"tools": options.tools,
"tool_choice": options.tool_choice,
"temperature": options.temperature
}
}
await self.service_manager.send_event(event)
if self._is_running:
await self._service_manager.update_session(options)

self._service_manager.options = options
self._options = options
logger.info("RealtimeAIClient: Sent session update to server.")

Expand All @@ -117,49 +102,49 @@ async def generate_response(self, commit_audio_buffer: bool = True):
logger.info("RealtimeAIClient: Generating response.")
if commit_audio_buffer:
commit_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "input_audio_buffer.commit"
}
await self.service_manager.send_event(commit_event)
await self._service_manager.send_event(commit_event)

response_create_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "response.create",
"response": {"modalities": self.options.modalities}
}
await self.service_manager.send_event(response_create_event)
await self._service_manager.send_event(response_create_event)

async def cancel_response(self):
"""Sends a response.cancel event to interrupt the model when playback is interrupted by user."""
cancel_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "response.cancel"
}
await self.service_manager.send_event(cancel_event)
await self._service_manager.send_event(cancel_event)
logger.info("Client: Sent response.cancel event to server.")

# Clear the event queue in the service manager
await self.service_manager.clear_event_queue()
await self._service_manager.clear_event_queue()
logger.info("RealtimeAIClient: Event queue cleared after cancellation.")

async def truncate_response(self, item_id: str, content_index: int, audio_end_ms: int):
"""Sends a conversation.item.truncate event to truncate the response."""
truncate_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "conversation.item.truncate",
"item_id": item_id,
"content_index": content_index,
"audio_end_ms": audio_end_ms
}
await self.service_manager.send_event(truncate_event)
await self._service_manager.send_event(truncate_event)
logger.info("Client: Sent conversation.item.truncate event to server.")

async def clear_input_audio_buffer(self):
clear_audio_buffers_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "input_audio_buffer.clear"
}
await self.service_manager.send_event(clear_audio_buffers_event)
await self._service_manager.send_event(clear_audio_buffers_event)
logger.info("Client: Sent input_audio_buffer.clear event to server.")

async def generate_response_from_function_call(self, call_id: str, function_output: str):
Expand All @@ -173,7 +158,7 @@ async def generate_response_from_function_call(self, call_id: str, function_outp

# Create the function call output event
item_create_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "conversation.item.create",
"item": {
"id": str(uuid.uuid4()).replace('-', ''),
Expand All @@ -184,24 +169,24 @@ async def generate_response_from_function_call(self, call_id: str, function_outp
}

# Send the function call output event
await self.service_manager.send_event(item_create_event)
await self._service_manager.send_event(item_create_event)
logger.info("Function call output event sent.")

# Create and send the response.create event
response_event = {
"event_id": self.service_manager._generate_event_id(),
"event_id": self._service_manager._generate_event_id(),
"type": "response.create",
"response": {"modalities": self.options.modalities}
}
await self.service_manager.send_event(response_event)
await self._service_manager.send_event(response_event)

async def _consume_events(self):
"""Consume events from the service manager asynchronously."""
logger.info("RealtimeAIClient: Started consuming events.")
try:
while self.is_running:
while self._is_running:
try:
event = await self.service_manager.get_next_event()
event = await self._service_manager.get_next_event()
if event:
# Schedule the event handler as an independent task
asyncio.create_task(self._handle_event(event))
Expand All @@ -219,16 +204,20 @@ async def _handle_event(self, event: EventBase):
"""Handles the received event based on its type using the event handler."""
event_type = event.type
method_name = f'on_{event_type.replace(".", "_")}'
handler = getattr(self.event_handler, method_name, None)
handler = getattr(self._event_handler, method_name, None)

if callable(handler):
try:
await handler(event)
except Exception as e:
logger.error(f"Error in handler {method_name} for event {event_type}: {e}")
else:
await self.event_handler.on_unhandled_event(event_type, vars(event))
await self._event_handler.on_unhandled_event(event_type, vars(event))

@property
def options(self):
return self._options
return self._options

@property
def is_running(self):
return self._is_running
Loading

0 comments on commit 2d2f5cc

Please sign in to comment.