Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agents-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"pillow>=11.3.0",
"numpy>=1.24.0",
"mcp>=1.16.0",
"torchvision>=0.23.0",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion agents-core/vision_agents/core/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ async def on_audio_received(event: AudioReceivedEvent):
return

if self.turn_detection is not None:
await self.turn_detection.process_audio(pcm, participant.user_id)
await self.turn_detection.process_audio(pcm, participant, conversation=self.conversation)

await self._reply_to_audio(pcm, participant)

Expand Down
1 change: 0 additions & 1 deletion agents-core/vision_agents/core/turn_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
TurnEvent,
TurnEventData,
TurnDetector,
TurnDetection,
)
from .events import (
TurnStartedEvent,
Expand Down
79 changes: 13 additions & 66 deletions agents-core/vision_agents/core/turn_detection/turn_detection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional, Dict, Any, Callable, Protocol
from typing import Optional, Dict, Any, Callable
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
import uuid
from getstream.video.rtc.track_util import PcmData
from vision_agents.core.events.manager import EventManager
from vision_agents.core.events import PluginInitializedEvent
from . import events
from ..agents.conversation import Conversation
from ..edge.types import Participant


class TurnEvent(Enum):
Expand All @@ -33,43 +34,6 @@ class TurnEventData:
EventListener = Callable[[TurnEventData], None]


class TurnDetection(Protocol):
"""Turn Detection shape definition used by the Agent class"""

events: EventManager

def is_detecting(self) -> bool:
"""Check if turn detection is currently active."""
...

# --- Unified high-level interface used by Agent ---
def start(self) -> None:
"""Start detection (convenience alias to start_detection)."""
...

def stop(self) -> None:
"""Stop detection (convenience alias to stop_detection)."""
...

async def process_audio(
self,
audio_data: PcmData,
user_id: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Ingest PcmData audio for a user.

The implementation should track participants internally as audio comes in.
Use the event system (events.send) to notify when turns change.

Args:
audio_data: PcmData object containing audio samples from Stream
user_id: Identifier for the user providing the audio
metadata: Optional additional metadata about the audio
"""
...


class TurnDetector(ABC):
"""Base implementation for turn detection with common functionality."""

Expand All @@ -79,22 +43,11 @@ def __init__(
provider_name: Optional[str] = None
) -> None:
self._confidence_threshold = confidence_threshold
self._is_detecting = False
self.is_active = False
self.session_id = str(uuid.uuid4())
self.provider_name = provider_name or self.__class__.__name__
self.events = EventManager()
self.events.register_events_from_module(events, ignore_not_compatible=True)
self.events.send(PluginInitializedEvent(
session_id=self.session_id,
plugin_name=self.provider_name,
plugin_type="TurnDetection",
provider=self.provider_name,
))

@abstractmethod
def is_detecting(self) -> bool:
"""Check if turn detection is currently active."""
return self._is_detecting

def _emit_turn_event(
self, event_type: TurnEvent, event_data: TurnEventData
Expand Down Expand Up @@ -129,29 +82,23 @@ def _emit_turn_event(
async def process_audio(
self,
audio_data: PcmData,
user_id: str,
metadata: Optional[Dict[str, Any]] = None,
participant: Participant,
conversation: Optional[Conversation],
) -> None:
"""Ingest PcmData audio for a user.

The implementation should track participants internally as audio comes in.
Use the event system (emit/on) to notify when turns change.
"""Process the audio and trigger turn start or turn end events

Args:
audio_data: PcmData object containing audio samples from Stream
user_id: Identifier for the user providing the audio
metadata: Optional additional metadata about the audio
participant: Participant that's speaking, includes user data
conversation: Transcription/ chat history, sometimes useful for turn detection
"""

...

# Convenience aliases to align with the unified protocol expected by Agent
@abstractmethod
def start(self) -> None:
"""Start detection (alias for start_detection)."""
...
"""Some turn detection systems want to run warmup etc here"""
self.is_active = True

@abstractmethod
def stop(self) -> None:
"""Stop detection (alias for stop_detection)."""
...
"""Again, some turn detection systems want to run cleanup here"""
self.is_active = False
60 changes: 60 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,66 @@ def mia_audio_48khz():
return pcm


@pytest.fixture
def mia_audio_48khz_chunked():
"""Load mia.mp3 and yield 48kHz PCM data in 20ms chunks."""
audio_file_path = os.path.join(get_assets_dir(), "mia.mp3")

# Load audio file using PyAV
container = av.open(audio_file_path)
audio_stream = container.streams.audio[0]
original_sample_rate = audio_stream.sample_rate
target_rate = 48000

# Create resampler if needed
resampler = None
if original_sample_rate != target_rate:
resampler = av.AudioResampler(
format='s16',
layout='mono',
rate=target_rate
)

# Read all audio frames
samples = []
for frame in container.decode(audio_stream):
# Resample if needed
if resampler:
frame = resampler.resample(frame)[0]

# Convert to numpy array
frame_array = frame.to_ndarray()
if len(frame_array.shape) > 1:
# Convert stereo to mono
frame_array = np.mean(frame_array, axis=0)
samples.append(frame_array)

# Concatenate all samples
samples = np.concatenate(samples)

# Convert to int16
samples = samples.astype(np.int16)
container.close()

# Calculate chunk size for 20ms at 48kHz
chunk_size = int(target_rate * 0.020) # 960 samples per 20ms

# Yield chunks of audio
chunks = []
for i in range(0, len(samples), chunk_size):
chunk_samples = samples[i:i + chunk_size]

# Create PCM data for this chunk
pcm_chunk = PcmData(
samples=chunk_samples,
sample_rate=target_rate,
format="s16"
)
chunks.append(pcm_chunk)

return chunks


@pytest.fixture
def golf_swing_image():
"""Load golf_swing.png image and return as bytes."""
Expand Down
22 changes: 22 additions & 0 deletions docs/ai/instructions/ai-turn-detector.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## Turn Detector

Here's a minimal example

```python

class MyTurnDetector(TurnDetector):
async def process_audio(
self,
audio_data: PcmData,
user_id: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:

# turn end
self._emit_turn_event(TurnEvent.TURN_ENDED, event_data)

# turn start
self._emit_turn_event(TurnEvent.TURN_STARTED, event_data)


```
2 changes: 1 addition & 1 deletion examples/01_simple_agent_example/simple_agent_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def start_agent() -> None:
tts=cartesia.TTS(),
stt=deepgram.STT(),
turn_detection=smart_turn.TurnDetection(
buffer_duration=2.0, confidence_threshold=0.5
buffer_in_seconds=2.0, confidence_threshold=0.5
), # Enable turn detection with FAL/ Smart turn
# vad=silero.VAD(),
# realtime version (vad, tts and stt not needed)
Expand Down
2 changes: 1 addition & 1 deletion plugins/aws/example/aws_qwen_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def start_agent() -> None:
llm=aws.LLM(model="qwen.qwen3-32b-v1:0"),
tts=cartesia.TTS(),
stt=deepgram.STT(),
turn_detection=smart_turn.TurnDetection(buffer_duration=2.0, confidence_threshold=0.5),
turn_detection=smart_turn.TurnDetection(buffer_in_seconds=2.0, confidence_threshold=0.5),
# Enable turn detection with FAL/ Smart turn
)
await agent.create_user()
Expand Down
2 changes: 1 addition & 1 deletion plugins/fish/example/fish_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def start_agent() -> None:
tts=fish.TTS(), # Uses Fish Audio for text-to-speech
stt=fish.STT(), # Uses Fish Audio for speech-to-text
llm=gemini.LLM("gemini-2.0-flash"),
turn_detection=smart_turn.TurnDetection(buffer_duration=2.0, confidence_threshold=0.5),
turn_detection=smart_turn.TurnDetection(buffer_in_seconds=2.0, confidence_threshold=0.5),
)

