Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions comps/asr/src/integrations/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import asyncio
import os
from typing import List
from typing import List, Union

import requests
from fastapi import File, Form, UploadFile
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(self, name: str, description: str, config: dict = None):

async def invoke(
self,
file: UploadFile = File(...), # Handling the uploaded file directly
file: Union[str, UploadFile], # accept base64 string or UploadFile
model: str = Form("openai/whisper-small"),
language: str = Form("english"),
prompt: str = Form(None),
Expand All @@ -41,28 +41,39 @@ async def invoke(
timestamp_granularities: List[str] = Form(None),
) -> AudioTranscriptionResponse:
"""Involve the ASR service to generate transcription for the provided input."""
# Read the uploaded file
file_contents = await file.read()
if isinstance(file, str):
data = {"audio": file}
# Send the file and model to the server
response = await asyncio.to_thread(
requests.post,
f"{self.base_url}/v1/asr",
json=data,
)
res = response.json()["asr_result"]
return AudioTranscriptionResponse(text=res)
else:
# Read the uploaded file
file_contents = await file.read()

# Prepare the files and data
files = {
"file": (file.filename, file_contents, file.content_type),
}
data = {
"model": model,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
"timestamp_granularities": timestamp_granularities,
}
# Prepare the files and data
files = {
"file": (file.filename, file_contents, file.content_type),
}
data = {
"model": model,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
"timestamp_granularities": timestamp_granularities,
}

# Send the file and model to the server
response = await asyncio.to_thread(
requests.post, f"{self.base_url}/v1/audio/transcriptions", files=files, data=data
)
res = response.json()["text"]
return AudioTranscriptionResponse(text=res)
# Send the file and model to the server
response = await asyncio.to_thread(
requests.post, f"{self.base_url}/v1/audio/transcriptions", files=files, data=data
)
res = response.json()["text"]
return AudioTranscriptionResponse(text=res)

def check_health(self) -> bool:
"""Checks the health of the embedding service.
Expand Down
10 changes: 8 additions & 2 deletions comps/asr/src/opea_asr_microservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os
import time
from typing import List
from typing import List, Union

from fastapi import File, Form, UploadFile
from integrations.whisper import OpeaWhisperAsr
Expand All @@ -19,12 +19,15 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.constants import MCPFuncType
from comps.cores.proto.api_protocol import AudioTranscriptionResponse

logger = CustomLogger("opea_asr_microservice")
logflag = os.getenv("LOGFLAG", False)

asr_component_name = os.getenv("ASR_COMPONENT_NAME", "OPEA_WHISPER_ASR")
enable_mcp = os.getenv("ENABLE_MCP", "").strip().lower() in {"true", "1", "yes"}

# Initialize OpeaComponentLoader
loader = OpeaComponentLoader(asr_component_name, description=f"OPEA ASR Component: {asr_component_name}")

Expand All @@ -37,10 +40,13 @@
port=9099,
input_datatype=Base64ByteStrDoc,
output_datatype=LLMParamsDoc,
enable_mcp=enable_mcp,
mcp_func_type=MCPFuncType.TOOL,
description="Convert audio to text.",
)
@register_statistics(names=["opea_service@asr"])
async def audio_to_text(
file: UploadFile = File(...), # Handling the uploaded file directly
file: Union[str, UploadFile], # accept base64 string or UploadFile
model: str = Form("openai/whisper-small"),
language: str = Form("english"),
prompt: str = Form(None),
Expand Down
10 changes: 9 additions & 1 deletion comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from enum import Enum, auto


class ServiceRoleType(Enum):
Expand Down Expand Up @@ -92,3 +92,11 @@ class MicroServiceEndpoint(Enum):

def __str__(self):
return self.value


class MCPFuncType(Enum):
"""The enum of a MCP function type."""

TOOL = auto()
RESOURCE = auto()
PROMPT = auto()
44 changes: 40 additions & 4 deletions comps/cores/mega/micro_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import asyncio
import os
from collections import defaultdict, deque
from collections.abc import Callable
from enum import Enum
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Type, TypeAlias

from ..proto.docarray import TextDoc
from .constants import ServiceRoleType, ServiceType
from .constants import MCPFuncType, ServiceRoleType, ServiceType
from .http_service import HTTPService
from .logger import CustomLogger
from .utils import check_ports_availability
Expand All @@ -17,6 +18,7 @@

logger = CustomLogger("micro_service")
logflag = os.getenv("LOGFLAG", False)
AnyFunction: TypeAlias = Callable[..., Any]


class MicroService(HTTPService):
Expand All @@ -43,6 +45,9 @@
dynamic_batching: bool = False,
dynamic_batching_timeout: int = 1,
dynamic_batching_max_batch_size: int = 32,
enable_mcp: bool = False,
mcp_func_type: Enum = MCPFuncType.TOOL,
func: AnyFunction = None,
):
"""Init the microservice."""
self.service_role = service_role
Expand All @@ -56,6 +61,7 @@
self.output_datatype = output_datatype
self.use_remote_service = use_remote_service
self.description = description
self.enable_mcp = enable_mcp
self.dynamic_batching = dynamic_batching
self.dynamic_batching_timeout = dynamic_batching_timeout
self.dynamic_batching_max_batch_size = dynamic_batching_max_batch_size
Expand All @@ -82,7 +88,7 @@
"host": self.host,
"port": self.port,
"title": name,
"description": "OPEA Microservice Infrastructure",
"description": self.description or "OPEA Microservice Infrastructure",
}

super().__init__(uvicorn_kwargs=self.uvicorn_kwargs, runtime_args=runtime_args)
Expand All @@ -93,7 +99,21 @@
self.request_buffer = defaultdict(deque)
self.add_startup_event(self._dynamic_batch_processor())

self._async_setup()
if not enable_mcp:
self._async_setup()
else:
from mcp.server.fastmcp import FastMCP

Check warning on line 105 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L105

Added line #L105 was not covered by tests

self.mcp = FastMCP(name, host=self.host, port=self.port)
dispatch = {

Check warning on line 108 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L107-L108

Added lines #L107 - L108 were not covered by tests
MCPFuncType.TOOL: self.mcp.add_tool,
MCPFuncType.RESOURCE: self.mcp.add_resource,
MCPFuncType.PROMPT: self.mcp.add_prompt,
}
try:
dispatch[mcp_func_type](func, name=name, description=description)
except KeyError:
raise ValueError(f"Unknown MCP func type: {mcp_func_type}")

Check warning on line 116 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L113-L116

Added lines #L113 - L116 were not covered by tests

# overwrite name
self.name = f"{name}/{self.__class__.__name__}" if name else self.__class__.__name__
Expand Down Expand Up @@ -144,6 +164,15 @@
else:
return f"{self.protocol}://{self.host}:{self.port}{self.endpoint}"

def start(self):
"""Start the server using MCP if enabled, otherwise fall back to default."""
if self.enable_mcp:
self.mcp.run(

Check warning on line 170 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L169-L170

Added lines #L169 - L170 were not covered by tests
transport="sse",
)
else:
super().start()

Check warning on line 174 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L174

Added line #L174 was not covered by tests

@property
def api_key_value(self):
return self.api_key
Expand All @@ -167,6 +196,9 @@
dynamic_batching: bool = False,
dynamic_batching_timeout: int = 1,
dynamic_batching_max_batch_size: int = 32,
enable_mcp: bool = False,
description: str = None,
mcp_func_type: Enum = MCPFuncType.TOOL,
):
def decorator(func):
if name not in opea_microservices:
Expand All @@ -187,6 +219,10 @@
dynamic_batching=dynamic_batching,
dynamic_batching_timeout=dynamic_batching_timeout,
dynamic_batching_max_batch_size=dynamic_batching_max_batch_size,
enable_mcp=enable_mcp,
func=func,
description=description,
mcp_func_type=mcp_func_type,
)
opea_microservices[name] = micro_service
opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods)
Expand Down
Loading