1717from typing import List , Dict , Optional , Tuple , Union
1818from pathlib import Path
1919from concurrent .futures import ThreadPoolExecutor
20- from pydantic import BaseModel
20+ from pydantic import BaseModel , Field
2121from pydantic_settings import BaseSettings
2222from numba .core .errors import NumbaWarning
2323from sklearn .mixture import GaussianMixture
2424from sentence_transformers import SentenceTransformer
25- from fastapi import FastAPI , UploadFile , File , HTTPException , Query
25+ from fastapi import FastAPI , UploadFile , File , HTTPException , Depends
2626from fastapi .responses import JSONResponse
2727from tqdm import tqdm
2828from contextlib import asynccontextmanager
@@ -465,7 +465,7 @@ async def lifespan(app: FastAPI):
465465app = FastAPI (
466466 title = "RAPTOR API" ,
467467 description = "API for Recursive Abstraction and Processing for Text Organization and Reduction" ,
468- version = "0.5.2 " ,
468+ version = "0.5.3 " ,
469469 lifespan = lifespan ,
470470)
471471
@@ -1550,6 +1550,41 @@ class ClusteringResult(BaseModel):
15501550 metadata : ClusteringMetadata
15511551
15521552
1553+ class RaptorInput (BaseModel ):
1554+ """Input parameters for RAPTOR clustering endpoint."""
1555+ llm_model : Optional [str ] = Field (
1556+ default = get_settings ().llm_model ,
1557+ description = "LLM model to use for summarization"
1558+ )
1559+ embedder_model : Optional [str ] = Field (
1560+ default = get_settings ().embedder_model ,
1561+ description = "Embedding model to use for generating embeddings"
1562+ )
1563+ threshold_tokens : Optional [int ] = Field (
1564+ default = None ,
1565+ description = "Token threshold for chunk optimization. If None, no optimization is applied"
1566+ )
1567+ temperature : Optional [float ] = Field (
1568+ default = get_settings ().temperature ,
1569+ description = "Temperature for text generation (0.0 to 1.0). Controls randomness in LLM output" ,
1570+ ge = 0.0 ,
1571+ le = 1.0
1572+ )
1573+ context_window : Optional [int ] = Field (
1574+ default = get_settings ().context_window ,
1575+ description = "Maximum context window size for LLM" ,
1576+ gt = 0
1577+ )
1578+ custom_prompt : Optional [str ] = Field (
1579+ default = None ,
1580+ description = "Custom prompt template for summarization"
1581+ )
1582+
1583+ class Config :
1584+ # Allow extra fields to be ignored for forward compatibility
1585+ extra = "ignore"
1586+
1587+
15531588@app .get ("/" )
15541589async def health_check ():
15551590 """Check the health status of the API service.
@@ -1579,40 +1614,14 @@ async def health_check():
15791614
15801615@app .post ("/raptor/" , response_class = JSONResponse )
15811616async def raptor (
1582- file : UploadFile = File (...),
1583- llm_model : str = Query (
1584- None , description = "LLM model to use" , example = get_settings ().llm_model
1585- ),
1586- embedder_model : str = Query (
1587- None ,
1588- description = "Embedding model to use" ,
1589- example = get_settings ().embedder_model ,
1590- ),
1591- threshold_tokens : Optional [int ] = Query (
1592- None , description = "Token threshold for chunk optimization"
1593- ),
1594- temperature : float = Query (
1595- None ,
1596- description = "Temperature for text generation" ,
1597- example = get_settings ().temperature ,
1598- ),
1599- context_window : int = Query (
1600- None , description = "Context window size" , example = get_settings ().context_window
1601- ),
1602- custom_prompt : Optional [str ] = Query (
1603- None , description = "Custom prompt template for summarization" , example = ""
1604- ),
1617+ file : UploadFile = File (..., description = "JSON file (.json) containing chunks to process with a 'chunks' array" ),
1618+ input_data : RaptorInput = Depends (),
16051619):
16061620 """Process semantic chunks from an uploaded JSON file for hierarchical clustering.
16071621
16081622 Args:
1609- file (UploadFile): JSON file (.json) containing chunks to process with a 'chunks' array
1610- llm_model (str): LLM model to use for summarization
1611- embedder_model (str): Model to use for generating embeddings
1612- threshold_tokens (Optional[int]): Maximum token limit for summaries
1613- temperature (float): Controls randomness in LLM output (0.0 to 1.0)
1614- context_window (int): Maximum context window size for LLM
1615- custom_prompt (Optional[str]): Optional custom prompt template as a string
1623+ file: JSON file (.json) containing chunks to process with a 'chunks' array
1624+ input_data: RAPTOR processing parameters (uses settings defaults if not provided)
16161625
16171626 Returns:
16181627 JSONResponse: Hierarchical clustering results with metadata
@@ -1629,12 +1638,18 @@ async def raptor(
16291638 detail = "Ollama server is not reachable. Please ensure Ollama is running and accessible." ,
16301639 )
16311640
1641+ # Get settings and apply defaults
1642+ settings = get_settings ()
1643+ llm_model = input_data .llm_model or settings .llm_model
1644+ embedder_model = input_data .embedder_model or settings .embedder_model
1645+ temperature = input_data .temperature if input_data .temperature is not None else settings .temperature
1646+ context_window = input_data .context_window if input_data .context_window is not None else settings .context_window
1647+ threshold_tokens = input_data .threshold_tokens
1648+ custom_prompt = input_data .custom_prompt
1649+
16321650 # Verify model availability before processing
16331651 # This handles the case where models might be deleted from Ollama after the app has started
16341652 logger .info (f"Verifying availability of LLM model: '{ llm_model } '" )
1635- settings = get_settings ()
1636- # Use the provided model or get from settings
1637- llm_model = llm_model or settings .llm_model
16381653 llm_model = ensure_ollama_model (llm_model , fallback_model = settings .llm_model )
16391654
16401655 # Use the custom prompt if provided as a string parameter, otherwise use the default
0 commit comments