await agent.create_user()
Expand Down
10 changes: 8 additions & 2 deletions plugins/krisp/vision_agents/plugins/krisp/turn_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
from getstream.audio.utils import resample_audio
from getstream.video.rtc.track_util import PcmData
from vision_agents.core.agents import Conversation
from vision_agents.core.edge.types import Participant
from vision_agents.core.turn_detection import (
TurnDetector,
TurnStartedEvent,
Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(
self._krisp_instance = None
self._buffer: Optional[bytearray] = None
self._turn_in_progress = False
self._is_detecting = False

def _initialize_krisp(self):
try:
Expand All @@ -85,15 +88,18 @@ def is_detecting(self) -> bool:
async def process_audio(
self,
audio_data: PcmData,
user_id: str,
metadata: Optional[Dict[str, Any]] = None,
participant: Participant,
conversation: Optional[Conversation] = None,
) -> None:
if not self.is_detecting():
return

if self._krisp_instance is None:
self.logger.error("Krisp instance is not initialized. Call start() first.")
return

user_id = participant.user_id
metadata = None # Can be extended if needed from participant/conversation

# Validate sample format
valid_formats = ["int16", "s16", "pcm_s16le"]
Expand Down
3 changes: 1 addition & 2 deletions plugins/openrouter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ uv pip install vision-agents-plugins-openrouter
```python
from vision_agents.plugins import openrouter, getstream, elevenlabs, cartesia, deepgram, smart_turn


agent = Agent(
edge=getstream.Edge(),
agent_user=User(name="OpenRouter AI"),
Expand All @@ -28,7 +27,7 @@ agent = Agent(
tts=elevenlabs.TTS(),
stt=deepgram.STT(),
turn_detection=smart_turn.TurnDetection(
buffer_duration=2.0, confidence_threshold=0.5
buffer_in_seconds=2.0, confidence_threshold=0.5
)
)
```
2 changes: 1 addition & 1 deletion plugins/openrouter/example/openrouter_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def start_agent() -> None:
tts=elevenlabs.TTS(),
stt=deepgram.STT(),
turn_detection=smart_turn.TurnDetection(
buffer_duration=2.0, confidence_threshold=0.5
buffer_in_seconds=2.0, confidence_threshold=0.5
)
)
await agent.create_user()
Expand Down
2 changes: 1 addition & 1 deletion plugins/sample_plugin/example/my_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def start_agent() -> None:
llm=aws.LLM(model="qwen.qwen3-32b-v1:0"),
tts=cartesia.TTS(),
stt=deepgram.STT(),
turn_detection=smart_turn.TurnDetection(buffer_duration=2.0, confidence_threshold=0.5),
turn_detection=smart_turn.TurnDetection(buffer_in_seconds=2.0, confidence_threshold=0.5),
# Enable turn detection with FAL/ Smart turn
)
await agent.create_user()
Expand Down
55 changes: 55 additions & 0 deletions plugins/smart_turn/tests/test_smart_turn_td.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio

