diff --git a/comps/asr/deployment/docker_compose/compose.yaml b/comps/asr/deployment/docker_compose/compose.yaml index 3595eaf225..4b0ac07da3 100644 --- a/comps/asr/deployment/docker_compose/compose.yaml +++ b/comps/asr/deployment/docker_compose/compose.yaml @@ -14,11 +14,13 @@ services: environment: ASR_ENDPOINT: ${ASR_ENDPOINT} ASR_COMPONENT_NAME: ${ASR_COMPONENT_NAME:-OPEA_WHISPER_ASR} + ENABLE_MCP: ${ENABLE_MCP:-False} asr-whisper: extends: asr container_name: asr-whisper-service environment: ASR_COMPONENT_NAME: ${ASR_COMPONENT_NAME:-OPEA_WHISPER_ASR} + ENABLE_MCP: ${ENABLE_MCP:-False} depends_on: whisper-service: condition: service_healthy @@ -27,6 +29,7 @@ services: container_name: asr-whisper-gaudi-service environment: ASR_COMPONENT_NAME: ${ASR_COMPONENT_NAME:-OPEA_WHISPER_ASR} + ENABLE_MCP: ${ENABLE_MCP:-False} depends_on: whisper-gaudi-service: condition: service_healthy diff --git a/comps/asr/src/integrations/whisper.py b/comps/asr/src/integrations/whisper.py index eb4c265ea1..39183e3350 100644 --- a/comps/asr/src/integrations/whisper.py +++ b/comps/asr/src/integrations/whisper.py @@ -3,7 +3,7 @@ import asyncio import os -from typing import List +from typing import List, Union import requests from fastapi import File, Form, UploadFile @@ -32,7 +32,7 @@ def __init__(self, name: str, description: str, config: dict = None): async def invoke( self, - file: UploadFile = File(...), # Handling the uploaded file directly + file: Union[str, UploadFile], # accept base64 string or UploadFile model: str = Form("openai/whisper-small"), language: str = Form("english"), prompt: str = Form(None), @@ -41,28 +41,39 @@ async def invoke( timestamp_granularities: List[str] = Form(None), ) -> AudioTranscriptionResponse: """Involve the ASR service to generate transcription for the provided input.""" - # Read the uploaded file - file_contents = await file.read() + if isinstance(file, str): + data = {"audio": file} + # Send the file and model to the server + response = await asyncio.to_thread( + requests.post, + f"{self.base_url}/v1/asr", + json=data, + ) + res = response.json()["asr_result"] + return AudioTranscriptionResponse(text=res) + else: + # Read the uploaded file + file_contents = await file.read() - # Prepare the files and data - files = { - "file": (file.filename, file_contents, file.content_type), - } - data = { - "model": model, - "language": language, - "prompt": prompt, - "response_format": response_format, - "temperature": temperature, - "timestamp_granularities": timestamp_granularities, - } + # Prepare the files and data + files = { + "file": (file.filename, file_contents, file.content_type), + } + data = { + "model": model, + "language": language, + "prompt": prompt, + "response_format": response_format, + "temperature": temperature, + "timestamp_granularities": timestamp_granularities, + } - # Send the file and model to the server - response = await asyncio.to_thread( - requests.post, f"{self.base_url}/v1/audio/transcriptions", files=files, data=data - ) - res = response.json()["text"] - return AudioTranscriptionResponse(text=res) + # Send the file and model to the server + response = await asyncio.to_thread( + requests.post, f"{self.base_url}/v1/audio/transcriptions", files=files, data=data + ) + res = response.json()["text"] + return AudioTranscriptionResponse(text=res) def check_health(self) -> bool: """Checks the health of the embedding service. diff --git a/comps/asr/src/opea_asr_microservice.py b/comps/asr/src/opea_asr_microservice.py index 8210149613..db9bb37947 100644 --- a/comps/asr/src/opea_asr_microservice.py +++ b/comps/asr/src/opea_asr_microservice.py @@ -3,7 +3,7 @@ import os import time -from typing import List +from typing import List, Union from fastapi import File, Form, UploadFile from integrations.whisper import OpeaWhisperAsr @@ -19,12 +19,15 @@ register_statistics, statistics_dict, ) +from comps.cores.mega.constants import MCPFuncType from comps.cores.proto.api_protocol import AudioTranscriptionResponse logger = CustomLogger("opea_asr_microservice") logflag = os.getenv("LOGFLAG", False) asr_component_name = os.getenv("ASR_COMPONENT_NAME", "OPEA_WHISPER_ASR") +enable_mcp = os.getenv("ENABLE_MCP", "").strip().lower() in {"true", "1", "yes"} + # Initialize OpeaComponentLoader loader = OpeaComponentLoader(asr_component_name, description=f"OPEA ASR Component: {asr_component_name}") @@ -37,10 +40,13 @@ port=9099, input_datatype=Base64ByteStrDoc, output_datatype=LLMParamsDoc, + enable_mcp=enable_mcp, + mcp_func_type=MCPFuncType.TOOL, + description="Convert audio to text.", ) @register_statistics(names=["opea_service@asr"]) async def audio_to_text( - file: UploadFile = File(...), # Handling the uploaded file directly + file: Union[str, UploadFile], # accept base64 string or UploadFile model: str = Form("openai/whisper-small"), language: str = Form("english"), prompt: str = Form(None), diff --git a/comps/asr/src/requirements.txt b/comps/asr/src/requirements.txt index f73cc5821a..cca9450d79 100644 --- a/comps/asr/src/requirements.txt +++ b/comps/asr/src/requirements.txt @@ -3,6 +3,7 @@ aiohttp datasets docarray[full] fastapi +mcp opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk diff --git a/comps/cores/mega/constants.py b/comps/cores/mega/constants.py index 0723bbd12a..ed1a2271d0 100644 --- a/comps/cores/mega/constants.py +++ b/comps/cores/mega/constants.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from enum import Enum +from enum import Enum, auto class ServiceRoleType(Enum): @@ -92,3 +92,11 @@ class MicroServiceEndpoint(Enum): def __str__(self): return self.value + + +class MCPFuncType(Enum): + """The enum of a MCP function type.""" + + TOOL = auto() + RESOURCE = auto() + PROMPT = auto() diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index 9635b0ac24..5d96be70c4 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -4,11 +4,12 @@ import asyncio import os from collections import defaultdict, deque +from collections.abc import Callable from enum import Enum -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Type, TypeAlias from ..proto.docarray import TextDoc -from .constants import ServiceRoleType, ServiceType +from .constants import MCPFuncType, ServiceRoleType, ServiceType from .http_service import HTTPService from .logger import CustomLogger from .utils import check_ports_availability @@ -17,6 +18,7 @@ logger = CustomLogger("micro_service") logflag = os.getenv("LOGFLAG", False) +AnyFunction: TypeAlias = Callable[..., Any] class MicroService(HTTPService): @@ -43,6 +45,9 @@ def __init__( dynamic_batching: bool = False, dynamic_batching_timeout: int = 1, dynamic_batching_max_batch_size: int = 32, + enable_mcp: bool = False, + mcp_func_type: Enum = MCPFuncType.TOOL, + func: AnyFunction = None, ): """Init the microservice.""" self.service_role = service_role @@ -56,6 +61,7 @@ def __init__( self.output_datatype = output_datatype self.use_remote_service = use_remote_service self.description = description + self.enable_mcp = enable_mcp self.dynamic_batching = dynamic_batching self.dynamic_batching_timeout = dynamic_batching_timeout self.dynamic_batching_max_batch_size = dynamic_batching_max_batch_size @@ -82,7 +88,7 @@ def __init__( "host": self.host, "port": self.port, "title": name, - "description": "OPEA Microservice Infrastructure", + "description": self.description or "OPEA Microservice Infrastructure", } super().__init__(uvicorn_kwargs=self.uvicorn_kwargs, runtime_args=runtime_args) @@ -93,7 +99,21 @@ def __init__( self.request_buffer = defaultdict(deque) self.add_startup_event(self._dynamic_batch_processor()) - self._async_setup() + if not enable_mcp: + self._async_setup() + else: + from mcp.server.fastmcp import FastMCP + + self.mcp = FastMCP(name, host=self.host, port=self.port) + dispatch = { + MCPFuncType.TOOL: self.mcp.add_tool, + MCPFuncType.RESOURCE: self.mcp.add_resource, + MCPFuncType.PROMPT: self.mcp.add_prompt, + } + try: + dispatch[mcp_func_type](func, name=func.__name__, description=description) + except KeyError: + raise ValueError(f"Unknown MCP func type: {mcp_func_type}") # overwrite name self.name = f"{name}/{self.__class__.__name__}" if name else self.__class__.__name__ @@ -144,6 +164,15 @@ def endpoint_path(self, model=None): else: return f"{self.protocol}://{self.host}:{self.port}{self.endpoint}" + def start(self): + """Start the server using MCP if enabled, otherwise fall back to default.""" + if self.enable_mcp: + self.mcp.run( + transport="sse", + ) + else: + super().start() + @property def api_key_value(self): return self.api_key @@ -167,6 +196,9 @@ def register_microservice( dynamic_batching: bool = False, dynamic_batching_timeout: int = 1, dynamic_batching_max_batch_size: int = 32, + enable_mcp: bool = False, + description: str = None, + mcp_func_type: Enum = MCPFuncType.TOOL, ): def decorator(func): if name not in opea_microservices: @@ -187,8 +219,22 @@ def decorator(func): dynamic_batching=dynamic_batching, dynamic_batching_timeout=dynamic_batching_timeout, dynamic_batching_max_batch_size=dynamic_batching_max_batch_size, + enable_mcp=enable_mcp, + func=func, + description=description, + mcp_func_type=mcp_func_type, ) opea_microservices[name] = micro_service + + elif enable_mcp: + mcp_handle = opea_microservices[name].mcp + dispatch = { + MCPFuncType.TOOL: mcp_handle.add_tool, + MCPFuncType.RESOURCE: mcp_handle.add_resource, + MCPFuncType.PROMPT: mcp_handle.add_prompt, + } + dispatch[mcp_func_type](func, name=func.__name__, description=description) + opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods) return func diff --git a/requirements.txt b/requirements.txt index c16f8ad52b..cca4354342 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ httpx kubernetes langchain langchain-community +mcp opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk diff --git a/tests/asr/test_asr_whisper_mcp.sh b/tests/asr/test_asr_whisper_mcp.sh new file mode 100644 index 0000000000..8bdfa65a7d --- /dev/null +++ b/tests/asr/test_asr_whisper_mcp.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +set -x + +WORKPATH=$(dirname "$PWD") +ip_address=$(hostname -I | awk '{print $1}') +export TAG=comps +export WHISPER_PORT=10104 +export ASR_PORT=10105 +export ENABLE_MCP=True +cd $WORKPATH + + +function build_docker_images() { + echo $(pwd) + docker build --no-cache -t opea/whisper:$TAG --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/third_parties/whisper/src/Dockerfile . + + if [ $? -ne 0 ]; then + echo "opea/whisper built fail" + exit 1 + else + echo "opea/whisper built successful" + fi + + docker build --no-cache -t opea/asr:$TAG --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/asr/src/Dockerfile . + + if [ $? -ne 0 ]; then + echo "opea/asr built fail" + exit 1 + else + echo "opea/asr built successful" + fi +} + +function start_service() { + unset http_proxy + export ASR_ENDPOINT=http://$ip_address:$WHISPER_PORT + + docker compose -f comps/asr/deployment/docker_compose/compose.yaml up whisper-service asr -d + sleep 1m +} + +function validate_microservice() { + pip install mcp + python3 ${WORKPATH}/tests/utils/validate_svc_with_mcp.py $ip_address $ASR_PORT "asr" + if [ $? -ne 0 ]; then + docker logs whisper-service + docker logs asr-service + exit 1 + fi + +} + +function stop_docker() { + docker ps -a --filter "name=whisper-service" --filter "name=asr-service" --format "{{.Names}}" | xargs -r docker stop +} + +function main() { + + stop_docker + + build_docker_images + start_service + + validate_microservice + + stop_docker + echo y | docker system prune + +} + +main diff --git a/tests/cores/mega/test_mcp.py b/tests/cores/mega/test_mcp.py new file mode 100644 index 0000000000..39a38d8168 --- /dev/null +++ b/tests/cores/mega/test_mcp.py @@ -0,0 +1,70 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import multiprocessing +import unittest + +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client + +from comps import TextDoc, opea_microservices, register_microservice +from comps.cores.mega.constants import MCPFuncType +from comps.version import __version__ + + +@register_microservice( + name="mcp_dummy", + host="0.0.0.0", + port=8087, + enable_mcp=True, + mcp_func_type=MCPFuncType.TOOL, + description="dummy mcp add func", +) +async def mcp_dummy(request: TextDoc) -> TextDoc: + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + text += "OPEA Project MCP!" + return {"text": text} + + +@register_microservice( + name="mcp_dummy", + host="0.0.0.0", + port=8087, + enable_mcp=True, + mcp_func_type=MCPFuncType.TOOL, + description="dummy mcp sum func", +) +async def mcp_dummy_sum(): + return 1 + 1 + + +class TestMicroService(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.process = multiprocessing.Process( + target=opea_microservices["mcp_dummy"].start, daemon=False, name="mcp_dummy" + ) + self.process.start() + + self.server_url = "http://localhost:8087" + + async def test_mcp(self): + async with sse_client(self.server_url + "/sse") as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + self.assertEqual(result.serverInfo.name, "mcp_dummy") + tool_result = await session.call_tool("mcp_dummy", {"request": {"text": "Hello "}}) + self.assertEqual(json.loads(tool_result.content[0].text)["text"], "Hello OPEA Project MCP!") + + tool_result = await session.call_tool( + "mcp_dummy_sum", + ) + self.assertEqual(tool_result.content[0].text, "2") + self.process.kill() + self.process.join(timeout=2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/validate_svc_with_mcp.py b/tests/utils/validate_svc_with_mcp.py new file mode 100644 index 0000000000..77f45fa656 --- /dev/null +++ b/tests/utils/validate_svc_with_mcp.py @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import base64 +import json +import os +import sys + +import requests +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client + + +async def validate_svc(ip_address, service_port, service_type): + + endpoint = f"http://{ip_address}:{service_port}" + + async with sse_client(endpoint + "/sse") as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + if service_type == "asr": + url = "https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav" + response = requests.get(url) + response.raise_for_status() # Ensure the download succeeded + binary_data = response.content + base64_str = base64.b64encode(binary_data).decode("utf-8") + input_dict = {"file": base64_str, "model": "openai/whisper-small", "language": "english"} + tool_result = await session.call_tool( + "audio_to_text", + input_dict, + ) + result_content = tool_result.content + # Check result + if json.loads(result_content[0].text)["text"].startswith("who is"): + print("Result correct.") + else: + print(f"Result wrong. Received was {result_content}") + exit(1) + else: + print(f"Unknown service type: {service_type}") + exit(1) + + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: python3 validate_svc_with_mcp.py ") + exit(1) + ip_address = sys.argv[1] + service_port = sys.argv[2] + service_type = sys.argv[3] + asyncio.run(validate_svc(ip_address, service_port, service_type))