diff --git a/.gitignore b/.gitignore index 5092b7b542..538298b0f2 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,8 @@ LightRAG.pdf download_models_hf.py lightrag-dev/ gui/ +/md +/uv.lock # unit-test files test_* diff --git a/Metadata_Filtering.md b/Metadata_Filtering.md new file mode 100644 index 0000000000..3d0236a758 --- /dev/null +++ b/Metadata_Filtering.md @@ -0,0 +1,481 @@ +# Metadata Filtering in LightRAG + +## Overview + +LightRAG supports metadata filtering during queries to retrieve only relevant chunks based on metadata criteria. + +**Important Limitations**: +- Metadata filtering is **only supported for PostgreSQL (PGVectorStorage), with metadata insertion also visible on Neo4j** +- Only **chunk-based queries** support metadata filtering (Mix and Naive modes) +- Metadata is stored in document status and propagated to chunks during extraction + +## Metadata Structure + +Metadata is stored as a dictionary (`dict[str, Any]`) in: +- Entity nodes (graph storage) +- Relationship edges (graph storage) +- Text chunks (KV storage) +- Vector embeddings (vector storage) + +```python +metadata = { + "author": "John Doe", + "department": "Engineering", + "document_type": "technical_spec", + "version": "1.0" +} +``` + +## Critical: Metadata Persistence in Document Status + +**Metadata is stored in DocProcessingStatus** - This ensures metadata is not lost if the processing queue is stopped or interrupted. + +### How It Works + +1. **Document Status Storage** (`lightrag/base.py` - `DocProcessingStatus`) + ```python + @dataclass + class DocProcessingStatus: + # ... other fields + metadata: dict[str, Any] = field(default_factory=dict) + """Additional metadata - PERSISTED across queue restarts""" + ``` + +2. **Metadata Flow**: + - Metadata stored in `DocProcessingStatus.metadata` when document is enqueued + - If queue stops, metadata persists in document status storage + - When processing resumes, metadata is read from document status + - Metadata is propagated to chunks during extraction + +3. **Why This Matters**: + - Queue can be stopped/restarted without losing metadata + - Metadata survives system crashes or interruptions + - Ensures data consistency across processing pipeline + +## Metadata Filtering During Queries + +### MetadataFilter Class + +```python +from lightrag.types import MetadataFilter + +# Simple filter +filter1 = MetadataFilter( + operator="AND", + operands=[{"department": "Engineering"}] +) + +# Complex filter with OR +filter2 = MetadataFilter( + operator="OR", + operands=[ + {"author": "John Doe"}, + {"author": "Jane Smith"} + ] +) + +# Nested filter +filter3 = MetadataFilter( + operator="AND", + operands=[ + {"document_type": "technical_spec"}, + MetadataFilter( + operator="OR", + operands=[ + {"version": "1.0"}, + {"version": "2.0"} + ] + ) + ] +) +``` + +### Supported Operators + +- **AND**: All conditions must be true +- **OR**: At least one condition must be true +- **NOT**: Negates the condition + +## Supported Query Modes + +### Mix Mode (Recommended) +Filters vector chunks from both KG and direct vector search: +```python +query_param = QueryParam( + mode="mix", + metadata_filter=MetadataFilter( + operator="AND", + operands=[ + {"department": "Engineering"}, + {"status": "approved"} + ] + ) +) +``` + +### Naive Mode +Filters vector chunks directly: +```python +query_param = QueryParam( + mode="naive", + metadata_filter=MetadataFilter( + operator="AND", + operands=[{"document_type": "manual"}] + ) +) +``` + +## Implementation Details + +### Architecture Flow + +1. **API Layer** (`lightrag/api/routers/query_routes.py`) + - REST endpoint receives `metadata_filter` as JSON dict + - Converts JSON to `MetadataFilter` object using `MetadataFilter.from_dict()` + +2. **QueryParam** (`lightrag/base.py`) + - `MetadataFilter` object is passed into `QueryParam.metadata_filter` + - QueryParam carries the filter through the query pipeline + +3. **Query Execution** (`lightrag/operate.py`) + - Only chunk-based queries use the filter: + - Line 2749: `chunks_vdb.query(..., metadata_filter=query_param.metadata_filter)` (Mix/Naive modes) + +4. **Storage Layer** (`lightrag/kg/postgres_impl.py`) + - PGVectorStorage: Converts filter to SQL WHERE clause with JSONB operators + +### Code Locations + +Key files implementing metadata support: +- `lightrag/types.py`: `MetadataFilter` class definition +- `lightrag/base.py`: `QueryParam` with `metadata_filter` field, `DocProcessingStatus` with metadata persistence +- `lightrag/api/routers/query_routes.py`: API endpoint that initializes MetadataFilter from JSON +- `lightrag/operate.py`: Query functions that pass filter to storage (Line 2749) +- `lightrag/kg/postgres_impl.py`: PostgreSQL JSONB filter implementation + +## Query Examples + +### Example 1: Filter by Department (Mix Mode) +```python +from lightrag import QueryParam +from lightrag.types import MetadataFilter + +query_param = QueryParam( + mode="mix", + metadata_filter=MetadataFilter( + operator="AND", + operands=[{"department": "Engineering"}] + ) +) + +response = rag.query("What are the key projects?", param=query_param) +``` + +### Example 2: Multi-tenant Filtering (Naive Mode) +```python +query_param = QueryParam( + mode="naive", + metadata_filter=MetadataFilter( + operator="AND", + operands=[ + {"tenant_id": "tenant_a"}, + {"access_level": "admin"} + ] + ) +) + +response = rag.query("Show admin resources", param=query_param) +``` + +### Example 3: Version Filtering (Mix Mode) +```python +query_param = QueryParam( + mode="mix", + metadata_filter=MetadataFilter( + operator="AND", + operands=[ + {"doc_type": "manual"}, + {"status": "current"} + ] + ) +) + +response = rag.query("How to configure?", param=query_param) +``` + +## Storage Backend Support + +**Important**: Metadata filtering is currently only supported for PostgreSQL vector storage. + +### Vector Storage +- **PGVectorStorage**: Full support with JSONB filtering +- **NanoVectorDBStorage**: Not supported +- **MilvusVectorDBStorage**: Not supported +- **ChromaVectorDBStorage**: Not supported +- **FaissVectorDBStorage**: Not supported +- **QdrantVectorDBStorage**: Not supported +- **MongoVectorDBStorage**: Not supported + +### Recommended Configuration + +For metadata filtering support: +```python +rag = LightRAG( + working_dir="./storage", + vector_storage="PGVectorStorage", + # Graph storage can be any type + # ... other config +) +``` + +## Server API Examples + +### REST API Query with Metadata Filter + +#### Simple Filter (Naive Mode) +```bash +curl -X POST http://localhost:9621/query \ + -H "Content-Type: application/json" \ + -d '{ + "query": "What are the key features?", + "mode": "naive", + "metadata_filter": { + "operator": "AND", + "operands": [ + {"department": "Engineering"}, + {"year": 2024} + ] + } + }' +``` + +#### Complex Nested Filter (Mix Mode) +```bash +curl -X POST http://localhost:9621/query \ + -H "Content-Type: application/json" \ + -d '{ + "query": "Show me technical documentation", + "mode": "mix", + "metadata_filter": { + "operator": "AND", + "operands": [ + {"document_type": "technical_spec"}, + { + "operator": "OR", + "operands": [ + {"version": "1.0"}, + {"version": "2.0"} + ] + } + ] + } + }' +``` + +#### Multi-tenant Query (Mix Mode) +```bash +curl -X POST http://localhost:9621/query \ + -H "Content-Type: application/json" \ + -d '{ + "query": "List all projects", + "mode": "mix", + "metadata_filter": { + "operator": "AND", + "operands": [ + {"tenant_id": "tenant_a"}, + {"access_level": "admin"} + ] + }, + "top_k": 20 + }' +``` + +### Python Client with Server + +```python +import requests +from lightrag.types import MetadataFilter + +# Option 1: Use MetadataFilter class and convert to dict +metadata_filter = MetadataFilter( + operator="AND", + operands=[ + {"department": "Engineering"}, + {"status": "approved"} + ] +) + +response = requests.post( + "http://localhost:9621/query", + json={ + "query": "What are the approved engineering documents?", + "mode": "mix", # Use mix or naive mode + "metadata_filter": metadata_filter.to_dict(), + "top_k": 10 + } +) + +# Option 2: Send dict directly (API will convert to MetadataFilter) +response = requests.post( + "http://localhost:9621/query", + json={ + "query": "What are the approved engineering documents?", + "mode": "naive", # Use mix or naive mode + "metadata_filter": { + "operator": "AND", + "operands": [ + {"department": "Engineering"}, + {"status": "approved"} + ] + }, + "top_k": 10 + } +) + +result = response.json() +print(result["response"]) +``` + +### How the API Processes Metadata Filters + +When you send a query to the REST API: + +1. **JSON Request** → API receives `metadata_filter` as a dict +2. **API Conversion** → `MetadataFilter.from_dict()` creates MetadataFilter object +3. **QueryParam** → MetadataFilter is set in `QueryParam.metadata_filter` +4. **Query Execution** → QueryParam with filter is passed to `kg_query()` or `naive_query()` +5. **Storage Query** → Filter is passed to vector storage query methods (chunks only) +6. **SQL** → PGVectorStorage converts filter to JSONB WHERE clause + +## Best Practices + +### 1. Consistent Metadata Schema +```python +# Good - consistent schema +metadata1 = {"author": "John", "dept": "Eng", "year": 2024} +metadata2 = {"author": "Jane", "dept": "Sales", "year": 2024} +``` + +### 2. Simple Indexable Values +```python +# Good - simple values +metadata = { + "status": "approved", + "priority": "high", + "year": 2024 +} +``` + +### 3. Use Appropriate Query Mode +- **Mix mode**: Best for combining KG context with filtered chunks +- **Naive mode**: Best for pure vector search with metadata filtering + +### 4. Performance Considerations +- Keep metadata fields minimal (Should be done automatically by the ORM) +- For PostgreSQL: Create GIN indexes on JSONB metadata columns: + ```sql + CREATE INDEX idx_chunks_metadata ON chunks USING GIN (metadata); + ``` +- Avoid overly complex nested filters + +## Troubleshooting + +### Filter Not Working +1. **Verify storage backend**: Ensure you're using PGVectorStorage +2. **Verify query mode**: Use "mix" or "naive" mode only +3. Verify metadata exists in chunks +4. Check metadata field names match exactly (case-sensitive) +5. Check logs for filter parsing errors +6. Test without filter first to ensure data exists + +### Performance Issues +1. Reduce filter complexity +2. Create GIN indexes on JSONB metadata columns in PostgreSQL +3. Profile query execution time +4. Consider caching frequently used filters + +### Unsupported Storage Backend +If you're using a storage backend that doesn't support metadata filtering: +1. Migrate to PGVectorStorage +2. Or implement post-filtering in application code +3. Or contribute metadata filtering support for your backend + +### Metadata Not Persisting After Queue Restart +- Metadata is stored in `DocProcessingStatus.metadata` +- Check document status storage is properly configured +- Verify metadata is set before document is enqueued + +## API Reference + +### MetadataFilter +```python +class MetadataFilter(BaseModel): + operator: str # "AND", "OR", or "NOT" + operands: List[Union[Dict[str, Any], 'MetadataFilter']] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization""" + ... + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter': + """Create MetadataFilter from dictionary (used by API)""" + ... +``` + +### QueryParam +```python +@dataclass +class QueryParam: + metadata_filter: MetadataFilter | None = None # Filter passed to chunk queries + mode: str = "mix" # Only "mix" and "naive" support metadata filtering + top_k: int = 60 + # ... other fields +``` + +### DocProcessingStatus +```python +@dataclass +class DocProcessingStatus: + # ... other fields + metadata: dict[str, Any] = field(default_factory=dict) + """Additional metadata - PERSISTED across queue restarts""" +``` + +### Query Method +```python +# Synchronous +response = rag.query( + query: str, + param: QueryParam # QueryParam contains metadata_filter +) + +# Asynchronous +response = await rag.aquery( + query: str, + param: QueryParam # QueryParam contains metadata_filter +) +``` + +### REST API Query Endpoint +```python +# In lightrag/api/routers/query_routes.py +@router.post("/query") +async def query_endpoint(request: QueryRequest): + # API receives metadata_filter as dict + metadata_filter_dict = request.metadata_filter + + # Convert dict to MetadataFilter object + metadata_filter = MetadataFilter.from_dict(metadata_filter_dict) if metadata_filter_dict else None + + # Create QueryParam with MetadataFilter + query_param = QueryParam( + mode=request.mode, # Must be "mix" or "naive" + metadata_filter=metadata_filter, + top_k=request.top_k + ) + + # Execute query with QueryParam + result = await rag.aquery(request.query, param=query_param) + return result +``` diff --git a/lightrag/api/config.py b/lightrag/api/config.py index de569f4722..e17c57e72d 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -397,6 +397,15 @@ def parse_args() -> argparse.Namespace: "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int ) + # Token tracking configuration + parser.add_argument( + "--enable-token-tracking", + action="store_true", + default=get_env_value("ENABLE_TOKEN_TRACKING", False, bool), + help="Enable token usage tracking for LLM calls (default: from env or False)", + ) + args.enable_token_tracking = get_env_value("ENABLE_TOKEN_TRACKING", False, bool) + ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index fb0f798562..e0f45cbde6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -301,7 +301,10 @@ def get_cors_origins(): Path(args.working_dir).mkdir(parents=True, exist_ok=True) def create_optimized_openai_llm_func( - config_cache: LLMConfigCache, args, llm_timeout: int + config_cache: LLMConfigCache, + args, + llm_timeout: int, + enable_token_tracking=False, ): """Create optimized OpenAI LLM function with pre-processed configuration""" @@ -325,6 +328,10 @@ async def optimized_openai_alike_model_complete( if config_cache.openai_llm_options: kwargs.update(config_cache.openai_llm_options) + # Remove token_tracker from kwargs if it exists to avoid duplicate argument error + # (we pass it explicitly below) + kwargs.pop("token_tracker", None) + return await openai_complete_if_cache( args.llm_model, prompt, @@ -332,13 +339,19 @@ async def optimized_openai_alike_model_complete( history_messages=history_messages, base_url=args.llm_binding_host, api_key=args.llm_binding_api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, **kwargs, ) return optimized_openai_alike_model_complete def create_optimized_azure_openai_llm_func( - config_cache: LLMConfigCache, args, llm_timeout: int + config_cache: LLMConfigCache, + args, + llm_timeout: int, + enable_token_tracking=False, ): """Create optimized Azure OpenAI LLM function with pre-processed configuration""" @@ -359,8 +372,12 @@ async def optimized_azure_openai_model_complete( # Use pre-processed configuration to avoid repeated parsing kwargs["timeout"] = llm_timeout - if config_cache.openai_llm_options: - kwargs.update(config_cache.openai_llm_options) + if config_cache.azure_openai_llm_options: + kwargs.update(config_cache.azure_openai_llm_options) + # Remove token_tracker from kwargs if it exists to avoid duplicate argument error + # (we pass it explicitly below) + kwargs.pop("token_tracker", None) + return await azure_openai_complete_if_cache( args.llm_model, @@ -370,12 +387,15 @@ async def optimized_azure_openai_model_complete( base_url=args.llm_binding_host, api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key), api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"), + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, **kwargs, ) return optimized_azure_openai_model_complete - def create_llm_model_func(binding: str): + def create_llm_model_func(binding: str, enable_token_tracking=False): """ Create LLM model function based on binding type. Uses optimized functions for OpenAI bindings and lazy import for others. @@ -384,21 +404,42 @@ def create_llm_model_func(binding: str): if binding == "lollms": from lightrag.llm.lollms import lollms_model_complete - return lollms_model_complete + async def lollms_model_complete_with_tracker(*args, **kwargs): + # Add token tracker if enabled + if enable_token_tracking and hasattr(app.state, "token_tracker"): + kwargs["token_tracker"] = app.state.token_tracker + return await lollms_model_complete(*args, **kwargs) + + return lollms_model_complete_with_tracker elif binding == "ollama": from lightrag.llm.ollama import ollama_model_complete - return ollama_model_complete + async def ollama_model_complete_with_tracker(*args, **kwargs): + # Add token tracker if enabled + if enable_token_tracking and hasattr(app.state, "token_tracker"): + kwargs["token_tracker"] = app.state.token_tracker + return await ollama_model_complete(*args, **kwargs) + + return ollama_model_complete_with_tracker elif binding == "aws_bedrock": - return bedrock_model_complete # Already defined locally + + async def bedrock_model_complete_with_tracker(*args, **kwargs): + # Add token tracker if enabled + if enable_token_tracking and hasattr(app.state, "token_tracker"): + kwargs["token_tracker"] = app.state.token_tracker + return await bedrock_model_complete(*args, **kwargs) + + return bedrock_model_complete_with_tracker elif binding == "azure_openai": # Use optimized function with pre-processed configuration return create_optimized_azure_openai_llm_func( - config_cache, args, llm_timeout + config_cache, args, llm_timeout, enable_token_tracking ) else: # openai and compatible # Use optimized function with pre-processed configuration - return create_optimized_openai_llm_func(config_cache, args, llm_timeout) + return create_optimized_openai_llm_func( + config_cache, args, llm_timeout, enable_token_tracking + ) except ImportError as e: raise Exception(f"Failed to import {binding} LLM binding: {e}") @@ -422,7 +463,14 @@ def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict: return {} def create_optimized_embedding_function( - config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args + config_cache: LLMConfigCache, + binding, + model, + host, + api_key, + dimensions, + args, + enable_token_tracking=False, ): """ Create optimized embedding function with pre-processed configuration for applicable bindings. @@ -435,7 +483,13 @@ async def optimized_embedding_function(texts): from lightrag.llm.lollms import lollms_embed return await lollms_embed( - texts, embed_model=model, host=host, api_key=api_key + texts, + embed_model=model, + host=host, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed @@ -455,26 +509,54 @@ async def optimized_embedding_function(texts): host=host, api_key=api_key, options=ollama_options, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - return await azure_openai_embed(texts, model=model, api_key=api_key) + return await azure_openai_embed( + texts, + model=model, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, + ) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - return await bedrock_embed(texts, model=model) + return await bedrock_embed( + texts, + model=model, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, + ) elif binding == "jina": from lightrag.llm.jina import jina_embed return await jina_embed( - texts, dimensions=dimensions, base_url=host, api_key=api_key + texts, + dimensions=dimensions, + base_url=host, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) else: # openai and compatible from lightrag.llm.openai import openai_embed return await openai_embed( - texts, model=model, base_url=host, api_key=api_key + texts, + model=model, + base_url=host, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") @@ -594,7 +676,9 @@ async def server_rerank_func( rag = LightRAG( working_dir=args.working_dir, workspace=args.workspace, - llm_model_func=create_llm_model_func(args.llm_binding), + llm_model_func=create_llm_model_func( + args.llm_binding, args.enable_token_tracking + ), llm_model_name=args.llm_model, llm_model_max_async=args.max_async, summary_max_tokens=args.summary_max_tokens, @@ -604,7 +688,16 @@ async def server_rerank_func( llm_model_kwargs=create_llm_model_kwargs( args.llm_binding, args, llm_timeout ), - embedding_func=embedding_func, + embedding_func=create_optimized_embedding_function( + config_cache, + args.embedding_binding, + args.embedding_model, + args.embedding_binding_host, + args.embedding_binding_api_key, + args.embedding_dim, + args, + args.enable_token_tracking, + ), default_llm_timeout=llm_timeout, default_embedding_timeout=embedding_timeout, kv_storage=args.kv_storage, @@ -629,6 +722,17 @@ async def server_rerank_func( logger.error(f"Failed to initialize LightRAG: {e}") raise + # Initialize token tracking if enabled + token_tracker = None + if args.enable_token_tracking: + from lightrag.utils import TokenTracker + + token_tracker = TokenTracker() + logger.info("Token tracking enabled") + + # Add token tracker to the app state for use in endpoints + app.state.token_tracker = token_tracker + # Add routes app.include_router( create_document_routes( @@ -637,7 +741,7 @@ async def server_rerank_func( api_key, ) ) - app.include_router(create_query_routes(rag, api_key, args.top_k)) + app.include_router(create_query_routes(rag, api_key, args.top_k, token_tracker)) app.include_router(create_graph_routes(rag, api_key)) # Add Ollama API routes diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 7e44b57d3e..f8bbb9ab7e 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -11,11 +11,13 @@ from datetime import datetime, timezone from pathlib import Path from typing import Dict, List, Optional, Any, Literal +import json from fastapi import ( APIRouter, BackgroundTasks, Depends, File, + Form, HTTPException, UploadFile, ) @@ -241,6 +243,7 @@ class InsertResponse(BaseModel): status: Status of the operation (success, duplicated, partial_success, failure) message: Detailed message describing the operation result track_id: Tracking ID for monitoring processing status + token_usage: Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count) """ status: Literal["success", "duplicated", "partial_success", "failure"] = Field( @@ -248,6 +251,10 @@ class InsertResponse(BaseModel): ) message: str = Field(description="Message describing the operation result") track_id: str = Field(description="Tracking ID for monitoring processing status") + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)", + ) class Config: json_schema_extra = { @@ -255,10 +262,17 @@ class Config: "status": "success", "message": "File 'document.pdf' uploaded successfully. Processing will continue in background.", "track_id": "upload_20250729_170612_abc123", + "token_usage": { + "prompt_tokens": 1250, + "completion_tokens": 450, + "total_tokens": 1700, + "call_count": 3, + }, } } + class ClearDocumentsResponse(BaseModel): """Response model for document clearing operation @@ -860,7 +874,7 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str async def pipeline_enqueue_file( - rag: LightRAG, file_path: Path, track_id: str = None + rag: LightRAG, file_path: Path, track_id: str = None, metadata: dict | None = None ) -> tuple[bool, str]: """Add a file to the queue for processing @@ -868,6 +882,7 @@ async def pipeline_enqueue_file( rag: LightRAG instance file_path: Path to the saved file track_id: Optional tracking ID, if not provided will be generated + metadata: Optional metadata to associate with the document Returns: tuple: (success: bool, track_id: str) """ @@ -1239,8 +1254,12 @@ async def pipeline_enqueue_file( return False, track_id try: + # Pass metadata to apipeline_enqueue_documents await rag.apipeline_enqueue_documents( - content, file_paths=file_path.name, track_id=track_id + content, + file_paths=file_path.name, + track_id=track_id, + metadata=metadata, ) logger.info( @@ -1722,19 +1741,21 @@ async def scan_for_new_documents(background_tasks: BackgroundTasks): "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) + background_tasks: BackgroundTasks, + file: UploadFile = File(...), + metadata: Optional[str] = Form(None), ): """ - Upload a file to the input directory and index it. + Upload a file to the input directory and index it with optional metadata. This API endpoint accepts a file through an HTTP POST request, checks if the uploaded file is of a supported type, saves it in the specified input directory, indexes it for retrieval, and returns a success status with relevant details. - + Metadata can be provided to associate custom data with the uploaded document. Args: background_tasks: FastAPI BackgroundTasks for async processing file (UploadFile): The file to be uploaded. It must have an allowed extension. - + metadata (dict, optional): Custom metadata to associate with the document. Returns: InsertResponse: A response object containing the upload status and a message. status can be "success", "duplicated", or error is thrown. @@ -1777,9 +1798,30 @@ async def upload_to_input_dir( track_id = generate_track_id("upload") - # Add to background tasks and get track_id - background_tasks.add_task(pipeline_index_file, rag, file_path, track_id) + # Parse metadata if provided + parsed_metadata = None + if metadata: + try: + parsed_metadata = json.loads(metadata) + if not isinstance(parsed_metadata, dict): + raise ValueError( + "Metadata must be a valid JSON dictionary string." + ) + except json.JSONDecodeError: + raise HTTPException( + status_code=400, detail="Metadata must be a valid JSON string." + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + # Add to background tasks with metadata + background_tasks.add_task( + pipeline_index_file_with_metadata, + rag, + file_path, + track_id, + parsed_metadata, + ) return InsertResponse( status="success", message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.", @@ -1791,6 +1833,35 @@ async def upload_to_input_dir( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + # New function to handle metadata during indexing + async def pipeline_index_file_with_metadata( + rag: LightRAG, file_path: Path, track_id: str, metadata: dict | None + ) -> tuple[bool, str]: + """ + Index a file with metadata by leveraging the existing pipeline. + Args: + rag: LightRAG instance + file_path: Path to the file to index + track_id: Tracking ID for the document + metadata: Optional metadata dictionary to associate with the document + Returns: + tuple[bool, str]: Success status and track ID + """ + # Use the existing pipeline to enqueue the file + success, returned_track_id = await pipeline_enqueue_file( + rag, file_path, track_id, metadata + ) + + if success: + logger.info(f"Successfully enqueued file with metadata: {metadata}") + else: + logger.error("Failed to enqueue file with metadata") + + # Trigger the pipeline processing + await rag.apipeline_process_enqueue_documents() + + return success, returned_track_id + @router.post( "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 53cc41c007..046a1ea486 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -6,8 +6,9 @@ import logging from typing import Any, Dict, List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request, Response from lightrag.base import QueryParam +from lightrag.types import MetadataFilter from lightrag.api.utils_api import get_combined_auth_dependency from pydantic import BaseModel, Field, field_validator @@ -22,6 +23,11 @@ class QueryRequest(BaseModel): description="The query text", ) + metadata_filter: MetadataFilter | None = Field( + default=None, + description="Optional metadata filter for nodes and edges. Can be a MetadataFilter object or a dict that will be converted to MetadataFilter.", + ) + mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( default="mix", description="Query mode", @@ -74,7 +80,7 @@ class QueryRequest(BaseModel): ) conversation_history: Optional[List[Dict[str, Any]]] = Field( - default=None, + default=[], description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", ) @@ -117,6 +123,16 @@ def conversation_history_role_check( raise ValueError("Each message 'role' must be a non-empty string.") return conversation_history + @field_validator("metadata_filter", mode="before") + @classmethod + def metadata_filter_convert(cls, v): + """Convert dict inputs to MetadataFilter objects.""" + if v is None: + return None + if isinstance(v, dict): + return MetadataFilter.from_dict(v) + return v + def to_query_params(self, is_stream: bool) -> "QueryParam": """Converts a QueryRequest instance into a QueryParam instance.""" # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically @@ -125,6 +141,11 @@ def to_query_params(self, is_stream: bool) -> "QueryParam": # Ensure `mode` and `stream` are set explicitly param = QueryParam(**request_data) param.stream = is_stream + + # Ensure metadata_filter remains as MetadataFilter object if it exists + if self.metadata_filter: + param.metadata_filter = self.metadata_filter + return param @@ -136,6 +157,10 @@ class QueryResponse(BaseModel): default=None, description="Reference list (Disabled when include_references=False, /query/data always includes references.)", ) + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)", + ) class QueryDataResponse(BaseModel): @@ -147,6 +172,10 @@ class QueryDataResponse(BaseModel): metadata: Dict[str, Any] = Field( description="Query metadata including mode, keywords, and processing information" ) + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)", + ) class StreamChunkResponse(BaseModel): @@ -162,9 +191,18 @@ class StreamChunkResponse(BaseModel): error: Optional[str] = Field( default=None, description="Error message if processing fails" ) + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the entire query (only in final chunk)", + ) -def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): +def create_query_routes( + rag, + api_key: Optional[str] = None, + top_k: int = 60, + token_tracker: Optional[Any] = None, +): combined_auth = get_combined_auth_dependency(api_key) @router.post( @@ -199,7 +237,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): }, "examples": { "with_references": { - "summary": "Response with references", + "summary": "Response with references and token usage", "description": "Example response when include_references=True", "value": { "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.", @@ -213,13 +251,25 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): "file_path": "/documents/machine_learning.txt", }, ], + "token_usage": { + "prompt_tokens": 245, + "completion_tokens": 87, + "total_tokens": 332, + "call_count": 1, + }, }, }, "without_references": { - "summary": "Response without references", + "summary": "Response without references but with token usage", "description": "Example response when include_references=False", "value": { - "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving." + "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.", + "token_usage": { + "prompt_tokens": 245, + "completion_tokens": 87, + "total_tokens": 332, + "call_count": 1, + }, }, }, "different_modes": { @@ -337,6 +387,10 @@ async def query_text(request: QueryRequest): - 500: Internal processing error (e.g., LLM service unavailable) """ try: + # Reset token tracker at start of query if available + if token_tracker: + token_tracker.reset() + param = request.to_query_params( False ) # Ensure stream=False for non-streaming endpoint @@ -355,11 +409,22 @@ async def query_text(request: QueryRequest): if not response_content: response_content = "No relevant context found for the query." + # Get token usage if available + token_usage = None + if token_tracker: + token_usage = token_tracker.get_usage() + # Return response with or without references based on request if request.include_references: - return QueryResponse(response=response_content, references=references) + return QueryResponse( + response=response_content, + references=references, + token_usage=token_usage, + ) else: - return QueryResponse(response=response_content, references=None) + return QueryResponse( + response=response_content, references=None, token_usage=token_usage + ) except Exception as e: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -556,6 +621,10 @@ async def query_text_stream(request: QueryRequest): Use streaming mode for real-time interfaces and non-streaming for batch processing. """ try: + # Reset token tracker at start of query if available + if token_tracker: + token_tracker.reset() + # Use the stream parameter from the request, defaulting to True if not specified stream_mode = request.stream if request.stream is not None else True param = request.to_query_params(stream_mode) @@ -584,6 +653,10 @@ async def stream_generator(): except Exception as e: logging.error(f"Streaming error: {str(e)}") yield f"{json.dumps({'error': str(e)})}\n" + + # Add final token usage chunk if streaming and token tracker is available + if token_tracker and llm_response.get("is_streaming"): + yield f"{json.dumps({'token_usage': token_tracker.get_usage()})}\n" else: # Non-streaming mode: send complete response in one message response_content = llm_response.get("content", "") @@ -595,6 +668,10 @@ async def stream_generator(): if request.include_references: complete_response["references"] = references + # Add token usage if available + if token_tracker: + complete_response["token_usage"] = token_tracker.get_usage() + yield f"{json.dumps(complete_response)}\n" return StreamingResponse( @@ -1001,18 +1078,31 @@ async def query_data(request: QueryRequest): as structured data analysis typically requires source attribution. """ try: + # Reset token tracker at start of query if available + if token_tracker: + token_tracker.reset() + param = request.to_query_params(False) # No streaming for data endpoint response = await rag.aquery_data(request.query, param=param) + # Get token usage if available + token_usage = None + if token_tracker: + token_usage = token_tracker.get_usage() + # aquery_data returns the new format with status, message, data, and metadata if isinstance(response, dict): - return QueryDataResponse(**response) + response_dict = dict(response) + response_dict["token_usage"] = token_usage + return QueryDataResponse(**response_dict) else: # Handle unexpected response format return QueryDataResponse( status="failure", - message="Invalid response type", + message="Unexpected response format", data={}, + metadata={}, + token_usage=None, ) except Exception as e: trace_exception(e) diff --git a/lightrag/base.py b/lightrag/base.py index b9ebeca80b..42cdbb5342 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -15,9 +15,10 @@ Dict, List, AsyncIterator, + Union, ) from .utils import EmbeddingFunc -from .types import KnowledgeGraph +from .types import KnowledgeGraph, MetadataFilter from .constants import ( GRAPH_FIELD_SEP, DEFAULT_TOP_K, @@ -39,6 +40,11 @@ load_dotenv(dotenv_path=".env", override=False) + + + + + class OllamaServerInfos: def __init__(self, name=None, tag=None): self._lightrag_name = name or os.getenv( @@ -163,6 +169,9 @@ class QueryParam: Default is True to enable reranking when rerank model is available. """ + metadata_filter: MetadataFilter | None = None + """Metadata for filtering nodes and edges, allowing for more precise querying.""" + include_references: bool = False """If True, includes reference list in the response for supported endpoints. This parameter controls whether the API response includes a references field @@ -223,7 +232,7 @@ class BaseVectorStorage(StorageNameSpace, ABC): @abstractmethod async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: MetadataFilter | None = None ) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results. @@ -445,6 +454,12 @@ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | N or None if the node doesn't exist """ + async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]: + """Get node IDs that match the given metadata filter with logical expressions.""" + # Default implementation - subclasses should override this method + # This is a placeholder that will be overridden by specific implementations + raise NotImplementedError("Subclasses must implement get_nodes_by_metadata_filter") + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """Get nodes as a batch using UNWIND diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 7d6a6dac58..ee1f0d0ccb 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -180,7 +180,7 @@ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: return [m["__id__"] for m in list_data] async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: dict[str, Any] | None = None ) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f2368afeda..35e6bd6d59 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1047,7 +1047,7 @@ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: return results async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: dict[str, Any] | None = None ) -> list[dict[str, Any]]: # Ensure collection is loaded before querying self._ensure_collection_loaded() diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 0c11022e26..787d848d5a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -2216,7 +2216,7 @@ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: return list_data async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: dict[str, Any] | None = None ) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" if query_embedding is not None: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index def5a83d3c..271619d6fd 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -137,7 +137,7 @@ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: ) async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: dict[str, Any] | None = None ) -> list[dict[str, Any]]: # Use provided embedding or compute it if query_embedding is not None: diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 896e597370..fb72730958 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,10 +1,12 @@ import os import re +import json from dataclasses import dataclass -from typing import final +from typing import final, Any import configparser + from tenacity import ( retry, stop_after_attempt, @@ -14,7 +16,7 @@ import logging from ..utils import logger -from ..base import BaseGraphStorage +from ..base import BaseGraphStorage, MetadataFilter from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock @@ -424,6 +426,89 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: await result.consume() # Ensure results are consumed even on error raise + async def get_nodes_by_metadata_filter(self, metadata_filter: MetadataFilter | None) -> list[str]: + """Get node IDs that match the given metadata filter with logical expressions.""" + workspace_label = self._get_workspace_label() + + # Build metadata conditions + params = {} + condition, params = self._build_metadata_conditions(metadata_filter, params) + + if not condition: + # If no condition, return empty list for safety + return [] + + # Build the query + query = f""" + MATCH (n:`{workspace_label}`) + WHERE {condition} + RETURN n.entity_id AS entity_id + """ + + async with self._driver.session(database=self._DATABASE) as session: + result = await session.run(query, params) + return [record["entity_id"] async for record in result] + + def _build_metadata_conditions( + self, + metadata_filter: MetadataFilter | None, + params: dict[str, Any], + node_var: str = "n" + ) -> tuple[str, dict[str, Any]]: + """ + Build Cypher WHERE conditions from a MetadataFilter. + + Args: + metadata_filter: The MetadataFilter object + params: Dictionary to collect parameters for the query + node_var: The variable name for the node in the Cypher query + + Returns: + Tuple of (condition_string, updated_params) + """ + if metadata_filter is None: + return "", params + + conditions = [] + + for operand in metadata_filter.operands: + if isinstance(operand, MetadataFilter): + # Recursive call for nested filters + sub_condition, params = self._build_metadata_conditions(operand, params, node_var) + if sub_condition: + conditions.append(f"({sub_condition})") + else: + # Simple key-value pair + for key, value in operand.items(): + prop_name = f"meta_{key}" # Using our prefix + param_name = f"{prop_name}_{len(params)}" + + if value is None: + # Check for existence of the key + conditions.append(f"{node_var}.{prop_name} IS NOT NULL") + else: + # Check for specific value + conditions.append(f"{node_var}.{prop_name} = ${param_name}") + params[param_name] = value + + if not conditions: + return "", params + + # Join conditions with the operator + if metadata_filter.operator == "AND": + condition = " AND ".join(conditions) + elif metadata_filter.operator == "OR": + condition = " OR ".join(conditions) + elif metadata_filter.operator == "NOT": + if len(conditions) == 1: + condition = f"NOT ({conditions[0]})" + else: + condition = f"NOT ({' AND '.join(conditions)})" + else: + raise ValueError(f"Unknown operator: {metadata_filter.operator}") + + return condition, params + async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties @@ -962,7 +1047,11 @@ async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: ) ), ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + async def upsert_node( + self, + node_id: str, + node_data: dict[str, str], + ) -> None: """ Upsert a node in the Neo4j database. @@ -971,8 +1060,25 @@ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: node_data: Dictionary of node properties """ workspace_label = self._get_workspace_label() - properties = node_data - entity_type = properties["entity_type"] + properties = node_data.copy() + + metadata = properties.pop("metadata", None) + for key, value in metadata.items(): + neo4j_key = key + + # Handle complex data types by converting them to strings + if isinstance(value, (dict, list)): + try: + properties[neo4j_key] = json.dumps(value) + except Exception as e: + logger.warning( + f"Failed to serialize metadata field {key} for node {node_id}: {e}" + ) + properties[neo4j_key] = str(value) + else: + properties[neo4j_key] = value + + entity_type = properties.get("entity_type", "Unknown") if "entity_id" not in properties: raise ValueError("Neo4j: node properties must contain an 'entity_id' field") @@ -992,7 +1098,7 @@ async def execute_upsert(tx: AsyncManagedTransaction): await session.execute_write(execute_upsert) except Exception as e: - logger.error(f"[{self.workspace}] Error during upsert: {str(e)}") + logger.error(f"[{self.workspace}] Error during node upsert: {str(e)}") raise @retry( @@ -1026,12 +1132,23 @@ async def upsert_edge( Raises: ValueError: If either source or target node does not exist or is not unique """ + edge_properties = edge_data + workspace_label = self._get_workspace_label() + + # Extract metadata if present and handle it properly + metadata = edge_properties.pop("metadata", None) + if metadata and isinstance(metadata, dict): + # Serialize metadata to JSON string + try: + edge_properties["metadata"] = json.dumps(metadata) + except Exception as e: + logger.warning(f"Failed to serialize metadata: {e}") + edge_properties["metadata"] = str(metadata) + try: - edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - workspace_label = self._get_workspace_label() query = f""" MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) WITH source diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 50c2108fff..afbc407d6f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -11,7 +11,12 @@ import ssl import itertools -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from lightrag.types import ( + KnowledgeGraph, + KnowledgeGraphNode, + KnowledgeGraphEdge, + MetadataFilter, +) from tenacity import ( retry, @@ -964,6 +969,20 @@ async def check_tables(self): logger.error( f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}" ) + + # Create GIN indexes for JSONB metadata columns + try: + await self._create_gin_metadata_indexes() + except Exception as e: + logger.error(f"PostgreSQL, Failed to create GIN metadata indexes: {e}") + # Compatibility check - add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS + try: + await self.add_metadata_to_tables() + except Exception as e: + logger.error( + f"PostgreSQL, Failed to add metadata columns to existing tables: {e}" + ) + # After all tables are created, attempt to migrate timestamp fields try: await self._migrate_timestamp_columns() @@ -1043,6 +1062,52 @@ async def check_tables(self): f"PostgreSQL, Failed to create full entities/relations tables: {e}" ) + async def add_metadata_to_tables(self): + """Add metadata columns to LIGHTRAG_DOC_CHUNKS and LIGHTRAG_VDB_CHUNKS tables if they don't exist""" + tables_to_check = [ + { + "name": "LIGHTRAG_DOC_CHUNKS", + "description": "Document chunks storage table", + }, + { + "name": "LIGHTRAG_VDB_CHUNKS", + "description": "Vector database chunks storage table", + }, + ] + + for table_info in tables_to_check: + table_name = table_info["name"] + try: + # Check if metadata column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = $1 + AND column_name = 'metadata' + """ + + column_info = await self.query( + check_column_sql, {"table_name": table_name.lower()} + ) + + if not column_info: + logger.info(f"Adding metadata column to {table_name} table") + add_column_sql = f""" + ALTER TABLE {table_name} + ADD COLUMN metadata JSONB NULL DEFAULT '{{}}'::jsonb + """ + await self.execute(add_column_sql) + logger.info( + f"Successfully added metadata column to {table_name} table" + ) + else: + logger.debug( + f"metadata column already exists in {table_name} table" + ) + + except Exception as e: + logger.warning(f"Failed to add metadata column to {table_name}: {e}") + async def _migrate_create_full_entities_relations_tables(self): """Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist""" tables_to_check = [ @@ -1246,6 +1311,37 @@ async def _create_ivfflat_vector_indexes(self): except Exception as e: logger.error(f"Failed to create ivfflat index on {k}: {e}") + async def _create_gin_metadata_indexes(self): + """Create GIN indexes for JSONB metadata columns to speed up metadata filtering""" + metadata_tables = [ + "LIGHTRAG_DOC_CHUNKS", + "LIGHTRAG_VDB_CHUNKS", + "LIGHTRAG_DOC_STATUS", + ] + + for table in metadata_tables: + index_name = f"idx_{table.lower()}_metadata_gin" + check_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{index_name}' + AND tablename = '{table.lower()}' + """ + + try: + index_exists = await self.query(check_index_sql) + if not index_exists: + create_index_sql = f""" + CREATE INDEX CONCURRENTLY IF NOT EXISTS {index_name} + ON {table} USING gin (metadata jsonb_path_ops) + """ + logger.info(f"PostgreSQL, Creating GIN index {index_name} on table {table}") + await self.execute(create_index_sql) + logger.info(f"PostgreSQL, Successfully created GIN index {index_name} on table {table}") + else: + logger.info(f"PostgreSQL, GIN index {index_name} already exists on table {table}") + except Exception as e: + logger.error(f"PostgreSQL, Failed to create GIN index on table {table}, Got: {e}") + async def query( self, sql: str, @@ -1784,6 +1880,7 @@ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: "llm_cache_list": json.dumps(v.get("llm_cache_list", [])), "create_time": current_time, "update_time": current_time, + "metadata": json.dumps(v.get("metadata", {})), } await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): @@ -1951,6 +2048,7 @@ def _upsert_chunks( "file_path": item["file_path"], "create_time": current_time, "update_time": current_time, + "metadata": json.dumps(item.get("metadata", {})), } except Exception as e: logger.error( @@ -2045,9 +2143,59 @@ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: await self.db.execute(upsert_sql, data) + ############# Metadata building function ################# + @staticmethod + def build_metadata_filter_clause(metadata_filter): + if not metadata_filter: + return "" + + def build_single_condition(key, value): + #Check and build "IN" conditions for the operand + if isinstance(value, list): + if not value: + return "1=0" + + in_conditions = [ + f"metadata @> '{{\"{key}\": {json.dumps(v)}}}'" for v in value + ] + return f"({ ' OR '.join(in_conditions) })" + + else: + # Use for scalars and dictionaries + json_value = json.dumps(value) + return f"metadata @> '{{\"{ key}\" : {json_value}}}'" + + try: + if isinstance(metadata_filter, dict): + conditions = [build_single_condition(k, v) for k, v in metadata_filter.items()] + return " AND " + " AND ".join(conditions) if conditions else "" + elif hasattr(metadata_filter, 'operands'): + sub_conditions = [] + for operand in metadata_filter.operands: + if isinstance(operand, dict): + conds = [build_single_condition(k, v) for k, v in operand.items()] + if conds: + sub_conditions.append("(" + " AND ".join(conds) + ")") + + if sub_conditions: + op = getattr(metadata_filter, 'operator', 'AND').upper() + connector = " OR " if op == "OR" else " AND " + prefix = " AND NOT (" if op == "NOT" else " AND (" + return prefix + connector.join(sub_conditions) + ")" + return "" + except Exception: + # Simple fallback + if isinstance(metadata_filter, dict): + return f" AND metadata @> '{json.dumps(metadata_filter)}'" + return "" + #################### query method ############### async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, + query: str, + top_k: int, + query_embedding: list[float] = None, + metadata_filter: MetadataFilter | None = None, ) -> list[dict[str, Any]]: if query_embedding is not None: embedding = query_embedding @@ -2058,8 +2206,11 @@ async def query( embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) - - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + metadata_filter_clause = self.build_metadata_filter_clause(metadata_filter) if metadata_filter else "" + sql = SQL_TEMPLATES[self.namespace].format( + embedding_string=embedding_string, + metadata_filter_clause=metadata_filter_clause, + ) params = { "workspace": self.workspace, "closer_than_threshold": 1 - self.cosine_better_than_threshold, @@ -4523,6 +4674,7 @@ def namespace_to_table_name(namespace: str) -> str: content TEXT, file_path TEXT NULL, llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, + metadata JSONB NULL DEFAULT '{}'::jsonb, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -4538,6 +4690,7 @@ def namespace_to_table_name(namespace: str) -> str: content TEXT, content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), file_path TEXT NULL, + metadata JSONB NULL DEFAULT '{{}}'::jsonb, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -4703,8 +4856,8 @@ def namespace_to_table_name(namespace: str) -> str: """, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, file_path, llm_cache_list, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + create_time, update_time, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, @@ -4712,7 +4865,8 @@ def namespace_to_table_name(namespace: str) -> str: content = EXCLUDED.content, file_path=EXCLUDED.file_path, llm_cache_list=EXCLUDED.llm_cache_list, - update_time = EXCLUDED.update_time + update_time = EXCLUDED.update_time, + metadata = EXCLUDED.metadata """, "upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count, create_time, update_time) @@ -4733,8 +4887,8 @@ def namespace_to_table_name(namespace: str) -> str: # SQL for VectorStorage "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + create_time, update_time, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, @@ -4742,7 +4896,8 @@ def namespace_to_table_name(namespace: str) -> str: content = EXCLUDED.content, content_vector=EXCLUDED.content_vector, file_path=EXCLUDED.file_path, - update_time = EXCLUDED.update_time + update_time = EXCLUDED.update_time, + metadata = EXCLUDED.metadata """, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector, chunk_ids, file_path, create_time, update_time) @@ -4767,36 +4922,72 @@ def namespace_to_table_name(namespace: str) -> str: file_path=EXCLUDED.file_path, update_time = EXCLUDED.update_time """, - "relationships": """ - SELECT r.source_id AS src_id, - r.target_id AS tgt_id, - EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_RELATION r - WHERE r.workspace = $1 - AND r.content_vector <=> '[{embedding_string}]'::vector < $2 - ORDER BY r.content_vector <=> '[{embedding_string}]'::vector - LIMIT $3; - """, - "entities": """ - SELECT e.entity_name, - EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_ENTITY e - WHERE e.workspace = $1 - AND e.content_vector <=> '[{embedding_string}]'::vector < $2 - ORDER BY e.content_vector <=> '[{embedding_string}]'::vector +"relationships": """ + WITH filtered_chunks AS ( + SELECT + c.id + FROM + LIGHTRAG_VDB_CHUNKS c + WHERE + c.workspace = $1 + {metadata_filter_clause} + ) + SELECT + r.source_id AS src_id, + r.target_id AS tgt_id, + EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at + FROM + LIGHTRAG_VDB_RELATION r + JOIN + filtered_chunks fc ON r.chunk_ids && ARRAY[fc.id] -- Find relationships linked to our valid chunks + WHERE + r.workspace = $1 + AND r.content_vector <=> '[{embedding_string}]'::vector < $2 + ORDER BY + r.content_vector <=> '[{embedding_string}]'::vector -- Rank the resulting relationships LIMIT $3; - """, + """, +"entities": """ + WITH filtered_chunks AS ( + SELECT + c.id + FROM + LIGHTRAG_VDB_CHUNKS c + WHERE + c.workspace = $1 + {metadata_filter_clause} + ) + SELECT + e.entity_name, + EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at + FROM + LIGHTRAG_VDB_ENTITY e + JOIN + filtered_chunks fc ON e.chunk_ids && ARRAY[fc.id] + WHERE + e.workspace = $1 + AND e.content_vector <=> '[{embedding_string}]'::vector < $2 + ORDER BY + e.content_vector <=> '[{embedding_string}]'::vector + LIMIT $3; + """, "chunks": """ - SELECT c.id, - c.content, - c.file_path, - EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_CHUNKS c - WHERE c.workspace = $1 - AND c.content_vector <=> '[{embedding_string}]'::vector < $2 - ORDER BY c.content_vector <=> '[{embedding_string}]'::vector - LIMIT $3; - """, + SELECT c.id, + c.content, + c.file_path, + EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at, + c.metadata, + c.content_vector <=> '[{embedding_string}]'::vector AS distance + + FROM LIGHTRAG_VDB_CHUNKS c + WHERE + c.workspace = $1 + {metadata_filter_clause} + AND c.content_vector <=> '[{embedding_string}]'::vector < $2 + ORDER BY + c.content_vector <=> '[{embedding_string}]'::vector + LIMIT $3; + """, # DROP tables "drop_specifiy_table_workspace": """ DELETE FROM {table_name} WHERE workspace=$1 diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index de1d07e7e6..81351001d9 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -212,7 +212,7 @@ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: return results async def query( - self, query: str, top_k: int, query_embedding: list[float] = None + self, query: str, top_k: int, query_embedding: list[float] = None, metadata_filter: dict[str, Any] | None = None ) -> list[dict[str, Any]]: if query_embedding is not None: embedding = query_embedding diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b4345405de..05827dc945 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -870,7 +870,8 @@ def insert( ids: str | list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, - ) -> str: + token_tracker: TokenTracker | None = None, + ) -> tuple[str, dict]: """Sync Insert documents with checkpoint support Args: @@ -884,10 +885,10 @@ def insert( track_id: tracking ID for monitoring processing status, if not provided, will be generated Returns: - str: tracking ID for monitoring processing status + tuple[str, dict]: (tracking ID for monitoring processing status, token usage statistics) """ loop = always_get_an_event_loop() - return loop.run_until_complete( + result = loop.run_until_complete( self.ainsert( input, split_by_character, @@ -895,8 +896,10 @@ def insert( ids, file_paths, track_id, + token_tracker, ) ) + return result async def ainsert( self, @@ -906,7 +909,8 @@ async def ainsert( ids: str | list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, - ) -> str: + token_tracker: TokenTracker | None = None, + ) -> tuple[str, dict]: """Async Insert documents with checkpoint support Args: @@ -928,10 +932,17 @@ async def ainsert( await self.apipeline_enqueue_documents(input, ids, file_paths, track_id) await self.apipeline_process_enqueue_documents( - split_by_character, split_by_character_only + split_by_character, + split_by_character_only, + token_tracker, ) - return track_id + return track_id, token_tracker.get_usage() if token_tracker else { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "call_count": 0, + } # TODO: deprecated, use insert instead def insert_custom_chunks( @@ -1011,6 +1022,7 @@ async def apipeline_enqueue_documents( ids: list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, + metadata: dict | None = None, ) -> str: """ Pipeline for Processing Documents @@ -1025,6 +1037,7 @@ async def apipeline_enqueue_documents( ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated file_paths: list of file paths corresponding to each document, used for citation track_id: tracking ID for monitoring processing status, if not provided, will be generated with "enqueue" prefix + metadata: Optional metadata to associate with the documents Returns: str: tracking ID for monitoring processing status @@ -1038,6 +1051,8 @@ async def apipeline_enqueue_documents( ids = [ids] if isinstance(file_paths, str): file_paths = [file_paths] + if isinstance(metadata, dict): + metadata = [metadata] # If file_paths is provided, ensure it matches the number of documents if file_paths is not None: @@ -1102,8 +1117,11 @@ async def apipeline_enqueue_documents( "file_path" ], # Store file path in document status "track_id": track_id, # Store track_id in document status + "metadata": metadata[i] + if isinstance(metadata, list) and i < len(metadata) + else metadata, # added provided custom metadata per document } - for id_, content_data in contents.items() + for i, (id_, content_data) in enumerate(contents.items()) } # 3. Filter out already processed documents @@ -1334,7 +1352,7 @@ async def _validate_and_fix_document_consistency( "track_id": getattr(status_doc, "track_id", ""), # Clear any error messages and processing metadata "error_msg": "", - "metadata": {}, + "metadata": getattr(status_doc, "metadata", {}), } # Update the status in to_process_docs as well @@ -1357,6 +1375,7 @@ async def apipeline_process_enqueue_documents( self, split_by_character: str | None = None, split_by_character_only: bool = False, + token_tracker: TokenTracker | None = None, ) -> None: """ Process pending documents by splitting them into chunks, processing @@ -1478,8 +1497,12 @@ async def process_document( pipeline_status: dict, pipeline_status_lock: asyncio.Lock, semaphore: asyncio.Semaphore, + token_tracker: TokenTracker | None = None, ) -> None: """Process single document""" + doc_metadata = getattr(status_doc, "metadata", None) + if doc_metadata is None: + doc_metadata = {} file_extraction_stage_ok = False async with semaphore: nonlocal processed_count @@ -1533,6 +1556,7 @@ async def process_document( "full_doc_id": doc_id, "file_path": file_path, # Add file path to each chunk "llm_cache_list": [], # Initialize empty LLM cache list for each chunk + "metadata": doc_metadata, } for dp in self.chunking_func( self.tokenizer, @@ -1552,6 +1576,10 @@ async def process_document( # Process document in two stages # Stage 1: Process text chunks and docs (parallel execution) + doc_metadata["processing_start_time"] = ( + processing_start_time + ) + doc_status_task = asyncio.create_task( self.doc_status.upsert( { @@ -1569,9 +1597,7 @@ async def process_document( ).isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time - }, + "metadata": doc_metadata, } } ) @@ -1596,8 +1622,12 @@ async def process_document( # Stage 2: Process entity relation graph (after text_chunks are saved) entity_relation_task = asyncio.create_task( - self._process_extract_entities( - chunks, pipeline_status, pipeline_status_lock + self._process_entity_relation_graph( + chunks, + doc_metadata, + pipeline_status, + pipeline_status_lock, + token_tracker, ) ) await entity_relation_task @@ -1631,6 +1661,10 @@ async def process_document( processing_end_time = int(time.time()) # Update document status to failed + doc_metadata["processing_start_time"] = ( + processing_start_time + ) + doc_metadata["processing_end_time"] = processing_end_time await self.doc_status.upsert( { doc_id: { @@ -1644,10 +1678,7 @@ async def process_document( ).isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time, - "processing_end_time": processing_end_time, - }, + "metadata": doc_metadata, } } ) @@ -1672,10 +1703,17 @@ async def process_document( current_file_number=current_file_number, total_files=total_files, file_path=file_path, + metadata=doc_metadata, # NEW: Pass metadata to merge function ) # Record processing end time processing_end_time = int(time.time()) + doc_metadata["processing_start_time"] = ( + processing_start_time + ) + doc_metadata["processing_end_time"] = ( + processing_end_time + ) await self.doc_status.upsert( { @@ -1691,10 +1729,7 @@ async def process_document( ).isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time, - "processing_end_time": processing_end_time, - }, + "metadata": doc_metadata, } } ) @@ -1732,6 +1767,13 @@ async def process_document( processing_end_time = int(time.time()) # Update document status to failed + doc_metadata["processing_start_time"] = ( + processing_start_time + ) + doc_metadata["processing_end_time"] = ( + processing_end_time + ) + await self.doc_status.upsert( { doc_id: { @@ -1743,10 +1785,7 @@ async def process_document( "updated_at": datetime.now().isoformat(), "file_path": file_path, "track_id": status_doc.track_id, # Preserve existing track_id - "metadata": { - "processing_start_time": processing_start_time, - "processing_end_time": processing_end_time, - }, + "metadata": doc_metadata, } } ) @@ -1763,6 +1802,7 @@ async def process_document( pipeline_status, pipeline_status_lock, semaphore, + token_tracker, ) ) @@ -1806,17 +1846,24 @@ async def process_document( pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - async def _process_extract_entities( - self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None + async def _process_entity_relation_graph( + self, + chunk: dict[str, Any], + metadata: dict | None, + pipeline_status=None, + pipeline_status_lock=None, + token_tracker: TokenTracker | None = None, ) -> list: try: chunk_results = await extract_entities( chunk, global_config=asdict(self), + metadata=metadata, # Pass metadata here pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, text_chunks_storage=self.text_chunks, + token_tracker=token_tracker, ) return chunk_results except Exception as e: @@ -2041,14 +2088,18 @@ def query( self, query: str, param: QueryParam = QueryParam(), + token_tracker: TokenTracker | None = None, system_prompt: str | None = None, - ) -> str | Iterator[str]: + ) -> Any: """ - Perform a sync query. + User query interface (backward compatibility wrapper). + + Delegates to aquery() for asynchronous execution and returns the result. Args: query (str): The query to be executed. param (QueryParam): Configuration parameters for query execution. + token_tracker (TokenTracker | None): Optional token tracker for monitoring usage. prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. Returns: @@ -2056,14 +2107,17 @@ def query( """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore + return loop.run_until_complete( + self.aquery(query, param, token_tracker, system_prompt) + ) # type: ignore async def aquery( self, query: str, param: QueryParam = QueryParam(), + token_tracker: TokenTracker | None = None, system_prompt: str | None = None, - ) -> str | AsyncIterator[str]: + ) -> Any: """ Perform a async query (backward compatibility wrapper). @@ -2082,7 +2136,7 @@ async def aquery( - Streaming: Returns AsyncIterator[str] """ # Call the new aquery_llm function to get complete results - result = await self.aquery_llm(query, param, system_prompt) + result = await self.aquery_llm(query, param, system_prompt, token_tracker) # Extract and return only the LLM response for backward compatibility llm_response = result.get("llm_response", {}) @@ -2117,6 +2171,7 @@ async def aquery_data( self, query: str, param: QueryParam = QueryParam(), + token_tracker: TokenTracker | None = None, ) -> dict[str, Any]: """ Asynchronous data retrieval API: returns structured retrieval results without LLM generation. @@ -2310,6 +2365,7 @@ async def aquery_llm( query: str, param: QueryParam = QueryParam(), system_prompt: str | None = None, + token_tracker: TokenTracker | None = None, ) -> dict[str, Any]: """ Asynchronous complete query API: returns structured retrieval results with LLM generation. @@ -2344,6 +2400,7 @@ async def aquery_llm( hashing_kv=self.llm_response_cache, system_prompt=system_prompt, chunks_vdb=self.chunks_vdb, + token_tracker=token_tracker, ) elif param.mode == "naive": query_result = await naive_query( @@ -2353,6 +2410,7 @@ async def aquery_llm( global_config, hashing_kv=self.llm_response_cache, system_prompt=system_prompt, + token_tracker=token_tracker, ) elif param.mode == "bypass": # Bypass mode: directly use LLM without knowledge retrieval diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 824ff088b3..e2793238c8 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -46,6 +46,7 @@ async def azure_openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, api_version: str | None = None, + token_tracker: Any | None = None, **kwargs, ): if enable_cot: @@ -94,28 +95,73 @@ async def azure_openai_complete_if_cache( ) if hasattr(response, "__aiter__"): + final_chunk_usage = None + accumulated_response = "" async def inner(): - async for chunk in response: - if len(chunk.choices) == 0: - continue - content = chunk.choices[0].delta.content - if content is None: - continue - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - yield content + nonlocal final_chunk_usage, accumulated_response + try: + async for chunk in response: + if len(chunk.choices) == 0: + continue + content = chunk.choices[0].delta.content + if content is None: + continue + accumulated_response += content + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + + # Check for usage in the last chunk + if hasattr(chunk, "usage") and chunk.usage is not None: + final_chunk_usage = chunk.usage + except Exception as e: + logger.error(f"Error in Azure OpenAI stream response: {str(e)}") + raise + finally: + # After streaming is complete, track token usage + if token_tracker and final_chunk_usage: + # Use actual usage from the API + token_counts = { + "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0), + "completion_tokens": getattr( + final_chunk_usage, "completion_tokens", 0 + ), + "total_tokens": getattr(final_chunk_usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Azure OpenAI streaming token usage: {token_counts}") + elif token_tracker: + logger.debug( + "No usage information available in Azure OpenAI streaming response" + ) return inner() else: content = response.choices[0].message.content if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) + + # Track token usage for non-streaming response + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "completion_tokens": getattr(response.usage, "completion_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Azure OpenAI non-streaming token usage: {token_counts}") + return content async def azure_openai_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + token_tracker=None, + **kwargs, ) -> str: kwargs.pop("keyword_extraction", None) result = await azure_openai_complete_if_cache( @@ -123,6 +169,7 @@ async def azure_openai_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + token_tracker=token_tracker, **kwargs, ) return result @@ -142,6 +189,7 @@ async def azure_openai_embed( base_url: str | None = None, api_key: str | None = None, api_version: str | None = None, + token_tracker: Any | None = None, ) -> np.ndarray: deployment = ( os.getenv("AZURE_EMBEDDING_DEPLOYMENT") @@ -174,4 +222,14 @@ async def azure_openai_embed( response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" ) + + # Track token usage for embeddings if token tracker is provided + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Azure OpenAI embedding token usage: {token_counts}") + return np.array([dp.embedding for dp in response.data]) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 16737341f0..f3f04cfde0 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -48,6 +48,7 @@ async def bedrock_complete_if_cache( aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: if enable_cot: @@ -155,6 +156,18 @@ async def stream_generator(): yield text # Handle other event types that might indicate stream end elif "messageStop" in event: + # Track token usage for streaming if token tracker is provided + if token_tracker and "usage" in event: + usage = event["usage"] + token_counts = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + token_tracker.add_usage(token_counts) + logging.debug( + f"Bedrock streaming token usage: {token_counts}" + ) break except Exception as e: @@ -228,6 +241,17 @@ async def stream_generator(): if not content or content.strip() == "": raise BedrockError("Received empty content from Bedrock API") + # Track token usage for non-streaming if token tracker is provided + if token_tracker and "usage" in response: + usage = response["usage"] + token_counts = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + token_tracker.add_usage(token_counts) + logging.debug(f"Bedrock non-streaming token usage: {token_counts}") + return content except Exception as e: @@ -239,7 +263,12 @@ async def stream_generator(): # Generic Bedrock completion function async def bedrock_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + token_tracker=None, + **kwargs, ) -> Union[str, AsyncIterator[str]]: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] @@ -248,6 +277,7 @@ async def bedrock_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + token_tracker=token_tracker, **kwargs, ) return result @@ -265,6 +295,7 @@ async def bedrock_embed( aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, + token_tracker=None, ) -> np.ndarray: # Respect existing env; only set if a non-empty value is available access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py index 9274dbfc4b..2fe2561fc1 100644 --- a/lightrag/llm/lollms.py +++ b/lightrag/llm/lollms.py @@ -108,6 +108,7 @@ async def lollms_model_complete( history_messages=[], enable_cot: bool = False, keyword_extraction=False, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: """Complete function for lollms model generation.""" @@ -135,7 +136,11 @@ async def lollms_model_complete( async def lollms_embed( - texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs + texts: List[str], + embed_model=None, + base_url="http://localhost:9600", + token_tracker=None, + **kwargs, ) -> np.ndarray: """ Generate embeddings for a list of texts using lollms server. @@ -144,6 +149,7 @@ async def lollms_embed( texts: List of strings to embed embed_model: Model name (not used directly as lollms uses configured vectorizer) base_url: URL of the lollms server + token_tracker: Optional token usage tracker for monitoring API usage **kwargs: Additional arguments passed to the request Returns: diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index b013496e95..01f5e06cf3 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -39,6 +39,7 @@ async def _ollama_model_if_cache( system_prompt=None, history_messages=[], enable_cot: bool = False, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: if enable_cot: @@ -74,13 +75,47 @@ async def _ollama_model_if_cache( """cannot cache stream response and process reasoning""" async def inner(): + accumulated_response = "" try: async for chunk in response: - yield chunk["message"]["content"] + chunk_content = chunk["message"]["content"] + accumulated_response += chunk_content + yield chunk_content except Exception as e: logger.error(f"Error in stream response: {str(e)}") raise finally: + # Track token usage for streaming if token tracker is provided + if token_tracker: + # Estimate prompt tokens: roughly 4 characters per token for English text + prompt_text = "" + if system_prompt: + prompt_text += system_prompt + " " + prompt_text += ( + " ".join( + [msg.get("content", "") for msg in history_messages] + ) + + " " + ) + prompt_text += prompt + prompt_tokens = len(prompt_text) // 4 + ( + 1 if len(prompt_text) % 4 else 0 + ) + + # Estimate completion tokens from accumulated response + completion_tokens = len(accumulated_response) // 4 + ( + 1 if len(accumulated_response) % 4 else 0 + ) + total_tokens = prompt_tokens + completion_tokens + + token_tracker.add_usage( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + ) + try: await ollama_client._client.aclose() logger.debug("Successfully closed Ollama client for streaming") @@ -91,6 +126,35 @@ async def inner(): else: model_response = response["message"]["content"] + # Track token usage if token tracker is provided + # Note: Ollama doesn't provide token usage in chat responses, so we estimate + if token_tracker: + # Estimate prompt tokens: roughly 4 characters per token for English text + prompt_text = "" + if system_prompt: + prompt_text += system_prompt + " " + prompt_text += ( + " ".join([msg.get("content", "") for msg in history_messages]) + " " + ) + prompt_text += prompt + prompt_tokens = len(prompt_text) // 4 + ( + 1 if len(prompt_text) % 4 else 0 + ) + + # Estimate completion tokens from response + completion_tokens = len(model_response) // 4 + ( + 1 if len(model_response) % 4 else 0 + ) + total_tokens = prompt_tokens + completion_tokens + + token_tracker.add_usage( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + ) + """ If the model also wraps its thoughts in a specific tag, this information is not needed for the final @@ -126,6 +190,7 @@ async def ollama_model_complete( history_messages=[], enable_cot: bool = False, keyword_extraction=False, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: keyword_extraction = kwargs.pop("keyword_extraction", None) @@ -138,11 +203,14 @@ async def ollama_model_complete( system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + token_tracker=token_tracker, **kwargs, ) -async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: +async def ollama_embed( + texts: list[str], embed_model, token_tracker=None, **kwargs +) -> np.ndarray: api_key = kwargs.pop("api_key", None) headers = { "Content-Type": "application/json", @@ -160,6 +228,21 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: data = await ollama_client.embed( model=embed_model, input=texts, options=options ) + + # Track token usage if token tracker is provided + # Note: Ollama doesn't provide token usage in embedding responses, so we estimate + if token_tracker: + # Estimate tokens: roughly 4 characters per token for English text + total_chars = sum(len(text) for text in texts) + estimated_tokens = total_chars // 4 + (1 if total_chars % 4 else 0) + token_tracker.add_usage( + { + "prompt_tokens": estimated_tokens, + "completion_tokens": 0, + "total_tokens": estimated_tokens, + } + ) + return np.array(data["embeddings"]) except Exception as e: logger.error(f"Error in ollama_embed: {str(e)}") diff --git a/lightrag/operate.py b/lightrag/operate.py index 0551fdb559..13715e9447 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2,6 +2,7 @@ from functools import partial import asyncio +import re import json import json_repair from typing import Any, AsyncIterator, overload, Literal @@ -11,6 +12,7 @@ logger, compute_mdhash_id, Tokenizer, + TokenTracker, is_float_regex, sanitize_and_normalize_extracted_text, pack_user_ass_to_openai_messages, @@ -125,6 +127,7 @@ async def _handle_entity_relation_summary( seperator: str, global_config: dict, llm_response_cache: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[str, bool]: """Handle entity relation description summary using map-reduce approach. @@ -187,6 +190,7 @@ async def _handle_entity_relation_summary( current_list, global_config, llm_response_cache, + token_tracker, ) return final_summary, True # LLM was used for final summarization @@ -242,6 +246,7 @@ async def _handle_entity_relation_summary( chunk, global_config, llm_response_cache, + token_tracker, ) new_summaries.append(summary) llm_was_used = True # Mark that LLM was used in reduce phase @@ -256,6 +261,7 @@ async def _summarize_descriptions( description_list: list[str], global_config: dict, llm_response_cache: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> str: """Helper function to summarize a list of descriptions using LLM. @@ -311,9 +317,10 @@ async def _summarize_descriptions( # Use LLM function with cache (higher priority for summary generation) summary, _ = await use_llm_func_with_cache( use_prompt, - use_llm_func, - llm_response_cache=llm_response_cache, + use_llm_func=use_llm_func, + hashing_kv=llm_response_cache, cache_type="summary", + token_tracker=token_tracker, ) return summary @@ -323,6 +330,7 @@ async def _handle_single_entity_extraction( chunk_key: str, timestamp: int, file_path: str = "unknown_source", + metadata: dict[str, Any] | None = None, ): if len(record_attributes) != 4 or "entity" not in record_attributes[0]: if len(record_attributes) > 1 and "entity" in record_attributes[0]: @@ -376,6 +384,7 @@ async def _handle_single_entity_extraction( source_id=chunk_key, file_path=file_path, timestamp=timestamp, + metadata=metadata, ) except ValueError as e: @@ -395,13 +404,14 @@ async def _handle_single_relationship_extraction( chunk_key: str, timestamp: int, file_path: str = "unknown_source", + metadata: dict[str, Any] | None = None, ): if ( len(record_attributes) != 5 or "relation" not in record_attributes[0] ): # treat "relationship" and "relation" interchangeable if len(record_attributes) > 1 and "relation" in record_attributes[0]: logger.warning( - f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) >2 else 'N/A'}`" + f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`" ) logger.debug(record_attributes) return None @@ -458,6 +468,7 @@ async def _handle_single_relationship_extraction( source_id=edge_source_id, file_path=file_path, timestamp=timestamp, + metadata=metadata, ) except ValueError as e: @@ -862,6 +873,7 @@ async def _process_extraction_result( file_path: str = "unknown_source", tuple_delimiter: str = "<|#|>", completion_delimiter: str = "<|COMPLETE|>", + metadata: dict[str, Any] | None = None, ) -> tuple[dict, dict]: """Process a single extraction result (either initial or gleaning) Args: @@ -943,7 +955,7 @@ async def _process_extraction_result( # Try to parse as entity entity_data = await _handle_single_entity_extraction( - record_attributes, chunk_key, timestamp, file_path + record_attributes, chunk_key, timestamp, file_path, metadata ) if entity_data is not None: maybe_nodes[entity_data["entity_name"]].append(entity_data) @@ -951,7 +963,7 @@ async def _process_extraction_result( # Try to parse as relationship relationship_data = await _handle_single_relationship_extraction( - record_attributes, chunk_key, timestamp, file_path + record_attributes, chunk_key, timestamp, file_path, metadata ) if relationship_data is not None: maybe_edges[ @@ -1295,6 +1307,7 @@ async def _merge_nodes_then_upsert( pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, + metadata: dict[str, Any] | None = None, ): """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] @@ -1382,6 +1395,7 @@ async def _merge_nodes_then_upsert( description=description, source_id=source_id, file_path=file_path, + metadata=metadata, # Add metadata here created_at=int(time.time()), ) await knowledge_graph_inst.upsert_node( @@ -1402,6 +1416,7 @@ async def _merge_edges_then_upsert( pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, added_entities: list = None, # New parameter to track entities added during edge processing + metadata: dict | None = None, ): if src_id == tgt_id: return None @@ -1534,6 +1549,7 @@ async def _merge_edges_then_upsert( "description": description, "entity_type": "UNKNOWN", "file_path": file_path, + "metadata": metadata, # Add metadata here "created_at": int(time.time()), } await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) @@ -1546,6 +1562,7 @@ async def _merge_edges_then_upsert( "description": description, "source_id": source_id, "file_path": file_path, + "metadata": metadata, # Add metadata here "created_at": int(time.time()), } added_entities.append(entity_data) @@ -1559,6 +1576,7 @@ async def _merge_edges_then_upsert( keywords=keywords, source_id=source_id, file_path=file_path, + metadata=metadata, # Add metadata here created_at=int(time.time()), ), ) @@ -1570,6 +1588,7 @@ async def _merge_edges_then_upsert( keywords=keywords, source_id=source_id, file_path=file_path, + metadata=metadata, # Add metadata here created_at=int(time.time()), ) @@ -1591,6 +1610,7 @@ async def merge_nodes_and_edges( current_file_number: int = 0, total_files: int = 0, file_path: str = "unknown_source", + metadata: dict | None = None, # Added metadata parameter ) -> None: """Two-phase merge: process all entities first, then all relationships @@ -1614,6 +1634,7 @@ async def merge_nodes_and_edges( current_file_number: Current file number for logging total_files: Total files for logging file_path: File path for logging + metadata: Document metadata to be attached to entities and relationships """ # Collect all nodes and edges from all chunks @@ -1667,6 +1688,7 @@ async def _locked_process_entity_name(entity_name, entities): pipeline_status, pipeline_status_lock, llm_response_cache, + metadata, ) # Vector database operation (equally critical, must succeed) @@ -1682,6 +1704,7 @@ async def _locked_process_entity_name(entity_name, entities): "file_path": entity_data.get( "file_path", "unknown_source" ), + "metadata": metadata, } } @@ -1797,7 +1820,8 @@ async def _locked_process_edges(edge_key, edges): pipeline_status, pipeline_status_lock, llm_response_cache, - added_entities, # Pass list to collect added entities + added_entities, + metadata, ) if edge_data is None: @@ -1818,6 +1842,7 @@ async def _locked_process_edges(edge_key, edges): "file_path", "unknown_source" ), "weight": edge_data.get("weight", 1.0), + "metadata": metadata, } } @@ -1967,13 +1992,14 @@ async def _locked_process_edges(edge_key, edges): pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - # Update storage + # Update storage with metadata if final_entity_names: await full_entities_storage.upsert( { doc_id: { "entity_names": list(final_entity_names), "count": len(final_entity_names), + "metadata": metadata, # Add metadata here } } ) @@ -1986,6 +2012,7 @@ async def _locked_process_edges(edge_key, edges): list(pair) for pair in final_relation_pairs ], "count": len(final_relation_pairs), + "metadata": metadata, # Add metadata here } } ) @@ -2010,10 +2037,12 @@ async def _locked_process_edges(edge_key, edges): async def extract_entities( chunks: dict[str, TextChunkSchema], global_config: dict[str, str], + metadata: dict[str, Any] | None = None, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -2047,6 +2076,55 @@ async def extract_entities( processed_chunks = 0 total_chunks = len(ordered_chunks) + '''async def _process_extraction_result( + result: str, + chunk_key: str, + file_path: str = "unknown_source", + metadata: dict[str, Any] | None = None, + ): + """Process a single extraction result (either initial or gleaning) + Args: + result (str): The extraction result to process + chunk_key (str): The chunk key for source tracking + file_path (str): The file path for citation + metadata (dict, optional): Additional metadata to include in extracted entities/relationships. + Returns: + tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships + """ + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) + + records = split_string_by_multi_markers( + result, + [context_base["record_delimiter"], context_base["completion_delimiter"]], + ) + + for record in records: + record = re.search(r"\((.*)\)", record) + if record is None: + continue + record = record.group(1) + record_attributes = split_string_by_multi_markers( + record, [context_base["tuple_delimiter"]] + ) + + if_entities = await _handle_single_entity_extraction( + record_attributes, chunk_key, file_path, metadata + ) + if if_entities is not None: + maybe_nodes[if_entities["entity_name"]].append(if_entities) + continue + + if_relation = await _handle_single_relationship_extraction( + record_attributes, chunk_key, file_path, metadata + ) + if if_relation is not None: + maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( + if_relation + ) + + return maybe_nodes, maybe_edges''' + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): """Process a single chunk Args: @@ -2078,49 +2156,53 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): final_result, timestamp = await use_llm_func_with_cache( entity_extraction_user_prompt, - use_llm_func, + use_llm_func=use_llm_func, system_prompt=entity_extraction_system_prompt, - llm_response_cache=llm_response_cache, + hashing_kv=llm_response_cache, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, + token_tracker=token_tracker, ) history = pack_user_ass_to_openai_messages( entity_extraction_user_prompt, final_result ) - # Process initial extraction with file path + # Process initial extraction with file path and metadata maybe_nodes, maybe_edges = await _process_extraction_result( final_result, chunk_key, timestamp, - file_path, + file_path=file_path, tuple_delimiter=context_base["tuple_delimiter"], completion_delimiter=context_base["completion_delimiter"], + metadata=metadata, ) # Process additional gleaning results only 1 time when entity_extract_max_gleaning is greater than zero. if entity_extract_max_gleaning > 0: glean_result, timestamp = await use_llm_func_with_cache( entity_continue_extraction_user_prompt, - use_llm_func, + use_llm_func=use_llm_func, system_prompt=entity_extraction_system_prompt, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, + token_tracker=token_tracker, ) - # Process gleaning result separately with file path + # Process gleaning result separately with file path and metadata glean_nodes, glean_edges = await _process_extraction_result( glean_result, chunk_key, timestamp, - file_path, + file_path=file_path, tuple_delimiter=context_base["tuple_delimiter"], completion_delimiter=context_base["completion_delimiter"], + metadata=metadata, ) # Merge results - compare description lengths to choose better version @@ -2225,7 +2307,7 @@ async def _process_with_semaphore(chunk): await asyncio.wait(pending) # Add progress prefix to the exception message - progress_prefix = f"C[{processed_chunks+1}/{total_chunks}]" + progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]" # Re-raise the original exception with a prefix prefixed_exception = create_prefixed_exception(first_exception, progress_prefix) @@ -2236,6 +2318,41 @@ async def _process_with_semaphore(chunk): return chunk_results +@overload +async def kg_query( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, + chunks_vdb: BaseVectorStorage = None, + return_raw_data: Literal[True] = False, + token_tracker: TokenTracker | None = None, +) -> dict[str, Any]: ... + + +@overload +async def kg_query( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, + chunks_vdb: BaseVectorStorage = None, + metadata_filters: list | None = None, + return_raw_data: Literal[False] = False, + token_tracker: TokenTracker | None = None, +) -> str | AsyncIterator[str]: ... + + async def kg_query( query: str, knowledge_graph_inst: BaseGraphStorage, @@ -2247,6 +2364,7 @@ async def kg_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, chunks_vdb: BaseVectorStorage = None, + token_tracker: TokenTracker | None = None, ) -> QueryResult: """ Execute knowledge graph query and return unified QueryResult object. @@ -2286,8 +2404,35 @@ async def kg_query( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) + # Handle cache + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, + query_param.metadata_filter, + ) + cached_result = await handle_cache( + hashing_kv, args_hash, query, query_param.mode, cache_type="query" + ) + if ( + cached_result is not None + and not query_param.only_need_context + and not query_param.only_need_prompt + ): + cached_response, _ = cached_result # Extract content, ignore timestamp + return QueryResult(content=cached_response) + hl_keywords, ll_keywords = await get_keywords_from_query( - query, query_param, global_config, hashing_kv + query, query_param, global_config, hashing_kv, token_tracker ) logger.debug(f"High-level keywords: {hl_keywords}") @@ -2391,6 +2536,7 @@ async def kg_query( history_messages=query_param.conversation_history, enable_cot=True, stream=query_param.stream, + token_tracker=token_tracker, ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): @@ -2448,6 +2594,7 @@ async def get_keywords_from_query( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[list[str], list[str]]: """ Retrieves high-level and low-level keywords for RAG operations. @@ -2470,7 +2617,7 @@ async def get_keywords_from_query( # Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await extract_keywords_only( - query, query_param, global_config, hashing_kv + query, query_param, global_config, hashing_kv, token_tracker ) return hl_keywords, ll_keywords @@ -2480,6 +2627,7 @@ async def extract_keywords_only( param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. @@ -2533,7 +2681,9 @@ async def extract_keywords_only( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - result = await use_model_func(kw_prompt, keyword_extraction=True) + result = await use_model_func( + kw_prompt, keyword_extraction=True, token_tracker=token_tracker + ) # 5. Parse out JSON from the LLM response result = remove_think_tags(result) @@ -2611,7 +2761,10 @@ async def _get_vector_context( cosine_threshold = chunks_vdb.cosine_better_than_threshold results = await chunks_vdb.query( - query, top_k=search_top_k, query_embedding=query_embedding + query, + top_k=search_top_k, + query_embedding=query_embedding, + metadata_filter=query_param.metadata_filter, ) if not results: logger.info( @@ -2628,6 +2781,7 @@ async def _get_vector_context( "file_path": result.get("file_path", "unknown_source"), "source_type": "vector", # Mark the source type "chunk_id": result.get("id"), # Add chunk_id for deduplication + "metadata": result.get("metadata"), } valid_chunks.append(chunk_with_metadata) @@ -2677,13 +2831,22 @@ async def _perform_kg_search( query_embedding = None if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): embedding_func_config = text_chunks_db.embedding_func - if embedding_func_config and embedding_func_config.func: + if embedding_func_config: try: - query_embedding = await embedding_func_config.func([query]) - query_embedding = query_embedding[ - 0 - ] # Extract first embedding from batch result - logger.debug("Pre-computed query embedding for all vector operations") + # Handle both EmbeddingFunc objects and plain callable functions + from .utils import EmbeddingFunc + if isinstance(embedding_func_config, EmbeddingFunc): + embedding_func = embedding_func_config.func + else: + # It's a plain callable function + embedding_func = embedding_func_config + + if embedding_func: + query_embedding = await embedding_func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + logger.debug("Pre-computed query embedding for all vector operations") except Exception as e: logger.warning(f"Failed to pre-compute query embedding: {e}") query_embedding = None @@ -3393,7 +3556,9 @@ async def _get_node_data( f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" ) - results = await entities_vdb.query(query, top_k=query_param.top_k) + results = await entities_vdb.query( + query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter + ) if not len(results): return [], [] @@ -3401,6 +3566,9 @@ async def _get_node_data( # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] + # Extract all entity IDs from your results list + node_ids = [r["entity_name"] for r in results] + # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( knowledge_graph_inst.get_nodes_batch(node_ids), @@ -3589,7 +3757,13 @@ async def _find_related_text_unit_from_entities( kg_chunk_pick_method = "WEIGHT" else: try: - actual_embedding_func = embedding_func_config.func + # Handle both EmbeddingFunc objects and plain callable functions + from .utils import EmbeddingFunc + if isinstance(embedding_func_config, EmbeddingFunc): + actual_embedding_func = embedding_func_config.func + else: + # It's a plain callable function + actual_embedding_func = embedding_func_config selected_chunk_ids = None if actual_embedding_func: @@ -3669,7 +3843,9 @@ async def _get_edge_data( f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" ) - results = await relationships_vdb.query(keywords, top_k=query_param.top_k) + results = await relationships_vdb.query( + keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter + ) if not len(results): return [], [] @@ -3882,7 +4058,13 @@ async def _find_related_text_unit_from_relations( kg_chunk_pick_method = "WEIGHT" else: try: - actual_embedding_func = embedding_func_config.func + # Handle both EmbeddingFunc objects and plain callable functions + from .utils import EmbeddingFunc + if isinstance(embedding_func_config, EmbeddingFunc): + actual_embedding_func = embedding_func_config.func + else: + # It's a plain callable function + actual_embedding_func = embedding_func_config if actual_embedding_func: selected_chunk_ids = await pick_by_vector_similarity( @@ -3963,6 +4145,7 @@ async def naive_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, return_raw_data: Literal[True] = True, + token_tracker: TokenTracker | None = None, ) -> dict[str, Any]: ... @@ -3975,6 +4158,7 @@ async def naive_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, return_raw_data: Literal[False] = False, + token_tracker: TokenTracker | None = None, ) -> str | AsyncIterator[str]: ... @@ -3985,6 +4169,7 @@ async def naive_query( global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, + token_tracker: TokenTracker | None = None, ) -> QueryResult: """ Execute naive query and return unified QueryResult object. @@ -4180,6 +4365,7 @@ async def naive_query( history_messages=query_param.conversation_history, enable_cot=True, stream=query_param.stream, + token_tracker=token_tracker, ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): diff --git a/lightrag/types.py b/lightrag/types.py index a18f2d3cd8..e61910b6fc 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,7 +1,7 @@ from __future__ import annotations -from pydantic import BaseModel -from typing import Any, Optional +from pydantic import BaseModel, Field, validator +from typing import Any, Optional, List, Union, Dict class GPTKeywordExtractionFormat(BaseModel): @@ -27,3 +27,46 @@ class KnowledgeGraph(BaseModel): nodes: list[KnowledgeGraphNode] = [] edges: list[KnowledgeGraphEdge] = [] is_truncated: bool = False + + +class MetadataFilter(BaseModel): + """ + Represents a logical expression for metadata filtering. + + Args: + operator: "AND", "OR", or "NOT" + operands: List of either simple key-value pairs or nested MetadataFilter objects + """ + operator: str = Field(..., description="Logical operator: AND, OR, or NOT") + operands: List[Union[Dict[str, Any], 'MetadataFilter']] = Field(default_factory=list, description="List of operands for filtering") + + @validator('operator') + def validate_operator(cls, v): + if v not in ["AND", "OR", "NOT"]: + raise ValueError('operator must be one of: "AND", "OR", "NOT"') + return v + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation.""" + return { + "operator": self.operator, + "operands": [ + operand.dict() if isinstance(operand, MetadataFilter) else operand + for operand in self.operands + ] + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'MetadataFilter': + """Create from dictionary representation.""" + operands = [] + for operand in data.get("operands", []): + if isinstance(operand, dict) and "operator" in operand: + operands.append(cls.from_dict(operand)) + else: + operands.append(operand) + return cls(operator=data.get("operator", "AND"), operands=operands) + + class Config: + """Pydantic configuration.""" + validate_assignment = True diff --git a/lightrag/utils.py b/lightrag/utils.py index 60542e43eb..bfea03395d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -4,6 +4,7 @@ import asyncio import html import csv +import contextvars import json import logging import logging.handlers @@ -507,6 +508,7 @@ async def worker(): task_id, args, kwargs, + ctx, ) = await asyncio.wait_for(queue.get(), timeout=1.0) except asyncio.TimeoutError: continue @@ -536,11 +538,15 @@ async def worker(): try: # Execute function with timeout protection if max_execution_timeout is not None: + # Run the function in the captured context + task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs))) result = await asyncio.wait_for( - func(*args, **kwargs), timeout=max_execution_timeout + task, timeout=max_execution_timeout ) else: - result = await func(*args, **kwargs) + # Run the function in the captured context + task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs))) + result = await task # Set result if future is still valid if not task_state.future.done(): @@ -791,6 +797,9 @@ async def wait_func( future=future, start_time=asyncio.get_event_loop().time() ) + # Capture current context + ctx = contextvars.copy_context() + try: # Register task state async with task_states_lock: @@ -809,13 +818,13 @@ async def wait_func( if _queue_timeout is not None: await asyncio.wait_for( queue.put( - (_priority, current_count, task_id, args, kwargs) + (_priority, current_count, task_id, args, kwargs, ctx) ), timeout=_queue_timeout, ) else: await queue.put( - (_priority, current_count, task_id, args, kwargs) + (_priority, current_count, task_id, args, kwargs, ctx) ) except asyncio.TimeoutError: raise QueueFullError( @@ -1472,8 +1481,7 @@ async def aexport_data( else: raise ValueError( - f"Unsupported file format: {file_format}. " - f"Choose from: csv, excel, md, txt" + f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt" ) if file_format is not None: print(f"Data exported to: {output_path} with format: {file_format}") @@ -1601,6 +1609,8 @@ async def use_llm_func_with_cache( cache_type: str = "extract", chunk_id: str | None = None, cache_keys_collector: list = None, + hashing_kv: "BaseKVStorage | None" = None, + token_tracker=None, ) -> tuple[str, int]: """Call LLM function with cache support and text sanitization @@ -1684,8 +1694,15 @@ async def use_llm_func_with_cache( if max_tokens is not None: kwargs["max_tokens"] = max_tokens + # Note: token_tracker is NOT passed here because the LLM wrapper functions + # (e.g., optimized_openai_alike_model_complete in lightrag_server.py) + # already handle token_tracker directly. Passing it here would cause + # "got multiple values for keyword argument 'token_tracker'" error. + res: str = await use_llm_func( - safe_user_prompt, system_prompt=safe_system_prompt, **kwargs + safe_user_prompt, + system_prompt=safe_system_prompt, + **kwargs, ) res = remove_think_tags(res) @@ -1718,9 +1735,16 @@ async def use_llm_func_with_cache( if max_tokens is not None: kwargs["max_tokens"] = max_tokens + # Note: token_tracker is NOT passed here because the LLM wrapper functions + # (e.g., optimized_openai_alike_model_complete in lightrag_server.py) + # already handle token_tracker directly. Passing it here would cause + # "got multiple values for keyword argument 'token_tracker'" error. + try: res = await use_llm_func( - safe_user_prompt, system_prompt=safe_system_prompt, **kwargs + safe_user_prompt, + system_prompt=safe_system_prompt, + **kwargs, ) except Exception as e: # Add [LLM func] prefix to error message @@ -2216,52 +2240,74 @@ async def pick_by_vector_similarity( return all_chunk_ids[:num_of_chunks] +from contextvars import ContextVar + + class TokenTracker: - """Track token usage for LLM calls.""" + """Track token usage for LLM calls using ContextVars for concurrency support.""" + + _usage_var: ContextVar[dict] = ContextVar("token_usage", default=None) def __init__(self): - self.reset() + # No instance state needed as we use ContextVar + pass def __enter__(self): self.reset() return self def __exit__(self, exc_type, exc_val, exc_tb): - print(self) + # Optional: Log usage on exit if needed + pass def reset(self): - self.prompt_tokens = 0 - self.completion_tokens = 0 - self.total_tokens = 0 - self.call_count = 0 + """Initialize/Reset token usage for the current context.""" + self._usage_var.set( + { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "call_count": 0, + } + ) - def add_usage(self, token_counts): + def _get_current_usage(self) -> dict: + """Get the usage dict for the current context, initializing if necessary.""" + usage = self._usage_var.get() + if usage is None: + usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "call_count": 0, + } + self._usage_var.set(usage) + return usage + + def add_usage(self, token_counts: dict): """Add token usage from one LLM call. Args: token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens """ - self.prompt_tokens += token_counts.get("prompt_tokens", 0) - self.completion_tokens += token_counts.get("completion_tokens", 0) + usage = self._get_current_usage() + + usage["prompt_tokens"] += token_counts.get("prompt_tokens", 0) + usage["completion_tokens"] += token_counts.get("completion_tokens", 0) # If total_tokens is provided, use it directly; otherwise calculate the sum if "total_tokens" in token_counts: - self.total_tokens += token_counts["total_tokens"] + usage["total_tokens"] += token_counts["total_tokens"] else: - self.total_tokens += token_counts.get( + usage["total_tokens"] += token_counts.get( "prompt_tokens", 0 ) + token_counts.get("completion_tokens", 0) - self.call_count += 1 + usage["call_count"] += 1 def get_usage(self): """Get current usage statistics.""" - return { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - "total_tokens": self.total_tokens, - "call_count": self.call_count, - } + return self._get_current_usage().copy() def __str__(self): usage = self.get_usage() @@ -2273,6 +2319,26 @@ def __str__(self): ) +def estimate_embedding_tokens(texts: list[str], tokenizer: Tokenizer) -> int: + """Estimate tokens for embedding operations based on text length. + + Most embedding APIs don't return token counts, so we estimate based on + the tokenizer encoding. This provides a reasonable approximation for tracking. + + Args: + texts: List of text strings to be embedded + tokenizer: Tokenizer instance for encoding + + Returns: + Estimated total token count for all texts + """ + total = 0 + for text in texts: + if text: # Skip empty strings + total += len(tokenizer.encode(text)) + return total + + async def apply_rerank_if_enabled( query: str, retrieved_docs: list[dict], diff --git a/pyproject.toml b/pyproject.toml index e850ce2c09..7e8d70f498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "future", "json_repair", "nano-vectordb", + "neo4j>=5.28.2", "networkx", "numpy", "pandas>=2.0.0",