import pytest

from plugins.smart_turn.vision_agents.plugins.smart_turn import TurnDetection
from vision_agents.core.agents.conversation import InMemoryConversation
from vision_agents.core.edge.types import Participant
from vision_agents.core.turn_detection import TurnStartedEvent, TurnEndedEvent

import logging

logger = logging.getLogger(__name__)

class TestSmartTurnTD:

@pytest.fixture
async def td(self):
td = TurnDetection()
try:
td.start()
yield td
finally:
td.stop()

async def test_turn_detection(self, td, mia_audio_48khz_chunked):
participant = Participant(user_id="mia", original={})
conversation = InMemoryConversation(instructions="be nice", messages=[])
event_order = []


# Subscribe to events
@td.events.subscribe
async def on_start(event: TurnStartedEvent):
logger.info(f"Smart turn turn started on {event.session_id}")
event_order.append("start")

@td.events.subscribe
async def on_stop(event: TurnEndedEvent):
logger.info(f"Smart turn turn ended on {event.session_id}")
event_order.append("stop")

# Process each 20ms audio chunk
chunks = list(mia_audio_48khz_chunked)
logger.info("len %d", len(chunks))
i = 0
for chunk in chunks:
i += 1
await td.process_audio(chunk, participant, conversation)
await asyncio.sleep(0.001)

await asyncio.sleep(5)

# Verify that turn detection is working - we should get at least some turn events
# With continuous processing, we may get multiple start/stop cycles
assert event_order == ["start", "stop"]
Loading
Loading