diff --git a/.github/workflows/docker/compose/router-compose.yaml b/.github/workflows/docker/compose/router-compose.yaml new file mode 100644 index 0000000000..a3aa7e8f7a --- /dev/null +++ b/.github/workflows/docker/compose/router-compose.yaml @@ -0,0 +1,9 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# this file should be run in the root of the repo +services: + router: + build: + dockerfile: comps/router/src/Dockerfile + image: ${REGISTRY:-opea}/opea-router:${TAG:-latest} diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 5068a81145..20f36697a7 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -1014,3 +1014,7 @@ class FineTuningJobCheckpoint(BaseModel): step_number: Optional[int] = None """The step number that the checkpoint was created at.""" + + +class RouteEndpointDoc(BaseModel): + url: str = Field(..., description="URL of the chosen inference endpoint") diff --git a/comps/router/deployment/docker_compose/compose.yaml b/comps/router/deployment/docker_compose/compose.yaml new file mode 100644 index 0000000000..430e5e2087 --- /dev/null +++ b/comps/router/deployment/docker_compose/compose.yaml @@ -0,0 +1,35 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +services: + router_service: + build: + context: ../../../.. + dockerfile: comps/router/src/Dockerfile + + image: "${REGISTRY_AND_REPO:-opea/router}:${TAG:-latest}" + container_name: opea_router + + volumes: + - ./configs:/app/configs + + environment: + CONFIG_PATH: /app/configs/router.yaml + + WEAK_ENDPOINT: ${WEAK_ENDPOINT:-http://opea_router:8000/weak} + STRONG_ENDPOINT: ${STRONG_ENDPOINT:-http://opea_router:8000/strong} + WEAK_MODEL_ID: ${WEAK_MODEL_ID:-openai/gpt-3.5-turbo} + STRONG_MODEL_ID: ${STRONG_MODEL_ID:-openai/gpt-4} + + HF_TOKEN: ${HF_TOKEN:?set HF_TOKEN} + OPENAI_API_KEY: ${OPENAI_API_KEY:-""} + + CONTROLLER_TYPE: ${CONTROLLER_TYPE:-routellm} + + ports: + - "${ROUTER_PORT:-6000}:6000" + restart: unless-stopped + +networks: + default: + driver: bridge diff --git a/comps/router/deployment/docker_compose/configs/routellm_config.yaml b/comps/router/deployment/docker_compose/configs/routellm_config.yaml new file mode 100644 index 0000000000..b387712b86 --- /dev/null +++ b/comps/router/deployment/docker_compose/configs/routellm_config.yaml @@ -0,0 +1,29 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# which embedder backend to use ("huggingface" or "openai") +embedding_provider: "huggingface" + +embedding_model_name: "intfloat/e5-base-v2" + +routing_algorithm: "mf" +threshold: 0.3 + +config: + sw_ranking: + arena_battle_datasets: + - "lmsys/lmsys-arena-human-preference-55k" + - "routellm/gpt4_judge_battles" + arena_embedding_datasets: + - "routellm/arena_battles_embeddings" + - "routellm/gpt4_judge_battles_embeddings" + + causal_llm: + checkpoint_path: "routellm/causal_llm_gpt4_augmented" + + bert: + checkpoint_path: "routellm/bert_gpt4_augmented" + + mf: + checkpoint_path: "OPEA/routellm-e5-base-v2" + use_openai_embeddings: false diff --git a/comps/router/deployment/docker_compose/configs/router.yaml b/comps/router/deployment/docker_compose/configs/router.yaml new file mode 100644 index 0000000000..b9dd1eac56 --- /dev/null +++ b/comps/router/deployment/docker_compose/configs/router.yaml @@ -0,0 +1,14 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +model_map: + weak: + endpoint: "${WEAK_ENDPOINT:-http://opea_router:8000/weak}" + model_id: "${WEAK_MODEL_ID}" + strong: + endpoint: "${STRONG_ENDPOINT:-http://opea_router:8000/strong}" + model_id: "${STRONG_MODEL_ID}" + +controller_config_paths: + routellm: "/app/configs/routellm_config.yaml" + semantic_router: "/app/configs/semantic_router_config.yaml" diff --git a/comps/router/deployment/docker_compose/configs/semantic_router_config.yaml b/comps/router/deployment/docker_compose/configs/semantic_router_config.yaml new file mode 100644 index 0000000000..fcfcec2689 --- /dev/null +++ b/comps/router/deployment/docker_compose/configs/semantic_router_config.yaml @@ -0,0 +1,20 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +embedding_provider: "huggingface" + +embedding_models: + huggingface: "BAAI/bge-base-en-v1.5" + openai: "text-embedding-ada-002" + +routes: + - name: "strong" + utterances: + - "Prove the Pythagorean theorem using geometric arguments..." + - "Explain the Calvin cycle..." + - "Discuss the ethical implications of deploying AI..." + - name: "weak" + utterances: + - "Hello, how are you?" + - "What's 2 + 2?" + - "Can you tell me a funny joke?" diff --git a/comps/router/deployment/docker_compose/deploy_router.sh b/comps/router/deployment/docker_compose/deploy_router.sh new file mode 100755 index 0000000000..fe29c94943 --- /dev/null +++ b/comps/router/deployment/docker_compose/deploy_router.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ======================== +# OPEA Router Deploy Script +# ======================== + +# Load environment variables from a .env file if present +if [ -f .env ]; then + echo "[INFO] Loading environment variables from .env" + export $(grep -v '^#' .env | xargs) +fi + +# Required variables +REQUIRED_VARS=("HF_TOKEN") + +# Validate that all required variables are set +for VAR in "${REQUIRED_VARS[@]}"; do + if [ -z "${!VAR}" ]; then + echo "[ERROR] $VAR is not set. Please set it in your environment or .env file." + exit 1 + fi +done + +export HUGGINGFACEHUB_API_TOKEN="$HF_TOKEN" + +# Default values for Docker image +REGISTRY_AND_REPO=${REGISTRY_AND_REPO:-opea/router} +TAG=${TAG:-latest} + +# Export them so Docker Compose can see them +export REGISTRY_AND_REPO +export TAG + +# Print summary +echo "[INFO] Starting deployment with the following config:" +echo " Image: ${REGISTRY_AND_REPO}:${TAG}" +echo " HF_TOKEN: ***${HF_TOKEN: -4}" +echo " OPENAI_API_KEY: ***${OPENAI_API_KEY: -4}" +echo "" + +# Compose up +echo "[INFO] Launching Docker Compose service..." +docker compose -f compose.yaml up --build + +# Wait a moment then check status +sleep 2 +docker ps --filter "name=opea-router" + +echo "[SUCCESS] Router service deployed and running on http://localhost:6000" diff --git a/comps/router/src/Dockerfile b/comps/router/src/Dockerfile new file mode 100644 index 0000000000..ded4d44ea0 --- /dev/null +++ b/comps/router/src/Dockerfile @@ -0,0 +1,30 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.10-slim + +# Install git +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* + +# Add a non-root user +RUN useradd -m -s /bin/bash user && chown -R user /home/user + +# Copy the *entire* comps/ package +WORKDIR /home/user +COPY comps /home/user/comps + +# Install deps from the router’s requirements.txt +RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r /home/user/comps/router/src/requirements.txt && git clone --depth 1 https://github.com/lm-sys/RouteLLM.git /tmp/RouteLLM && patch -p1 -d /tmp/RouteLLM < /home/user/comps/router/src/hf_compatibility.patch && pip install --no-cache-dir /tmp/RouteLLM && rm -rf /tmp/RouteLLM + +# Make imports work +ENV PYTHONPATH=/home/user + +# Switch to non-root +USER user + +# Expose the port +EXPOSE 6000 + +# Run the microservice +WORKDIR /home/user/comps/router/src +CMD ["python", "opea_router_microservice.py"] diff --git a/comps/router/src/README.md b/comps/router/src/README.md new file mode 100644 index 0000000000..97cf56362c --- /dev/null +++ b/comps/router/src/README.md @@ -0,0 +1,120 @@ +# Router Microservice + +> Location: comps/router/src/README.md + +A lightweight HTTP service that routes incoming text prompts to the most appropriate LLM back‑end (e.g. strong vs weak) and returns the target inference endpoint. It is built on the OPEA micro‑service SDK and can switch between two controller back‑ends: + +- RouteLLM (matrix‑factorisation, dataset‑driven) +- Semantic‑Router (encoder‑based semantic similarity) + +The router is stateless; it inspects the prompt, consults the configured controller, and replies with a single URL such as http://opea_router:8000/strong. + +## Build + +```bash +# From repo root 📂 +# Build the container image directly +$ docker build -t opea/router:latest -f comps/router/src/Dockerfile . +``` + +Alternatively, the Docker Compose workflow below will build the image for you. + +```bash +# Navigate to the compose bundle +$ cd comps/router/deployment/docker_compose + +# Populate required secrets (or create a .env file) +$ export HF_TOKEN="" +$ export OPENAI_API_KEY="" + +# Optional: point to custom inference endpoints / models +$ export WEAK_ENDPOINT=http://my‑llm‑gateway:8000/weak +$ export STRONG_ENDPOINT=http://my‑llm‑gateway:8000/strong +$ export CONTROLLER_TYPE=routellm # or semantic_router + +# Launch (using the helper script) +$ ./deploy_router.sh +``` + +_The service listens on http://localhost:6000 (host‑mapped from container port 6000). Logs stream to STDOUT; use Ctrl‑C to stop or docker compose down to clean up._ + +## RouteLLM compatibility patch + +The upstream **RouteLLM** project is geared toward OpenAI embeddings and GPT-4–augmented +checkpoints. +We include a small patch – `hf_compatibility.patch` – that: + +- adds a `hf_token` plumb-through, +- switches the Matrix-Factorisation router to Hugging Face sentence embeddings, +- removes hard-coded GPT-4 “golden-label” defaults. + +**Container users:** +The Dockerfile applies the patch automatically during `docker build`, so you don’t have to do anything. + +**Local development:** + +```bash +# 1. Clone upstream RouteLLM +git clone https://github.com/lm-sys/RouteLLM.git +cd RouteLLM + +# 2. Apply the patch shipped with this repo +patch -p1 < ../comps/router/src/hf_compatibility.patch + +# 3. Install the patched library +pip install -e . +``` + +## API Usage + +| Method | URL | Body schema | Success response | +| ------ | ----------- | ----------------------------- | ---------------------------------------------- | +| `POST` | `/v1/route` | `{ "text": "" }` | `200 OK` → `{ "url": "" }` | + +**Example** + +``` +curl -X POST http://localhost:6000/v1/route \ + -H "Content-Type: application/json" \ + -d '{"text": "Explain the Calvin cycle in photosynthesis."}' +``` + +Expected JSON _(assuming the strong model wins the routing decision)_: + +``` +{ + "url": "http://opea_router:8000/strong" +} +``` + +## Configuration Reference + +| Variable / file | Purpose | Default | Where set | +| ----------------------------------- | ------------------------------------------------- | -------------------------------------- | ------------------- | +| `HF_TOKEN` | Hugging Face auth token for encoder models | — | `.env` / shell | +| `OPENAI_API_KEY` | OpenAI key (only if `embedding_provider: openai`) | — | `.env` / shell | +| `CONTROLLER_TYPE` | `routellm` or `semantic_router` | `routellm` | env / `router.yaml` | +| `CONFIG_PATH` | Path to global router YAML | `/app/configs/router.yaml` | Compose env | +| `WEAK_ENDPOINT` / `STRONG_ENDPOINT` | Final inference URLs | container DNS | Compose env | +| `WEAK_MODEL_ID` / `STRONG_MODEL_ID` | Model IDs forwarded to controllers | `openai/gpt-3.5-turbo`, `openai/gpt-4` | Compose env | + +## Troubleshooting + +`HF_TOKEN` is not set – export the token or place it in a .env file next to compose.yaml. + +Unknown controller type – `CONTROLLER_TYPE` must be either routellm or semantic_router and a matching entry must exist in controller_config_paths. + +Routed model `` not in `model_map` – make sure model_map in router.yaml lists both strong and weak with the correct model_id values. + +Use docker compose logs -f router_service for real‑time debugging. + +## Testing + +Includes an end-to-end script for the RouteLLM controller: + +```bash +chmod +x tests/router/test_router_routellm.sh +export HF_TOKEN="" +export OPENAI_API_KEY="" +tests/router/test_router_routellm.sh +``` diff --git a/comps/router/src/hf_compatibility.patch b/comps/router/src/hf_compatibility.patch new file mode 100644 index 0000000000..585b78824f --- /dev/null +++ b/comps/router/src/hf_compatibility.patch @@ -0,0 +1,326 @@ +diff -ruN upstream-RouteLLM/routellm/controller.py patched-RouteLLM/routellm/controller.py +--- upstream-RouteLLM/routellm/controller.py 2025-05-28 19:32:46.029844725 +0000 ++++ patched-RouteLLM/routellm/controller.py 2025-05-28 19:32:14.595998148 +0000 +@@ -9,24 +9,6 @@ + + from routellm.routers.routers import ROUTER_CLS + +-# Default config for routers augmented using golden label data from GPT-4. +-# This is exactly the same as config.example.yaml. +-GPT_4_AUGMENTED_CONFIG = { +- "sw_ranking": { +- "arena_battle_datasets": [ +- "lmsys/lmsys-arena-human-preference-55k", +- "routellm/gpt4_judge_battles", +- ], +- "arena_embedding_datasets": [ +- "routellm/arena_battles_embeddings", +- "routellm/gpt4_judge_battles_embeddings", +- ], +- }, +- "causal_llm": {"checkpoint_path": "routellm/causal_llm_gpt4_augmented"}, +- "bert": {"checkpoint_path": "routellm/bert_gpt4_augmented"}, +- "mf": {"checkpoint_path": "routellm/mf_gpt4_augmented"}, +-} +- + + class RoutingError(Exception): + pass +@@ -48,7 +30,9 @@ + api_base: Optional[str] = None, + api_key: Optional[str] = None, + progress_bar: bool = False, ++ hf_token: Optional[str] = None, # Add hf_token as a parameter + ): ++ self.hf_token = hf_token # Store the hf_token + self.model_pair = ModelPair(strong=strong_model, weak=weak_model) + self.routers = {} + self.api_base = api_base +@@ -57,7 +41,7 @@ + self.progress_bar = progress_bar + + if config is None: +- config = GPT_4_AUGMENTED_CONFIG ++ raise ValueError("Config cannot be None. Please provide a valid configuration dictionary.") + + router_pbar = None + if progress_bar: +@@ -67,7 +51,8 @@ + for router in routers: + if router_pbar is not None: + router_pbar.set_description(f"Loading {router}") +- self.routers[router] = ROUTER_CLS[router](**config.get(router, {})) ++ self.routers[router] = ROUTER_CLS[router](hf_token=self.hf_token, **config.get(router, {})) ++ + + # Some Python magic to match the OpenAI Python SDK + self.chat = SimpleNamespace( +@@ -101,6 +86,14 @@ + f"Invalid model {model}. Model name must be of the format 'router-[router name]-[threshold]." + ) + return router, threshold ++ ++ def get_routed_model(self, messages: list, router: str, threshold: float) -> str: ++ """ ++ Get the routed model for a given message using the specified router and threshold. ++ """ ++ self._validate_router_threshold(router, threshold) ++ routed_model = self._get_routed_model_for_completion(messages, router, threshold) ++ return routed_model + + def _get_routed_model_for_completion( + self, messages: list, router: str, threshold: float +diff -ruN upstream-RouteLLM/routellm/routers/matrix_factorization/model.py patched-RouteLLM/routellm/routers/matrix_factorization/model.py +--- upstream-RouteLLM/routellm/routers/matrix_factorization/model.py 2025-05-28 19:32:46.084844456 +0000 ++++ patched-RouteLLM/routellm/routers/matrix_factorization/model.py 2025-05-28 19:32:14.651997875 +0000 +@@ -1,7 +1,14 @@ + import torch + from huggingface_hub import PyTorchModelHubMixin +- ++from transformers import AutoTokenizer, AutoModel + from routellm.routers.similarity_weighted.utils import OPENAI_CLIENT ++import logging ++ ++logging.basicConfig( ++ level=logging.INFO, ++ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ++) ++logger = logging.getLogger(__name__) + + MODEL_IDS = { + "RWKV-4-Raven-14B": 0, +@@ -70,7 +77,6 @@ + "zephyr-7b-beta": 63, + } + +- + class MFModel(torch.nn.Module, PyTorchModelHubMixin): + def __init__( + self, +@@ -79,51 +85,80 @@ + text_dim, + num_classes, + use_proj, ++ use_openai_embeddings=False, # Default: Hugging Face embeddings ++ embedding_model_name="BAAI/bge-base-en", # Match notebook ++ hf_token=None, # Hugging Face API token + ): + super().__init__() +- self._name = "TextMF" + self.use_proj = use_proj +- self.P = torch.nn.Embedding(num_models, dim) ++ self.use_openai_embeddings = use_openai_embeddings ++ self.hf_token = hf_token ++ self.embedding_model_name = embedding_model_name + +- self.embedding_model = "text-embedding-3-small" ++ # Model embedding matrix ++ self.P = torch.nn.Embedding(num_models, dim) + + if self.use_proj: +- self.text_proj = torch.nn.Sequential( +- torch.nn.Linear(text_dim, dim, bias=False) +- ) ++ self.text_proj = torch.nn.Linear(text_dim, dim, bias=False) + else: +- assert ( +- text_dim == dim +- ), f"text_dim {text_dim} must be equal to dim {dim} if not using projection" +- +- self.classifier = torch.nn.Sequential( +- torch.nn.Linear(dim, num_classes, bias=False) +- ) ++ assert text_dim == dim, f"text_dim {text_dim} must be equal to dim {dim} if not using projection" ++ ++ self.classifier = torch.nn.Linear(dim, num_classes, bias=False) ++ ++ if not self.use_openai_embeddings: ++ logger.info(f"Loading Hugging Face tokenizer and model: {self.embedding_model_name}") ++ ++ # Load tokenizer & model exactly as in the notebook ++ self.tokenizer = AutoTokenizer.from_pretrained( ++ self.embedding_model_name, ++ token=hf_token ++ ) ++ self.embedding_model = AutoModel.from_pretrained( ++ self.embedding_model_name, ++ token=hf_token ++ ) ++ self.embedding_model.eval() # Set to inference mode ++ self.embedding_model.to(self.get_device()) + + def get_device(self): + return self.P.weight.device + ++ def get_prompt_embedding(self, prompt): ++ """Generate sentence embedding using mean pooling (matches notebook).""" ++ ++ inputs = self.tokenizer( ++ prompt, ++ padding=True, ++ truncation=True, ++ return_tensors="pt" ++ ).to(self.get_device()) ++ ++ with torch.no_grad(): ++ outputs = self.embedding_model(**inputs) ++ last_hidden_state = outputs.last_hidden_state ++ ++ # Mean pooling over token embeddings ++ prompt_embed = last_hidden_state.mean(dim=1).squeeze() ++ ++ return prompt_embed ++ + def forward(self, model_id, prompt): + model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) +- + model_embed = self.P(model_id) + model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1) ++ prompt_embed = self.get_prompt_embedding(prompt) + +- prompt_embed = ( +- OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model) +- .data[0] +- .embedding +- ) +- prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) +- prompt_embed = self.text_proj(prompt_embed) ++ if self.use_proj: ++ prompt_embed = self.text_proj(prompt_embed) + + return self.classifier(model_embed * prompt_embed).squeeze() + + @torch.no_grad() + def pred_win_rate(self, model_a, model_b, prompt): + logits = self.forward([model_a, model_b], prompt) +- winrate = torch.sigmoid(logits[0] - logits[1]).item() ++ raw_diff = logits[0] - logits[1] ++ winrate = torch.sigmoid(raw_diff).item() + return winrate + + def load(self, path): +- self.load_state_dict(torch.load(path)) ++ self.load_state_dict(torch.load(path)) +\ No newline at end of file +diff -ruN upstream-RouteLLM/routellm/routers/routers.py patched-RouteLLM/routellm/routers/routers.py +--- upstream-RouteLLM/routellm/routers/routers.py 2025-05-28 19:32:46.084844456 +0000 ++++ patched-RouteLLM/routellm/routers/routers.py 2025-05-28 19:32:14.651997875 +0000 +@@ -1,7 +1,7 @@ + import abc + import functools + import random +- ++from transformers import AutoTokenizer, AutoModel + import numpy as np + import torch + from datasets import concatenate_datasets, load_dataset +@@ -21,6 +21,13 @@ + compute_tiers, + preprocess_battles, + ) ++import logging ++ ++logging.basicConfig( ++ level=logging.INFO, ++ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ++) ++logger = logging.getLogger(__name__) + + + def no_parallel(cls): +@@ -211,18 +218,47 @@ + def __init__( + self, + checkpoint_path, +- # This is the model pair for scoring at inference time, +- # and can be different from the model pair used for routing. + strong_model="gpt-4-1106-preview", + weak_model="mixtral-8x7b-instruct-v0.1", + hidden_size=128, +- num_models=64, +- text_dim=1536, ++ num_models=None, ++ text_dim=None, + num_classes=1, + use_proj=True, ++ use_openai_embeddings=True, ++ embedding_model_name=None, ++ hf_token=None, + ): ++ """ ++ A simplified constructor that flattens the logic for: ++ 1) Setting num_models from MODEL_IDS, ++ 2) Determining embedding_model_name defaults, ++ 3) Setting text_dim for OpenAI vs. HF embeddings, ++ 4) Initializing the MFModel, ++ 5) Setting strong/weak model IDs. ++ """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ++ # Default num_models to the length of MODEL_IDS if not provided ++ num_models = num_models or len(MODEL_IDS) ++ ++ # Decide which embedding model_name to use if none provided ++ if not embedding_model_name: ++ if use_openai_embeddings: ++ # e.g. "text-embedding-ada-002" or your default ++ embedding_model_name = "text-embedding-3-small" ++ else: ++ raise ValueError("Missing model id in config file. Please add a valid model id") ++ ++ # Decide text_dim if not provided ++ if text_dim is None: ++ if use_openai_embeddings: ++ # e.g., 1536 for text-embedding-ada-002 ++ text_dim = 1536 ++ else: ++ text_dim = self._infer_hf_text_dim(embedding_model_name) ++ ++ # Initialize the MFModel + self.model = MFModel.from_pretrained( + checkpoint_path, + dim=hidden_size, +@@ -230,14 +266,40 @@ + text_dim=text_dim, + num_classes=num_classes, + use_proj=use_proj, +- ) +- self.model = self.model.eval().to(device) ++ use_openai_embeddings=use_openai_embeddings, ++ embedding_model_name=embedding_model_name, ++ hf_token=hf_token, ++ ).eval().to(device) ++ ++ # Store strong/weak model IDs + self.strong_model_id = MODEL_IDS[strong_model] + self.weak_model_id = MODEL_IDS[weak_model] + ++ @staticmethod ++ def _infer_hf_text_dim(embedding_model_name: str) -> int: ++ """ ++ Helper to load a huggingface model and extract its hidden size. ++ Immediately frees model from memory. ++ """ ++ tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) ++ hf_model = AutoModel.from_pretrained(embedding_model_name) ++ dim = hf_model.config.hidden_size ++ ++ del tokenizer ++ del hf_model ++ ++ return dim ++ + def calculate_strong_win_rate(self, prompt): ++ """ ++ Scores the prompt using the MF model to see how ++ often the 'strong' model is predicted to win ++ over the 'weak' model. ++ """ + winrate = self.model.pred_win_rate( +- self.strong_model_id, self.weak_model_id, prompt ++ self.strong_model_id, ++ self.weak_model_id, ++ prompt + ) + return winrate + diff --git a/comps/router/src/integrations/controllers/base_controller.py b/comps/router/src/integrations/controllers/base_controller.py new file mode 100644 index 0000000000..2601274970 --- /dev/null +++ b/comps/router/src/integrations/controllers/base_controller.py @@ -0,0 +1,14 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod + + +class BaseController(ABC): + """An abstract base controller class providing a framework for routing and + endpoint retrieval functionality.""" + + @abstractmethod + def route(self, messages, **kwargs): + """Determines the appropriate routing based on input messages.""" + pass diff --git a/comps/router/src/integrations/controllers/controller_factory.py b/comps/router/src/integrations/controllers/controller_factory.py new file mode 100644 index 0000000000..a9c41f927e --- /dev/null +++ b/comps/router/src/integrations/controllers/controller_factory.py @@ -0,0 +1,47 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Dict + +import yaml +from dotenv import load_dotenv + +from comps.router.src.integrations.controllers.routellm_controller.routellm_controller import RouteLLMController +from comps.router.src.integrations.controllers.semantic_router_controller.semantic_router_controller import ( + SemanticRouterController, +) + +load_dotenv() + +HF_TOKEN = os.getenv("HF_TOKEN", "") +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") +CONTROLLER_TYPE = os.getenv("CONTROLLER_TYPE", None) + + +class ControllerFactory: + + @staticmethod + def get_controller_config(config_filename: str) -> Dict: + try: + with open(config_filename, "r") as file: + config = yaml.safe_load(file) + return config + except FileNotFoundError as e: + raise FileNotFoundError(f"Configuration file '{config_filename}' not found.") from e + except yaml.YAMLError as e: + raise ValueError(f"Error parsing the configuration file: {e}") from e + + @staticmethod + def factory(controller_config: str, model_map: Dict): + """Returns an instance of the appropriate controller based on the controller_type.""" + + config = ControllerFactory.get_controller_config(controller_config) + + if CONTROLLER_TYPE == "routellm": + return RouteLLMController(config=config, api_key=OPENAI_API_KEY, hf_token=HF_TOKEN, model_map=model_map) + + elif CONTROLLER_TYPE == "semantic_router": + return SemanticRouterController(config=config, api_key=OPENAI_API_KEY, model_map=model_map) + else: + raise ValueError(f"Unknown controller type: {CONTROLLER_TYPE}") diff --git a/comps/router/src/integrations/controllers/routellm_controller/routellm_controller.py b/comps/router/src/integrations/controllers/routellm_controller/routellm_controller.py new file mode 100644 index 0000000000..1c17d96c4f --- /dev/null +++ b/comps/router/src/integrations/controllers/routellm_controller/routellm_controller.py @@ -0,0 +1,75 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os + +from routellm.controller import Controller as RouteLLM_Controller + +from comps.router.src.integrations.controllers.base_controller import BaseController + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + + +class RouteLLMController(BaseController): + def __init__(self, config, hf_token=None, api_key=None, model_map=None): + self.config = config + self.model_map = model_map or {} + + # Determine embedding provider + provider = config.get("embedding_provider", "huggingface").lower() + + # Resolve embedding model: env override ↔️ config default + env_var = "ROUTELLM_EMBEDDING_MODEL_NAME" + default_model = config.get("embedding_model_name") + self.embedding_model = os.getenv(env_var, default_model) + if not self.embedding_model: + raise ValueError(f"No embedding_model_name in config and {env_var} not set") + logging.info(f"[RouteLLM] using {provider} embedding model: {self.embedding_model}") + + # Inject into nested mf config + nested = self.config.setdefault("config", {}) + mf = nested.setdefault("mf", {}) + mf["embedding_model_name"] = self.embedding_model + + # Validate routing settings + self.routing_algorithm = config.get("routing_algorithm") + if not self.routing_algorithm: + raise ValueError("routing_algorithm must be specified in configuration") + self.threshold = config.get("threshold", 0.2) + + # Extract strong/weak model IDs + strong_model = self.model_map.get("strong", {}).get("model_id") + weak_model = self.model_map.get("weak", {}).get("model_id") + if not strong_model or not weak_model: + raise ValueError("model_map must include both 'strong' and 'weak' entries") + + # Prepare Env for OpenAI if needed + if provider == "openai": + if not api_key: + raise ValueError("api_key is required for OpenAI embeddings") + os.environ["OPENAI_API_KEY"] = api_key + + # Initialize the underlying controller (keyword args to match signature) + self.controller = RouteLLM_Controller( + routers=[self.routing_algorithm], + strong_model=strong_model, + weak_model=weak_model, + config=nested, + hf_token=hf_token if provider == "huggingface" else None, + api_key=api_key if provider == "openai" else None, + ) + + def route(self, messages): + routed_name = self.controller.get_routed_model( + messages, + router=self.routing_algorithm, + threshold=self.threshold, + ) + endpoint_key = next((k for k, v in self.model_map.items() if v.get("model_id") == routed_name), None) + if not endpoint_key: + raise ValueError(f"Routed model '{routed_name}' not in model_map") + return self.model_map[endpoint_key]["endpoint"] diff --git a/comps/router/src/integrations/controllers/semantic_router_controller/semantic_router_controller.py b/comps/router/src/integrations/controllers/semantic_router_controller/semantic_router_controller.py new file mode 100644 index 0000000000..961f7a86ab --- /dev/null +++ b/comps/router/src/integrations/controllers/semantic_router_controller/semantic_router_controller.py @@ -0,0 +1,92 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os + +# from decorators import log_latency +from dotenv import load_dotenv +from semantic_router import Route +from semantic_router.encoders import HuggingFaceEncoder, OpenAIEncoder +from semantic_router.routers import SemanticRouter + +from comps.cores.telemetry.opea_telemetry import opea_telemetry +from comps.router.src.integrations.controllers.base_controller import BaseController + +load_dotenv() +hf_token = os.getenv("HF_TOKEN", "") +openai_api_key = os.getenv("OPENAI_API_KEY", "") + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + + +class SemanticRouterController(BaseController): + def __init__(self, config, api_key=None, model_map=None): + self.config = config + self.model_map = model_map or {} + + # grab provider + model mapping + provider = config.get("embedding_provider", "").lower() + models = config.get("embedding_models", {}) + + if provider not in {"huggingface", "openai"}: + raise ValueError(f"Unsupported embedding_provider: '{provider}'") + if provider not in models: + raise ValueError(f"No embedding_models entry for provider '{provider}'") + + model_name = models[provider] + logging.info(f"SemanticRouter using {provider} encoder '{model_name}'") + + if provider == "huggingface": + hf_token = os.getenv("HF_TOKEN", "") + self.encoder = HuggingFaceEncoder( + name=model_name, + model_kwargs={"token": hf_token}, + tokenizer_kwargs={"token": hf_token}, + ) + else: + if not api_key: + raise ValueError("valid api key is required for selected model provider") + os.environ["OPENAI_API_KEY"] = api_key + self.encoder = OpenAIEncoder(model=model_name) + + # build your routing layer + self._build_route_layer() + + def _build_route_layer(self): + # Build routes from the local controller config + routes = self.config.get("routes", []) + route_list = [Route(name=route["name"], utterances=route["utterances"]) for route in routes] + + # Reinitialize SemanticRouter to clear previous embeddings when switching models + self.route_layer = SemanticRouter(encoder=self.encoder, routes=route_list) + logging.info("[DEBUG] Successfully re-initialized SemanticRouter with fresh embeddings.") + + @opea_telemetry + def route(self, messages): + """Determines which inference endpoint to use based on the provided messages. + + It looks up the model_map to retrieve the nested endpoint value. + """ + query = messages[0]["content"] + + route_choice = self.route_layer(query) + endpoint_key = route_choice.name + + if not endpoint_key: + routes = self.config.get("routes", []) + if routes: + endpoint_key = routes[0]["name"] + else: + raise ValueError("No routes available in the configuration.") + + # Lookup the endpoint in the model_map + model_entry = self.model_map.get(endpoint_key) + if model_entry is None: + raise ValueError(f"Inference endpoint '{endpoint_key}' not found in global model_map.") + + # Return the endpoint from the model map + return model_entry["endpoint"] diff --git a/comps/router/src/opea_router_microservice.py b/comps/router/src/opea_router_microservice.py new file mode 100644 index 0000000000..83960816f3 --- /dev/null +++ b/comps/router/src/opea_router_microservice.py @@ -0,0 +1,96 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +import yaml + +from comps import ( + CustomLogger, + ServiceType, + TextDoc, + opea_microservices, + register_microservice, +) +from comps.cores.proto.api_protocol import RouteEndpointDoc +from comps.router.src.integrations.controllers.controller_factory import ControllerFactory + +# Set up logging +logger = CustomLogger("opea_router_microservice") +logflag = os.getenv("LOGFLAG", False) + +CONFIG_PATH = os.getenv("CONFIG_PATH") + +_config_data = {} +_controller_factory = None +_controller = None + + +def _load_config(): + global _config_data, _controller_factory, _controller + + try: + with open(CONFIG_PATH, "r") as f: + new_data = yaml.safe_load(f) or {} + except Exception as e: + logger.error(f"Failed to load config: {e}") + raise RuntimeError(f"Failed to load config: {e}") + + _config_data = new_data + logger.info(f"[Router] Loaded config data from: {CONFIG_PATH}") + + if _controller_factory is None: + _controller_factory = ControllerFactory() + + model_map = _config_data.get("model_map", {}) + controller_type = os.getenv("CONTROLLER_TYPE") or _config_data.get("controller_type", "routellm") + + # look up the correct controller-config path + try: + controller_config_path = _config_data["controller_config_paths"][controller_type] + except KeyError: + raise RuntimeError(f"No config path for controller_type='{controller_type}' in global config") + + _controller = _controller_factory.factory(controller_config=controller_config_path, model_map=model_map) + + logger.info("[Router] Controller re-initialized successfully.") + + +# Initial config load at startup +_load_config() + + +@register_microservice( + name="opea_service@router", + service_type=ServiceType.LLM, + endpoint="/v1/route", + host="0.0.0.0", + port=6000, + input_datatype=TextDoc, + output_datatype=RouteEndpointDoc, +) +def route_microservice(input: TextDoc) -> RouteEndpointDoc: + """Microservice that decides which model endpoint is best for the given text input. + + Returns only the route URL (does not forward). + """ + if not _controller: + raise RuntimeError("Controller is not initialized — config load failed?") + + query_content = input.text + messages = [{"content": query_content}] + + try: + endpoint = _controller.route(messages) + if not endpoint: + raise ValueError("No suitable model endpoint found.") + return RouteEndpointDoc(url=endpoint) + + except Exception as e: + logger.error(f"[Router] Error during model routing: {e}") + raise + + +if __name__ == "__main__": + logger.info("OPEA Router Microservice is starting...") + opea_microservices["opea_service@router"].start() diff --git a/comps/router/src/requirements.txt b/comps/router/src/requirements.txt new file mode 100644 index 0000000000..5b2bb91185 --- /dev/null +++ b/comps/router/src/requirements.txt @@ -0,0 +1,23 @@ +aiofiles +aiohttp +docarray[full] +docx2txt +fastapi +httpx +kubernetes +langchain +langchain-community +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk +pillow +prometheus-fastapi-instrumentator +pydantic +pypdf +python-dotenv +python-multipart +pyyaml +requests +semantic-router +shortuuid +uvicorn[standard] diff --git a/tests/router/test_router_routellm_on_xeon.sh b/tests/router/test_router_routellm_on_xeon.sh new file mode 100755 index 0000000000..4880bcb8c4 --- /dev/null +++ b/tests/router/test_router_routellm_on_xeon.sh @@ -0,0 +1,71 @@ +#!/usr/bin/env bash +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# End-to-end test – Router micro-service, RouteLLM controller (CPU/Xeon) +set -xeuo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +WORKPATH="$(cd "$SCRIPT_DIR/../.." && pwd)" +host=127.0.0.1 +LOG_PATH="$WORKPATH/tests" +ROUTER_PORT=6000 +CONTAINER=opea_router + +# Required secrets +: "${HF_TOKEN:?Need HF_TOKEN}" +: "${OPENAI_API_KEY:=}" + +REGISTRY_AND_REPO=${REGISTRY_AND_REPO:-opea/router} +TAG=${TAG:-latest} + +export HF_TOKEN OPENAI_API_KEY REGISTRY_AND_REPO TAG + +build_image() { + cd "$WORKPATH" + docker build --no-cache -t "${REGISTRY_AND_REPO}:${TAG}" \ + -f comps/router/src/Dockerfile . +} + +start_router() { + cd "$WORKPATH/comps/router/deployment/docker_compose" + + export CONTROLLER_TYPE=routellm + + docker compose -f compose.yaml up router_service -d + sleep 20 +} + +validate() { + # weak route + rsp=$( + curl -s --noproxy localhost,127.0.0.1 \ + -X POST http://${host}:${ROUTER_PORT}/v1/route \ + -H 'Content-Type: application/json' \ + -d '{"text":"What is 2 + 2?"}' + ) + [[ $rsp == *"weak"* ]] || { echo "weak routing failed ($rsp)"; exit 1; } + + # strong route + hard='Given a 100x100 grid where each cell is independently colored black or white such that for every cell the sum of black cells in its row, column, and both main diagonals is a distinct prime number, determine whether there exists a unique configuration of the grid that satisfies this condition and, if so, compute the total number of black cells in that configuration.' + rsp=$( + curl -s --noproxy localhost,127.0.0.1 \ + -X POST http://${host}:${ROUTER_PORT}/v1/route \ + -H 'Content-Type: application/json' \ + -d "{\"text\":\"$hard\"}" + ) + [[ $rsp == *"strong"* ]] || { echo "strong routing failed ($rsp)"; exit 1; } +} + +cleanup() { + cd "$WORKPATH/comps/router/deployment/docker_compose" + docker compose -f compose.yaml down --remove-orphans +} + +trap cleanup EXIT +cleanup +build_image +start_router +validate + +echo "✅ RouteLLM controller test passed."