From 154c46ae37118eed0c5227339afd4dc077d2167f Mon Sep 17 00:00:00 2001 From: MyButtermilk <153296172+MyButtermilk@users.noreply.github.com> Date: Wed, 19 Nov 2025 21:31:02 +0100 Subject: [PATCH] Improve session handling and add API tests --- backend/PIPECAT_MIGRATION.md | 23 ++ backend/app/main.py | 34 +-- backend/app/pipecat/__init__.py | 6 + backend/app/pipecat/api.py | 66 +++++ backend/app/pipecat/components.py | 262 ++++++++++++++++++ backend/app/pipecat/schemas.py | 50 ++++ backend/app/pipecat/session.py | 44 +++ backend/app/pipecat/settings.py | 41 +++ backend/app/services/model_registry.py | 93 +------ backend/app/services/transcription_service.py | 134 +-------- backend/app/services/vad.py | 75 +---- backend/requirements.txt | 1 + backend/tests/conftest.py | 9 + backend/tests/test_pipecat_api.py | 90 ++++++ 14 files changed, 615 insertions(+), 313 deletions(-) create mode 100644 backend/PIPECAT_MIGRATION.md create mode 100644 backend/app/pipecat/__init__.py create mode 100644 backend/app/pipecat/api.py create mode 100644 backend/app/pipecat/components.py create mode 100644 backend/app/pipecat/schemas.py create mode 100644 backend/app/pipecat/session.py create mode 100644 backend/app/pipecat/settings.py create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/test_pipecat_api.py diff --git a/backend/PIPECAT_MIGRATION.md b/backend/PIPECAT_MIGRATION.md new file mode 100644 index 0000000..0447533 --- /dev/null +++ b/backend/PIPECAT_MIGRATION.md @@ -0,0 +1,23 @@ +# Pipecat backend alignment + +The FastAPI backend now mirrors Pipecat's session-aware routing. Existing `.env` +values remain compatible through the `PipecatSettings` shim, which reuses the +legacy `Settings` model while exposing `asr_model_path`, `tokenizer_path`, and +`vad_model_path` aliases expected by Pipecat components. + +## Routing changes +- `POST /api/sessions` creates or upserts a Pipecat session (you can provide + your own `session_id`). +- `POST /api/sessions/{session_id}/transcriptions` sends audio for that session. +- `POST /api/transcriptions` remains as a compatibility shim that routes through + a default session. + +## Settings surface +- The existing environment variables continue to work. Paths defined under + `models__parakeet_model_path`, `models__parakeet_tokenizer_path`, and + `models__silero_vad_path` are now surfaced to Pipecat as + `asr_model_path`, `tokenizer_path`, and `vad_model_path`. +- API prefix remains driven by `API_PREFIX` (`/api` by default) to match + Pipecat's defaults. + +No data migrations are required; model downloads remain in the same locations. diff --git a/backend/app/main.py b/backend/app/main.py index f637740..b755964 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,20 +1,12 @@ from __future__ import annotations -import json -from typing import Annotated - -from fastapi import FastAPI, File, Form, UploadFile +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from loguru import logger -from app.config import Settings, get_settings -from app.models.requests import TranscriptionRequest -from app.models.responses import TranscriptionResult -from app.services.transcription_service import ParakeetTranscriptionService +from app.pipecat import router as pipecat_router def create_app() -> FastAPI: - settings = get_settings() app = FastAPI(title="Parakeet Local", version="1.0.0") app.add_middleware( @@ -25,27 +17,7 @@ def create_app() -> FastAPI: allow_headers=["*"], ) - service = ParakeetTranscriptionService(settings) - - @app.get(f"{settings.api_prefix}/health") - async def healthcheck() -> dict[str, str]: - return {"status": "ok"} - - @app.post(f"{settings.api_prefix}/transcriptions", response_model=TranscriptionResult) - async def transcribe_audio( - file: UploadFile = File(...), - payload: Annotated[str | None, Form()] = None, - ) -> TranscriptionResult: - body = TranscriptionRequest() - if payload: - try: - body = TranscriptionRequest(**json.loads(payload)) - except json.JSONDecodeError as exc: - logger.warning("Failed to decode payload JSON: {}", exc) - audio_bytes = await file.read() - result = service.transcribe_bytes(audio_bytes, request=body, filename=file.filename) - return result - + app.include_router(pipecat_router) return app diff --git a/backend/app/pipecat/__init__.py b/backend/app/pipecat/__init__.py new file mode 100644 index 0000000..b018ddd --- /dev/null +++ b/backend/app/pipecat/__init__.py @@ -0,0 +1,6 @@ +"""Pipecat-compatible backend wiring for Parakeet Local.""" + +from app.pipecat.api import router +from app.pipecat.settings import PipecatSettings, get_pipecat_settings + +__all__ = ["router", "PipecatSettings", "get_pipecat_settings"] diff --git a/backend/app/pipecat/api.py b/backend/app/pipecat/api.py new file mode 100644 index 0000000..2f08c30 --- /dev/null +++ b/backend/app/pipecat/api.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import json +from typing import Annotated + +from fastapi import APIRouter, File, Form, UploadFile +from loguru import logger + +from app.pipecat.schemas import ( + SessionCreateRequest, + SessionCreateResponse, + TranscriptionRequest, + TranscriptionResult, +) +from app.pipecat.session import SessionManager +from app.pipecat.settings import get_pipecat_settings + +router = APIRouter() + +settings = get_pipecat_settings() +session_manager = SessionManager(settings) + + +@router.get(f"{settings.api_prefix}/health") +async def healthcheck() -> dict[str, str]: + return {"status": "ok"} + + +@router.post(f"{settings.api_prefix}/sessions", response_model=SessionCreateResponse) +async def create_session(payload: SessionCreateRequest | None = None) -> SessionCreateResponse: + return session_manager.create_session(payload) + + +@router.post(f"{settings.api_prefix}/sessions/{{session_id}}/transcriptions", response_model=TranscriptionResult) +async def transcribe_audio( + session_id: str, + file: UploadFile = File(...), + payload: Annotated[str | None, Form()] = None, +) -> TranscriptionResult: + body = TranscriptionRequest(session_id=session_id) + if payload: + try: + body = TranscriptionRequest(**json.loads(payload)) + except json.JSONDecodeError as exc: + logger.warning("Failed to decode payload JSON: {}", exc) + body = TranscriptionRequest(session_id=session_id) + + if body.session_id and body.session_id != session_id: + logger.warning( + "Payload session_id %s does not match path session_id %s; using path value", + body.session_id, + session_id, + ) + body.session_id = session_id + audio_bytes = await file.read() + result = session_manager.transcribe(session_id, body, audio_bytes, filename=file.filename) + return result + + +@router.post(f"{settings.api_prefix}/transcriptions", response_model=TranscriptionResult) +async def transcribe_with_default_session( + file: UploadFile = File(...), + payload: Annotated[str | None, Form()] = None, +) -> TranscriptionResult: + default_session = session_manager.create_session(SessionCreateRequest(session_id="default")) + return await transcribe_audio(default_session.session_id, file=file, payload=payload) diff --git a/backend/app/pipecat/components.py b/backend/app/pipecat/components.py new file mode 100644 index 0000000..38b6fe2 --- /dev/null +++ b/backend/app/pipecat/components.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import json +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence + +import numpy as np +import onnxruntime as ort +import requests +from loguru import logger + +from app.pipecat.settings import PipecatSettings, get_pipecat_settings +from app.utils.audio_utils import load_audio, save_waveform, split_segments + +PARAKEET_MODEL_URL = ( + "https://huggingface.co/onnx-community/parakeet-ctc-v3/resolve/main/model.onnx?download=1" +) +PARAKEET_TOKENIZER_URL = ( + "https://huggingface.co/onnx-community/parakeet-ctc-v3/resolve/main/tokenizer.json?download=1" +) +SILERO_VAD_URL = "https://huggingface.co/snakers4/silero-vad/resolve/main/models/silero_vad.onnx?download=1" + + +class PipecatModelRegistry: + """Pipecat-flavored registry around the Parakeet model assets.""" + + def __init__(self, settings: PipecatSettings | None = None) -> None: + self.settings = settings or get_pipecat_settings() + self._sessions: Dict[str, ort.InferenceSession] = {} + self._tokenizer: Dict[str, Any] | None = None + + def ensure_resources(self) -> None: + self._download_if_missing(self.settings.asr_model_path, PARAKEET_MODEL_URL) + self._download_if_missing(self.settings.tokenizer_path, PARAKEET_TOKENIZER_URL) + self._download_if_missing(self.settings.vad_model_path, SILERO_VAD_URL) + + def get_asr_session(self) -> ort.InferenceSession: + if "asr" not in self._sessions: + self.ensure_resources() + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + self._sessions["asr"] = ort.InferenceSession( + str(self.settings.asr_model_path), providers=providers + ) + logger.info("Loaded ASR model from {}", self.settings.asr_model_path) + return self._sessions["asr"] + + def get_vad_session(self) -> ort.InferenceSession: + if "vad" not in self._sessions: + self.ensure_resources() + providers = ["CPUExecutionProvider"] + self._sessions["vad"] = ort.InferenceSession( + str(self.settings.vad_model_path), providers=providers + ) + logger.info("Loaded VAD model from {}", self.settings.vad_model_path) + return self._sessions["vad"] + + def get_tokenizer(self) -> Dict[str, Any]: + if self._tokenizer is None: + self.ensure_resources() + with self.settings.tokenizer_path.open("r", encoding="utf-8") as handle: + self._tokenizer = json.load(handle) + return self._tokenizer + + def _download_if_missing(self, path: Path, url: str) -> None: + if path.exists(): + return + logger.info("Downloading %s", url) + path.parent.mkdir(parents=True, exist_ok=True) + response = requests.get(url, timeout=60, stream=True) + response.raise_for_status() + with path.open("wb") as destination: + shutil.copyfileobj(response.raw, destination) + logger.info("Downloaded resource to %s", path) + + +@dataclass +class DecoderVocabulary: + tokens: Sequence[str] + blank_id: int + word_delimiter: str = " " + + @classmethod + def from_tokenizer_dict(cls, tokenizer_dict: dict) -> "DecoderVocabulary": + model = tokenizer_dict.get("model", {}) + vocab = model.get("vocab", {}) + tokens = [""] * len(vocab) + for token, index in vocab.items(): + tokens[index] = token.replace("▁", " ") + blank_id = tokenizer_dict.get("added_tokens", [{}])[0].get("id", len(tokens) - 1) + return cls(tokens=tokens, blank_id=blank_id) + + def decode(self, token_ids: Sequence[int]) -> str: + pieces: List[str] = [] + previous = None + for token_id in token_ids: + if token_id == self.blank_id or token_id == previous: + previous = token_id + continue + pieces.append(self.tokens[token_id]) + previous = token_id + text = "".join(pieces) + return " ".join(text.split()) + + +@dataclass +class SpeechSegment: + start: int + end: int + + +class PipecatVAD: + def __init__(self, settings: PipecatSettings | None = None) -> None: + self.settings = settings or get_pipecat_settings() + self.registry = PipecatModelRegistry(self.settings) + self.session: ort.InferenceSession = self.registry.get_vad_session() + + def detect(self, waveform: np.ndarray, sample_rate: int, threshold: float | None = None) -> List[SpeechSegment]: + threshold = threshold or self.settings.vad_threshold + stride = 512 + window = 1536 + probs: List[float] = [] + for start in range(0, len(waveform) - window, stride): + chunk = waveform[start : start + window] + ort_inputs = { + "input": chunk.reshape(1, -1), + "sr": np.array(sample_rate, dtype=np.int64), + } + (prob,) = self.session.run(None, ort_inputs) + probs.append(float(prob.squeeze())) + + speech_segments: List[SpeechSegment] = [] + active = False + seg_start = 0 + for index, prob in enumerate(probs): + time_start = index * stride + time_end = time_start + window + if prob >= threshold and not active: + active = True + seg_start = time_start + elif prob < threshold and active: + active = False + if time_end - seg_start >= self.settings.vad_min_speech_seconds * sample_rate: + speech_segments.append(SpeechSegment(seg_start, time_end)) + + if active: + speech_segments.append(SpeechSegment(seg_start, len(waveform))) + + merged: List[SpeechSegment] = [] + for segment in speech_segments: + if not merged: + merged.append(segment) + continue + prev = merged[-1] + gap = segment.start - prev.end + if gap / sample_rate <= self.settings.vad_min_silence_seconds: + merged[-1] = SpeechSegment(prev.start, segment.end) + else: + merged.append(segment) + + return merged + + def extract(self, waveform: np.ndarray, segments: Iterable[SpeechSegment]) -> np.ndarray: + pieces = [waveform[segment.start : segment.end] for segment in segments] + if not pieces: + return waveform + return np.concatenate(pieces) + + +class PipecatDecoder: + def __init__(self, registry: PipecatModelRegistry) -> None: + tokenizer = registry.get_tokenizer() + self.vocab = DecoderVocabulary.from_tokenizer_dict(tokenizer) + + def decode(self, logits: np.ndarray) -> List[int]: + token_ids = np.argmax(logits, axis=-1).flatten().tolist() + return token_ids + + def tokens_to_text(self, token_ids: Sequence[int]) -> str: + return self.vocab.decode(token_ids) + + +class PipecatPipeline: + def __init__(self, settings: PipecatSettings | None = None) -> None: + self.settings = settings or get_pipecat_settings() + self.registry = PipecatModelRegistry(self.settings) + self.vad = PipecatVAD(self.settings) + self.decoder = PipecatDecoder(self.registry) + self.session = self.registry.get_asr_session() + + def transcribe(self, audio_bytes: bytes, *, request_settings: dict | None = None, filename: str | None = None): + from app.pipecat.schemas import TranscriptSegment, TranscriptionResult + + waveform, sample_rate = load_audio(audio_bytes, self.settings.sample_rate) + request_settings = request_settings or {} + vad_enabled = request_settings.get("enable_vad", True) + if vad_enabled: + vad_segments = self.vad.detect( + waveform, sample_rate, threshold=request_settings.get("vad_threshold") + ) + if vad_segments: + logger.debug("Detected %d speech segments via VAD", len(vad_segments)) + waveform = self.vad.extract(waveform, vad_segments) + + text_segments: List[TranscriptSegment] = [] + transcript_parts: List[str] = [] + processed_duration = len(waveform) / sample_rate + offset = 0.0 + + for segment_waveform in split_segments( + waveform, sample_rate, self.settings.max_segment_seconds + ): + inputs = self.session.get_inputs() + if not inputs: + raise RuntimeError("ASR ONNX session has no inputs") + input_name = inputs[0].name + audio = segment_waveform.astype(np.float32)[np.newaxis, :] + outputs = self.session.run(None, {input_name: audio}) + tokens = self.decoder.decode(outputs[0]) + text = self.decoder.tokens_to_text(tokens) + transcript_parts.append(text) + end_time = offset + len(segment_waveform) / sample_rate + text_segments.append( + TranscriptSegment( + text=text, + start=offset, + end=end_time, + speaker=request_settings.get("speaker_hint"), + confidence=None, + ) + ) + offset = end_time + + transcript_text = " ".join(part for part in transcript_parts if part) + if request_settings.get("enable_punctuation", True): + transcript_text = self._restore_punctuation(transcript_text) + for segment in text_segments: + segment.text = self._restore_punctuation(segment.text) + + settings_applied = {k: v for k, v in request_settings.items() if v is not None} + + if filename: + target = Path(self.settings.storage_dir) / (request_settings.get("request_id") or "transcript") + save_waveform(target.with_suffix(".wav"), waveform, sample_rate) + + return TranscriptionResult( + request_id=request_settings.get("request_id"), + text=transcript_text, + duration=processed_duration, + segments=text_segments, + settings_applied=settings_applied, + ) + + def _restore_punctuation(self, text: str) -> str: + normalized = text.strip() + if not normalized: + return normalized + normalized = normalized[0].upper() + normalized[1:] + if normalized[-1] not in {".", "?", "!"}: + normalized += "." + return normalized diff --git a/backend/app/pipecat/schemas.py b/backend/app/pipecat/schemas.py new file mode 100644 index 0000000..0c51511 --- /dev/null +++ b/backend/app/pipecat/schemas.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class SessionSettings(BaseModel): + """Pipecat session options mapped to legacy request settings.""" + + language: Optional[str] = Field(default=None) + enable_punctuation: bool = Field(default=True) + enable_vad: bool = Field(default=True) + vad_threshold: Optional[float] = Field(default=None) + diarization: bool = Field(default=False) + + +class SessionCreateRequest(BaseModel): + session_id: Optional[str] = Field(default=None, description="Optional client-chosen session id") + settings: SessionSettings = Field(default_factory=SessionSettings) + + +class SessionCreateResponse(BaseModel): + session_id: str + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class TranscriptSegment(BaseModel): + text: str = Field(..., description="Recognized text for this segment.") + start: float = Field(..., description="Segment start time in seconds.") + end: float = Field(..., description="Segment end time in seconds.") + speaker: Optional[str] = Field(default=None) + confidence: Optional[float] = Field(default=None) + + +class TranscriptionRequest(BaseModel): + request_id: Optional[str] = Field(default=None) + session_id: Optional[str] = Field(default=None) + settings: SessionSettings = Field(default_factory=SessionSettings) + + +class TranscriptionResult(BaseModel): + request_id: Optional[str] = Field(default=None) + session_id: Optional[str] = Field(default=None) + created_at: datetime = Field(default_factory=datetime.utcnow) + text: str = Field(...) + duration: float = Field(...) + segments: List[TranscriptSegment] = Field(default_factory=list) + settings_applied: dict = Field(default_factory=dict) diff --git a/backend/app/pipecat/session.py b/backend/app/pipecat/session.py new file mode 100644 index 0000000..a95172b --- /dev/null +++ b/backend/app/pipecat/session.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from typing import Dict + +from app.pipecat.components import PipecatPipeline +from app.pipecat.schemas import SessionCreateRequest, SessionCreateResponse, SessionSettings, TranscriptionRequest +from app.pipecat.settings import PipecatSettings, get_pipecat_settings + + +@dataclass +class PipecatSession: + session_id: str + settings: SessionSettings + pipeline: PipecatPipeline = field(default_factory=PipecatPipeline) + + +class SessionManager: + def __init__(self, settings: PipecatSettings | None = None) -> None: + self.settings = settings or get_pipecat_settings() + self.sessions: Dict[str, PipecatSession] = {} + + def create_session(self, payload: SessionCreateRequest | None = None) -> SessionCreateResponse: + payload = payload or SessionCreateRequest() + session_id = payload.session_id or str(uuid.uuid4()) + pipeline = PipecatPipeline(self.settings) + self.sessions[session_id] = PipecatSession(session_id=session_id, settings=payload.settings, pipeline=pipeline) + return SessionCreateResponse(session_id=session_id) + + def get_session(self, session_id: str) -> PipecatSession: + if session_id not in self.sessions: + self.sessions[session_id] = PipecatSession( + session_id=session_id, settings=SessionSettings(), pipeline=PipecatPipeline(self.settings) + ) + return self.sessions[session_id] + + def transcribe(self, session_id: str, request: TranscriptionRequest, audio_bytes: bytes, filename: str | None = None): + session = self.get_session(session_id) + settings = request.settings.dict(exclude_none=True) + settings["request_id"] = request.request_id + result = session.pipeline.transcribe(audio_bytes, request_settings=settings, filename=filename) + result.session_id = session_id + return result diff --git a/backend/app/pipecat/settings.py b/backend/app/pipecat/settings.py new file mode 100644 index 0000000..6990c18 --- /dev/null +++ b/backend/app/pipecat/settings.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +"""Pipecat-facing settings shims. + +These objects mirror the upstream Pipecat FastAPI backend configuration while +continuing to honor the existing ``Settings`` model and ``.env`` values used by +Parakeet Local. This keeps environment variables stable while letting Pipecat +code import familiar names. +""" + +from pydantic import Field + +from app.config import Settings as ParakeetSettings, get_settings as get_parakeet_settings + + +class PipecatSettings(ParakeetSettings): + """Alias for the legacy Parakeet settings with Pipecat-friendly names.""" + + api_prefix: str = Field(default="/api", description="Pipecat API prefix.") + + class Config(ParakeetSettings.Config): + env_file = ParakeetSettings.Config.env_file + env_nested_delimiter = ParakeetSettings.Config.env_nested_delimiter + + @property + def asr_model_path(self): + """Pipecat naming shim for the Parakeet model path.""" + + return self.models.parakeet_model_path + + @property + def tokenizer_path(self): + return self.models.parakeet_tokenizer_path + + @property + def vad_model_path(self): + return self.models.silero_vad_path + + +def get_pipecat_settings() -> PipecatSettings: + return PipecatSettings(**get_parakeet_settings().dict()) diff --git a/backend/app/services/model_registry.py b/backend/app/services/model_registry.py index 095a8d0..1172a0c 100644 --- a/backend/app/services/model_registry.py +++ b/backend/app/services/model_registry.py @@ -1,96 +1,15 @@ from __future__ import annotations -import json -import shutil -from pathlib import Path -from typing import Any, Dict -import onnxruntime as ort -import requests -from loguru import logger +"""Backward-compatible shim for Pipecat model registry wiring.""" -from app.config import Settings, get_settings +from app.pipecat.components import PipecatModelRegistry +from app.pipecat.settings import PipecatSettings -PARAKEET_MODEL_URL = ( - "https://huggingface.co/onnx-community/parakeet-ctc-v3/resolve/main/model.onnx?download=1" -) -PARAKEET_TOKENIZER_URL = ( - "https://huggingface.co/onnx-community/parakeet-ctc-v3/resolve/main/tokenizer.json?download=1" -) -SILERO_VAD_URL = "https://huggingface.co/snakers4/silero-vad/resolve/main/models/silero_vad.onnx?download=1" +_registry: PipecatModelRegistry | None = None -class ModelRegistry: - """Handle model downloading and lazy loading for the ASR pipeline.""" - - def __init__(self, settings: Settings | None = None) -> None: - self.settings = settings or get_settings() - self._sessions: Dict[str, ort.InferenceSession] = {} - self._tokenizer: Dict[str, Any] | None = None - - def ensure_resources(self) -> None: - """Ensure that all required model files are available locally.""" - - self._download_if_missing( - self.settings.models.parakeet_model_path, PARAKEET_MODEL_URL - ) - self._download_if_missing( - self.settings.models.parakeet_tokenizer_path, PARAKEET_TOKENIZER_URL - ) - self._download_if_missing(self.settings.models.silero_vad_path, SILERO_VAD_URL) - - def get_asr_session(self) -> ort.InferenceSession: - """Return the cached ASR ONNX session.""" - - if "parakeet" not in self._sessions: - self.ensure_resources() - providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] - self._sessions["parakeet"] = ort.InferenceSession( - str(self.settings.models.parakeet_model_path), - providers=providers, - ) - logger.info("Loaded Parakeet v3 ASR model from {}", self.settings.models.parakeet_model_path) - return self._sessions["parakeet"] - - def get_vad_session(self) -> ort.InferenceSession: - """Return the cached Silero VAD session.""" - - if "silero_vad" not in self._sessions: - self.ensure_resources() - providers = ["CPUExecutionProvider"] - self._sessions["silero_vad"] = ort.InferenceSession( - str(self.settings.models.silero_vad_path), - providers=providers, - ) - logger.info("Loaded Silero VAD model from {}", self.settings.models.silero_vad_path) - return self._sessions["silero_vad"] - - def get_tokenizer(self) -> Dict[str, Any]: - """Return the tokenizer metadata for the ASR model.""" - - if self._tokenizer is None: - self.ensure_resources() - with self.settings.models.parakeet_tokenizer_path.open("r", encoding="utf-8") as handle: - self._tokenizer = json.load(handle) - return self._tokenizer - - def _download_if_missing(self, path: Path, url: str) -> None: - if path.exists(): - return - - logger.info("Downloading %s", url) - path.parent.mkdir(parents=True, exist_ok=True) - response = requests.get(url, timeout=60, stream=True) - response.raise_for_status() - with path.open("wb") as destination: - shutil.copyfileobj(response.raw, destination) - logger.info("Downloaded resource to %s", path) - - -_registry: ModelRegistry | None = None - - -def get_registry(settings: Settings | None = None) -> ModelRegistry: +def get_registry(settings: PipecatSettings | None = None) -> PipecatModelRegistry: global _registry if _registry is None: - _registry = ModelRegistry(settings=settings) + _registry = PipecatModelRegistry(settings=settings) return _registry diff --git a/backend/app/services/transcription_service.py b/backend/app/services/transcription_service.py index 0ce12f9..197fa44 100644 --- a/backend/app/services/transcription_service.py +++ b/backend/app/services/transcription_service.py @@ -1,58 +1,15 @@ from __future__ import annotations -import uuid -from dataclasses import dataclass -from pathlib import Path -from typing import List, Sequence +"""Legacy entry point redirected to Pipecat's pipeline.""" -import numpy as np -from loguru import logger - -from app.config import Settings, get_settings -from app.models.requests import TranscriptionRequest -from app.models.responses import TranscriptSegment, TranscriptionResult -from app.services.model_registry import get_registry -from app.services.vad import SileroVAD -from app.utils.audio_utils import load_audio, save_waveform, split_segments - - -@dataclass -class DecoderVocabulary: - tokens: Sequence[str] - blank_id: int - word_delimiter: str = " " - - @classmethod - def from_tokenizer_dict(cls, tokenizer_dict: dict) -> "DecoderVocabulary": - model = tokenizer_dict.get("model", {}) - vocab = model.get("vocab", {}) - tokens = [""] * len(vocab) - for token, index in vocab.items(): - tokens[index] = token.replace("▁", " ") - blank_id = tokenizer_dict.get("added_tokens", [{}])[0].get("id", len(tokens) - 1) - return cls(tokens=tokens, blank_id=blank_id) - - def decode(self, token_ids: Sequence[int]) -> str: - pieces: List[str] = [] - previous = None - for token_id in token_ids: - if token_id == self.blank_id or token_id == previous: - previous = token_id - continue - pieces.append(self.tokens[token_id]) - previous = token_id - text = "".join(pieces) - return " ".join(text.split()) +from app.pipecat.components import PipecatPipeline +from app.pipecat.settings import PipecatSettings +from app.pipecat.schemas import TranscriptionRequest, TranscriptionResult class ParakeetTranscriptionService: - def __init__(self, settings: Settings | None = None) -> None: - self.settings = settings or get_settings() - self.registry = get_registry(self.settings) - self.vad = SileroVAD(self.settings) - tokenizer = self.registry.get_tokenizer() - self.vocab = DecoderVocabulary.from_tokenizer_dict(tokenizer) - self.session = self.registry.get_asr_session() + def __init__(self, settings: PipecatSettings | None = None) -> None: + self.pipeline = PipecatPipeline(settings) def transcribe_bytes( self, @@ -61,77 +18,8 @@ def transcribe_bytes( filename: str | None = None, ) -> TranscriptionResult: request = request or TranscriptionRequest() - waveform, sample_rate = load_audio(audio_bytes, self.settings.sample_rate) - - if request.settings.enable_vad: - vad_segments = self.vad.detect( - waveform, sample_rate, threshold=request.settings.vad_threshold - ) - if vad_segments: - logger.debug("Detected %d speech segments via VAD", len(vad_segments)) - waveform = self.vad.extract(waveform, vad_segments) - - text_segments: List[TranscriptSegment] = [] - transcript_parts: List[str] = [] - processed_duration = len(waveform) / sample_rate - offset = 0.0 - - for segment_waveform in split_segments( - waveform, sample_rate, self.settings.max_segment_seconds - ): - tokens = self._infer(segment_waveform) - text = self.vocab.decode(tokens) - transcript_parts.append(text) - end_time = offset + len(segment_waveform) / sample_rate - text_segments.append( - TranscriptSegment( - text=text, - start=offset, - end=end_time, - speaker="SPEAKER_1" if request.settings.diarization else None, - confidence=None, - ) - ) - offset = end_time - - transcript_text = " ".join(part for part in transcript_parts if part) - - if request.settings.enable_punctuation: - transcript_text = self._restore_punctuation(transcript_text) - for segment in text_segments: - segment.text = self._restore_punctuation(segment.text) - - request_id = request.request_id or str(uuid.uuid4()) - settings_applied = request.settings.dict(exclude_none=True) - - if filename: - target = Path(self.settings.storage_dir) / request_id - save_waveform(target.with_suffix(".wav"), waveform, sample_rate) - - return TranscriptionResult( - request_id=request_id, - text=transcript_text, - duration=processed_duration, - segments=text_segments, - settings_applied=settings_applied, - ) - - def _infer(self, waveform: np.ndarray) -> Sequence[int]: - inputs = self.session.get_inputs() - if not inputs: - raise RuntimeError("Parakeet ONNX session has no inputs") - input_name = inputs[0].name - audio = waveform.astype(np.float32)[np.newaxis, :] - outputs = self.session.run(None, {input_name: audio}) - logits = outputs[0] - token_ids = np.argmax(logits, axis=-1).flatten().tolist() - return token_ids - - def _restore_punctuation(self, text: str) -> str: - normalized = text.strip() - if not normalized: - return normalized - normalized = normalized[0].upper() + normalized[1:] - if normalized[-1] not in {".", "?", "!"}: - normalized += "." - return normalized + settings = request.settings.dict(exclude_none=True) + settings["request_id"] = request.request_id + result = self.pipeline.transcribe(audio_bytes, request_settings=settings, filename=filename) + result.session_id = request.session_id + return result diff --git a/backend/app/services/vad.py b/backend/app/services/vad.py index 78b2a27..00feba9 100644 --- a/backend/app/services/vad.py +++ b/backend/app/services/vad.py @@ -1,76 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Iterable, List +"""Compatibility layer pointing to Pipecat's VAD wrapper.""" -import numpy as np -import onnxruntime as ort +from app.pipecat.components import PipecatVAD as SileroVAD, SpeechSegment -from app.config import Settings, get_settings -from app.services.model_registry import get_registry - - -@dataclass -class SpeechSegment: - start: int - end: int - - -class SileroVAD: - """Wrapper around the Silero Voice Activity Detection ONNX model.""" - - def __init__(self, settings: Settings | None = None) -> None: - self.settings = settings or get_settings() - self.registry = get_registry(self.settings) - self.session: ort.InferenceSession = self.registry.get_vad_session() - - def detect(self, waveform: np.ndarray, sample_rate: int, threshold: float | None = None) -> List[SpeechSegment]: - threshold = threshold or self.settings.vad_threshold - stride = 512 - window = 1536 - probs: List[float] = [] - for start in range(0, len(waveform) - window, stride): - chunk = waveform[start : start + window] - ort_inputs = { - "input": chunk.reshape(1, -1), - "sr": np.array(sample_rate, dtype=np.int64), - } - (prob,) = self.session.run(None, ort_inputs) - probs.append(float(prob.squeeze())) - - speech_segments: List[SpeechSegment] = [] - active = False - seg_start = 0 - for index, prob in enumerate(probs): - time_start = index * stride - time_end = time_start + window - if prob >= threshold and not active: - active = True - seg_start = time_start - elif prob < threshold and active: - active = False - if time_end - seg_start >= self.settings.vad_min_speech_seconds * sample_rate: - speech_segments.append(SpeechSegment(seg_start, time_end)) - - if active: - speech_segments.append(SpeechSegment(seg_start, len(waveform))) - - merged: List[SpeechSegment] = [] - for segment in speech_segments: - if not merged: - merged.append(segment) - continue - prev = merged[-1] - gap = segment.start - prev.end - if gap / sample_rate <= self.settings.vad_min_silence_seconds: - merged[-1] = SpeechSegment(prev.start, segment.end) - else: - merged.append(segment) - - return merged - - def extract(self, waveform: np.ndarray, segments: Iterable[SpeechSegment]) -> np.ndarray: - pieces = [waveform[segment.start : segment.end] for segment in segments] - if not pieces: - return waveform - return np.concatenate(pieces) +__all__ = ["SileroVAD", "SpeechSegment"] diff --git a/backend/requirements.txt b/backend/requirements.txt index c0302ad..c5862b1 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,3 +9,4 @@ requests==2.31.0 loguru==0.7.2 python-multipart==0.0.9 aiofiles==23.2.1 +pytest==8.4.2 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..47180da --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) diff --git a/backend/tests/test_pipecat_api.py b/backend/tests/test_pipecat_api.py new file mode 100644 index 0000000..962753e --- /dev/null +++ b/backend/tests/test_pipecat_api.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import asyncio +import io +import json + +from starlette.datastructures import UploadFile + +from app.pipecat import api +from app.pipecat.schemas import ( + SessionCreateRequest, + SessionCreateResponse, + TranscriptionRequest, + TranscriptionResult, +) + + +class DummySessionManager: + def __init__(self, session_id: str = "session-1") -> None: + self.session_id = session_id + self.created_payloads: list[SessionCreateRequest | None] = [] + self.transcriptions: list[tuple[str, TranscriptionRequest, bytes, str | None]] = [] + + def create_session(self, payload: SessionCreateRequest | None = None) -> SessionCreateResponse: + self.created_payloads.append(payload) + return SessionCreateResponse(session_id=self.session_id) + + def transcribe( + self, session_id: str, request: TranscriptionRequest, audio_bytes: bytes, filename: str | None = None + ) -> TranscriptionResult: + self.transcriptions.append((session_id, request, audio_bytes, filename)) + return TranscriptionResult( + session_id=session_id, + request_id=request.request_id, + text="hello", + duration=0.1, + segments=[], + settings_applied={"enable_punctuation": request.settings.enable_punctuation}, + ) + + +def test_healthcheck_returns_ok(monkeypatch): + api.session_manager = DummySessionManager() + + result = asyncio.run(api.healthcheck()) + + assert result == {"status": "ok"} + + +def test_create_session_uses_manager(monkeypatch): + dummy_manager = DummySessionManager(session_id="abc123") + api.session_manager = dummy_manager + + response = asyncio.run(api.create_session(SessionCreateRequest(session_id="abc123"))) + + assert response.session_id == "abc123" + assert dummy_manager.created_payloads[0].session_id == "abc123" + + +def test_transcribe_enforces_path_session_id(monkeypatch): + dummy_manager = DummySessionManager(session_id="path-id") + api.session_manager = dummy_manager + payload = json.dumps({"request_id": "req-1", "session_id": "payload-id", "settings": {"enable_vad": False}}) + upload = UploadFile(filename="sample.wav", file=io.BytesIO(b"audio-bytes")) + + response = asyncio.run(api.transcribe_audio("path-id", file=upload, payload=payload)) + + assert response.session_id == "path-id" + assert dummy_manager.transcriptions, "Transcription should be called" + session_id, request, audio_bytes, filename = dummy_manager.transcriptions[0] + assert session_id == "path-id" + assert request.session_id == "path-id" + assert request.request_id == "req-1" + assert audio_bytes == b"audio-bytes" + assert filename == "sample.wav" + + +def test_default_transcription_creates_session(monkeypatch): + dummy_manager = DummySessionManager(session_id="generated") + api.session_manager = dummy_manager + upload = UploadFile(filename="voice.wav", file=io.BytesIO(b"audio")) + + response = asyncio.run(api.transcribe_with_default_session(file=upload)) + + assert response.session_id == "generated" + assert dummy_manager.created_payloads, "Default session should be created" + assert dummy_manager.transcriptions, "Default session should be used for transcription" + session_id, request, *_ = dummy_manager.transcriptions[0] + assert session_id == "generated" + assert request.session_id == "generated"