Skip to content

Commit 1367540

Browse files
refactor: migrate RAPTOR API parameters to Pydantic model for better validation
1 parent cc228f8 commit 1367540

File tree

1 file changed

+51
-36
lines changed

1 file changed

+51
-36
lines changed

raptor_api.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from typing import List, Dict, Optional, Tuple, Union
1818
from pathlib import Path
1919
from concurrent.futures import ThreadPoolExecutor
20-
from pydantic import BaseModel
20+
from pydantic import BaseModel, Field
2121
from pydantic_settings import BaseSettings
2222
from numba.core.errors import NumbaWarning
2323
from sklearn.mixture import GaussianMixture
2424
from sentence_transformers import SentenceTransformer
25-
from fastapi import FastAPI, UploadFile, File, HTTPException, Query
25+
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends
2626
from fastapi.responses import JSONResponse
2727
from tqdm import tqdm
2828
from contextlib import asynccontextmanager
@@ -465,7 +465,7 @@ async def lifespan(app: FastAPI):
465465
app = FastAPI(
466466
title="RAPTOR API",
467467
description="API for Recursive Abstraction and Processing for Text Organization and Reduction",
468-
version="0.5.2",
468+
version="0.5.3",
469469
lifespan=lifespan,
470470
)
471471

@@ -1550,6 +1550,41 @@ class ClusteringResult(BaseModel):
15501550
metadata: ClusteringMetadata
15511551

15521552

1553+
class RaptorInput(BaseModel):
1554+
"""Input parameters for RAPTOR clustering endpoint."""
1555+
llm_model: Optional[str] = Field(
1556+
default=get_settings().llm_model,
1557+
description="LLM model to use for summarization"
1558+
)
1559+
embedder_model: Optional[str] = Field(
1560+
default=get_settings().embedder_model,
1561+
description="Embedding model to use for generating embeddings"
1562+
)
1563+
threshold_tokens: Optional[int] = Field(
1564+
default=None,
1565+
description="Token threshold for chunk optimization. If None, no optimization is applied"
1566+
)
1567+
temperature: Optional[float] = Field(
1568+
default=get_settings().temperature,
1569+
description="Temperature for text generation (0.0 to 1.0). Controls randomness in LLM output",
1570+
ge=0.0,
1571+
le=1.0
1572+
)
1573+
context_window: Optional[int] = Field(
1574+
default=get_settings().context_window,
1575+
description="Maximum context window size for LLM",
1576+
gt=0
1577+
)
1578+
custom_prompt: Optional[str] = Field(
1579+
default=None,
1580+
description="Custom prompt template for summarization"
1581+
)
1582+
1583+
class Config:
1584+
# Allow extra fields to be ignored for forward compatibility
1585+
extra = "ignore"
1586+
1587+
15531588
@app.get("/")
15541589
async def health_check():
15551590
"""Check the health status of the API service.
@@ -1579,40 +1614,14 @@ async def health_check():
15791614

15801615
@app.post("/raptor/", response_class=JSONResponse)
15811616
async def raptor(
1582-
file: UploadFile = File(...),
1583-
llm_model: str = Query(
1584-
None, description="LLM model to use", example=get_settings().llm_model
1585-
),
1586-
embedder_model: str = Query(
1587-
None,
1588-
description="Embedding model to use",
1589-
example=get_settings().embedder_model,
1590-
),
1591-
threshold_tokens: Optional[int] = Query(
1592-
None, description="Token threshold for chunk optimization"
1593-
),
1594-
temperature: float = Query(
1595-
None,
1596-
description="Temperature for text generation",
1597-
example=get_settings().temperature,
1598-
),
1599-
context_window: int = Query(
1600-
None, description="Context window size", example=get_settings().context_window
1601-
),
1602-
custom_prompt: Optional[str] = Query(
1603-
None, description="Custom prompt template for summarization", example=""
1604-
),
1617+
file: UploadFile = File(..., description="JSON file (.json) containing chunks to process with a 'chunks' array"),
1618+
input_data: RaptorInput = Depends(),
16051619
):
16061620
"""Process semantic chunks from an uploaded JSON file for hierarchical clustering.
16071621
16081622
Args:
1609-
file (UploadFile): JSON file (.json) containing chunks to process with a 'chunks' array
1610-
llm_model (str): LLM model to use for summarization
1611-
embedder_model (str): Model to use for generating embeddings
1612-
threshold_tokens (Optional[int]): Maximum token limit for summaries
1613-
temperature (float): Controls randomness in LLM output (0.0 to 1.0)
1614-
context_window (int): Maximum context window size for LLM
1615-
custom_prompt (Optional[str]): Optional custom prompt template as a string
1623+
file: JSON file (.json) containing chunks to process with a 'chunks' array
1624+
input_data: RAPTOR processing parameters (uses settings defaults if not provided)
16161625
16171626
Returns:
16181627
JSONResponse: Hierarchical clustering results with metadata
@@ -1629,12 +1638,18 @@ async def raptor(
16291638
detail="Ollama server is not reachable. Please ensure Ollama is running and accessible.",
16301639
)
16311640

1641+
# Get settings and apply defaults
1642+
settings = get_settings()
1643+
llm_model = input_data.llm_model or settings.llm_model
1644+
embedder_model = input_data.embedder_model or settings.embedder_model
1645+
temperature = input_data.temperature if input_data.temperature is not None else settings.temperature
1646+
context_window = input_data.context_window if input_data.context_window is not None else settings.context_window
1647+
threshold_tokens = input_data.threshold_tokens
1648+
custom_prompt = input_data.custom_prompt
1649+
16321650
# Verify model availability before processing
16331651
# This handles the case where models might be deleted from Ollama after the app has started
16341652
logger.info(f"Verifying availability of LLM model: '{llm_model}'")
1635-
settings = get_settings()
1636-
# Use the provided model or get from settings
1637-
llm_model = llm_model or settings.llm_model
16381653
llm_model = ensure_ollama_model(llm_model, fallback_model=settings.llm_model)
16391654

16401655
# Use the custom prompt if provided as a string parameter, otherwise use the default

0 commit comments

Comments
 (0)