diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f59085ac4..a979df6957 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -2,67 +2,60 @@ LightRAG FastAPI Server """ -from fastapi import FastAPI, Depends, HTTPException, Request -from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse -from fastapi.openapi.docs import ( - get_swagger_ui_html, - get_swagger_ui_oauth2_redirect_html, -) -import os +import configparser import logging import logging.config +import os import sys -import uvicorn -import pipmaster as pm -from fastapi.staticfiles import StaticFiles -from fastapi.responses import RedirectResponse +from contextlib import asynccontextmanager from pathlib import Path -import configparser + +import pipmaster as pm +import uvicorn from ascii_colors import ASCIIColors -from fastapi.middleware.cors import CORSMiddleware -from contextlib import asynccontextmanager from dotenv import load_dotenv +from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.docs import ( + get_swagger_ui_html, + get_swagger_ui_oauth2_redirect_html, +) +from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.security import OAuth2PasswordRequestForm +from fastapi.staticfiles import StaticFiles + +from lightrag import LightRAG +from lightrag import __version__ as core_version +from lightrag.api import __api_version__ +from lightrag.api.auth import auth_handler +from lightrag.api.routers.document_routes import DocumentManager, create_document_routes +from lightrag.api.routers.graph_routes import create_graph_routes +from lightrag.api.routers.ollama_api import OllamaAPI +from lightrag.api.routers.query_routes import create_query_routes from lightrag.api.utils_api import ( - get_combined_auth_dependency, - display_splash_screen, check_env_file, + display_splash_screen, + get_combined_auth_dependency, ) -from .config import ( - global_args, - update_uvicorn_mode_config, - get_default_host, -) -from lightrag.utils import get_env_value -from lightrag import LightRAG, __version__ as core_version -from lightrag.api import __api_version__ -from lightrag.types import GPTKeywordExtractionFormat -from lightrag.utils import EmbeddingFunc from lightrag.constants import ( - DEFAULT_LOG_MAX_BYTES, + DEFAULT_EMBEDDING_TIMEOUT, + DEFAULT_LLM_TIMEOUT, DEFAULT_LOG_BACKUP_COUNT, DEFAULT_LOG_FILENAME, - DEFAULT_LLM_TIMEOUT, - DEFAULT_EMBEDDING_TIMEOUT, -) -from lightrag.api.routers.document_routes import ( - DocumentManager, - create_document_routes, + DEFAULT_LOG_MAX_BYTES, ) -from lightrag.api.routers.query_routes import create_query_routes -from lightrag.api.routers.graph_routes import create_graph_routes -from lightrag.api.routers.ollama_api import OllamaAPI - -from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( - get_namespace_data, - get_default_workspace, - # set_default_workspace, cleanup_keyed_lock, finalize_share_data, + get_default_workspace, + get_namespace_data, + set_default_workspace, ) -from fastapi.security import OAuth2PasswordRequestForm -from lightrag.api.auth import auth_handler +from lightrag.types import GPTKeywordExtractionFormat +from lightrag.utils import EmbeddingFunc, get_env_value, logger, set_verbose_debug + +from .config import get_default_host, global_args, update_uvicorn_mode_config # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance @@ -343,16 +336,73 @@ def create_app(args): # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - # Initialize document manager with workspace support for data isolation - doc_manager = DocumentManager(args.input_dir, workspace=args.workspace) + doc_manager_cache = {} - @asynccontextmanager - async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events""" - # Store background tasks - app.state.background_tasks = set() + def create_doc_manager(request: Request | None) -> DocumentManager: + """Create or retrieve DocumentManager for the current workspace""" + workspace = args.workspace + if request is not None: + workspace = get_workspace_from_request(request, args.workspace) + + logger.debug(f"Using DocumentManager for workspace: '{workspace}'") + if workspace in doc_manager_cache: + return doc_manager_cache[workspace] + + doc_manager = DocumentManager(args.input_dir, workspace=workspace) + doc_manager_cache[workspace] = doc_manager + + return doc_manager_cache[workspace] + + rag_cache = {} + + async def create_rag(request: Request | None) -> LightRAG: + """Create or retrieve LightRAG instance for the current workspace""" + workspace = args.workspace + if request is not None: + workspace = get_workspace_from_request(request, args.workspace) + + logger.debug(f"Using LightRAG instance for workspace: '{workspace}'") + if workspace in rag_cache: + return rag_cache[workspace] + + # Create ollama_server_infos from command line arguments + from lightrag.api.config import OllamaServerInfos + ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag) + + # Initialize RAG with unified configuration try: + rag = LightRAG( + working_dir=args.working_dir, + workspace=workspace, + llm_model_func=create_llm_model_func(args.llm_binding), + llm_model_name=args.llm_model, + llm_model_max_async=args.max_async, + summary_max_tokens=args.summary_max_tokens, + summary_context_size=args.summary_context_size, + chunk_token_size=int(args.chunk_size), + chunk_overlap_token_size=int(args.chunk_overlap_size), + llm_model_kwargs=create_llm_model_kwargs(args.llm_binding, args, llm_timeout), + embedding_func=embedding_func, + default_llm_timeout=llm_timeout, + default_embedding_timeout=embedding_timeout, + kv_storage=args.kv_storage, + graph_storage=args.graph_storage, + vector_storage=args.vector_storage, + doc_status_storage=args.doc_status_storage, + vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, + enable_llm_cache=args.enable_llm_cache, + rerank_model_func=rerank_model_func, + max_parallel_insert=args.max_parallel_insert, + max_graph_nodes=args.max_graph_nodes, + addon_params={ + "language": args.summary_language, + "entity_types": args.entity_types, + }, + ollama_server_infos=ollama_server_infos, + ) + # Initialize database connections # Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace await rag.initialize_storages() @@ -360,13 +410,30 @@ async def lifespan(app: FastAPI): # Data migration regardless of storage implementation await rag.check_and_migrate_data() + rag_cache[workspace] = rag + return rag + except Exception as e: + logger.error(f"Failed to initialize LightRAG: {e}") + raise + + @asynccontextmanager + async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events""" + # Store background tasks + app.state.background_tasks = set() + + try: + create_doc_manager(None) # Pre-create default DocumentManager + await create_rag(None) # Pre-create default LightRAG + ASCIIColors.green("\nServer is ready to accept connections! 🚀\n") yield finally: # Clean up database connections - await rag.finalize_storages() + for rag in rag_cache.values(): + await rag.finalize_storages() if "LIGHTRAG_GUNICORN_MODE" not in os.environ: # Only perform cleanup in Uvicorn single-process mode @@ -404,6 +471,7 @@ async def lifespan(app: FastAPI): "tryItOutEnabled": True, } + set_default_workspace(args.workspace) app = FastAPI(**app_kwargs) # Add custom validation error handler for /query/data endpoint @@ -456,7 +524,7 @@ def get_cors_origins(): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) - def get_workspace_from_request(request: Request) -> str | None: + def get_workspace_from_request(request: Request, default: str) -> str: """ Extract workspace from HTTP request header or use default. @@ -474,7 +542,7 @@ def get_workspace_from_request(request: Request) -> str | None: workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() if not workspace: - workspace = None + workspace = default return workspace @@ -1022,66 +1090,19 @@ async def server_rerank_func( else: logger.info("Reranking is disabled") - # Create ollama_server_infos from command line arguments - from lightrag.api.config import OllamaServerInfos - - ollama_server_infos = OllamaServerInfos( - name=args.simulated_model_name, tag=args.simulated_model_tag - ) - - # Initialize RAG with unified configuration - try: - rag = LightRAG( - working_dir=args.working_dir, - workspace=args.workspace, - llm_model_func=create_llm_model_func(args.llm_binding), - llm_model_name=args.llm_model, - llm_model_max_async=args.max_async, - summary_max_tokens=args.summary_max_tokens, - summary_context_size=args.summary_context_size, - chunk_token_size=int(args.chunk_size), - chunk_overlap_token_size=int(args.chunk_overlap_size), - llm_model_kwargs=create_llm_model_kwargs( - args.llm_binding, args, llm_timeout - ), - embedding_func=embedding_func, - default_llm_timeout=llm_timeout, - default_embedding_timeout=embedding_timeout, - kv_storage=args.kv_storage, - graph_storage=args.graph_storage, - vector_storage=args.vector_storage, - doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, - enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, - enable_llm_cache=args.enable_llm_cache, - rerank_model_func=rerank_model_func, - max_parallel_insert=args.max_parallel_insert, - max_graph_nodes=args.max_graph_nodes, - addon_params={ - "language": args.summary_language, - "entity_types": args.entity_types, - }, - ollama_server_infos=ollama_server_infos, - ) - except Exception as e: - logger.error(f"Failed to initialize LightRAG: {e}") - raise - # Add routes app.include_router( create_document_routes( - rag, - doc_manager, + create_rag, + create_doc_manager, api_key, ) ) - app.include_router(create_query_routes(rag, api_key, args.top_k)) - app.include_router(create_graph_routes(rag, api_key)) + app.include_router(create_query_routes(create_rag, api_key, args.top_k)) + app.include_router(create_graph_routes(create_rag, api_key)) # Add Ollama API routes - ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) + ollama_api = OllamaAPI(create_rag, top_k=args.top_k, api_key=api_key) app.include_router(ollama_api.router, prefix="/api") # Custom Swagger UI endpoint for offline support @@ -1212,10 +1233,7 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()): async def get_status(request: Request): """Get current system status including WebUI availability""" try: - workspace = get_workspace_from_request(request) - default_workspace = get_default_workspace() - if workspace is None: - workspace = default_workspace + workspace = get_workspace_from_request(request, get_default_workspace()) pipeline_status = await get_namespace_data( "pipeline_status", workspace=workspace ) @@ -1250,7 +1268,7 @@ async def get_status(request: Request): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": default_workspace, + "workspace": workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 85183bbd18..b81a401cdb 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -3,29 +3,31 @@ """ import asyncio -from functools import lru_cache -from lightrag.utils import logger, get_pinyin_sort_key -import aiofiles import shutil import traceback from datetime import datetime, timezone -from pathlib import Path -from typing import Dict, List, Optional, Any, Literal +from functools import lru_cache from io import BytesIO +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional + +import aiofiles from fastapi import ( APIRouter, BackgroundTasks, Depends, File, HTTPException, + Request, UploadFile, ) from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG -from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus -from lightrag.utils import generate_track_id from lightrag.api.utils_api import get_combined_auth_dependency +from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus +from lightrag.utils import generate_track_id, get_pinyin_sort_key, logger + from ..config import global_args @@ -2029,16 +2031,12 @@ async def background_delete_documents( logger.error(f"Error processing pending documents after deletion: {e}") -def create_document_routes( - rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None -): +def create_document_routes(create_rag, create_doc_manager, api_key: Optional[str] = None): # Create combined auth dependency for document routes combined_auth = get_combined_auth_dependency(api_key) - @router.post( - "/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)] - ) - async def scan_for_new_documents(background_tasks: BackgroundTasks): + @router.post("/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]) + async def scan_for_new_documents(raw_request: Request, background_tasks: BackgroundTasks): """ Trigger the scanning process for new documents. @@ -2049,6 +2047,9 @@ async def scan_for_new_documents(background_tasks: BackgroundTasks): Returns: ScanResponse: A response object containing the scanning status and track_id """ + rag = await create_rag(raw_request) + doc_manager = create_doc_manager(raw_request) + # Generate track_id with "scan" prefix for scanning operation track_id = generate_track_id("scan") @@ -2060,11 +2061,9 @@ async def scan_for_new_documents(background_tasks: BackgroundTasks): track_id=track_id, ) - @router.post( - "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)] - ) + @router.post("/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]) async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) + raw_request: Request, background_tasks: BackgroundTasks, file: UploadFile = File(...) ): """ Upload a file to the input directory and index it. @@ -2085,6 +2084,9 @@ async def upload_to_input_dir( HTTPException: If the file type is not supported (400) or other errors occur (500). """ try: + rag = await create_rag(raw_request) + doc_manager = create_doc_manager(raw_request) + # Sanitize filename to prevent Path Traversal attacks safe_filename = sanitize_filename(file.filename, doc_manager.input_dir) @@ -2133,12 +2135,8 @@ async def upload_to_input_dir( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.post( - "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] - ) - async def insert_text( - request: InsertTextRequest, background_tasks: BackgroundTasks - ): + @router.post("/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]) + async def insert_text(raw_request: Request, request: InsertTextRequest, background_tasks: BackgroundTasks): """ Insert text into the RAG system. @@ -2156,6 +2154,8 @@ async def insert_text( HTTPException: If an error occurs during text processing (500). """ try: + rag = await create_rag(raw_request) + # Check if file_source already exists in doc_status storage if ( request.file_source @@ -2200,9 +2200,7 @@ async def insert_text( response_model=InsertResponse, dependencies=[Depends(combined_auth)], ) - async def insert_texts( - request: InsertTextsRequest, background_tasks: BackgroundTasks - ): + async def insert_texts(raw_request: Request, request: InsertTextsRequest, background_tasks: BackgroundTasks): """ Insert multiple texts into the RAG system. @@ -2220,6 +2218,8 @@ async def insert_texts( HTTPException: If an error occurs during text processing (500). """ try: + rag = await create_rag(raw_request) + # Check if any file_sources already exist in doc_status storage if request.file_sources: for file_source in request.file_sources: @@ -2261,10 +2261,8 @@ async def insert_texts( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.delete( - "", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)] - ) - async def clear_documents(): + @router.delete("", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]) + async def clear_documents(raw_request: Request): """ Clear all documents from the RAG system. @@ -2285,46 +2283,49 @@ async def clear_documents(): HTTPException: Raised when a serious error occurs during the clearing process, with status code 500 and error details in the detail field. """ - from lightrag.kg.shared_storage import ( - get_namespace_data, - get_namespace_lock, - ) + try: + rag = await create_rag(raw_request) + doc_manager = create_doc_manager(raw_request) - # Get pipeline status and lock - pipeline_status = await get_namespace_data( - "pipeline_status", workspace=rag.workspace - ) - pipeline_status_lock = get_namespace_lock( - "pipeline_status", workspace=rag.workspace - ) + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_namespace_lock, + ) - # Check and set status with lock - async with pipeline_status_lock: - if pipeline_status.get("busy", False): - return ClearDocumentsResponse( - status="busy", - message="Cannot clear documents while pipeline is busy", - ) - # Set busy to true - pipeline_status.update( - { - "busy": True, - "job_name": "Clearing Documents", - "job_start": datetime.now().isoformat(), - "docs": 0, - "batchs": 0, - "cur_batch": 0, - "request_pending": False, # Clear any previous request - "latest_message": "Starting document clearing process", - } + # Get pipeline status and lock + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace ) - # Cleaning history_messages without breaking it as a shared list object - del pipeline_status["history_messages"][:] - pipeline_status["history_messages"].append( - "Starting document clearing process" + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace ) - try: + # Check and set status with lock + async with pipeline_status_lock: + if pipeline_status.get("busy", False): + return ClearDocumentsResponse( + status="busy", + message="Cannot clear documents while pipeline is busy", + ) + # Set busy to true + pipeline_status.update( + { + "busy": True, + "job_name": "Clearing Documents", + "job_start": datetime.now().isoformat(), + "docs": 0, + "batchs": 0, + "cur_batch": 0, + "request_pending": False, # Clear any previous request + "latest_message": "Starting document clearing process", + } + ) + # Cleaning history_messages without breaking it as a shared list object + del pipeline_status["history_messages"][:] + pipeline_status["history_messages"].append( + "Starting document clearing process" + ) + # Use drop method to clear all data drop_tasks = [] storages = [ @@ -2460,7 +2461,7 @@ async def clear_documents(): dependencies=[Depends(combined_auth)], response_model=PipelineStatusResponse, ) - async def get_pipeline_status() -> PipelineStatusResponse: + async def get_pipeline_status(raw_request: Request) -> PipelineStatusResponse: """ Get the current status of the document indexing pipeline. @@ -2485,10 +2486,12 @@ async def get_pipeline_status() -> PipelineStatusResponse: HTTPException: If an error occurs while retrieving pipeline status (500) """ try: + rag = await create_rag(raw_request) + from lightrag.kg.shared_storage import ( + get_all_update_flags_status, get_namespace_data, get_namespace_lock, - get_all_update_flags_status, ) pipeline_status = await get_namespace_data( @@ -2556,10 +2559,8 @@ async def get_pipeline_status() -> PipelineStatusResponse: raise HTTPException(status_code=500, detail=str(e)) # TODO: Deprecated, use /documents/paginated instead - @router.get( - "", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)] - ) - async def documents() -> DocsStatusesResponse: + @router.get("", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]) + async def documents(raw_request: Request) -> DocsStatusesResponse: """ Get the status of all documents in the system. This endpoint is deprecated; use /documents/paginated instead. To prevent excessive resource consumption, a maximum of 1,000 records is returned. @@ -2578,6 +2579,8 @@ async def documents() -> DocsStatusesResponse: HTTPException: If an error occurs while retrieving document statuses (500). """ try: + rag = await create_rag(raw_request) + statuses = ( DocStatus.PENDING, DocStatus.PROCESSING, @@ -2673,6 +2676,7 @@ class DeleteDocByIdResponse(BaseModel): summary="Delete a document and all its associated data by its ID.", ) async def delete_document( + raw_request: Request, delete_request: DeleteDocRequest, background_tasks: BackgroundTasks, ) -> DeleteDocByIdResponse: @@ -2699,9 +2703,12 @@ async def delete_document( HTTPException: - 500: If an unexpected internal error occurs during initialization. """ - doc_ids = delete_request.doc_ids - try: + rag = await create_rag(raw_request) + doc_manager = create_doc_manager(raw_request) + + doc_ids = delete_request.doc_ids + from lightrag.kg.shared_storage import ( get_namespace_data, get_namespace_lock, @@ -2750,7 +2757,7 @@ async def delete_document( response_model=ClearCacheResponse, dependencies=[Depends(combined_auth)], ) - async def clear_cache(request: ClearCacheRequest): + async def clear_cache(raw_request: Request, request: ClearCacheRequest): """ Clear all cache data from the LLM response cache storage. @@ -2767,6 +2774,7 @@ async def clear_cache(request: ClearCacheRequest): HTTPException: If an error occurs during cache clearing (500). """ try: + rag = await create_rag(raw_request) # Call the aclear_cache method (no modes parameter) await rag.aclear_cache() @@ -2784,7 +2792,7 @@ async def clear_cache(request: ClearCacheRequest): response_model=DeletionResult, dependencies=[Depends(combined_auth)], ) - async def delete_entity(request: DeleteEntityRequest): + async def delete_entity(raw_request: Request, request: DeleteEntityRequest): """ Delete an entity and all its relationships from the knowledge graph. @@ -2798,6 +2806,8 @@ async def delete_entity(request: DeleteEntityRequest): HTTPException: If the entity is not found (404) or an error occurs (500). """ try: + rag = await create_rag(raw_request) + result = await rag.adelete_by_entity(entity_name=request.entity_name) if result.status == "not_found": raise HTTPException(status_code=404, detail=result.message) @@ -2819,7 +2829,7 @@ async def delete_entity(request: DeleteEntityRequest): response_model=DeletionResult, dependencies=[Depends(combined_auth)], ) - async def delete_relation(request: DeleteRelationRequest): + async def delete_relation(raw_request: Request, request: DeleteRelationRequest): """ Delete a relationship between two entities from the knowledge graph. @@ -2833,6 +2843,8 @@ async def delete_relation(request: DeleteRelationRequest): HTTPException: If the relation is not found (404) or an error occurs (500). """ try: + rag = await create_rag(raw_request) + result = await rag.adelete_by_relation( source_entity=request.source_entity, target_entity=request.target_entity, @@ -2857,7 +2869,7 @@ async def delete_relation(request: DeleteRelationRequest): response_model=TrackStatusResponse, dependencies=[Depends(combined_auth)], ) - async def get_track_status(track_id: str) -> TrackStatusResponse: + async def get_track_status(raw_request: Request, track_id: str) -> TrackStatusResponse: """ Get the processing status of documents by tracking ID. @@ -2877,6 +2889,8 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: HTTPException: If track_id is invalid (400) or an error occurs (500). """ try: + rag = await create_rag(raw_request) + # Validate track_id if not track_id or not track_id.strip(): raise HTTPException(status_code=400, detail="Track ID cannot be empty") @@ -2932,6 +2946,7 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: dependencies=[Depends(combined_auth)], ) async def get_documents_paginated( + raw_request: Request, request: DocumentsRequest, ) -> PaginatedDocsResponse: """ @@ -2954,6 +2969,8 @@ async def get_documents_paginated( HTTPException: If an error occurs while retrieving documents (500). """ try: + rag = await create_rag(raw_request) + # Get paginated documents and status counts in parallel docs_task = rag.doc_status.get_docs_paginated( status_filter=request.status_filter, @@ -3018,7 +3035,7 @@ async def get_documents_paginated( response_model=StatusCountsResponse, dependencies=[Depends(combined_auth)], ) - async def get_document_status_counts() -> StatusCountsResponse: + async def get_document_status_counts(raw_request: Request) -> StatusCountsResponse: """ Get counts of documents by status. @@ -3032,6 +3049,8 @@ async def get_document_status_counts() -> StatusCountsResponse: HTTPException: If an error occurs while retrieving status counts (500). """ try: + rag = await create_rag(raw_request) + status_counts = await rag.doc_status.get_all_status_counts() return StatusCountsResponse(status_counts=status_counts) @@ -3045,7 +3064,7 @@ async def get_document_status_counts() -> StatusCountsResponse: response_model=ReprocessResponse, dependencies=[Depends(combined_auth)], ) - async def reprocess_failed_documents(background_tasks: BackgroundTasks): + async def reprocess_failed_documents(raw_request: Request, background_tasks: BackgroundTasks): """ Reprocess failed and pending documents. @@ -3068,6 +3087,8 @@ async def reprocess_failed_documents(background_tasks: BackgroundTasks): HTTPException: If an error occurs while initiating reprocessing (500). """ try: + rag = await create_rag(raw_request) + # Generate track_id with "retry" prefix for retry operation track_id = generate_track_id("retry") @@ -3093,7 +3114,7 @@ async def reprocess_failed_documents(background_tasks: BackgroundTasks): response_model=CancelPipelineResponse, dependencies=[Depends(combined_auth)], ) - async def cancel_pipeline(): + async def cancel_pipeline(raw_request: Request): """ Request cancellation of the currently running pipeline. @@ -3115,6 +3136,8 @@ async def cancel_pipeline(): HTTPException: If an error occurs while setting cancellation flag (500). """ try: + rag = await create_rag(raw_request) + from lightrag.kg.shared_storage import ( get_namespace_data, get_namespace_lock, diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index e892ff011c..51cd795b8d 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -2,12 +2,14 @@ This module contains all graph-related routes for the LightRAG API. """ -from typing import Optional, Dict, Any import traceback -from fastapi import APIRouter, Depends, Query, HTTPException +from typing import Any, Dict, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import BaseModel, Field from lightrag.utils import logger + from ..utils_api import get_combined_auth_dependency router = APIRouter(tags=["graph"]) @@ -86,11 +88,11 @@ class RelationCreateRequest(BaseModel): ) -def create_graph_routes(rag, api_key: Optional[str] = None): +def create_graph_routes(create_rag, api_key: Optional[str] = None): combined_auth = get_combined_auth_dependency(api_key) @router.get("/graph/label/list", dependencies=[Depends(combined_auth)]) - async def get_graph_labels(): + async def get_graph_labels(raw_request: Request): """ Get all graph labels @@ -98,6 +100,8 @@ async def get_graph_labels(): List[str]: List of graph labels """ try: + rag = await create_rag(raw_request) + return await rag.get_graph_labels() except Exception as e: logger.error(f"Error getting graph labels: {str(e)}") @@ -108,9 +112,8 @@ async def get_graph_labels(): @router.get("/graph/label/popular", dependencies=[Depends(combined_auth)]) async def get_popular_labels( - limit: int = Query( - 300, description="Maximum number of popular labels to return", ge=1, le=1000 - ), + raw_request: Request, + limit: int = Query(300, description="Maximum number of popular labels to return", ge=1, le=1000), ): """ Get popular labels by node degree (most connected entities) @@ -122,6 +125,8 @@ async def get_popular_labels( List[str]: List of popular labels sorted by degree (highest first) """ try: + rag = await create_rag(raw_request) + return await rag.chunk_entity_relation_graph.get_popular_labels(limit) except Exception as e: logger.error(f"Error getting popular labels: {str(e)}") @@ -132,6 +137,7 @@ async def get_popular_labels( @router.get("/graph/label/search", dependencies=[Depends(combined_auth)]) async def search_labels( + raw_request: Request, q: str = Query(..., description="Search query string"), limit: int = Query( 50, description="Maximum number of search results to return", ge=1, le=100 @@ -148,6 +154,8 @@ async def search_labels( List[str]: List of matching labels sorted by relevance """ try: + rag = await create_rag(raw_request) + return await rag.chunk_entity_relation_graph.search_labels(q, limit) except Exception as e: logger.error(f"Error searching labels with query '{q}': {str(e)}") @@ -158,6 +166,7 @@ async def search_labels( @router.get("/graphs", dependencies=[Depends(combined_auth)]) async def get_knowledge_graph( + raw_request: Request, label: str = Query(..., description="Label to get knowledge graph for"), max_depth: int = Query(3, description="Maximum depth of graph", ge=1), max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), @@ -177,6 +186,8 @@ async def get_knowledge_graph( Dict[str, List[str]]: Knowledge graph for label """ try: + rag = await create_rag(raw_request) + # Log the label parameter to check for leading spaces logger.debug( f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})" @@ -196,6 +207,7 @@ async def get_knowledge_graph( @router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)]) async def check_entity_exists( + raw_request: Request, name: str = Query(..., description="Entity name to check"), ): """ @@ -208,6 +220,8 @@ async def check_entity_exists( Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists """ try: + rag = await create_rag(raw_request) + exists = await rag.chunk_entity_relation_graph.has_node(name) return {"exists": exists} except Exception as e: @@ -218,7 +232,7 @@ async def check_entity_exists( ) @router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)]) - async def update_entity(request: EntityUpdateRequest): + async def update_entity(raw_request: Request, request: EntityUpdateRequest): """ Update an entity's properties in the knowledge graph @@ -353,6 +367,8 @@ async def update_entity(request: EntityUpdateRequest): } """ try: + rag = await create_rag(raw_request) + result = await rag.aedit_entity( entity_name=request.entity_name, updated_data=request.updated_data, @@ -408,7 +424,7 @@ async def update_entity(request: EntityUpdateRequest): ) @router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)]) - async def update_relation(request: RelationUpdateRequest): + async def update_relation(raw_request: Request, request: RelationUpdateRequest): """Update a relation's properties in the knowledge graph Args: @@ -418,6 +434,8 @@ async def update_relation(request: RelationUpdateRequest): Dict: Updated relation information """ try: + rag = await create_rag(raw_request) + result = await rag.aedit_relation( source_entity=request.source_id, target_entity=request.target_id, @@ -443,7 +461,7 @@ async def update_relation(request: RelationUpdateRequest): ) @router.post("/graph/entity/create", dependencies=[Depends(combined_auth)]) - async def create_entity(request: EntityCreateRequest): + async def create_entity(raw_request: Request, request: EntityCreateRequest): """ Create a new entity in the knowledge graph @@ -488,6 +506,8 @@ async def create_entity(request: EntityCreateRequest): } """ try: + rag = await create_rag(raw_request) + # Use the proper acreate_entity method which handles: # - Graph lock for concurrency # - Vector embedding creation in entities_vdb @@ -516,7 +536,7 @@ async def create_entity(request: EntityCreateRequest): ) @router.post("/graph/relation/create", dependencies=[Depends(combined_auth)]) - async def create_relation(request: RelationCreateRequest): + async def create_relation(raw_request: Request, request: RelationCreateRequest): """ Create a new relationship between two entities in the knowledge graph @@ -573,6 +593,8 @@ async def create_relation(request: RelationCreateRequest): } """ try: + rag = await create_rag(raw_request) + # Use the proper acreate_relation method which handles: # - Graph lock for concurrency # - Entity existence validation @@ -605,7 +627,7 @@ async def create_relation(request: RelationCreateRequest): ) @router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)]) - async def merge_entities(request: EntityMergeRequest): + async def merge_entities(raw_request: Request, request: EntityMergeRequest): """ Merge multiple entities into a single entity, preserving all relationships @@ -662,6 +684,8 @@ async def merge_entities(request: EntityMergeRequest): - This operation cannot be undone, so verify entity names before merging """ try: + rag = await create_rag(raw_request) + result = await rag.amerge_entities( source_entities=request.entities_to_change, target_entity=request.entity_to_change_into, diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 15c695cee7..7163dfd4d9 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -1,17 +1,17 @@ -from fastapi import APIRouter, HTTPException, Request -from pydantic import BaseModel -from typing import List, Dict, Any, Optional, Type -from lightrag.utils import logger -import time +import asyncio import json import re +import time from enum import Enum +from typing import Any, Dict, List, Optional, Type + +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse -import asyncio +from pydantic import BaseModel + from lightrag import LightRAG, QueryParam -from lightrag.utils import TiktokenTokenizer from lightrag.api.utils_api import get_combined_auth_dependency -from fastapi import Depends +from lightrag.utils import TiktokenTokenizer, logger # query mode according to query prefix (bypass is not LightRAG quer mode) @@ -117,9 +117,7 @@ class OllamaPsResponse(BaseModel): models: List[OllamaRunningModel] -async def parse_request_body( - request: Request, model_class: Type[BaseModel] -) -> BaseModel: +async def parse_request_body(request: Request, model_class: Type[BaseModel]) -> BaseModel: """ Parse request body based on Content-Type header. Supports both application/json and application/octet-stream. @@ -151,9 +149,7 @@ async def parse_request_body( except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid JSON in request body") except Exception as e: - raise HTTPException( - status_code=400, detail=f"Error parsing request body: {str(e)}" - ) + raise HTTPException(status_code=400, detail=f"Error parsing request body: {str(e)}") def estimate_tokens(text: str) -> int: @@ -218,9 +214,8 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]: class OllamaAPI: - def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None): - self.rag = rag - self.ollama_server_infos = rag.ollama_server_infos + def __init__(self, create_rag, top_k: int = 60, api_key: Optional[str] = None): + self.create_rag = create_rag self.top_k = top_k self.api_key = api_key self.router = APIRouter(tags=["ollama"]) @@ -236,21 +231,24 @@ async def get_version(): return OllamaVersionResponse(version="0.9.3") @self.router.get("/tags", dependencies=[Depends(combined_auth)]) - async def get_tags(): + async def get_tags(raw_request: Request): """Return available models acting as an Ollama server""" + rag = await self.create_rag(raw_request) + ollama_server_infos = rag.ollama_server_infos + return OllamaTagResponse( models=[ { - "name": self.ollama_server_infos.LIGHTRAG_MODEL, - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "size": self.ollama_server_infos.LIGHTRAG_SIZE, - "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, + "name": ollama_server_infos.LIGHTRAG_MODEL, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT, + "size": ollama_server_infos.LIGHTRAG_SIZE, + "digest": ollama_server_infos.LIGHTRAG_DIGEST, "details": { "parent_model": "", "format": "gguf", - "family": self.ollama_server_infos.LIGHTRAG_NAME, - "families": [self.ollama_server_infos.LIGHTRAG_NAME], + "family": ollama_server_infos.LIGHTRAG_NAME, + "families": [ollama_server_infos.LIGHTRAG_NAME], "parameter_size": "13B", "quantization_level": "Q4_0", }, @@ -259,15 +257,18 @@ async def get_tags(): ) @self.router.get("/ps", dependencies=[Depends(combined_auth)]) - async def get_running_models(): + async def get_running_models(raw_request: Request): """List Running Models - returns currently running models""" + rag = await self.create_rag(raw_request) + ollama_server_infos = rag.ollama_server_infos + return OllamaPsResponse( models=[ { - "name": self.ollama_server_infos.LIGHTRAG_MODEL, - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "size": self.ollama_server_infos.LIGHTRAG_SIZE, - "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, + "name": ollama_server_infos.LIGHTRAG_MODEL, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "size": ollama_server_infos.LIGHTRAG_SIZE, + "digest": ollama_server_infos.LIGHTRAG_DIGEST, "details": { "parent_model": "", "format": "gguf", @@ -277,14 +278,12 @@ async def get_running_models(): "quantization_level": "Q4_0", }, "expires_at": "2050-12-31T14:38:31.83753-07:00", - "size_vram": self.ollama_server_infos.LIGHTRAG_SIZE, + "size_vram": ollama_server_infos.LIGHTRAG_SIZE, } ] ) - @self.router.post( - "/generate", dependencies=[Depends(combined_auth)], include_in_schema=True - ) + @self.router.post("/generate", dependencies=[Depends(combined_auth)], include_in_schema=True) async def generate(raw_request: Request): """Handle generate completion requests acting as an Ollama model For compatibility purpose, the request is not processed by LightRAG, @@ -292,6 +291,9 @@ async def generate(raw_request: Request): Supports both application/json and application/octet-stream Content-Types. """ try: + rag = await self.create_rag(raw_request) + ollama_server_infos = rag.ollama_server_infos + # Parse the request body manually request = await parse_request_body(raw_request, OllamaGenerateRequest) @@ -300,12 +302,10 @@ async def generate(raw_request: Request): prompt_tokens = estimate_tokens(query) if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system + rag.llm_model_kwargs["system_prompt"] = request.system if request.stream: - response = await self.rag.llm_model_func( - query, stream=True, **self.rag.llm_model_kwargs - ) + response = await rag.llm_model_func(query, stream=True, **rag.llm_model_kwargs) async def stream_generator(): first_chunk_time = None @@ -320,8 +320,8 @@ async def stream_generator(): total_response = response data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": response, "done": False, } @@ -333,8 +333,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": "", "done": True, "done_reason": "stop", @@ -358,8 +358,8 @@ async def stream_generator(): total_response += chunk data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": chunk, "done": False, } @@ -375,8 +375,8 @@ async def stream_generator(): # Send error message to client error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": f"\n\nError: {error_msg}", "error": f"\n\nError: {error_msg}", "done": False, @@ -385,8 +385,8 @@ async def stream_generator(): # Send final message to close the stream final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": "", "done": True, } @@ -400,8 +400,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": "", "done": True, "done_reason": "stop", @@ -428,9 +428,7 @@ async def stream_generator(): ) else: first_chunk_time = time.time_ns() - response_text = await self.rag.llm_model_func( - query, stream=False, **self.rag.llm_model_kwargs - ) + response_text = await rag.llm_model_func(query, stream=False, **rag.llm_model_kwargs) last_chunk_time = time.time_ns() if not response_text: @@ -442,8 +440,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time return { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "response": str(response_text), "done": True, "done_reason": "stop", @@ -468,7 +466,11 @@ async def chat(raw_request: Request): Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. Supports both application/json and application/octet-stream Content-Types. """ + try: + rag = await self.create_rag(raw_request) + ollama_server_infos = rag.ollama_server_infos + # Parse the request body manually request = await parse_request_body(raw_request, OllamaChatRequest) @@ -516,15 +518,15 @@ async def chat(raw_request: Request): # Determine if the request is prefix with "/bypass" if mode == SearchMode.bypass: if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system - response = await self.rag.llm_model_func( + rag.llm_model_kwargs["system_prompt"] = request.system + response = await rag.llm_model_func( cleaned_query, stream=True, history_messages=conversation_history, - **self.rag.llm_model_kwargs, + **rag.llm_model_kwargs, ) else: - response = await self.rag.aquery( + response = await rag.aquery( cleaned_query, param=query_param ) @@ -541,8 +543,8 @@ async def stream_generator(): total_response = response data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": response, @@ -558,8 +560,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": "", @@ -586,8 +588,8 @@ async def stream_generator(): total_response += chunk data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": chunk, @@ -607,8 +609,8 @@ async def stream_generator(): # Send error message to client error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": f"\n\nError: {error_msg}", @@ -621,8 +623,8 @@ async def stream_generator(): # Send final message to close the stream final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": "", @@ -641,8 +643,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": "", @@ -678,18 +680,16 @@ async def stream_generator(): ) if match_result or mode == SearchMode.bypass: if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system + rag.llm_model_kwargs["system_prompt"] = request.system - response_text = await self.rag.llm_model_func( + response_text = await rag.llm_model_func( cleaned_query, stream=False, history_messages=conversation_history, - **self.rag.llm_model_kwargs, + **rag.llm_model_kwargs, ) else: - response_text = await self.rag.aquery( - cleaned_query, param=query_param - ) + response_text = await rag.aquery(cleaned_query, param=query_param) last_chunk_time = time.time_ns() @@ -702,8 +702,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time return { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": str(response_text), diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 99a799c182..90c084b6d2 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -4,11 +4,13 @@ import json from typing import Any, Dict, List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException -from lightrag.base import QueryParam + +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel, Field, field_validator + from lightrag.api.utils_api import get_combined_auth_dependency +from lightrag.base import QueryParam from lightrag.utils import logger -from pydantic import BaseModel, Field, field_validator router = APIRouter(tags=["query"]) @@ -190,7 +192,7 @@ class StreamChunkResponse(BaseModel): ) -def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): +def create_query_routes(create_rag, api_key: Optional[str] = None, top_k: int = 60): combined_auth = get_combined_auth_dependency(api_key) @router.post( @@ -322,7 +324,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): }, }, ) - async def query_text(request: QueryRequest): + async def query_text(raw_request: Request, request: QueryRequest): """ Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored. @@ -402,6 +404,8 @@ async def query_text(request: QueryRequest): - 500: Internal processing error (e.g., LLM service unavailable) """ try: + rag = await create_rag(raw_request) + param = request.to_query_params( False ) # Ensure stream=False for non-streaming endpoint @@ -532,7 +536,7 @@ async def query_text(request: QueryRequest): }, }, ) - async def query_text_stream(request: QueryRequest): + async def query_text_stream(raw_request: Request, request: QueryRequest): """ Advanced RAG query endpoint with flexible streaming response. @@ -660,6 +664,8 @@ async def query_text_stream(request: QueryRequest): Use streaming mode for real-time interfaces and non-streaming for batch processing. """ try: + rag = await create_rag(raw_request) + # Use the stream parameter from the request, defaulting to True if not specified stream_mode = request.stream if request.stream is not None else True param = request.to_query_params(stream_mode) @@ -1035,7 +1041,7 @@ async def stream_generator(): }, }, ) - async def query_data(request: QueryRequest): + async def query_data(raw_request: Request, request: QueryRequest): """ Advanced data retrieval endpoint for structured RAG analysis. @@ -1139,6 +1145,8 @@ async def query_data(request: QueryRequest): as structured data analysis typically requires source attribution. """ try: + rag = await create_rag(raw_request) + param = request.to_query_params(False) # No streaming for data endpoint response = await rag.aquery_data(request.query, param=param) @@ -1151,6 +1159,7 @@ async def query_data(request: QueryRequest): status="failure", message="Invalid response type", data={}, + metadata={}, ) except Exception as e: logger.error(f"Error processing data query: {str(e)}", exc_info=True) diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index 7cf1aec6e0..33218a3425 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -289,6 +289,7 @@ const axiosInstance = axios.create({ axiosInstance.interceptors.request.use((config) => { const apiKey = useSettingsStore.getState().apiKey const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); + const workspace = localStorage.getItem('LIGHTRAG-WORKSPACE'); // Always include token if it exists, regardless of path if (token) { @@ -297,6 +298,9 @@ axiosInstance.interceptors.request.use((config) => { if (apiKey) { config.headers['X-API-Key'] = apiKey } + if (workspace) { + config.headers['LIGHTRAG-WORKSPACE'] = workspace + } return config }) @@ -397,6 +401,7 @@ export const queryTextStream = async ( ) => { const apiKey = useSettingsStore.getState().apiKey; const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); + const workspace = localStorage.getItem('LIGHTRAG-WORKSPACE'); const headers: HeadersInit = { 'Content-Type': 'application/json', 'Accept': 'application/x-ndjson', @@ -407,6 +412,9 @@ export const queryTextStream = async ( if (apiKey) { headers['X-API-Key'] = apiKey; } + if (workspace) { + headers['LIGHTRAG-WORKSPACE'] = workspace; + } try { const response = await fetch(`${backendBaseUrl}/query/stream`, { diff --git a/lightrag_webui/src/features/SiteHeader.tsx b/lightrag_webui/src/features/SiteHeader.tsx index dbea38bd3e..14d425ec46 100644 --- a/lightrag_webui/src/features/SiteHeader.tsx +++ b/lightrag_webui/src/features/SiteHeader.tsx @@ -7,8 +7,10 @@ import { useAuthStore } from '@/stores/state' import { cn } from '@/lib/utils' import { useTranslation } from 'react-i18next' import { navigationService } from '@/services/navigation' -import { ZapIcon, GithubIcon, LogOutIcon } from 'lucide-react' +import { ZapIcon, GithubIcon, LogOutIcon, CheckIcon } from 'lucide-react' import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip' +import {useState, useEffect} from "react"; +import {toast} from 'sonner'; interface NavigationTabProps { value: string @@ -57,6 +59,7 @@ function TabsNavigation() { export default function SiteHeader() { const { t } = useTranslation() const { isGuestMode, coreVersion, apiVersion, username, webuiTitle, webuiDescription } = useAuthStore() + const [workspace, setWorkspace] = useState(''); const versionDisplay = (coreVersion && apiVersion) ? `${coreVersion}/${apiVersion}` @@ -72,6 +75,26 @@ export default function SiteHeader() { navigationService.navigateToLogin(); } + useEffect(() => { + const ws = localStorage.getItem('LIGHTRAG-WORKSPACE') || ''; + setWorkspace(ws); + }, []); + + const handleWorkspaceUpdate = () => { + const trimed = workspace.trim(); + if (trimed) { + localStorage.setItem('LIGHTRAG-WORKSPACE', trimed); + toast.success(t('Workspace set. Reloading page...')); + } else { + localStorage.removeItem('LIGHTRAG-WORKSPACE'); + toast.success(t('Workspace cleared. Reloading page...')); + } + + setTimeout(() => { + window.location.reload(); + }, 500); + } + return (
@@ -111,6 +134,10 @@ export default function SiteHeader() {