diff --git a/requirements.txt b/requirements.txt index 48b4245..f1c5562 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -https://github.com/jhakulin/realtime-ai/releases/download/v0.1.4/realtime_ai-0.1.4-py3-none-any.whl +https://github.com/jhakulin/realtime-ai/releases/download/v0.1.6/realtime_ai-0.1.6-py3-none-any.whl pyaudio numpy websockets diff --git a/samples/async/sample_realtime_ai_with_keyword_and_vad.py b/samples/async/sample_realtime_ai_with_keyword_and_vad.py index c174e50..1bbdfcc 100644 --- a/samples/async/sample_realtime_ai_with_keyword_and_vad.py +++ b/samples/async/sample_realtime_ai_with_keyword_and_vad.py @@ -76,7 +76,7 @@ def on_speech_start(self): logger.info("Local VAD: User speech started") logger.info(f"on_speech_start: Current state: {self._state}") - if self._state == ConversationState.KEYWORD_DETECTED: + if self._state == ConversationState.KEYWORD_DETECTED or self._state == ConversationState.CONVERSATION_ACTIVE: asyncio.run_coroutine_threadsafe(self._set_state(ConversationState.CONVERSATION_ACTIVE), self._event_loop) asyncio.run_coroutine_threadsafe(self._cancel_silence_timer(), self._event_loop) diff --git a/samples/sample_realtime_ai_with_keyword_and_vad.py b/samples/sample_realtime_ai_with_keyword_and_vad.py index 993fea9..b1d5ae7 100644 --- a/samples/sample_realtime_ai_with_keyword_and_vad.py +++ b/samples/sample_realtime_ai_with_keyword_and_vad.py @@ -70,7 +70,7 @@ def on_speech_start(self): logger.info("Local VAD: User speech started") logger.info(f"on_speech_start: Current state: {self._state}") - if self._state == ConversationState.KEYWORD_DETECTED: + if self._state == ConversationState.KEYWORD_DETECTED or self._state == ConversationState.CONVERSATION_ACTIVE: self._set_state(ConversationState.CONVERSATION_ACTIVE) self._cancel_silence_timer() diff --git a/setup.py b/setup.py index da50e79..99ef2da 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="realtime-ai", - version="0.1.5", + version="0.1.6", description="Python SDK for real-time audio processing with OpenAI's Realtime REST API.", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/src/realtime_ai/aio/audio_stream_manager.py b/src/realtime_ai/aio/audio_stream_manager.py index 2039213..239fc4b 100644 --- a/src/realtime_ai/aio/audio_stream_manager.py +++ b/src/realtime_ai/aio/audio_stream_manager.py @@ -12,52 +12,52 @@ class AudioStreamManager: Manages streaming audio data to the Realtime API via the Service Manager. """ def __init__(self, stream_options: AudioStreamOptions, service_manager: RealtimeAIServiceManager): - self.stream_options = stream_options - self.service_manager = service_manager - self.audio_queue = asyncio.Queue() - self.is_streaming = False - self.stream_task = None + self._stream_options = stream_options + self._service_manager = service_manager + self._audio_queue = asyncio.Queue() + self._is_streaming = False + self._stream_task = None def _start_stream(self): - if not self.is_streaming: - self.is_streaming = True - self.stream_task = asyncio.create_task(self._stream_audio()) + if not self._is_streaming: + self._is_streaming = True + self._stream_task = asyncio.create_task(self._stream_audio()) logger.info("Audio streaming started.") async def stop_stream(self): - if self.is_streaming: - self.is_streaming = False - if self.stream_task: - self.stream_task.cancel() + if self._is_streaming: + self._is_streaming = False + if self._stream_task: + self._stream_task.cancel() try: - await self.stream_task + await self._stream_task except asyncio.CancelledError: logger.info("Audio streaming task cancelled.") logger.info("Audio streaming stopped.") async def write_audio_buffer(self, audio_data: bytes): - if not self.is_streaming: + if not self._is_streaming: self._start_stream() logger.info("Enqueuing audio data for streaming.") - await self.audio_queue.put(audio_data) + await self._audio_queue.put(audio_data) logger.info("Audio data enqueued for streaming.") async def _stream_audio(self): - logger.info(f"Streaming audio task started, is_streaming: {self.is_streaming}") - while self.is_streaming: + logger.info(f"Streaming audio task started, is_streaming: {self._is_streaming}") + while self._is_streaming: try: - audio_chunk = await self.audio_queue.get() + audio_chunk = await self._audio_queue.get() processed_audio = self._process_audio(audio_chunk) encoded_audio = base64.b64encode(processed_audio).decode() # Send input_audio_buffer.append event append_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "input_audio_buffer.append", "audio": encoded_audio } - await self.service_manager.send_event(append_event) + await self._service_manager.send_event(append_event) logger.info("input_audio_buffer.append event sent.") except asyncio.CancelledError: diff --git a/src/realtime_ai/aio/realtime_ai_client.py b/src/realtime_ai/aio/realtime_ai_client.py index fbedb1a..59024f0 100644 --- a/src/realtime_ai/aio/realtime_ai_client.py +++ b/src/realtime_ai/aio/realtime_ai_client.py @@ -18,33 +18,33 @@ class RealtimeAIClient: """ def __init__(self, options: RealtimeAIOptions, stream_options: AudioStreamOptions, event_handler: RealtimeAIEventHandler): self._options = options - self.service_manager = RealtimeAIServiceManager(options) - self.audio_stream_manager = AudioStreamManager(stream_options, self.service_manager) - self.event_handler = event_handler - self.is_running = False + self._service_manager = RealtimeAIServiceManager(options) + self._audio_stream_manager = AudioStreamManager(stream_options, self._service_manager) + self._event_handler = event_handler + self._is_running = False self._consume_task = None async def start(self): """Starts the RealtimeAIClient.""" - if not self.is_running: - self.is_running = True + if not self._is_running: + self._is_running = True try: - await self.service_manager.connect() # Connect to the service initially + await self._service_manager.connect() # Connect to the service initially logger.info("RealtimeAIClient: Client started.") # Schedule the event consumption coroutine as a background task self._consume_task = asyncio.create_task(self._consume_events()) except Exception as e: logger.error(f"RealtimeAIClient: Error during client start: {e}") - self.is_running = False + self._is_running = False async def stop(self): """Stops the RealtimeAIClient gracefully.""" - if self.is_running: - self.is_running = False + if self._is_running: + self._is_running = False try: - await self.audio_stream_manager.stop_stream() - await self.service_manager.disconnect() + await self._audio_stream_manager.stop_stream() + await self._service_manager.disconnect() logger.info("RealtimeAIClient: Services stopped.") if self._consume_task: @@ -55,20 +55,20 @@ async def stop(self): except asyncio.CancelledError: logger.info("RealtimeAIClient: consume_events task cancelled.") - await self.service_manager.clear_event_queue() + await self._service_manager.clear_event_queue() except Exception as e: logger.error(f"RealtimeAIClient: Error during client stop: {e}") async def send_audio(self, audio_data: bytes): """Sends audio data to the audio stream manager for processing.""" logger.info("RealtimeAIClient: Queuing audio data for streaming.") - await self.audio_stream_manager.write_audio_buffer(audio_data) + await self._audio_stream_manager.write_audio_buffer(audio_data) async def send_text(self, text: str, role: str = "user", generate_response: bool = True): """Sends text input to the service manager. """ event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "conversation.item.create", "item": { "type": "message", @@ -81,7 +81,7 @@ async def send_text(self, text: str, role: str = "user", generate_response: bool ] } } - await self.service_manager.send_event(event) + await self._service_manager.send_event(event) logger.info("RealtimeAIClient: Sent text input to server.") # Generate a response if required @@ -90,25 +90,10 @@ async def send_text(self, text: str, role: str = "user", generate_response: bool async def update_session(self, options: RealtimeAIOptions): """Updates the session configuration with the provided options.""" - event = { - "event_id": self.service_manager._generate_event_id(), - "type": "session.update", - "session": { - "modalities": options.modalities, - "instructions": options.instructions, - "voice": options.voice, - "input_audio_format": options.input_audio_format, - "output_audio_format": options.output_audio_format, - "input_audio_transcription": { - "model": options.input_audio_transcription_model - }, - "turn_detection": options.turn_detection, - "tools": options.tools, - "tool_choice": options.tool_choice, - "temperature": options.temperature - } - } - await self.service_manager.send_event(event) + if self._is_running: + await self._service_manager.update_session(options) + + self._service_manager.options = options self._options = options logger.info("RealtimeAIClient: Sent session update to server.") @@ -117,49 +102,49 @@ async def generate_response(self, commit_audio_buffer: bool = True): logger.info("RealtimeAIClient: Generating response.") if commit_audio_buffer: commit_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "input_audio_buffer.commit" } - await self.service_manager.send_event(commit_event) + await self._service_manager.send_event(commit_event) response_create_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "response.create", "response": {"modalities": self.options.modalities} } - await self.service_manager.send_event(response_create_event) + await self._service_manager.send_event(response_create_event) async def cancel_response(self): """Sends a response.cancel event to interrupt the model when playback is interrupted by user.""" cancel_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "response.cancel" } - await self.service_manager.send_event(cancel_event) + await self._service_manager.send_event(cancel_event) logger.info("Client: Sent response.cancel event to server.") # Clear the event queue in the service manager - await self.service_manager.clear_event_queue() + await self._service_manager.clear_event_queue() logger.info("RealtimeAIClient: Event queue cleared after cancellation.") async def truncate_response(self, item_id: str, content_index: int, audio_end_ms: int): """Sends a conversation.item.truncate event to truncate the response.""" truncate_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "conversation.item.truncate", "item_id": item_id, "content_index": content_index, "audio_end_ms": audio_end_ms } - await self.service_manager.send_event(truncate_event) + await self._service_manager.send_event(truncate_event) logger.info("Client: Sent conversation.item.truncate event to server.") async def clear_input_audio_buffer(self): clear_audio_buffers_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "input_audio_buffer.clear" } - await self.service_manager.send_event(clear_audio_buffers_event) + await self._service_manager.send_event(clear_audio_buffers_event) logger.info("Client: Sent input_audio_buffer.clear event to server.") async def generate_response_from_function_call(self, call_id: str, function_output: str): @@ -173,7 +158,7 @@ async def generate_response_from_function_call(self, call_id: str, function_outp # Create the function call output event item_create_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "conversation.item.create", "item": { "id": str(uuid.uuid4()).replace('-', ''), @@ -184,24 +169,24 @@ async def generate_response_from_function_call(self, call_id: str, function_outp } # Send the function call output event - await self.service_manager.send_event(item_create_event) + await self._service_manager.send_event(item_create_event) logger.info("Function call output event sent.") # Create and send the response.create event response_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "response.create", "response": {"modalities": self.options.modalities} } - await self.service_manager.send_event(response_event) + await self._service_manager.send_event(response_event) async def _consume_events(self): """Consume events from the service manager asynchronously.""" logger.info("RealtimeAIClient: Started consuming events.") try: - while self.is_running: + while self._is_running: try: - event = await self.service_manager.get_next_event() + event = await self._service_manager.get_next_event() if event: # Schedule the event handler as an independent task asyncio.create_task(self._handle_event(event)) @@ -219,7 +204,7 @@ async def _handle_event(self, event: EventBase): """Handles the received event based on its type using the event handler.""" event_type = event.type method_name = f'on_{event_type.replace(".", "_")}' - handler = getattr(self.event_handler, method_name, None) + handler = getattr(self._event_handler, method_name, None) if callable(handler): try: @@ -227,8 +212,12 @@ async def _handle_event(self, event: EventBase): except Exception as e: logger.error(f"Error in handler {method_name} for event {event_type}: {e}") else: - await self.event_handler.on_unhandled_event(event_type, vars(event)) + await self._event_handler.on_unhandled_event(event_type, vars(event)) @property def options(self): - return self._options + return self._options + + @property + def is_running(self): + return self._is_running \ No newline at end of file diff --git a/src/realtime_ai/aio/realtime_ai_service_manager.py b/src/realtime_ai/aio/realtime_ai_service_manager.py index 065175f..560d842 100644 --- a/src/realtime_ai/aio/realtime_ai_service_manager.py +++ b/src/realtime_ai/aio/realtime_ai_service_manager.py @@ -43,34 +43,14 @@ class RealtimeAIServiceManager: """ def __init__(self, options: RealtimeAIOptions): - self.options = options - self.websocket_manager = WebSocketManager(options, self) - self.event_queue = asyncio.Queue() - self.is_connected = False - - # Pre-create session.update event details - self.session_update_event = { - "event_id": self._generate_event_id(), - "type": "session.update", - "session": { - "modalities": self.options.modalities, - "instructions": self.options.instructions, - "voice": self.options.voice, - "input_audio_format": self.options.input_audio_format, - "output_audio_format": self.options.output_audio_format, - "input_audio_transcription": { - "model": self.options.input_audio_transcription_model - }, - "turn_detection": self.options.turn_detection, - "tools": self.options.tools, - "tool_choice": self.options.tool_choice, - "temperature": self.options.temperature - } - } + self._options = options + self._websocket_manager = WebSocketManager(options, self) + self._event_queue = asyncio.Queue() + self._is_connected = False async def connect(self): try: - await self.websocket_manager.connect() + await self._websocket_manager.connect() except asyncio.CancelledError: logger.info("RealtimeAIServiceManager: Connection was cancelled.") except Exception as e: @@ -78,8 +58,8 @@ async def connect(self): async def disconnect(self): try: - await self.event_queue.put(None) # Signal the event loop to stop - await self.websocket_manager.disconnect() + await self._event_queue.put(None) # Signal the event loop to stop + await self._websocket_manager.disconnect() except asyncio.CancelledError: logger.info("RealtimeAIServiceManager: Disconnect was cancelled.") except Exception as e: @@ -87,15 +67,15 @@ async def disconnect(self): async def send_event(self, event: dict): try: - await self.websocket_manager.send(event) + await self._websocket_manager.send(event) logger.debug(f"RealtimeAIServiceManager: Sent event: {event.get('type')}") except Exception as e: logger.error(f"RealtimeAIServiceManager: Failed to send event {event.get('type')}: {e}") async def on_connected(self, reconnection: bool = False): - self.is_connected = True + self._is_connected = True logger.info("RealtimeAIServiceManager: Connected to WebSocket.") - await self.send_event(self.session_update_event) + await self.update_session(self._options) if reconnection: # If it's a reconnection, trigger a ReconnectedEvent reconnect_event = ReconnectedEvent( @@ -107,7 +87,7 @@ async def on_connected(self, reconnection: bool = False): logger.debug("RealtimeAIServiceManager: session.update event sent.") async def on_disconnected(self, status_code: int, reason: str): - self.is_connected = False + self._is_connected = False logger.warning(f"RealtimeAIServiceManager: WebSocket disconnected: {status_code} - {reason}") async def on_error(self, error: Exception): @@ -118,7 +98,7 @@ async def on_message_received(self, message: str): json_object = json.loads(message) event = self.parse_realtime_event(json_object) if event: - await self.event_queue.put(event) + await self._event_queue.put(event) logger.debug(f"RealtimeAIServiceManager: Event queued: {event.type}") except json.JSONDecodeError as e: logger.error(f"RealtimeAIServiceManager: JSON parse error: {e}") @@ -178,12 +158,33 @@ def parse_realtime_event(self, json_object: dict) -> Optional[EventBase]: logger.warning(f"RealtimeAIServiceManager: Unknown message type received: {event_type}") return None + async def update_session(self, options: RealtimeAIOptions) -> dict: + event = { + "event_id": self._generate_event_id(), + "type": "session.update", + "session": { + "modalities": options.modalities, + "instructions": options.instructions, + "voice": options.voice, + "input_audio_format": options.input_audio_format, + "output_audio_format": options.output_audio_format, + "input_audio_transcription": { + "model": options.input_audio_transcription_model + }, + "turn_detection": options.turn_detection, + "tools": options.tools, + "tool_choice": options.tool_choice, + "temperature": options.temperature + } + } + await self.send_event(event) + async def clear_event_queue(self): """Clears all events in the event queue.""" try: - while not self.event_queue.empty(): - await self.event_queue.get() - self.event_queue.task_done() + while not self._event_queue.empty(): + await self._event_queue.get() + self._event_queue.task_done() logger.info("RealtimeAIServiceManager: Event queue cleared.") except Exception as e: logger.error(f"RealtimeAIServiceManager: Failed to clear event queue: {e}") @@ -219,9 +220,18 @@ def _get_event_class(self, event_type: str) -> Optional[Type[EventBase]]: async def get_next_event(self) -> Optional[EventBase]: try: logger.debug("RealtimeAIServiceManager: Waiting for next event...") - return await asyncio.wait_for(self.event_queue.get(), timeout=5.0) + return await asyncio.wait_for(self._event_queue.get(), timeout=5.0) except asyncio.TimeoutError: return None def _generate_event_id(self) -> str: return f"event_{uuid.uuid4()}" + + @property + def options(self): + return self._options + + @options.setter + def options(self, options: RealtimeAIOptions): + self._options = options + logger.info("RealtimeAIServiceManager: Options updated.") diff --git a/src/realtime_ai/aio/web_socket_manager.py b/src/realtime_ai/aio/web_socket_manager.py index 203e0ed..187bf30 100644 --- a/src/realtime_ai/aio/web_socket_manager.py +++ b/src/realtime_ai/aio/web_socket_manager.py @@ -14,39 +14,39 @@ class WebSocketManager: """ def __init__(self, options: RealtimeAIOptions, service_manager): - self.options = options - self.service_manager = service_manager - self.websocket = None + self._options = options + self._service_manager = service_manager + self._websocket = None - if self.options.azure_openai_endpoint: - self.request_id = uuid.uuid4() - self.url = self.options.azure_openai_endpoint + f"?api-version={self.options.azure_openai_api_version}" + f"&deployment={self.options.model}" - self.headers = { - "x-ms-client-request-id": str(self.request_id), - "api-key": self.options.api_key, + if self._options.azure_openai_endpoint: + request_id = uuid.uuid4() + self._url = self._options.azure_openai_endpoint + f"?api-version={self._options.azure_openai_api_version}" + f"&deployment={self._options.model}" + self._headers = { + "x-ms-client-request-id": str(request_id), + "api-key": self._options.api_key, } else: - self.url = f"{self.options.url}?model={self.options.model}" - self.headers = { - "Authorization": f"Bearer {self.options.api_key}", + self._url = f"{self._options.url}?model={self._options.model}" + self._headers = { + "Authorization": f"Bearer {self._options.api_key}", "openai-beta": "realtime=v1", } - self.reconnect_delay = 5 # Time to wait before attempting to reconnect, in seconds + self._reconnect_delay = 5 # Time to wait before attempting to reconnect, in seconds async def connect(self, reconnection=False): """ Establishes a WebSocket connection. """ try: - if self.websocket and self.websocket.open: + if self._websocket: logger.info("WebSocketManager: Already connected.") return - logger.info(f"WebSocketManager: Connecting to {self.url}") - self.websocket = await websockets.connect(self.url, extra_headers=self.headers) + logger.info(f"WebSocketManager: Connecting to {self._url}") + self._websocket = await websockets.connect(self._url, additional_headers=self._headers) logger.info("WebSocketManager: WebSocket connection established.") - await self.service_manager.on_connected(reconnection=reconnection) + await self._service_manager.on_connected(reconnection=reconnection) asyncio.create_task(self._receive_messages()) # Begin listening as a separate task except Exception as e: @@ -57,16 +57,16 @@ async def _receive_messages(self): Listens for incoming WebSocket messages and delegates them to the service manager. """ try: - async for message in self.websocket: - await self.service_manager.on_message_received(message) + async for message in self._websocket: + await self._service_manager.on_message_received(message) logger.debug(f"WebSocketManager: Received message: {message}") if "session_expired" in message and "maximum duration of 15 minutes" in message: logger.info("WebSocketManager: Reconnecting due to maximum duration reached.") - await asyncio.sleep(self.reconnect_delay) + await asyncio.sleep(self._reconnect_delay) await self.connect(reconnection=True) except websockets.exceptions.ConnectionClosed as e: logger.warning(f"WebSocketManager: Connection closed during receive: {e.code} - {e.reason}") - await self.service_manager.on_disconnected(e.code, e.reason) + await self._service_manager.on_disconnected(e.code, e.reason) except asyncio.CancelledError: logger.info("WebSocketManager: Receive task was cancelled.") except Exception as e: @@ -76,26 +76,28 @@ async def disconnect(self): """ Gracefully disconnects the WebSocket connection. """ - if self.websocket: + if self._websocket: try: - await self.websocket.close() + await self._websocket.close() logger.info("WebSocketManager: WebSocket closed gracefully.") except Exception as e: logger.error(f"WebSocketManager: Error closing WebSocket: {e}") + finally: + self._websocket = None async def send(self, message: dict): """ Sends a message over the WebSocket. """ # check if message is cancel_event - if self.websocket and self.websocket.open: + if self._websocket: try: message_str = json.dumps(message) - await self.websocket.send(message_str) + await self._websocket.send(message_str) logger.debug(f"WebSocketManager: Sent message: {message_str}") except Exception as e: logger.error(f"WebSocketManager: Send failed: {e}") - await self.service_manager.on_error(e) + await self._service_manager.on_error(e) else: logger.error("WebSocketManager: Cannot send message. WebSocket is not connected.") raise ConnectionError("WebSocket is not connected.") diff --git a/src/realtime_ai/audio_stream_manager.py b/src/realtime_ai/audio_stream_manager.py index b282209..e13a4e3 100644 --- a/src/realtime_ai/audio_stream_manager.py +++ b/src/realtime_ai/audio_stream_manager.py @@ -14,53 +14,56 @@ class AudioStreamManager: """ def __init__(self, stream_options: AudioStreamOptions, service_manager: RealtimeAIServiceManager): - self.stream_options = stream_options - self.service_manager = service_manager - self.audio_queue = queue.Queue() - self.is_streaming = False - self.stream_thread = None + self._stream_options = stream_options + self._service_manager = service_manager + self._audio_queue = queue.Queue() + self._is_streaming = False + self._stream_thread = None self._stop_event = threading.Event() + self._lock = threading.RLock() - def start_stream(self): - if not self.is_streaming: - self.is_streaming = True - self._stop_event.clear() - self.stream_thread = threading.Thread(target=self._stream_audio) - self.stream_thread.start() - logger.info("Audio streaming started.") + def _start_stream(self): + with self._lock: + if not self._is_streaming: + self._is_streaming = True + self._stop_event.clear() + self._stream_thread = threading.Thread(target=self._stream_audio) + self._stream_thread.start() + logger.info("Audio streaming started.") def stop_stream(self): - if self.is_streaming: - self.is_streaming = False + if self._is_streaming: + self._is_streaming = False self._stop_event.set() # Signal to the thread to stop - if self.stream_thread: - self.stream_thread.join() + if self._stream_thread: + self._stream_thread.join() logger.info("Audio streaming stopped.") def write_audio_buffer_sync(self, audio_data: bytes): - if not self.is_streaming: - self.start_stream() + with self._lock: + if not self._is_streaming: + self._start_stream() logger.info("Enqueuing audio data for streaming.") - self.audio_queue.put_nowait(audio_data) + self._audio_queue.put_nowait(audio_data) logger.info("Audio data enqueued for streaming.") def _stream_audio(self): - logger.info(f"Streaming audio task started, is_streaming: {self.is_streaming}") + logger.info(f"Streaming audio task started, is_streaming: {self._is_streaming}") - while self.is_streaming and not self._stop_event.is_set(): + while self._is_streaming and not self._stop_event.is_set(): try: - audio_chunk = self.audio_queue.get(timeout=1) # Block for a short moment + audio_chunk = self._audio_queue.get(timeout=1) # Block for a short moment processed_audio = self._process_audio(audio_chunk) encoded_audio = base64.b64encode(processed_audio).decode() # Send input_audio_buffer.append event append_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "input_audio_buffer.append", "audio": encoded_audio } - self.service_manager.send_event(append_event) + self._service_manager.send_event(append_event) logger.info("input_audio_buffer.append event sent.") except queue.Empty: diff --git a/src/realtime_ai/realtime_ai_client.py b/src/realtime_ai/realtime_ai_client.py index 381fcae..bed42ee 100644 --- a/src/realtime_ai/realtime_ai_client.py +++ b/src/realtime_ai/realtime_ai_client.py @@ -18,10 +18,10 @@ class RealtimeAIClient: """ def __init__(self, options: RealtimeAIOptions, stream_options: AudioStreamOptions, event_handler: RealtimeAIEventHandler): self._options = options - self.service_manager = RealtimeAIServiceManager(options) - self.audio_stream_manager = AudioStreamManager(stream_options, self.service_manager) - self.event_handler = event_handler - self.is_running = False + self._service_manager = RealtimeAIServiceManager(options) + self._audio_stream_manager = AudioStreamManager(stream_options, self._service_manager) + self._event_handler = event_handler + self._is_running = False self._lock = threading.Lock() # Initialize the consume thread and executor as None @@ -32,14 +32,14 @@ def __init__(self, options: RealtimeAIOptions, stream_options: AudioStreamOption def start(self): """Starts the RealtimeAIClient.""" with self._lock: - if self.is_running: + if self._is_running: logger.warning("RealtimeAIClient: Client is already running.") return - self.is_running = True + self._is_running = True self._stop_event.clear() try: - self.service_manager.connect() # Connect to the service + self._service_manager.connect() # Connect to the service logger.info("RealtimeAIClient: Client started.") # Initialize and start the ThreadPoolExecutor here @@ -51,24 +51,24 @@ def start(self): self._consume_thread.start() logger.info("RealtimeAIClient: Event consumption thread started.") except Exception as e: - self.is_running = False + self._is_running = False logger.error(f"RealtimeAIClient: Error during client start: {e}") def stop(self, timeout: float = 5.0): """Stops the RealtimeAIClient gracefully.""" with self._lock: - if not self.is_running: + if not self._is_running: logger.warning("RealtimeAIClient: Client is already stopped.") return - self.is_running = False + self._is_running = False # Signal stop event self._stop_event.set() try: - self.audio_stream_manager.stop_stream() - self.service_manager.disconnect() + self._audio_stream_manager.stop_stream() + self._service_manager.disconnect() if self._consume_thread is not None: @@ -85,7 +85,7 @@ def stop(self, timeout: float = 5.0): logger.info("RealtimeAIClient: ThreadPoolExecutor shut down.") self.executor = None - self.service_manager.clear_event_queue() + self._service_manager.clear_event_queue() logger.info("RealtimeAIClient: Services stopped.") except Exception as e: logger.error(f"RealtimeAIClient: Error during client stop: {e}") @@ -93,12 +93,12 @@ def stop(self, timeout: float = 5.0): def send_audio(self, audio_data: bytes): """Sends audio data to the audio stream manager for processing.""" logger.info("RealtimeAIClient: Queuing audio data for streaming.") - self.audio_stream_manager.write_audio_buffer_sync(audio_data) # Ensure this is a sync method + self._audio_stream_manager.write_audio_buffer_sync(audio_data) # Ensure this is a sync method def send_text(self, text: str, role: str = "user", generate_response: bool = True): """Sends text input to the service manager.""" event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "conversation.item.create", "item": { "type": "message", @@ -120,25 +120,10 @@ def send_text(self, text: str, role: str = "user", generate_response: bool = Tru def update_session(self, options: RealtimeAIOptions): """Updates the session configuration with the provided options.""" - event = { - "event_id": self.service_manager._generate_event_id(), - "type": "session.update", - "session": { - "modalities": options.modalities, - "instructions": options.instructions, - "voice": options.voice, - "input_audio_format": options.input_audio_format, - "output_audio_format": options.output_audio_format, - "input_audio_transcription": { - "model": options.input_audio_transcription_model - }, - "turn_detection": options.turn_detection, - "tools": options.tools, - "tool_choice": options.tool_choice, - "temperature": options.temperature - } - } - self._send_event_to_manager(event) + if self._is_running: + self._service_manager.update_session(options) + + self._service_manager.options = options self._options = options logger.info("RealtimeAIClient: Sent session update to server.") @@ -147,12 +132,12 @@ def generate_response(self, commit_audio_buffer: bool = True): logger.info("RealtimeAIClient: Generating response.") if commit_audio_buffer: self._send_event_to_manager({ - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "input_audio_buffer.commit", }) self._send_event_to_manager({ - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "response.create", "response": {"modalities": self.options.modalities} }) @@ -160,19 +145,19 @@ def generate_response(self, commit_audio_buffer: bool = True): def cancel_response(self): """Sends a response.cancel event to interrupt the model when playback is interrupted by user.""" self._send_event_to_manager({ - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "response.cancel" }) logger.info("Client: Sent response.cancel event to server.") # Clear the event queue in the service manager - self.service_manager.clear_event_queue() + self._service_manager.clear_event_queue() logger.info("RealtimeAIClient: Event queue cleared after cancellation.") def truncate_response(self, item_id: str, content_index: int, audio_end_ms: int): """Sends a conversation.item.truncate event to truncate the response.""" self._send_event_to_manager({ - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "conversation.item.truncate", "item_id": item_id, "content_index": content_index, @@ -183,7 +168,7 @@ def truncate_response(self, item_id: str, content_index: int, audio_end_ms: int) def clear_input_audio_buffer(self): """Sends an input_audio_buffer.clear event to the server.""" self._send_event_to_manager({ - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "input_audio_buffer.clear" }) logger.info("Client: Sent input_audio_buffer.clear event to server.") @@ -197,7 +182,7 @@ def generate_response_from_function_call(self, call_id: str, function_output: st """ # Create the function call output event item_create_event = { - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "conversation.item.create", "item": { "id": str(uuid.uuid4()).replace('-', ''), @@ -213,7 +198,7 @@ def generate_response_from_function_call(self, call_id: str, function_output: st # Optionally trigger a response self._send_event_to_manager({ - "event_id": self.service_manager._generate_event_id(), + "event_id": self._service_manager._generate_event_id(), "type": "response.create", "response": {"modalities": self.options.modalities} }) @@ -223,7 +208,7 @@ def _consume_events(self): logger.info("Consume thread: Started consuming events.") while not self._stop_event.is_set(): try: - event = self.service_manager.get_next_event() + event = self._service_manager.get_next_event() if event is None: logger.info("Consume thread: Received sentinel, exiting.") break @@ -244,7 +229,7 @@ def _handle_event(self, event: EventBase): """Handles the received event based on its type using the event handler.""" event_type = event.type method_name = f'on_{event_type.replace(".", "_")}' - handler = getattr(self.event_handler, method_name, None) + handler = getattr(self._event_handler, method_name, None) if callable(handler): try: @@ -252,17 +237,21 @@ def _handle_event(self, event: EventBase): except Exception as e: logger.error(f"Error in handler {method_name} for event {event_type}: {e}") else: - self.event_handler.on_unhandled_event(event_type, vars(event)) + self._event_handler.on_unhandled_event(event_type, vars(event)) def _send_event_to_manager(self, event): """Helper method to send an event to the manager.""" - self.service_manager.send_event(event) + self._service_manager.send_event(event) @property def options(self): return self._options + + @property + def is_running(self): + return self._is_running # Optional: Ensure that threads are cleaned up if the object is deleted while running def __del__(self): - if self.is_running: + if self._is_running: self.stop() \ No newline at end of file diff --git a/src/realtime_ai/realtime_ai_service_manager.py b/src/realtime_ai/realtime_ai_service_manager.py index 5be073a..758dfc7 100644 --- a/src/realtime_ai/realtime_ai_service_manager.py +++ b/src/realtime_ai/realtime_ai_service_manager.py @@ -45,60 +45,39 @@ class RealtimeAIServiceManager: """ def __init__(self, options: RealtimeAIOptions): - self.options = options - self.websocket_manager = WebSocketManager(options, self) - self.event_queue = queue.Queue() - self.is_connected = False - + self._options = options + self._websocket_manager = WebSocketManager(options, self) + self._event_queue = queue.Queue() + self._is_connected = False self._thread_running_event = threading.Event() - # Pre-construct session.update event details - self.session_update_event = { - "event_id": self._generate_event_id(), - "type": "session.update", - "session": { - "modalities": self.options.modalities, - "instructions": self.options.instructions, - "voice": self.options.voice, - "input_audio_format": self.options.input_audio_format, - "output_audio_format": self.options.output_audio_format, - "input_audio_transcription": { - "model": self.options.input_audio_transcription_model - }, - "turn_detection": self.options.turn_detection, - "tools": self.options.tools, - "tool_choice": self.options.tool_choice, - "temperature": self.options.temperature - } - } - def connect(self): try: - self.websocket_manager.connect() - self.is_connected = True + self._websocket_manager.connect() + self._is_connected = True logger.info("RealtimeAIServiceManager: Connection started to WebSocket.") except Exception as e: logger.error(f"RealtimeAIServiceManager: Unexpected error during connect: {e}") def disconnect(self): try: - self.event_queue.put(None) # Signal the event loop to stop - self.websocket_manager.disconnect() - self.is_connected = False + self._event_queue.put(None) # Signal the event loop to stop + self._websocket_manager.disconnect() + self._is_connected = False logger.warning("RealtimeAIServiceManager: WebSocket disconnection started.") except Exception as e: logger.error(f"RealtimeAIServiceManager: Unexpected error during disconnect: {e}") def send_event(self, event: dict): try: - self.websocket_manager.send(event) + self._websocket_manager.send(event) logger.debug(f"RealtimeAIServiceManager: Sent event: {event.get('type')}") except Exception as e: logger.error(f"RealtimeAIServiceManager: Failed to send event {event.get('type')}: {e}") def on_connected(self, reconnection: bool = False): logger.info("RealtimeAIServiceManager: WebSocket connected.") - self.send_event(self.session_update_event) + self.update_session(self._options) if reconnection: # If it's a reconnection, trigger a ReconnectedEvent reconnect_event = ReconnectedEvent( @@ -120,7 +99,7 @@ def on_message_received(self, message: str): json_object = json.loads(message) event = self.parse_realtime_event(json_object) if event: - self.event_queue.put_nowait(event) + self._event_queue.put_nowait(event) logger.debug(f"RealtimeAIServiceManager: Event queued: {event.type}") except json.JSONDecodeError as e: logger.error(f"RealtimeAIServiceManager: JSON parse error: {e}") @@ -180,10 +159,31 @@ def parse_realtime_event(self, json_object: dict) -> Optional[EventBase]: logger.warning(f"RealtimeAIServiceManager: Unknown message type received: {event_type}") return None + def update_session(self, options: RealtimeAIOptions) -> dict: + event = { + "event_id": self._generate_event_id(), + "type": "session.update", + "session": { + "modalities": options.modalities, + "instructions": options.instructions, + "voice": options.voice, + "input_audio_format": options.input_audio_format, + "output_audio_format": options.output_audio_format, + "input_audio_transcription": { + "model": options.input_audio_transcription_model + }, + "turn_detection": options.turn_detection, + "tools": options.tools, + "tool_choice": options.tool_choice, + "temperature": options.temperature + } + } + self.send_event(event) + def clear_event_queue(self): """Clears all events in the event queue.""" try: - self.event_queue.queue.clear() + self._event_queue.queue.clear() logger.info("RealtimeAIServiceManager: Event queue cleared.") except Exception as e: logger.error(f"RealtimeAIServiceManager: Failed to clear event queue: {e}") @@ -219,9 +219,18 @@ def _get_event_class(self, event_type: str) -> Optional[Type[EventBase]]: def get_next_event(self, timeout=5.0) -> Optional[EventBase]: try: logger.info("RealtimeAIServiceManager: Waiting for next event...") - return self.event_queue.get(timeout=timeout) + return self._event_queue.get(timeout=timeout) except queue.Empty: raise def _generate_event_id(self) -> str: return f"event_{uuid.uuid4()}" + + @property + def options(self): + return self._options + + @options.setter + def options(self, options: RealtimeAIOptions): + self._options = options + logger.info("RealtimeAIServiceManager: Options updated.") diff --git a/src/realtime_ai/web_socket_manager.py b/src/realtime_ai/web_socket_manager.py index d2cad26..62475eb 100644 --- a/src/realtime_ai/web_socket_manager.py +++ b/src/realtime_ai/web_socket_manager.py @@ -15,48 +15,48 @@ class WebSocketManager: """ def __init__(self, options : RealtimeAIOptions, service_manager): - self.options = options - self.service_manager = service_manager - - if self.options.azure_openai_endpoint: - self.request_id = uuid.uuid4() - self.url = self.options.azure_openai_endpoint + f"?api-version={self.options.azure_openai_api_version}" + f"&deployment={self.options.model}" - self.headers = { - "x-ms-client-request-id": str(self.request_id), - "api-key": self.options.api_key, + self._options = options + self._service_manager = service_manager + + if self._options.azure_openai_endpoint: + request_id = uuid.uuid4() + self._url = self._options.azure_openai_endpoint + f"?api-version={self._options.azure_openai_api_version}" + f"&deployment={self._options.model}" + self._headers = { + "x-ms-client-request-id": str(request_id), + "api-key": self._options.api_key, } else: - self.url = f"{self.options.url}?model={self.options.model}" - self.headers = { - "Authorization": f"Bearer {self.options.api_key}", + self._url = f"{self._options.url}?model={self._options.model}" + self._headers = { + "Authorization": f"Bearer {self._options.api_key}", "openai-beta": "realtime=v1", } - self.ws = None + self._ws = None self._receive_thread = None - self.reconnect_delay = 5 # Time to wait before attempting to reconnect, in seconds - self.is_reconnection = False + self._reconnect_delay = 5 # Time to wait before attempting to reconnect, in seconds + self._is_reconnection = False def connect(self): """ Establishes a WebSocket connection. """ try: - if self.ws and self.ws.sock and self.ws.sock.connected: + if self._ws and self._ws.sock and self._ws.sock.connected: logger.info("WebSocketManager: Already connected.") return - logger.info(f"WebSocketManager: Connecting to {self.url}") - self.ws = websocket.WebSocketApp( - self.url, + logger.info(f"WebSocketManager: Connecting to {self._url}") + self._ws = websocket.WebSocketApp( + self._url, on_open=self._on_open, on_message=self._on_message, on_error=self._on_error, on_close=self._on_close, - header=self.headers + header=self._headers ) - self._receive_thread = threading.Thread(target=self.ws.run_forever) + self._receive_thread = threading.Thread(target=self._ws.run_forever) self._receive_thread.start() logger.info("WebSocketManager: WebSocket connection established.") except Exception as e: @@ -66,8 +66,8 @@ def disconnect(self): """ Gracefully disconnects the WebSocket connection. """ - if self.ws: - self.ws.close() + if self._ws: + self._ws.close() if self._receive_thread: self._receive_thread.join() logger.info("WebSocketManager: WebSocket closed gracefully.") @@ -76,46 +76,46 @@ def send(self, message: dict): """ Sends a message over the WebSocket. """ - if self.ws and self.ws.sock and self.ws.sock.connected: + if self._ws and self._ws.sock and self._ws.sock.connected: try: message_str = json.dumps(message) - self.ws.send(message_str) + self._ws.send(message_str) logger.debug(f"WebSocketManager: Sent message: {message_str}") except Exception as e: logger.error(f"WebSocketManager: Send failed: {e}") def _on_open(self, ws): logger.info("WebSocketManager: WebSocket connection opened.") - if self.is_reconnection: + if self._is_reconnection: logger.info("WebSocketManager: Connection reopened (Reconnection).") - self.service_manager.on_connected(reconnection=True) - self.is_reconnection = False + self._service_manager.on_connected(reconnection=True) + self._is_reconnection = False else: logger.info("WebSocketManager: Connection opened (Initial).") - self.service_manager.on_connected() + self._service_manager.on_connected() - self.is_reconnection = False + self._is_reconnection = False def _on_message(self, ws, message): logger.debug(f"WebSocketManager: Received message: {message}") - self.service_manager.on_message_received(message) + self._service_manager.on_message_received(message) def _on_error(self, ws, error): logger.error(f"WebSocketManager: WebSocket error: {error}") - self.service_manager.on_error(error) + self._service_manager.on_error(error) def _on_close(self, ws, close_status_code, close_msg): logger.warning(f"WebSocketManager: WebSocket connection closed: {close_status_code} - {close_msg}") - self.service_manager.on_disconnected(close_status_code, close_msg) + self._service_manager.on_disconnected(close_status_code, close_msg) # If the session ended due to maximum duration, attempt to reconnect if close_status_code == 1001 and "maximum duration of 15 minutes" in close_msg: logger.debug("WebSocketManager: Session ended due to maximum duration. Reconnecting...") - if self.options.enable_auto_reconnect: + if self._options.enable_auto_reconnect: self._schedule_reconnect() def _schedule_reconnect(self): logger.info("WebSocketManager: Scheduling reconnection...") - time.sleep(self.reconnect_delay) - self.is_reconnection = True + time.sleep(self._reconnect_delay) + self._is_reconnection = True self.connect()