diff --git a/.github/workflows/pipeline-tests.yml b/.github/workflows/pipeline-tests.yml new file mode 100644 index 0000000..6b5ea2a --- /dev/null +++ b/.github/workflows/pipeline-tests.yml @@ -0,0 +1,128 @@ +name: Pipeline Tests + +on: + push: + branches: [ main, development ] + pull_request: + branches: [ main, development ] + +jobs: + test-minimal: + name: Minimal Pipeline Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: false # avoid .gitmodules errors + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.13' + + - name: Install dependencies + run: pip install -r tests/requirements-minimal.txt + + - name: Run minimal pipeline tests (no external services) + run: python -m pytest tests/pipeline/test_minimal_pipeline.py tests/pipeline/test_components.py -v --tb=short + + - name: Run configuration validation + run: | + python -c " + import yaml, sys + configs = ['config.yml', 'pipelines/configs/retrieval/ci_google_gemini.yml'] + for config in configs: + try: + with open(config) as f: + yaml.safe_load(f) + print(f'{config} is valid') + except Exception as e: + print(f'{config} failed: {e}') + sys.exit(1) + " + + test-integration: + name: Integration Tests with Qdrant + runs-on: ubuntu-latest + + services: + qdrant: + image: qdrant/qdrant:latest + ports: + - 6333:6333 # REST + - 6334:6334 # gRPC + env: + QDRANT__SERVICE__HTTP_PORT: 6333 + QDRANT__SERVICE__GRPC_PORT: 6334 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: false + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.13' + + - name: Install dependencies + run: pip install -r tests/requirements-minimal.txt + + - name: Wait for Qdrant to be ready (readiness + API check) + run: | + for i in {1..60}; do + if curl -fsS http://127.0.0.1:6333/readyz > /dev/null && \ + curl -fsS http://127.0.0.1:6333/collections > /dev/null; then + echo "Qdrant is ready!" + exit 0 + fi + echo "Waiting for Qdrant ($i/60)..." + sleep 2 + done + echo "Qdrant did not become ready in time" + exit 1 + + - name: Test Qdrant connectivity + run: python -m pytest tests/pipeline/test_qdrant_connectivity.py -v --tb=short + + - name: Run basic integration tests + run: python -m pytest tests/pipeline/ -v --tb=short -m "not requires_api" + + test-security: + name: Security and Config Validation + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: false + + - name: Check for hardcoded secrets + run: | + if grep -r "sk-" . --exclude-dir=.git --exclude="*.md" --exclude="*.yml"; then + echo "Found potential hardcoded API keys" + exit 1 + fi + if grep -r "google_api_key.*=" . --exclude-dir=.git --exclude="*.md" --exclude="*.yml" | grep -v "getenv\|environ"; then + echo "Found potential hardcoded Google API keys" + exit 1 + fi + echo "No hardcoded secrets found" + + - name: Validate configuration structure + run: | + python -c " + import yaml + with open('pipelines/configs/retrieval/ci_google_gemini.yml') as f: + config = yaml.safe_load(f) + assert 'retrieval_pipeline' in config + assert 'retriever' in config['retrieval_pipeline'] + assert 'embedding' in config['retrieval_pipeline']['retriever'] + assert 'google' == config['retrieval_pipeline']['retriever']['embedding']['dense']['provider'] + assert 'GOOGLE_API_KEY' == config['retrieval_pipeline']['retriever']['embedding']['dense']['api_key_env'] + print('Configuration structure is valid') + " diff --git a/.gitignore b/.gitignore index a8c09d1..a09d47b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,14 @@ climate-fever *.log __pycache__ sandbox/* -/__pycache__ \ No newline at end of file +/__pycache__ +synthetic_dataset\text_dataset_template.json +extraction_output/ +.idea/misc.xml +.idea/modules.xml +.idea/Thesis.iml +.idea/vcs.xml +.idea/inspectionProfiles/profiles_settings.xml +*.json +*.csv + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..219439b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +# Use a slim Python base image +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Install system packages if needed +RUN apt-get update && apt-get install -y \ + build-essential \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the full source code +COPY . . diff --git a/README.md b/README.md new file mode 100644 index 0000000..42c6a5e --- /dev/null +++ b/README.md @@ -0,0 +1,266 @@ +# Advanced RAG Retrieval System with LangGraph Agent + +A production-ready, modular RAG (Retrieval-Augmented Generation) system with configurable pipelines and LangGraph agent integration. + +## Key Features + +- **YAML-Configurable Pipelines**: Switch retrieval strategies without code changes +- **LangGraph Agent Integration**: Seamless agent workflows with rich metadata +- **Modular Components**: Easily extensible rerankers, filters, and retrievers +- **Multiple Retrieval Methods**: Dense, sparse, and hybrid retrieval +- **Production Ready**: Robust error handling, logging, and monitoring +- **A/B Testing Support**: Compare configurations easily +- **Rich Metadata**: Access scores, methods, and quality metrics + +## Architecture Overview + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ LangGraph │────│ Configurable │────│ Retrieval │ +│ Agent │ │ Retriever Agent │ │ Pipeline │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ + ┌────────────────────────────────┼────────────────────────────────┐ + │ │ │ + ┌─────▼─────┐ ┌───────▼────────┐ ┌─────▼─────┐ + │ Retrievers │ │ Rerankers │ │ Filters │ + │ │ │ │ │ │ + │• Dense │ │• CrossEncoder │ │• Score │ + │• Sparse │ │• BGE Reranker │ │• Content │ + │• Hybrid │ │• Multi-stage │ │• Custom │ + └───────────┘ └────────────────┘ └───────────┘ +``` + +## Quick Start + +### 1. Install Dependencies + +```bash +pip install -r requirements.txt +``` + +### 2. Configure Environment + +```bash +# Copy example config +cp config.yml.example config.yml + +# Set up your API keys and database connections in config.yml +``` + +### 3. Start Using the System + +```python +# main.py - Chat with your agent +from agent.graph import graph + +state = {"question": "How to handle Python exceptions?"} +result = graph.invoke(state) +print(result["answer"]) +``` + +### 4. Switch Retrieval Configurations + +```bash +# List available configurations +python bin/switch_agent_config.py --list + +# Switch to advanced reranked pipeline +python bin/switch_agent_config.py advanced_reranked + +# Test the configuration +python test_agent_retriever_node.py +``` + +## Available Configurations + +| Configuration | Description | Components | Use Case | +|---------------|-------------|------------|----------| +| `basic_dense` | Simple dense retrieval | Dense retriever only | Development, testing | +| `advanced_reranked` | Production quality | Dense + CrossEncoder + filters | Production RAG | +| `hybrid_multistage` | Best performance | Hybrid + multi-stage reranking | High-quality results | +| `experimental` | Latest features | BGE reranker + custom filters | Experimentation | + +## 🔧 **Configuration Example** + +```yaml +# pipelines/configs/retrieval/advanced_reranked.yml +retrieval_pipeline: + retriever: + type: dense + top_k: 10 + + stages: + - type: reranker + config: + model_type: cross_encoder + model_name: "ms-marco-MiniLM-L-6-v2" + + - type: filter + config: + type: score + min_score: 0.5 + + - type: answer_enhancer + config: + boost_factor: 2.0 +``` + +## Project Structure + +``` +Thesis/ +├── agent/ # LangGraph agent implementation +│ ├── graph.py # Main agent graph +│ ├── schema.py # Agent state schemas +│ └── nodes/ # Agent nodes (retriever, generator, etc.) +│ +├── components/ # Modular retrieval components +│ ├── retrieval_pipeline.py # Main pipeline orchestrator +│ ├── rerankers.py # Reranking implementations +│ ├── filters.py # Filtering implementations +│ └── advanced_rerankers.py # Advanced reranking strategies +│ +├── pipelines/ # Data processing and configuration +│ ├── configs/retrieval/ # Retrieval pipeline configurations +│ ├── adapters/ # Dataset adapters (BEIR, etc.) +│ └── ingest/ # Data ingestion pipeline +│ +├── bin/ # Command-line utilities +│ ├── switch_agent_config.py # Configuration management +│ ├── agent_retriever.py # Configurable retriever agent +│ └── retrieval_pipeline.py # Direct pipeline usage +│ +├── docs/ # Documentation +│ ├── SYSTEM_EXTENSION_GUIDE.md # Complete extension guide +│ ├── AGENT_INTEGRATION.md # Agent integration details +│ ├── CODE_CLEANUP_SUMMARY.md # Code cleanup documentation +│ └── EXTENSIBILITY.md # Quick extensibility overview +│ +├── tests/ # Test suite +│ ├── retrieval/ # Retrieval pipeline tests +│ └── agent/ # Agent integration tests +│ +├── deprecated/ # Legacy code (organized) +│ ├── old_processors/ # Superseded by new pipeline +│ ├── old_debug_scripts/ # Legacy debugging tools +│ └── old_playground/ # Legacy test scripts +│ +├── database/ # Database controllers +├── embedding/ # Embedding utilities +├── retrievers/ # Base retrievers +├── examples/ # Usage examples +└── config/ # Configuration utilities +``` + +## Testing + +```bash +# Test agent integration +python test_agent_retriever_node.py + +# Run all tests +python tests/run_all_tests.py + +# Test specific components +python -m pytest tests/retrieval/ -v +``` + +## Documentation + +- **[System Extension Guide](docs/SYSTEM_EXTENSION_GUIDE.md)** - Complete guide to extending the system +- **[Agent Integration](docs/AGENT_INTEGRATION.md)** - How the agent uses configurable pipelines +- **[Code Cleanup Summary](docs/CODE_CLEANUP_SUMMARY.md)** - Professional code standards and cleanup details +- **[Extensibility Overview](docs/EXTENSIBILITY.md)** - Quick overview of extension capabilities +- **[Architecture](docs/MLOPS_PIPELINE_ARCHITECTURE.md)** - System architecture details + +## Extending the System + +### Add a Custom Reranker + +```python +# components/my_reranker.py +from .rerankers import BaseReranker + +class MyCustomReranker(BaseReranker): + def rerank(self, query: str, documents: List[Document]) -> List[Document]: + # Your custom reranking logic + for doc in documents: + doc.metadata["score"] = self.calculate_score(query, doc.page_content) + + return sorted(documents, key=lambda x: x.metadata["score"], reverse=True) +``` + +### Create a New Configuration + +```yaml +# pipelines/configs/retrieval/my_config.yml +retrieval_pipeline: + retriever: + type: hybrid + top_k: 15 + + stages: + - type: reranker + config: + model_type: my_custom + custom_param: "value" +``` + +### Switch and Test + +```bash +python bin/switch_agent_config.py my_config +python test_agent_retriever_node.py +``` + +## Production Usage + +The system is designed for production use with: + +- **Robust Error Handling**: Graceful degradation when components fail +- **Comprehensive Logging**: Monitor retrieval performance and quality +- **Configuration Management**: Easy deployment of different strategies +- **Performance Optimization**: Efficient batching and caching support +- **Monitoring Ready**: Built-in metrics and health checks + +## Use Cases + +- **Document Q&A Systems**: High-quality retrieval for knowledge bases +- **Research Assistants**: Multi-modal retrieval for academic content +- **Customer Support**: Context-aware response generation +- **Code Search**: Semantic search over codebases +- **Legal Research**: Precise retrieval from legal documents + +## Contributing + +1. Fork the repository +2. Create a feature branch +3. Add your extension following the patterns in `docs/SYSTEM_EXTENSION_GUIDE.md` +4. Add tests for your components +5. Submit a pull request + +## Performance + +The system supports various performance optimization strategies: + +- **Caching**: LRU caching for repeated queries +- **Batching**: Efficient batch processing for rerankers +- **Adaptive Top-K**: Dynamic result count based on query complexity +- **Multi-threading**: Parallel processing for pipeline stages + +## Migration from Legacy + +If you have existing code using the deprecated `processors/` system: + +1. Check `deprecated/old_processors/` for reference +2. Use the new pipeline configurations in `pipelines/configs/retrieval/` +3. Follow the migration patterns in `docs/AGENT_INTEGRATION.md` + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. + +--- + +**Ready to build amazing RAG systems?** Start with the [System Extension Guide](docs/SYSTEM_EXTENSION_GUIDE.md)! diff --git a/playground/__init__.py b/agent/__init__.py similarity index 100% rename from playground/__init__.py rename to agent/__init__.py diff --git a/agent/graph.py b/agent/graph.py new file mode 100644 index 0000000..a3c31b9 --- /dev/null +++ b/agent/graph.py @@ -0,0 +1,42 @@ +from langgraph.graph import StateGraph +from agent.nodes.query_interpreter import make_query_interpreter +from agent.nodes.retriever import make_configurable_retriever +from agent.nodes.generator import make_generator +from agent.nodes.memory_updater import memory_updater +from agent.schema import AgentState +from config.config_loader import load_config +from langchain_openai import ChatOpenAI + +# Load config +config = load_config("config.yml") + +# Setup LLM +llm_cfg = config["llm"] +llm = ChatOpenAI(model=llm_cfg.get("model", "gpt-4.1-mini"), + temperature=llm_cfg.get("temperature", 0.0)) + +# Setup configurable retriever node +retrieval_config_path = config.get("agent_retrieval", {}).get( + "config_path", "pipelines/configs/retrieval/modern_hybrid.yml") +retriever = make_configurable_retriever(config_path=retrieval_config_path) + +# Setup other nodes +generator = make_generator(llm) +query_interpreter = make_query_interpreter(llm) + +# Build the graph +builder = StateGraph(AgentState) +builder.add_node("query_interpreter", query_interpreter) +builder.add_node("retriever", retriever) +builder.add_node("generator", generator) +builder.add_node("memory_updater", memory_updater) +builder.set_entry_point("query_interpreter") + +builder.add_conditional_edges("query_interpreter", lambda state: state["next_node"], { + "retriever": "retriever", + "generator": "generator", +}) + +builder.add_edge("retriever", "generator") +builder.add_edge("generator", "memory_updater") +graph = builder.compile() diff --git a/agent/nodes/generator.py b/agent/nodes/generator.py new file mode 100644 index 0000000..e3cc4c6 --- /dev/null +++ b/agent/nodes/generator.py @@ -0,0 +1,57 @@ +from typing import Dict, Any +from langchain_core.prompts import PromptTemplate +from logs.utils.logger import get_logger + +from langchain_openai import ChatOpenAI + +logger = get_logger("generator") + +generator_prompt = PromptTemplate.from_template( + """You are a helpful assistant. Use the given context to answer the user's question. + +Context: +{context} + +Question: +{question} + +Answer in clear, professional natural language. +""" +) + + +def make_generator(llm): + """ + Returns a generator node function with the provided LLM injected. + """ + def generator(state: Dict[str, Any]) -> Dict[str, Any]: + question = state["question"] + + if "context" in state: + context = state["context"] + logger.info("[Generator] Generating from retrieved context.") + elif "answer" in state: + context = state["answer"] + logger.info("[Generator] Generating from existing answer.") + else: + context = "No information available." + logger.warning("[Generator] No context available.") + + try: + prompt = generator_prompt.format( + context=context, question=question) + response = llm.invoke(prompt) + final_answer = response.content.strip() + + logger.info("[Generator] Answer generated.") + return { + **state, + "answer": final_answer + } + except Exception as e: + logger.error(f"[Generator] LLM invocation failed: {str(e)}") + return { + **state, + "answer": "I'm sorry, I couldn't generate an answer due to an internal error." + } + return generator diff --git a/agent/nodes/memory_updater.py b/agent/nodes/memory_updater.py new file mode 100644 index 0000000..7b2049e --- /dev/null +++ b/agent/nodes/memory_updater.py @@ -0,0 +1,12 @@ +from typing import Dict, Any +AgentState = Dict[str, Any] + + +def memory_updater(state: AgentState) -> AgentState: + history = state.get("chat_history", []) + question = state.get("question", "") + answer = state.get("answer", "") + + history += [f"User: {question}", f"Assistant: {answer}"] + state["chat_history"] = history[-20:] + return state diff --git a/agent/nodes/query_interpreter.py b/agent/nodes/query_interpreter.py new file mode 100644 index 0000000..4f1cae3 --- /dev/null +++ b/agent/nodes/query_interpreter.py @@ -0,0 +1,93 @@ +import logging +import json +from pathlib import Path +from typing import Dict, Any +from datetime import datetime, timezone + +from langchain_core.runnables import Runnable +from langchain_core.prompts import PromptTemplate +from logs.utils.logger import get_logger + +logger = get_logger(__name__) + +# Optional: Move prompt to config for full modularity. +QUERY_INTERPRETER_PROMPT = """ +You are a planner agent for a modular RAG pipeline. + +Your job is to: +1. Understand the user's intent. +2. Decide if the answer requires accessing unstructured document chunks or can be answered directly without retrieval. +3. Output a plan, a query type ("text" or "none"), and a next_node to route to. + +Today's date is: {reference_date} + +Question: {question} + +--- + +Respond in valid JSON using this format: +{{ + "query_type": "text" | "none", + "next_node": "retriever" | "generator", + "plan": ["Step 1: ...", "Step 2: ..."], + "reasoning": "..." +}} + +Examples: +# Example 1: Direct answer (no retrieval needed) +{{ + "query_type": "none", + "next_node": "generator", + "plan": ["Recognize this as a chitchat or general info question.", "Answer directly."], + "reasoning": "No retrieval required for this question." +}} + +# Example 2: Document retrieval needed +{{ + "query_type": "text", + "next_node": "retriever", + "plan": ["Detect that document search is needed.", "Route to retriever."], + "reasoning": "This question requires searching through documents." +}} +""" + + +def make_query_interpreter(llm): + """ + Factory to return a query_interpreter node with the provided LLM. + """ + def query_interpreter(state: Dict[str, Any]) -> Dict[str, Any]: + question = state["question"] + reference_date = state.get("reference_date") or datetime.now( + timezone.utc).date().isoformat() + + prompt = QUERY_INTERPRETER_PROMPT.format( + question=question, reference_date=reference_date + ) + + try: + response = llm.invoke(prompt) + content = response.content.strip() + parsed = json.loads(content) + except Exception as e: + logger.error("Error in query_interpreter: %s", str(e)) + logger.error("LLM response: %s", + response.content if 'response' in locals() else 'None') + parsed = { + "query_type": "none", + "next_node": "generator", + "plan": ["Fallback to generator: answer directly."], + "reasoning": "Failed to parse model response. Defaulting to direct answer." + } + + logger.info("=== Question ===") + logger.info(question) + logger.info("=== Parsed Result ===") + logger.info(json.dumps(parsed, indent=2)) + + return { + "question": question, + "reference_date": reference_date, + **parsed + } + return query_interpreter diff --git a/agent/nodes/retriever.py b/agent/nodes/retriever.py new file mode 100644 index 0000000..55175c6 --- /dev/null +++ b/agent/nodes/retriever.py @@ -0,0 +1,153 @@ +from typing import Dict, Any, List +from langchain_core.documents import Document +from logs.utils.logger import get_logger +from bin.agent_retriever import ConfigurableRetrieverAgent + +logger = get_logger(__name__) + + +def make_configurable_retriever(config_path: str = None, cache_pipeline: bool = True): + """ + Factory to return a configurable retriever node. + + Args: + config_path (str, optional): Path to YAML configuration file for retrieval pipeline. + If None, will load from main config.yml + cache_pipeline (bool): Whether to cache the pipeline for reuse + + Returns: + function: Retriever node function that can be used in LangGraph agent + """ + # Load config path from main config if not provided + if config_path is None: + from config.config_loader import load_config + main_config = load_config() + config_path = main_config.get("agent_retrieval", {}).get("config_path", + "pipelines/configs/retrieval/modern_hybrid.yml") + + # Initialize the configurable retriever agent + agent = ConfigurableRetrieverAgent(config_path, cache_pipeline) + + # Log configuration info + config_info = agent.get_config_info() + logger.info(f"[Retriever] Initialized with config: {config_path}") + logger.info( + f"[Retriever] Pipeline: {config_info['retriever_type']} with {config_info['num_stages']} stages") + logger.info( + f"[Retriever] Components: {', '.join(config_info['stage_types'])}") + + def retriever(state: Dict[str, Any]) -> Dict[str, Any]: + """ + Retriever node function for LangGraph agent. + + Args: + state (Dict[str, Any]): Current agent state containing question and other context + + Returns: + Dict[str, Any]: Updated state with retrieved documents and metadata + """ + query = state["question"] + logger.info(f"[Retriever] Query: {query}") + + try: + # Get retrieval configuration + top_k = state.get("retrieval_top_k", + config_info.get("retriever_top_k", 5)) + + # Retrieve documents using configurable pipeline + docs_info = agent.retrieve(query, top_k=top_k) + + # Convert to context string and preserve metadata + context_parts = [] + retrieved_docs = [] + + for doc_info in docs_info: + # Add to context + content = doc_info["content"] + context_parts.append(content) + + # Create Document object with metadata + doc = Document( + page_content=content, + metadata={ + "score": doc_info["score"], + "retrieval_method": doc_info["retrieval_method"], + "question_title": doc_info["question_title"], + "tags": doc_info["tags"], + "external_id": doc_info["external_id"], + "enhanced": doc_info["enhanced"], + "answer_quality": doc_info["answer_quality"] + } + ) + retrieved_docs.append(doc) + + context = "\n\n".join(context_parts) + + logger.info( + f"[Retriever] Retrieved {len(docs_info)} documents using {config_info['retriever_type']}") + logger.info( + f"[Retriever] Pipeline components: {', '.join(config_info['stage_types'])}") + + # Return enhanced state with retrieval metadata + return { + **state, + "context": context, + "retrieved_documents": retrieved_docs, + "retrieval_metadata": { + "num_results": len(docs_info), + "retrieval_method": docs_info[0]["retrieval_method"] if docs_info else "none", + "pipeline_config": config_info, + "top_result_score": docs_info[0]["score"] if docs_info else 0.0 + } + } + + except Exception as e: + logger.error(f"[Retriever] Retrieval failed: {str(e)}") + return { + **state, + "context": "", + "error": f"Retriever failed: {str(e)}", + "retrieval_metadata": { + "num_results": 0, + "error": str(e) + } + } + + return retriever + + +def make_retriever(db, dense_embedder, sparse_embedder, top_k=5, strategy=None): + """ + Legacy retriever factory for backward compatibility. + Consider migrating to make_configurable_retriever for better flexibility. + """ + def retriever(state: Dict[str, Any]) -> Dict[str, Any]: + query = state["question"] + logger.info(f"[Retriever] Query: {query}") + if strategy: + logger.info(f"[Retriever] Retrieval strategy: {strategy}") + + try: + vectorstore = db.as_langchain_vectorstore( + dense_embedding=dense_embedder, + sparse_embedding=sparse_embedder, + ) + + docs: List[Document] = vectorstore.similarity_search( + query, k=top_k) + context = "\n\n".join([doc.page_content for doc in docs]) + + logger.info(f"[Retriever] Retrieved {len(docs)} documents.") + return { + **state, + "context": context + } + + except Exception as e: + logger.error(f"[Retriever] Retrieval failed: {str(e)}") + return { + **state, + "context": "", + "error": f"Retriever failed: {str(e)}" + } + return retriever diff --git a/agent/schema.py b/agent/schema.py new file mode 100644 index 0000000..9dfb4eb --- /dev/null +++ b/agent/schema.py @@ -0,0 +1,36 @@ +from typing import TypedDict, List, Optional, Union, Dict, Any +from langchain_core.documents import Document + + +class AgentState(TypedDict, total=False): + """ + Agent state schema that defines all possible state variables for the LangGraph agent. + + Attributes: + question (str): The user's input question + reference_date (str): Reference date for temporal queries + next_node (str): Next node to execute in the agent graph + context (str, optional): Contextual information for response generation + answer (str, optional): Final answer to return to user + chat_history (List[str]): Previous conversation history + + # Enhanced retrieval fields for configurable pipeline integration + retrieved_documents (List[Document], optional): Full document objects with metadata + retrieval_metadata (Dict[str, Any], optional): Pipeline info, scores, method details + retrieval_top_k (int, optional): Override default top_k for dynamic result count + error (str, optional): Error messages from any processing stage + """ + question: str + reference_date: str + next_node: str + context: Optional[str] + answer: Optional[str] + chat_history: List[str] + + # Enhanced retrieval fields + # Full document objects with metadata + retrieved_documents: Optional[List[Document]] + # Pipeline info, scores, etc. + retrieval_metadata: Optional[Dict[str, Any]] + retrieval_top_k: Optional[int] # Override default top_k + error: Optional[str] # Error messages diff --git a/benchmark/bedrock.py b/benchmark/bedrock.py deleted file mode 100644 index edd3c8a..0000000 --- a/benchmark/bedrock.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -""" -Shows how to create a list of action items from a meeting transcript -with the Amazon Titan Text model (on demand). -""" -import json -import logging -import boto3 - -from botocore.exceptions import ClientError - - -class ImageError(Exception): - "Custom exception for errors returned by Amazon Titan Text models" - - def __init__(self, message): - self.message = message - - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -def generate_text(model_id, body): - """ - Generate text using Amazon Titan Text models on demand. - Args: - model_id (str): The model ID to use. - body (str) : The request body to use. - Returns: - response (json): The response from the model. - """ - - logger.info( - "Generating text with Amazon Titan Text model %s", model_id) - - bedrock = boto3.client(service_name='bedrock-runtime') - - accept = "application/json" - content_type = "application/json" - - response = bedrock.invoke_model( - body=body, modelId=model_id, accept=accept, contentType=content_type - ) - response_body = json.loads(response.get("body").read()) - - finish_reason = response_body.get("error") - - if finish_reason is not None: - raise ImageError(f"Text generation error. Error is {finish_reason}") - - logger.info( - "Successfully generated text with Amazon Titan Text model %s", model_id) - - return response_body - - -def main(): - """ - Entrypoint for Amazon Titan Text model example. - """ - try: - logging.basicConfig(level=logging.INFO, - format="%(levelname)s: %(message)s") - - # You can replace the model_id with any other Titan Text Models - # Titan Text Model family model_id is as mentioned below: - # amazon.titan-text-premier-v1:0, amazon.titan-text-express-v1, amazon.titan-text-lite-v1 - model_id = 'amazon.titan-text-premier-v1:0' - - prompt = """Meeting transcript: Miguel: Hi Brant, I want to discuss the workstream - for our new product launch Brant: Sure Miguel, is there anything in particular you want - to discuss? Miguel: Yes, I want to talk about how users enter into the product. - Brant: Ok, in that case let me add in Namita. Namita: Hey everyone - Brant: Hi Namita, Miguel wants to discuss how users enter into the product. - Miguel: its too complicated and we should remove friction. - for example, why do I need to fill out additional forms? - I also find it difficult to find where to access the product - when I first land on the landing page. Brant: I would also add that - I think there are too many steps. Namita: Ok, I can work on the - landing page to make the product more discoverable but brant - can you work on the additonal forms? Brant: Yes but I would need - to work with James from another team as he needs to unblock the sign up workflow. - Miguel can you document any other concerns so that I can discuss with James only once? - Miguel: Sure. - From the meeting transcript above, Create a list of action items for each person. """ - - body = json.dumps({ - "inputText": prompt, - "textGenerationConfig": { - "maxTokenCount": 3072, - "stopSequences": [], - "temperature": 0.7, - "topP": 0.9 - } - }) - - response_body = generate_text(model_id, body) - print(f"Input token count: {response_body['inputTextTokenCount']}") - - for result in response_body['results']: - print(f"Token count: {result['tokenCount']}") - print(f"Output text: {result['outputText']}") - print(f"Completion reason: {result['completionReason']}") - - except ClientError as err: - message = err.response["Error"]["Message"] - logger.error("A client error occurred: %s", message) - print("A client error occured: " + - format(message)) - except ImageError as err: - logger.error(err.message) - print(err.message) - - else: - print( - f"Finished generating text with the Amazon Titan Text Premier model {model_id}.") - - -if __name__ == "__main__": - main() diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py deleted file mode 100644 index 93fb3ef..0000000 --- a/benchmark/benchmark.py +++ /dev/null @@ -1,80 +0,0 @@ -import logging -from collections import defaultdict -from typing import Dict, List - -from beir.datasets.data_loader import GenericDataLoader -from beir.retrieval.evaluation import EvaluateRetrieval - -from mongodb_utils import connect_to_mongodb, MongoAtlasRetriever -from embeddings import TitanEmbeddingWrapper - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s" -) - - -def build_results(queries: Dict[str, str], retriever: MongoAtlasRetriever) -> Dict[str, Dict[str, float]]: - """Retrieve top documents for all queries and build a BEIR-compatible results dict.""" - results = {} - for qid, query_text in queries.items(): - logging.info(f"\nRunning query {qid}: {query_text}") - top_docs = retriever.retrieve(query_text) - - logging.info(f"Top {len(top_docs)} docs retrieved for query {qid}:") - for doc in top_docs: - logging.info( - f" doc_id={doc.get('doc_id')} score={doc.get('score')}") - - results[qid] = { - doc["doc_id"]: float( - doc["score"]) if doc["score"] is not None else 1.0 - for doc in top_docs - } - return results - - -def check_overlap(results: Dict[str, Dict[str, float]], qrels: Dict[str, Dict[str, int]]) -> Dict[str, List[str]]: - """Check which retrieved documents overlap with the ground-truth qrels.""" - overlaps = defaultdict(list) - for qid, retrieved_docs in results.items(): - if qid not in qrels: - continue - gt_doc_ids = set(qrels[qid].keys()) - retrieved_doc_ids = set(retrieved_docs.keys()) - common = gt_doc_ids.intersection(retrieved_doc_ids) - if common: - overlaps[qid] = list(common) - return overlaps - - -def main(): - """Main pipeline: Load data, connect to Mongo, retrieve results, evaluate, and save metrics.""" - dataset_path = "trec-covid" - corpus, queries, qrels = GenericDataLoader(dataset_path).load("test") - - client = connect_to_mongodb() - collection = client["aws_gen_ai"]["TrecCovid"] - - embedding_wrapper = TitanEmbeddingWrapper( - model="amazon.titan-embed-text-v2:0") - retriever = MongoAtlasRetriever( - collection=collection, - embedding_wrapper=embedding_wrapper, - index_name="vector_search", - top_k=500 - ) - - logging.info( - "\nRunning retrieval on ALL queries in the TREC-COVID test set...") - results = build_results(queries, retriever) - evaluator = EvaluateRetrieval() - k_values = [1, 3, 5, 10] - - # Evaluate all metrics - ndcg, _map, recall, precision = evaluator.evaluate( - qrels, results, k_values) - - -if __name__ == "__main__": - main() diff --git a/benchmark/download_datasets.py b/benchmark/download_datasets.py deleted file mode 100644 index 7848b0e..0000000 --- a/benchmark/download_datasets.py +++ /dev/null @@ -1,41 +0,0 @@ -import os -import pathlib - -from beir import util - - -def main(): - out_dir = pathlib.Path(__file__).parent.absolute() - - dataset_files = [ - "msmarco.zip", - "trec-covid.zip", - "nfcorpus.zip", - "nq.zip", - "hotpotqa.zip", - "fiqa.zip", - "arguana.zip", - "webis-touche2020.zip", - "cqadupstack.zip", - "quora.zip", - "dbpedia-entity.zip", - "scidocs.zip", - "fever.zip", - "climate-fever.zip", - "scifact.zip", - "germanquad.zip", - ] - - for dataset in dataset_files: - zip_file = os.path.join(out_dir, dataset) - url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}" - - print(f"Downloading {dataset} ...") - util.download_url(url, zip_file) - - print(f"Unzipping {dataset} ...") - util.unzip(zip_file, out_dir) - - -if __name__ == "__main__": - main() diff --git a/benchmark/embeddings.py b/benchmark/embeddings.py deleted file mode 100644 index d079366..0000000 --- a/benchmark/embeddings.py +++ /dev/null @@ -1,89 +0,0 @@ -import boto3 -import json -from typing import List - - -class TitanEmbeddingWrapper: - def __init__(self, model: str, region: str = "us-east-1"): - """ - Wrapper for generating embeddings using Amazon Titan Embed Text v2. - - Args: - model (str): The model ID to use for embedding. - region (str): AWS region (default: us-east-1). - """ - self.model = model - self.region = region - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """ - Generate embeddings for a list of texts using the Titan embedding model. - - Args: - texts (List[str]): List of input texts. - - Returns: - List[List[float]]: List of embedding vectors. - """ - embeddings = [] - for text in texts: - try: - embedding = get_titan_embedding( - text, model=self.model, region=self.region) - embeddings.append(embedding) - except Exception as e: - print( - f"Failed to generate embedding for text: {text[:30]}... Error: {e}") - # Fallback to a zero vector if embedding fails - embeddings.append([0.0] * 1024) - return embeddings - - -def get_titan_embedding(text: str, model: str, region: str = "us-east-1") -> List[float]: - """ - Generate a text embedding using Amazon Titan Embed Text v2. - - Args: - text (str): The input text to embed. - model (str): The model ID to use for embedding. - region (str): AWS region (default: us-east-1). - - Returns: - List[float]: The embedding vector as a list of floats. - """ - try: - # Initialize the Bedrock runtime client - client = boto3.client("bedrock-runtime", region_name=region) - - # Prepare the request body - body = { - "inputText": text - } - - # Invoke the model - response = client.invoke_model( - body=json.dumps(body), - modelId=model, - accept="application/json", - contentType="application/json" - ) - - # Parse the response - response_body = json.loads(response["body"].read()) - return response_body["embedding"] - - except Exception as e: - raise Exception(f"Failed to generate embedding: {e}") - - -# Example usage -if __name__ == "__main__": - text = "COVID-19 is caused by the SARS-CoV-2 virus." - model_id = "amazon.titan-embed-text-v2:0" - - # Using the wrapper - wrapper = TitanEmbeddingWrapper(model=model_id) - embeddings = wrapper.embed_documents([text]) - - print("Embedding vector (first 10 values):", embeddings[0][:10], "...") - print("Embedding dimension:", len(embeddings[0])) # Should be 1536 diff --git a/benchmark/mongodb_utils.py b/benchmark/mongodb_utils.py deleted file mode 100644 index 5781256..0000000 --- a/benchmark/mongodb_utils.py +++ /dev/null @@ -1,167 +0,0 @@ -import os -import logging -from typing import List, Dict -from pymongo import MongoClient -from pymongo.collection import Collection -from dotenv import load_dotenv -from embeddings import get_titan_embedding - - -def load_env() -> None: - """Load environment variables from a .env file.""" - load_dotenv() - - -def get_mongodb_uri() -> str: - """Retrieve the MongoDB URI from environment variables.""" - uri = os.getenv("MONGODB_URI") - if not uri: - raise EnvironmentError( - "MONGODB_URI is not defined in the environment.") - return uri - - -def connect_to_mongodb() -> MongoClient: - """ - Connect to MongoDB using the URI defined in environment variables. - - Returns: - MongoClient: A connected MongoDB client instance. - """ - load_env() - uri = get_mongodb_uri() - try: - client = MongoClient(uri, serverSelectionTimeoutMS=5000) - client.admin.command("ping") # Test the connection - logging.info("Successfully connected to MongoDB.") - return client - except Exception as e: - logging.error("Failed to connect to MongoDB.") - raise ConnectionError(f"Failed to connect to MongoDB: {e}") - - -class MongoAtlasRetriever: - def __init__(self, collection: Collection, embedding_wrapper, index_name: str = "vector_search", top_k: int = 100): - """ - Args: - collection (Collection): pymongo collection instance - embedding_wrapper: An embedding wrapper instance (e.g., TitanEmbeddingWrapper). - index_name (str): Atlas Search vector index name. - top_k (int): Number of top results to return. - """ - self.collection = collection - self.embedding_wrapper = embedding_wrapper - self.index_name = index_name - self.top_k = top_k - - def retrieve(self, query: str) -> List[Dict]: - logging.info("Generating query vector using the embedding wrapper.") - query_vector = self.embedding_wrapper.embed_documents([query])[0] - - pipeline = [ - { - "$vectorSearch": { - "index": self.index_name, - "path": "embedding", - "queryVector": query_vector, - "numCandidates": 1000, - "limit": self.top_k - } - }, - { - "$project": { - "doc_id": 1, - "text": 1, - "score": {"$meta": "vectorSearchScore"} - } - } - - ] - - logging.info(f"Running vector search with top_k={self.top_k}.") - results = list(self.collection.aggregate(pipeline)) - - # Format for BEIR: doc_id and score required - return [ - { - "doc_id": doc.get("doc_id"), - "text": doc.get("text", ""), - "score": doc.get("score") - } - for doc in results - ] - - -def normalize_vector_score(score: float, min_val: float = 0.965, max_val: float = 1.0) -> float: - """ - Normalize a similarity score from vector search to a 0-1 scale. - - Args: - score (float): Original similarity score. - min_val (float): Minimum expected value. - max_val (float): Maximum expected value. - - Returns: - float: Normalized score. - """ - return max(0.0, min(1.0, (score - min_val) / (max_val - min_val))) - - -def insert_documents(collection: Collection, documents: List[dict]) -> None: - """ - Insert a list of documents into a MongoDB collection. - - Args: - collection (Collection): A MongoDB collection object. - documents (List[dict]): List of documents to insert. - - Raises: - ValueError: If documents list is empty. - """ - if not documents: - raise ValueError("No documents to insert.") - collection.insert_many(documents) - - -def main(): - try: - # Connect to MongoDB - client = connect_to_mongodb() - - # Retrieve database and collection names from environment variables - db_name = os.getenv("MONGODB_DATABASE", "aws_gen_ai") - collection_name = os.getenv("MONGODB_COLLECTION", "NQ") - - # Access the database and collection - db = client[db_name] - collection = db[collection_name] - - # Sample documents to insert - sample_documents = [ - { - "doc_id": "doc1", - "title": "Example Title 1", - "text": "Example content for document 1.", - "embedding": [0.1, 0.2, 0.3], - "source": "test" - }, - { - "doc_id": "doc2", - "title": "Example Title 2", - "text": "Example content for document 2.", - "embedding": [0.4, 0.5, 0.6], - "source": "test" - } - ] - - # Insert documents into the collection - insert_documents(collection, sample_documents) - print( - f"Successfully inserted {len(sample_documents)} documents into {collection.full_name}") - - except Exception as e: - print(f"Error: {e}") - - -if __name__ == "__main__": - main() diff --git a/benchmark/process_dataset.py b/benchmark/process_dataset.py deleted file mode 100644 index 9c21d6e..0000000 --- a/benchmark/process_dataset.py +++ /dev/null @@ -1,166 +0,0 @@ -from beir.datasets.data_loader import GenericDataLoader -from langchain_core.documents import Document -from langchain_experimental.text_splitter import SemanticChunker -from langchain.text_splitter import RecursiveCharacterTextSplitter -from pymongo.collection import Collection -from typing import List -from mongodb_utils import connect_to_mongodb -from embeddings import TitanEmbeddingWrapper -import logging - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - handlers=[ - logging.StreamHandler(), - logging.FileHandler("process_dataset.log", mode="w") - ] -) - -DATABASE_NAME = "aws_gen_ai" -COLLECTION_NAME = "TrecCovid" - -embedding_wrapper = TitanEmbeddingWrapper(model="amazon.titan-embed-text-v2:0") - - -class BaseChunker: - def split_documents(self, documents: List[Document]) -> List[Document]: - raise NotImplementedError( - "This method should be implemented by subclasses.") - - -class SemanticChunkerWrapper(BaseChunker): - def __init__(self, embedding_wrapper): - self.splitter = SemanticChunker( - embeddings=embedding_wrapper, - breakpoint_threshold_type="gradient", - breakpoint_threshold_amount=0.8, - ) - - def split_documents(self, documents: List[Document]) -> List[Document]: - logging.info("Using SemanticChunker to split documents.") - return self.splitter.split_documents(documents) - - -class TextSplitterWrapper(BaseChunker): - def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): - self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, chunk_overlap=chunk_overlap) - - def split_documents(self, documents: List[Document]) -> List[Document]: - logging.info("Using TextSplitter to split documents.") - chunked_docs = [] - for doc in documents: - chunks = self.text_splitter.split_text(doc.page_content) - for chunk in chunks: - chunked_docs.append( - Document(page_content=chunk, metadata={**doc.metadata})) - return chunked_docs - - -class DocumentProcessor: - def __init__(self, chunker: BaseChunker, embedding_wrapper, collection: Collection): - self.chunker = chunker - self.embedding_wrapper = embedding_wrapper - self.collection = collection - - def process_documents(self, documents: List[Document], batch_size: int = 250): - logging.info( - f"Starting batch processing with batch size: {batch_size}") - - for batch_start in range(0, len(documents), batch_size): - batch = documents[batch_start:batch_start + batch_size] - logging.info( - f"Processing batch {batch_start // batch_size + 1} with {len(batch)} documents.") - - chunked_docs = self.chunker.split_documents(batch) - logging.info( - f"Batch {batch_start // batch_size + 1}: Split {len(batch)} documents into {len(chunked_docs)} chunks.") - - texts = [chunk.page_content for chunk in chunked_docs] - logging.info( - f"Batch {batch_start // batch_size + 1}: Starting embedding generation for {len(texts)} chunks...") - - try: - embeddings = self.embedding_wrapper.embed_documents(texts) - logging.info( - f"Batch {batch_start // batch_size + 1}: Successfully generated embeddings for {len(texts)} chunks.") - except Exception as e: - logging.error( - f"Batch {batch_start // batch_size + 1}: Failed to generate embeddings. Error: {e}") - continue - - to_insert = [] - for i, chunk in enumerate(chunked_docs): - text = chunk.page_content - metadata = chunk.metadata - doc_id = metadata.get("doc_id") - logging.info( - f"Processing chunk {i+1}/{len(chunked_docs)} for document ID: {doc_id}") - to_insert.append({ - "doc_id": doc_id, - "chunk_id": f"{doc_id}_{i}", - "text": text, - "embedding": embeddings[i], - "metadata": metadata, - "source": "trec-covid" - }) - - if to_insert: - try: - logging.info( - f"Batch {batch_start // batch_size + 1}: Inserting {len(to_insert)} chunks into MongoDB...") - self.collection.insert_many(to_insert) - logging.info( - f"Batch {batch_start // batch_size + 1}: Successfully inserted {len(to_insert)} chunks into the collection '{self.collection.name}'.") - except Exception as e: - logging.error( - f"Batch {batch_start // batch_size + 1}: Failed to insert chunks into MongoDB. Error: {e}") - - logging.info("Batch processing completed successfully.") - - -def load_all_documents(corpus_path: str) -> List[Document]: - logging.info(f"Loading all documents from the dataset at {corpus_path}.") - corpus, _, _ = GenericDataLoader( - data_folder=corpus_path).load(split="test") - documents = [] - - for doc_id, content in corpus.items(): - title = content.get("title", "") - text = content.get("text", "") - full_text = f"{title}. {text}".strip() - if not full_text: - logging.warning( - f"Document {doc_id} has no content and will be skipped.") - continue - documents.append(Document(page_content=full_text, metadata={ - "doc_id": doc_id, "title": title})) - - logging.info(f"Loaded {len(documents)} documents from the dataset.") - return documents - - -def main(): - logging.info("Starting the document processing pipeline.") - client = connect_to_mongodb() - collection = client[DATABASE_NAME][COLLECTION_NAME] - - docs = load_all_documents("trec-covid") - logging.info(f"Loaded {len(docs)} documents from the dataset.") - - use_text_splitter = True # Set to False to use SemanticChunker - if use_text_splitter: - chunker = TextSplitterWrapper(chunk_size=1000, chunk_overlap=200) - else: - chunker = SemanticChunkerWrapper(embedding_wrapper) - - processor = DocumentProcessor(chunker, embedding_wrapper, collection) - processor.process_documents(docs, batch_size=50) - - logging.info("Document processing pipeline completed successfully.") - - -if __name__ == "__main__": - main() diff --git a/benchmark/results/trec-covid_metrics.csv b/benchmark/results/trec-covid_metrics.csv deleted file mode 100644 index fc775b3..0000000 --- a/benchmark/results/trec-covid_metrics.csv +++ /dev/null @@ -1,4 +0,0 @@ -k,nDCG,Recall,Precision,MRR -1,0.0000,0.0000,0.0000,0.0000 -3,0.0000,0.0000,0.0000,0.0000 -5,0.0000,0.0000,0.0000,0.0000 diff --git a/benchmark/results/trec-covid_sample_results.csv b/benchmark/results/trec-covid_sample_results.csv deleted file mode 100644 index fc775b3..0000000 --- a/benchmark/results/trec-covid_sample_results.csv +++ /dev/null @@ -1,4 +0,0 @@ -k,nDCG,Recall,Precision,MRR -1,0.0000,0.0000,0.0000,0.0000 -3,0.0000,0.0000,0.0000,0.0000 -5,0.0000,0.0000,0.0000,0.0000 diff --git a/benchmark/retriever_test.py b/benchmark/retriever_test.py deleted file mode 100644 index 46e1846..0000000 --- a/benchmark/retriever_test.py +++ /dev/null @@ -1,39 +0,0 @@ -from mongodb_utils import connect_to_mongodb, MongoAtlasRetriever -from beir.datasets.data_loader import GenericDataLoader -from embeddings import TitanEmbeddingWrapper -import os - -# Define the dataset path -corpus_path = "trec-covid" - -# Ensure the dataset exists -if not os.path.exists(corpus_path): - raise FileNotFoundError( - f"Dataset not found at {corpus_path}. Please download the dataset.") - -# Load the queries from the BEIR dataset -_, queries, _ = GenericDataLoader(corpus_path).load(split="test") -first_qid, first_query = next(iter(queries.items())) - -# Connect to MongoDB -client = connect_to_mongodb() -collection = client["aws_gen_ai"]["TrecCovid"] - -# Initialize the embedding wrapper -embedding_wrapper = TitanEmbeddingWrapper(model="amazon.titan-embed-text-v2:0") - -# Initialize the MongoAtlasRetriever with the embedding wrapper -retriever = MongoAtlasRetriever( - collection, embedding_wrapper, index_name="vector_search", top_k=5) - -# Retrieve results for the first query -results = retriever.retrieve(first_query) - -# Display the results -print(f"Query: {first_query}") -print("Top results:") -for i, doc in enumerate(results, 1): - print(f"Rank {i}:") - print(f" doc_id: {doc.get('doc_id')}") - print(f" text: {doc.get('text')[:300]}...") # truncate for preview - print(f" score: {doc.get('score')}") diff --git a/benchmark/test_aws_connection.py b/benchmark/test_aws_connection.py deleted file mode 100644 index bed73bc..0000000 --- a/benchmark/test_aws_connection.py +++ /dev/null @@ -1,15 +0,0 @@ -import boto3 - - -def test_aws_connection(): - try: - client = boto3.client("sts") - response = client.get_caller_identity() - print("Successfully connected to AWS!") - print(f"Account: {response['Account']}, ARN: {response['Arn']}") - except Exception as e: - print(f"Failed to connect to AWS: {e}") - - -if __name__ == "__main__": - test_aws_connection() diff --git a/benchmark_scenarios/dense_baseline.yml b/benchmark_scenarios/dense_baseline.yml new file mode 100644 index 0000000..4797f71 --- /dev/null +++ b/benchmark_scenarios/dense_baseline.yml @@ -0,0 +1,34 @@ +# Dense Retrieval Optimization Scenario +description: "Dense retrieval with Google Gemini embeddings, top_k=10" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "dense" + top_k: 10 + score_threshold: 0.1 + +# Embedding configuration (override from main config) +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + strategy: dense + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 50 +experiment_name: "dense_baseline" diff --git a/benchmark_scenarios/dense_high_precision.yml b/benchmark_scenarios/dense_high_precision.yml new file mode 100644 index 0000000..2091b38 --- /dev/null +++ b/benchmark_scenarios/dense_high_precision.yml @@ -0,0 +1,34 @@ +# Dense Retrieval with High Precision (stricter threshold) +description: "Dense retrieval with higher score threshold for precision" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "dense" + top_k: 10 + score_threshold: 0.3 # Higher threshold for precision + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + strategy: dense + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 50 +experiment_name: "dense_high_precision" diff --git a/benchmark_scenarios/dense_high_recall.yml b/benchmark_scenarios/dense_high_recall.yml new file mode 100644 index 0000000..e6940b2 --- /dev/null +++ b/benchmark_scenarios/dense_high_recall.yml @@ -0,0 +1,34 @@ +# Dense Retrieval with Higher Recall (more results) +description: "Dense retrieval with higher top_k=20 for better recall" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "dense" + top_k: 20 + score_threshold: 0.05 # Lower threshold for more results + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + strategy: dense + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10, 20] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 50 +experiment_name: "dense_high_recall" diff --git a/benchmark_scenarios/hybrid_advanced.yml b/benchmark_scenarios/hybrid_advanced.yml new file mode 100644 index 0000000..169cae1 --- /dev/null +++ b/benchmark_scenarios/hybrid_advanced.yml @@ -0,0 +1,48 @@ +# Advanced Hybrid Retrieval Optimization +description: "Advanced hybrid retrieval with optimized fusion parameters" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "hybrid" + top_k: 20 # Higher top_k for better recall + score_threshold: 0.01 + fusion_method: rrf # Options: 'rrf', 'weighted_sum' + dense_weight: 0.7 # For weighted_sum method + sparse_weight: 0.3 # For weighted_sum method + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + sparse: + provider: sparse + model: Qdrant/bm25 + vector_name: sparse + strategy: hybrid + +# Fusion method configuration +fusion: + method: rrf # rrf or weighted_sum + rrf_k: 50 # Smaller k gives more emphasis to top results + dense_weight: 0.7 # For weighted_sum + sparse_weight: 0.3 + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10, 15, 20] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 100 # Larger sample size +experiment_name: "hybrid_advanced_rrf" diff --git a/benchmark_scenarios/hybrid_reranking.yml b/benchmark_scenarios/hybrid_reranking.yml new file mode 100644 index 0000000..78c7a19 --- /dev/null +++ b/benchmark_scenarios/hybrid_reranking.yml @@ -0,0 +1,57 @@ +# Advanced Reranking Hybrid Retrieval +description: "Hybrid retrieval with advanced reranking experiments" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "hybrid" + top_k: 30 # Get more candidates for reranking + score_threshold: 0.01 + fusion_method: rrf + dense_weight: 0.7 + sparse_weight: 0.3 + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + sparse: + provider: sparse + model: Qdrant/bm25 + vector_name: sparse + strategy: hybrid + +# Advanced reranking configuration +reranking: + enabled: true + model: "cross-encoder/ms-marco-TinyBERT-L-2-v2" # Faster alternative + # model: "cross-encoder/ms-marco-MiniLM-L-6-v2" # Current (balanced) + # model: "cross-encoder/ms-marco-MiniLM-L-12-v2" # Larger/better + top_k: 10 + batch_size: 16 + +# Fusion method configuration +fusion: + method: rrf + rrf_k: 50 + dense_weight: 0.7 + sparse_weight: 0.3 + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10, 15, 20] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 100 +experiment_name: "hybrid_advanced_reranking" diff --git a/benchmark_scenarios/hybrid_retrieval.yml b/benchmark_scenarios/hybrid_retrieval.yml new file mode 100644 index 0000000..979d42e --- /dev/null +++ b/benchmark_scenarios/hybrid_retrieval.yml @@ -0,0 +1,38 @@ +# Hybrid Retrieval Optimization Scenario +description: "Hybrid dense+sparse retrieval for best of both worlds" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "hybrid" + top_k: 15 + score_threshold: 0.1 + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + sparse: + provider: sparse + model: Qdrant/bm25 + vector_name: sparse + strategy: hybrid + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10, 15] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 50 +experiment_name: "hybrid_dense_sparse" diff --git a/benchmark_scenarios/hybrid_weighted.yml b/benchmark_scenarios/hybrid_weighted.yml new file mode 100644 index 0000000..48deefb --- /dev/null +++ b/benchmark_scenarios/hybrid_weighted.yml @@ -0,0 +1,47 @@ +# Weighted Sum Hybrid Retrieval Optimization +description: "Hybrid retrieval using weighted sum fusion with optimized weights" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "hybrid" + top_k: 20 + score_threshold: 0.01 + fusion_method: weighted_sum + dense_weight: 0.8 # Emphasize dense retrieval more + sparse_weight: 0.2 + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + sparse: + provider: sparse + model: Qdrant/bm25 + vector_name: sparse + strategy: hybrid + +# Fusion method configuration +fusion: + method: weighted_sum + dense_weight: 0.8 # Emphasize dense retrieval more + sparse_weight: 0.2 + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10, 15, 20] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 100 +experiment_name: "hybrid_weighted_sum" diff --git a/benchmark_scenarios/quick_test.yml b/benchmark_scenarios/quick_test.yml new file mode 100644 index 0000000..8f24a3e --- /dev/null +++ b/benchmark_scenarios/quick_test.yml @@ -0,0 +1,34 @@ +# Small Dataset Quick Test +description: "Quick test with small dataset for rapid iteration" + +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration +retrieval: + type: "dense" + top_k: 10 + score_threshold: 0.1 + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + strategy: dense + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 10 # Small for quick testing +experiment_name: "quick_test" diff --git a/benchmark_scenarios/sparse_bm25.yml b/benchmark_scenarios/sparse_bm25.yml new file mode 100644 index 0000000..0c2e78e --- /dev/null +++ b/benchmark_scenarios/sparse_bm25.yml @@ -0,0 +1,40 @@ +# DISABLED: Sparse Retrieval Optimization Scenario +# Note: Sparse retriever currently has configuration issues +description: "DISABLED - Sparse retrieval with BM25 (needs fixing)" + +# This scenario is disabled because sparse retrieval is not working +disabled: true +disabled_reason: "Sparse retriever has configuration/implementation issues" + +# Original configuration (for reference): +# Dataset configuration +dataset: + path: "/home/spiros/Desktop/Thesis/datasets/sosum/data" + use_ground_truth: true + +# Retrieval configuration - DISABLED +retrieval: + type: "dense" # Fallback to dense since sparse doesn't work + top_k: 15 + score_threshold: 0.1 + +# Embedding configuration +embedding: + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + strategy: dense # Changed from sparse to dense + +# Evaluation configuration +evaluation: + k_values: [1, 5, 10, 15] + metrics: + retrieval: ["precision@k", "recall@k", "mrr", "ndcg@k"] + +# Experiment parameters +max_queries: 50 +experiment_name: "sparse_bm25_disabled" diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/benchmark_contracts.py b/benchmarks/benchmark_contracts.py new file mode 100644 index 0000000..e3e5d52 --- /dev/null +++ b/benchmarks/benchmark_contracts.py @@ -0,0 +1,61 @@ +"""Benchmark contracts and interfaces.""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union +from dataclasses import dataclass +from enum import Enum +from pydantic import BaseModel, Field + + +class BenchmarkTask(str, Enum): + """Standard benchmark tasks.""" + RETRIEVAL = "retrieval" # How well does retrieval find relevant docs? + GENERATION = "generation" # How good are the generated answers? + END_TO_END = "end_to_end" # Complete RAG pipeline evaluation + RERANKING = "reranking" # How well does reranking improve results? + SEMANTIC_SEARCH = "semantic_search" # Pure semantic similarity + + +@dataclass +class BenchmarkQuery: + """A single benchmark query.""" + query_id: str + query_text: str + expected_answer: Optional[str] = None + relevant_doc_ids: Optional[List[str]] = None + difficulty: Optional[str] = None # easy, medium, hard + category: Optional[str] = None # domain-specific categories + metadata: Dict[str, Any] = None + + +@dataclass +class BenchmarkResult: + """Result of a single query evaluation.""" + query_id: str + retrieved_docs: List[str] + generated_answer: Optional[str] = None + retrieval_time_ms: float = 0.0 + generation_time_ms: float = 0.0 + scores: Dict[str, float] = None # metric_name -> score + + +class BenchmarkAdapter(ABC): + """Abstract adapter for different benchmark datasets.""" + + @property + @abstractmethod + def name(self) -> str: + """Dataset name.""" + + @property + @abstractmethod + def tasks(self) -> List[BenchmarkTask]: + """Supported benchmark tasks.""" + + @abstractmethod + def load_queries(self, split: str = "test") -> List[BenchmarkQuery]: + """Load benchmark queries.""" + + @abstractmethod + def get_ground_truth(self, query_id: str) -> Dict[str, Any]: + """Get ground truth for evaluation.""" diff --git a/benchmarks/benchmark_optimizer.py b/benchmarks/benchmark_optimizer.py new file mode 100644 index 0000000..c4bb9d8 --- /dev/null +++ b/benchmarks/benchmark_optimizer.py @@ -0,0 +1,369 @@ +""" +Flexible benchmark runner with configurable optimization parameters. +Supports multiple benchmark scenarios for hyperparameter optimization. +""" + +from config.config_loader import load_config +from benchmarks.benchmark_contracts import BenchmarkQuery +from benchmarks.benchmarks_adapters import StackOverflowBenchmarkAdapter, FullDatasetAdapter +from benchmarks.benchmarks_runner import BenchmarkRunner +import sys +import os +import yaml +import argparse +import pandas as pd +from pathlib import Path +from typing import Dict, Any, List +sys.path.append('/home/spiros/Desktop/Thesis/Thesis') + + +class BenchmarkOptimizer: + """Flexible benchmark runner for optimization experiments.""" + + def __init__(self, base_config_path: str = "config.yml"): + """Initialize with base configuration.""" + self.base_config = load_config(base_config_path) + self.results_history = [] + + def load_benchmark_config(self, benchmark_config_path: str) -> Dict[str, Any]: + """Load benchmark-specific configuration.""" + with open(benchmark_config_path, 'r') as f: + benchmark_config = yaml.safe_load(f) + + # Merge with base config + config = self.base_config.copy() + config.update(benchmark_config) + + return config + + def run_optimization_scenario(self, scenario_name: str, config: Dict[str, Any]) -> Dict[str, Any]: + """Run a single optimization scenario.""" + print(f"\n🚀 Running optimization scenario: {scenario_name}") + print(f"📊 Config: {config.get('description', 'No description')}") + + # Setup benchmark runner + runner = BenchmarkRunner(config) + + # Setup data adapter based on config + dataset_config = config.get('dataset', {}) + dataset_path = dataset_config.get( + 'path', '/home/spiros/Desktop/Thesis/datasets/sosum/data') + + if dataset_config.get('use_ground_truth', True): + # Use FullDatasetAdapter for ground truth evaluation + adapter = FullDatasetAdapter(dataset_path) + else: + # Use custom adapter for real questions without ground truth + adapter = self._create_real_data_adapter(dataset_path) + + # Run benchmark + print(f"📈 Running with max_queries: {config.get('max_queries', 10)}") + results = runner.run_benchmark( + adapter=adapter, + max_queries=config.get('max_queries', 10) + ) + + # Add scenario metadata + results['scenario_name'] = scenario_name + results['scenario_config'] = config + + # Store results + self.results_history.append(results) + + return results + + def _create_real_data_adapter(self, dataset_path: str): + """Create adapter for real data without ground truth.""" + class RealDataAdapter(StackOverflowBenchmarkAdapter): + def load_queries(self, split: str = "test"): + import pandas as pd + + question_file = Path(dataset_path) / "question.csv" + df = pd.read_csv(question_file) + + queries = [] + for idx, row in df.iterrows(): + if idx >= 50: # Limit for testing + break + + if pd.isna(row['question_title']): + continue + + query = BenchmarkQuery( + query_id=f"real_so_{row['question_id']}", + query_text=str(row['question_title']), + expected_answer=None, + relevant_doc_ids=None, # No ground truth + difficulty="medium", + category="programming", + metadata={"source": "real_stackoverflow"} + ) + queries.append(query) + + return queries + + return RealDataAdapter(dataset_path) + + def run_multiple_scenarios(self, scenarios_dir: str = "benchmark_scenarios") -> List[Dict[str, Any]]: + """Run multiple optimization scenarios from a directory.""" + scenarios_path = Path(scenarios_dir) + if not scenarios_path.exists(): + print(f"❌ Scenarios directory not found: {scenarios_path}") + return [] + + results = [] + for scenario_file in scenarios_path.glob("*.yml"): + scenario_name = scenario_file.stem + config = self.load_benchmark_config(str(scenario_file)) + + try: + result = self.run_optimization_scenario(scenario_name, config) + results.append(result) + + # Print quick summary + self._print_scenario_summary(scenario_name, result) + + except Exception as e: + print(f"❌ Failed scenario {scenario_name}: {e}") + continue + + return results + + def _print_scenario_summary(self, scenario_name: str, results: Dict[str, Any]): + """Print a quick summary of scenario results.""" + print(f"\n📊 {scenario_name} Results:") + print(f" Queries: {results['config']['total_queries']}") + print( + f" Avg Time: {results['performance']['avg_retrieval_time_ms']:.2f}ms") + + # Print key metrics + metrics = results.get('metrics', {}) + for metric_name in ['precision@5', 'recall@5', 'mrr']: + if metric_name in metrics: + mean_val = metrics[metric_name]['mean'] + print(f" {metric_name}: {mean_val:.3f}") + + def compare_scenarios(self) -> Dict[str, Any]: + """Compare all run scenarios.""" + if not self.results_history: + print("❌ No scenarios run yet") + return {} + + print( + f"\n🔬 OPTIMIZATION COMPARISON ({len(self.results_history)} scenarios)") + print("="*80) + + comparison = { + 'scenarios': [], + 'best_precision': {'scenario': None, 'value': 0}, + 'best_recall': {'scenario': None, 'value': 0}, + 'best_mrr': {'scenario': None, 'value': 0}, + 'fastest': {'scenario': None, 'time': float('inf')} + } + + for result in self.results_history: + scenario_name = result['scenario_name'] + metrics = result.get('metrics', {}) + avg_time = result['performance']['avg_retrieval_time_ms'] + + scenario_summary = { + 'name': scenario_name, + 'precision@5': metrics.get('precision@5', {}).get('mean', 0), + 'recall@5': metrics.get('recall@5', {}).get('mean', 0), + 'mrr': metrics.get('mrr', {}).get('mean', 0), + 'avg_time_ms': avg_time, + 'config': result['scenario_config'] + } + + comparison['scenarios'].append(scenario_summary) + + # Track best performers + if scenario_summary['precision@5'] > comparison['best_precision']['value']: + comparison['best_precision'] = { + 'scenario': scenario_name, 'value': scenario_summary['precision@5']} + + if scenario_summary['recall@5'] > comparison['best_recall']['value']: + comparison['best_recall'] = { + 'scenario': scenario_name, 'value': scenario_summary['recall@5']} + + if scenario_summary['mrr'] > comparison['best_mrr']['value']: + comparison['best_mrr'] = { + 'scenario': scenario_name, 'value': scenario_summary['mrr']} + + if avg_time < comparison['fastest']['time']: + comparison['fastest'] = { + 'scenario': scenario_name, 'time': avg_time} + + # Print scenario details + print(f"📋 {scenario_name}:") + print(f" Precision@5: {scenario_summary['precision@5']:.3f}") + print(f" Recall@5: {scenario_summary['recall@5']:.3f}") + print(f" MRR: {scenario_summary['mrr']:.3f}") + print(f" Avg Time: {avg_time:.2f}ms") + print( + f" Config: {result['scenario_config'].get('description', 'N/A')}") + print() + + # Print best performers + print(f"🏆 BEST PERFORMERS:") + print( + f" Best Precision@5: {comparison['best_precision']['scenario']} ({comparison['best_precision']['value']:.3f})") + print( + f" Best Recall@5: {comparison['best_recall']['scenario']} ({comparison['best_recall']['value']:.3f})") + print( + f" Best MRR: {comparison['best_mrr']['scenario']} ({comparison['best_mrr']['value']:.3f})") + print( + f" Fastest: {comparison['fastest']['scenario']} ({comparison['fastest']['time']:.2f}ms)") + + return comparison + + def save_results(self, output_file: str = "benchmark_optimization_results.csv"): + """Save all results to a CSV file.""" + if not self.results_history: + print("❌ No results to save") + return + + # Prepare data for CSV + csv_data = [] + for result in self.results_history: + scenario_name = result['scenario_name'] + config = result.get('scenario_config', {}) + metrics = result.get('metrics', {}) + performance = result.get('performance', {}) + + row = { + 'scenario_name': scenario_name, + 'description': config.get('description', 'N/A'), + 'default_retriever': config.get('default_retriever', 'N/A'), + 'max_queries': config.get('max_queries', 0), + 'total_queries': result.get('config', {}).get('total_queries', 0), + 'avg_time_ms': performance.get('avg_retrieval_time_ms', 0), + 'min_time_ms': performance.get('min_retrieval_time_ms', 0), + 'max_time_ms': performance.get('max_retrieval_time_ms', 0), + 'precision@1_mean': metrics.get('precision@1', {}).get('mean', 0), + 'precision@1_std': metrics.get('precision@1', {}).get('std', 0), + 'precision@5_mean': metrics.get('precision@5', {}).get('mean', 0), + 'precision@5_std': metrics.get('precision@5', {}).get('std', 0), + 'precision@10_mean': metrics.get('precision@10', {}).get('mean', 0), + 'precision@10_std': metrics.get('precision@10', {}).get('std', 0), + 'recall@1_mean': metrics.get('recall@1', {}).get('mean', 0), + 'recall@1_std': metrics.get('recall@1', {}).get('std', 0), + 'recall@5_mean': metrics.get('recall@5', {}).get('mean', 0), + 'recall@5_std': metrics.get('recall@5', {}).get('std', 0), + 'recall@10_mean': metrics.get('recall@10', {}).get('mean', 0), + 'recall@10_std': metrics.get('recall@10', {}).get('std', 0), + 'mrr_mean': metrics.get('mrr', {}).get('mean', 0), + 'mrr_std': metrics.get('mrr', {}).get('std', 0), + 'ndcg@5_mean': metrics.get('ndcg@5', {}).get('mean', 0), + 'ndcg@5_std': metrics.get('ndcg@5', {}).get('std', 0), + 'ndcg@10_mean': metrics.get('ndcg@10', {}).get('mean', 0), + 'ndcg@10_std': metrics.get('ndcg@10', {}).get('std', 0), + } + + # Add configuration details + retrieval_config = config.get('retrieval', {}) + row['top_k'] = retrieval_config.get('top_k', 'N/A') + row['score_threshold'] = retrieval_config.get( + 'score_threshold', 'N/A') + + # Add embedding details + embedding_config = config.get('embedding', {}) + if isinstance(embedding_config, dict): + row['embedding_provider'] = embedding_config.get( + 'provider', 'N/A') + row['embedding_model'] = embedding_config.get('model', 'N/A') + else: + row['embedding_provider'] = 'N/A' + row['embedding_model'] = 'N/A' + + csv_data.append(row) + + # Create DataFrame and save to CSV + df = pd.DataFrame(csv_data) + df.to_csv(output_file, index=False) + + print(f"💾 Results saved to {output_file}") + print( + f"📊 Saved {len(csv_data)} scenarios with {len(df.columns)} columns") + + # Also save a summary CSV with just key metrics + summary_file = output_file.replace('.csv', '_summary.csv') + summary_columns = [ + 'scenario_name', 'description', 'default_retriever', 'total_queries', + 'avg_time_ms', 'precision@5_mean', 'recall@5_mean', 'mrr_mean' + ] + summary_df = df[summary_columns] + summary_df.to_csv(summary_file, index=False) + print(f"📋 Summary saved to {summary_file}") + + +def main(): + """Main function with CLI support.""" + parser = argparse.ArgumentParser( + description="Run benchmark optimization scenarios") + parser.add_argument('--scenario', type=str, + help='Single scenario config file') + parser.add_argument('--scenarios-dir', type=str, default='benchmark_scenarios', + help='Directory containing scenario configs') + parser.add_argument('--compare-only', action='store_true', + help='Only compare existing results') + + args = parser.parse_args() + + optimizer = BenchmarkOptimizer() + + if args.compare_only: + # Load existing results if available + try: + # Try to load from CSV first, then fallback to YAML + if os.path.exists('benchmark_optimization_results.csv'): + df = pd.read_csv('benchmark_optimization_results.csv') + # Convert CSV back to results format for comparison + optimizer.results_history = [] + for _, row in df.iterrows(): + result = { + 'scenario_name': row['scenario_name'], + 'scenario_config': { + 'description': row['description'], + 'default_retriever': row['default_retriever'], + 'max_queries': row['max_queries'] + }, + 'config': {'total_queries': row['total_queries']}, + 'performance': {'avg_retrieval_time_ms': row['avg_time_ms']}, + 'metrics': { + 'precision@5': {'mean': row['precision@5_mean']}, + 'recall@5': {'mean': row['recall@5_mean']}, + 'mrr': {'mean': row['mrr_mean']} + } + } + optimizer.results_history.append(result) + else: + # Fallback to YAML format + with open('benchmark_optimization_results.yml', 'r') as f: + data = yaml.safe_load(f) + optimizer.results_history = data.get('scenarios', []) + optimizer.compare_scenarios() + except FileNotFoundError: + print("❌ No existing results found (searched for .csv and .yml)") + return + + if args.scenario: + # Run single scenario + config = optimizer.load_benchmark_config(args.scenario) + result = optimizer.run_optimization_scenario( + Path(args.scenario).stem, config) + optimizer._print_scenario_summary(Path(args.scenario).stem, result) + else: + # Run multiple scenarios + results = optimizer.run_multiple_scenarios(args.scenarios_dir) + + if results: + # Compare all scenarios + optimizer.compare_scenarios() + + # Save results + optimizer.save_results() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmarks_adapters.py b/benchmarks/benchmarks_adapters.py new file mode 100644 index 0000000..a933d3b --- /dev/null +++ b/benchmarks/benchmarks_adapters.py @@ -0,0 +1,274 @@ +"""StackOverflow benchmark adapter.""" +import json +import os +from pathlib import Path +from typing import List, Union, Dict, Any +from benchmarks.benchmark_contracts import BenchmarkAdapter, BenchmarkTask, BenchmarkQuery + + +class StackOverflowBenchmarkAdapter(BenchmarkAdapter): + """Benchmark adapter for StackOverflow datasets.""" + + def __init__(self, dataset_path: str): + self.dataset_path = Path(dataset_path) + + @property + def name(self) -> str: + return "stackoverflow" + + @property + def tasks(self) -> List[BenchmarkTask]: + return [BenchmarkTask.RETRIEVAL, BenchmarkTask.END_TO_END] + + def load_queries(self, split: str = "test") -> List[BenchmarkQuery]: + """Convert SO questions into benchmark queries.""" + queries = [] + + # Try to find CSV or JSON files in the dataset directory + csv_files = list(self.dataset_path.glob("*.csv")) + json_files = list(self.dataset_path.glob("*.json")) + + if csv_files: + queries = self._load_from_csv(csv_files[0]) + elif json_files: + queries = self._load_from_json(json_files[0]) + else: + print(f"⚠️ No CSV or JSON files found in {self.dataset_path}") + # Return dummy queries for testing + return self._create_dummy_queries() + + print(f"✅ Loaded {len(queries)} queries from {split} split") + return queries[:100] # Limit for testing + + def _load_from_csv(self, csv_file: Path) -> List[BenchmarkQuery]: + """Load queries from CSV file.""" + import pandas as pd + + try: + df = pd.read_csv(csv_file) + queries = [] + + # Try different column name combinations + title_col = None + body_col = None + id_col = None + + for col in df.columns: + if 'title' in col.lower() or 'question' in col.lower(): + title_col = col + elif 'body' in col.lower() or 'text' in col.lower(): + body_col = col + elif 'id' in col.lower(): + id_col = col + + if not title_col: + print( + f"❌ No title column found. Available columns: {list(df.columns)}") + return self._create_dummy_queries() + + for idx, row in df.iterrows(): + if idx >= 100: # Limit for testing + break + + query_id = str(row[id_col]) if id_col else f"csv_{idx}" + title = str(row[title_col]) + body = str(row[body_col]) if body_col else "" + + if not title or title == 'nan': + continue + + query = BenchmarkQuery( + query_id=f"so_{query_id}", + query_text=title, + expected_answer=body[:500] if body and body != 'nan' else None, + relevant_doc_ids=None, + difficulty="medium", + category="programming", + metadata={ + "source": "stackoverflow_csv", + "row_index": idx + } + ) + queries.append(query) + + return queries + + except Exception as e: + print(f"❌ Error loading CSV {csv_file}: {e}") + return self._create_dummy_queries() + + def _load_from_json(self, json_file: Path) -> List[BenchmarkQuery]: + """Load queries from JSON file.""" + try: + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + queries = [] + + # Handle different JSON structures + if isinstance(data, list): + questions = data + elif isinstance(data, dict) and 'questions' in data: + questions = data['questions'] + else: + questions = [data] # Single question + + for i, question in enumerate(questions[:100]): # Limit for testing + query = self._create_query_from_question(question, i) + if query: + queries.append(query) + + return queries + + except Exception as e: + print(f"❌ Error loading JSON {json_file}: {e}") + return self._create_dummy_queries() + + def _create_query_from_question(self, question: Dict[str, Any], index: int) -> BenchmarkQuery: + """Create a benchmark query from a question.""" + + # Try different possible field names + title = question.get('title') or question.get( + 'question_title') or question.get('Title') + body = question.get('body') or question.get( + 'question_body') or question.get('Body') or "" + qid = question.get('id') or question.get( + 'question_id') or question.get('Id') or f"q_{index}" + + if not title: + return None + + return BenchmarkQuery( + query_id=f"so_{qid}", + query_text=title, + expected_answer=body[:500] if body else None, + relevant_doc_ids=None, + difficulty="medium", + category="programming", + metadata={ + "original_question": question, + "source": "stackoverflow_json" + } + ) + + def _create_dummy_queries(self) -> List[BenchmarkQuery]: + """Create dummy queries for testing.""" + dummy_questions = [ + "How to show error message box in .NET?", + "What is the difference between StringBuilder and String in C#?", + "How to convert string to int in Java?", + "How to handle null values in Python?", + "What is the best way to iterate over a dictionary in Python?", + "How to reverse a string in Python?", + "What is object-oriented programming?", + "How to use lambda functions in Python?", + "What is the difference between list and tuple?", + "How to handle exceptions in Python?" + ] + + queries = [] + for i, question in enumerate(dummy_questions): + query = BenchmarkQuery( + query_id=f"dummy_{i}", + query_text=question, + expected_answer=f"Programming answer for: {question}", + relevant_doc_ids=None, + difficulty="easy", + category="programming", + metadata={"source": "dummy"} + ) + queries.append(query) + + return queries + + def get_ground_truth(self, query_id: str) -> Dict[str, Any]: + """Get ground truth for evaluation.""" + return {"relevant_docs": [], "expected_answer": None} + + +class FullDatasetAdapter(StackOverflowBenchmarkAdapter): + """Adapter that uses the full dataset with ground truth for proper evaluation.""" + + def __init__(self, dataset_path: str): + super().__init__(dataset_path) + + @property + def name(self) -> str: + return "stackoverflow_full_dataset" + + def load_queries(self, split: str = "test") -> List[BenchmarkQuery]: + """Load queries with ground truth from the full dataset.""" + import pandas as pd + import ast + + question_file = self.dataset_path / "question.csv" + + if not question_file.exists(): + print(f"❌ Question file not found: {question_file}") + return self._create_dummy_queries() + + try: + print(f"📂 Loading questions from {question_file}") + df = pd.read_csv(question_file) + print(f"📊 Total questions in dataset: {len(df)}") + + # Filter for questions with ground truth (answer_posts) + df_with_gt = df[df['answer_posts'].notna()] + print(f"📊 Questions with ground truth: {len(df_with_gt)}") + + queries = [] + for idx, row in df_with_gt.iterrows(): + if pd.isna(row['question_title']) or not row['question_title'].strip(): + continue + + # Parse answer IDs from the answer_posts field + try: + if isinstance(row['answer_posts'], str): + # Try to parse as literal (list format) + answer_ids = ast.literal_eval(row['answer_posts']) + else: + # Could be a single ID or other format + answer_ids = [int(row['answer_posts'])] + + # Convert to document IDs with 'a_' prefix + relevant_doc_ids = [f"a_{aid}" for aid in answer_ids] + + if not relevant_doc_ids: + continue # Skip if no valid answer IDs + + except (ValueError, SyntaxError, TypeError) as e: + print( + f"⚠️ Failed to parse answer_posts for question {row['question_id']}: {e}") + continue + + query = BenchmarkQuery( + query_id=f"full_so_{row['question_id']}", + query_text=str(row['question_title']).strip(), + expected_answer=None, # We don't need the answer text for retrieval eval + relevant_doc_ids=relevant_doc_ids, + difficulty="medium", + category="programming", + metadata={ + "source": "full_dataset_with_ground_truth", + "original_question_id": row['question_id'], + "question_type": row.get('question_type', 'unknown'), + "tags": row.get('tags', ''), + "num_ground_truth_docs": len(relevant_doc_ids) + } + ) + queries.append(query) + + print( + f"✅ Successfully loaded {len(queries)} queries with ground truth") + return queries + + except Exception as e: + print(f"❌ Error loading full dataset: {e}") + import traceback + traceback.print_exc() + return self._create_dummy_queries() + + def get_ground_truth(self, query_id: str) -> Dict[str, Any]: + """Get ground truth for evaluation (override parent method).""" + # For this adapter, ground truth is already in the query's relevant_doc_ids + return {"relevant_docs": [], "expected_answer": None} diff --git a/benchmarks/benchmarks_metrics.py b/benchmarks/benchmarks_metrics.py new file mode 100644 index 0000000..70b0883 --- /dev/null +++ b/benchmarks/benchmarks_metrics.py @@ -0,0 +1,105 @@ +"""Comprehensive evaluation metrics for RAG systems.""" + +from typing import List, Dict, Any +import numpy as np + + +class BenchmarkMetrics: + """Collection of evaluation metrics for RAG systems.""" + + @staticmethod + def retrieval_metrics( + retrieved_docs: List[str], + relevant_docs: List[str], + k_values: List[int] = [1, 5, 10, 20] + ) -> Dict[str, float]: + """Compute retrieval metrics.""" + metrics = {} + + # If no ground truth is available, return NaN metrics to indicate unavailable evaluation + if not relevant_docs: + for k in k_values: + metrics[f"precision@{k}"] = float('nan') + metrics[f"recall@{k}"] = float('nan') + metrics[f"ndcg@{k}"] = float('nan') + metrics["mrr"] = float('nan') + return metrics + + # Precision@K + for k in k_values: + retrieved_k = retrieved_docs[:k] + if retrieved_k: + relevant_retrieved = len(set(retrieved_k) & set(relevant_docs)) + metrics[f"precision@{k}"] = relevant_retrieved / \ + len(retrieved_k) + else: + metrics[f"precision@{k}"] = 0.0 + + # Recall@K + for k in k_values: + retrieved_k = retrieved_docs[:k] + if relevant_docs: + relevant_retrieved = len(set(retrieved_k) & set(relevant_docs)) + metrics[f"recall@{k}"] = relevant_retrieved / \ + len(relevant_docs) + else: + metrics[f"recall@{k}"] = 0.0 + + # Mean Reciprocal Rank (MRR) + mrr = 0.0 + for i, doc_id in enumerate(retrieved_docs): + if doc_id in relevant_docs: + mrr = 1.0 / (i + 1) + break + metrics["mrr"] = mrr + + # NDCG@K (simplified binary relevance) + for k in k_values: + retrieved_k = retrieved_docs[:k] + if retrieved_k and relevant_docs: + # Binary relevance: 1 if relevant, 0 if not + relevance_scores = [ + 1.0 if doc in relevant_docs else 0.0 for doc in retrieved_k] + dcg = sum(rel / np.log2(i + 2) + for i, rel in enumerate(relevance_scores)) + + # Ideal DCG (best possible ordering) + ideal_relevance = sorted(relevance_scores, reverse=True) + idcg = sum(rel / np.log2(i + 2) + for i, rel in enumerate(ideal_relevance)) + + metrics[f"ndcg@{k}"] = dcg / idcg if idcg > 0 else 0.0 + else: + metrics[f"ndcg@{k}"] = 0.0 + + return metrics + + @staticmethod + def generation_metrics( + generated_answer: str, + reference_answer: str + ) -> Dict[str, float]: + """Compute simple text generation metrics.""" + metrics = {} + + if not reference_answer: + return {"length_ratio": 0.0, "character_overlap": 0.0} + + # Simple metrics without external dependencies + metrics["length_ratio"] = len(generated_answer) / len(reference_answer) + + # Character overlap ratio + gen_chars = set(generated_answer.lower()) + ref_chars = set(reference_answer.lower()) + overlap = len(gen_chars & ref_chars) + metrics["character_overlap"] = overlap / \ + len(ref_chars) if ref_chars else 0.0 + + # Word overlap ratio + gen_words = set(generated_answer.lower().split()) + ref_words = set(reference_answer.lower().split()) + word_overlap = len(gen_words & ref_words) + metrics["word_overlap"] = word_overlap / \ + len(ref_words) if ref_words else 0.0 + + return metrics diff --git a/benchmarks/benchmarks_runner.py b/benchmarks/benchmarks_runner.py new file mode 100644 index 0000000..05ff13e --- /dev/null +++ b/benchmarks/benchmarks_runner.py @@ -0,0 +1,284 @@ +"""Configuration-driven benchmark execution engine.""" + +import time +import numpy as np +from typing import List, Dict, Any, Optional +from tqdm import tqdm + +from benchmarks.benchmark_contracts import BenchmarkAdapter, BenchmarkQuery, BenchmarkResult +from benchmarks.benchmarks_metrics import BenchmarkMetrics +from components.retrieval_pipeline import RetrievalPipelineFactory +from config.config_loader import get_benchmark_config, get_retriever_config + + +class BenchmarkRunner: + """Execute benchmarks against configurable RAG systems.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + # Use config directly instead of get_benchmark_config + self.benchmark_config = config + self.metrics = BenchmarkMetrics() + + # Initialize retrieval engine based on unified config + self.retrieval_pipeline = self._init_retrieval_pipeline() + + # Initialize generation engine (optional) + self.generation_engine = self._init_generation_engine() + + def _init_retrieval_pipeline(self): + """Initialize retrieval pipeline from unified configuration.""" + # Try to get retriever type from multiple config locations + retrieval_type = None + + # Check if explicitly set in the config (for benchmark optimizer) + if 'default_retriever' in self.config: + retrieval_type = self.config['default_retriever'] + elif 'retrieval' in self.config: + retrieval_config = self.config.get("retrieval", {}) + retrieval_type = retrieval_config.get("type") + elif 'benchmark' in self.config and 'retrieval' in self.config['benchmark']: + benchmark_retrieval = self.config['benchmark']['retrieval'] + retrieval_type = benchmark_retrieval.get("strategy") + + # Use unified config factory (will use pipeline default if retrieval_type is None) + return RetrievalPipelineFactory.create_from_unified_config(self.config, retrieval_type) + + def _init_generation_engine(self): + """Initialize generation engine from configuration.""" + generation_config = self.config.get("generation", {}) + + if not generation_config.get("enabled", False): + return None + + # For now, return None - generation engine can be implemented later + return None + + def run_benchmark( + self, + adapter: BenchmarkAdapter, + tasks: List[str] = None, + max_queries: int = None + ) -> Dict[str, Any]: + """Run comprehensive benchmark with configurable components.""" + + print(f"🚀 Running benchmark: {adapter.name}") + + retrieval_type = self.config.get( + "retrieval", {}).get("type", "unknown") + print(f"🔍 Retrieval strategy: {retrieval_type}") + + if self.generation_engine: + print( + f"🤖 Generation provider: {self.generation_engine.provider_name}") + + # Load queries + queries = adapter.load_queries() + if max_queries: + queries = queries[:max_queries] + + print(f"📊 Evaluating {len(queries)} queries") + + results = [] + + # Process each query with progress bar + for query in tqdm(queries, desc="Processing queries"): + result = self._evaluate_query(query, adapter) + results.append(result) + + # Aggregate results + return self._aggregate_results(results, adapter.name) + + def _evaluate_query(self, query: BenchmarkQuery, adapter: BenchmarkAdapter) -> BenchmarkResult: + """Evaluate a single query with configurable components.""" + + # Retrieval evaluation using the pipeline + start_time = time.time() + + # Use the pipeline's run method + search_results = self.retrieval_pipeline.run( + query.query_text, + k=self.config.get("retrieval", {}).get("top_k", 20) + ) + + retrieval_time = (time.time() - start_time) * 1000 + + # Extract document IDs from results + retrieved_doc_ids = [] + for result in search_results: + doc_id = self._extract_document_id_from_result(result) + retrieved_doc_ids.append(str(doc_id)) + + # Compute retrieval metrics + retrieval_scores = {} + if query.relevant_doc_ids: + retrieval_scores = self.metrics.retrieval_metrics( + retrieved_doc_ids, + query.relevant_doc_ids, + k_values=self.config.get("evaluation", {}).get( + "k_values", [1, 5, 10, 20]) + ) + else: + # If no ground truth, return NaN metrics to indicate unavailable evaluation + k_values = self.config.get("evaluation", {}).get( + "k_values", [1, 5, 10, 20]) + for k in k_values: + retrieval_scores[f"precision@{k}"] = float('nan') + retrieval_scores[f"recall@{k}"] = float('nan') + retrieval_scores[f"ndcg@{k}"] = float('nan') + retrieval_scores["mrr"] = float('nan') + + # Generation evaluation (if enabled) + generation_scores = {} + generated_answer = None + generation_time = 0.0 + + if query.expected_answer and self.generation_engine: + start_time = time.time() + generated_answer = self.generation_engine.generate( + query=query.query_text, + context_docs=search_results[:self.config.get( + "generation", {}).get("context_limit", 5)] + ) + generation_time = (time.time() - start_time) * 1000 + + generation_scores = self.metrics.generation_metrics( + generated_answer, + query.expected_answer + ) + + # Combine all scores + all_scores = {**retrieval_scores, **generation_scores} + + return BenchmarkResult( + query_id=query.query_id, + retrieved_docs=retrieved_doc_ids, + generated_answer=generated_answer, + retrieval_time_ms=retrieval_time, + generation_time_ms=generation_time, + scores=all_scores + ) + + def _aggregate_results(self, results: List[BenchmarkResult], dataset_name: str) -> Dict[str, Any]: + """Aggregate individual results into final metrics.""" + + # Collect all scores + all_scores = {} + for result in results: + if result.scores: + for metric, score in result.scores.items(): + if metric not in all_scores: + all_scores[metric] = [] + all_scores[metric].append(score) + + # Get pipeline component names + component_names = [] + if hasattr(self.retrieval_pipeline, 'components'): + component_names = [ + comp.component_name for comp in self.retrieval_pipeline.components] + + # Compute averages and stats, handling NaN values + aggregated = { + "dataset": dataset_name, + "config": { + "retrieval_strategy": self.config.get("retrieval", {}).get("type", "unknown"), + "generation_enabled": self.generation_engine is not None, + "total_queries": len(results), + "components": component_names + }, + "performance": { + "avg_retrieval_time_ms": np.mean([r.retrieval_time_ms for r in results]), + "avg_generation_time_ms": np.mean([r.generation_time_ms for r in results]), + "total_time_ms": sum(r.retrieval_time_ms + r.generation_time_ms for r in results) + }, + "metrics": {} + } + + # Handle metrics with proper NaN handling + for metric, scores in all_scores.items(): + # Filter out NaN values for computation + valid_scores = [s for s in scores if not np.isnan(s)] + + if valid_scores: + aggregated["metrics"][metric] = { + "mean": np.mean(valid_scores), + "std": np.std(valid_scores), + "min": np.min(valid_scores), + "max": np.max(valid_scores), + "median": np.median(valid_scores), + "count": len(valid_scores), + "total_queries": len(scores) + } + else: + aggregated["metrics"][metric] = { + "mean": float('nan'), + "std": float('nan'), + "min": float('nan'), + "max": float('nan'), + "median": float('nan'), + "count": 0, + "total_queries": len(scores), + "note": "No ground truth available for evaluation" + } + + return aggregated + + def _extract_document_id_from_result(self, result) -> str: + """ + Extract document ID from retrieval result. + + For Qdrant, we need to get the external_id from the payload since + LangChain doesn't expose it in the document metadata. + """ + # First try: check if external_id is in document metadata + if hasattr(result, 'metadata') and result.metadata: + doc_id = result.metadata.get("external_id") + if doc_id: + return str(doc_id) + + # Second try: check document's metadata directly + if hasattr(result, 'page_content'): + # This is a Document object, check its metadata + if hasattr(result, 'metadata') and result.metadata: + doc_id = result.metadata.get("external_id") + if doc_id: + return str(doc_id) + + # Third try: if result has document attribute + if hasattr(result, 'document'): + if hasattr(result.document, 'metadata') and result.document.metadata: + doc_id = result.document.metadata.get("external_id") + if doc_id: + return str(doc_id) + + # Fourth try: For complex document IDs, try to extract the external_id part + # Look for patterns like "stackoverflow_sosum:a_123456:hash" -> "a_123456" + try: + # Check all possible metadata locations for any ID-like fields + metadata_sources = [] + + if hasattr(result, 'metadata') and result.metadata: + metadata_sources.append(result.metadata) + if hasattr(result, 'document') and hasattr(result.document, 'metadata'): + metadata_sources.append(result.document.metadata) + + for metadata in metadata_sources: + for key, value in metadata.items(): + if isinstance(value, str): + # Try to extract answer ID from complex document IDs + if ':a_' in value: + # Pattern: "stackoverflow_sosum:a_123456:hash" + parts = value.split(':') + for part in parts: + if part.startswith('a_'): + return part + + # Direct match for answer IDs + if value.startswith('a_') and value.replace('a_', '').replace('_', '').isdigit(): + return value + + except Exception as e: + pass + + # Fallback to unknown if no ID found + return "unknown" diff --git a/benchmarks/run_benchmark_optimization.py b/benchmarks/run_benchmark_optimization.py new file mode 100644 index 0000000..8cda60d --- /dev/null +++ b/benchmarks/run_benchmark_optimization.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Simple benchmark runner for easy optimization experiments. +""" + +import sys +import os +sys.path.append('/home/spiros/Desktop/Thesis/Thesis') + +from benchmarks.benchmark_optimizer import BenchmarkOptimizer + +def main(): + print("🔬 RAG Benchmark Optimizer") + print("="*50) + + optimizer = BenchmarkOptimizer() + + print("\nAvailable options:") + print("1. Run quick test (10 queries)") + print("2. Run single scenario") + print("3. Run all scenarios") + print("4. Compare previous results") + + choice = input("\nEnter choice (1-4): ").strip() + + if choice == "1": + # Quick test + print("\n🚀 Running quick test...") + config = optimizer.load_benchmark_config("benchmark_scenarios/quick_test.yml") + result = optimizer.run_optimization_scenario("quick_test", config) + optimizer._print_scenario_summary("quick_test", result) + + elif choice == "2": + # Single scenario + print("\nAvailable scenarios:") + scenarios = [ + "dense_baseline.yml", + "dense_high_recall.yml", + "dense_high_precision.yml", + "sparse_bm25.yml", + "hybrid_retrieval.yml", + "quick_test.yml" + ] + + for i, scenario in enumerate(scenarios, 1): + print(f"{i}. {scenario}") + + scenario_choice = input("\nEnter scenario number: ").strip() + try: + scenario_idx = int(scenario_choice) - 1 + scenario_file = scenarios[scenario_idx] + + print(f"\n🚀 Running scenario: {scenario_file}") + config = optimizer.load_benchmark_config(f"benchmark_scenarios/{scenario_file}") + result = optimizer.run_optimization_scenario(scenario_file.replace('.yml', ''), config) + optimizer._print_scenario_summary(scenario_file.replace('.yml', ''), result) + + except (ValueError, IndexError): + print("❌ Invalid scenario choice") + + elif choice == "3": + # All scenarios + print("\n🚀 Running all scenarios...") + results = optimizer.run_multiple_scenarios("benchmark_scenarios") + + if results: + optimizer.compare_scenarios() + optimizer.save_results() + + elif choice == "4": + # Compare results + print("\n📊 Comparing previous results...") + try: + import yaml + with open('benchmark_optimization_results.yml', 'r') as f: + data = yaml.safe_load(f) + optimizer.results_history = data.get('scenarios', []) + optimizer.compare_scenarios() + except FileNotFoundError: + print("❌ No previous results found. Run some benchmarks first!") + + else: + print("❌ Invalid choice") + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_real_benchmark.py b/benchmarks/run_real_benchmark.py new file mode 100644 index 0000000..43fd004 --- /dev/null +++ b/benchmarks/run_real_benchmark.py @@ -0,0 +1,116 @@ +"""Benchmark runner using real StackOverflow data.""" + +import sys +import os +sys.path.append('/home/spiros/Desktop/Thesis/Thesis') + +from benchmarks.benchmarks_runner import BenchmarkRunner +from benchmarks.benchmarks_adapters import StackOverflowBenchmarkAdapter +from benchmarks.benchmark_contracts import BenchmarkQuery +from config.config_loader import load_config + + +def run_real_stackoverflow_benchmark(): + """Run benchmark with real StackOverflow questions.""" + + print("🚀 Starting StackOverflow Benchmark with REAL Data") + + # Load your main configuration + config = load_config("config.yml") + + # Override for benchmarking + config["retrieval"] = { + "type": "dense", # Start with dense for simplicity + "top_k": 10, + "score_threshold": 0.1 + } + + config["evaluation"] = { + "k_values": [1, 5, 10], + "metrics": { + "retrieval": ["precision@k", "recall@k", "mrr", "ndcg@k"] + } + } + + # Create custom adapter for real data + class RealStackOverflowAdapter(StackOverflowBenchmarkAdapter): + def load_queries(self, split: str = "test"): + """Load from question.csv specifically.""" + import pandas as pd + + question_file = self.dataset_path / "question.csv" + print(f"📂 Loading from {question_file}") + + try: + df = pd.read_csv(question_file) + print(f"📊 Found {len(df)} questions in dataset") + + queries = [] + for idx, row in df.iterrows(): + if idx >= 20: # Limit to 20 real questions + break + + if pd.isna(row['question_title']) or not row['question_title']: + continue + + from benchmarks.benchmark_contracts import BenchmarkQuery + query = BenchmarkQuery( + query_id=f"real_so_{row['question_id']}", + query_text=str(row['question_title']), + expected_answer=str(row['question_body'])[:500] if not pd.isna( + row['question_body']) else None, + relevant_doc_ids=None, # No ground truth available + difficulty="medium", + category=str(row['tags']) if not pd.isna( + row['tags']) else "programming", + metadata={ + "original_question_id": row['question_id'], + "question_type": row['question_type'], + "tags": row['tags'], + "source": "real_stackoverflow" + } + ) + queries.append(query) + + print(f"✅ Loaded {len(queries)} real StackOverflow queries") + return queries + + except Exception as e: + print(f"❌ Error loading real data: {e}") + return self._create_dummy_queries() + + # Initialize components + runner = BenchmarkRunner(config) + adapter = RealStackOverflowAdapter( + dataset_path="/home/spiros/Desktop/Thesis/datasets/sosum/data" + ) + + # Run benchmark + print("📊 Running benchmark with real data...") + results = runner.run_benchmark( + adapter=adapter, + max_queries=10 # Test with 10 real questions + ) + + # Print results + print("\n📊 REAL STACKOVERFLOW BENCHMARK RESULTS:") + print(f"Dataset: {results['dataset']}") + print(f"Total Queries: {results['config']['total_queries']}") + print(f"Avg Time: {results['performance']['avg_retrieval_time_ms']:.2f}ms") + print(f"Components: {', '.join(results['config']['components'])}") + + print("\n🎯 Metrics:") + for metric_name in ['precision@5', 'precision@10', 'recall@5', 'recall@10', 'mrr']: + if metric_name in results['metrics']: + stats = results['metrics'][metric_name] + print( + f" {metric_name:12}: {stats['mean']:.3f} ± {stats['std']:.3f}") + + print("\n📋 Sample Query Results:") + # The results don't include individual queries, but we can see the overall performance + + return results + + +if __name__ == "__main__": + run_real_stackoverflow_benchmark() diff --git a/bin/__init__.py b/bin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bin/agent_retriever.py b/bin/agent_retriever.py new file mode 100644 index 0000000..31d7a5d --- /dev/null +++ b/bin/agent_retriever.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +""" +Agent wrapper for retrieval pipeline with configurable YAML. +Simple interface for agents to use any retrieval configuration. +""" + +from components.retrieval_pipeline import RetrievalPipelineFactory, RetrievalResult +import yaml +import logging +from pathlib import Path +from typing import List, Dict, Any, Optional +import sys +import os + +# Add project root to path +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + + +logger = logging.getLogger(__name__) + + +class ConfigurableRetrieverAgent: + """ + Agent that can use any YAML configuration for retrieval. + Provides a simple interface for agents to retrieve documents using + configurable pipelines without needing to know implementation details. + """ + + def __init__(self, config_path: str, cache_pipeline: bool = True): + """ + Initialize agent with a specific configuration. + + Args: + config_path (str): Path to YAML configuration file + cache_pipeline (bool): Whether to cache the pipeline for reuse + """ + self.config_path = config_path + self.cache_pipeline = cache_pipeline + self._pipeline = None + self._config = None + + # Load configuration + self._load_config() + + logger.info( + f"ConfigurableRetrieverAgent initialized with config: {config_path}") + + def _load_config(self): + """ + Load configuration from YAML file. + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If config is invalid + """ + config_file = Path(self.config_path) + + if not config_file.exists(): + raise FileNotFoundError( + f"Configuration file not found: {self.config_path}") + + with open(config_file, 'r') as f: + self._config = yaml.safe_load(f) + + logger.info(f"Loaded configuration: {self.config_path}") + + def _get_pipeline(self): + """ + Get or create the retrieval pipeline. + + Returns: + RetrievalPipeline: Configured retrieval pipeline + """ + if self._pipeline is None or not self.cache_pipeline: + logger.info("Creating retrieval pipeline from configuration...") + self._pipeline = RetrievalPipelineFactory.create_from_config( + self._config) + + components = [c.component_name for c in self._pipeline.components] + logger.info(f"Pipeline components: {components}") + + return self._pipeline + + def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Retrieve documents for a query. + + Args: + query (str): Search query + top_k (int): Number of results to return + + Returns: + List[Dict[str, Any]]: List of dictionaries with document information + containing rank, score, content, metadata, etc. + """ + logger.info(f"Retrieving documents for query: '{query[:50]}...'") + + # Get pipeline and run retrieval + pipeline = self._get_pipeline() + results = pipeline.run(query, k=top_k) + + # Convert to simple dictionary format for agents + documents = [] + for i, result in enumerate(results): + labels = result.document.metadata.get('labels', {}) + + doc_info = { + 'rank': i + 1, + 'score': result.score, + 'content': result.document.page_content, + 'retrieval_method': result.retrieval_method, + 'question_title': labels.get('title', ''), + 'tags': labels.get('tags', []), + 'external_id': labels.get('external_id', ''), + 'enhanced': result.metadata.get('enhanced', False), + 'answer_quality': result.metadata.get('answer_quality', ''), + 'metadata': result.metadata + } + documents.append(doc_info) + + logger.info(f"Retrieved {len(documents)} documents") + return documents + + def get_config_info(self) -> Dict[str, Any]: + """Get information about the current configuration.""" + pipeline_config = self._config.get('retrieval_pipeline', {}) + retriever_config = pipeline_config.get('retriever', {}) + stages = pipeline_config.get('stages', []) + + return { + 'config_path': self.config_path, + 'retriever_type': retriever_config.get('type', 'unknown'), + 'retriever_top_k': retriever_config.get('top_k', 5), + 'num_stages': len(stages), + 'stage_types': [stage.get('type', 'unknown') for stage in stages], + 'embedding_strategy': self._config.get('embedding_strategy', 'unknown'), + 'collection': self._config.get('qdrant', {}).get('collection', 'unknown') + } + + def switch_config(self, new_config_path: str): + """ + Switch to a different configuration. + + Args: + new_config_path: Path to new YAML configuration + """ + logger.info( + f"Switching configuration from {self.config_path} to {new_config_path}") + + self.config_path = new_config_path + self._config = None + self._pipeline = None # Force recreation + + self._load_config() + logger.info("Configuration switched successfully") + + +def get_agent_with_config(config_name: str) -> ConfigurableRetrieverAgent: + """ + Convenience function to get an agent with a named configuration. + + Args: + config_name: Name of config file (e.g., 'basic_dense' for basic_dense.yml) + + Returns: + ConfigurableRetrieverAgent instance + """ + config_path = f"pipelines/configs/retrieval/{config_name}.yml" + return ConfigurableRetrieverAgent(config_path) + + +def demo_agent_usage(): + """Demonstrate how to use the configurable agent.""" + print("🤖 Configurable Retriever Agent Demo") + print("=" * 50) + + # Test queries + queries = [ + "How to handle Python exceptions?", + "Binary search algorithm implementation", + "What are Python metaclasses?" + ] + + # Test different configurations + configs = [ + "basic_dense", + "advanced_reranked", + "experimental" + ] + + for config_name in configs: + print(f"\n📋 Testing configuration: {config_name}") + print("-" * 40) + + try: + # Create agent with specific config + agent = get_agent_with_config(config_name) + + # Show config info + config_info = agent.get_config_info() + print(f"Retriever: {config_info['retriever_type']}") + print( + f"Stages: {config_info['num_stages']} ({', '.join(config_info['stage_types'])})") + + # Test a query + query = queries[0] + results = agent.retrieve(query, top_k=2) + + print(f"\nQuery: {query}") + print(f"Results: {len(results)}") + + for doc in results[:1]: # Show top result + print( + f" Score: {doc['score']:.3f} | Method: {doc['retrieval_method']}") + print(f" Question: {doc['question_title'][:50]}...") + + except Exception as e: + print(f"❌ Error with {config_name}: {e}") + + +if __name__ == "__main__": + # Setup logging for demo + logging.basicConfig(level=logging.INFO) + + # Run demonstration + demo_agent_usage() + + print("\n" + "=" * 50) + print("💡 Usage Examples:") + print("=" * 50) + print(""" +# Simple usage +agent = get_agent_with_config('basic_dense') +results = agent.retrieve('Python exceptions', top_k=5) + +# Switch configurations dynamically +agent.switch_config('pipelines/configs/retrieval/advanced_reranked.yml') +results = agent.retrieve('same query', top_k=5) + +# Get config information +info = agent.get_config_info() +print(f"Using {info['retriever_type']} with {info['num_stages']} stages") + """) diff --git a/bin/ingest.py b/bin/ingest.py new file mode 100755 index 0000000..36d75ac --- /dev/null +++ b/bin/ingest.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +""" +Ingestion pipeline CLI - single entrypoint for all ingestion operations. +Implements canary -> verify -> promote workflow with comprehensive logging. +""" +import argparse +import logging +import sys +import json +from pathlib import Path +from typing import List, Dict, Any + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from pipelines.ingest.pipeline import IngestionPipeline, BatchIngestionPipeline +from pipelines.eval.evaluator import RetrievalEvaluator +from pipelines.adapters.natural_questions import NaturalQuestionsAdapter +from pipelines.adapters.stackoverflow import StackOverflowAdapter +from pipelines.adapters.energy_papers import EnergyPapersAdapter +from pipelines.contracts import DatasetSplit +from config.config_loader import load_config + + +def setup_logging(verbose: bool = False): + """Setup logging configuration.""" + level = logging.DEBUG if verbose else logging.INFO + + # Create logs directory + logs_dir = Path("logs") + logs_dir.mkdir(exist_ok=True) + + # Configure logging + logging.basicConfig( + level=level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler(logs_dir / "ingestion.log") + ] + ) + + +def get_adapter(adapter_type: str, dataset_path: str, version: str = "1.0.0"): + """Factory function to create dataset adapters.""" + adapters = { + "natural_questions": NaturalQuestionsAdapter, + "stackoverflow": StackOverflowAdapter, + "energy_papers": EnergyPapersAdapter, + } + + if adapter_type not in adapters: + available = ", ".join(adapters.keys()) + raise ValueError(f"Unknown adapter type '{adapter_type}'. Available: {available}") + + adapter_class = adapters[adapter_type] + return adapter_class(dataset_path, version) + + +def cmd_ingest(args): + """Run ingestion pipeline.""" + logger = logging.getLogger("ingest") + logger.info(f"Starting ingestion: {args.adapter_type} from {args.dataset_path}") + + # Load configuration + config = load_config(args.config) if args.config else load_config() + + # Create adapter + adapter = get_adapter(args.adapter_type, args.dataset_path, args.version) + + # Create pipeline + pipeline = IngestionPipeline(config=config) + + # Parse split + split = DatasetSplit(args.split) + + # Run ingestion + try: + record = pipeline.ingest_dataset( + adapter=adapter, + split=split, + dry_run=args.dry_run, + max_documents=args.max_documents, + canary_mode=args.canary + ) + + # Print results + print(f"\n✓ Ingestion completed successfully!") + print(f" Dataset: {record.dataset_name} v{record.dataset_version}") + print(f" Documents: {record.total_documents}") + print(f" Chunks: {record.total_chunks}") + print(f" Successful: {record.successful_chunks}") + print(f" Failed: {record.failed_chunks}") + print(f" Success rate: {record.successful_chunks/record.total_chunks*100:.1f}%" if record.total_chunks > 0 else "N/A") + print(f" Run ID: {record.run_id}") + + if args.verify: + logger.info("Running verification...") + collection_info = pipeline.get_collection_status() + print(f"\n Collection Status:") + print(f" Name: {collection_info.get('collection_name', 'unknown')}") + print(f" Points: {collection_info.get('points_count', 0)}") + print(f" Status: {collection_info.get('status', 'unknown')}") + + return 0 + + except Exception as e: + logger.error(f"Ingestion failed: {e}") + print(f"\n✗ Ingestion failed: {e}") + return 1 + + +def cmd_batch_ingest(args): + """Run batch ingestion for multiple datasets.""" + logger = logging.getLogger("batch_ingest") + + # Load batch configuration + with open(args.batch_config, 'r') as f: + batch_config = json.load(f) + + datasets = batch_config.get("datasets", []) + if not datasets: + logger.error("No datasets specified in batch configuration") + return 1 + + logger.info(f"Starting batch ingestion of {len(datasets)} datasets") + + # Create adapters + adapters = [] + for dataset_config in datasets: + adapter = get_adapter( + dataset_config["type"], + dataset_config["path"], + dataset_config.get("version", "1.0.0") + ) + adapters.append(adapter) + + # Run batch ingestion + pipeline = BatchIngestionPipeline(args.config) + + try: + results = pipeline.ingest_multiple_datasets( + adapters=adapters, + split=DatasetSplit(args.split), + dry_run=args.dry_run, + max_documents=args.max_documents + ) + + # Print summary + summary = pipeline.get_summary() + print(f"\n✓ Batch ingestion completed!") + print(f" Datasets processed: {summary['total_datasets']}") + print(f" Total documents: {summary['total_documents']}") + print(f" Total chunks: {summary['total_chunks']}") + print(f" Overall success rate: {summary['success_rate']*100:.1f}%") + + return 0 + + except Exception as e: + logger.error(f"Batch ingestion failed: {e}") + print(f"\n✗ Batch ingestion failed: {e}") + return 1 + + +def cmd_evaluate(args): + """Run evaluation on ingested dataset.""" + logger = logging.getLogger("evaluate") + logger.info(f"Starting evaluation: {args.adapter_type}") + + # Load configuration + config = load_config(args.config) if args.config else load_config() + + # Create adapter + adapter = get_adapter(args.adapter_type, args.dataset_path, args.version) + + # Create retriever + from retrievers.router import RetrieverRouter + retriever = RetrieverRouter(config) + + # Create evaluator + evaluator = RetrievalEvaluator(config) + + try: + # Run evaluation + evaluation_run = evaluator.evaluate_dataset( + adapter=adapter, + retriever=retriever, + split=args.split + ) + + # Save results + output_dir = Path(args.output_dir) + evaluator.save_results(evaluation_run, output_dir) + + # Print summary + metrics = evaluation_run.metrics + print(f"\n✓ Evaluation completed!") + print(f" Dataset: {evaluation_run.dataset_name}") + print(f" Queries: {metrics.total_queries}") + print(f" Recall@5: {metrics.recall_at_k.get(5, 0):.3f}") + print(f" Precision@5: {metrics.precision_at_k.get(5, 0):.3f}") + print(f" NDCG@5: {metrics.ndcg_at_k.get(5, 0):.3f}") + print(f" MRR: {metrics.mrr:.3f}") + print(f" Results saved to: {output_dir}") + + return 0 + + except Exception as e: + logger.error(f"Evaluation failed: {e}") + print(f"\n✗ Evaluation failed: {e}") + return 1 + + +def cmd_status(args): + """Show collection and pipeline status.""" + config = load_config(args.config) if args.config else load_config() + pipeline = IngestionPipeline(config=config) + + try: + # Get collection info + collection_info = pipeline.get_collection_status() + + print(f"\nCollection Status:") + print(f" Name: {collection_info.get('collection_name', 'unknown')}") + print(f" Points: {collection_info.get('points_count', 0):,}") + print(f" Status: {collection_info.get('status', 'unknown')}") + + vectors_config = collection_info.get('vectors_config', {}) + sparse_vectors_config = collection_info.get('sparse_vectors_config', {}) + + if vectors_config: + print(f" Dense vectors: {len(vectors_config)}") + for name, config in vectors_config.items(): + print(f" {name}: {config.size} dims") + + if sparse_vectors_config: + print(f" Sparse vectors: {len(sparse_vectors_config)}") + + # Show recent lineage files + lineage_dir = Path("output/lineage") + if lineage_dir.exists(): + lineage_files = sorted(lineage_dir.glob("*.json"), key=lambda x: x.stat().st_mtime, reverse=True) + if lineage_files: + print(f"\n📝 Recent Ingestion Runs:") + for file_path in lineage_files[:5]: + with open(file_path, 'r') as f: + data = json.load(f) + record = data.get("ingestion_record", {}) + print(f" {record.get('dataset_name', 'unknown')} - {record.get('started_at', 'unknown')}") + + return 0 + + except Exception as e: + print(f"\n✗ Status check failed: {e}") + return 1 + + +def cmd_cleanup(args): + """Clean up canary collections and temporary files.""" + config = load_config(args.config) if args.config else load_config() + pipeline = IngestionPipeline(config=config) + + try: + pipeline.cleanup_canary_collections() + print("✓ Cleanup completed") + return 0 + except Exception as e: + print(f"✗ Cleanup failed: {e}") + return 1 + + +def main(): + """Main CLI entrypoint.""" + parser = argparse.ArgumentParser( + description="Ingestion Pipeline CLI", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Ingest Natural Questions dataset + python bin/ingest.py ingest natural_questions /path/to/nq --config config.yml + + # Dry run with limited documents + python bin/ingest.py ingest stackoverflow /path/to/so --dry-run --max-docs 100 + + # Canary ingestion + python bin/ingest.py ingest energy_papers papers/ --canary + + # Batch ingestion + python bin/ingest.py batch-ingest batch_config.json + + # Evaluate retrieval + python bin/ingest.py evaluate natural_questions /path/to/nq --output-dir results/ + + # Check status + python bin/ingest.py status + """ + ) + + parser.add_argument("--config", "-c", help="Configuration file path") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging") + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Ingest command + ingest_parser = subparsers.add_parser("ingest", help="Ingest a single dataset") + ingest_parser.add_argument("adapter_type", choices=["natural_questions", "stackoverflow", "energy_papers"], + help="Dataset adapter type") + ingest_parser.add_argument("dataset_path", help="Path to dataset") + ingest_parser.add_argument("--version", default="1.0.0", help="Dataset version") + ingest_parser.add_argument("--split", choices=["train", "val", "test", "all"], default="all", + help="Dataset split to process") + ingest_parser.add_argument("--dry-run", action="store_true", help="Don't upload to vector store") + ingest_parser.add_argument("--max-docs", type=int, dest="max_documents", help="Maximum documents to process") + ingest_parser.add_argument("--canary", action="store_true", help="Use canary collection") + ingest_parser.add_argument("--verify", action="store_true", help="Run verification after ingestion") + ingest_parser.set_defaults(func=cmd_ingest) + + # Batch ingest command + batch_parser = subparsers.add_parser("batch-ingest", help="Ingest multiple datasets") + batch_parser.add_argument("batch_config", help="JSON file with batch configuration") + batch_parser.add_argument("--split", choices=["train", "val", "test", "all"], default="all") + batch_parser.add_argument("--dry-run", action="store_true", help="Don't upload to vector store") + batch_parser.add_argument("--max-docs", type=int, dest="max_documents", help="Maximum documents per dataset") + batch_parser.set_defaults(func=cmd_batch_ingest) + + # Evaluate command + eval_parser = subparsers.add_parser("evaluate", help="Evaluate retrieval performance") + eval_parser.add_argument("adapter_type", choices=["natural_questions", "stackoverflow", "energy_papers"]) + eval_parser.add_argument("dataset_path", help="Path to dataset") + eval_parser.add_argument("--version", default="1.0.0", help="Dataset version") + eval_parser.add_argument("--split", choices=["train", "val", "test"], default="test") + eval_parser.add_argument("--output-dir", default="output/evaluation", help="Output directory for results") + eval_parser.set_defaults(func=cmd_evaluate) + + # Status command + status_parser = subparsers.add_parser("status", help="Show pipeline status") + status_parser.set_defaults(func=cmd_status) + + # Cleanup command + cleanup_parser = subparsers.add_parser("cleanup", help="Clean up canary collections") + cleanup_parser.set_defaults(func=cmd_cleanup) + + # Parse and execute + args = parser.parse_args() + + if not args.command: + parser.print_help() + return 1 + + # Setup logging + setup_logging(args.verbose) + + # Execute command + return args.func(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/bin/qdrant_inspector.py b/bin/qdrant_inspector.py new file mode 100644 index 0000000..50fd852 --- /dev/null +++ b/bin/qdrant_inspector.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +""" +Qdrant Inspection Tool - Explore and query your vector database +""" +import argparse +import json +from typing import List, Dict, Any +from qdrant_client import QdrantClient +from qdrant_client.http.models import Filter, FieldCondition, Range, MatchValue + +def inspect_collections(client: QdrantClient): + """List all collections and their stats.""" + print("=== QDRANT COLLECTIONS ===") + collections = client.get_collections() + + if not collections.collections: + print("No collections found.") + return + + for collection in collections.collections: + info = client.get_collection(collection.name) + print(f"\n📁 Collection: {collection.name}") + print(f" Status: {info.status}") + print(f" Vectors: {info.vectors_count if info.vectors_count else 'Computing...'}") + + # Vector configuration + if hasattr(info.config.params, 'vectors'): + if isinstance(info.config.params.vectors, dict): + for name, config in info.config.params.vectors.items(): + print(f" Vector '{name}': {config.size}D, distance={config.distance}") + else: + print(f" Vector: {info.config.params.vectors.size}D, distance={info.config.params.vectors.distance}") + +def browse_data(client: QdrantClient, collection_name: str, limit: int = 10): + """Browse data in a collection.""" + print(f"\n=== BROWSING: {collection_name} ===") + + try: + # Get collection info + info = client.get_collection(collection_name) + print(f"Status: {info.status}") + print(f"Vectors: {info.vectors_count if info.vectors_count else 'Computing...'}") + + # Get sample points + points, next_page_offset = client.scroll( + collection_name=collection_name, + limit=limit, + with_payload=True, + with_vectors=False # Don't load vectors for browsing + ) + + print(f"\n📄 Sample documents ({len(points)} shown):") + for i, point in enumerate(points, 1): + payload = point.payload + print(f"\n{i}. ID: {point.id}") + print(f" Source: {payload.get('source', 'unknown')}") + print(f" External ID: {payload.get('external_id', 'unknown')}") + print(f" Split: {payload.get('split', 'unknown')}") + print(f" Chunk: {payload.get('chunk_index', 0)}/{payload.get('num_chunks', 1)}") + print(f" Model: {payload.get('embedding_model', 'unknown')}") + print(f" Text: {payload.get('text', '')[:200]}...") + + # Show labels if available + labels = payload.get('labels', {}) + if labels and isinstance(labels, dict): + interesting_labels = {k: v for k, v in labels.items() + if k in ['post_type', 'tags', 'doc_type', 'title']} + if interesting_labels: + print(f" Labels: {interesting_labels}") + + except Exception as e: + print(f"Error browsing collection: {e}") + +def search_collection(client: QdrantClient, collection_name: str, query: str, limit: int = 5): + """Search collection using vector similarity.""" + print(f"\n=== SEARCHING: {collection_name} ===") + print(f"Query: '{query}'") + + try: + # We need to embed the query first + # For now, let's try a simple vector search + # This is a simplified version - in production you'd embed the query properly + + # Try to get some results using vector search + from qdrant_client.http.models import SearchRequest + + # First, let's try a simple scroll to see what's there + points, _ = client.scroll( + collection_name=collection_name, + limit=limit, + with_payload=True, + with_vectors=False + ) + + if not points: + print("No data found in collection.") + return + + # For now, let's do a text-based filter as fallback + print(f"\n� Showing {len(points)} sample documents (vector search not implemented in inspector):") + for i, point in enumerate(points, 1): + payload = point.payload + text = payload.get('text', '') + + # Simple text matching + if query.lower() in text.lower(): + print(f"\n{i}. Match (ID: {point.id})") + print(f" External ID: {payload.get('external_id')}") + print(f" Source: {payload.get('source')}") + print(f" Text: {text[:300]}...") + + print(f"\n💡 Note: This is text-based search. For semantic search, use the retrieval pipeline.") + + except Exception as e: + print(f"Error searching collection: {e}") + +def filter_by_metadata(client: QdrantClient, collection_name: str, key: str, value: str, limit: int = 10): + """Filter points by metadata.""" + print(f"\n=== FILTERING: {collection_name} ===") + print(f"Filter: {key} = '{value}'") + + try: + points, _ = client.scroll( + collection_name=collection_name, + scroll_filter=Filter( + must=[ + FieldCondition( + key=key, + match=MatchValue(value=value) + ) + ] + ), + limit=limit, + with_payload=True, + with_vectors=False + ) + + print(f"\n📋 Found {len(points)} results:") + for i, point in enumerate(points, 1): + payload = point.payload + print(f"\n{i}. ID: {point.id}") + print(f" {key}: {payload.get(key)}") + print(f" Text: {payload.get('text', '')[:200]}...") + + except Exception as e: + print(f"Error filtering collection: {e}") + +def collection_stats(client: QdrantClient, collection_name: str): + """Show detailed statistics for a collection.""" + print(f"\n=== STATISTICS: {collection_name} ===") + + try: + # Get all points to compute stats + all_points, _ = client.scroll( + collection_name=collection_name, + limit=10000, # Adjust based on your collection size + with_payload=True, + with_vectors=False + ) + + if not all_points: + print("No data found.") + return + + # Compute statistics + sources = {} + splits = {} + models = {} + post_types = {} + + total_chars = 0 + chunk_counts = [] + + for point in all_points: + payload = point.payload + + # Count by source + source = payload.get('source', 'unknown') + sources[source] = sources.get(source, 0) + 1 + + # Count by split + split = payload.get('split', 'unknown') + splits[split] = splits.get(split, 0) + 1 + + # Count by model + model = payload.get('embedding_model', 'unknown') + models[model] = models.get(model, 0) + 1 + + # Count by post type (if available) + labels = payload.get('labels', {}) + if isinstance(labels, dict): + post_type = labels.get('post_type', 'unknown') + post_types[post_type] = post_types.get(post_type, 0) + 1 + + # Text statistics + text_len = len(payload.get('text', '')) + total_chars += text_len + + # Chunk info + num_chunks = payload.get('num_chunks', 1) + chunk_counts.append(num_chunks) + + # Print statistics + print(f"📊 Total documents: {len(all_points)}") + print(f"📊 Average text length: {total_chars / len(all_points):.1f} characters") + print(f"📊 Average chunks per document: {sum(chunk_counts) / len(chunk_counts):.1f}") + + print(f"\n📁 Sources:") + for source, count in sources.items(): + print(f" {source}: {count}") + + print(f"\n🔀 Splits:") + for split, count in splits.items(): + print(f" {split}: {count}") + + print(f"\n🤖 Models:") + for model, count in models.items(): + print(f" {model}: {count}") + + if post_types and any(pt != 'unknown' for pt in post_types.keys()): + print(f"\n📝 Post Types:") + for post_type, count in post_types.items(): + if post_type != 'unknown': + print(f" {post_type}: {count}") + + except Exception as e: + print(f"Error computing statistics: {e}") + +def main(): + parser = argparse.ArgumentParser(description="Qdrant Database Inspector") + parser.add_argument("--host", default="localhost", help="Qdrant host") + parser.add_argument("--port", type=int, default=6333, help="Qdrant port") + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # List collections + subparsers.add_parser("list", help="List all collections") + + # Browse data + browse_parser = subparsers.add_parser("browse", help="Browse collection data") + browse_parser.add_argument("collection", help="Collection name") + browse_parser.add_argument("--limit", type=int, default=10, help="Number of documents to show") + + # Search + search_parser = subparsers.add_parser("search", help="Search collection") + search_parser.add_argument("collection", help="Collection name") + search_parser.add_argument("query", help="Search query") + search_parser.add_argument("--limit", type=int, default=5, help="Number of results") + + # Filter + filter_parser = subparsers.add_parser("filter", help="Filter by metadata") + filter_parser.add_argument("collection", help="Collection name") + filter_parser.add_argument("key", help="Metadata key") + filter_parser.add_argument("value", help="Metadata value") + filter_parser.add_argument("--limit", type=int, default=10, help="Number of results") + + # Statistics + stats_parser = subparsers.add_parser("stats", help="Show collection statistics") + stats_parser.add_argument("collection", help="Collection name") + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + # Connect to Qdrant + try: + client = QdrantClient(host=args.host, port=args.port) + + if args.command == "list": + inspect_collections(client) + + elif args.command == "browse": + browse_data(client, args.collection, args.limit) + + elif args.command == "search": + search_collection(client, args.collection, args.query, args.limit) + + elif args.command == "filter": + filter_by_metadata(client, args.collection, args.key, args.value, args.limit) + + elif args.command == "stats": + collection_stats(client, args.collection) + + except Exception as e: + print(f"Error connecting to Qdrant: {e}") + print("Make sure Qdrant is running: docker run -p 6333:6333 qdrant/qdrant") + +if __name__ == "__main__": + main() diff --git a/bin/retrieval_pipeline.py b/bin/retrieval_pipeline.py new file mode 100644 index 0000000..c106968 --- /dev/null +++ b/bin/retrieval_pipeline.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +Retrieval Pipeline CLI - Use any YAML configuration +Usage: python bin/retrieval_pipeline.py --config pipelines/configs/retrieval/basic_dense.yml --query "How to handle exceptions in Python?" +""" + +import argparse +import yaml +import logging +from pathlib import Path +import sys +import os + +# Add project root to path +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +from components.retrieval_pipeline import RetrievalPipelineFactory + +# Setup logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def load_config(config_path: str) -> dict: + """Load configuration from YAML file.""" + config_file = Path(config_path) + + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + logger.info(f"Loaded configuration from: {config_path}") + return config + + +def run_retrieval(config: dict, query: str, top_k: int = 5) -> list: + """Run retrieval with the specified configuration.""" + logger.info(f"Creating pipeline from configuration...") + + # Create pipeline from config + pipeline = RetrievalPipelineFactory.create_from_config(config) + + logger.info(f"Pipeline components: {[c.component_name for c in pipeline.components]}") + + # Run retrieval + logger.info(f"Running query: '{query}'") + results = pipeline.run(query, k=top_k) + + return results + + +def display_results(results: list, show_content: bool = False): + """Display retrieval results in a nice format.""" + print(f"\n🔍 Found {len(results)} results:") + print("=" * 80) + + for i, result in enumerate(results, 1): + labels = result.document.metadata.get('labels', {}) + + print(f"\n{i}. Score: {result.score:.4f} | Method: {result.retrieval_method}") + + # Show question title if available + title = labels.get('title', 'N/A') + if title != 'N/A': + print(f" 📝 Question: {title}") + + # Show tags if available + tags = labels.get('tags', []) + if tags: + print(f" 🏷️ Tags: {', '.join(tags[:5])}") # Show first 5 tags + + # Show enhancement info if available + if result.metadata.get('enhanced'): + quality = result.metadata.get('answer_quality', 'unknown') + print(f" ✨ Enhanced (Quality: {quality})") + + # Show content if requested + if show_content: + content = result.document.page_content[:200] + "..." if len(result.document.page_content) > 200 else result.document.page_content + print(f" 📄 Content: {content}") + + print("-" * 80) + + +def list_available_configs(): + """List all available configuration files.""" + config_dir = Path("pipelines/configs/retrieval") + + if not config_dir.exists(): + print("❌ No retrieval configurations found") + return + + print("\n📋 Available configurations:") + print("=" * 50) + + configs = list(config_dir.glob("*.yml")) + for config_file in sorted(configs): + try: + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + # Extract pipeline info + pipeline_info = config.get('retrieval_pipeline', {}) + retriever_type = pipeline_info.get('retriever', {}).get('type', 'unknown') + stages = pipeline_info.get('stages', []) + + print(f"\n📁 {config_file.name}") + print(f" Retriever: {retriever_type}") + print(f" Stages: {len(stages)} components") + + if stages: + stage_types = [stage.get('type', 'unknown') for stage in stages] + print(f" Pipeline: {retriever_type} → {' → '.join(stage_types)}") + + except Exception as e: + print(f"❌ Error reading {config_file.name}: {e}") + + +def main(): + """Main CLI function.""" + parser = argparse.ArgumentParser( + description="Run retrieval pipeline with specified YAML configuration", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use basic dense retrieval + python bin/retrieval_pipeline.py --config pipelines/configs/retrieval/basic_dense.yml --query "Python exceptions" + + # Use advanced pipeline with reranking + python bin/retrieval_pipeline.py --config pipelines/configs/retrieval/advanced_reranked.yml --query "binary search algorithm" + + # Show full content of results + python bin/retrieval_pipeline.py --config pipelines/configs/retrieval/experimental.yml --query "metaclasses" --show-content + + # List available configurations + python bin/retrieval_pipeline.py --list-configs + """ + ) + + parser.add_argument( + '--config', '-c', + type=str, + help='Path to YAML configuration file' + ) + + parser.add_argument( + '--query', '-q', + type=str, + help='Search query' + ) + + parser.add_argument( + '--top-k', '-k', + type=int, + default=5, + help='Number of results to retrieve (default: 5)' + ) + + parser.add_argument( + '--show-content', + action='store_true', + help='Show document content in results' + ) + + parser.add_argument( + '--list-configs', + action='store_true', + help='List all available configuration files' + ) + + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Enable verbose logging' + ) + + args = parser.parse_args() + + # Set logging level + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + try: + # List configs and exit + if args.list_configs: + list_available_configs() + return + + # Validate required arguments + if not args.config: + print("❌ Error: --config is required (or use --list-configs to see available options)") + parser.print_help() + return + + if not args.query: + print("❌ Error: --query is required") + parser.print_help() + return + + print(f"🚀 Running retrieval pipeline") + print(f"📋 Config: {args.config}") + print(f"🔍 Query: {args.query}") + print(f"📊 Top-K: {args.top_k}") + + # Load configuration + config = load_config(args.config) + + # Run retrieval + results = run_retrieval(config, args.query, args.top_k) + + # Display results + display_results(results, show_content=args.show_content) + + print(f"\n✅ Retrieval completed successfully!") + + except KeyboardInterrupt: + print("\n❌ Interrupted by user") + except Exception as e: + print(f"❌ Error: {e}") + if args.verbose: + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/bin/switch_agent_config.py b/bin/switch_agent_config.py new file mode 100644 index 0000000..9d2c450 --- /dev/null +++ b/bin/switch_agent_config.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +""" +Utility to switch the agent's retrieval configuration. +Usage: python bin/switch_agent_config.py [config_name] +""" + +import yaml +import sys +import argparse +from pathlib import Path + + +def list_available_configs(): + """ + List all available retrieval configurations. + + Returns: + list: List of tuples containing (config_name, description, file_path) + """ + config_dir = Path("pipelines/configs/retrieval") + if not config_dir.exists(): + print("Retrieval configs directory not found") + return [] + + configs = [] + for config_file in config_dir.glob("*.yml"): + config_name = config_file.stem + + # Read config to get description + try: + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + retrieval_config = config.get('retrieval_pipeline', {}) + retriever_type = retrieval_config.get( + 'retriever', {}).get('type', 'unknown') + num_stages = len(retrieval_config.get('stages', [])) + + description = f"{retriever_type} retrieval with {num_stages} stages" + configs.append((config_name, description, str(config_file))) + + except Exception as e: + configs.append( + (config_name, f"Error reading config: {e}", str(config_file))) + + return configs + + +def switch_agent_config(config_name: str): + """ + Switch the agent's retrieval configuration. + + Args: + config_name (str): Name of the configuration to switch to + + Returns: + bool: True if successful, False otherwise + """ + # Check if config exists + config_path = Path(f"pipelines/configs/retrieval/{config_name}.yml") + if not config_path.exists(): + print(f"Configuration '{config_name}' not found at {config_path}") + return False + + # Load main config + main_config_path = Path("config.yml") + if not main_config_path.exists(): + print(f"Main config file not found: {main_config_path}") + return False + + try: + with open(main_config_path, 'r') as f: + config = yaml.safe_load(f) + + # Update retrieval config path + if 'agent_retrieval' not in config: + config['agent_retrieval'] = {} + + old_config = config['agent_retrieval'].get('config_path', 'not set') + config['agent_retrieval'][ + 'config_path'] = f"pipelines/configs/retrieval/{config_name}.yml" + config['agent_retrieval']['active_config'] = config_name + + # Write updated config + with open(main_config_path, 'w') as f: + yaml.dump(config, f, default_flow_style=False) + + print(f"Agent retrieval configuration switched:") + print(f" From: {old_config}") + print(f" To: pipelines/configs/retrieval/{config_name}.yml") + + # Show config details + with open(config_path, 'r') as f: + retrieval_config = yaml.safe_load(f) + + pipeline_config = retrieval_config.get('retrieval_pipeline', {}) + retriever_info = pipeline_config.get('retriever', {}) + stages = pipeline_config.get('stages', []) + + print(f"\nConfiguration Details:") + print( + f" Retriever: {retriever_info.get('type', 'unknown')} (top_k={retriever_info.get('top_k', 5)})") + print(f" Stages: {len(stages)}") + for i, stage in enumerate(stages, 1): + stage_type = stage.get('type', 'unknown') + print(f" {i}. {stage_type}") + + print(f"\nRestart your agent to apply the new configuration.") + return True + + except Exception as e: + print(f"Error updating configuration: {e}") + return False + + +def main(): + """ + Main function to handle command line arguments and execute configuration switching. + """ + parser = argparse.ArgumentParser( + description="Switch agent retrieval configuration") + parser.add_argument("config_name", nargs='?', + help="Name of the configuration to switch to") + parser.add_argument("--list", "-l", action="store_true", + help="List available configurations") + + args = parser.parse_args() + + if args.list or not args.config_name: + print("Available Retrieval Configurations:") + print("=" * 50) + + configs = list_available_configs() + if not configs: + print("No configurations found") + return + + for name, description, path in configs: + print(f"{name}") + print(f" {description}") + print(f" Path: {path}") + print() + + if not args.config_name: + print("Usage: python bin/switch_agent_config.py ") + return + + if args.config_name: + success = switch_agent_config(args.config_name) + if success: + print(f"\nTest the new configuration:") + print(f" python tests/test_agent_retrieval.py") + else: + print(f"\nAvailable configs:") + for name, _, _ in list_available_configs(): + print(f" - {name}") + + +if __name__ == "__main__": + main() diff --git a/components/__init__.py b/components/__init__.py new file mode 100644 index 0000000..daa47fe --- /dev/null +++ b/components/__init__.py @@ -0,0 +1,58 @@ +""" +Modular components for extensible retrieval pipelines. +""" + +from .retrieval_pipeline import ( + RetrievalPipeline, + RetrievalPipelineFactory, + RetrievalResult, + RetrievalComponent, + BaseRetriever, + Reranker, + ResultFilter, + PostProcessor +) + +from .rerankers import ( + CrossEncoderReranker, + SemanticReranker, + BM25Reranker, + EnsembleReranker +) + +from .filters import ( + ScoreFilter, + MetadataFilter, + TagFilter, + DuplicateFilter, + AnswerEnhancer, + ContextEnricher, + ResultLimiter +) + +__all__ = [ + # Core pipeline + 'RetrievalPipeline', + 'RetrievalPipelineFactory', + 'RetrievalResult', + 'RetrievalComponent', + 'BaseRetriever', + 'Reranker', + 'ResultFilter', + 'PostProcessor', + + # Rerankers + 'CrossEncoderReranker', + 'SemanticReranker', + 'BM25Reranker', + 'EnsembleReranker', + + # Filters and processors + 'ScoreFilter', + 'MetadataFilter', + 'TagFilter', + 'DuplicateFilter', + 'AnswerEnhancer', + 'ContextEnricher', + 'ResultLimiter', +] diff --git a/components/advanced_rerankers.py b/components/advanced_rerankers.py new file mode 100644 index 0000000..1284495 --- /dev/null +++ b/components/advanced_rerankers.py @@ -0,0 +1,300 @@ +""" +Advanced reranking components for the modular retrieval pipeline. +Demonstrates how easy it is to add new rerankers. +""" + +from typing import List, Dict, Any +import logging +import numpy as np +from components.retrieval_pipeline import Reranker, RetrievalResult + +logger = logging.getLogger(__name__) + + +class CohereBReranker(Reranker): + """ + Cohere Rerank API-based reranker. + High-quality reranking using Cohere's commercial models. + """ + + def __init__(self, api_key: str = None, model: str = "rerank-english-v2.0", + top_k: int = None): + self.api_key = api_key + self.model = model + self.top_k = top_k + self._client = None + + logger.info(f"Initialized CohereBReranker with model: {model}") + + @property + def component_name(self) -> str: + return f"cohere_reranker_{self.model.replace('-', '_')}" + + def _load_client(self): + """Lazy load the Cohere client.""" + if self._client is None: + try: + import cohere + self._client = cohere.Client(self.api_key) + logger.info(f"Loaded Cohere client with model: {self.model}") + except ImportError: + raise ImportError("cohere package is required for CohereBReranker") + except Exception as e: + logger.warning(f"Could not initialize Cohere client: {e}") + raise + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Rerank results using Cohere Rerank API.""" + if not results: + return results + + try: + self._load_client() + + # Prepare documents for reranking + documents = [result.document.page_content for result in results] + + # Call Cohere Rerank API + response = self._client.rerank( + model=self.model, + query=query, + documents=documents, + top_k=self.top_k or len(results) + ) + + # Reorder results based on Cohere scores + reranked_results = [] + for rank_result in response.results: + original_result = results[rank_result.index] + + # Store original score in metadata + original_result.metadata["original_score"] = original_result.score + original_result.metadata["cohere_score"] = rank_result.relevance_score + + # Update score and method + original_result.score = rank_result.relevance_score + original_result.retrieval_method = f"{original_result.retrieval_method}+cohere" + + reranked_results.append(original_result) + + logger.info(f"Reranked {len(results)} results with Cohere, returning top {len(reranked_results)}") + return reranked_results + + except Exception as e: + logger.warning(f"Cohere reranking failed: {e}, returning original results") + return results + + +class BgeReranker(Reranker): + """ + BGE (BAAI General Embedding) reranker using local models. + Excellent for multilingual and general domain reranking. + """ + + def __init__(self, model_name: str = "BAAI/bge-reranker-base", + device: str = "cpu", top_k: int = None): + self.model_name = model_name + self.device = device + self.top_k = top_k + self._tokenizer = None + self._model = None + + logger.info(f"Initialized BgeReranker with model: {model_name}") + + @property + def component_name(self) -> str: + return f"bge_reranker_{self.model_name.split('/')[-1]}" + + def _load_model(self): + """Lazy load the BGE model.""" + if self._model is None: + try: + from transformers import AutoTokenizer, AutoModelForSequenceClassification + import torch + + self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self._model = AutoModelForSequenceClassification.from_pretrained(self.model_name) + + if self.device == "cuda" and torch.cuda.is_available(): + self._model = self._model.cuda() + + logger.info(f"Loaded BGE model: {self.model_name}") + except ImportError: + raise ImportError("transformers and torch are required for BgeReranker") + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Rerank results using BGE model.""" + if not results: + return results + + self._load_model() + + try: + import torch + + # Prepare query-document pairs + pairs = [[query, result.document.page_content] for result in results] + + # Tokenize and compute scores + with torch.no_grad(): + inputs = self._tokenizer(pairs, padding=True, truncation=True, + return_tensors='pt', max_length=512) + + if self.device == "cuda" and torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + scores = self._model(**inputs, return_dict=True).logits.view(-1, ).float() + scores = torch.sigmoid(scores).cpu().numpy() + + # Create scored results + scored_results = [] + for i, result in enumerate(results): + # Store original score + result.metadata["original_score"] = result.score + result.metadata["bge_score"] = float(scores[i]) + + # Update score and method + result.score = float(scores[i]) + result.retrieval_method = f"{result.retrieval_method}+bge" + + scored_results.append((result, float(scores[i]))) + + # Sort by BGE score + scored_results.sort(key=lambda x: x[1], reverse=True) + + # Return top k + top_k = self.top_k or len(scored_results) + reranked_results = [result for result, _ in scored_results[:top_k]] + + logger.info(f"Reranked {len(results)} results with BGE, returning top {len(reranked_results)}") + return reranked_results + + except Exception as e: + logger.warning(f"BGE reranking failed: {e}, returning original results") + return results + + +class ColBERTReranker(Reranker): + """ + ColBERT-based reranker for late interaction reranking. + Highly effective for passage ranking tasks. + """ + + def __init__(self, model_name: str = "colbert-ir/colbertv2.0", + device: str = "cpu", top_k: int = None): + self.model_name = model_name + self.device = device + self.top_k = top_k + self._model = None + + logger.info(f"Initialized ColBERTReranker with model: {model_name}") + + @property + def component_name(self) -> str: + return f"colbert_reranker_{self.model_name.split('/')[-1]}" + + def _load_model(self): + """Lazy load the ColBERT model.""" + if self._model is None: + try: + from colbert.modeling.checkpoint import Checkpoint + + self._model = Checkpoint(self.model_name, colbert_config=None) + logger.info(f"Loaded ColBERT model: {self.model_name}") + except ImportError: + raise ImportError("colbert-ai package is required for ColBERTReranker") + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Rerank results using ColBERT late interaction.""" + if not results: + return results + + try: + self._load_model() + + # Prepare documents + documents = [result.document.page_content for result in results] + + # Compute ColBERT scores + Q = self._model.queryFromText([query]) + D = self._model.docFromText(documents) + + # Late interaction scoring + scores = self._model.score(Q, D).squeeze().tolist() + if isinstance(scores, float): # Single document + scores = [scores] + + # Create scored results + scored_results = [] + for i, result in enumerate(results): + # Store original score + result.metadata["original_score"] = result.score + result.metadata["colbert_score"] = scores[i] + + # Update score and method + result.score = scores[i] + result.retrieval_method = f"{result.retrieval_method}+colbert" + + scored_results.append((result, scores[i])) + + # Sort by ColBERT score + scored_results.sort(key=lambda x: x[1], reverse=True) + + # Return top k + top_k = self.top_k or len(scored_results) + reranked_results = [result for result, _ in scored_results[:top_k]] + + logger.info(f"Reranked {len(results)} results with ColBERT, returning top {len(reranked_results)}") + return reranked_results + + except Exception as e: + logger.warning(f"ColBERT reranking failed: {e}, returning original results") + return results + + +class MultiStageReranker(Reranker): + """ + Multi-stage reranker that combines multiple reranking models. + First stage: Fast lightweight reranker (e.g., BGE-small) + Second stage: High-quality reranker (e.g., Cohere, CrossEncoder-large) + """ + + def __init__(self, stage1_reranker: Reranker, stage2_reranker: Reranker, + stage1_k: int = 20, stage2_k: int = None): + self.stage1_reranker = stage1_reranker + self.stage2_reranker = stage2_reranker + self.stage1_k = stage1_k + self.stage2_k = stage2_k + + logger.info(f"Initialized MultiStageReranker: {stage1_reranker.component_name} -> {stage2_reranker.component_name}") + + @property + def component_name(self) -> str: + return f"multistage_{self.stage1_reranker.component_name}_{self.stage2_reranker.component_name}" + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Apply two-stage reranking.""" + if not results: + return results + + logger.info(f"Multi-stage reranking: Stage 1 with {self.stage1_reranker.component_name}") + + # Stage 1: Fast reranking to reduce candidate set + stage1_results = self.stage1_reranker.rerank(query, results, **kwargs) + stage1_results = stage1_results[:self.stage1_k] + + logger.info(f"Multi-stage reranking: Stage 2 with {self.stage2_reranker.component_name}") + + # Stage 2: High-quality reranking on reduced set + stage2_results = self.stage2_reranker.rerank(query, stage1_results, **kwargs) + + if self.stage2_k: + stage2_results = stage2_results[:self.stage2_k] + + # Update retrieval method to reflect multi-stage process + for result in stage2_results: + result.retrieval_method = f"{result.retrieval_method}+multistage" + result.metadata["multistage_reranking"] = True + + logger.info(f"Multi-stage reranking completed: {len(results)} -> {self.stage1_k} -> {len(stage2_results)}") + return stage2_results diff --git a/components/filters.py b/components/filters.py new file mode 100644 index 0000000..07f8c85 --- /dev/null +++ b/components/filters.py @@ -0,0 +1,313 @@ +""" +Filter and post-processor components for the modular retrieval pipeline. +""" + +from typing import List, Dict, Any, Set +import logging +from components.retrieval_pipeline import ResultFilter, PostProcessor, RetrievalResult + +logger = logging.getLogger(__name__) + + +class ScoreFilter(ResultFilter): + """Filter results based on minimum score threshold.""" + + def __init__(self, min_score: float = 0.5): + self.min_score = min_score + logger.info(f"Initialized ScoreFilter with min_score={min_score}") + + @property + def component_name(self) -> str: + return f"score_filter_{self.min_score}" + + def filter(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Filter results below minimum score.""" + min_score = kwargs.get('min_score', self.min_score) + + filtered = [r for r in results if r.score >= min_score] + + logger.info(f"ScoreFilter: {len(results)} -> {len(filtered)} results (min_score={min_score})") + return filtered + + +class MetadataFilter(ResultFilter): + """Filter results based on metadata criteria.""" + + def __init__(self, filter_criteria: Dict[str, Any]): + self.filter_criteria = filter_criteria + logger.info(f"Initialized MetadataFilter with criteria: {filter_criteria}") + + @property + def component_name(self) -> str: + return "metadata_filter" + + def filter(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Filter results based on metadata.""" + criteria = kwargs.get('filter_criteria', self.filter_criteria) + + filtered = [] + for result in results: + doc_metadata = result.document.metadata + labels = doc_metadata.get('labels', {}) + + # Check each criterion + passes_filter = True + for key, expected_value in criteria.items(): + # Check in labels first, then in main metadata + actual_value = labels.get(key) or doc_metadata.get(key) + + if isinstance(expected_value, (list, set)): + # Check if actual value is in the expected set + if actual_value not in expected_value: + passes_filter = False + break + elif isinstance(expected_value, dict): + # Handle complex criteria like {"score": {">=": 0.5}} + if not self._check_complex_criteria(actual_value, expected_value): + passes_filter = False + break + else: + # Exact match + if actual_value != expected_value: + passes_filter = False + break + + if passes_filter: + filtered.append(result) + + logger.info(f"MetadataFilter: {len(results)} -> {len(filtered)} results") + return filtered + + def _check_complex_criteria(self, value, criteria): + """Check complex criteria like {">=": 0.5}.""" + for op, threshold in criteria.items(): + if op == ">=": + return value >= threshold + elif op == ">": + return value > threshold + elif op == "<=": + return value <= threshold + elif op == "<": + return value < threshold + elif op == "==": + return value == threshold + elif op == "!=": + return value != threshold + elif op == "in": + return value in threshold + elif op == "not_in": + return value not in threshold + return True + + +class TagFilter(ResultFilter): + """Filter results based on tags.""" + + def __init__(self, required_tags: List[str] = None, excluded_tags: List[str] = None): + self.required_tags = set(required_tags or []) + self.excluded_tags = set(excluded_tags or []) + logger.info(f"Initialized TagFilter (required: {required_tags}, excluded: {excluded_tags})") + + @property + def component_name(self) -> str: + return "tag_filter" + + def filter(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Filter results based on tags.""" + required = set(kwargs.get('required_tags', self.required_tags)) + excluded = set(kwargs.get('excluded_tags', self.excluded_tags)) + + filtered = [] + for result in results: + labels = result.document.metadata.get('labels', {}) + tags = set(labels.get('tags', [])) + + # Check required tags + if required and not required.issubset(tags): + continue + + # Check excluded tags + if excluded and excluded.intersection(tags): + continue + + filtered.append(result) + + logger.info(f"TagFilter: {len(results)} -> {len(filtered)} results") + return filtered + + +class DuplicateFilter(ResultFilter): + """Remove duplicate results based on external_id or content.""" + + def __init__(self, dedup_by: str = "external_id"): + self.dedup_by = dedup_by # "external_id", "content", or "both" + logger.info(f"Initialized DuplicateFilter (dedup_by={dedup_by})") + + @property + def component_name(self) -> str: + return f"duplicate_filter_{self.dedup_by}" + + def filter(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Remove duplicates.""" + dedup_by = kwargs.get('dedup_by', self.dedup_by) + + seen = set() + filtered = [] + + for result in results: + # Generate deduplication key + if dedup_by == "external_id": + key = result.document.metadata.get('external_id') + elif dedup_by == "content": + key = result.document.page_content[:200] # First 200 chars + elif dedup_by == "both": + external_id = result.document.metadata.get('external_id', '') + content_hash = hash(result.document.page_content) + key = (external_id, content_hash) + else: + key = result.document.page_content[:200] + + if key not in seen: + seen.add(key) + filtered.append(result) + + logger.info(f"DuplicateFilter: {len(results)} -> {len(filtered)} results") + return filtered + + +class AnswerEnhancer(PostProcessor): + """Enhance answer results with better formatting and context.""" + + def __init__(self): + logger.info("Initialized AnswerEnhancer") + + @property + def component_name(self) -> str: + return "answer_enhancer" + + def post_process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Enhance answer formatting and metadata.""" + enhanced = [] + + for result in results: + labels = result.document.metadata.get('labels', {}) + + # Extract answer information + title = labels.get('title', 'N/A') + tags = labels.get('tags', []) + has_context = labels.get('has_question_context', False) + + # Enhance metadata + enhanced_metadata = { + **result.metadata, + "question_title": title, + "programming_tags": tags, + "has_question_context": has_context, + "answer_quality": self._assess_answer_quality(result, labels), + "enhanced": True + } + + # Create enhanced result + enhanced_result = RetrievalResult( + document=result.document, + score=result.score, + retrieval_method=f"{result.retrieval_method}+enhanced", + metadata=enhanced_metadata + ) + enhanced.append(enhanced_result) + + logger.info(f"AnswerEnhancer: Enhanced {len(results)} results") + return enhanced + + def _assess_answer_quality(self, result: RetrievalResult, labels: Dict) -> str: + """Assess the quality of an answer.""" + content_length = len(result.document.page_content) + has_code = any(marker in result.document.page_content.lower() + for marker in ['```', '', 'def ', 'function', 'class ']) + has_context = labels.get('has_question_context', False) + tags_count = len(labels.get('tags', [])) + + score = 0 + if content_length > 200: + score += 1 + if has_code: + score += 1 + if has_context: + score += 1 + if tags_count >= 2: + score += 1 + + if score >= 3: + return "high" + elif score >= 2: + return "medium" + else: + return "low" + + +class ContextEnricher(PostProcessor): + """Enrich results with additional context information.""" + + def __init__(self): + logger.info("Initialized ContextEnricher") + + @property + def component_name(self) -> str: + return "context_enricher" + + def post_process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Add contextual information to results.""" + enriched = [] + + for i, result in enumerate(results): + labels = result.document.metadata.get('labels', {}) + + # Add positional context + context_info = { + "rank_position": i + 1, + "query_used": query, + "result_type": "answer" if labels.get('post_type') == 'answer' else "unknown", + "source_platform": "stackoverflow" if "stackoverflow" in labels.get('source', '') else "unknown", + "enriched_at": self._get_timestamp() + } + + # Merge with existing metadata + enriched_metadata = {**result.metadata, **context_info} + + enriched_result = RetrievalResult( + document=result.document, + score=result.score, + retrieval_method=f"{result.retrieval_method}+enriched", + metadata=enriched_metadata + ) + enriched.append(enriched_result) + + logger.info(f"ContextEnricher: Enriched {len(results)} results") + return enriched + + def _get_timestamp(self) -> str: + """Get current timestamp.""" + from datetime import datetime + return datetime.now().isoformat() + + +class ResultLimiter(PostProcessor): + """Limit the number of final results.""" + + def __init__(self, max_results: int = 10): + self.max_results = max_results + logger.info(f"Initialized ResultLimiter (max_results={max_results})") + + @property + def component_name(self) -> str: + return f"result_limiter_{self.max_results}" + + def post_process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Limit the number of results.""" + max_results = kwargs.get('max_results', self.max_results) + limited = results[:max_results] + + if len(results) > max_results: + logger.info(f"ResultLimiter: Limited {len(results)} -> {len(limited)} results") + + return limited diff --git a/components/rerankers.py b/components/rerankers.py new file mode 100644 index 0000000..0937a5a --- /dev/null +++ b/components/rerankers.py @@ -0,0 +1,349 @@ +""" +Reranking components for the modular retrieval pipeline. +""" + +from typing import List, Dict, Any +import logging +from components.retrieval_pipeline import Reranker, RetrievalResult + +logger = logging.getLogger(__name__) + + +class CrossEncoderReranker(Reranker): + """ + Cross-encoder based reranker using sentence-transformers. + + Uses transformer models trained for passage ranking tasks. + Supports models like: + - ms-marco-MiniLM-L-12-v2 + - ms-marco-MiniLM-L-6-v2 + - cross-encoder/ms-marco-TinyBERT-L-2-v2 + """ + + def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", + device: str = "cpu", top_k: int = None): + """ + Initialize the cross-encoder reranker. + + Args: + model_name (str): HuggingFace model name for cross-encoder + device (str): Device to run inference on ('cpu' or 'cuda') + top_k (int, optional): Maximum number of results to return + """ + self.model_name = model_name + self.device = device + self.top_k = top_k + self._model = None + + logger.info( + f"Initialized CrossEncoderReranker with model: {model_name}") + + @property + def component_name(self) -> str: + """Return component name for identification.""" + return f"cross_encoder_reranker_{self.model_name.split('/')[-1]}" + + def _load_model(self): + """ + Lazy load the cross-encoder model. + + Raises: + ImportError: If sentence-transformers is not installed + """ + if self._model is None: + try: + from sentence_transformers import CrossEncoder + self._model = CrossEncoder(self.model_name, device=self.device) + logger.info(f"Loaded CrossEncoder model: {self.model_name}") + except ImportError: + raise ImportError( + "sentence-transformers is required for CrossEncoderReranker") + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """ + Rerank results using cross-encoder model. + + Args: + query (str): The search query + results (List[RetrievalResult]): Results to rerank + **kwargs: Additional parameters + + Returns: + List[RetrievalResult]: Reranked results sorted by cross-encoder scores + """ + if not results: + return results + + self._load_model() + + # Get top_k from kwargs or use instance default + top_k = kwargs.get('top_k', self.top_k or len(results)) + + # Prepare query-document pairs + query_doc_pairs = [] + for result in results: + # Use document content for reranking + doc_text = result.document.page_content + query_doc_pairs.append([query, doc_text]) + + # Score with cross-encoder + try: + scores = self._model.predict(query_doc_pairs) + + # Update results with new scores + reranked_results = [] + for i, result in enumerate(results): + new_result = RetrievalResult( + document=result.document, + score=float(scores[i]), + retrieval_method=f"{result.retrieval_method}+cross_encoder", + metadata={ + **result.metadata, + "original_score": result.score, + "reranker_model": self.model_name, + "reranked": True + } + ) + reranked_results.append(new_result) + + # Sort by new scores and take top_k + reranked_results.sort(key=lambda x: x.score, reverse=True) + final_results = reranked_results[:top_k] + + logger.info( + f"Reranked {len(results)} results, returning top {len(final_results)}") + return final_results + + except Exception as e: + logger.error(f"Error in cross-encoder reranking: {e}") + # Fallback to original results + return results[:top_k] + + +class SemanticReranker(Reranker): + """ + Semantic similarity reranker using embeddings and cosine similarity. + """ + + def __init__(self, embedder=None, top_k: int = None): + self.embedder = embedder + self.top_k = top_k + + logger.info("Initialized SemanticReranker") + + @property + def component_name(self) -> str: + return "semantic_reranker" + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Rerank using semantic similarity.""" + if not results: + return results + + if not self.embedder: + logger.warning("No embedder provided, skipping semantic reranking") + return results + + top_k = kwargs.get('top_k', self.top_k or len(results)) + + try: + import numpy as np + from sklearn.metrics.pairwise import cosine_similarity + + # Get query embedding + query_embedding = self.embedder.embed_query(query) + + # Get document embeddings + doc_texts = [result.document.page_content for result in results] + doc_embeddings = self.embedder.embed_documents(doc_texts) + + # Calculate similarities + similarities = cosine_similarity( + [query_embedding], doc_embeddings)[0] + + # Update results with new scores + reranked_results = [] + for i, result in enumerate(results): + new_result = RetrievalResult( + document=result.document, + score=float(similarities[i]), + retrieval_method=f"{result.retrieval_method}+semantic", + metadata={ + **result.metadata, + "original_score": result.score, + "reranked": True + } + ) + reranked_results.append(new_result) + + # Sort and return top_k + reranked_results.sort(key=lambda x: x.score, reverse=True) + return reranked_results[:top_k] + + except Exception as e: + logger.error(f"Error in semantic reranking: {e}") + return results[:top_k] + + +class BM25Reranker(Reranker): + """ + BM25-based reranker for keyword matching. + """ + + def __init__(self, k1: float = 1.2, b: float = 0.75, top_k: int = None): + self.k1 = k1 + self.b = b + self.top_k = top_k + + logger.info(f"Initialized BM25Reranker (k1={k1}, b={b})") + + @property + def component_name(self) -> str: + return "bm25_reranker" + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Rerank using BM25 scoring.""" + if not results: + return results + + try: + from rank_bm25 import BM25Okapi + import nltk + from nltk.tokenize import word_tokenize + + # Download required NLTK data if not present + try: + nltk.data.find('tokenizers/punkt') + except LookupError: + nltk.download('punkt') + + top_k = kwargs.get('top_k', self.top_k or len(results)) + + # Tokenize documents + doc_texts = [result.document.page_content for result in results] + tokenized_docs = [word_tokenize(doc.lower()) for doc in doc_texts] + + # Create BM25 object + bm25 = BM25Okapi(tokenized_docs) + + # Tokenize query and get scores + tokenized_query = word_tokenize(query.lower()) + scores = bm25.get_scores(tokenized_query) + + # Update results with BM25 scores + reranked_results = [] + for i, result in enumerate(results): + new_result = RetrievalResult( + document=result.document, + score=float(scores[i]), + retrieval_method=f"{result.retrieval_method}+bm25", + metadata={ + **result.metadata, + "original_score": result.score, + "bm25_params": {"k1": self.k1, "b": self.b}, + "reranked": True + } + ) + reranked_results.append(new_result) + + # Sort and return top_k + reranked_results.sort(key=lambda x: x.score, reverse=True) + return reranked_results[:top_k] + + except ImportError as e: + logger.error(f"Missing dependency for BM25 reranking: {e}") + return results[:top_k] + except Exception as e: + logger.error(f"Error in BM25 reranking: {e}") + return results[:top_k] + + +class EnsembleReranker(Reranker): + """ + Ensemble reranker that combines multiple reranking strategies. + """ + + def __init__(self, rerankers: List[Reranker], weights: List[float] = None, top_k: int = None): + self.rerankers = rerankers + self.weights = weights or [1.0] * len(rerankers) + self.top_k = top_k + + if len(self.weights) != len(rerankers): + raise ValueError( + "Number of weights must match number of rerankers") + + logger.info( + f"Initialized EnsembleReranker with {len(rerankers)} rerankers") + + @property + def component_name(self) -> str: + return f"ensemble_reranker_{len(self.rerankers)}_models" + + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Rerank using ensemble of multiple rerankers.""" + if not results: + return results + + top_k = kwargs.get('top_k', self.top_k or len(results)) + + # Get scores from each reranker + all_scores = [] + for reranker in self.rerankers: + try: + reranked = reranker.rerank(query, results, **kwargs) + # Extract scores in same order as input + scores = [] + for orig_result in results: + # Find corresponding result in reranked list + orig_id = id(orig_result.document) + for reranked_result in reranked: + if id(reranked_result.document) == orig_id: + scores.append(reranked_result.score) + break + else: + scores.append(0.0) # Not found + all_scores.append(scores) + except Exception as e: + logger.error(f"Error in {reranker.component_name}: {e}") + # Use zero scores for failed reranker + all_scores.append([0.0] * len(results)) + + # Normalize scores and compute weighted average + import numpy as np + + normalized_scores = [] + for scores in all_scores: + scores_array = np.array(scores) + if scores_array.max() > scores_array.min(): + # Min-max normalization + normalized = (scores_array - scores_array.min()) / \ + (scores_array.max() - scores_array.min()) + else: + normalized = scores_array + normalized_scores.append(normalized) + + # Weighted combination + ensemble_scores = np.zeros(len(results)) + for i, (scores, weight) in enumerate(zip(normalized_scores, self.weights)): + ensemble_scores += weight * scores + + # Create final results + final_results = [] + for i, result in enumerate(results): + new_result = RetrievalResult( + document=result.document, + score=float(ensemble_scores[i]), + retrieval_method=f"{result.retrieval_method}+ensemble", + metadata={ + **result.metadata, + "original_score": result.score, + "ensemble_components": [r.component_name for r in self.rerankers], + "ensemble_weights": self.weights, + "reranked": True + } + ) + final_results.append(new_result) + + # Sort and return top_k + final_results.sort(key=lambda x: x.score, reverse=True) + return final_results[:top_k] diff --git a/components/retrieval_pipeline.py b/components/retrieval_pipeline.py new file mode 100644 index 0000000..0e220aa --- /dev/null +++ b/components/retrieval_pipeline.py @@ -0,0 +1,648 @@ +""" +Modular and extensible retrieval pipeline for RAG systems. +Supports easy addition of components like rerankers, filters, and post-processors. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union, Tuple +from dataclasses import dataclass +from langchain_core.documents import Document +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class RetrievalResult: + """ + Enhanced result structure for retrieval pipeline. + + Attributes: + document (Document): The retrieved document + score (float): Relevance score + retrieval_method (str): Method used for retrieval + metadata (Dict[str, Any]): Additional metadata + """ + document: Document + score: float + retrieval_method: str + metadata: Dict[str, Any] = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +class RetrievalComponent(ABC): + """ + Base class for all retrieval pipeline components. + All pipeline components (retrievers, rerankers, filters) inherit from this. + """ + + @property + @abstractmethod + def component_name(self) -> str: + """ + Return the name of this component. + + Returns: + str: Component name for identification and logging + """ + pass + + @abstractmethod + def process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """ + Process the query and/or results. + + Args: + query (str): The search query + results (List[RetrievalResult]): Current results to process + **kwargs: Additional parameters + + Returns: + List[RetrievalResult]: Processed results + """ + pass + + +class BaseRetriever(RetrievalComponent): + """ + Base retriever that generates initial results. + All specific retrievers (dense, sparse, hybrid) inherit from this. + """ + + @abstractmethod + def retrieve(self, query: str, k: int = 5) -> List[RetrievalResult]: + """ + Retrieve initial results. + + Args: + query (str): Search query + k (int): Number of results to retrieve + + Returns: + List[RetrievalResult]: Initial retrieval results + """ + pass + + def process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """ + For retrievers, generate new results (ignore input results). + + Args: + query (str): Search query + results (List[RetrievalResult]): Ignored for initial retrieval + **kwargs: Additional parameters including 'k' for result count + + Returns: + List[RetrievalResult]: Fresh retrieval results + """ + k = kwargs.get('k', 5) + return self.retrieve(query, k) + + +class Reranker(RetrievalComponent): + """ + Base class for reranking components. + Rerankers take existing results and reorder them based on improved relevance scoring. + """ + + @abstractmethod + def rerank(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """ + Rerank the results. + + Args: + query (str): The search query + results (List[RetrievalResult]): Results to rerank + **kwargs: Additional reranking parameters + + Returns: + List[RetrievalResult]: Reranked results + """ + pass + + def process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Process by reranking.""" + return self.rerank(query, results, **kwargs) + + +class ResultFilter(RetrievalComponent): + """Base class for filtering components.""" + + @abstractmethod + def filter(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Filter the results.""" + pass + + def process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Process by filtering.""" + return self.filter(query, results, **kwargs) + + +class PostProcessor(RetrievalComponent): + """Base class for post-processing components.""" + + @abstractmethod + def post_process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Post-process the results.""" + pass + + def process(self, query: str, results: List[RetrievalResult], **kwargs) -> List[RetrievalResult]: + """Process by post-processing.""" + return self.post_process(query, results, **kwargs) + + +class RetrievalPipeline: + """ + Modular retrieval pipeline that chains components together. + + Example usage: + pipeline = RetrievalPipeline([ + QdrantHybridRetriever(config), + CrossEncoderReranker(model="ms-marco-MiniLM-L-12-v2"), + MetadataFilter(min_score=0.5), + AnswerContextEnhancer() + ]) + + results = pipeline.run(query="How to count bits?", k=10) + """ + + def __init__(self, components: List[RetrievalComponent], config: Dict[str, Any] = None): + self.components = components + self.config = config or {} + self._validate_pipeline() + + logger.info(f"Initialized retrieval pipeline with {len(components)} components: " + f"{[comp.component_name for comp in components]}") + + def _validate_pipeline(self): + """Validate that the pipeline has at least one retriever.""" + has_retriever = any(isinstance(comp, BaseRetriever) + for comp in self.components) + if not has_retriever: + raise ValueError( + "Pipeline must contain at least one BaseRetriever component") + + def run(self, query: str, **kwargs) -> List[RetrievalResult]: + """ + Run the full retrieval pipeline. + + Args: + query: The search query + **kwargs: Additional parameters passed to components + + Returns: + List of RetrievalResult objects + """ + logger.info(f"Running retrieval pipeline for query: '{query[:50]}...'") + + results = [] + + for i, component in enumerate(self.components): + component_name = component.component_name + logger.debug(f"Step {i+1}: Running {component_name}") + + try: + # Merge component-specific config with runtime kwargs + component_kwargs = kwargs.copy() + component_config = self.config.get(component_name, {}) + component_kwargs.update(component_config) + + # Process with component + results = component.process(query, results, **component_kwargs) + + logger.debug( + f"{component_name} returned {len(results)} results") + + except Exception as e: + logger.error(f"Error in {component_name}: {e}") + # Decide whether to continue or fail + if self.config.get('fail_on_component_error', False): + raise + # Continue with previous results + + logger.info(f"Pipeline completed with {len(results)} final results") + return results + + def add_component(self, component: RetrievalComponent, position: int = -1): + """Add a component to the pipeline at the specified position.""" + if position == -1: + self.components.append(component) + else: + self.components.insert(position, component) + + self._validate_pipeline() + logger.info( + f"Added {component.component_name} to pipeline at position {position}") + + def remove_component(self, component_name: str) -> bool: + """Remove a component by name.""" + for i, comp in enumerate(self.components): + if comp.component_name == component_name: + removed = self.components.pop(i) + logger.info(f"Removed {removed.component_name} from pipeline") + self._validate_pipeline() + return True + return False + + def get_component(self, component_name: str) -> Optional[RetrievalComponent]: + """Get a component by name.""" + for comp in self.components: + if comp.component_name == component_name: + return comp + return None + + def to_langchain_retriever(self): + """Create a LangChain-compatible retriever interface.""" + class LangChainWrapper: + def __init__(self, pipeline: RetrievalPipeline): + self.pipeline = pipeline + + def get_relevant_documents(self, query: str) -> List[Document]: + results = self.pipeline.run(query) + return [r.document for r in results] + + def retrieve(self, query: str, k: int = 5) -> List[Document]: + results = self.pipeline.run(query, k=k) + return [r.document for r in results] + + return LangChainWrapper(self) + + +class RetrievalPipelineFactory: + """Factory for creating common retrieval pipeline configurations.""" + + @staticmethod + def create_dense_pipeline(config: Dict[str, Any]) -> RetrievalPipeline: + """Create a dense-only retrieval pipeline.""" + from retrievers.dense_retriever import QdrantDenseRetriever + + # Create modern dense retriever + retriever = QdrantDenseRetriever(config) + + return RetrievalPipeline([retriever], config) + + @staticmethod + def create_hybrid_pipeline(config: Dict[str, Any]) -> RetrievalPipeline: + """Create a hybrid retrieval pipeline.""" + from retrievers.hybrid_retriever import QdrantHybridRetriever + + # Create modern hybrid retriever + retriever = QdrantHybridRetriever(config) + + return RetrievalPipeline([retriever], config) + + @staticmethod + def create_reranked_pipeline(config: Dict[str, Any], reranker_model: str = None) -> RetrievalPipeline: + """Create a pipeline with retrieval + reranking.""" + # Start with hybrid if available, otherwise dense + if config.get("embedding", {}).get("sparse"): + pipeline = RetrievalPipelineFactory.create_hybrid_pipeline(config) + else: + pipeline = RetrievalPipelineFactory.create_dense_pipeline(config) + + # Add reranker if specified + if reranker_model: + from components.rerankers import CrossEncoderReranker + reranker = CrossEncoderReranker(model_name=reranker_model) + pipeline.add_component(reranker) + + return pipeline + + @staticmethod + def create_from_config(config: Dict[str, Any]) -> 'RetrievalPipeline': + """ + Create a retrieval pipeline from configuration. + + Args: + config: Configuration dictionary with 'retrieval_pipeline' section + + Returns: + Configured RetrievalPipeline + + Example config: + retrieval_pipeline: + retriever: + type: dense # or hybrid + top_k: 10 + stages: + - type: score_filter + config: + min_score: 0.3 + - type: reranker + config: + model_type: cross_encoder + model_name: "cross-encoder/ms-marco-MiniLM-L-6-v2" + top_k: 5 + """ + pipeline_config = config.get("retrieval_pipeline", {}) + + if not pipeline_config: + raise ValueError("No 'retrieval_pipeline' section found in config") + + # Create retriever + retriever_config = pipeline_config.get("retriever", {}) + retriever = RetrievalPipelineFactory._create_retriever( + retriever_config, config) + + # Initialize pipeline with retriever + pipeline = RetrievalPipeline([retriever], config) + + # Add stages + stages = pipeline_config.get("stages", []) + for stage_config in stages: + component = RetrievalPipelineFactory._create_stage_component( + stage_config, config) + if component: + pipeline.add_component(component) + + logger.info( + f"Created pipeline from config with {len(pipeline.components)} components") + return pipeline + + @staticmethod + def _create_retriever(retriever_config: Dict[str, Any], global_config: Dict[str, Any]) -> BaseRetriever: + """Create retriever from configuration.""" + retriever_type = retriever_config.get("type", "dense") + + if retriever_type == "dense": + from retrievers.dense_retriever import QdrantDenseRetriever + return QdrantDenseRetriever(global_config) + elif retriever_type == "sparse": + from retrievers.sparse_retriever import QdrantSparseRetriever + return QdrantSparseRetriever(global_config) + elif retriever_type == "hybrid": + from retrievers.hybrid_retriever import QdrantHybridRetriever + return QdrantHybridRetriever(global_config) + elif retriever_type == "semantic": + from retrievers.semantic_retriever import SemanticRetriever + return SemanticRetriever(global_config) + else: + raise ValueError(f"Unknown retriever type: {retriever_type}") + + @staticmethod + def _create_stage_component(stage_config: Dict[str, Any], global_config: Dict[str, Any]) -> Optional[RetrievalComponent]: + """Create a pipeline stage component from configuration.""" + stage_type = stage_config.get("type") + config = stage_config.get("config", {}) + + try: + if stage_type == "score_filter": + from components.filters import ScoreFilter + return ScoreFilter(min_score=config.get("min_score", 0.3)) + + elif stage_type == "duplicate_filter": + from components.filters import DuplicateFilter + return DuplicateFilter(dedup_by=config.get("dedup_by", "external_id")) + + elif stage_type == "tag_filter": + from components.filters import TagFilter + return TagFilter( + required_tags=config.get("required_tags"), + excluded_tags=config.get("excluded_tags") + ) + + elif stage_type == "answer_enhancer": + from components.filters import AnswerEnhancer + return AnswerEnhancer() + + elif stage_type == "result_limiter": + from components.filters import ResultLimiter + return ResultLimiter(max_results=config.get("max_results", 5)) + + elif stage_type == "reranker": + return RetrievalPipelineFactory._create_reranker(config) + + else: + logger.warning(f"Unknown stage type: {stage_type}") + return None + + except ImportError as e: + logger.warning(f"Could not create {stage_type}: {e}") + return None + except Exception as e: + logger.error(f"Error creating {stage_type}: {e}") + return None + + @staticmethod + def _create_reranker(config: Dict[str, Any]) -> Optional[Reranker]: + """Create reranker from configuration.""" + model_type = config.get("model_type") + + try: + if model_type == "cross_encoder": + from components.rerankers import CrossEncoderReranker + return CrossEncoderReranker( + model_name=config.get( + "model_name", "cross-encoder/ms-marco-MiniLM-L-6-v2"), + top_k=config.get("top_k") + ) + + elif model_type == "bge": + from components.advanced_rerankers import BgeReranker + return BgeReranker( + model_name=config.get( + "model_name", "BAAI/bge-reranker-base"), + top_k=config.get("top_k") + ) + + elif model_type == "multistage": + stage1_config = config.get("stage1", {}) + stage2_config = config.get("stage2", {}) + + stage1_reranker = RetrievalPipelineFactory._create_reranker( + stage1_config) + stage2_reranker = RetrievalPipelineFactory._create_reranker( + stage2_config) + + if stage1_reranker and stage2_reranker: + from components.advanced_rerankers import MultiStageReranker + return MultiStageReranker( + stage1_reranker=stage1_reranker, + stage2_reranker=stage2_reranker, + stage1_k=stage1_config.get("top_k", 10), + stage2_k=stage2_config.get("top_k", 5) + ) + + elif model_type == "ensemble": + rerankers = [] + weights = [] + + for reranker_config in config.get("rerankers", []): + reranker = RetrievalPipelineFactory._create_reranker( + reranker_config) + if reranker: + rerankers.append(reranker) + weights.append(reranker_config.get("weight", 1.0)) + + if rerankers: + from components.rerankers import EnsembleReranker + return EnsembleReranker( + rerankers=rerankers, + weights=weights + ) + + else: + logger.warning(f"Unknown reranker type: {model_type}") + return None + + except ImportError as e: + logger.warning(f"Could not create {model_type} reranker: {e}") + return None + except Exception as e: + logger.error(f"Error creating {model_type} reranker: {e}") + return None + + @staticmethod + def create_sparse_pipeline(config: Dict[str, Any]) -> RetrievalPipeline: + """Create a sparse-only retrieval pipeline.""" + from retrievers.sparse_retriever import QdrantSparseRetriever + + # Create modern sparse retriever + retriever = QdrantSparseRetriever(config) + + return RetrievalPipeline([retriever], config) + + @staticmethod + def create_semantic_pipeline(config: Dict[str, Any]) -> RetrievalPipeline: + """Create a semantic retrieval pipeline with intelligent routing.""" + from retrievers.semantic_retriever import SemanticRetriever + + # Create semantic retriever + retriever = SemanticRetriever(config) + + return RetrievalPipeline([retriever], config) + + @staticmethod + def create_from_retriever_config(retriever_type: str, global_config: Dict[str, Any] = None) -> 'RetrievalPipeline': + """ + Create a retrieval pipeline from a retriever configuration file. + + Args: + retriever_type: Type of retriever (dense, sparse, hybrid, semantic) + global_config: Optional global configuration to merge with + + Returns: + Configured RetrievalPipeline + """ + try: + from pipelines.configs.retriever_config_loader import load_retriever_config + + # Load retriever-specific configuration + retriever_config = load_retriever_config(retriever_type) + + # Merge with global config if provided + if global_config: + from pipelines.configs.retriever_config_loader import RetrieverConfigLoader + loader = RetrieverConfigLoader() + merged_config = loader.merge_with_global_config( + retriever_config, global_config) + else: + merged_config = retriever_config + + # Create retriever from the merged config + retriever = RetrievalPipelineFactory._create_retriever( + merged_config['retriever'], merged_config + ) + + # Create pipeline + pipeline = RetrievalPipeline([retriever], merged_config) + + logger.info( + f"Created {retriever_type} pipeline from configuration file") + return pipeline + + except Exception as e: + logger.error( + f"Failed to create pipeline from {retriever_type} config: {e}") + raise + + @staticmethod + def list_available_retrievers() -> List[str]: + """ + List all available retriever types from configuration files. + + Returns: + List of available retriever types + """ + try: + from pipelines.configs.retriever_config_loader import RetrieverConfigLoader + loader = RetrieverConfigLoader() + return loader.get_available_configs() + except Exception as e: + logger.warning(f"Could not load retriever configs: {e}") + return [] + + @staticmethod + def create_from_unified_config(config: Dict[str, Any], retriever_type: str = None) -> 'RetrievalPipeline': + """ + Create a retrieval pipeline from unified configuration structure. + + Args: + config: Complete configuration dictionary with retriever configs embedded + retriever_type: Type of retriever to use (if not specified, uses pipeline default) + + Returns: + Configured RetrievalPipeline + + Example usage: + config = load_config("config.yml") + pipeline = RetrievalPipelineFactory.create_from_unified_config(config, "hybrid") + """ + from config.config_loader import get_retriever_config, get_pipeline_config + + # Get pipeline configuration + pipeline_config = get_pipeline_config(config) + + # Determine retriever type + if retriever_type is None: + retriever_type = pipeline_config.get("default_retriever", "hybrid") + + # Get retriever-specific configuration + retriever_config = get_retriever_config(config, retriever_type) + + # Create retriever using unified config + retriever = RetrievalPipelineFactory._create_retriever_from_unified_config( + retriever_config, config) + + # Initialize pipeline with retriever + pipeline = RetrievalPipeline([retriever], config) + + # Add components from pipeline config + components = pipeline_config.get("components", []) + for component_config in components: + if component_config.get("type") == "retriever": + # Skip retriever component as it's already added + continue + + component = RetrievalPipelineFactory._create_stage_component( + component_config, config) + if component: + pipeline.add_component(component) + + logger.info( + f"Created {retriever_type} pipeline from unified config with {len(pipeline.components)} components") + return pipeline + + @staticmethod + def _create_retriever_from_unified_config(retriever_config: Dict[str, Any], + global_config: Dict[str, Any]) -> BaseRetriever: + """Create retriever from unified configuration structure.""" + retriever_type = retriever_config.get("type") + + if retriever_type == "dense": + from retrievers.dense_retriever import QdrantDenseRetriever + return QdrantDenseRetriever(retriever_config) + elif retriever_type == "sparse": + from retrievers.sparse_retriever import QdrantSparseRetriever + return QdrantSparseRetriever(retriever_config) + elif retriever_type == "hybrid": + from retrievers.hybrid_retriever import QdrantHybridRetriever + return QdrantHybridRetriever(retriever_config) + elif retriever_type == "semantic": + from retrievers.semantic_retriever import SemanticRetriever + return SemanticRetriever(retriever_config) + else: + raise ValueError(f"Unknown retriever type: {retriever_type}") diff --git a/config.yml b/config.yml new file mode 100644 index 0000000..09277e9 --- /dev/null +++ b/config.yml @@ -0,0 +1,141 @@ +agent_retrieval: + active_config: fast_hybrid + config_path: pipelines/configs/retrieval/fast_hybrid.yml +benchmark: + evaluation: + k_values: + - 1 + - 5 + - 10 + - 20 + metrics: + - precision + - recall + - f1 + - mrr + - ndcg + retrieval: + search_params: + score_threshold: 0.0 + strategy: hybrid + top_k: 20 +embedding: + dense: + api_key_env: GOOGLE_API_KEY + batch_size: 32 + dimensions: 768 + model: models/embedding-001 + provider: google + vector_name: dense + sparse: + model: Qdrant/bm25 + provider: sparse + vector_name: sparse + strategy: hybrid +llm: + model: gpt-4.1-mini + provider: openai + temperature: 0.0 +qdrant: + collection: sosum_stackoverflow_hybrid_v1 + dense_vector_name: dense + sparse_vector_name: sparse +retrieval_pipeline: + components: + - config: + retriever_type: hybrid + type: retriever + - config: + min_score: 0.01 + type: score_filter + - config: + model_name: cross-encoder/ms-marco-MiniLM-L-6-v2 + model_type: cross_encoder + top_k: 10 + type: reranker + default_retriever: hybrid +retrievers: + dense: + embedding: + api_key_env: GOOGLE_API_KEY + dimensions: 768 + model: models/embedding-001 + provider: google + performance: + batch_size: 32 + enable_caching: true + lazy_initialization: true + qdrant: + collection_name: sosum_stackoverflow_hybrid_v1 + vector_name: dense + score_threshold: 0.0 + top_k: 10 + type: dense + hybrid: + dense_weight: 0.6 + embedding: + dense: + api_key_env: GOOGLE_API_KEY + dimensions: 768 + model: models/embedding-001 + provider: google + sparse: + dimensions: null + model: Qdrant/bm25 + provider: sparse + strategy: hybrid + fusion: + dense_weight: 0.7 + method: rrf + rrf_k: 60 + sparse_weight: 0.3 + fusion_method: rrf + performance: + batch_size: 32 + enable_caching: true + lazy_initialization: true + parallel_search: false + qdrant: + collection_name: sosum_stackoverflow_hybrid_v1 + dense_vector_name: dense + hybrid_config: + alpha: 0.5 + reciprocal_rank_constant: 60 + sparse_vector_name: sparse + score_threshold: 0.0 + sparse_weight: 0.4 + top_k: 10 + type: hybrid + semantic: + embedding: + api_key_env: GOOGLE_API_KEY + dimensions: 768 + model: models/embedding-001 + provider: google + performance: + batch_size: 32 + enable_caching: true + lazy_initialization: true + qdrant: + collection_name: sosum_stackoverflow_hybrid_v1 + vector_name: dense + score_threshold: 0.0 + semantic_config: + context_window: 3 + similarity_threshold: 0.8 + top_k: 10 + type: semantic + sparse: + embedding: + model: Qdrant/bm25 + provider: sparse + performance: + batch_size: 32 + enable_caching: true + lazy_initialization: true + qdrant: + collection_name: sosum_stackoverflow_hybrid_v1 + vector_name: sparse + score_threshold: 0.0 + top_k: 10 + type: sparse diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..0125546 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1 @@ +from config.config_loader import load_config diff --git a/config/config_loader.py b/config/config_loader.py new file mode 100644 index 0000000..4afcebf --- /dev/null +++ b/config/config_loader.py @@ -0,0 +1,181 @@ +import yaml +import os +from pathlib import Path +from typing import Dict, Any, Optional +import logging + +logger = logging.getLogger(__name__) + + +def load_config(config_path: str = "config.yml") -> Dict[str, Any]: + """ + Load unified YAML configuration for the pipeline. + + Args: + config_path: Path to the main configuration file + + Returns: + Complete configuration dictionary + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If configuration is invalid + """ + config_path = Path(config_path) + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + try: + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + logger.info(f"Loaded configuration from {config_path}") + return config + + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in {config_path}: {e}") + except Exception as e: + raise ValueError(f"Error loading config {config_path}: {e}") + + +def get_retriever_config(config: Dict[str, Any], retriever_type: str) -> Dict[str, Any]: + """ + Extract retriever-specific configuration from unified config. + + Args: + config: Main configuration dictionary + retriever_type: Type of retriever (dense, sparse, hybrid, semantic) + + Returns: + Retriever-specific configuration + + Raises: + ValueError: If retriever type not found + """ + retrievers_config = config.get("retrievers", {}) + + if retriever_type not in retrievers_config: + available_types = list(retrievers_config.keys()) + raise ValueError( + f"Retriever type '{retriever_type}' not found. Available: {available_types}") + + retriever_config = retrievers_config[retriever_type].copy() + + # Merge with global settings + if "embedding" in config and "embedding" not in retriever_config: + retriever_config["embedding"] = config["embedding"] + + if "qdrant" in config and "qdrant" not in retriever_config: + retriever_config["qdrant"] = config["qdrant"] + + return retriever_config + + +def get_benchmark_config(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Extract benchmark configuration with defaults. + + Args: + config: Main configuration dictionary + + Returns: + Benchmark configuration + """ + benchmark_config = config.get("benchmark", {}) + + # Set defaults + defaults = { + "evaluation": { + "k_values": [1, 5, 10, 20], + "metrics": ["precision", "recall", "f1", "mrr", "ndcg"] + }, + "retrieval": { + "strategy": "hybrid", + "top_k": 20, + "search_params": { + "score_threshold": 0.0 + } + } + } + + # Merge defaults with provided config + for key, default_value in defaults.items(): + if key not in benchmark_config: + benchmark_config[key] = default_value + elif isinstance(default_value, dict): + for sub_key, sub_default in default_value.items(): + if sub_key not in benchmark_config[key]: + benchmark_config[key][sub_key] = sub_default + + return benchmark_config + + +def get_pipeline_config(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Extract retrieval pipeline configuration. + + Args: + config: Main configuration dictionary + + Returns: + Pipeline configuration + """ + pipeline_config = config.get("retrieval_pipeline", {}) + + # Set defaults + if "default_retriever" not in pipeline_config: + pipeline_config["default_retriever"] = "hybrid" + + if "components" not in pipeline_config: + pipeline_config["components"] = [ + {"type": "retriever", "config": { + "retriever_type": pipeline_config["default_retriever"]}} + ] + + return pipeline_config + + +def load_config_with_overrides(config_path: str = "config.yml", + overrides: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Load configuration with optional overrides. + + Args: + config_path: Path to the main configuration file + overrides: Optional dictionary of configuration overrides + + Returns: + Configuration with overrides applied + """ + config = load_config(config_path) + + if overrides: + # Deep merge overrides + config = _deep_merge(config, overrides) + logger.info("Applied configuration overrides") + + return config + + +def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """ + Deep merge two dictionaries. + + Args: + base: Base dictionary + override: Override dictionary + + Returns: + Merged dictionary + """ + result = base.copy() + + for key, value in override.items(): + if (key in result and + isinstance(result[key], dict) and + isinstance(value, dict)): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + + return result diff --git a/database/__init__.py b/database/__init__.py index 0c9b570..e1451c5 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -1 +1,2 @@ from .qdrant_controller import QdrantVectorDB +from .postgres_controller import PostgresController diff --git a/database/postgres_controller.py b/database/postgres_controller.py new file mode 100644 index 0000000..c34de49 --- /dev/null +++ b/database/postgres_controller.py @@ -0,0 +1,156 @@ +import os +import uuid +import datetime +import logging +from dotenv import load_dotenv + +from sqlalchemy import create_engine, Column, String, Integer, Text, DateTime, text +from sqlalchemy.orm import declarative_base, sessionmaker, Session +from sqlalchemy.exc import OperationalError +from logs.utils.logger import get_logger +logger = get_logger("postgres_controller") + +# Load environment variables from .env +load_dotenv(override=True) + +# Define SQLAlchemy Base +Base = declarative_base() + + +class ImageAsset(Base): + """ + ORM model for storing image assets metadata in the 'image_assets' table. + """ + __tablename__ = 'image_assets' + + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + doc_id = Column(String, nullable=False) + page_number = Column(Integer, nullable=False) + file_path = Column(String, nullable=False) + caption = Column(Text) + extracted_text = Column(Text) + created_at = Column(DateTime, default=datetime.datetime.utcnow) + + +class TableAsset(Base): + """ + ORM model for storing table assets metadata in the 'table_assets' table. + """ + __tablename__ = 'table_assets' + + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + doc_id = Column(String, nullable=False) + page_number = Column(Integer, nullable=False) + table_json = Column(Text, nullable=False) + caption = Column(Text) + created_at = Column(DateTime, default=datetime.datetime.utcnow) + + +class PostgresController: + """ + Controller for managing PostgreSQL database connections and asset insertions. + Loads DB config from a config dict (or environment variables), handles session management, + and provides methods for inserting image/table asset metadata. + """ + engine = None + SessionLocal = None + + def __init__(self, db_config: dict = None): + """ + Initialize the database connection using config dict or environment variables. + Raises ValueError if configuration is missing, or ConnectionError if DB is unreachable. + """ + db_config = db_config or {} + + # Prefer config dict; fallback to env vars if not set + user = db_config.get('user') or os.getenv('POSTGRES_USER') + password = db_config.get('password') or os.getenv('POSTGRES_PASSWORD') + host = db_config.get('host') or os.getenv('POSTGRES_HOST') + port = db_config.get('port') or os.getenv('POSTGRES_PORT') + database = db_config.get('database') or os.getenv('POSTGRES_DB') + + if not all([user, password, host, port, database]): + logger.error( + "One or more required Postgres configuration variables are missing.") + raise ValueError( + "One or more required Postgres configuration variables are missing.") + + connection_str = f'postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}' + logger.info(f"Connecting to Postgres: {connection_str}") + + try: + self.engine = create_engine(connection_str) + Base.metadata.create_all(self.engine) + self.SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=self.engine) + logger.info( + "PostgreSQL engine and tables initialized successfully.") + except OperationalError as e: + logger.error(f"Failed to connect to the database: {e}") + + def get_session(self) -> Session: + """ + Create and return a new SQLAlchemy session. + Caller is responsible for closing or using with-statement. + """ + return self.SessionLocal() + + def insert_image_asset(self, doc_id: str, page_number: int, file_path: str, + caption: str = None, extracted_text: str = None): + """ + Insert a new image asset record into the database. + Args: + doc_id (str): Document identifier. + page_number (int): Page number where image appears. + file_path (str): File path of the stored image. + caption (str, optional): Caption text. + extracted_text (str, optional): Extracted OCR text from image. + """ + with self.get_session() as session: + asset = ImageAsset( + doc_id=doc_id, + page_number=page_number, + file_path=file_path, + caption=caption, + extracted_text=extracted_text, + ) + session.add(asset) + session.commit() + logger.info( + f"Inserted ImageAsset for doc_id={doc_id}, page_number={page_number}") + + def insert_table_asset(self, doc_id: str, page_number: int, table_json: str, + caption: str = None): + """ + Insert a new table asset record into the database. + Args: + doc_id (str): Document identifier. + page_number (int): Page number where table appears. + table_json (str): Serialized JSON for table contents. + caption (str, optional): Caption text. + """ + with self.get_session() as session: + asset = TableAsset( + doc_id=doc_id, + page_number=page_number, + table_json=table_json, + caption=caption, + ) + session.add(asset) + session.commit() + logger.info( + f"Inserted TableAsset for doc_id={doc_id}, page_number={page_number}") + + +if __name__ == "__main__": + """ + Main execution block for connectivity testing. + Attempts to connect to PostgreSQL and run a simple test query. + """ + try: + controller = PostgresController() + with controller.get_session() as session: + session.execute(text("SELECT 1")) + logger.info("PostgreSQL connection successful. Tables are ready.") + except Exception as e: + logger.error(f"Connection failed: {str(e)}") diff --git a/database/qdrant_controller.py b/database/qdrant_controller.py index 0afcf4b..95c7e1c 100644 --- a/database/qdrant_controller.py +++ b/database/qdrant_controller.py @@ -1,48 +1,85 @@ -from qdrant_client.http.models import VectorParams, SparseVectorParams -from qdrant_client import models as qmodels -from qdrant_client.http.models import Distance import os import uuid import logging -from typing import List, Optional -from dotenv import load_dotenv +from typing import List, Optional, Dict, Any +from dotenv import load_dotenv from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_qdrant import QdrantVectorStore, RetrievalMode - -from qdrant_client import QdrantClient +from qdrant_client import QdrantClient, models as qmodels from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams - +from logs.utils.logger import get_logger from .base import BaseVectorDB -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) + +logger = get_logger(__name__) class QdrantVectorDB(BaseVectorDB): - def __init__(self): - load_dotenv(override=True) - self.host: str = os.getenv("QDRANT_HOST") - self.port: int = int(os.getenv("QDRANT_PORT")) - self.api_key: Optional[str] = os.getenv("QDRANT_API_KEY") - self.collection_name: str = os.getenv("QDRANT_COLLECTION") - self.dense_vector_name: str = os.getenv("DENSE_VECTOR_NAME", "dense") - self.sparse_vector_name: str = os.getenv( - "SPARSE_VECTOR_NAME", "sparse") + def __init__(self, strategy: str = "dense", config: Optional[Dict[str, Any]] = None): + # Only load .env if it exists, don't fail if it doesn't + try: + load_dotenv(override=True) + except Exception: + logger.debug( + "No .env file found, using environment variables and defaults") + + self.strategy = strategy.lower() + + # Use config if provided, otherwise fall back to environment variables with defaults + if config and "qdrant" in config: + qdrant_config = config["qdrant"] + self.host = qdrant_config.get( + "host", os.getenv("QDRANT_HOST", "localhost")) + self.port = int(qdrant_config.get( + "port", os.getenv("QDRANT_PORT", "6333"))) + self.api_key = qdrant_config.get( + "api_key", os.getenv("QDRANT_API_KEY")) + self.collection_name = qdrant_config.get("collection", qdrant_config.get( + "collection_name", os.getenv("QDRANT_COLLECTION", "default_collection"))) + self.dense_vector_name = qdrant_config.get( + "dense_vector_name", os.getenv("DENSE_VECTOR_NAME", "dense")) + self.sparse_vector_name = qdrant_config.get( + "sparse_vector_name", os.getenv("SPARSE_VECTOR_NAME", "sparse")) + else: + # Fall back to environment variables with sensible defaults + self.host = os.getenv("QDRANT_HOST", "localhost") + self.port = int(os.getenv("QDRANT_PORT", "6333")) + # Can be None for local instances + self.api_key = os.getenv("QDRANT_API_KEY") + self.collection_name = os.getenv( + "QDRANT_COLLECTION", "default_collection") + self.dense_vector_name = os.getenv("DENSE_VECTOR_NAME", "dense") + self.sparse_vector_name = os.getenv("SPARSE_VECTOR_NAME", "sparse") logger.info(f"Qdrant collection: {self.collection_name}") logger.info(f"Dense vector: {self.dense_vector_name}") logger.info(f"Sparse vector: {self.sparse_vector_name}") - - self.client: QdrantClient = QdrantClient( - host=self.host, - port=self.port, - api_key=self.api_key or None, - ) + logger.info(f"Connecting to Qdrant at {self.host}:{self.port}") + + # Validate required configuration + if not self.host: + raise ValueError("QDRANT_HOST is required but not provided") + if not self.collection_name: + raise ValueError("QDRANT_COLLECTION is required but not provided") + + try: + self.client = QdrantClient( + host=self.host, + port=self.port, + api_key=self.api_key or None, + ) + logger.info("Successfully connected to Qdrant") + except Exception as e: + logger.error(f"Failed to connect to Qdrant: {e}") + raise def init_collection(self, dense_vector_size: int) -> None: """ - Create (or recreate) the collection for dense and sparse vectors. + Initialize (or re-create) a Qdrant collection for dense and sparse vectors. + Deletes existing collection if already present. + Args: + dense_vector_size (int): The dimensionality of the dense vector. """ if self.client.collection_exists(self.collection_name): logger.info( @@ -50,18 +87,15 @@ def init_collection(self, dense_vector_size: int) -> None: ) self.client.delete_collection(self.collection_name) - # Create with separate configs for dense & sparse self.client.create_collection( collection_name=self.collection_name, vectors_config={ - # only your dense side here self.dense_vector_name: VectorParams( size=dense_vector_size, distance=Distance.COSINE, ) }, sparse_vectors_config={ - # only your sparse side here, using SparseVectorParams self.sparse_vector_name: SparseVectorParams( index=qmodels.SparseIndexParams(on_disk=False) ) @@ -76,19 +110,13 @@ def init_collection(self, dense_vector_size: int) -> None: def get_client(self) -> QdrantClient: """ - Get the Qdrant client instance. - - Returns: - QdrantClient: The initialized Qdrant client. + Return the initialized Qdrant client instance. """ return self.client def get_collection_name(self) -> str: """ - Get the name of the current Qdrant collection. - - Returns: - str: The collection name. + Return the name of the current Qdrant collection. """ return self.collection_name @@ -98,13 +126,50 @@ def insert_documents( dense_embedder: Optional[Embeddings] = None, sparse_embedder: Optional[Embeddings] = None, ) -> None: + """ + Insert a list of LangChain Documents into the configured Qdrant collection, + initializing the collection if needed (using dense_embedder for dimension). + Args: + documents (List[Document]): The documents to insert. + dense_embedder (Optional[Embeddings]): Embedder for dense vectors. + sparse_embedder (Optional[Embeddings]): Embedder for sparse vectors. + """ + # Initialize collection only if needed and if dense_embedder is provided + if not self.client.collection_exists(self.collection_name) and dense_embedder: + sample_embedding = dense_embedder.embed_query("test") + dense_dim = len(sample_embedding) + self.init_collection(dense_vector_size=dense_dim) + vectorstore = self.as_langchain_vectorstore( dense_embedding=dense_embedder, sparse_embedding=sparse_embedder, ) - ids = [str(uuid.uuid4()) for _ in documents] - vectorstore.add_documents(documents=documents, ids=ids) + # Use external_id from metadata if available, otherwise generate UUID + ids = [] + processed_documents = [] + for doc in documents: + external_id = doc.metadata.get("external_id") + if external_id: + ids.append(str(external_id)) + # Ensure external_id is preserved in the document metadata + # Create a copy of the document with external_id explicitly in metadata + doc_copy = Document( + page_content=doc.page_content, + metadata={**doc.metadata, "external_id": str(external_id)} + ) + processed_documents.append(doc_copy) + else: + generated_id = str(uuid.uuid4()) + ids.append(generated_id) + # Add the generated ID to metadata as well + doc_copy = Document( + page_content=doc.page_content, + metadata={**doc.metadata, "external_id": generated_id} + ) + processed_documents.append(doc_copy) + + vectorstore.add_documents(documents=processed_documents, ids=ids) logger.info( f"Inserted {len(documents)} documents into '{self.collection_name}' " @@ -115,8 +180,12 @@ def as_langchain_vectorstore( self, dense_embedding: Optional[Embeddings] = None, sparse_embedding: Optional[Embeddings] = None, + strategy: Optional[str] = None ) -> QdrantVectorStore: - strategy = os.getenv("EMBEDDING_STRATEGY", "dense").lower() + """ + Returns a LangChain-compatible QdrantVectorStore based on the selected retrieval strategy. + """ + strategy = (strategy or self.strategy or "dense").lower() if strategy == "dense": return QdrantVectorStore( @@ -128,7 +197,6 @@ def as_langchain_vectorstore( sparse_vector_name=self.sparse_vector_name, retrieval_mode=RetrievalMode.DENSE, ) - elif strategy == "sparse": return QdrantVectorStore( client=self.client, @@ -137,7 +205,6 @@ def as_langchain_vectorstore( sparse_vector_name=self.sparse_vector_name, retrieval_mode=RetrievalMode.SPARSE, ) - elif strategy == "hybrid": return QdrantVectorStore( client=self.client, @@ -148,6 +215,6 @@ def as_langchain_vectorstore( sparse_vector_name=self.sparse_vector_name, retrieval_mode=RetrievalMode.HYBRID, ) - else: + logger.error(f"Invalid EMBEDDING_STRATEGY: {strategy}") raise ValueError(f"Invalid EMBEDDING_STRATEGY: {strategy}") diff --git a/datasets/sosum b/datasets/sosum new file mode 160000 index 0000000..4587714 --- /dev/null +++ b/datasets/sosum @@ -0,0 +1 @@ +Subproject commit 4587714b1efabc75725751ee5f1fde64c6480734 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..ab4c6dd --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,46 @@ +version: '3.8' + +services: + qdrant: + image: qdrant/qdrant + ports: + - "6333:6333" + volumes: + - qdrant_data:/qdrant/storage + + postgres: + image: postgres:14 + environment: + POSTGRES_USER: admin + POSTGRES_PASSWORD: admin + POSTGRES_DB: tableDB + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + + app: + build: + context: . + dockerfile: Dockerfile + environment: + - QDRANT_HOST=qdrant + - QDRANT_PORT=6333 + - POSTGRES_HOST=postgres + - POSTGRES_PORT=5432 + - POSTGRES_USER=admin + - POSTGRES_PASSWORD=admin + - POSTGRES_DB=tableDB + volumes: + - .:/app + depends_on: + - qdrant + - postgres + ports: + - "8000:8000" # optional, if you add an API later + working_dir: /app + command: ["tail", "-f", "/dev/null"] # Replace with your real entrypoint + +volumes: + qdrant_data: + postgres_data: diff --git a/docs/MLOPS_PIPELINE_ARCHITECTURE.md b/docs/MLOPS_PIPELINE_ARCHITECTURE.md new file mode 100644 index 0000000..f75b4ff --- /dev/null +++ b/docs/MLOPS_PIPELINE_ARCHITECTURE.md @@ -0,0 +1,756 @@ +# MLOps Pipeline Architecture for RAG Systems + +## Table of Contents +1. [Introduction to MLOps for RAG](#introduction) +2. [Overall Architecture](#overall-architecture) +3. [Core Pipeline Components](#core-components) +4. [Data Flow and Processing](#data-flow) +5. [MLOps Principles Implementation](#mlops-principles) +6. [Configuration Management](#configuration) +7. [Reproducibility and Versioning](#reproducibility) +8. [Monitoring and Observability](#monitoring) +9. [Advantages and Trade-offs](#advantages-tradeoffs) +10. [How to Reproduce in Other Projects](#reproduction-guide) + +## 1. Introduction to MLOps for RAG {#introduction} + +### What is MLOps? +MLOps (Machine Learning Operations) is a set of practices that combines Machine Learning, DevOps, and Data Engineering to deploy and maintain ML systems in production reliably and efficiently. + +### Why MLOps for RAG Systems? +Retrieval-Augmented Generation (RAG) systems have unique challenges: +- **Data Pipeline Complexity**: Multiple data sources, formats, and processing steps +- **Model Dependencies**: Embedding models, chunking strategies, retrieval algorithms +- **Evaluation Complexity**: Measuring retrieval quality and generation performance +- **Version Management**: Dataset versions, model versions, configuration versions +- **Experimentation**: Comparing different embedding models and retrieval strategies + +Our pipeline addresses these challenges through systematic MLOps practices. + +## 2. Overall Architecture {#overall-architecture} + +``` +┌────────────────────────────────────────────────────────────────────┐ +│ RAG MLOps Pipeline │ +├────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ Raw Data │───>│ Adapters │───>│ Validation │ │ +│ │ (Multiple │ │ (Dataset │ │ & Quality │ │ +│ │ Sources) │ │ Specific) │ │ Checks │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ Vector Store│◀──│ Embedder │◀─ | Chunker │ │ +│ │ (Qdrant) │ │ (Multiple │ │ (Strategy │ │ +│ │ │ │ Strategies) │ │ Based) │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ Evaluation │◀──│ Smoke Tests │◀─ │ Lineage │ │ +│ │ Framework │ │ & Quality │ │ Tracking │ │ +│ │ │ │ Assurance │ │ │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +│ │ +└────────────────────────────────────────────────────────────────────| +``` + +### Key Design Principles: +1. **Modularity**: Each component has a single responsibility +2. **Extensibility**: Easy to add new datasets, embedders, or evaluation metrics +3. **Reproducibility**: Deterministic IDs, versioning, and configuration management +4. **Observability**: Comprehensive logging, metrics, and lineage tracking +5. **Safety**: Dry-run modes, canary deployments, and validation checks + +## 3. Core Pipeline Components {#core-components} + +### 3.1 Dataset Adapters (`pipelines/adapters/`) + +**Purpose**: Convert raw dataset formats into standardized LangChain Documents + +**Architecture**: +```python +class DatasetAdapter(ABC): + @abstractmethod + def read_rows(self, split: DatasetSplit) -> Iterable[BaseRow]: + """Read raw dataset rows""" + + @abstractmethod + def to_documents(self, rows: List[BaseRow], split: DatasetSplit) -> List[Document]: + """Convert to standardized documents""" + + @abstractmethod + def get_evaluation_queries(self, split: DatasetSplit) -> List[Dict[str, Any]]: + """Provide evaluation queries""" +``` + +**Advantages**: +- ✅ **Dataset Agnostic**: Same pipeline works for any dataset +- ✅ **Type Safety**: Pydantic schemas ensure data validity +- ✅ **Extensibility**: Easy to add new datasets +- ✅ **Evaluation Integration**: Built-in evaluation query generation + +**Trade-offs**: +- ❌ **Initial Overhead**: Requires implementing adapter for each dataset +- ❌ **Memory Usage**: Loads entire dataset into memory (can be optimized with streaming) + +**Example Implementation** (StackOverflow): +```python +class StackOverflowAdapter(DatasetAdapter): + def __init__(self, data_path: str): + self.data_path = Path(data_path) + # Load questions and answers from CSV files + + def to_documents(self, rows: List[StackOverflowRow], split: DatasetSplit) -> List[Document]: + documents = [] + for row in rows: + # Create question document + doc = Document( + page_content=f"Question: {row.title}\n\n{row.body}", + metadata={ + "source": self.source_name, + "external_id": f"q_{row.question_id}", + "split": split.value, + "type": "question", + "tags": row.tags + } + ) + documents.append(doc) + return documents +``` + +### 3.2 Validation System (`pipelines/ingest/validator.py`) + +**Purpose**: Ensure data quality before processing + +**Components**: +- **Character Validation**: Check for problematic characters +- **Content Validation**: Ensure minimum content requirements +- **Metadata Validation**: Verify required fields exist + +**Advantages**: +- ✅ **Early Error Detection**: Catch issues before expensive embedding generation +- ✅ **Data Quality Assurance**: Consistent quality across datasets +- ✅ **Configurable Rules**: Different validation rules per dataset type + +**Trade-offs**: +- ❌ **Processing Overhead**: Additional validation step +- ❌ **False Positives**: May flag valid content (e.g., HTML in code examples) + +**Example Validation Rules**: +```python +def validate_document(self, doc: Document) -> ValidationResult: + errors = [] + + # Check minimum length + if len(doc.page_content.strip()) < self.min_length: + errors.append(f"Content too short: {len(doc.page_content)} < {self.min_length}") + + # Check for required metadata + if not doc.metadata.get("external_id"): + errors.append("Missing external_id in metadata") + + return ValidationResult( + valid=len(errors) == 0, + doc_id=doc.metadata.get("external_id", "unknown"), + errors=errors + ) +``` + +### 3.3 Chunking System (`pipelines/ingest/chunker.py`) + +**Purpose**: Split documents into optimal chunks for embedding and retrieval + +**Strategies**: +- **Recursive Character Splitting**: Split by paragraphs, then sentences, then characters +- **Token-based Splitting**: Split based on token count for transformer models +- **Semantic Splitting**: Future enhancement for content-aware splitting + +**Advantages**: +- ✅ **Strategy Flexibility**: Multiple chunking approaches +- ✅ **Deterministic Results**: Same input always produces same chunks +- ✅ **Metadata Preservation**: Chunk metadata tracks source document + +**Trade-offs**: +- ❌ **Context Loss**: Splitting may break semantic coherence +- ❌ **Parameter Sensitivity**: Chunk size affects retrieval quality + +**Configuration Example**: +```yaml +chunking: + strategy: "recursive_character" + chunk_size: 1000 + chunk_overlap: 200 + separators: ["\n\n", "\n", " ", ""] +``` + +### 3.4 Embedding System (`pipelines/ingest/embedder.py`) + +**Purpose**: Generate dense and sparse embeddings for semantic search + +**Strategies**: +- **Dense Embeddings**: Sentence transformers (e.g., MiniLM, BGE, E5) +- **Sparse Embeddings**: TF-IDF, BM25 (future: SPLADE) +- **Hybrid Embeddings**: Combination of dense and sparse + +**Advantages**: +- ✅ **Multiple Strategies**: Compare different embedding approaches +- ✅ **Caching**: Avoid recomputing embeddings +- ✅ **Batch Processing**: Efficient GPU utilization +- ✅ **Error Handling**: Graceful fallbacks for failed embeddings + +**Trade-offs**: +- ❌ **Computational Cost**: Embedding generation is expensive +- ❌ **Model Dependencies**: Different models require different environments +- ❌ **Storage Requirements**: Embeddings consume significant space + +**Example Configuration**: +```yaml +embedding: + strategy: "hybrid" # dense, sparse, or hybrid + dense: + provider: "hf" + model: "sentence-transformers/all-MiniLM-L6-v2" + sparse: + provider: "hf" + model: "sentence-transformers/all-MiniLM-L6-v2" + batch_size: 32 + cache_enabled: true +``` + +### 3.5 Vector Store Integration (`pipelines/ingest/uploader.py`) + +**Purpose**: Upload embeddings to vector database with proper indexing + +**Features**: +- **Collection Management**: Automatic collection creation and configuration +- **Batch Uploads**: Efficient bulk operations +- **Canary Deployments**: Safe testing with temporary collections +- **Metadata Storage**: Rich metadata for filtering and retrieval + +**Advantages**: +- ✅ **Scalability**: Handles large datasets efficiently +- ✅ **Safety**: Canary mode prevents affecting production +- ✅ **Flexibility**: Multiple vector stores supported (Qdrant primary) + +**Trade-offs**: +- ❌ **Infrastructure Dependency**: Requires vector database setup +- ❌ **Network Overhead**: Upload time depends on network and data size + +### 3.6 Lineage Tracking (`pipelines/contracts.py`) + +**Purpose**: Track complete provenance of processed data + +**Information Tracked**: +- **Data Provenance**: Source dataset, version, and split +- **Processing History**: Chunking strategy, embedding model, configuration +- **Code Provenance**: Git commit hash, configuration hash +- **Quality Metrics**: Success/failure counts, validation results + +**Advantages**: +- ✅ **Reproducibility**: Complete history for debugging and reproduction +- ✅ **Compliance**: Audit trail for data governance +- ✅ **Debugging**: Easy to trace issues to specific configurations + +**Trade-offs**: +- ❌ **Storage Overhead**: Additional metadata storage +- ❌ **Complexity**: More fields to maintain and track + +## 4. Data Flow and Processing {#data-flow} + +### Step-by-Step Processing Flow: + +``` +1. Configuration Loading + ├── Load YAML config file + ├── Validate configuration schema + └── Initialize components with config + +2. Data Ingestion + ├── Adapter reads raw data files + ├── Convert to standardized BaseRow objects + └── Generate LangChain Documents + +3. Validation + ├── Check document content quality + ├── Validate required metadata fields + └── Filter out invalid documents + +4. Chunking + ├── Split documents using configured strategy + ├── Generate deterministic chunk IDs + └── Preserve metadata and provenance + +5. Embedding Generation + ├── Process chunks in batches + ├── Generate dense/sparse embeddings + └── Cache results for efficiency + +6. Vector Store Upload + ├── Create/configure collection + ├── Upload chunks with embeddings + └── Verify upload success + +7. Quality Assurance + ├── Run smoke tests + ├── Validate retrieval functionality + └── Generate quality reports + +8. Lineage Recording + ├── Save complete processing history + ├── Record configuration and results + └── Enable reproduction and debugging +``` + +### ID Generation Strategy: + +```python +# Deterministic Document ID +doc_hash = sha256(normalized_content).hexdigest()[:12] +doc_id = f"{source}:{external_id}:{doc_hash}" + +# Deterministic Chunk ID +chunk_id = f"{doc_id}#c{chunk_index:04d}" +``` + +**Benefits**: +- ✅ **Idempotency**: Rerunning pipeline produces same IDs +- ✅ **Deduplication**: Same content gets same ID across runs +- ✅ **Traceability**: Easy to trace chunks back to source documents + +## 5. MLOps Principles Implementation {#mlops-principles} + +### 5.1 Reproducibility + +**Implementation**: +- **Deterministic IDs**: Content-based hashing ensures same results +- **Configuration Versioning**: YAML configs tracked in git +- **Environment Specification**: Requirements.txt pins exact versions +- **Data Versioning**: Dataset versions tracked in metadata + +**Example**: +```yaml +# Configuration is versioned and tracked +dataset: + name: "stackoverflow" + version: "1.0.0" + path: "/data/sosum" + +embedding: + model: "sentence-transformers/all-MiniLM-L6-v2" + # Exact model version ensures reproducibility +``` + +### 5.2 Experimentation + +**A/B Testing Support**: +```yaml +# Different configs for comparing embedding models +collection_name: "sosum_stackoverflow_minilm_v1" # MiniLM experiment +collection_name: "sosum_stackoverflow_bge_large_v1" # BGE Large experiment +collection_name: "sosum_stackoverflow_e5_large_v1" # E5 Large experiment +``` + +**Benefits**: +- ✅ **Safe Comparison**: Each experiment uses separate collection +- ✅ **Parallel Testing**: Multiple configurations can run simultaneously +- ✅ **Easy Rollback**: Keep previous versions available + +### 5.3 Monitoring and Observability + +**Logging Strategy**: +```python +logger.info(f"Processing {len(documents)} documents with {strategy} strategy") +logger.warning(f"Validation errors found: {validation_errors}") +logger.error(f"Embedding generation failed: {error}") +``` + +**Metrics Tracked**: +- Processing times per component +- Success/failure rates +- Data quality metrics +- Embedding generation statistics + +### 5.4 Quality Assurance + +**Multi-layer Validation**: +1. **Input Validation**: Check raw data quality +2. **Processing Validation**: Verify each transformation step +3. **Output Validation**: Smoke tests on final results +4. **End-to-end Testing**: Retrieval quality evaluation + +## 6. Configuration Management {#configuration} + +### Configuration Schema: +```yaml +# Dataset Configuration +dataset: + name: "stackoverflow" + version: "1.0.0" + adapter: "stackoverflow" + path: "/path/to/data" + +# Processing Configuration +chunking: + strategy: "recursive_character" + chunk_size: 1000 + chunk_overlap: 200 + +embedding: + strategy: "dense" # dense, sparse, hybrid + provider: "hf" + model: "sentence-transformers/all-MiniLM-L6-v2" + batch_size: 32 + +# Infrastructure Configuration +vector_store: + provider: "qdrant" + collection_name: "sosum_stackoverflow_minilm_v1" + distance_metric: "cosine" + +# Experiment Configuration +experiment: + name: "minilm_baseline" + description: "Baseline with MiniLM embeddings" + canary: false + max_documents: null # null = no limit +``` + +### Configuration Benefits: +- ✅ **Declarative**: Infrastructure as code approach +- ✅ **Version Controlled**: Track configuration changes +- ✅ **Environment Specific**: Different configs for dev/staging/prod +- ✅ **Validation**: Schema validation prevents configuration errors + +## 7. Reproducibility and Versioning {#reproducibility} + +### Version Management Strategy: + +```python +class ChunkMeta(BaseModel): + # Identity and Content + doc_id: str # Deterministic based on content + chunk_id: str # Deterministic based on doc_id + index + doc_sha256: str # Content hash for integrity + + # Provenance Tracking + source: str # Dataset name + dataset_version: str # Dataset version + git_commit: str # Code version when processed + config_hash: str # Configuration hash + + # Processing Metadata + embedding_model: str # Exact model used + chunk_strategy: dict # Chunking parameters used +``` + +### Reproduction Steps: +1. **Checkout Code**: Use git commit from lineage record +2. **Load Configuration**: Use exact config from lineage record +3. **Install Dependencies**: Use requirements.txt from that commit +4. **Run Pipeline**: Should produce identical results + +### Benefits: +- ✅ **Full Traceability**: Know exactly how any chunk was created +- ✅ **Bug Investigation**: Reproduce issues from production +- ✅ **Compliance**: Meet audit requirements for data processing + +## 8. Monitoring and Observability {#monitoring} + +### Observability Stack: + +```python +# Structured Logging +logger.info( + "Embedding generation completed", + extra={ + "component": "embedder", + "strategy": "dense", + "model": "all-MiniLM-L6-v2", + "batch_size": 32, + "chunks_processed": 1500, + "processing_time": 45.2, + "success_rate": 0.998 + } +) +``` + +### Key Metrics: +- **Throughput**: Documents/chunks processed per minute +- **Quality**: Validation success rates, embedding generation success +- **Performance**: Processing time per component +- **Resource Usage**: Memory, CPU, GPU utilization +- **Error Rates**: Failed documents, failed embeddings + +### Alerts and Monitoring: +- High failure rates in validation or embedding +- Processing time exceeding thresholds +- Resource utilization issues +- Data quality degradation + +## 9. Advantages and Trade-offs {#advantages-tradeoffs} + +### Overall Architecture Advantages: + +✅ **Modularity**: +- Easy to replace components (e.g., switch from Qdrant to Pinecone) +- Test components in isolation +- Parallel development by different team members + +✅ **Reproducibility**: +- Deterministic results across runs +- Complete provenance tracking +- Easy debugging and issue reproduction + +✅ **Scalability**: +- Batch processing for efficiency +- Horizontal scaling of individual components +- Streaming support for large datasets (future enhancement) + +✅ **Experimentation**: +- A/B testing different configurations +- Safe canary deployments +- Easy comparison of approaches + +✅ **Quality Assurance**: +- Multi-layer validation +- Automated testing and smoke tests +- Continuous monitoring + +### Trade-offs and Limitations: + +❌ **Complexity**: +- More complex than simple scripts +- Requires understanding of MLOps concepts +- More code to maintain + +❌ **Initial Setup Cost**: +- Significant upfront investment +- Infrastructure dependencies (vector database, etc.) +- Learning curve for team members + +❌ **Resource Requirements**: +- Embedding generation requires computational resources +- Vector storage requires significant disk space +- Caching increases memory usage + +❌ **Vendor Dependencies**: +- Qdrant for vector storage +- HuggingFace for embedding models +- Specific Python version and libraries + +### When to Use This Architecture: + +**Good Fit**: +- Multiple datasets to process +- Need for experimentation and comparison +- Production RAG systems requiring reliability +- Teams needing reproducibility and compliance +- Long-term projects requiring maintenance + +**Not Ideal For**: +- One-off experiments or prototypes +- Very small datasets (< 1000 documents) +- Teams without MLOps experience +- Projects with tight deadlines +- Limited computational resources + +## 10. How to Reproduce in Other Projects {#reproduction-guide} + +### 10.1 Project Structure Setup + +``` +your_project/ +├── pipelines/ +│ ├── contracts.py # Base schemas and interfaces +│ ├── adapters/ # Dataset-specific adapters +│ │ ├── your_dataset.py +│ │ └── another_dataset.py +│ ├── ingest/ # Core processing components +│ │ ├── validator.py +│ │ ├── chunker.py +│ │ ├── embedder.py +│ │ ├── uploader.py +│ │ └── pipeline.py +│ ├── configs/ # Configuration files +│ │ ├── baseline.yml +│ │ └── experiment.yml +│ └── eval/ # Evaluation framework +│ └── evaluator.py +├── bin/ +│ └── ingest.py # CLI interface +├── docs/ +│ └── architecture.md +└── requirements.txt +``` + +### 10.2 Implementation Steps + +#### Step 1: Define Core Contracts +```python +# pipelines/contracts.py +from abc import ABC, abstractmethod +from pydantic import BaseModel +from enum import Enum + +class DatasetSplit(str, Enum): + TRAIN = "train" + TEST = "test" + ALL = "all" + +class BaseRow(BaseModel): + external_id: str + class Config: + extra = "allow" + +class DatasetAdapter(ABC): + @abstractmethod + def read_rows(self, split: DatasetSplit) -> Iterable[BaseRow]: + pass + + @abstractmethod + def to_documents(self, rows: List[BaseRow], split: DatasetSplit) -> List[Document]: + pass +``` + +#### Step 2: Implement Dataset Adapter +```python +# pipelines/adapters/your_dataset.py +class YourDatasetAdapter(DatasetAdapter): + def __init__(self, data_path: str): + self.data_path = Path(data_path) + + @property + def source_name(self) -> str: + return "your_dataset" + + def read_rows(self, split: DatasetSplit) -> Iterable[YourDatasetRow]: + # Load your data format (CSV, JSON, etc.) + for item in self._load_data(): + yield YourDatasetRow(**item) + + def to_documents(self, rows: List[YourDatasetRow], split: DatasetSplit) -> List[Document]: + documents = [] + for row in rows: + doc = Document( + page_content=row.content, + metadata={ + "source": self.source_name, + "external_id": row.id, + "split": split.value, + # Add your specific metadata + } + ) + documents.append(doc) + return documents +``` + +#### Step 3: Configure Processing Pipeline +```yaml +# pipelines/configs/your_config.yml +dataset: + name: "your_dataset" + version: "1.0.0" + adapter: "your_dataset" + path: "/path/to/your/data" + +chunking: + strategy: "recursive_character" + chunk_size: 1000 + chunk_overlap: 200 + +embedding: + strategy: "dense" + provider: "hf" + model: "sentence-transformers/all-MiniLM-L6-v2" + +vector_store: + provider: "qdrant" + collection_name: "your_dataset_v1" +``` + +#### Step 4: Run Pipeline +```bash +# Install dependencies +pip install -r requirements.txt + +# Start vector database (if using Qdrant) +docker run -p 6333:6333 qdrant/qdrant + +# Run ingestion +python bin/ingest.py --config pipelines/configs/your_config.yml \ + ingest your_dataset /path/to/data --dry-run --max-docs 100 + +# Run without dry-run when ready +python bin/ingest.py --config pipelines/configs/your_config.yml \ + ingest your_dataset /path/to/data +``` + +### 10.3 Customization Points + +#### Custom Validation Rules: +```python +def validate_document(self, doc: Document) -> ValidationResult: + errors = [] + + # Your domain-specific validation + if "required_field" not in doc.metadata: + errors.append("Missing required field") + + # Custom content checks + if len(doc.page_content.split()) < 10: + errors.append("Content too short") + + return ValidationResult( + valid=len(errors) == 0, + errors=errors + ) +``` + +#### Custom Embedding Provider: +```python +class CustomEmbedder: + def __init__(self, config: Dict[str, Any]): + self.model = load_your_model(config["model_path"]) + + def embed_query(self, text: str) -> List[float]: + return self.model.encode(text).tolist() +``` + +#### Custom Evaluation Metrics: +```python +def evaluate_retrieval(self, queries: List[str], ground_truth: List[List[str]]) -> Dict[str, float]: + # Implement your evaluation logic + results = {} + for k in [1, 3, 5, 10]: + results[f"recall_at_{k}"] = compute_recall_at_k(predictions, ground_truth, k) + return results +``` + +### 10.4 Best Practices for Adaptation + +1. **Start Simple**: Begin with basic adapter and gradually add features +2. **Use Type Hints**: Leverage Pydantic for data validation and documentation +3. **Test Components**: Write unit tests for each component +4. **Configuration First**: Make everything configurable from YAML files +5. **Log Everything**: Add comprehensive logging for debugging +6. **Version Everything**: Track dataset versions, model versions, and code versions +7. **Validate Early**: Catch data quality issues as early as possible +8. **Plan for Scale**: Consider memory and compute requirements +9. **Document Decisions**: Explain why certain approaches were chosen +10. **Monitor in Production**: Set up alerts and monitoring for production systems + +### 10.5 Common Pitfalls to Avoid + +1. **Hard-coded Paths**: Use configuration files instead +2. **Missing Error Handling**: Plan for failures in each component +3. **No Rollback Strategy**: Always have a way to revert changes +4. **Insufficient Testing**: Test with small datasets first +5. **Ignoring Resource Limits**: Monitor memory and disk usage +6. **No Backup Strategy**: Plan for data and model backup +7. **Vendor Lock-in**: Design for portability between providers +8. **Poor Documentation**: Document configuration options and troubleshooting + +This architecture provides a solid foundation for MLOps in RAG systems. The key is to start with the core components and gradually add complexity based on your specific needs. The modular design ensures you can adapt it to different domains, datasets, and requirements while maintaining the benefits of reproducibility, scalability, and quality assurance. diff --git a/docs/PROJECT_STRUCTURE.md b/docs/PROJECT_STRUCTURE.md new file mode 100644 index 0000000..ce14846 --- /dev/null +++ b/docs/PROJECT_STRUCTURE.md @@ -0,0 +1,232 @@ +# Project Structure Documentation + +This document describes the current organization of the RAG retrieval pipeline project after the cleanup and reorganization. + +## 📁 Core Project Structure + +### Main Application +``` +├── main.py # Main application entry point +├── config.yml # Main configuration file +├── .env # Environment variables +├── requirements.txt # Python dependencies +└── README.md # Project documentation +``` + +### Agent System +``` +agent/ +├── __init__.py +├── graph.py # LangGraph agent workflow +├── schema.py # Agent state schema +└── nodes/ + └── retriever.py # Configurable retriever node +``` + +### Components (Modular Pipeline System) +``` +components/ +├── retrieval_pipeline.py # Core pipeline framework +├── rerankers.py # Reranking components +├── filters.py # Filtering components +└── advanced_rerankers.py # Advanced reranking implementations +``` + +### Configuration Management +``` +config/ +├── __init__.py +└── config_loader.py # Configuration loading utilities +``` + +### Database Controllers +``` +database/ +├── __init__.py +├── base.py # Base database interface +├── postgres_controller.py # PostgreSQL controller +└── qdrant_controller.py # Qdrant vector database controller +``` + +### Embedding System +``` +embedding/ +├── __init__.py +├── factory.py # Embedding factory +├── bedrock_embeddings.py # AWS Bedrock embeddings +├── hf_embedder.py # HuggingFace embeddings +├── processor.py # Embedding processing +├── recursive_splitter.py # Document splitting +├── sparse_embedder.py # Sparse embeddings +├── splitter.py # Text splitting utilities +└── utils.py # Embedding utilities +``` + +### Pipeline Configurations +``` +pipelines/ +├── configs/ +│ └── retrieval/ # YAML retrieval configurations +│ ├── stackoverflow_minilm.yml +│ ├── hybrid_basic.yml +│ └── advanced_ensemble.yml +├── adapters/ # Data adapters +└── ingest/ # Ingestion pipelines +``` + +### CLI Tools +``` +bin/ +├── agent_retriever.py # CLI agent retriever +├── switch_agent_config.py # Configuration switching utility +└── qdrant_inspector.py # Qdrant inspection tool +``` + +### Examples +``` +examples/ +├── simple_qa_agent.py # Simple Q&A agent example +└── (other examples...) +``` + +## 🧪 Test Organization + +All tests are now organized under the `tests/` directory with clear categorization: + +### Test Structure +``` +tests/ +├── run_all_tests.py # Main test runner +├── test_agent_retrieval.py # Agent integration tests +├── agent/ # Agent-specific tests +│ └── test_retriever_node.py +├── components/ # Component unit tests +│ ├── test_retrieval_pipeline.py +│ └── test_rerankers.py +├── retrieval/ # Retrieval system tests +│ ├── test_extensibility.py +│ ├── test_modular_pipeline.py +│ ├── test_advanced_rerankers.py +│ └── test_answer_retrieval.py +├── ingestion/ # Data ingestion tests +│ ├── test_new_adapter.py +│ └── test_adapter_qa.py +├── embedding/ # Embedding system tests +│ └── test_sparse_embeddings.py +├── examples/ # Example tests +│ ├── test_sosum_minimal.py +│ └── test_sosum_adapter.py +├── pipelines/ # Pipeline tests +│ └── smoke_tests.py +└── benchmarks/ # Performance tests + ├── retriever_test.py + └── test_aws_connection.py +``` + +### Test Categories + +1. **Unit Tests** (`tests/components/`): Test individual components in isolation +2. **Integration Tests** (`tests/agent/`, `tests/retrieval/`): Test component interactions +3. **System Tests** (`tests/examples/`, `tests/pipelines/`): End-to-end testing +4. **Performance Tests** (`tests/benchmarks/`): Performance and load testing + +## 🗂️ Deprecated Code + +All obsolete code has been moved to the `deprecated/` directory: + +``` +deprecated/ +├── old_debug_scripts/ # Debug and analysis scripts +├── old_playground/ # Experimental code +├── old_processors/ # Legacy processor implementations +└── old_tests/ # Superseded test files +``` + +## 📋 Running Tests + +### Run All Tests +```bash +python tests/run_all_tests.py +``` + +### Run Specific Test Categories +```bash +# Component tests +python -m pytest tests/components/ + +# Agent tests +python -m pytest tests/agent/ + +# Retrieval tests +python -m pytest tests/retrieval/ + +# Integration tests +python tests/test_agent_retrieval.py +``` + +### Run Individual Tests +```bash +python tests/components/test_rerankers.py +python tests/agent/test_retriever_node.py +``` + +## 🔧 Configuration Management + +### Pipeline Configurations +- **Location**: `pipelines/configs/retrieval/` +- **Format**: YAML files defining retrieval pipelines +- **Switching**: Use `bin/switch_agent_config.py` + +### Environment Configuration +- **Main Config**: `config.yml` +- **Environment Variables**: `.env` +- **Loading**: Via `config/config_loader.py` + +## 📚 Documentation + +### User Guides +``` +docs/ +├── AGENT_INTEGRATION.md # Agent integration guide +├── EXTENSIBILITY.md # How to extend the system +├── SYSTEM_EXTENSION_GUIDE.md # System extension guide +└── CODE_CLEANUP_SUMMARY.md # Cleanup summary +``` + +### API Documentation +- Docstrings in all major components +- Type hints throughout codebase +- Configuration examples in YAML files + +## 🚀 Getting Started + +1. **Install Dependencies**: + ```bash + pip install -r requirements.txt + ``` + +2. **Configure Environment**: + ```bash + cp .env_example .env + # Edit .env with your settings + ``` + +3. **Run Basic Tests**: + ```bash + python tests/run_all_tests.py + ``` + +4. **Start the Agent**: + ```bash + python main.py + ``` + +## 🎯 Key Features + +- **Modular Design**: Easy to add/remove components +- **YAML Configuration**: Flexible pipeline configuration +- **Comprehensive Testing**: Full test coverage +- **Clear Documentation**: Extensive guides and examples +- **Clean Architecture**: Well-organized codebase +- **Type Safety**: Full type hints +- **Extensible**: Easy to add new components diff --git a/docs/QUICK_START_GUIDE.md b/docs/QUICK_START_GUIDE.md new file mode 100644 index 0000000..5f0c4c1 --- /dev/null +++ b/docs/QUICK_START_GUIDE.md @@ -0,0 +1,658 @@ +# Quick Start Guide: Implementing MLOps Pipeline for RAG + +This guide provides a step-by-step walkthrough for implementing the MLOps pipeline architecture in your own projects. + +## Prerequisites + +- Python 3.9+ +- Docker (for vector database) +- Git (for version control) +- Basic understanding of ML and RAG concepts + +## 1. Project Initialization (15 minutes) + +### Create Project Structure +```bash +mkdir my-rag-project +cd my-rag-project + +# Create directory structure +mkdir -p {pipelines/{adapters,ingest,configs,eval},bin,docs,tests} +mkdir -p {embedding,database,logs/utils,examples,scripts} + +# Initialize git repository +git init +``` + +### Setup Python Environment +```bash +# Create virtual environment +python -m venv .venv +source .venv/bin/activate # Linux/Mac +# .venv\Scripts\activate # Windows + +# Install core dependencies +pip install pydantic langchain langchain-core qdrant-client sentence-transformers pandas pyyaml python-dotenv +``` + +## 2. Implement Core Contracts (30 minutes) + +### Create Base Contracts (`pipelines/contracts.py`) +```python +"""Core contracts for the RAG pipeline.""" +import hashlib +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Dict, List, Optional, Any, Iterable +from enum import Enum +from pathlib import Path + +from pydantic import BaseModel, Field +from langchain_core.documents import Document + +class DatasetSplit(str, Enum): + TRAIN = "train" + VALIDATION = "val" + TEST = "test" + ALL = "all" + +class BaseRow(BaseModel): + """Base schema for dataset rows.""" + external_id: str = Field(..., description="Unique identifier from source") + + class Config: + extra = "allow" + +class ChunkMeta(BaseModel): + """Metadata for processed chunks.""" + # Identity + doc_id: str + chunk_id: str + doc_sha256: str + text: str + + # Source + source: str + dataset_version: str + external_id: str + + # Processing + chunk_index: int + num_chunks: int + char_count: int + split: DatasetSplit + + # Pipeline metadata + ingested_at: datetime = Field(default_factory=datetime.utcnow) + git_commit: Optional[str] = None + config_hash: Optional[str] = None + + # Embeddings + embedding_model: Optional[str] = None + embedding_dim: Optional[int] = None + dense_embedding: Optional[List[float]] = None + sparse_embedding: Optional[Dict[int, float]] = None + + # Additional metadata + labels: Dict[str, Any] = Field(default_factory=dict) + +class DatasetAdapter(ABC): + """Abstract adapter for datasets.""" + + @property + @abstractmethod + def source_name(self) -> str: + pass + + @property + @abstractmethod + def version(self) -> str: + pass + + @abstractmethod + def read_rows(self, split: DatasetSplit = DatasetSplit.ALL) -> Iterable[BaseRow]: + pass + + @abstractmethod + def to_documents(self, rows: List[BaseRow], split: DatasetSplit) -> List[Document]: + pass + + @abstractmethod + def get_evaluation_queries(self, split: DatasetSplit = DatasetSplit.TEST) -> List[Dict[str, Any]]: + pass + +# Utility functions +def compute_content_hash(text: str) -> str: + """Compute SHA256 hash of normalized text.""" + normalized = " ".join(text.strip().split()) + return hashlib.sha256(normalized.encode('utf-8')).hexdigest() + +def build_doc_id(source: str, external_id: str, content_hash: str) -> str: + """Build deterministic document ID.""" + return f"{source}:{external_id}:{content_hash[:12]}" + +def build_chunk_id(doc_id: str, chunk_index: int) -> str: + """Build deterministic chunk ID.""" + return f"{doc_id}#c{chunk_index:04d}" +``` + +## 3. Implement Your First Dataset Adapter (45 minutes) + +### Example: CSV Dataset Adapter (`pipelines/adapters/csv_dataset.py`) +```python +"""CSV dataset adapter example.""" +import pandas as pd +from pathlib import Path +from typing import Iterable, List, Dict, Any + +from langchain_core.documents import Document +from pipelines.contracts import DatasetAdapter, BaseRow, DatasetSplit + +class CSVRow(BaseRow): + """Schema for CSV dataset rows.""" + title: str + content: str + category: Optional[str] = None + +class CSVDatasetAdapter(DatasetAdapter): + """Adapter for CSV-based datasets.""" + + def __init__(self, data_path: str, text_column: str = "content", id_column: str = "id"): + self.data_path = Path(data_path) + self.text_column = text_column + self.id_column = id_column + + if not self.data_path.exists(): + raise FileNotFoundError(f"Dataset not found: {self.data_path}") + + @property + def source_name(self) -> str: + return "csv_dataset" + + @property + def version(self) -> str: + return "1.0.0" + + def read_rows(self, split: DatasetSplit = DatasetSplit.ALL) -> Iterable[CSVRow]: + """Read rows from CSV file.""" + if self.data_path.is_file(): + # Single CSV file + df = pd.read_csv(self.data_path) + else: + # Directory with split files + split_file = self.data_path / f"{split.value}.csv" + if not split_file.exists() and split == DatasetSplit.ALL: + # Try common filenames + for filename in ["data.csv", "dataset.csv", "train.csv"]: + split_file = self.data_path / filename + if split_file.exists(): + break + + if not split_file.exists(): + raise FileNotFoundError(f"Split file not found: {split_file}") + + df = pd.read_csv(split_file) + + for _, row in df.iterrows(): + yield CSVRow( + external_id=str(row[self.id_column]), + title=row.get("title", ""), + content=row[self.text_column], + category=row.get("category") + ) + + def to_documents(self, rows: List[CSVRow], split: DatasetSplit) -> List[Document]: + """Convert rows to LangChain documents.""" + documents = [] + + for row in rows: + # Combine title and content + if row.title: + content = f"{row.title}\n\n{row.content}" + else: + content = row.content + + doc = Document( + page_content=content, + metadata={ + "source": self.source_name, + "external_id": row.external_id, + "split": split.value, + "title": row.title, + "category": row.category, + "dataset_version": self.version + } + ) + documents.append(doc) + + return documents + + def get_evaluation_queries(self, split: DatasetSplit = DatasetSplit.TEST) -> List[Dict[str, Any]]: + """Generate evaluation queries.""" + # Simple approach: use titles as queries + queries = [] + for row in self.read_rows(split): + if row.title: + queries.append({ + "query": row.title, + "expected_doc_id": row.external_id, + "category": row.category + }) + + return queries[:100] # Limit for testing +``` + +## 4. Create Configuration System (20 minutes) + +### Configuration Schema (`pipelines/configs/config_schema.py`) +```python +"""Configuration schema validation.""" +from pydantic import BaseModel, Field +from typing import Dict, Any, Optional + +class DatasetConfig(BaseModel): + name: str + version: str + adapter: str + path: str + +class ChunkingConfig(BaseModel): + strategy: str = "recursive_character" + chunk_size: int = 1000 + chunk_overlap: int = 200 + +class EmbeddingConfig(BaseModel): + strategy: str = "dense" # dense, sparse, hybrid + provider: str = "hf" + model: str = "sentence-transformers/all-MiniLM-L6-v2" + batch_size: int = 32 + +class VectorStoreConfig(BaseModel): + provider: str = "qdrant" + collection_name: str + host: str = "localhost" + port: int = 6333 + +class PipelineConfig(BaseModel): + dataset: DatasetConfig + chunking: ChunkingConfig + embedding: EmbeddingConfig + vector_store: VectorStoreConfig + + # Optional experiment settings + experiment: Optional[Dict[str, Any]] = None + max_documents: Optional[int] = None + dry_run: bool = False +``` + +### Example Configuration (`pipelines/configs/csv_example.yml`) +```yaml +dataset: + name: "my_csv_dataset" + version: "1.0.0" + adapter: "csv_dataset" + path: "/path/to/your/data.csv" + +chunking: + strategy: "recursive_character" + chunk_size: 1000 + chunk_overlap: 200 + +embedding: + strategy: "dense" + provider: "hf" + model: "sentence-transformers/all-MiniLM-L6-v2" + batch_size: 16 + +vector_store: + provider: "qdrant" + collection_name: "my_csv_dataset_v1" + host: "localhost" + port: 6333 + +experiment: + name: "baseline" + description: "Initial baseline with MiniLM embeddings" + canary: false + +max_documents: null # null = no limit +dry_run: false +``` + +## 5. Implement Core Processing Components (60 minutes) + +### Simple Chunker (`pipelines/ingest/chunker.py`) +```python +"""Document chunking functionality.""" +import logging +from typing import List, Dict, Any + +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter + +logger = logging.getLogger(__name__) + +class DocumentChunker: + """Chunks documents using configurable strategies.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.strategy = config.get("strategy", "recursive_character") + + if self.strategy == "recursive_character": + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=config.get("chunk_size", 1000), + chunk_overlap=config.get("chunk_overlap", 200), + separators=config.get("separators", ["\n\n", "\n", " ", ""]) + ) + else: + raise ValueError(f"Unknown chunking strategy: {self.strategy}") + + def chunk_documents(self, documents: List[Document]) -> List[Document]: + """Split documents into chunks.""" + logger.info(f"Chunking {len(documents)} documents with {self.strategy} strategy") + + chunked_docs = [] + + for doc in documents: + chunks = self.splitter.split_documents([doc]) + + # Add chunk metadata + for i, chunk in enumerate(chunks): + chunk.metadata.update({ + "chunk_index": i, + "num_chunks": len(chunks), + "chunk_strategy": self.strategy, + "original_doc_id": doc.metadata.get("external_id") + }) + chunked_docs.append(chunk) + + logger.info(f"Generated {len(chunked_docs)} chunks from {len(documents)} documents") + return chunked_docs +``` + +### Simple Embedder (`pipelines/ingest/embedder.py`) +```python +"""Embedding generation.""" +import logging +from typing import List, Dict, Any, Optional + +from langchain_core.documents import Document +from sentence_transformers import SentenceTransformer + +from pipelines.contracts import ChunkMeta, compute_content_hash, build_doc_id, build_chunk_id, DatasetSplit + +logger = logging.getLogger(__name__) + +class EmbeddingPipeline: + """Generate embeddings for documents.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.strategy = config.get("strategy", "dense") + + if self.strategy in ["dense", "hybrid"]: + model_name = config.get("model", "sentence-transformers/all-MiniLM-L6-v2") + self.model = SentenceTransformer(model_name) + logger.info(f"Loaded embedding model: {model_name}") + + def process_documents(self, documents: List[Document]) -> List[ChunkMeta]: + """Convert documents to ChunkMeta with embeddings.""" + logger.info(f"Processing {len(documents)} documents for embeddings") + + chunk_metas = [] + texts = [doc.page_content for doc in documents] + + # Generate embeddings in batch + if self.strategy in ["dense", "hybrid"]: + embeddings = self.model.encode(texts, convert_to_tensor=False) + logger.info(f"Generated {len(embeddings)} embeddings") + + # Convert to ChunkMeta + for i, doc in enumerate(documents): + chunk_meta = self._document_to_chunk_meta(doc) + + if self.strategy in ["dense", "hybrid"]: + chunk_meta.dense_embedding = embeddings[i].tolist() + chunk_meta.embedding_dim = len(embeddings[i]) + chunk_meta.embedding_model = self.config.get("model") + + chunk_metas.append(chunk_meta) + + return chunk_metas + + def _document_to_chunk_meta(self, doc: Document) -> ChunkMeta: + """Convert Document to ChunkMeta.""" + text = doc.page_content + metadata = doc.metadata + + # Generate deterministic IDs + doc_sha256 = compute_content_hash(text) + source = metadata.get("source", "unknown") + external_id = metadata.get("external_id", "unknown") + + doc_id = build_doc_id(source, external_id, doc_sha256) + chunk_index = metadata.get("chunk_index", 0) + chunk_id = build_chunk_id(doc_id, chunk_index) + + return ChunkMeta( + doc_id=doc_id, + chunk_id=chunk_id, + doc_sha256=doc_sha256, + text=text, + source=source, + dataset_version=metadata.get("dataset_version", "unknown"), + external_id=external_id, + chunk_index=chunk_index, + num_chunks=metadata.get("num_chunks", 1), + char_count=len(text), + split=DatasetSplit(metadata.get("split", "all")), + labels=metadata + ) +``` + +## 6. Create Simple CLI Interface (30 minutes) + +### CLI Script (`bin/ingest.py`) +```python +#!/usr/bin/env python3 +"""Simple ingestion CLI.""" +import argparse +import yaml +import logging +from pathlib import Path +import sys +import importlib + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from pipelines.ingest.chunker import DocumentChunker +from pipelines.ingest.embedder import EmbeddingPipeline + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def load_adapter(adapter_name: str, data_path: str): + """Dynamically load dataset adapter.""" + module_name = f"pipelines.adapters.{adapter_name}" + module = importlib.import_module(module_name) + + # Find adapter class (assumes pattern: XxxDatasetAdapter) + adapter_class = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if (isinstance(attr, type) and + hasattr(attr, 'source_name') and + attr_name.endswith('Adapter')): + adapter_class = attr + break + + if not adapter_class: + raise ValueError(f"No adapter class found in {module_name}") + + return adapter_class(data_path) + +def main(): + parser = argparse.ArgumentParser(description="Simple RAG Ingestion Pipeline") + parser.add_argument("config", help="Configuration file path") + parser.add_argument("--dry-run", action="store_true", help="Run without uploading") + parser.add_argument("--max-docs", type=int, help="Limit number of documents") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging") + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Load configuration + with open(args.config, 'r') as f: + config = yaml.safe_load(f) + + logger.info(f"Starting ingestion with config: {args.config}") + + try: + # Load dataset adapter + adapter = load_adapter( + config["dataset"]["adapter"], + config["dataset"]["path"] + ) + logger.info(f"Loaded adapter: {adapter.source_name}") + + # Read data + rows = list(adapter.read_rows()) + if args.max_docs: + rows = rows[:args.max_docs] + logger.info(f"Read {len(rows)} rows") + + # Convert to documents + documents = adapter.to_documents(rows, split="all") + logger.info(f"Created {len(documents)} documents") + + # Chunk documents + chunker = DocumentChunker(config["chunking"]) + chunked_docs = chunker.chunk_documents(documents) + logger.info(f"Created {len(chunked_docs)} chunks") + + # Generate embeddings + embedder = EmbeddingPipeline(config["embedding"]) + chunk_metas = embedder.process_documents(chunked_docs) + logger.info(f"Generated embeddings for {len(chunk_metas)} chunks") + + if args.dry_run: + logger.info("DRY RUN - Would upload to vector store") + logger.info(f"Sample chunk: {chunk_metas[0].chunk_id}") + else: + # TODO: Implement vector store upload + logger.info("Vector store upload not implemented yet") + + logger.info("Ingestion completed successfully!") + + except Exception as e: + logger.error(f"Ingestion failed: {e}") + return 1 + + return 0 + +if __name__ == "__main__": + sys.exit(main()) +``` + +## 7. Test Your Implementation (15 minutes) + +### Create Test Data (`test_data.csv`) +```csv +id,title,content,category +1,"Introduction to Python","Python is a programming language that lets you work quickly and integrate systems more effectively.","programming" +2,"Machine Learning Basics","Machine learning is a method of data analysis that automates analytical model building.","ai" +3,"Data Science Overview","Data science is an interdisciplinary field that uses scientific methods to extract knowledge from data.","data" +``` + +### Test Configuration (`test_config.yml`) +```yaml +dataset: + name: "test_dataset" + version: "1.0.0" + adapter: "csv_dataset" + path: "test_data.csv" + +chunking: + strategy: "recursive_character" + chunk_size: 500 + chunk_overlap: 50 + +embedding: + strategy: "dense" + provider: "hf" + model: "sentence-transformers/all-MiniLM-L6-v2" + batch_size: 2 + +vector_store: + provider: "qdrant" + collection_name: "test_v1" +``` + +### Run Test +```bash +# Create test data +echo 'id,title,content,category +1,"Introduction to Python","Python is a programming language that lets you work quickly and integrate systems more effectively.","programming" +2,"Machine Learning Basics","Machine learning is a method of data analysis that automates analytical model building.","ai"' > test_data.csv + +# Test the pipeline +python bin/ingest.py test_config.yml --dry-run --max-docs 2 --verbose +``` + +## 8. Next Steps and Extensions + +### Immediate Improvements +1. **Add Vector Store Integration**: Implement actual upload to Qdrant +2. **Add Validation**: Input validation and quality checks +3. **Add Error Handling**: Robust error handling and recovery +4. **Add Logging**: Structured logging with metrics + +### Advanced Features +1. **Evaluation Framework**: Implement retrieval evaluation +2. **Configuration Validation**: Schema validation for configs +3. **Experiment Tracking**: MLflow or W&B integration +4. **Monitoring**: Prometheus metrics and Grafana dashboards +5. **Streaming**: Process large datasets without loading into memory + +### Production Readiness +1. **Docker Containers**: Containerize the application +2. **CI/CD Pipeline**: Automated testing and deployment +3. **Infrastructure as Code**: Terraform for cloud resources +4. **Security**: Authentication, authorization, and encryption +5. **Backup and Recovery**: Data backup and disaster recovery + +## Troubleshooting + +### Common Issues + +**Import Errors**: +```bash +# Make sure project root is in Python path +export PYTHONPATH="${PYTHONPATH}:$(pwd)" +``` + +**Missing Dependencies**: +```bash +# Install additional packages as needed +pip install sentence-transformers pandas pyyaml +``` + +**Configuration Errors**: +- Check YAML syntax +- Verify file paths exist +- Ensure adapter names match module names + +**Memory Issues**: +- Reduce batch_size in embedding config +- Use --max-docs to limit dataset size +- Consider streaming implementation for large datasets + +This quick start guide should get you up and running with a basic MLOps pipeline for RAG systems. Start with this foundation and gradually add more sophisticated features as your needs grow. diff --git a/docs/SOSUM_INGESTION.md b/docs/SOSUM_INGESTION.md new file mode 100644 index 0000000..b0a60f9 --- /dev/null +++ b/docs/SOSUM_INGESTION.md @@ -0,0 +1,213 @@ +# Ingesting SOSum Stack Overflow Dataset + +This guide shows how to ingest the SOSum dataset using the pipeline. + +## About SOSum + +**SOSum** is a dataset of extractive summaries of Stack Overflow posts from: +https://github.com/BonanKou/SOSum-A-Dataset-of-Extractive-Summaries-of-Stack-Overflow-Posts-and-labeling-tools + +**Dataset Statistics:** +- 506 popular Stack Overflow questions +- 2,278 total posts (questions + answers) +- 669 unique tags covered +- Median view count: 253K +- Median post score: 17 +- Manual extractive summaries for answers + +## Dataset Format + +SOSum comes in two CSV files: + +### `question.csv` +| Field | Description | +|-------|-------------| +| Question Id | Post ID of the SO question | +| Question Type | 1=conceptual, 2=how-to, 3=debug-corrective | +| Question Title | Question title as string | +| Question Body | List of sentences from question content | +| Tags | SO tags associated with question | +| Answer Posts | Comma-separated answer post IDs | + +### `answer.csv` +| Field | Description | +|-------|-------------| +| Answer Id | Post ID of SO answer | +| Answer Body | List of sentences from answer content | +| Summary | Extractive summative sentences | + +## Quick Start + +### 1. Download the Dataset + +```bash +# Clone the SOSum repository +git clone https://github.com/BonanKou/SOSum-A-Dataset-of-Extractive-Summaries-of-Stack-Overflow-Posts-and-labeling-tools.git sosum + +# The CSV files are in sosum/data/ directory +ls sosum/data/ +# Should show: question.csv answer.csv +``` + +### 2. Test the Adapter + +```bash +# Run the example script to test everything works +python examples/ingest_sosum_example.py +``` + +### 3. Dry Run Ingestion + +```bash +# Test with a small sample (no upload to vector store) +python bin/ingest.py ingest stackoverflow sosum/ --dry-run --max-docs 10 --verbose +``` + +### 4. Canary Ingestion + +```bash +# Safe test with real upload to canary collection +python bin/ingest.py ingest stackoverflow sosum/ --canary --max-docs 100 --verify +``` + +### 5. Check Status + +```bash +python bin/ingest.py status +``` + +### 6. Full Ingestion + +```bash +# Ingest all data +python bin/ingest.py ingest stackoverflow sosum/ --config pipelines/configs/stackoverflow.yml +``` + +### 7. Evaluate Retrieval + +```bash +# Test retrieval performance +python bin/ingest.py evaluate stackoverflow sosum/ --output-dir results/sosum/ +``` + +## What Gets Ingested + +### Document Types + +1. **Questions**: Combined title + body content + - ID format: `q_{question_id}` + - Content: "Title: {title}\n\nQuestion: {body}" + - Metadata: tags, question_type, related_posts + +2. **Answers**: Answer body + summary (if available) + - ID format: `a_{answer_id}` + - Content: "Answer: {body}\n\nSummary: {summary}" (if summary exists) + - Metadata: has_summary, summary + +### Metadata Fields + +- `external_id`: Unique identifier (q_123 or a_456) +- `source`: "stackoverflow_sosum" +- `post_type`: "question" or "answer" +- `doc_type`: "question" or "answer" +- `tags`: List of SO tags (questions only) +- `title`: Question title (questions only) +- `question_type`: 1, 2, or 3 (questions only) +- `has_summary`: Boolean (answers only) +- `summary`: Extractive summary text (answers only) + +### Evaluation Queries + +The adapter automatically generates evaluation queries: + +1. **Question titles** → Should retrieve the question document +2. **Short question queries** → First 5 words of title +3. **Answer summaries** → Should retrieve the answer document + +## Configuration + +The pipeline uses `pipelines/configs/stackoverflow.yml`: + +- **Code-aware chunking**: Preserves code blocks and functions +- **Hybrid embedding**: Dense + sparse vectors for better code retrieval +- **Smaller validation limits**: Handles extractive summaries (shorter content) +- **SOSum-specific collection**: `sosum_stackoverflow_v1` + +## Expected Results + +After successful ingestion: + +- **Documents**: ~2,278 documents (506 questions + ~1,772 answers) +- **Chunks**: Depends on chunking strategy (likely 3,000-5,000 chunks) +- **Vectors**: Hybrid (dense + sparse) for each chunk +- **Collection**: Named `sosum_stackoverflow_v1` in Qdrant + +## Troubleshooting + +### Common Issues + +1. **File not found**: + ```bash + # Make sure files exist + ls sosum/data/question.csv sosum/data/answer.csv + ``` + +2. **Parsing errors**: + ```bash + # Check CSV format + head -5 sosum/data/question.csv + head -5 sosum/data/answer.csv + ``` + +3. **Import errors**: + ```bash + # Check dependencies + pip install pandas pydantic langchain-core + ``` + +4. **Qdrant connection**: + ```bash + # Check if Qdrant is running + python bin/ingest.py status + ``` + +### Debug Commands + +```bash +# Verbose logging +python bin/ingest.py ingest stackoverflow sosum/ --dry-run --verbose + +# Check logs +tail -f logs/ingestion.log + +# Test specific number of docs +python bin/ingest.py ingest stackoverflow sosum/ --dry-run --max-docs 5 +``` + +## Integration with Retrieval + +After ingestion, you can test retrieval: + +```python +from retrievers.router import RetrieverRouter +from config.config_loader import load_config + +config = load_config("pipelines/configs/stackoverflow.yml") +retriever = RetrieverRouter(config) + +# Test queries +results = retriever.search("Python list comprehension example", top_k=5) +for result in results: + print(f"Score: {result['score']:.3f}") + print(f"Doc: {result['metadata']['external_id']}") + print(f"Content: {result['content'][:100]}...") + print() +``` + +## Next Steps + +1. **Add more datasets**: Use the same adapter pattern for other SO datasets +2. **Custom evaluation**: Add domain-specific evaluation queries +3. **Tune chunking**: Experiment with chunk sizes for code content +4. **Hybrid weights**: Tune dense vs sparse retrieval weights +5. **Summary utilization**: Use extractive summaries for enhanced retrieval diff --git a/embedding/hf_embedder.py b/embedding/embeddings.py similarity index 66% rename from embedding/hf_embedder.py rename to embedding/embeddings.py index c1c5c7a..5c0e67b 100644 --- a/embedding/hf_embedder.py +++ b/embedding/embeddings.py @@ -1,4 +1,5 @@ -from langchain_huggingface import HuggingFaceEmbeddings +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_google_genai import GoogleGenerativeAIEmbeddings class HuggingFaceEmbedder(HuggingFaceEmbeddings): diff --git a/embedding/factory.py b/embedding/factory.py index 75928e4..3e58d29 100644 --- a/embedding/factory.py +++ b/embedding/factory.py @@ -1,47 +1,52 @@ -import os -from embedding.hf_embedder import HuggingFaceEmbedder +from embedding.embeddings import HuggingFaceEmbedder from embedding.bedrock_embeddings import TitanEmbedder from embedding.sparse_embedder import SparseEmbedder -import dotenv from langchain_qdrant import FastEmbedSparse - -dotenv.load_dotenv() +from langchain_google_genai import GoogleGenerativeAIEmbeddings +import os -def get_embedder(name: str = None, **kwargs): +def get_embedder(cfg: dict): """ - Factory to return a LangChain-compatible embedder instance. + Factory to return a LangChain-compatible embedder instance, based on YAML config. Args: - name (str, optional): Embedder name. If not provided, will fetch from ENV. - kwargs: Additional model configuration. + cfg (dict): Embedder configuration dictionary. Returns: A LangChain-compatible embedder object. """ + provider = cfg.get("provider", "hf").strip().lower() - name = (name or os.getenv("DENSE_EMBEDDER") - or os.getenv("SPARSE_EMBEDDER")).strip().lower() - - if name == "hf": - model_name = kwargs.get("model_name") or os.getenv( - "HF_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2" - ) - return HuggingFaceEmbedder(model_name=model_name) + if provider == "hf": + model_name = cfg.get( + "model_name", "sentence-transformers/all-MiniLM-L6-v2") + device = cfg.get("device", "cpu") + return HuggingFaceEmbedder(model_name=model_name, device=device) - elif name == "titan": - model = kwargs.get("model") or os.getenv( - "TITAN_MODEL", "amazon.titan-embed-text-v2:0" - ) - region = kwargs.get("region") or os.getenv("TITAN_REGION", "us-east-1") + elif provider == "titan": + model = cfg.get("model_name", "amazon.titan-embed-text-v2:0") + region = cfg.get("region", "us-east-1") return TitanEmbedder(model=model, region=region) - elif name == "fastembed": - model_name = kwargs.get("model_name") or os.getenv( - "FASTEMBED_MODEL", "BAAI/bge-small-en-v1.5" + elif provider == "fastembed": + model_name = cfg.get("model_name", "BAAI/bge-small-en-v1.5") + return FastEmbedSparse(model_name=model_name) + + elif provider == "google": + model_name = cfg.get("model", "models/embedding-001") + return GoogleGenerativeAIEmbeddings( + model=model_name, + google_api_key=os.getenv("GOOGLE_API_KEY") ) - return FastEmbedSparse(model_name=os.getenv("SPARSE_MODEL_NAME", "Qdrant/bm25")) + + elif provider == "sparse": + # Support both 'model' and 'model_name' for consistency with other providers + model_name = cfg.get("model") or cfg.get("model_name") or "Qdrant/bm25" + device = cfg.get("device", "cpu") + return SparseEmbedder(model_name=model_name, device=device) else: raise ValueError( - f"Unsupported embedder name: '{name}'. Supported: hf, titan, fastembed") + f"Unsupported embedder provider: '{provider}'. Supported: hf, titan, fastembed, sparse, google" + ) diff --git a/embedding/recursive_splitter.py b/embedding/recursive_splitter.py index 29846e1..0e74040 100644 --- a/embedding/recursive_splitter.py +++ b/embedding/recursive_splitter.py @@ -13,3 +13,9 @@ def __init__(self, chunk_size=500, chunk_overlap=50): def split(self, docs: List[Document]) -> List[Document]: return self.splitter.split_documents(docs) + + def create_documents(self, texts: List[str], metadatas: List[dict] = None) -> List[Document]: + """ + Create Document objects from a list of texts (and optional metadatas). + """ + return self.splitter.create_documents(texts, metadatas) diff --git a/embedding/sparse_embedder.py b/embedding/sparse_embedder.py index 0f597c7..d6fc810 100644 --- a/embedding/sparse_embedder.py +++ b/embedding/sparse_embedder.py @@ -1,6 +1,6 @@ import logging from typing import List, Dict -from fastembed import TextEmbedding +from fastembed import SparseTextEmbedding from langchain_core.embeddings import Embeddings logger = logging.getLogger(__name__) @@ -9,22 +9,21 @@ class SparseEmbedder(Embeddings): """ - Embedder that produces sparse vectors using FastEmbed. + Embedder that produces sparse vectors using FastEmbed SparseTextEmbedding. """ - def __init__(self, model_name: str = "BAAI/bge-base-en", device: str = "cuda"): + def __init__(self, model_name: str = "Qdrant/bm25", device: str = "cpu"): """ Args: - model_name (str): Name of the sparse model to load (e.g., BGE sparse models). + model_name (str): Name of the sparse model to load (e.g., "Qdrant/bm25"). device (str): Device to run the model ("cpu" or "cuda"). """ - self.model = TextEmbedding( - model_name=model_name, - embedding_type="sparse", - device=device + self.model = SparseTextEmbedding( + model_name=model_name ) + self.model_name = model_name logger.info( - f"Initialized SparseEmbedder with model: {model_name}, device: {device}") + f"Initialized SparseEmbedder with model: {model_name}") def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]: """ @@ -37,7 +36,14 @@ def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]: List[Dict[int, float]]: List of sparse embeddings (one per text). """ logger.info(f"Embedding {len(texts)} documents (sparse).") - return list(self.model.embed(texts)) + embeddings = [] + for embedding in self.model.embed(texts): + # Convert SparseEmbedding to dict + sparse_dict = {} + for idx, val in zip(embedding.indices, embedding.values): + sparse_dict[int(idx)] = float(val) + embeddings.append(sparse_dict) + return embeddings def embed_query(self, text: str) -> Dict[int, float]: """ @@ -49,4 +55,5 @@ def embed_query(self, text: str) -> Dict[int, float]: Returns: Dict[int, float]: Sparse vector for query. """ - return next(self.model.embed([text])) + embeddings = self.embed_documents([text]) + return embeddings[0] diff --git a/logs/utils/__init__.py b/logs/utils/__init__.py new file mode 100644 index 0000000..c29da6e --- /dev/null +++ b/logs/utils/__init__.py @@ -0,0 +1 @@ +from logs.utils.logger import get_logger diff --git a/logs/utils/logger.py b/logs/utils/logger.py new file mode 100644 index 0000000..1aa18fd --- /dev/null +++ b/logs/utils/logger.py @@ -0,0 +1,33 @@ +import logging +from pathlib import Path + +# Ensure log directory exists +Path("logs").mkdir(exist_ok=True) + + +def get_logger(name: str) -> logging.Logger: + """ + Returns a logger instance writing to logs/agent.log. + Ensures handlers are not duplicated. + Args: + name (str): Logger name, usually __name__. + Returns: + logging.Logger: Configured logger object. + """ + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + # Only add handler if none exist for this logger + if not any(isinstance(h, logging.FileHandler) and h.baseFilename.endswith("agent.log") + for h in logger.handlers): + handler = logging.FileHandler("logs/agent.log") + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger + +# Example usage in other modules: +# logger = get_logger(__name__) +# logger.info("Logger is working!") diff --git a/main.py b/main.py new file mode 100644 index 0000000..a6a0e48 --- /dev/null +++ b/main.py @@ -0,0 +1,59 @@ +""" +Main application entry point for the RAG agent. +Provides an interactive chat interface for the LangGraph agent with configurable retrieval. +""" + +from agent.graph import graph +from logs.utils.logger import get_logger + +logger = get_logger("chat") + + +def main(): + """ + Main chat loop for the RAG agent. + Handles user input, agent invocation, and response display. + """ + chat_history = [] + + print("RAG Agent - Interactive Chat") + print("Type 'exit' or 'quit' to end the conversation") + print("-" * 50) + + while True: + user_input = input("You: ") + if user_input.lower() in {"exit", "quit"}: + print("Goodbye!") + break + + state = { + "question": user_input, + "chat_history": chat_history + } + + try: + final_state = graph.invoke(state) + answer = final_state.get("answer", "[No answer returned]") + chat_history = final_state.get("chat_history", []) + + print("\n---") + print(f"Agent: {answer}") + + # Log to file + logger.info(f"User: {user_input}") + logger.info(f"Agent: {answer}") + + if "error" in final_state: + logger.error(f"Execution error: {final_state['error']}") + print(f"[Error occurred: {final_state['error']}]") + + print("---\n") + + except Exception as e: + logger.error(f"Agent invocation failed: {e}") + print(f"[Error: Agent failed to process your request: {e}]") + print("---\n") + + +if __name__ == "__main__": + main() diff --git a/pipelines/README.md b/pipelines/README.md new file mode 100644 index 0000000..4cc5ec3 --- /dev/null +++ b/pipelines/README.md @@ -0,0 +1,373 @@ +# Ingestion Pipeline + +A comprehensive, theory-backed ingestion pipeline for retrieval-augmented generation (RAG) systems. This pipeline implements deterministic IDs, idempotent loads, dataset adapters, and comprehensive evaluation. + +## 🎯 Mission + +This pipeline guarantees: +- **Reproducibility**: Same raw data → same chunk IDs and vectors +- **Idempotency**: Reruns don't duplicate anything; only changed content is updated +- **Portability**: New datasets plug in without touching downstream code +- **Observability**: You can prove what you loaded, when, with which config/code + +## 🏗️ Architecture + +``` +pipelines/ +├── contracts.py # Base schemas and interfaces +├── adapters/ # Dataset-specific adapters +│ ├── natural_questions.py +│ ├── stackoverflow.py +│ └── energy_papers.py +├── ingest/ # Core ingestion components +│ ├── validator.py # Document validation and cleaning +│ ├── chunker.py # Advanced chunking strategies +│ ├── embedder.py # Embedding pipeline with caching +│ ├── uploader.py # Vector store uploader +│ ├── smoke_tests.py # Post-ingestion verification +│ └── pipeline.py # Main orchestrator +├── eval/ # Evaluation framework +│ └── evaluator.py # Unified retrieval evaluation +└── configs/ # Dataset-specific configurations + ├── natural_questions.yml + ├── stackoverflow.yml + └── energy_papers.yml +``` + +## 🚀 Quick Start + +### 1. Install Dependencies + +The pipeline uses your existing dependencies plus these optional packages: +```bash +pip install numpy # For evaluation metrics +``` + +### 2. Configure Your Setup + +Copy and modify a configuration template: +```bash +cp pipelines/configs/energy_papers.yml my_config.yml +# Edit my_config.yml for your needs +``` + +### 3. Ingest Your First Dataset + +```bash +# Ingest your energy papers (using existing papers/ directory) +python bin/ingest.py ingest energy_papers papers/ --config my_config.yml + +# Dry run with limited documents for testing +python bin/ingest.py ingest energy_papers papers/ --dry-run --max-docs 10 + +# Canary deployment (test with separate collection) +python bin/ingest.py ingest energy_papers papers/ --canary --verify +``` + +### 4. Check Status + +```bash +python bin/ingest.py status +``` + +### 5. Evaluate Retrieval + +```bash +python bin/ingest.py evaluate energy_papers papers/ --output-dir results/ +``` + +## 📖 Core Concepts + +### Dataset Adapters + +Each dataset needs an adapter that implements: +- `read_rows()`: Read raw dataset into standardized format +- `to_documents()`: Convert to LangChain Documents +- `get_evaluation_queries()`: Provide evaluation queries + +### Deterministic IDs + +- **Document ID**: `{source}:{external_id}:{content_hash[:12]}` +- **Chunk ID**: `{doc_id}#c{chunk_index:04d}` + +This ensures identical content always gets the same ID. + +### Chunking Strategies + +- **Recursive**: General-purpose character-based chunking +- **Semantic**: Sentence-boundary aware chunking +- **Code Aware**: Preserves code blocks and functions +- **Table Aware**: Preserves table structure +- **Auto**: Automatically selects strategy based on content + +### Embedding Strategies + +- **Dense**: Semantic embeddings (HuggingFace, Titan) +- **Sparse**: Keyword-based embeddings (BM25, FastEmbed) +- **Hybrid**: Both dense and sparse for optimal recall + +## 🎛️ Advanced Usage + +### Batch Processing + +```bash +# Create batch configuration +cat > batch_config.json << EOF +{ + "datasets": [ + {"type": "energy_papers", "path": "papers/", "version": "1.0.0"}, + {"type": "stackoverflow", "path": "/path/to/stackoverflow", "version": "1.0.0"} + ] +} +EOF + +# Run batch ingestion +python bin/ingest.py batch-ingest batch_config.json --max-docs 100 +``` + +### Custom Dataset Adapter + +```python +from pipelines.contracts import DatasetAdapter, BaseRow +from pipelines.adapters.base import MyCustomRow + +class MyDatasetAdapter(DatasetAdapter): + @property + def source_name(self) -> str: + return "my_dataset" + + def read_rows(self, split) -> Iterable[MyCustomRow]: + # Your data reading logic + pass + + def to_documents(self, rows, split) -> List[Document]: + # Convert to LangChain Documents + pass + + def get_evaluation_queries(self, split) -> List[Dict]: + # Return evaluation queries + pass +``` + +### Canary → Promote Workflow + +```bash +# 1. Canary deployment +python bin/ingest.py ingest my_dataset /path/to/data --canary + +# 2. Verify canary collection +python bin/ingest.py evaluate my_dataset /path/to/data --output-dir canary_results/ + +# 3. If good, promote (re-run without --canary) +python bin/ingest.py ingest my_dataset /path/to/data + +# 4. Clean up canary +python bin/ingest.py cleanup +``` + +## 📊 Evaluation Framework + +The pipeline includes a comprehensive evaluation system: + +### Metrics +- **Recall@k**: Fraction of relevant docs in top-k +- **Precision@k**: Fraction of top-k that are relevant +- **NDCG@k**: Normalized Discounted Cumulative Gain +- **MRR**: Mean Reciprocal Rank +- **MAP**: Mean Average Precision + +### Usage +```bash +python bin/ingest.py evaluate energy_papers papers/ \ + --split test \ + --output-dir evaluation_results/ +``` + +## 🔧 Configuration + +### Main Configuration (config.yml) +```yaml +embedding_strategy: hybrid + +embedding: + dense: + provider: hf + model_name: sentence-transformers/all-MiniLM-L6-v2 + sparse: + provider: fastembed + model_name: Qdrant/bm25 + +chunking: + strategy: semantic + chunk_size: 500 + chunk_overlap: 50 + +qdrant: + collection: my_collection + dense_vector_name: dense + sparse_vector_name: sparse +``` + +### Dataset-Specific Configs +Each dataset can have specialized settings in `pipelines/configs/`. + +## 📈 Monitoring & Observability + +### Lineage Tracking +Every ingestion run creates a lineage record: +```json +{ + "run_id": "uuid", + "dataset_name": "energy_papers", + "config_hash": "abc123", + "git_commit": "def456", + "total_documents": 100, + "successful_chunks": 850, + "sample_doc_ids": ["doc1", "doc2"], + "environment": {...} +} +``` + +### Logs +- Console output for real-time monitoring +- Structured logs in `logs/ingestion.log` +- Per-run lineage in `output/lineage/` + +### Smoke Tests +Automatic post-ingestion validation: +- Collection exists and is populated +- Vector dimensions are consistent +- Golden queries return reasonable results +- Embedding quality metrics + +## 🛠️ Extending the Pipeline + +### Adding a New Dataset + +1. **Create an adapter** in `pipelines/adapters/my_dataset.py` +2. **Add configuration** in `pipelines/configs/my_dataset.yml` +3. **Register in CLI** by adding to `get_adapter()` in `bin/ingest.py` +4. **Test** with dry run: `python bin/ingest.py ingest my_dataset /path --dry-run` + +### Adding a New Chunking Strategy + +```python +from pipelines.ingest.chunker import ChunkingStrategy + +class MyChunkingStrategy(ChunkingStrategy): + @property + def strategy_name(self) -> str: + return "my_strategy" + + def chunk(self, documents) -> List[Document]: + # Your chunking logic + pass + +# Register in ChunkingStrategyFactory.STRATEGIES +``` + +### Adding a New Smoke Test + +```python +from pipelines.ingest.smoke_tests import SmokeTest, SmokeTestResult + +class MyCustomTest(SmokeTest): + @property + def test_name(self) -> str: + return "my_test" + + def run(self, config) -> SmokeTestResult: + # Your test logic + pass +``` + +## 🎯 Best Practices + +### Development Workflow +1. **Start with dry runs** to validate data processing +2. **Use canary deployments** for new configurations +3. **Run evaluations** to measure retrieval quality +4. **Monitor lineage** for reproducibility + +### Production Deployment +1. **Pin configuration versions** with git commits +2. **Use deterministic seeds** for dataset splits +3. **Archive lineage records** for compliance +4. **Monitor embedding quality** over time + +### Performance Optimization +1. **Enable embedding caching** for repeated runs +2. **Tune batch sizes** based on your hardware +3. **Use appropriate chunking** for your content type +4. **Monitor vector store performance** + +## 🔍 Troubleshooting + +### Common Issues + +**Embeddings are all zeros** +```bash +# Check embedding configuration +python bin/ingest.py ingest my_dataset /path --dry-run --max-docs 1 -v +``` + +**Collection not found** +```bash +# Check Qdrant connection and collection settings +python bin/ingest.py status +``` + +**Evaluation shows zero recall** +```bash +# Verify evaluation queries match ingested content +python bin/ingest.py evaluate my_dataset /path --max-docs 10 +``` + +**Import errors** +Make sure you're running from the project root and all dependencies are installed. + +## 📝 Development Notes + +This pipeline is designed to integrate seamlessly with your existing: +- Qdrant vector database setup +- Embedding factory (HF, Titan, FastEmbed) +- Configuration system (YAML-based) +- Document processing pipeline + +The theory-backed design ensures that you can confidently: +- Add new datasets without touching core code +- Compare retrieval quality across configurations +- Reproduce any ingestion run exactly +- Scale to production workloads + +## 🎉 Example: Complete Workflow + +```bash +# 1. Setup +git add . && git commit -m "Setup ingestion pipeline" + +# 2. Test with dry run +python bin/ingest.py ingest energy_papers papers/ \ + --config pipelines/configs/energy_papers.yml \ + --dry-run --max-docs 5 --verbose + +# 3. Canary deployment +python bin/ingest.py ingest energy_papers papers/ \ + --canary --max-docs 50 + +# 4. Evaluate canary +python bin/ingest.py evaluate energy_papers papers/ \ + --output-dir canary_eval/ + +# 5. Full deployment (if canary looks good) +python bin/ingest.py ingest energy_papers papers/ + +# 6. Production evaluation +python bin/ingest.py evaluate energy_papers papers/ \ + --output-dir production_eval/ + +# 7. Check final status +python bin/ingest.py status +``` + +This gives you a production-ready, theory-backed ingestion pipeline that scales across datasets and maintains full lineage! 🚀 diff --git a/pipelines/__init__.py b/pipelines/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/adapters/__init__.py b/pipelines/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/adapters/beir_base.py b/pipelines/adapters/beir_base.py new file mode 100644 index 0000000..055434c --- /dev/null +++ b/pipelines/adapters/beir_base.py @@ -0,0 +1,118 @@ +""" +Base adapter functionality for BEIR datasets. +""" +import os +from pathlib import Path +from typing import List, Dict, Any, Iterable +from beir.datasets.data_loader import GenericDataLoader + +from pipelines.contracts import BaseRow, DatasetAdapter, DatasetSplit +from langchain_core.documents import Document + + +class BeirBaseAdapter(DatasetAdapter): + """Base adapter for BEIR datasets.""" + + def __init__(self, dataset_path: str, dataset_name: str, version: str = "1.0.0"): + self.dataset_path = Path(dataset_path) + self.dataset_name = dataset_name + self._version = version + + if not self.dataset_path.exists(): + raise FileNotFoundError(f"Dataset not found at {self.dataset_path}") + + @property + def source_name(self) -> str: + return self.dataset_name + + @property + def version(self) -> str: + return self._version + + def _load_beir_data(self, split: DatasetSplit): + """Load BEIR dataset components.""" + split_name = "test" if split in [DatasetSplit.TEST, DatasetSplit.ALL] else split.value + + try: + corpus, queries, qrels = GenericDataLoader( + str(self.dataset_path) + ).load(split=split_name) + return corpus, queries, qrels + except Exception as e: + # Fallback for datasets without train/val splits + corpus, queries, qrels = GenericDataLoader( + str(self.dataset_path) + ).load(split="test") + return corpus, queries, qrels + + def get_evaluation_queries(self, split: DatasetSplit = DatasetSplit.TEST) -> List[Dict[str, Any]]: + """Return evaluation queries with relevance judgments.""" + _, queries, qrels = self._load_beir_data(split) + + eval_queries = [] + for qid, query_text in queries.items(): + relevant_docs = list(qrels.get(qid, {}).keys()) if qid in qrels else [] + + eval_queries.append({ + "query_id": qid, + "query": query_text, + "relevant_doc_ids": relevant_docs, + "relevance_scores": qrels.get(qid, {}) + }) + + return eval_queries + + +class BeirRow(BaseRow): + """Row schema for BEIR datasets.""" + title: str = "" + text: str = "" + metadata: Dict[str, Any] = {} + + +class GenericBeirAdapter(BeirBaseAdapter): + """Generic adapter for any BEIR dataset.""" + + def read_rows(self, split: DatasetSplit = DatasetSplit.ALL) -> Iterable[BeirRow]: + """Read corpus documents as rows.""" + corpus, _, _ = self._load_beir_data(split) + + for doc_id, content in corpus.items(): + yield BeirRow( + external_id=doc_id, + title=content.get("title", ""), + text=content.get("text", ""), + metadata=content.get("metadata", {}) + ) + + def to_documents(self, rows: List[BeirRow], split: DatasetSplit) -> List[Document]: + """Convert rows to LangChain Documents.""" + documents = [] + + for row in rows: + # Combine title and text + content_parts = [] + if row.title: + content_parts.append(row.title) + if row.text: + content_parts.append(row.text) + + full_text = ". ".join(content_parts).strip() + if not full_text: + continue + + metadata = { + "external_id": row.external_id, + "title": row.title, + "source": self.source_name, + "dataset_version": self.version, + "split": split.value, + **row.metadata + } + + documents.append(Document( + page_content=full_text, + metadata=metadata + )) + + return documents diff --git a/pipelines/adapters/energy_papers.py b/pipelines/adapters/energy_papers.py new file mode 100644 index 0000000..d10ae07 --- /dev/null +++ b/pipelines/adapters/energy_papers.py @@ -0,0 +1,175 @@ +""" +Adapter for energy research papers (PDF documents). +""" +import os +from pathlib import Path +from typing import List, Dict, Any, Iterable + +from pipelines.contracts import BaseRow, DatasetAdapter, DatasetSplit +from langchain_core.documents import Document + + +class EnergyPaperRow(BaseRow): + """Row schema for energy research papers.""" + title: str + file_path: str + content: str = "" + authors: List[str] = [] + abstract: str = "" + keywords: List[str] = [] + year: int = 0 + + class Config: + extra = "allow" + + +class EnergyPapersAdapter(DatasetAdapter): + """Adapter for energy research papers dataset.""" + + def __init__(self, papers_path: str, version: str = "1.0.0"): + self.papers_path = Path(papers_path) + self._version = version + + if not self.papers_path.exists(): + raise FileNotFoundError(f"Papers directory not found at {self.papers_path}") + + @property + def source_name(self) -> str: + return "energy_papers" + + @property + def version(self) -> str: + return self._version + + def read_rows(self, split: DatasetSplit = DatasetSplit.ALL) -> Iterable[EnergyPaperRow]: + """Read PDF files from papers directory.""" + pdf_files = list(self.papers_path.glob("*.pdf")) + + # Simple split logic based on filename patterns or random split + total_files = len(pdf_files) + if split == DatasetSplit.TRAIN: + pdf_files = pdf_files[:int(0.7 * total_files)] + elif split == DatasetSplit.VALIDATION: + pdf_files = pdf_files[int(0.7 * total_files):int(0.85 * total_files)] + elif split == DatasetSplit.TEST: + pdf_files = pdf_files[int(0.85 * total_files):] + + for pdf_path in pdf_files: + try: + yield self._extract_paper_info(pdf_path) + except Exception as e: + print(f"Error processing {pdf_path}: {e}") + continue + + def _extract_paper_info(self, pdf_path: Path) -> EnergyPaperRow: + """Extract basic information from PDF file.""" + # Extract title from filename (clean it up) + title = pdf_path.stem + title = title.replace("_", " ").replace("-", " ") + # Remove common patterns like "v1", "v2", etc. + import re + title = re.sub(r'\s+v\d+.*$', '', title) + title = title.strip() + + # Try to extract more metadata if available + # For now, use basic file-based extraction + # In a real implementation, you'd use PyMuPDF, pdfplumber, etc. + + content = "" + authors = [] + abstract = "" + keywords = [] + year = 0 + + # Extract year from filename if present + year_match = re.search(r'20\d{2}', pdf_path.name) + if year_match: + year = int(year_match.group()) + + # For this example, we'll simulate content extraction + # In practice, you'd integrate with your existing PDF processing + try: + # Placeholder for actual PDF text extraction + # You could integrate with your existing PDF processors here + content = f"Content from {pdf_path.name} would be extracted here" + except Exception as e: + print(f"Could not extract content from {pdf_path}: {e}") + + return EnergyPaperRow( + external_id=pdf_path.stem, + title=title, + file_path=str(pdf_path), + content=content, + authors=authors, + abstract=abstract, + keywords=keywords, + year=year + ) + + def to_documents(self, rows: List[EnergyPaperRow], split: DatasetSplit) -> List[Document]: + """Convert paper rows to Documents.""" + documents = [] + + for row in rows: + # Create document content + content_parts = [] + if row.title: + content_parts.append(f"Title: {row.title}") + if row.abstract: + content_parts.append(f"Abstract: {row.abstract}") + if row.content: + content_parts.append(row.content) + + full_text = "\n\n".join(content_parts) + if not full_text.strip(): + continue + + metadata = { + "external_id": row.external_id, + "title": row.title, + "file_path": row.file_path, + "authors": row.authors, + "keywords": row.keywords, + "year": row.year, + "source": self.source_name, + "dataset_version": self.version, + "split": split.value, + "doc_type": "research_paper" + } + + documents.append(Document( + page_content=full_text, + metadata=metadata + )) + + return documents + + def get_evaluation_queries(self, split: DatasetSplit = DatasetSplit.TEST) -> List[Dict[str, Any]]: + """Return evaluation queries for energy papers.""" + eval_queries = [] + + # Generate queries from paper titles and abstracts + common_energy_queries = [ + "renewable energy optimization", + "solar panel efficiency", + "wind turbine design", + "energy storage systems", + "smart grid technology", + "carbon emission reduction", + "energy management systems", + "power system reliability", + "sustainable energy development", + "energy efficiency improvements" + ] + + for i, query in enumerate(common_energy_queries): + # For energy papers, relevance would need to be determined + # by semantic similarity or keyword matching + eval_queries.append({ + "query_id": f"energy_query_{i}", + "query": query, + "relevant_doc_ids": [], # Would need manual annotation + "domain": "energy" + }) + + return eval_queries diff --git a/pipelines/adapters/natural_questions.py b/pipelines/adapters/natural_questions.py new file mode 100644 index 0000000..60207c0 --- /dev/null +++ b/pipelines/adapters/natural_questions.py @@ -0,0 +1,157 @@ +""" +Adapter for Natural Questions dataset. +""" +import json +from pathlib import Path +from typing import List, Dict, Any, Iterable + +from pipelines.contracts import BaseRow, DatasetAdapter, DatasetSplit +from langchain_core.documents import Document + + +class NaturalQuestionsRow(BaseRow): + """Row schema for Natural Questions dataset.""" + question: str + answer: str = "" + context: str = "" + long_answer: str = "" + short_answers: List[str] = [] + + class Config: + extra = "allow" + + +class NaturalQuestionsAdapter(DatasetAdapter): + """Adapter for Natural Questions dataset.""" + + def __init__(self, dataset_path: str, version: str = "1.0.0"): + self.dataset_path = Path(dataset_path) + self._version = version + + if not self.dataset_path.exists(): + raise FileNotFoundError(f"Natural Questions dataset not found at {self.dataset_path}") + + @property + def source_name(self) -> str: + return "natural_questions" + + @property + def version(self) -> str: + return self._version + + def read_rows(self, split: DatasetSplit = DatasetSplit.ALL) -> Iterable[NaturalQuestionsRow]: + """Read Natural Questions rows from JSONL files.""" + # Common NQ file patterns + file_patterns = { + DatasetSplit.TRAIN: ["train*.jsonl", "nq-train-*.jsonl"], + DatasetSplit.VALIDATION: ["dev*.jsonl", "nq-dev-*.jsonl", "val*.jsonl"], + DatasetSplit.TEST: ["test*.jsonl", "nq-test-*.jsonl"] + } + + files_to_read = [] + if split == DatasetSplit.ALL: + for patterns in file_patterns.values(): + for pattern in patterns: + files_to_read.extend(self.dataset_path.glob(pattern)) + else: + for pattern in file_patterns.get(split, []): + files_to_read.extend(self.dataset_path.glob(pattern)) + + # Fallback: read any JSONL files + if not files_to_read: + files_to_read = list(self.dataset_path.glob("*.jsonl")) + + for file_path in files_to_read: + with open(file_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f): + try: + data = json.loads(line.strip()) + yield self._parse_nq_item(data, f"{file_path.name}:{line_num}") + except (json.JSONDecodeError, KeyError) as e: + print(f"Skipping malformed line in {file_path}:{line_num}: {e}") + continue + + def _parse_nq_item(self, data: Dict[str, Any], external_id: str) -> NaturalQuestionsRow: + """Parse a Natural Questions item.""" + # Handle different NQ formats + question = data.get("question", data.get("question_text", "")) + + # Extract answers - NQ has complex answer structures + short_answers = [] + long_answer = "" + + if "annotations" in data: + for annotation in data["annotations"]: + if "short_answers" in annotation: + for sa in annotation["short_answers"]: + if "text" in sa: + short_answers.append(sa["text"]) + + if "long_answer" in annotation and "candidate_text" in annotation["long_answer"]: + long_answer = annotation["long_answer"]["candidate_text"] + + # Fallback for simpler formats + if not short_answers and "answer" in data: + if isinstance(data["answer"], list): + short_answers = data["answer"] + else: + short_answers = [str(data["answer"])] + + context = data.get("document_text", data.get("context", "")) + + return NaturalQuestionsRow( + external_id=external_id, + question=question, + answer=short_answers[0] if short_answers else "", + context=context, + long_answer=long_answer, + short_answers=short_answers + ) + + def to_documents(self, rows: List[NaturalQuestionsRow], split: DatasetSplit) -> List[Document]: + """Convert NQ rows to Documents - treating contexts as retrievable documents.""" + documents = [] + + for row in rows: + if not row.context or not row.context.strip(): + continue + + # Create document from context + metadata = { + "external_id": row.external_id, + "question": row.question, + "answers": row.short_answers, + "long_answer": row.long_answer, + "source": self.source_name, + "dataset_version": self.version, + "split": split.value, + "doc_type": "context" + } + + documents.append(Document( + page_content=row.context, + metadata=metadata + )) + + return documents + + def get_evaluation_queries(self, split: DatasetSplit = DatasetSplit.TEST) -> List[Dict[str, Any]]: + """Return evaluation queries for Natural Questions.""" + eval_queries = [] + + for row in self.read_rows(split): + if not row.question: + continue + + # For NQ, relevant docs are the contexts that contain answers + relevant_docs = [row.external_id] if row.context and row.short_answers else [] + + eval_queries.append({ + "query_id": row.external_id, + "query": row.question, + "relevant_doc_ids": relevant_docs, + "gold_answers": row.short_answers, + "long_answer": row.long_answer + }) + + return eval_queries diff --git a/pipelines/adapters/stackoverflow.py b/pipelines/adapters/stackoverflow.py new file mode 100644 index 0000000..c9e37fe --- /dev/null +++ b/pipelines/adapters/stackoverflow.py @@ -0,0 +1,444 @@ +""" +Adapter for Stack Overflow dataset (SOSum format). +Handles the SOSum dataset from: https://github.com/BonanKou/SOSum-A-Dataset-of-Extractive-Summaries-of-Stack-Overflow-Posts-and-labeling-tools +""" +import json +import csv +import ast +from pathlib import Path +from typing import List, Dict, Any, Iterable, Optional + +from pipelines.contracts import BaseRow, DatasetAdapter, DatasetSplit +from langchain_core.documents import Document + + +class StackOverflowRow(BaseRow): + """Row schema for Stack Overflow posts (SOSum format).""" + title: str + body: str # For questions: question body, for answers: answer body + tags: List[str] = [] + post_type: str = "question" # "question" or "answer" + # 1=conceptual, 2=how-to, 3=debug-corrective + question_type: Optional[int] = None + summary: Optional[str] = None # For answers: extractive summary + related_posts: List[str] = [] # For questions: answer post IDs + + class Config: + extra = "allow" + + +class StackOverflowAdapter(DatasetAdapter): + """Adapter for Stack Overflow dataset (SOSum format).""" + + def __init__(self, dataset_path: str, version: str = "1.0.0"): + self.dataset_path = Path(dataset_path) + self._version = version + + if not self.dataset_path.exists(): + raise FileNotFoundError( + f"SOSum dataset not found at {self.dataset_path}") + + # Check for SOSum format files + self.question_file = self.dataset_path / "question.csv" + self.answer_file = self.dataset_path / "answer.csv" + + # Also check for data subfolder (common in SOSum) + if not self.question_file.exists(): + data_dir = self.dataset_path / "data" + if data_dir.exists(): + self.question_file = data_dir / "question.csv" + self.answer_file = data_dir / "answer.csv" + + if not self.question_file.exists() or not self.answer_file.exists(): + raise FileNotFoundError( + f"SOSum format files not found. Expected question.csv and answer.csv in {self.dataset_path} or {self.dataset_path}/data/" + ) + + @property + def source_name(self) -> str: + return "stackoverflow_sosum" + + @property + def version(self) -> str: + return self._version + + def read_rows(self, split: DatasetSplit = DatasetSplit.ALL) -> Iterable[StackOverflowRow]: + """Read SOSum Stack Overflow posts from CSV files.""" + # Read questions first + yield from self._read_questions() + + # Then read answers + yield from self._read_answers() + + def _read_questions(self) -> Iterable[StackOverflowRow]: + """Read questions from question.csv.""" + try: + with open(self.question_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader): + try: + yield self._parse_question_row(row, f"q{row_num}") + except Exception as e: + print(f"Error parsing question row {row_num}: {e}") + continue + except Exception as e: + print(f"Error reading questions file {self.question_file}: {e}") + + def _read_answers(self) -> Iterable[StackOverflowRow]: + """Read answers from answer.csv.""" + try: + with open(self.answer_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader): + try: + yield self._parse_answer_row(row, f"a{row_num}") + except Exception as e: + print(f"Error parsing answer row {row_num}: {e}") + continue + except Exception as e: + print(f"Error reading answers file {self.answer_file}: {e}") + + def _parse_question_row(self, row: Dict[str, str], external_id: str) -> StackOverflowRow: + """Parse a question row from question.csv.""" + question_id = row.get( + "Question Id", row.get("question_id", external_id)) + title = row.get("Question Title", row.get("question_title", "")) + body = row.get("Question Body", row.get("question_body", "")) + + # Parse question body if it's a list representation + if body.startswith('[') and body.endswith(']'): + try: + body_list = ast.literal_eval(body) + body = " ".join(body_list) if isinstance( + body_list, list) else body + except: + pass # Keep original if parsing fails + + # Parse tags + tags_str = row.get("Tags", row.get("tags", "")) + tags = [] + if tags_str: + # Handle different tag formats + if tags_str.startswith('[') and tags_str.endswith(']'): + try: + tags = ast.literal_eval(tags_str) + except: + tags = tags_str.strip('[]').replace("'", "").split(',') + else: + tags = tags_str.split(',') + tags = [tag.strip() for tag in tags if tag.strip()] + + # Parse question type + question_type = None + type_str = row.get("Question Type", row.get("question_type")) + if type_str: + try: + question_type = int(type_str) + except: + pass + + # Parse related answer posts + answer_posts = [] + answer_posts_str = row.get("Answer Posts", row.get("answer_posts", "")) + if answer_posts_str: + # Handle list format like [315365] or [315365, 123456] + if answer_posts_str.startswith('[') and answer_posts_str.endswith(']'): + try: + answer_posts_list = ast.literal_eval(answer_posts_str) + answer_posts = [str(post).strip() + for post in answer_posts_list if str(post).strip()] + except: + # Fallback to manual parsing + clean_str = answer_posts_str.strip('[]') + answer_posts = [post.strip() + for post in clean_str.split(',') if post.strip()] + else: + answer_posts = [ + post.strip() for post in answer_posts_str.split(',') if post.strip()] + + return StackOverflowRow( + external_id=f"q_{question_id}", + title=title, + body=body, + tags=tags, + post_type="question", + question_type=question_type, + related_posts=answer_posts + ) + + def _parse_answer_row(self, row: Dict[str, str], external_id: str) -> StackOverflowRow: + """Parse an answer row from answer.csv.""" + answer_id = row.get("Answer Id", row.get("answer_id", external_id)) + body = row.get("Answer Body", row.get("answer_body", "")) + summary = row.get("Summary", row.get("summary", "")) + + # Parse answer body if it's a list representation + if body.startswith('[') and body.endswith(']'): + try: + body_list = ast.literal_eval(body) + body = " ".join(body_list) if isinstance( + body_list, list) else body + except: + pass + + # Parse summary if it's a list representation + if summary.startswith('[') and summary.endswith(']'): + try: + summary_list = ast.literal_eval(summary) + summary = " ".join(summary_list) if isinstance( + summary_list, list) else summary + except: + pass + + return StackOverflowRow( + external_id=f"a_{answer_id}", + title="", # Answers don't have titles + body=body, + tags=[], # Answers inherit tags from questions + post_type="answer", + summary=summary + ) + + def to_documents(self, rows: Iterable[StackOverflowRow], split: DatasetSplit = DatasetSplit.ALL) -> List[Document]: + """Convert SOSum rows to LangChain documents. + + RAG-optimized approach: Only answers are ingested as retrievable documents. + Questions are stored as metadata/context for their corresponding answers. + This ensures users retrieve valuable answers, not just questions. + """ + documents = [] + questions_map = {} # Map question_id -> question data + + # First pass: collect all rows and build questions map + all_rows = list(rows) + + # Build questions map for context + for row in all_rows: + if row.post_type == "question": + questions_map[row.external_id] = row + + # Second pass: create documents only from answers, with question context + for row in all_rows: + if row.post_type != "answer": + continue # Only process answers as primary documents + + # Skip empty answers + if not row.body.strip(): + continue + + # Build answer content + answer_content = row.body.strip() + if row.summary and row.summary.strip(): + answer_content = f"{answer_content}\n\n[Summary: {row.summary.strip()}]" + + # Find the corresponding question for context + question_context = None + question_title = "" + question_tags = [] + + # Extract the numeric answer ID (remove 'a_' prefix) + answer_id_num = row.external_id.replace('a_', '') + + # Try to find matching question by looking for related posts + for q_id, question in questions_map.items(): + # Check if this answer ID is in the question's related_posts + if answer_id_num in question.related_posts: + question_context = question.body + question_title = question.title + question_tags = question.tags + break + + # If no direct link found, try to find question with similar ID + # This is a fallback heuristic for datasets where linking isn't perfect + if not question_context: + for q_id, question in questions_map.items(): + question_id_num = q_id.replace('q_', '') + if question_id_num == answer_id_num: + question_context = question.body + question_title = question.title + question_tags = question.tags + break + + # Create the final document content + # The answer is the primary content, question provides context + content_parts = [] + if question_title: + content_parts.append(f"Q: {question_title}") + if question_context and question_context.strip(): + content_parts.append( + f"Question Details: {question_context.strip()}") + content_parts.append(f"Answer: {answer_content}") + + content = "\n\n".join(content_parts) + + metadata = { + "external_id": row.external_id, + "source": self.source_name, + "post_type": "answer", # Always answer since we only ingest answers + "doc_type": "answer", # Always answer + "tags": question_tags, # Use question tags + "title": question_title if question_title else None, + "split": split.value, + "answer_body": row.body, # Store pure answer separately + } + + # Add answer-specific metadata + if row.summary: + metadata["summary"] = row.summary + metadata["has_summary"] = True + + # Add question context as metadata + if question_context: + metadata["question_context"] = question_context + metadata["has_question_context"] = True + + # Remove None values + metadata = {k: v for k, v in metadata.items() if v is not None} + + documents.append(Document( + page_content=content, + metadata=metadata + )) + + return documents + + def get_evaluation_queries(self) -> List[Dict[str, Any]]: + """Get evaluation queries optimized for answer retrieval. + + Creates queries based on question titles/content that should retrieve + the corresponding answers, not the questions themselves. + """ + evaluation_queries = [] + + # Build a map of questions to their related answers + question_to_answers = {} + answer_ids = set() + + # First, collect all answer IDs + try: + with open(self.answer_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader): + answer_id = row.get("Answer Id", row.get( + "answer_id", f"a{row_num}")) + answer_ids.add(f"a_{answer_id}") + except Exception as e: + print(f"Error reading answer IDs: {e}") + + # Read questions and create queries that should retrieve answers + try: + with open(self.question_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader): + title = row.get("Question Title", + row.get("question_title", "")) + body = row.get("Question Body", + row.get("question_body", "")) + question_id = row.get("Question Id", row.get( + "question_id", f"q{row_num}")) + + # Parse related answer posts + answer_posts_str = row.get( + "Answer Posts", row.get("answer_posts", "")) + related_answers = [] + if answer_posts_str: + related_answers = [ + f"a_{post.strip()}" for post in answer_posts_str.split(',') if post.strip()] + # Filter to only include answers that actually exist + related_answers = [ + aid for aid in related_answers if aid in answer_ids] + + # If no explicit related answers, try to infer by ID similarity + if not related_answers: + potential_answer = f"a_{question_id}" + if potential_answer in answer_ids: + related_answers = [potential_answer] + + # Only create queries if we have answers to retrieve + if related_answers and title and len(title) > 10: + # Query from question title - should retrieve answers + evaluation_queries.append({ + "query": title, + "expected_docs": related_answers, + "query_type": "question_title_to_answer", + "query_id": f"eval_q2a_{question_id}", + "description": f"Question title should retrieve answer(s)" + }) + + # Query from shortened title + short_query = " ".join(title.split()[:6]) + if len(short_query) > 10: + evaluation_queries.append({ + "query": short_query, + "expected_docs": related_answers, + "query_type": "question_short_to_answer", + "query_id": f"eval_q2a_short_{question_id}", + "description": f"Short question should retrieve answer(s)" + }) + + # Query from question body (first sentence) + if body and len(body) > 20: + # Parse body if it's a list + if body.startswith('[') and body.endswith(']'): + try: + body_list = ast.literal_eval(body) + body = " ".join(body_list) if isinstance( + body_list, list) else body + except: + pass + + # Take first sentence or first 100 chars + first_sentence = body.split( + '.')[0] if '.' in body else body[:100] + if len(first_sentence) > 20: + evaluation_queries.append({ + "query": first_sentence, + "expected_docs": related_answers, + "query_type": "question_body_to_answer", + "query_id": f"eval_qbody2a_{question_id}", + "description": f"Question body should retrieve answer(s)" + }) + + # Limit to reasonable number for testing + if len(evaluation_queries) >= 60: + break + except Exception as e: + print(f"Error creating evaluation queries from questions: {e}") + + # Add queries based on answer summaries - these should retrieve the same answers + try: + with open(self.answer_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader): + summary = row.get("Summary", row.get("summary", "")) + answer_id = row.get("Answer Id", row.get( + "answer_id", f"a{row_num}")) + expected_answer_id = f"a_{answer_id}" + + if summary and len(summary) > 20 and expected_answer_id in answer_ids: + # Parse summary if it's a list + if summary.startswith('[') and summary.endswith(']'): + try: + summary_list = ast.literal_eval(summary) + # Use first 2 sentences + summary = " ".join(summary_list[:2]) + except: + # Fallback to first 150 chars + summary = summary[:150] + + if len(summary) > 15: + evaluation_queries.append({ + "query": summary, + "expected_docs": [expected_answer_id], + "query_type": "answer_summary_to_answer", + "query_id": f"eval_a2a_{answer_id}", + "description": f"Answer summary should retrieve the full answer" + }) + + # Limit total queries + if len(evaluation_queries) >= 100: + break + except Exception as e: + print(f"Error creating evaluation queries from answers: {e}") + + return evaluation_queries[:100] # Return max 100 queries diff --git a/pipelines/configs/README.md b/pipelines/configs/README.md new file mode 100644 index 0000000..698db4b --- /dev/null +++ b/pipelines/configs/README.md @@ -0,0 +1,156 @@ +# Pipeline Configuration Directory + +This directory contains all configuration files for different pipeline components, organized by purpose and usage. + +## 📁 Directory Structure + +``` +pipelines/configs/ +├── README.md # This documentation +├── retriever_config_loader.py # Configuration loading utilities +├── datasets/ # Dataset-specific pipeline configurations +├── retrieval/ # Agent retrieval configurations +├── examples/ # Example configurations and templates +└── legacy/ # Deprecated/old configurations +``` + +## 🗂️ Configuration Categories + +### 📊 `datasets/` - Dataset Pipeline Configurations +Contains configurations for processing and ingesting different datasets. + +- `stackoverflow.yml` - Main SOSum Stack Overflow dataset configuration +- `stackoverflow_hybrid.yml` - Hybrid embedding variant for Stack Overflow +- `natural_questions.yml` - Google Natural Questions dataset configuration +- `energy_papers.yml` - Energy papers dataset configuration + +**Purpose**: Data ingestion, chunking, embedding, and indexing pipelines. + +### 🤖 `retrieval/` - Agent Retrieval Configurations +Contains configurations for the agent's retrieval system. + +- `modern_hybrid.yml` - Advanced hybrid retrieval with RRF fusion and reranking +- `modern_dense.yml` - Dense retrieval with neural reranking +- `fast_hybrid.yml` - Speed-optimized hybrid retrieval + +**Purpose**: Agent question-answering and document retrieval. + +### 📚 `examples/` - Example Configurations +Contains template and example configuration files. + +- `batch_example.json` - Example batch processing configuration + +**Purpose**: Templates for creating new configurations. + +### 🗄️ `legacy/` - Deprecated Configurations +Contains older configuration files kept for compatibility. + +- `stackoverflow_bge_large.yml` - BGE large embedding configuration +- `stackoverflow_e5_large.yml` - E5 large embedding configuration +- `stackoverflow_minilm.yml` - MiniLM embedding configuration + +**Purpose**: Backward compatibility and reference. + +## 🚀 Usage + +### For Dataset Processing +```python +# Load dataset configuration +from pipelines.configs.retriever_config_loader import load_config +config = load_config("datasets/stackoverflow.yml") +``` + +### For Agent Retrieval +```bash +# Switch agent retrieval configuration +python bin/switch_agent_config.py modern_hybrid +``` + +### For Benchmarking +```python +# Use different retrieval configurations +configs = [ + "retrieval/modern_hybrid.yml", + "retrieval/modern_dense.yml", + "retrieval/fast_hybrid.yml" +] +``` + +## 🔧 Configuration Types + +### Dataset Configuration Schema +```yaml +dataset: + name: "dataset_name" + version: "1.0.0" + description: "Dataset description" + +embedding: + strategy: "hybrid" # dense, sparse, or hybrid + dense: + provider: "provider_name" + model: "model_name" + sparse: + provider: "provider_name" + model: "model_name" + +chunking: + strategy: "recursive" + chunk_size: 512 + chunk_overlap: 50 +``` + +### Agent Retrieval Configuration Schema +```yaml +description: "Configuration description" + +retrieval_pipeline: + retriever: + type: "hybrid" # dense, sparse, or hybrid + top_k: 20 + fusion_method: "rrf" + + stages: + - type: "retriever" + - type: "score_filter" + config: + min_score: 0.01 + - type: "reranker" + config: + model_type: "cross_encoder" +``` + +## 📝 Best Practices + +1. **Naming Convention**: + - Dataset configs: `{dataset_name}.yml` + - Retrieval configs: `{strategy}_{variant}.yml` + - Legacy configs: `{original_name}.yml` (in legacy/) + +2. **Documentation**: + - Include `description` field in all configs + - Add comments for complex parameters + - Document expected performance characteristics + +3. **Version Control**: + - Keep legacy configs for reproducibility + - Version dataset configurations + - Test new configs before deployment + +## 🔍 Finding Configurations + +- **For data processing**: Look in `datasets/` +- **For agent retrieval**: Look in `retrieval/` +- **For examples/templates**: Look in `examples/` +- **For old/deprecated**: Look in `legacy/` + +## 🛠️ Maintenance + +- Regularly review and clean up unused configurations +- Move outdated configs to `legacy/` instead of deleting +- Update documentation when adding new configuration types +- Test configuration loading after structural changes + +--- + +*Last updated: August 30, 2025* diff --git a/pipelines/configs/__init__.py b/pipelines/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/configs/datasets/energy_papers.yml b/pipelines/configs/datasets/energy_papers.yml new file mode 100644 index 0000000..90d6655 --- /dev/null +++ b/pipelines/configs/datasets/energy_papers.yml @@ -0,0 +1,81 @@ +# Energy Research Papers Configuration +dataset: + name: "energy_papers" + version: "1.0.0" + description: "Energy research papers collection" + +# Embedding strategy +embedding_strategy: hybrid + +embedding: + dense: + provider: hf + model_name: sentence-transformers/all-MiniLM-L6-v2 + batch_size: 16 # Smaller batches for academic content + device: cuda + vector_name: dense + sparse: + provider: fastembed + model_name: Qdrant/bm25 + vector_name: sparse + +# Chunking configuration +chunking: + strategy: semantic # Best for academic papers + chunk_size: 600 # Larger chunks for academic context + chunk_overlap: 100 + max_chunk_size: 1000 + sentence_overlap: 2 # More overlap for academic continuity + +# Validation settings +validation: + min_char_length: 100 + max_char_length: 50000 # Allow very long content for papers + remove_duplicates: true + clean_html: false # Academic papers may have minimal HTML + preserve_code_blocks: false + +# Retriever configuration +retriever: + type: qdrant + top_k: 10 + +# Qdrant settings +qdrant: + collection: energy_papers_hybrid_v1 + dense_vector_name: dense + sparse_vector_name: sparse + +# Upload settings +upload: + batch_size: 25 # Smaller batches for large academic content + wait: true + versioning: true + +# Evaluation settings +evaluation: + k_values: [1, 3, 5, 10] + similarity_threshold: 0.75 + +# Smoke tests +smoke_tests: + min_success_rate: 0.7 # Lower threshold for specialized domain + golden_queries: + - query: "renewable energy optimization" + min_recall: 0.1 + - query: "solar panel efficiency" + min_recall: 0.1 + - query: "wind turbine design" + min_recall: 0.1 + - query: "energy storage systems" + min_recall: 0.1 + - query: "smart grid technology" + min_recall: 0.1 + +# Output configuration +output_dir: "output/energy_papers" + +# Embedding cache +embedding_cache: + enabled: true + dir: "cache/embeddings/energy_papers" diff --git a/pipelines/configs/datasets/natural_questions.yml b/pipelines/configs/datasets/natural_questions.yml new file mode 100644 index 0000000..51219a1 --- /dev/null +++ b/pipelines/configs/datasets/natural_questions.yml @@ -0,0 +1,79 @@ +# Natural Questions Dataset Configuration +dataset: + name: "natural_questions" + version: "1.0.0" + description: "Google Natural Questions dataset" + +# Embedding strategy: dense, sparse, or hybrid +embedding_strategy: hybrid + +embedding: + dense: + provider: hf + model_name: sentence-transformers/all-MiniLM-L6-v2 + batch_size: 32 + device: cuda + vector_name: dense + sparse: + provider: fastembed + model_name: Qdrant/bm25 + vector_name: sparse + +# Chunking configuration +chunking: + strategy: semantic # Best for Q&A content + chunk_size: 400 # Smaller chunks for precise Q&A + chunk_overlap: 50 + max_chunk_size: 600 + sentence_overlap: 1 + +# Validation settings +validation: + min_char_length: 30 + max_char_length: 10000 + remove_duplicates: true + clean_html: true + preserve_code_blocks: false + +# Retriever configuration +retriever: + type: qdrant + top_k: 10 + +# Qdrant settings +qdrant: + collection: nq_hybrid_v1 + dense_vector_name: dense + sparse_vector_name: sparse + +# Upload settings +upload: + batch_size: 100 + wait: true + versioning: true + +# Evaluation settings +evaluation: + k_values: [1, 3, 5, 10, 20] + similarity_threshold: 0.8 + semantic_matching: true + +# Smoke tests +smoke_tests: + min_success_rate: 0.8 + min_overall_success: 0.7 + golden_queries: + - query: "What is the capital of France?" + relevant_doc_ids: [] + min_recall: 0.1 + - query: "How does photosynthesis work?" + relevant_doc_ids: [] + min_recall: 0.1 + +# Output configuration +output_dir: "output/natural_questions" + +# Embedding cache +embedding_cache: + enabled: true + dir: "cache/embeddings/nq" diff --git a/pipelines/configs/datasets/stackoverflow.yml b/pipelines/configs/datasets/stackoverflow.yml new file mode 100644 index 0000000..9b56e33 --- /dev/null +++ b/pipelines/configs/datasets/stackoverflow.yml @@ -0,0 +1,81 @@ +# SOSum Stack Overflow Dataset Configuration +# Dataset: https://github.com/BonanKou/SOSum-A-Dataset-of-Extractive-Summaries-of-Stack-Overflow-Posts-and-labeling-tools +dataset: + name: "stackoverflow_sosum" + version: "1.0.0" + description: "SOSum: Extractive summaries of Stack Overflow posts (506 questions, 2278 posts)" + +# Embedding strategy +embedding_strategy: hybrid + +embedding: + dense: + provider: hf + model_name: sentence-transformers/all-MiniLM-L6-v2 + batch_size: 32 + device: cuda + vector_name: dense + sparse: + provider: fastembed + model_name: Qdrant/bm25 + vector_name: sparse + +# Chunking configuration +chunking: + strategy: code_aware # Best for code-heavy content + chunk_size: 800 # Larger chunks for code context + chunk_overlap: 100 + preserve_functions: true + preserve_code_blocks: true + +# Validation settings +validation: + min_char_length: 30 # SOSum has shorter summaries + max_char_length: 50000 # Allow very long content for complex SO posts + remove_duplicates: true + clean_html: true + preserve_code_blocks: true + allowed_languages: ["en"] + +# Retriever configuration +retriever: + type: qdrant + top_k: 15 + +# Qdrant settings +qdrant: + collection: sosum_stackoverflow_v1 + dense_vector_name: dense + sparse_vector_name: sparse + +# Upload settings +upload: + batch_size: 25 # Smaller batches for potentially large posts + wait: true + versioning: true + +# Evaluation settings +evaluation: + k_values: [1, 3, 5, 10, 15] + similarity_threshold: 0.7 + +# Smoke tests +smoke_tests: + min_success_rate: 0.8 + golden_queries: + - query: "Python list comprehension example" + min_recall: 0.1 + - query: "JavaScript async function" + min_recall: 0.1 + - query: "How to solve error in code" + min_recall: 0.1 + - query: "Best practice programming" + min_recall: 0.1 + +# Output configuration +output_dir: "output/sosum_stackoverflow" + +# Embedding cache +embedding_cache: + enabled: true + dir: "cache/embeddings/sosum_stackoverflow" diff --git a/pipelines/configs/datasets/stackoverflow_hybrid.yml b/pipelines/configs/datasets/stackoverflow_hybrid.yml new file mode 100644 index 0000000..6361a04 --- /dev/null +++ b/pipelines/configs/datasets/stackoverflow_hybrid.yml @@ -0,0 +1,44 @@ +# Configuration for Hybrid Dense + Sparse Embeddings + +dataset: + name: "stackoverflow_sosum" + version: "v1.0.0" + adapter: "stackoverflow" # REQUIRED: This was missing! + +chunking: + strategy: "recursive" # FIXED: Strategy name + chunk_size: 512 + chunk_overlap: 50 + separators: ["\n\n", "\n", " ", ""] # REQUIRED for recursive chunking + +embedding: + strategy: "hybrid" # FIXED: Moved from top level + dense: + provider: "google" + model: "models/embedding-001" # FIXED: changed from model_name to model + batch_size: 32 + sparse: + provider: "sparse" + model: "Qdrant/bm25" # FIXED: changed from model_name to model + batch_size: 32 + +qdrant: + collection: "sosum_stackoverflow_hybrid_v1" + dense_vector_name: "dense" + sparse_vector_name: "sparse" + distance_metric: "cosine" # REQUIRED: Added missing field + +upload: + batch_size: 50 + wait: true + versioning: true # ADDED: For proper versioning + +validation: + enabled: true # REQUIRED: Added missing field + max_text_length: 10000 # FIXED: changed from max_char_length + min_text_length: 10 # FIXED: changed from min_char_length + +smoke_tests: + enabled: true + sample_size: 5 # REQUIRED: Added missing field + min_success_rate: 0.7 diff --git a/pipelines/configs/examples/dataset_template.yml b/pipelines/configs/examples/dataset_template.yml new file mode 100644 index 0000000..e9ad91b --- /dev/null +++ b/pipelines/configs/examples/dataset_template.yml @@ -0,0 +1,95 @@ +# Template for Dataset Pipeline Configuration +# Copy this file and modify for your specific dataset + +dataset: + name: "your_dataset_name" + version: "1.0.0" + description: "Description of your dataset" + adapter: "adapter_name" # Required: specify which adapter to use + +# Chunking strategy for document processing +chunking: + strategy: "recursive" # Options: recursive, fixed, semantic + chunk_size: 512 # Size of each chunk in tokens/characters + chunk_overlap: 50 # Overlap between chunks + separators: ["\n\n", "\n", " ", ""] # For recursive chunking + +# Embedding configuration +embedding: + strategy: "hybrid" # Options: dense, sparse, hybrid + + # Dense embeddings (semantic similarity) + dense: + provider: "google" # Options: google, hf, openai + model: "models/embedding-001" # Model identifier + dimensions: 768 # Embedding dimensions + batch_size: 32 # Batch size for processing + vector_name: "dense" # Vector field name in database + + # Sparse embeddings (keyword matching) + sparse: + provider: "sparse" # Usually "sparse" or "fastembed" + model: "Qdrant/bm25" # Sparse model identifier + vector_name: "sparse" # Vector field name in database + +# Database configuration +qdrant: + collection_name: "your_collection_name" + host: "localhost" + port: 6333 + + # Vector configuration + dense_vector_name: "dense" + sparse_vector_name: "sparse" + + # Performance settings + shard_number: 1 + replication_factor: 1 + +# Processing pipeline stages +pipeline: + # Document loading and parsing + loader: + type: "custom" # Depends on your data format + + # Text processing and cleaning + processor: + clean_text: true + remove_special_chars: false + normalize_whitespace: true + + # Chunking and splitting + chunker: + use_overlap: true + preserve_context: true + + # Embedding generation + embedder: + parallel_processing: true + cache_embeddings: true + + # Database insertion + indexer: + batch_size: 100 + create_if_not_exists: true + +# Performance and optimization +performance: + lazy_initialization: true + enable_caching: true + cache_ttl: 3600 # Cache time-to-live in seconds + parallel_workers: 4 # Number of parallel workers + memory_limit: "2GB" # Memory limit for processing + +# Logging and monitoring +logging: + level: "INFO" # DEBUG, INFO, WARNING, ERROR + log_embeddings: false # Log embedding generation details + log_performance: true # Log performance metrics + +# Validation and quality checks +validation: + check_duplicates: true + min_chunk_size: 10 # Minimum chunk size in characters + max_chunk_size: 2048 # Maximum chunk size in characters + validate_embeddings: true # Validate embedding generation diff --git a/pipelines/configs/examples/retrieval_template.yml b/pipelines/configs/examples/retrieval_template.yml new file mode 100644 index 0000000..5d4bff6 --- /dev/null +++ b/pipelines/configs/examples/retrieval_template.yml @@ -0,0 +1,118 @@ +# Template for Agent Retrieval Configuration +# Copy this file and modify for your specific retrieval needs + +description: "Template for agent retrieval configuration" + +# Main retrieval pipeline configuration +retrieval_pipeline: + # Primary retriever configuration + retriever: + type: "hybrid" # Options: dense, sparse, hybrid + top_k: 20 # Number of candidates to retrieve + score_threshold: 0.01 # Minimum score threshold + + # For hybrid retrieval + fusion_method: "rrf" # Options: rrf, weighted_sum + + # Fusion parameters + fusion: + method: "rrf" # Reciprocal rank fusion + rrf_k: 60 # RRF parameter (higher = more democratic) + dense_weight: 0.7 # Weight for dense results (weighted_sum) + sparse_weight: 0.3 # Weight for sparse results (weighted_sum) + + # Embedding configuration + embedding: + strategy: "hybrid" + dense: + provider: "google" + model: "models/embedding-001" + dimensions: 768 + api_key_env: "GOOGLE_API_KEY" + batch_size: 32 + vector_name: "dense" + sparse: + provider: "sparse" + model: "Qdrant/bm25" + vector_name: "sparse" + + # Database configuration + qdrant: + collection_name: "your_collection_name" + dense_vector_name: "dense" + sparse_vector_name: "sparse" + + # Performance settings + performance: + lazy_initialization: true + batch_size: 32 + enable_caching: true + parallel_search: false + + # Processing stages (ordered pipeline) + stages: + # Stage 1: Primary retrieval (automatically added) + - type: "retriever" + name: "primary_retriever" + + # Stage 2: Score filtering + - type: "score_filter" + name: "score_filter" + config: + min_score: 0.01 # Filter out low-quality results + max_results: 15 # Limit results for efficiency + + # Stage 3: Neural reranking + - type: "reranker" + name: "neural_reranker" + config: + model_type: "cross_encoder" + model_name: "cross-encoder/ms-marco-MiniLM-L-6-v2" + top_k: 10 # Final number of results + batch_size: 16 # Reranking batch size + +# Global configuration (inherited by all components) +embedding_strategy: "hybrid" + +# Quality and validation settings +quality: + min_relevance_score: 0.1 # Minimum relevance threshold + diversity_threshold: 0.8 # Avoid too similar results + +# Performance optimization +optimization: + cache_embeddings: true + cache_ttl: 1800 # 30 minutes + prefetch_size: 50 # Prefetch candidates + +# Monitoring and logging +monitoring: + log_retrieval_stats: true + log_pipeline_timing: true + track_query_patterns: false + +# Advanced settings +advanced: + # Query expansion + query_expansion: + enabled: false + method: "synonyms" # Options: synonyms, embeddings + max_expansions: 3 + + # Result diversification + diversification: + enabled: false + method: "mmr" # Maximal Marginal Relevance + lambda_param: 0.7 # Balance between relevance and diversity + + # Custom scoring + custom_scoring: + enabled: false + boost_recent: false # Boost more recent documents + boost_popular: false # Boost popular documents + +# Error handling +error_handling: + fallback_to_dense: true # If hybrid fails, use dense only + max_retries: 3 # Retry attempts for failed queries + timeout_seconds: 30 # Query timeout diff --git a/pipelines/configs/legacy/stackoverflow_bge_large.yml b/pipelines/configs/legacy/stackoverflow_bge_large.yml new file mode 100644 index 0000000..f855c65 --- /dev/null +++ b/pipelines/configs/legacy/stackoverflow_bge_large.yml @@ -0,0 +1,21 @@ +# Configuration for BGE Large +dataset: + name: "stackoverflow_sosum" + version: "1.0.0" + +embedding_strategy: hybrid + +embedding: + dense: + provider: hf + model_name: BAAI/bge-large-en-v1.5 + batch_size: 16 # Smaller batch for larger model + +qdrant: + collection: sosum_stackoverflow_bge_large_v1 # Different collection + dense_vector_name: dense + sparse_vector_name: sparse + +output_dir: "output/sosum_bge_large" +embedding_cache: + dir: "cache/embeddings/sosum_bge_large" diff --git a/pipelines/configs/legacy/stackoverflow_e5_large.yml b/pipelines/configs/legacy/stackoverflow_e5_large.yml new file mode 100644 index 0000000..4100eac --- /dev/null +++ b/pipelines/configs/legacy/stackoverflow_e5_large.yml @@ -0,0 +1,20 @@ +# Configuration for E5 Large +dataset: + name: "stackoverflow_sosum" + version: "1.0.0" + +embedding_strategy: dense # Dense only for comparison + +embedding: + dense: + provider: hf + model_name: intfloat/e5-large-v2 + batch_size: 8 # Even smaller batch + +qdrant: + collection: sosum_stackoverflow_e5_large_v1 # Another collection + dense_vector_name: dense + +output_dir: "output/sosum_e5_large" +embedding_cache: + dir: "cache/embeddings/sosum_e5_large" diff --git a/pipelines/configs/legacy/stackoverflow_minilm.yml b/pipelines/configs/legacy/stackoverflow_minilm.yml new file mode 100644 index 0000000..9f1edb6 --- /dev/null +++ b/pipelines/configs/legacy/stackoverflow_minilm.yml @@ -0,0 +1,21 @@ +# Configuration for Sentence Transformers +dataset: + name: "stackoverflow_sosum" + version: "1.0.0" + +embedding_strategy: dense + +embedding: + dense: + provider: hf + model_name: sentence-transformers/all-MiniLM-L6-v2 + batch_size: 32 + +qdrant: + collection: sosum_stackoverflow_minilm_v1 # Unique collection name + dense_vector_name: dense + sparse_vector_name: sparse + +output_dir: "output/sosum_minilm" +embedding_cache: + dir: "cache/embeddings/sosum_minilm" diff --git a/pipelines/configs/retrieval/ci_google_gemini.yml b/pipelines/configs/retrieval/ci_google_gemini.yml new file mode 100644 index 0000000..a224f16 --- /dev/null +++ b/pipelines/configs/retrieval/ci_google_gemini.yml @@ -0,0 +1,47 @@ +description: "Google Gemini only retrieval config for CI testing" + +retrieval_pipeline: + retriever: + type: "dense" + top_k: 5 + score_threshold: 0.1 + + # Embedding config needs to be inside the retriever config + embedding: + strategy: dense + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 16 + vector_name: dense + + # Qdrant config needs to be inside the retriever config + qdrant: + collection_name: test_ci_collection + dense_vector_name: dense + host: localhost + port: 6333 + + performance: + lazy_initialization: true + batch_size: 16 + enable_caching: false + + stages: + - type: retriever + name: primary_retriever + - type: score_filter + name: score_filter + config: + min_score: 0.1 + max_results: 5 + +# Global configs for backward compatibility +embedding_strategy: dense + +qdrant: + collection: test_ci_collection + host: localhost + port: 6333 diff --git a/pipelines/configs/retrieval/fast_hybrid.yml b/pipelines/configs/retrieval/fast_hybrid.yml new file mode 100644 index 0000000..82f2ba6 --- /dev/null +++ b/pipelines/configs/retrieval/fast_hybrid.yml @@ -0,0 +1,74 @@ +# High-Performance Retrieval Configuration for Agent +# Optimized for speed with minimal reranking + +description: "Fast hybrid retrieval optimized for agent response speed" + +# Retrieval Pipeline Configuration +retrieval_pipeline: + retriever: + type: hybrid + top_k: 10 # Fewer candidates for speed + score_threshold: 0.05 # Higher threshold for speed + fusion_method: rrf + + # Fusion configuration + fusion: + method: rrf + rrf_k: 50 # Slightly more aggressive ranking + dense_weight: 0.8 # Favor dense for speed + sparse_weight: 0.2 + + # Embedding configuration + embedding: + strategy: hybrid + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 16 # Smaller batches for faster response + vector_name: dense + sparse: + provider: sparse + model: Qdrant/bm25 + vector_name: sparse + + # Database configuration + qdrant: + collection_name: sosum_stackoverflow_hybrid_v1 + dense_vector_name: dense + sparse_vector_name: sparse + + # Performance settings + performance: + lazy_initialization: true + batch_size: 16 + enable_caching: true + parallel_search: true # Enable for speed + + # Pipeline stages (minimal for speed) + stages: + - type: retriever + name: hybrid_retriever + config: + retriever_type: hybrid + + - type: score_filter + name: score_filter + config: + min_score: 0.01 # Lower threshold for better results + max_results: 8 + + - type: reranker + name: cross_encoder_reranker + config: + model_type: cross_encoder + model_name: cross-encoder/ms-marco-TinyBERT-L-2-v2 # Faster model + top_k: 5 # Fewer final results for speed + batch_size: 8 + +# Agent-specific settings +agent: + max_context_length: 6000 # Shorter for faster processing + include_metadata: false # Less data to process + response_format: concise diff --git a/pipelines/configs/retrieval/modern_dense.yml b/pipelines/configs/retrieval/modern_dense.yml new file mode 100644 index 0000000..da95bf2 --- /dev/null +++ b/pipelines/configs/retrieval/modern_dense.yml @@ -0,0 +1,58 @@ +# Modern Dense Retrieval Configuration for Agent +# Uses the improved dense retriever with neural reranking + +description: "Dense semantic retrieval with Google embeddings and neural reranking" + +# Retrieval Pipeline Configuration +retrieval_pipeline: + retriever: + type: dense + top_k: 15 # Get candidates for reranking + score_threshold: 0.0 + + # Embedding configuration + embedding: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + + # Database configuration + qdrant: + collection_name: sosum_stackoverflow_hybrid_v1 + vector_name: dense + + # Performance settings + performance: + lazy_initialization: true + batch_size: 32 + enable_caching: true + + # Pipeline stages (ordered processing) + stages: + - type: retriever + name: dense_retriever + config: + retriever_type: dense + + - type: score_filter + name: score_filter + config: + min_score: 0.0 + max_results: 12 + + - type: reranker + name: cross_encoder_reranker + config: + model_type: cross_encoder + model_name: cross-encoder/ms-marco-MiniLM-L-6-v2 + top_k: 10 # Final number of results + batch_size: 16 + +# Agent-specific settings +agent: + max_context_length: 8000 + include_metadata: true + response_format: detailed diff --git a/pipelines/configs/retrieval/modern_hybrid.yml b/pipelines/configs/retrieval/modern_hybrid.yml new file mode 100644 index 0000000..0aa7ec4 --- /dev/null +++ b/pipelines/configs/retrieval/modern_hybrid.yml @@ -0,0 +1,74 @@ +# Modern Hybrid Retrieval Configuration for Agent +# Uses the improved hybrid retriever with RRF fusion and CrossEncoder reranking + +description: "Advanced hybrid retrieval with dense+sparse fusion and neural reranking" + +# Retrieval Pipeline Configuration +retrieval_pipeline: + retriever: + type: hybrid + top_k: 20 # Get more candidates for reranking + score_threshold: 0.01 # Low threshold for RRF compatibility + fusion_method: rrf + + # Fusion configuration + fusion: + method: rrf + rrf_k: 60 # Standard RRF parameter + dense_weight: 0.7 # For weighted_sum fallback + sparse_weight: 0.3 + + # Embedding configuration + embedding: + strategy: hybrid + dense: + provider: google + model: models/embedding-001 + dimensions: 768 + api_key_env: GOOGLE_API_KEY + batch_size: 32 + vector_name: dense + sparse: + provider: sparse + model: Qdrant/bm25 + vector_name: sparse + + # Database configuration + qdrant: + collection_name: sosum_stackoverflow_hybrid_v1 + dense_vector_name: dense + sparse_vector_name: sparse + + # Performance settings + performance: + lazy_initialization: true + batch_size: 32 + enable_caching: true + parallel_search: false + + # Pipeline stages (ordered processing) + stages: + - type: retriever + name: hybrid_retriever + config: + retriever_type: hybrid + + - type: score_filter + name: score_filter + config: + min_score: 0.01 # Compatible with RRF scores + max_results: 15 + + - type: reranker + name: cross_encoder_reranker + config: + model_type: cross_encoder + model_name: cross-encoder/ms-marco-MiniLM-L-6-v2 + top_k: 10 # Final number of results + batch_size: 16 + +# Agent-specific settings +agent: + max_context_length: 8000 + include_metadata: true + response_format: detailed diff --git a/pipelines/configs/retriever_config_loader.py b/pipelines/configs/retriever_config_loader.py new file mode 100644 index 0000000..29b80ac --- /dev/null +++ b/pipelines/configs/retriever_config_loader.py @@ -0,0 +1,160 @@ +""" +Retriever configuration loader utility. +Loads and validates retriever configurations from YAML files. +""" + +import os +import yaml +from pathlib import Path +from typing import Dict, Any, Optional, List +import logging + +logger = logging.getLogger(__name__) + + +class RetrieverConfigLoader: + """ + Loads and manages retriever configurations. + """ + + def __init__(self, config_dir: str = None): + """ + Initialize the config loader. + + Args: + config_dir: Directory containing retriever configs (defaults to pipelines/configs/retrievers) + """ + if config_dir is None: + # Default to project's retriever configs directory + project_root = Path(__file__).parent.parent.parent + self.config_dir = project_root / "pipelines" / "configs" / "retrievers" + else: + self.config_dir = Path(config_dir) + + logger.info( + f"RetrieverConfigLoader initialized with config_dir: {self.config_dir}") + + def load_config(self, retriever_type: str) -> Dict[str, Any]: + """ + Load configuration for a specific retriever type. + + Args: + retriever_type: Type of retriever (dense, sparse, hybrid, semantic) + + Returns: + Configuration dictionary + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If configuration is invalid + """ + config_file = self.config_dir / f"{retriever_type}_retriever.yml" + + if not config_file.exists(): + raise FileNotFoundError( + f"Configuration file not found: {config_file}") + + try: + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + logger.info(f"Loaded configuration for {retriever_type} retriever") + + # Validate basic structure + self._validate_config(config, retriever_type) + + return config + + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in {config_file}: {e}") + except Exception as e: + raise ValueError(f"Error loading config {config_file}: {e}") + + def _validate_config(self, config: Dict[str, Any], retriever_type: str): + """Validate configuration structure.""" + required_keys = ['retriever'] + + for key in required_keys: + if key not in config: + raise ValueError(f"Missing required configuration key: {key}") + + # Validate retriever type matches + configured_type = config['retriever'].get('type') + if configured_type != retriever_type: + logger.warning( + f"Config type mismatch: expected {retriever_type}, got {configured_type}") + + def get_available_configs(self) -> List[str]: + """ + Get list of available retriever configurations. + + Returns: + List of available retriever types + """ + if not self.config_dir.exists(): + return [] + + configs = [] + for file_path in self.config_dir.glob("*_retriever.yml"): + retriever_type = file_path.stem.replace("_retriever", "") + configs.append(retriever_type) + + return sorted(configs) + + def load_all_configs(self) -> Dict[str, Dict[str, Any]]: + """ + Load all available retriever configurations. + + Returns: + Dictionary mapping retriever type to configuration + """ + configs = {} + + for retriever_type in self.get_available_configs(): + try: + configs[retriever_type] = self.load_config(retriever_type) + except Exception as e: + logger.warning(f"Failed to load {retriever_type} config: {e}") + + return configs + + def merge_with_global_config(self, retriever_config: Dict[str, Any], + global_config: Dict[str, Any]) -> Dict[str, Any]: + """ + Merge retriever-specific config with global pipeline config. + + Args: + retriever_config: Retriever-specific configuration + global_config: Global pipeline configuration + + Returns: + Merged configuration with global config as base and retriever config taking precedence + """ + merged_config = global_config.copy() + + # Update with retriever-specific settings + for key, value in retriever_config.items(): + if isinstance(value, dict) and key in merged_config and isinstance(merged_config[key], dict): + # Deep merge for nested dictionaries + merged_config[key].update(value) + else: + # Direct assignment for non-dict values or new keys + merged_config[key] = value + + return merged_config + + +# Convenience function for easy access +def load_retriever_config(retriever_type: str, config_dir: str = None) -> Dict[str, Any]: + """ + Convenience function to load a retriever configuration. + + Args: + retriever_type: Type of retriever to load config for + config_dir: Optional custom config directory + + Returns: + Configuration dictionary + """ + loader = RetrieverConfigLoader(config_dir) + return loader.load_config(retriever_type) diff --git a/pipelines/contracts.py b/pipelines/contracts.py new file mode 100644 index 0000000..fe3f96c --- /dev/null +++ b/pipelines/contracts.py @@ -0,0 +1,227 @@ +""" +Contracts and schemas for the ingestion pipeline. +Defines the social contract between raw data and the pipeline. +""" +import hashlib +import uuid +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Dict, List, Optional, Any, Iterable, Union +from enum import Enum +from pathlib import Path + +from pydantic import BaseModel, Field, validator +from langchain_core.documents import Document + + +class DatasetSplit(str, Enum): + """Standardized dataset splits.""" + TRAIN = "train" + VALIDATION = "val" + TEST = "test" + ALL = "all" + + +class BaseRow(BaseModel): + """Base schema for dataset-specific rows. All adapters must extend this.""" + external_id: str = Field(..., description="Original dataset identifier") + + class Config: + extra = "allow" # Allow dataset-specific fields + + +class ChunkMeta(BaseModel): + """Dataset-agnostic metadata for processed chunks.""" + # Identity + doc_id: str = Field(..., description="Deterministic document ID") + chunk_id: str = Field(..., description="Deterministic chunk ID") + + # Content provenance + doc_sha256: str = Field(..., description="SHA256 of normalized document content") + text: str = Field(..., description="Chunk text content") + + # Source tracking + source: str = Field(..., description="Dataset/source name") + dataset_version: str = Field(..., description="Dataset version") + external_id: str = Field(..., description="Original dataset identifier") + uri: Optional[str] = Field(None, description="Source URI/path") + + # Processing metadata + chunk_index: int = Field(..., description="0-based chunk index within document") + num_chunks: int = Field(..., description="Total chunks in document") + + # Content metadata + token_count: Optional[int] = Field(None, description="Estimated token count") + char_count: int = Field(..., description="Character count") + + # Dataset metadata + split: DatasetSplit = Field(..., description="Dataset split") + labels: Dict[str, Any] = Field(default_factory=dict, description="Dataset labels/annotations") + + # Pipeline metadata + ingested_at: datetime = Field(default_factory=datetime.utcnow) + git_commit: Optional[str] = Field(None, description="Git commit of ingestion code") + config_hash: Optional[str] = Field(None, description="Hash of ingestion config") + + # Embedding metadata + embedding_model: Optional[str] = Field(None, description="Embedding model used") + embedding_dim: Optional[int] = Field(None, description="Embedding dimension") + dense_embedding: Optional[List[float]] = Field(None, description="Dense embedding vector") + sparse_embedding: Optional[Dict[int, float]] = Field(None, description="Sparse embedding vector") + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for vector store payload.""" + return self.dict() + + +class IngestionRecord(BaseModel): + """Record of an ingestion run for lineage tracking.""" + run_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + dataset_name: str + dataset_version: str + config_hash: str + git_commit: Optional[str] + + # Counts + total_documents: int + total_chunks: int + successful_chunks: int + failed_chunks: int + + # Timing + started_at: datetime + completed_at: Optional[datetime] = None + + # Sample IDs for verification + sample_doc_ids: List[str] = Field(default_factory=list) + sample_chunk_ids: List[str] = Field(default_factory=list) + + # Additional metadata + metadata: Dict[str, Any] = Field(default_factory=dict) + + # Configuration + chunk_strategy: Dict[str, Any] = Field(default_factory=dict) + embedding_strategy: Dict[str, Any] = Field(default_factory=dict) + + def mark_complete(self): + """Mark the ingestion as completed.""" + self.completed_at = datetime.utcnow() + + +def normalize_text(text: str) -> str: + """Normalize text for consistent hashing.""" + return " ".join(text.strip().split()) + + +def compute_content_hash(text: str) -> str: + """Compute SHA256 hash of normalized text.""" + normalized = normalize_text(text) + return hashlib.sha256(normalized.encode('utf-8')).hexdigest() + + +def build_doc_id(source: str, external_id: str, content_hash: str) -> str: + """Build deterministic document ID.""" + return f"{source}:{external_id}:{content_hash[:12]}" + + +def build_chunk_id(doc_id: str, chunk_index: int) -> str: + """Build deterministic chunk ID.""" + return f"{doc_id}#c{chunk_index:04d}" + + +class DatasetAdapter(ABC): + """Abstract adapter interface for dataset-specific processing.""" + + @property + @abstractmethod + def source_name(self) -> str: + """Return the source/dataset name.""" + pass + + @property + @abstractmethod + def version(self) -> str: + """Return the dataset version.""" + pass + + @abstractmethod + def read_rows(self, split: DatasetSplit = DatasetSplit.ALL) -> Iterable[BaseRow]: + """Read raw dataset rows.""" + pass + + @abstractmethod + def to_documents(self, rows: List[BaseRow], split: DatasetSplit) -> List[Document]: + """Convert rows to LangChain Documents with metadata.""" + pass + + @abstractmethod + def get_evaluation_queries(self, split: DatasetSplit = DatasetSplit.TEST) -> List[Dict[str, Any]]: + """Return evaluation queries for this dataset.""" + pass + + +class ValidationResult(BaseModel): + """Result of document validation.""" + valid: bool + doc_id: str + errors: List[str] = Field(default_factory=list) + warnings: List[str] = Field(default_factory=list) + + +class SmokeTestResult(BaseModel): + """Result of post-ingestion smoke tests.""" + passed: bool + test_name: str + details: Dict[str, Any] = Field(default_factory=dict) + errors: List[str] = Field(default_factory=list) + + +class RetrievalMetrics(BaseModel): + """Standard retrieval evaluation metrics.""" + recall_at_k: Dict[int, float] = Field(default_factory=dict) + precision_at_k: Dict[int, float] = Field(default_factory=dict) + ndcg_at_k: Dict[int, float] = Field(default_factory=dict) + mrr: float = 0.0 + map_score: float = 0.0 + + # Additional metrics + total_queries: int = 0 + total_relevant: int = 0 + + def add_k_metrics(self, k: int, recall: float, precision: float, ndcg: float): + """Add metrics for a specific k value.""" + self.recall_at_k[k] = recall + self.precision_at_k[k] = precision + self.ndcg_at_k[k] = ndcg + + +class EvaluationRun(BaseModel): + """Complete evaluation run results.""" + run_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + dataset_name: str + dataset_version: str + collection_name: str + + # Configuration + retriever_config: Dict[str, Any] + embedding_config: Dict[str, Any] + + # Results + metrics: RetrievalMetrics + per_query_results: List[Dict[str, Any]] = Field(default_factory=list) + + # Metadata + evaluated_at: datetime = Field(default_factory=datetime.utcnow) + git_commit: Optional[str] = None + + def save_to_file(self, path: Path): + """Save evaluation results to JSON file.""" + import json + + with open(path, 'w') as f: + json.dump( + self.dict(), + f, + indent=2, + default=str # Handle datetime serialization + ) diff --git a/pipelines/eval/__init__.py b/pipelines/eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/eval/evaluator.py b/pipelines/eval/evaluator.py new file mode 100644 index 0000000..c0cba31 --- /dev/null +++ b/pipelines/eval/evaluator.py @@ -0,0 +1,351 @@ +""" +Unified retrieval evaluation runner for consistent metrics across datasets. +Implements standard IR metrics with flexible gold standard handling. +""" +import json +import logging +from typing import List, Dict, Any, Optional, Tuple, Set +from pathlib import Path +from collections import defaultdict +import numpy as np + +from pipelines.contracts import DatasetAdapter, RetrievalMetrics, EvaluationRun + + +logger = logging.getLogger(__name__) + + +class RetrievalEvaluator: + """Unified evaluation runner for retrieval systems.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.k_values = config.get("evaluation", {}).get("k_values", [1, 3, 5, 10]) + self.similarity_threshold = config.get("evaluation", {}).get("similarity_threshold", 0.8) + self.enable_semantic_matching = config.get("evaluation", {}).get("semantic_matching", False) + + def evaluate_dataset( + self, + adapter: DatasetAdapter, + retriever: Any, + split: str = "test" + ) -> EvaluationRun: + """Evaluate retriever on a dataset using its adapter.""" + logger.info(f"Evaluating {adapter.source_name} dataset (split: {split})") + + # Get evaluation queries from adapter + eval_queries = adapter.get_evaluation_queries(split) + if not eval_queries: + raise ValueError(f"No evaluation queries found for {adapter.source_name}") + + logger.info(f"Running evaluation on {len(eval_queries)} queries") + + # Run retrieval for all queries + query_results = [] + all_metrics = defaultdict(list) + + for i, query_info in enumerate(eval_queries): + if i % 100 == 0: + logger.info(f"Processing query {i+1}/{len(eval_queries)}") + + query_result = self._evaluate_single_query(query_info, retriever) + query_results.append(query_result) + + # Aggregate metrics + for metric_name, value in query_result.get("metrics", {}).items(): + all_metrics[metric_name].append(value) + + # Compute aggregate metrics + aggregate_metrics = RetrievalMetrics() + + # Calculate averages for each k value + for k in self.k_values: + recall_values = all_metrics[f"recall_at_{k}"] + precision_values = all_metrics[f"precision_at_{k}"] + ndcg_values = all_metrics[f"ndcg_at_{k}"] + + if recall_values: + aggregate_metrics.recall_at_k[k] = np.mean(recall_values) + aggregate_metrics.precision_at_k[k] = np.mean(precision_values) + aggregate_metrics.ndcg_at_k[k] = np.mean(ndcg_values) + + # Calculate MRR and MAP + mrr_values = all_metrics["mrr"] + map_values = all_metrics["map"] + + aggregate_metrics.mrr = np.mean(mrr_values) if mrr_values else 0.0 + aggregate_metrics.map_score = np.mean(map_values) if map_values else 0.0 + aggregate_metrics.total_queries = len(eval_queries) + aggregate_metrics.total_relevant = sum( + len(q.get("relevant_doc_ids", [])) for q in eval_queries + ) + + # Create evaluation run + evaluation_run = EvaluationRun( + dataset_name=adapter.source_name, + dataset_version=adapter.version, + collection_name=self.config.get("qdrant", {}).get("collection", "unknown"), + retriever_config=self._extract_retriever_config(), + embedding_config=self.config.get("embedding", {}), + metrics=aggregate_metrics, + per_query_results=query_results + ) + + logger.info(f"Evaluation completed. Average recall@5: {aggregate_metrics.recall_at_k.get(5, 0):.3f}") + return evaluation_run + + def _evaluate_single_query(self, query_info: Dict[str, Any], retriever: Any) -> Dict[str, Any]: + """Evaluate a single query.""" + query = query_info["query"] + query_id = query_info.get("query_id", "unknown") + relevant_doc_ids = set(query_info.get("relevant_doc_ids", [])) + relevance_scores = query_info.get("relevance_scores", {}) + + try: + # Retrieve documents + retrieved_docs = retriever.retrieve(query) + + # Extract document IDs and scores + if isinstance(retrieved_docs, list) and retrieved_docs: + if isinstance(retrieved_docs[0], tuple): + # Handle (doc, score) format + doc_ids = [doc.metadata.get("external_id", doc.metadata.get("doc_id", "unknown")) + for doc, _ in retrieved_docs] + scores = [score for _, score in retrieved_docs] + else: + # Handle doc list format + doc_ids = [doc.metadata.get("external_id", doc.metadata.get("doc_id", "unknown")) + for doc in retrieved_docs] + scores = [1.0] * len(doc_ids) # Default scores + else: + doc_ids = [] + scores = [] + + # Calculate metrics for all k values + metrics = {} + + for k in self.k_values: + # Limit to top-k + top_k_ids = doc_ids[:k] + top_k_scores = scores[:k] + + # Calculate recall@k + if relevant_doc_ids: + relevant_retrieved = len(set(top_k_ids).intersection(relevant_doc_ids)) + recall_k = relevant_retrieved / len(relevant_doc_ids) + else: + recall_k = 0.0 + + # Calculate precision@k + precision_k = relevant_retrieved / k if k > 0 and top_k_ids else 0.0 + + # Calculate NDCG@k + ndcg_k = self._calculate_ndcg(top_k_ids, relevance_scores, k) + + metrics[f"recall_at_{k}"] = recall_k + metrics[f"precision_at_{k}"] = precision_k + metrics[f"ndcg_at_{k}"] = ndcg_k + + # Calculate MRR + mrr = self._calculate_mrr(doc_ids, relevant_doc_ids) + metrics["mrr"] = mrr + + # Calculate MAP + map_score = self._calculate_map(doc_ids, relevant_doc_ids) + metrics["map"] = map_score + + return { + "query_id": query_id, + "query": query, + "retrieved_count": len(doc_ids), + "relevant_count": len(relevant_doc_ids), + "metrics": metrics, + "success": True + } + + except Exception as e: + logger.error(f"Error evaluating query '{query}': {e}") + return { + "query_id": query_id, + "query": query, + "error": str(e), + "success": False, + "metrics": {f"{metric}_at_{k}": 0.0 for metric in ["recall", "precision", "ndcg"] for k in self.k_values} + } + + def _calculate_ndcg(self, retrieved_ids: List[str], relevance_scores: Dict[str, float], k: int) -> float: + """Calculate Normalized Discounted Cumulative Gain@k.""" + if not retrieved_ids or not relevance_scores: + return 0.0 + + # Calculate DCG + dcg = 0.0 + for i, doc_id in enumerate(retrieved_ids[:k]): + relevance = relevance_scores.get(doc_id, 0.0) + if i == 0: + dcg += relevance + else: + dcg += relevance / np.log2(i + 1) + + # Calculate IDCG (ideal DCG) + ideal_relevances = sorted(relevance_scores.values(), reverse=True)[:k] + idcg = 0.0 + for i, relevance in enumerate(ideal_relevances): + if i == 0: + idcg += relevance + else: + idcg += relevance / np.log2(i + 1) + + return dcg / idcg if idcg > 0 else 0.0 + + def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: Set[str]) -> float: + """Calculate Mean Reciprocal Rank.""" + for i, doc_id in enumerate(retrieved_ids): + if doc_id in relevant_ids: + return 1.0 / (i + 1) + return 0.0 + + def _calculate_map(self, retrieved_ids: List[str], relevant_ids: Set[str]) -> float: + """Calculate Mean Average Precision.""" + if not relevant_ids: + return 0.0 + + relevant_count = 0 + precision_sum = 0.0 + + for i, doc_id in enumerate(retrieved_ids): + if doc_id in relevant_ids: + relevant_count += 1 + precision_at_i = relevant_count / (i + 1) + precision_sum += precision_at_i + + return precision_sum / len(relevant_ids) if relevant_ids else 0.0 + + def _extract_retriever_config(self) -> Dict[str, Any]: + """Extract retriever configuration for lineage.""" + return { + "strategy": self.config.get("embedding_strategy", "unknown"), + "retriever_type": self.config.get("retriever", {}).get("type", "unknown"), + "top_k": self.config.get("retriever", {}).get("top_k", 10), + "dense_model": self.config.get("embedding", {}).get("dense", {}).get("model_name", "unknown"), + "sparse_model": self.config.get("embedding", {}).get("sparse", {}).get("model_name", "unknown") + } + + def save_results(self, evaluation_run: EvaluationRun, output_dir: Path): + """Save evaluation results to files.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # Save main results + results_file = output_dir / f"{evaluation_run.dataset_name}_evaluation.json" + evaluation_run.save_to_file(results_file) + + # Save metrics summary + metrics_file = output_dir / f"{evaluation_run.dataset_name}_metrics.csv" + self._save_metrics_csv(evaluation_run.metrics, metrics_file) + + # Save per-query results + query_results_file = output_dir / f"{evaluation_run.dataset_name}_per_query.csv" + self._save_query_results_csv(evaluation_run.per_query_results, query_results_file) + + logger.info(f"Evaluation results saved to {output_dir}") + + def _save_metrics_csv(self, metrics: RetrievalMetrics, file_path: Path): + """Save metrics summary as CSV.""" + import csv + + with open(file_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(["k", "nDCG", "Recall", "Precision", "MRR"]) + + for k in sorted(metrics.recall_at_k.keys()): + writer.writerow([ + k, + f"{metrics.ndcg_at_k.get(k, 0):.4f}", + f"{metrics.recall_at_k.get(k, 0):.4f}", + f"{metrics.precision_at_k.get(k, 0):.4f}", + f"{metrics.mrr:.4f}" if k == min(metrics.recall_at_k.keys()) else "" + ]) + + def _save_query_results_csv(self, query_results: List[Dict[str, Any]], file_path: Path): + """Save per-query results as CSV.""" + import csv + + if not query_results: + return + + with open(file_path, 'w', newline='') as f: + # Get all metric names from first successful query + sample_metrics = {} + for result in query_results: + if result.get("success") and result.get("metrics"): + sample_metrics = result["metrics"] + break + + fieldnames = ["query_id", "query", "success"] + list(sample_metrics.keys()) + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + for result in query_results: + row = { + "query_id": result.get("query_id", ""), + "query": result.get("query", ""), + "success": result.get("success", False) + } + + # Add metrics + metrics = result.get("metrics", {}) + for metric_name in sample_metrics.keys(): + row[metric_name] = f"{metrics.get(metric_name, 0):.4f}" + + writer.writerow(row) + + +class MetricsComparator: + """Compare evaluation results across different configurations.""" + + @staticmethod + def compare_runs(runs: List[EvaluationRun]) -> Dict[str, Any]: + """Compare multiple evaluation runs.""" + if not runs: + return {} + + comparison = { + "run_count": len(runs), + "datasets": [run.dataset_name for run in runs], + "configurations": [ + { + "dataset": run.dataset_name, + "collection": run.collection_name, + "embedding_strategy": run.embedding_config.get("strategy", "unknown") + } + for run in runs + ], + "metrics_comparison": {} + } + + # Compare metrics across runs + k_values = set() + for run in runs: + k_values.update(run.metrics.recall_at_k.keys()) + + for k in sorted(k_values): + comparison["metrics_comparison"][f"recall_at_{k}"] = [ + run.metrics.recall_at_k.get(k, 0) for run in runs + ] + comparison["metrics_comparison"][f"precision_at_{k}"] = [ + run.metrics.precision_at_k.get(k, 0) for run in runs + ] + comparison["metrics_comparison"][f"ndcg_at_{k}"] = [ + run.metrics.ndcg_at_k.get(k, 0) for run in runs + ] + + comparison["metrics_comparison"]["mrr"] = [run.metrics.mrr for run in runs] + comparison["metrics_comparison"]["map"] = [run.metrics.map_score for run in runs] + + return comparison + + @staticmethod + def save_comparison(comparison: Dict[str, Any], output_file: Path): + """Save comparison results to JSON.""" + with open(output_file, 'w') as f: + json.dump(comparison, f, indent=2, default=str) diff --git a/pipelines/ingest/__init__.py b/pipelines/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/ingest/chunker.py b/pipelines/ingest/chunker.py new file mode 100644 index 0000000..742b242 --- /dev/null +++ b/pipelines/ingest/chunker.py @@ -0,0 +1,475 @@ +""" +Advanced chunking strategies for different content types. +Matches embedder receptive field and preserves semantic coherence. +""" +import re +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod + +from langchain_core.documents import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter +from embedding.recursive_splitter import RecursiveSplitter + + +class ChunkingStrategy(ABC): + """Abstract base for chunking strategies.""" + + @abstractmethod + def chunk(self, documents: List[Document]) -> List[Document]: + """Chunk documents into smaller pieces.""" + pass + + @property + @abstractmethod + def strategy_name(self) -> str: + """Return strategy name for metadata.""" + pass + + +class RecursiveChunkingStrategy(ChunkingStrategy): + """Recursive character-based chunking for general text.""" + + def __init__(self, config: Dict[str, Any]): + self.chunk_size = config.get("chunk_size", 500) + self.chunk_overlap = config.get("chunk_overlap", 50) + self.separators = config.get("separators", None) + + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + separators=self.separators, + length_function=len, + is_separator_regex=False, + ) + + def chunk(self, documents: List[Document]) -> List[Document]: + """Split documents using recursive character splitting.""" + chunks = self.splitter.split_documents(documents) + + # Add chunk metadata + for i, chunk in enumerate(chunks): + chunk.metadata["chunk_index"] = i + chunk.metadata["chunking_strategy"] = self.strategy_name + chunk.metadata["chunk_size"] = self.chunk_size + chunk.metadata["chunk_overlap"] = self.chunk_overlap + + return chunks + + @property + def strategy_name(self) -> str: + return "recursive_character" + + +class SemanticChunkingStrategy(ChunkingStrategy): + """Semantic chunking that preserves sentence boundaries.""" + + def __init__(self, config: Dict[str, Any]): + self.target_chunk_size = config.get("chunk_size", 500) + self.max_chunk_size = config.get("max_chunk_size", 800) + self.sentence_overlap = config.get("sentence_overlap", 1) + + def chunk(self, documents: List[Document]) -> List[Document]: + """Split documents at sentence boundaries.""" + chunks = [] + + for doc in documents: + doc_chunks = self._chunk_document(doc) + chunks.extend(doc_chunks) + + return chunks + + def _chunk_document(self, doc: Document) -> List[Document]: + """Chunk a single document at sentence boundaries.""" + text = doc.page_content + sentences = self._split_into_sentences(text) + + if not sentences: + return [doc] + + chunks = [] + current_chunk = "" + chunk_sentences = [] + + for i, sentence in enumerate(sentences): + # Check if adding this sentence would exceed target size + potential_chunk = current_chunk + " " + sentence if current_chunk else sentence + + if len(potential_chunk) <= self.target_chunk_size or not current_chunk: + current_chunk = potential_chunk + chunk_sentences.append(sentence) + else: + # Save current chunk + if current_chunk: + chunks.append(self._create_chunk(doc, current_chunk, len(chunks))) + + # Start new chunk (with overlap) + overlap_start = max(0, len(chunk_sentences) - self.sentence_overlap) + overlap_sentences = chunk_sentences[overlap_start:] + + current_chunk = " ".join(overlap_sentences + [sentence]) + chunk_sentences = overlap_sentences + [sentence] + + # Add final chunk + if current_chunk: + chunks.append(self._create_chunk(doc, current_chunk, len(chunks))) + + return chunks + + def _split_into_sentences(self, text: str) -> List[str]: + """Split text into sentences using regex.""" + # Simple sentence splitting - can be enhanced with NLTK/spaCy + sentence_pattern = r'(?<=[.!?])\s+' + sentences = re.split(sentence_pattern, text) + return [s.strip() for s in sentences if s.strip()] + + def _create_chunk(self, original_doc: Document, chunk_text: str, chunk_index: int) -> Document: + """Create a chunk document with metadata.""" + metadata = original_doc.metadata.copy() + metadata.update({ + "chunk_index": chunk_index, + "chunking_strategy": self.strategy_name, + "char_count": len(chunk_text) + }) + + return Document(page_content=chunk_text, metadata=metadata) + + @property + def strategy_name(self) -> str: + return "semantic_sentence" + + +class CodeAwareChunkingStrategy(ChunkingStrategy): + """Chunking strategy that preserves code blocks and functions.""" + + def __init__(self, config: Dict[str, Any]): + self.chunk_size = config.get("chunk_size", 800) + self.preserve_functions = config.get("preserve_functions", True) + self.preserve_code_blocks = config.get("preserve_code_blocks", True) + + # Fallback to recursive splitting for non-code content + self.fallback_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=50 + ) + + def chunk(self, documents: List[Document]) -> List[Document]: + """Split documents preserving code structure.""" + chunks = [] + + for doc in documents: + if self._has_code_content(doc.page_content): + doc_chunks = self._chunk_code_document(doc) + else: + doc_chunks = self.fallback_splitter.split_documents([doc]) + + # Add metadata + for i, chunk in enumerate(doc_chunks): + chunk.metadata["chunk_index"] = i + chunk.metadata["chunking_strategy"] = self.strategy_name + + chunks.extend(doc_chunks) + + return chunks + + def _has_code_content(self, text: str) -> bool: + """Detect if text contains code.""" + code_indicators = [ + r'```', # Markdown code blocks + r'def \w+\(', # Python functions + r'function \w+\(', # JavaScript functions + r'class \w+', # Class definitions + r'import \w+', # Import statements + r'\{\s*\n.*\n\s*\}', # Brace blocks + ] + + for pattern in code_indicators: + if re.search(pattern, text, re.MULTILINE): + return True + + return False + + def _chunk_code_document(self, doc: Document) -> List[Document]: + """Chunk document with code-aware splitting.""" + text = doc.page_content + chunks = [] + + # Find code blocks + code_blocks = list(re.finditer(r'```[\w]*\n(.*?)\n```', text, re.DOTALL)) + + if not code_blocks: + # No explicit code blocks, use function-based splitting + return self._split_by_functions(doc) + + last_end = 0 + for i, match in enumerate(code_blocks): + # Add text before code block + before_code = text[last_end:match.start()].strip() + if before_code: + chunks.append(self._create_chunk(doc, before_code, len(chunks))) + + # Add code block (keep intact if not too large) + code_content = match.group(0) + if len(code_content) <= self.chunk_size: + chunks.append(self._create_chunk(doc, code_content, len(chunks))) + else: + # Split large code blocks + code_chunks = self.fallback_splitter.create_documents([code_content]) + for code_chunk in code_chunks: + code_chunk.metadata = doc.metadata.copy() + chunks.append(code_chunk) + + last_end = match.end() + + # Add remaining text + remaining_text = text[last_end:].strip() + if remaining_text: + chunks.append(self._create_chunk(doc, remaining_text, len(chunks))) + + return chunks + + def _split_by_functions(self, doc: Document) -> List[Document]: + """Split by function/class boundaries.""" + text = doc.page_content + + # Find function/class definitions + function_pattern = r'^(def |class |function |async def )' + lines = text.split('\n') + + chunks = [] + current_chunk_lines = [] + + for line in lines: + if re.match(function_pattern, line.strip()) and current_chunk_lines: + # Start new chunk at function boundary + chunk_text = '\n'.join(current_chunk_lines) + if chunk_text.strip(): + chunks.append(self._create_chunk(doc, chunk_text, len(chunks))) + current_chunk_lines = [line] + else: + current_chunk_lines.append(line) + + # Check size limit + chunk_text = '\n'.join(current_chunk_lines) + if len(chunk_text) > self.chunk_size: + chunks.append(self._create_chunk(doc, chunk_text, len(chunks))) + current_chunk_lines = [] + + # Add final chunk + if current_chunk_lines: + chunk_text = '\n'.join(current_chunk_lines) + if chunk_text.strip(): + chunks.append(self._create_chunk(doc, chunk_text, len(chunks))) + + return chunks if chunks else [doc] + + def _create_chunk(self, original_doc: Document, chunk_text: str, chunk_index: int) -> Document: + """Create chunk with metadata.""" + metadata = original_doc.metadata.copy() + metadata.update({ + "chunk_index": chunk_index, + "chunking_strategy": self.strategy_name, + "char_count": len(chunk_text), + "has_code": self._has_code_content(chunk_text) + }) + + return Document(page_content=chunk_text, metadata=metadata) + + @property + def strategy_name(self) -> str: + return "code_aware" + + +class TableAwareChunkingStrategy(ChunkingStrategy): + """Chunking strategy that preserves table structure.""" + + def __init__(self, config: Dict[str, Any]): + self.chunk_size = config.get("chunk_size", 1000) + self.preserve_headers = config.get("preserve_headers", True) + self.max_table_size = config.get("max_table_size", 2000) + + self.fallback_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=50 + ) + + def chunk(self, documents: List[Document]) -> List[Document]: + """Split documents preserving table structure.""" + chunks = [] + + for doc in documents: + if self._has_tables(doc.page_content): + doc_chunks = self._chunk_table_document(doc) + else: + doc_chunks = self.fallback_splitter.split_documents([doc]) + + # Add metadata + for i, chunk in enumerate(doc_chunks): + chunk.metadata["chunk_index"] = i + chunk.metadata["chunking_strategy"] = self.strategy_name + + chunks.extend(doc_chunks) + + return chunks + + def _has_tables(self, text: str) -> bool: + """Detect if text contains tables.""" + table_indicators = [ + r'\|.*\|.*\|', # Markdown tables + r'\t.*\t.*\t', # Tab-separated + r'┌─+┬─+┐', # ASCII tables + r'', # HTML tables + ] + + for pattern in table_indicators: + if re.search(pattern, text, re.MULTILINE): + return True + + return False + + def _chunk_table_document(self, doc: Document) -> List[Document]: + """Chunk document with table awareness.""" + text = doc.page_content + + # Find markdown tables + table_pattern = r'(\|.*\|.*\n)+(\|[-:\s]+\|.*\n)?(\|.*\|.*\n)+' + tables = list(re.finditer(table_pattern, text, re.MULTILINE)) + + if not tables: + return self.fallback_splitter.split_documents([doc]) + + chunks = [] + last_end = 0 + + for table_match in tables: + # Add text before table + before_table = text[last_end:table_match.start()].strip() + if before_table: + chunks.extend(self._split_text_chunk(doc, before_table, len(chunks))) + + # Process table + table_content = table_match.group(0) + if len(table_content) <= self.max_table_size: + # Keep table intact + chunks.append(self._create_chunk(doc, table_content, len(chunks))) + else: + # Split large table by rows + chunks.extend(self._split_large_table(doc, table_content, len(chunks))) + + last_end = table_match.end() + + # Add remaining text + remaining_text = text[last_end:].strip() + if remaining_text: + chunks.extend(self._split_text_chunk(doc, remaining_text, len(chunks))) + + return chunks + + def _split_text_chunk(self, doc: Document, text: str, start_index: int) -> List[Document]: + """Split non-table text using fallback splitter.""" + temp_doc = Document(page_content=text, metadata=doc.metadata.copy()) + text_chunks = self.fallback_splitter.split_documents([temp_doc]) + + for i, chunk in enumerate(text_chunks): + chunk.metadata["chunk_index"] = start_index + i + chunk.metadata["chunking_strategy"] = self.strategy_name + + return text_chunks + + def _split_large_table(self, doc: Document, table_content: str, start_index: int) -> List[Document]: + """Split large table by rows while preserving header.""" + lines = table_content.split('\n') + header_lines = [] + data_lines = [] + + # Identify header (first line + separator if exists) + if lines: + header_lines.append(lines[0]) + if len(lines) > 1 and re.match(r'\|[-:\s]+\|', lines[1]): + header_lines.append(lines[1]) + data_lines = lines[2:] + else: + data_lines = lines[1:] + + chunks = [] + current_rows = header_lines.copy() if self.preserve_headers else [] + + for line in data_lines: + current_rows.append(line) + chunk_text = '\n'.join(current_rows) + + if len(chunk_text) > self.chunk_size: + # Save current chunk + if len(current_rows) > len(header_lines): + chunks.append(self._create_chunk(doc, chunk_text, start_index + len(chunks))) + + # Start new chunk with headers + current_rows = header_lines.copy() + [line] if self.preserve_headers else [line] + + # Add final chunk + if len(current_rows) > len(header_lines): + chunk_text = '\n'.join(current_rows) + chunks.append(self._create_chunk(doc, chunk_text, start_index + len(chunks))) + + return chunks + + def _create_chunk(self, original_doc: Document, chunk_text: str, chunk_index: int) -> Document: + """Create chunk with metadata.""" + metadata = original_doc.metadata.copy() + metadata.update({ + "chunk_index": chunk_index, + "chunking_strategy": self.strategy_name, + "char_count": len(chunk_text), + "has_table": self._has_tables(chunk_text) + }) + + return Document(page_content=chunk_text, metadata=metadata) + + @property + def strategy_name(self) -> str: + return "table_aware" + + +class ChunkingStrategyFactory: + """Factory for creating chunking strategies.""" + + STRATEGIES = { + "recursive": RecursiveChunkingStrategy, + "semantic": SemanticChunkingStrategy, + "code_aware": CodeAwareChunkingStrategy, + "table_aware": TableAwareChunkingStrategy, + } + + @classmethod + def create_strategy(cls, strategy_name: str, config: Dict[str, Any]) -> ChunkingStrategy: + """Create chunking strategy by name.""" + if strategy_name not in cls.STRATEGIES: + available = ", ".join(cls.STRATEGIES.keys()) + raise ValueError(f"Unknown chunking strategy '{strategy_name}'. Available: {available}") + + strategy_class = cls.STRATEGIES[strategy_name] + return strategy_class(config) + + @classmethod + def get_strategy_for_content(cls, content: str, config: Dict[str, Any]) -> ChunkingStrategy: + """Auto-select chunking strategy based on content analysis.""" + # Simple heuristics for auto-selection + if cls._has_code_content(content): + return cls.create_strategy("code_aware", config) + elif cls._has_tables(content): + return cls.create_strategy("table_aware", config) + elif config.get("use_semantic", False): + return cls.create_strategy("semantic", config) + else: + return cls.create_strategy("recursive", config) + + @staticmethod + def _has_code_content(text: str) -> bool: + """Detect code content.""" + code_patterns = [r'```', r'def \w+\(', r'function \w+\(', r'class \w+'] + return any(re.search(pattern, text) for pattern in code_patterns) + + @staticmethod + def _has_tables(text: str) -> bool: + """Detect table content.""" + table_patterns = [r'\|.*\|.*\|', r'\t.*\t.*\t'] + return any(re.search(pattern, text, re.MULTILINE) for pattern in table_patterns) diff --git a/pipelines/ingest/embedder.py b/pipelines/ingest/embedder.py new file mode 100644 index 0000000..68894e2 --- /dev/null +++ b/pipelines/ingest/embedder.py @@ -0,0 +1,357 @@ +""" +Enhanced embedding pipeline with strategy selection and batch processing. +Supports dense, sparse, and hybrid embeddings with caching and error handling. +""" +import os +import json +import hashlib +from typing import List, Dict, Any, Optional, Tuple, Union +from pathlib import Path +import logging +from datetime import datetime + +from tqdm import tqdm +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings + +from embedding.factory import get_embedder +from pipelines.contracts import ChunkMeta, build_doc_id, build_chunk_id, compute_content_hash, DatasetSplit + + +logger = logging.getLogger(__name__) + + +class EmbeddingPipeline: + """Enhanced embedding pipeline with strategy selection and caching.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + + # Check for strategy in multiple locations for backward compatibility + self.embedding_strategy = ( + config.get("embedding_strategy") or # Top level (legacy) + # Under embedding (current) + config.get("embedding", {}).get("strategy") or + "dense" # Default fallback + ).lower() + + logger.info(f"Using embedding strategy: {self.embedding_strategy}") + + # Initialize embedders based on strategy + self.dense_embedder = None + self.sparse_embedder = None + + if self.embedding_strategy in ["dense", "hybrid"]: + dense_config = config.get("embedding", {}).get("dense", {}) + self.dense_embedder = get_embedder(dense_config) + logger.info( + f"Initialized dense embedder: {dense_config.get('provider', 'unknown')}") + + if self.embedding_strategy in ["sparse", "hybrid"]: + sparse_config = config.get("embedding", {}).get("sparse", {}) + self.sparse_embedder = get_embedder(sparse_config) + logger.info( + f"Initialized sparse embedder: {sparse_config.get('provider', 'unknown')}") + + # Caching configuration + self.enable_cache = config.get( + "embedding_cache", {}).get("enabled", True) + self.cache_dir = Path(config.get( + "embedding_cache", {}).get("dir", "cache/embeddings")) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Batch processing + self.batch_size = config.get("embedding", {}).get("batch_size", 32) + self.max_retries = config.get("embedding", {}).get("max_retries", 3) + + # Error handling + self.fail_fast = config.get("embedding", {}).get("fail_fast", False) + self.fallback_embedding_dim = config.get( + "embedding", {}).get("fallback_dim", 384) + + def process_documents(self, documents: List[Document]) -> List[ChunkMeta]: + """Process documents into ChunkMeta with embeddings.""" + if not documents: + return [] + + logger.info( + f"Processing {len(documents)} documents with {self.embedding_strategy} strategy") + + chunk_metas = [] + + # Convert documents to ChunkMeta objects + print("📄 Converting documents to chunks...") + for doc in tqdm(documents, desc="Converting documents", unit="doc"): + chunk_meta = self._document_to_chunk_meta(doc) + chunk_metas.append(chunk_meta) + + # Generate embeddings in batches + if self.dense_embedder: + print("🔤 Generating dense embeddings...") + logger.info("Generating dense embeddings...") + dense_embeddings = self._generate_embeddings( + [meta.text for meta in chunk_metas], + self.dense_embedder, + "dense" + ) + + print("🔗 Attaching dense embeddings...") + for meta, embedding in tqdm(zip(chunk_metas, dense_embeddings), + desc="Attaching dense embeddings", + total=len(chunk_metas), unit="chunk"): + meta.dense_embedding = embedding + if embedding: + meta.embedding_dim = len(embedding) + meta.embedding_model = self._get_model_name( + self.dense_embedder) + + if self.sparse_embedder: + print("🕸️ Generating sparse embeddings...") + logger.info("Generating sparse embeddings...") + sparse_embeddings = self._generate_embeddings( + [meta.text for meta in chunk_metas], + self.sparse_embedder, + "sparse" + ) + + print("🔗 Attaching sparse embeddings...") + for meta, embedding in tqdm(zip(chunk_metas, sparse_embeddings), + desc="Attaching sparse embeddings", + total=len(chunk_metas), unit="chunk"): + meta.sparse_embedding = embedding + + logger.info(f"Successfully processed {len(chunk_metas)} chunk metas") + return chunk_metas + + def _document_to_chunk_meta(self, doc: Document) -> ChunkMeta: + """Convert Document to ChunkMeta with deterministic IDs.""" + text = doc.page_content + metadata = doc.metadata + + # Generate deterministic IDs + doc_sha256 = compute_content_hash(text) + source = metadata.get("source", "unknown") + external_id = metadata.get("external_id", "unknown") + + doc_id = build_doc_id(source, external_id, doc_sha256) + chunk_index = metadata.get("chunk_index", 0) + chunk_id = build_chunk_id(doc_id, chunk_index) + + # Extract git commit and config hash from environment/config + git_commit = self._get_git_commit() + config_hash = self._compute_config_hash() + + chunk_meta = ChunkMeta( + doc_id=doc_id, + chunk_id=chunk_id, + doc_sha256=doc_sha256, + text=text, + source=source, + dataset_version=metadata.get("dataset_version", "unknown"), + external_id=external_id, + uri=metadata.get("uri"), + chunk_index=chunk_index, + num_chunks=metadata.get("num_chunks", 1), + token_count=metadata.get("token_estimate"), + char_count=len(text), + split=DatasetSplit(metadata.get("split", "all")), + labels=metadata.get("labels", {}), + git_commit=git_commit, + config_hash=config_hash + ) + + # Copy additional metadata + for key, value in metadata.items(): + if key not in chunk_meta.dict(): + chunk_meta.labels[key] = value + + return chunk_meta + + def _generate_embeddings(self, texts: List[str], embedder: Embeddings, embedding_type: str) -> List[Optional[Union[List[float], Dict[int, float]]]]: + """Generate embeddings with caching and error handling.""" + embeddings = [] + + # Calculate total batches for progress tracking + total_batches = (len(texts) + self.batch_size - 1) // self.batch_size + + with tqdm(total=total_batches, desc=f"Processing {embedding_type} batches", unit="batch") as pbar: + for i in range(0, len(texts), self.batch_size): + batch_texts = texts[i:i + self.batch_size] + batch_embeddings = self._process_embedding_batch( + batch_texts, embedder, embedding_type) + embeddings.extend(batch_embeddings) + pbar.update(1) + + # Update description with current progress + cached_count = sum( + 1 for e in batch_embeddings if e is not None) + pbar.set_postfix({ + 'texts': f"{min(i + self.batch_size, len(texts))}/{len(texts)}", + 'cached': f"{cached_count}/{len(batch_embeddings)}" + }) + + return embeddings + + def _process_embedding_batch(self, texts: List[str], embedder: Embeddings, embedding_type: str) -> List[Optional[Union[List[float], Dict[int, float]]]]: + """Process a batch of texts for embedding.""" + batch_embeddings = [] + + for text in tqdm(texts, desc=f"Embedding {embedding_type}", unit="text", leave=False): + try: + # Check cache first + if self.enable_cache: + cached_embedding = self._get_cached_embedding( + text, embedder, embedding_type) + if cached_embedding is not None: + batch_embeddings.append(cached_embedding) + continue + + # Generate new embedding + embedding = self._generate_single_embedding( + text, embedder, embedding_type) + + # Cache the result + if self.enable_cache and embedding is not None: + self._cache_embedding( + text, embedding, embedder, embedding_type) + + batch_embeddings.append(embedding) + + except Exception as e: + logger.error( + f"Error generating {embedding_type} embedding for text: {text[:100]}... Error: {e}") + + if self.fail_fast: + raise + + # Add fallback embedding + fallback_embedding = self._create_fallback_embedding( + embedding_type) + batch_embeddings.append(fallback_embedding) + + return batch_embeddings + + def _generate_single_embedding(self, text: str, embedder: Embeddings, embedding_type: str) -> Optional[Union[List[float], Dict[int, float]]]: + """Generate embedding for a single text with retries.""" + for attempt in range(self.max_retries): + try: + embedding = embedder.embed_query(text) + + # Validate embedding based on type + if embedding_type == "dense": + if embedding and isinstance(embedding, list) and len(embedding) > 0: + return embedding + else: # sparse + # Handle different sparse embedding formats + if embedding is not None: + if isinstance(embedding, dict) and len(embedding) > 0: + return embedding + elif hasattr(embedding, '__len__') and len(embedding) > 0: + # Convert array-like objects to dict if needed + logger.warning( + f"Converting sparse embedding type {type(embedding)} to dict") + return dict(embedding) if hasattr(embedding, 'items') else {} + else: + logger.warning( + f"Unexpected sparse embedding type: {type(embedding)}") + return {} + + logger.warning( + f"Empty or invalid {embedding_type} embedding returned for text: {text[:50]}...") + return None + + except Exception as e: + logger.warning( + f"Attempt {attempt + 1}/{self.max_retries} failed for {embedding_type} embedding: {e}") + if attempt == self.max_retries - 1: + raise + + return None + + def _get_cached_embedding(self, text: str, embedder: Embeddings, embedding_type: str) -> Optional[Union[List[float], Dict[int, float]]]: + """Retrieve cached embedding.""" + cache_key = self._compute_cache_key(text, embedder, embedding_type) + cache_file = self.cache_dir / f"{cache_key}.json" + + if cache_file.exists(): + try: + with open(cache_file, 'r') as f: + data = json.load(f) + embedding = data.get("embedding") + + # Convert string keys back to int for sparse embeddings + if embedding_type == "sparse" and isinstance(embedding, dict): + return {int(k): v for k, v in embedding.items()} + + return embedding + except Exception as e: + logger.warning(f"Error reading cache file {cache_file}: {e}") + + return None + + def _cache_embedding(self, text: str, embedding: Union[List[float], Dict[int, float]], embedder: Embeddings, embedding_type: str): + """Cache embedding result.""" + cache_key = self._compute_cache_key(text, embedder, embedding_type) + cache_file = self.cache_dir / f"{cache_key}.json" + + try: + # Convert int keys to string for JSON serialization + serializable_embedding = embedding + if embedding_type == "sparse" and isinstance(embedding, dict): + serializable_embedding = { + str(k): v for k, v in embedding.items()} + + cache_data = { + "text_hash": hashlib.sha256(text.encode()).hexdigest(), + "embedding_type": embedding_type, + "model": self._get_model_name(embedder), + "embedding": serializable_embedding, + "created_at": str(datetime.now()) + } + + with open(cache_file, 'w') as f: + json.dump(cache_data, f) + + except Exception as e: + logger.warning(f"Error caching embedding: {e}") + + def _compute_cache_key(self, text: str, embedder: Embeddings, embedding_type: str) -> str: + """Compute cache key for text and embedder.""" + model_name = self._get_model_name(embedder) + content = f"{text}:{embedding_type}:{model_name}" + return hashlib.sha256(content.encode()).hexdigest()[:16] + + def _get_model_name(self, embedder: Embeddings) -> str: + """Extract model name from embedder.""" + if hasattr(embedder, 'model_name'): + return embedder.model_name + elif hasattr(embedder, 'model'): + return embedder.model + else: + return embedder.__class__.__name__ + + def _create_fallback_embedding(self, embedding_type: str) -> Union[List[float], Dict[int, float]]: + """Create fallback embedding for failed cases.""" + if embedding_type == "sparse": + # For sparse embeddings, return empty dict + return {} + else: + # For dense embeddings, return zero vector + return [0.0] * self.fallback_embedding_dim + + def _get_git_commit(self) -> Optional[str]: + """Get current git commit hash.""" + try: + import subprocess + result = subprocess.run(['git', 'rev-parse', 'HEAD'], + capture_output=True, text=True, cwd='.') + if result.returncode == 0: + return result.stdout.strip() + except Exception: + pass + return None + + def _compute_config_hash(self) -> str: + """Compute hash of current configuration.""" + config_str = json.dumps(self.config, sort_keys=True, default=str) + return hashlib.sha256(config_str.encode()).hexdigest()[:12] diff --git a/pipelines/ingest/pipeline.py b/pipelines/ingest/pipeline.py new file mode 100644 index 0000000..b6a84f4 --- /dev/null +++ b/pipelines/ingest/pipeline.py @@ -0,0 +1,431 @@ +""" +Main ingestion pipeline that orchestrates all components. +Implements the complete theory-backed pipeline with lineage tracking. +""" +import os +import json +import logging +from typing import Dict, Any, List, Optional +from pathlib import Path +from datetime import datetime + +from tqdm import tqdm +from pipelines.contracts import ( + DatasetAdapter, ChunkMeta, IngestionRecord, DatasetSplit +) +from pipelines.ingest.validator import DocumentValidator +from pipelines.ingest.chunker import ChunkingStrategyFactory +from pipelines.ingest.embedder import EmbeddingPipeline +from pipelines.ingest.uploader import VectorStoreUploader +from pipelines.ingest.smoke_tests import SmokeTestRunner +from config.config_loader import load_config + + +logger = logging.getLogger(__name__) + + +class IngestionPipeline: + """ + Main ingestion pipeline orchestrating all components. + Implements deterministic IDs, idempotent loads, and comprehensive lineage. + """ + + def __init__(self, config_path: Optional[str] = None, config: Optional[Dict[str, Any]] = None): + """Initialize pipeline with configuration.""" + if config: + self.config = config + elif config_path: + self.config = load_config(config_path) + else: + self.config = load_config() # Default config.yml + + # Initialize components + self.validator = DocumentValidator(self.config.get("validation", {})) + self.embedding_pipeline = EmbeddingPipeline(self.config) + self.uploader = VectorStoreUploader(self.config) + self.smoke_test_runner = SmokeTestRunner(self.config) + + # Pipeline configuration + self.dry_run = False + self.max_documents = None + self.canary_mode = False + + # Output directories + self.output_dir = Path(self.config.get("output_dir", "output")) + self.lineage_dir = self.output_dir / "lineage" + self.lineage_dir.mkdir(parents=True, exist_ok=True) + + logger.info("Ingestion pipeline initialized") + + def ingest_dataset( + self, + adapter: DatasetAdapter, + split: DatasetSplit = DatasetSplit.ALL, + dry_run: bool = False, + max_documents: Optional[int] = None, + canary_mode: bool = False + ) -> IngestionRecord: + """ + Main ingestion method - processes a complete dataset. + + Args: + adapter: Dataset adapter implementing DatasetAdapter interface + split: Dataset split to process + dry_run: If True, don't upload to vector store + max_documents: Limit number of documents (for testing) + canary_mode: If True, use canary collection name + """ + self.dry_run = dry_run + self.max_documents = max_documents + self.canary_mode = canary_mode + + logger.info( + f"Starting ingestion: {adapter.source_name} v{adapter.version} (split: {split.value})") + if dry_run: + logger.info("DRY RUN MODE - No uploads will be performed") + if canary_mode: + logger.info("CANARY MODE - Using canary collection") + + # Create ingestion record + record = IngestionRecord( + dataset_name=adapter.source_name, + dataset_version=adapter.version, + config_hash=self._compute_config_hash(), + git_commit=self._get_git_commit(), + total_documents=0, + total_chunks=0, + successful_chunks=0, + failed_chunks=0, + started_at=datetime.utcnow(), + chunk_strategy=self.config.get("chunking", {}), + embedding_strategy=self.config.get("embedding", {}) + ) + + try: + # Step 1: Read and validate documents + logger.info("Step 1: Reading and validating documents...") + documents = self._read_and_validate_documents( + adapter, split, record) + + if not documents: + logger.warning("No valid documents found") + record.mark_complete() + return record + + # Step 2: Chunk documents + logger.info("Step 2: Chunking documents...") + chunks = self._chunk_documents(documents, record) + + # Step 3: Generate embeddings and create ChunkMeta + logger.info("Step 3: Generating embeddings...") + chunk_metas = self._process_chunks(chunks, record) + + # Step 4: Upload to vector store (unless dry run) + if not dry_run: + logger.info("Step 4: Uploading to vector store...") + upload_record = self._upload_chunks(chunk_metas) + + # Update record with upload results + record.successful_chunks = upload_record.successful_chunks + record.failed_chunks = upload_record.failed_chunks + + # Step 5: Run smoke tests + logger.info("Step 5: Running smoke tests...") + smoke_results = self._run_smoke_tests(chunk_metas) + record.metadata = {"smoke_test_results": smoke_results} + + else: + logger.info("Step 4: Skipped (dry run)") + record.successful_chunks = len(chunk_metas) + record.failed_chunks = 0 + + # Complete and save record + record.mark_complete() + self._save_lineage(record) + + logger.info( + f"Ingestion completed: {record.successful_chunks} successful, {record.failed_chunks} failed") + return record + + except Exception as e: + logger.error(f"Ingestion failed: {e}") + record.mark_complete() + record.metadata = {"error": str(e)} + self._save_lineage(record) + raise + + def _read_and_validate_documents( + self, + adapter: DatasetAdapter, + split: DatasetSplit, + record: IngestionRecord + ) -> List[Any]: + """Read documents from adapter and validate them.""" + # Read rows from adapter (don't limit here - let adapter process all rows for proper mapping) + rows = list(adapter.read_rows(split)) + + # Convert to documents + documents = adapter.to_documents(rows, split) + + # Apply max_documents limit AFTER document creation + if self.max_documents: + documents = documents[:self.max_documents] + logger.info( + f"Limited to {self.max_documents} documents for testing") + + record.total_documents = len(documents) + + logger.info( + f"Read {len(documents)} documents from {adapter.source_name}") + + # Validate documents + print("✅ Validating documents...") + validation_results = [] + for doc in tqdm(documents, desc="Validating documents", unit="doc"): + result = self.validator.validate_document(doc) + validation_results.append(result) + + # Filter valid documents and clean them + valid_documents = [] + validation_errors = [] + + print("🧹 Cleaning valid documents...") + for doc, validation_result in tqdm(zip(documents, validation_results), + desc="Processing validation", + total=len(documents), unit="doc"): + if validation_result.valid: + cleaned_doc = self.validator.clean_document(doc) + valid_documents.append(cleaned_doc) + else: + validation_errors.append({ + "doc_id": validation_result.doc_id, + "errors": validation_result.errors + }) + + logger.info( + f"Validation: {len(valid_documents)} valid, {len(validation_errors)} invalid") + + if validation_errors: + # Show first 5 + logger.warning(f"Validation errors found: {validation_errors[:5]}") + + return valid_documents + + def _chunk_documents(self, documents: List[Any], record: IngestionRecord) -> List[Any]: + """Chunk documents using configured strategy.""" + chunking_config = self.config.get("chunking", {}) + strategy_name = chunking_config.get("strategy", "recursive") + + # Auto-select strategy if needed + if strategy_name == "auto": + # Analyze first document to determine strategy + sample_content = documents[0].page_content if documents else "" + strategy = ChunkingStrategyFactory.get_strategy_for_content( + sample_content, chunking_config) + else: + strategy = ChunkingStrategyFactory.create_strategy( + strategy_name, chunking_config) + + logger.info(f"Using chunking strategy: {strategy.strategy_name}") + + # Chunk all documents + print("✂️ Chunking documents...") + chunks = [] + for doc in tqdm(documents, desc="Chunking documents", unit="doc"): + # Chunk one document at a time for progress + doc_chunks = strategy.chunk([doc]) + chunks.extend(doc_chunks) + + # Update chunk metadata with document totals + print("📊 Updating chunk metadata...") + doc_chunk_counts = {} + for chunk in tqdm(chunks, desc="Counting chunks per doc", unit="chunk"): + doc_id = chunk.metadata.get("external_id", "unknown") + doc_chunk_counts[doc_id] = doc_chunk_counts.get(doc_id, 0) + 1 + + # Update num_chunks for each chunk + for chunk in tqdm(chunks, desc="Updating metadata", unit="chunk"): + doc_id = chunk.metadata.get("external_id", "unknown") + chunk.metadata["num_chunks"] = doc_chunk_counts.get(doc_id, 1) + + record.total_chunks = len(chunks) + logger.info( + f"Generated {len(chunks)} chunks from {len(documents)} documents") + + return chunks + + def _process_chunks(self, chunks: List[Any], record: IngestionRecord) -> List[ChunkMeta]: + """Process chunks through embedding pipeline.""" + # Generate embeddings and convert to ChunkMeta + chunk_metas = self.embedding_pipeline.process_documents(chunks) + + # Add sample IDs to record + record.sample_doc_ids = list( + set(meta.doc_id for meta in chunk_metas[:10])) + record.sample_chunk_ids = [meta.chunk_id for meta in chunk_metas[:10]] + + logger.info( + f"Processed {len(chunk_metas)} chunk metas with embeddings") + return chunk_metas + + def _upload_chunks(self, chunk_metas: List[ChunkMeta]) -> IngestionRecord: + """Upload chunks to vector store.""" + if self.canary_mode: + # Modify collection name for canary + original_collection = self.uploader.collection_name + self.uploader.collection_name = f"{original_collection}_canary" + logger.info( + f"Using canary collection: {self.uploader.collection_name}") + + print(f"📤 Uploading {len(chunk_metas)} chunks to vector store...") + upload_record = self.uploader.upload_chunks(chunk_metas) + + # Verify upload with sample + if chunk_metas: + print("🔍 Verifying upload...") + sample_ids = [meta.chunk_id for meta in chunk_metas[:5]] + verification = self.uploader.verify_upload(sample_ids) + upload_record.metadata = {"verification": verification} + + if not verification.get("verification_passed", False): + logger.warning(f"Upload verification failed: {verification}") + + return upload_record + + def _run_smoke_tests(self, chunk_metas: List[ChunkMeta]) -> Dict[str, Any]: + """Run post-ingestion smoke tests.""" + # Pass the actual collection name being used + actual_collection = self.uploader.collection_name + results = self.smoke_test_runner.run_smoke_tests( + collection_name=actual_collection, chunk_metas=chunk_metas) + return {"test_results": results, "passed": all(r.passed for r in results)} + + def _save_lineage(self, record: IngestionRecord): + """Save ingestion lineage for reproducibility.""" + lineage_file = self.lineage_dir / \ + f"{record.dataset_name}_{record.run_id}.json" + + lineage_data = { + "ingestion_record": record.dict(), + "config": self.config, + "environment": { + "python_version": self._get_python_version(), + "git_commit": record.git_commit, + "timestamp": str(datetime.utcnow()), + "hostname": os.uname().nodename if hasattr(os, 'uname') else "unknown" + } + } + + with open(lineage_file, 'w') as f: + json.dump(lineage_data, f, indent=2, default=str) + + logger.info(f"Lineage saved to {lineage_file}") + + def _compute_config_hash(self) -> str: + """Compute hash of current configuration.""" + import hashlib + config_str = json.dumps(self.config, sort_keys=True, default=str) + return hashlib.sha256(config_str.encode()).hexdigest()[:12] + + def _get_git_commit(self) -> Optional[str]: + """Get current git commit hash.""" + try: + import subprocess + result = subprocess.run(['git', 'rev-parse', 'HEAD'], + capture_output=True, text=True, cwd='.') + if result.returncode == 0: + return result.stdout.strip() + except Exception: + pass + return None + + def _get_python_version(self) -> str: + """Get Python version.""" + import sys + return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + + def get_collection_status(self) -> Dict[str, Any]: + """Get current collection status.""" + return self.uploader.get_collection_info() + + def cleanup_canary_collections(self): + """Clean up canary collections after testing.""" + client = self.uploader.client + collections = client.get_collections().collections + + canary_collections = [ + c.name for c in collections if "_canary" in c.name] + + for collection_name in canary_collections: + try: + client.delete_collection(collection_name) + logger.info(f"Deleted canary collection: {collection_name}") + except Exception as e: + logger.error( + f"Error deleting canary collection {collection_name}: {e}") + + +class BatchIngestionPipeline: + """Pipeline for processing multiple datasets in sequence.""" + + def __init__(self, config_path: Optional[str] = None): + self.pipeline = IngestionPipeline(config_path) + self.results = [] + + def ingest_multiple_datasets( + self, + adapters: List[DatasetAdapter], + **kwargs + ) -> List[IngestionRecord]: + """Ingest multiple datasets in sequence.""" + logger.info(f"Starting batch ingestion of {len(adapters)} datasets") + + for i, adapter in enumerate(adapters): + logger.info( + f"Processing dataset {i+1}/{len(adapters)}: {adapter.source_name}") + + try: + record = self.pipeline.ingest_dataset(adapter, **kwargs) + self.results.append(record) + + # Brief pause between datasets + import time + time.sleep(1) + + except Exception as e: + logger.error(f"Failed to process {adapter.source_name}: {e}") + # Continue with next dataset + continue + + logger.info( + f"Batch ingestion completed: {len(self.results)} datasets processed") + return self.results + + def get_summary(self) -> Dict[str, Any]: + """Get summary of batch ingestion results.""" + if not self.results: + return {"error": "No results available"} + + total_documents = sum(r.total_documents for r in self.results) + total_chunks = sum(r.total_chunks for r in self.results) + successful_chunks = sum(r.successful_chunks for r in self.results) + failed_chunks = sum(r.failed_chunks for r in self.results) + + return { + "total_datasets": len(self.results), + "total_documents": total_documents, + "total_chunks": total_chunks, + "successful_chunks": successful_chunks, + "failed_chunks": failed_chunks, + "success_rate": successful_chunks / total_chunks if total_chunks > 0 else 0, + "datasets": [ + { + "name": r.dataset_name, + "version": r.dataset_version, + "documents": r.total_documents, + "chunks": r.total_chunks, + "success_rate": r.successful_chunks / r.total_chunks if r.total_chunks > 0 else 0 + } + for r in self.results + ] + } diff --git a/pipelines/ingest/smoke_tests.py b/pipelines/ingest/smoke_tests.py new file mode 100644 index 0000000..bf41863 --- /dev/null +++ b/pipelines/ingest/smoke_tests.py @@ -0,0 +1,327 @@ +""" +Smoke tests for validating ingestion pipeline results. +Provides basic sanity checks after data ingestion. +""" +import logging +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +from datetime import datetime + +from qdrant_client import QdrantClient +from database.qdrant_controller import QdrantVectorDB +from pipelines.contracts import ChunkMeta + +logger = logging.getLogger(__name__) + + +@dataclass +class SmokeTestResult: + """Result of a smoke test execution.""" + test_name: str + passed: bool + message: str + details: Dict[str, Any] = None + + def __post_init__(self): + if self.details is None: + self.details = {} + + +class SmokeTestRunner: + """Runs smoke tests after ingestion to validate data quality.""" + + def __init__(self, config: Dict[str, Any]): + """ + Initialize smoke test runner. + + Args: + config: Configuration dictionary containing test settings + """ + self.config = config + self.sample_size = config.get("sample_size", 5) + self.min_success_rate = config.get("min_success_rate", 0.8) + self.vector_db = QdrantVectorDB() + self.client = self.vector_db.get_client() + + def run_smoke_tests(self, collection_name: str, chunk_metas: List[ChunkMeta]) -> List[SmokeTestResult]: + """ + Run all smoke tests on the ingested data. + + Args: + collection_name: Name of the Qdrant collection to test + chunk_metas: List of ingested chunk metadata + + Returns: + List of smoke test results + """ + logger.info(f"Running smoke tests on collection: {collection_name}") + + results = [] + + # Test 1: Collection exists and has data + results.append(self._test_collection_exists(collection_name)) + + # Test 2: Data retrieval works + results.append(self._test_data_retrieval(collection_name)) + + # Test 3: Vector search works + results.append(self._test_vector_search(collection_name)) + + # Test 4: Sample data quality + if chunk_metas: + results.append(self._test_data_quality(chunk_metas)) + + # Summary + passed_tests = sum(1 for r in results if r.passed) + total_tests = len(results) + success_rate = passed_tests / total_tests if total_tests > 0 else 0 + + logger.info( + f"Smoke tests completed: {passed_tests}/{total_tests} passed ({success_rate:.1%})") + + return results + + def _test_collection_exists(self, collection_name: str) -> SmokeTestResult: + """Test if collection exists and has data.""" + try: + collections = self.client.get_collections() + collection_names = [col.name for col in collections.collections] + + if collection_name not in collection_names: + return SmokeTestResult( + test_name="collection_exists", + passed=False, + message=f"Collection {collection_name} not found", + details={"available_collections": collection_names} + ) + + # Check if collection has data + collection_info = self.client.get_collection(collection_name) + point_count = collection_info.points_count + + if point_count == 0: + return SmokeTestResult( + test_name="collection_exists", + passed=False, + message=f"Collection {collection_name} exists but has no data", + details={"point_count": point_count} + ) + + return SmokeTestResult( + test_name="collection_exists", + passed=True, + message=f"Collection {collection_name} exists with {point_count} points", + details={"point_count": point_count} + ) + + except Exception as e: + return SmokeTestResult( + test_name="collection_exists", + passed=False, + message=f"Error checking collection: {str(e)}" + ) + + def _test_data_retrieval(self, collection_name: str) -> SmokeTestResult: + """Test basic data retrieval.""" + try: + # Try to scroll through some points + points = self.client.scroll( + collection_name=collection_name, + limit=self.sample_size, + with_payload=True, + with_vectors=False + )[0] # Get the points from the tuple + + if not points: + return SmokeTestResult( + test_name="data_retrieval", + passed=False, + message="No points could be retrieved from collection" + ) + + # Check if points have required payload fields + sample_point = points[0] + required_fields = ["text", "doc_id", "chunk_id"] + missing_fields = [] + + for field in required_fields: + if field not in sample_point.payload: + missing_fields.append(field) + + if missing_fields: + return SmokeTestResult( + test_name="data_retrieval", + passed=False, + message=f"Missing required payload fields: {missing_fields}", + details={"available_fields": list( + sample_point.payload.keys())} + ) + + return SmokeTestResult( + test_name="data_retrieval", + passed=True, + message=f"Successfully retrieved {len(points)} points with valid payload", + details={"retrieved_count": len(points)} + ) + + except Exception as e: + return SmokeTestResult( + test_name="data_retrieval", + passed=False, + message=f"Error retrieving data: {str(e)}" + ) + + def _test_vector_search(self, collection_name: str) -> SmokeTestResult: + """Test vector search functionality.""" + try: + # Get collection info to determine vector dimensions + collection_info = self.client.get_collection(collection_name) + vector_config = collection_info.config.params.vectors + + # Handle both named vectors and single vector config + if hasattr(vector_config, 'dense'): + # Named vectors (hybrid setup) + vector_size = vector_config.dense.size + vector_name = "dense" + else: + # Single vector config + vector_size = vector_config.size + vector_name = None + + # Create a dummy query vector + query_vector = [0.1] * vector_size + + # Perform search + search_kwargs = { + "collection_name": collection_name, + "query_vector": query_vector, + "limit": 3 + } + + if vector_name: + search_kwargs["using"] = vector_name + + search_results = self.client.search(**search_kwargs) + + if not search_results: + return SmokeTestResult( + test_name="vector_search", + passed=False, + message="Vector search returned no results" + ) + + return SmokeTestResult( + test_name="vector_search", + passed=True, + message=f"Vector search successful, found {len(search_results)} results", + details={ + "result_count": len(search_results), + "vector_size": vector_size, + "top_score": search_results[0].score if search_results else None + } + ) + + except Exception as e: + return SmokeTestResult( + test_name="vector_search", + passed=False, + message=f"Error in vector search: {str(e)}" + ) + + def _test_data_quality(self, chunk_metas: List[ChunkMeta]) -> SmokeTestResult: + """Test data quality of ingested chunks.""" + try: + if not chunk_metas: + return SmokeTestResult( + test_name="data_quality", + passed=False, + message="No chunk metadata provided for quality testing" + ) + + # Sample some chunks for testing + sample_size = min(self.sample_size, len(chunk_metas)) + sample_chunks = chunk_metas[:sample_size] + + issues = [] + + for chunk in sample_chunks: + # Check text length + if not chunk.text or len(chunk.text.strip()) < 10: + issues.append( + f"Chunk {chunk.chunk_id} has insufficient text") + + # Check required fields + if not chunk.doc_id: + issues.append(f"Chunk {chunk.chunk_id} missing doc_id") + + if not chunk.source: + issues.append(f"Chunk {chunk.chunk_id} missing source") + + if issues: + return SmokeTestResult( + test_name="data_quality", + passed=False, + message=f"Found {len(issues)} data quality issues", + details={"issues": issues[:10]} # Limit to first 10 issues + ) + + return SmokeTestResult( + test_name="data_quality", + passed=True, + message=f"Data quality check passed for {sample_size} chunks", + details={"checked_chunks": sample_size} + ) + + except Exception as e: + return SmokeTestResult( + test_name="data_quality", + passed=False, + message=f"Error in data quality check: {str(e)}" + ) + + def run_all_tests(self, collection_name: str, chunk_metas: Optional[List[ChunkMeta]] = None) -> Dict[str, Any]: + """ + Run all smoke tests and return summary results. + + Args: + collection_name: Name of the Qdrant collection to test + chunk_metas: Optional list of ingested chunk metadata + + Returns: + Dictionary with test results and summary + """ + if chunk_metas is None: + chunk_metas = [] + + test_results = self.run_smoke_tests(collection_name, chunk_metas) + + # Calculate summary statistics + passed_tests = sum(1 for r in test_results if r.passed) + total_tests = len(test_results) + success_rate = passed_tests / total_tests if total_tests > 0 else 0 + + # Check if overall success rate meets minimum threshold + overall_passed = success_rate >= self.min_success_rate + + summary = { + "overall_passed": overall_passed, + "success_rate": success_rate, + "passed_tests": passed_tests, + "total_tests": total_tests, + "min_success_rate": self.min_success_rate, + "test_results": [ + { + "test_name": result.test_name, + "passed": result.passed, + "message": result.message, + "details": result.details + } + for result in test_results + ], + "timestamp": datetime.utcnow().isoformat() + } + + logger.info( + f"Smoke tests summary: {passed_tests}/{total_tests} passed ({success_rate:.1%}) - Overall: {'PASS' if overall_passed else 'FAIL'}") + + return summary diff --git a/pipelines/ingest/uploader.py b/pipelines/ingest/uploader.py new file mode 100644 index 0000000..f40aa3b --- /dev/null +++ b/pipelines/ingest/uploader.py @@ -0,0 +1,286 @@ +""" +Vector store uploader with idempotent operations and versioning. +Handles dense, sparse, and hybrid vector upserts to Qdrant. +""" +import uuid +import logging +import hashlib +from typing import List, Dict, Any, Optional +from datetime import datetime + +from qdrant_client import QdrantClient +from qdrant_client.http.models import PointStruct, VectorParams, SparseVectorParams, Distance +from qdrant_client.http.models import SparseVector +from qdrant_client import models as qmodels + +from database.qdrant_controller import QdrantVectorDB +from pipelines.contracts import ChunkMeta, IngestionRecord + + +logger = logging.getLogger(__name__) + + +def string_to_uuid(text: str) -> str: + """Convert a string to a deterministic UUID.""" + # Create a deterministic UUID from string using SHA256 + hash_bytes = hashlib.sha256(text.encode('utf-8')).digest()[:16] + return str(uuid.UUID(bytes=hash_bytes)) + + +class VectorStoreUploader: + """Handles idempotent vector store uploads with versioning.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.qdrant_config = config.get("qdrant", {}) + + # Collection configuration + self.collection_name = self.qdrant_config.get("collection", "default_collection") + self.dense_vector_name = self.qdrant_config.get("dense_vector_name", "dense") + self.sparse_vector_name = self.qdrant_config.get("sparse_vector_name", "sparse") + + # Upload configuration + self.batch_size = config.get("upload", {}).get("batch_size", 100) + self.wait_for_completion = config.get("upload", {}).get("wait", True) + self.enable_versioning = config.get("upload", {}).get("versioning", True) + + # Initialize Qdrant + self.vector_db = QdrantVectorDB() + self.client = self.vector_db.get_client() + + logger.info(f"Initialized uploader for collection: {self.collection_name}") + + def upload_chunks(self, chunk_metas: List[ChunkMeta]) -> IngestionRecord: + """Upload chunk metas to vector store with full lineage tracking.""" + if not chunk_metas: + logger.warning("No chunks provided for upload") + return self._create_empty_record() + + logger.info(f"Starting upload of {len(chunk_metas)} chunks to {self.collection_name}") + + # Create ingestion record + record = IngestionRecord( + dataset_name=chunk_metas[0].source, + dataset_version=chunk_metas[0].dataset_version, + config_hash=chunk_metas[0].config_hash or "unknown", + git_commit=chunk_metas[0].git_commit, + total_documents=len(set(meta.doc_id for meta in chunk_metas)), + total_chunks=len(chunk_metas), + successful_chunks=0, + failed_chunks=0, + started_at=datetime.utcnow() + ) + + # Prepare collection + self._ensure_collection_exists(chunk_metas) + + # Upload in batches + successful_count = 0 + failed_count = 0 + + for i in range(0, len(chunk_metas), self.batch_size): + batch = chunk_metas[i:i + self.batch_size] + batch_success, batch_failed = self._upload_batch(batch) + successful_count += batch_success + failed_count += batch_failed + + logger.info(f"Batch {i//self.batch_size + 1}: {batch_success} successful, {batch_failed} failed") + + # Update record + record.successful_chunks = successful_count + record.failed_chunks = failed_count + record.mark_complete() + + # Add sample IDs for verification + record.sample_doc_ids = list(set(meta.doc_id for meta in chunk_metas[:10])) + record.sample_chunk_ids = [meta.chunk_id for meta in chunk_metas[:10]] + + logger.info(f"Upload completed: {successful_count} successful, {failed_count} failed") + return record + + def _upload_batch(self, chunk_metas: List[ChunkMeta]) -> tuple[int, int]: + """Upload a batch of chunk metas.""" + points = [] + successful_count = 0 + failed_count = 0 + + for meta in chunk_metas: + try: + point = self._chunk_meta_to_point(meta) + if point: + points.append(point) + successful_count += 1 + else: + failed_count += 1 + except Exception as e: + logger.error(f"Error converting chunk {meta.chunk_id} to point: {e}") + failed_count += 1 + + # Upsert points + if points: + try: + self.client.upsert( + collection_name=self.collection_name, + points=points, + wait=self.wait_for_completion + ) + logger.debug(f"Successfully upserted {len(points)} points") + except Exception as e: + logger.error(f"Error upserting batch: {e}") + # All points in batch failed + failed_count += successful_count + successful_count = 0 + + return successful_count, failed_count + + def _chunk_meta_to_point(self, meta: ChunkMeta) -> Optional[PointStruct]: + """Convert ChunkMeta to Qdrant PointStruct.""" + try: + # Extract embeddings from ChunkMeta fields + dense_embedding = meta.dense_embedding + sparse_embedding = meta.sparse_embedding + + # Prepare vectors + vectors = {} + + if dense_embedding: + vectors[self.dense_vector_name] = dense_embedding + + if sparse_embedding: + # Handle different sparse formats + if isinstance(sparse_embedding, dict): + # Convert {token_id: weight} to SparseVector + indices = list(sparse_embedding.keys()) + values = list(sparse_embedding.values()) + vectors[self.sparse_vector_name] = SparseVector( + indices=indices, + values=values + ) + elif hasattr(sparse_embedding, 'indices') and hasattr(sparse_embedding, 'values'): + # Already in correct format + vectors[self.sparse_vector_name] = sparse_embedding + + # Prepare payload (exclude embeddings to save space) + payload = meta.to_dict() + payload.pop("metadata", None) # Remove nested metadata + + # Add computed fields + payload["embedding_types"] = list(vectors.keys()) + payload["has_dense"] = self.dense_vector_name in vectors + payload["has_sparse"] = self.sparse_vector_name in vectors + + return PointStruct( + id=string_to_uuid(meta.chunk_id), # Convert string ID to UUID + vector=vectors, + payload=payload + ) + + except Exception as e: + logger.error(f"Error creating point for chunk {meta.chunk_id}: {e}") + return None + + def _ensure_collection_exists(self, chunk_metas: List[ChunkMeta]): + """Ensure collection exists with proper configuration.""" + if self.client.collection_exists(self.collection_name): + logger.info(f"Collection {self.collection_name} already exists") + return + + # Determine vector dimensions + dense_dim = None + for meta in chunk_metas: + dense_embedding = meta.dense_embedding + if dense_embedding: + dense_dim = len(dense_embedding) + break + + if dense_dim is None: + logger.warning("No dense embeddings found, using default dimension") + dense_dim = 384 # Default dimension + + # Create collection with vector configurations + vectors_config = {} + sparse_vectors_config = {} + + # Add dense vector config if needed + if any(meta.dense_embedding for meta in chunk_metas): + vectors_config[self.dense_vector_name] = VectorParams( + size=dense_dim, + distance=Distance.COSINE + ) + + # Add sparse vector config if needed + if any(meta.sparse_embedding for meta in chunk_metas): + sparse_vectors_config[self.sparse_vector_name] = SparseVectorParams( + index=qmodels.SparseIndexParams(on_disk=False) + ) + + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=vectors_config, + sparse_vectors_config=sparse_vectors_config + ) + + logger.info(f"Created collection {self.collection_name} with {len(vectors_config)} dense and {len(sparse_vectors_config)} sparse vectors") + + def _create_empty_record(self) -> IngestionRecord: + """Create empty ingestion record for failed cases.""" + return IngestionRecord( + dataset_name="unknown", + dataset_version="unknown", + config_hash="unknown", + git_commit=None, + total_documents=0, + total_chunks=0, + successful_chunks=0, + failed_chunks=0, + started_at=datetime.utcnow() + ) + + def get_collection_info(self) -> Dict[str, Any]: + """Get information about the current collection.""" + try: + info = self.client.get_collection(self.collection_name) + return { + "collection_name": self.collection_name, + "points_count": info.points_count, + "vectors_config": info.config.params.vectors, + "sparse_vectors_config": info.config.params.sparse_vectors, + "status": info.status + } + except Exception as e: + logger.error(f"Error getting collection info: {e}") + return {} + + def verify_upload(self, sample_chunk_ids: List[str]) -> Dict[str, Any]: + """Verify that sample chunks were uploaded correctly.""" + verification_results = { + "total_samples": len(sample_chunk_ids), + "found_samples": 0, + "missing_samples": [], + "verification_passed": False + } + + try: + for chunk_id in sample_chunk_ids: + # Convert chunk_id to UUID (same as during upload) + uuid_id = string_to_uuid(chunk_id) + points = self.client.retrieve( + collection_name=self.collection_name, + ids=[uuid_id] + ) + + if points: + verification_results["found_samples"] += 1 + else: + verification_results["missing_samples"].append(chunk_id) + + # Consider verification passed if at least 90% of samples found + success_rate = verification_results["found_samples"] / verification_results["total_samples"] + verification_results["verification_passed"] = success_rate >= 0.9 + verification_results["success_rate"] = success_rate + + except Exception as e: + logger.error(f"Error during verification: {e}") + verification_results["error"] = str(e) + + return verification_results diff --git a/pipelines/ingest/validator.py b/pipelines/ingest/validator.py new file mode 100644 index 0000000..33f06e0 --- /dev/null +++ b/pipelines/ingest/validator.py @@ -0,0 +1,239 @@ +""" +Document validation and cleaning pipeline. +Implements "fail fast" validation with comprehensive checks. +""" +import re +import unicodedata +from typing import List, Set, Dict, Any +from collections import Counter + +from langchain_core.documents import Document +from pipelines.contracts import ValidationResult, normalize_text, compute_content_hash + + +class DocumentValidator: + """Validates and cleans documents before ingestion.""" + + def __init__(self, config: Dict[str, Any]): + self.min_char_length = config.get("min_char_length", 50) + self.max_char_length = config.get("max_char_length", 1_000_000) + self.min_token_estimate = config.get("min_token_estimate", 10) + self.max_token_estimate = config.get("max_token_estimate", 100_000) + self.allowed_languages = set(config.get("allowed_languages", ["en"])) + self.remove_duplicates = config.get("remove_duplicates", True) + self.normalize_unicode = config.get("normalize_unicode", True) + self.clean_html = config.get("clean_html", True) + self.preserve_code_blocks = config.get("preserve_code_blocks", True) + + # Content patterns to check + self.suspicious_patterns = [ + r"(?i)lorem ipsum", # Placeholder text + r"^.{0,10}$", # Too short content + r"^(.)\1{50,}", # Repeated characters + ] + + # Character sets - be more permissive for HTML/code content + self.allowed_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + self.allowed_chars.update(" .,!?;:()[]{}\"'-\n\t") + self.allowed_chars.update("<>/&=#@%*+_|\\`~$^") # Add common HTML/code characters + + # Seen document hashes for deduplication + self.seen_hashes: Set[str] = set() + + def validate_batch(self, documents: List[Document]) -> List[ValidationResult]: + """Validate a batch of documents.""" + results = [] + + for doc in documents: + result = self.validate_document(doc) + results.append(result) + + return results + + def validate_document(self, doc: Document) -> ValidationResult: + """Validate a single document.""" + doc_id = doc.metadata.get("external_id", "unknown") + errors = [] + warnings = [] + + # Basic content checks + content = doc.page_content or "" + if not content.strip(): + errors.append("Empty content") + + # Length checks + char_count = len(content) + if char_count < self.min_char_length: + errors.append(f"Content too short: {char_count} < {self.min_char_length} chars") + + if char_count > self.max_char_length: + errors.append(f"Content too long: {char_count} > {self.max_char_length} chars") + + # Token estimate + token_estimate = self._estimate_tokens(content) + if token_estimate < self.min_token_estimate: + warnings.append(f"Low token count estimate: {token_estimate}") + + if token_estimate > self.max_token_estimate: + warnings.append(f"High token count estimate: {token_estimate}") + + # Content quality checks + errors.extend(self._check_content_quality(content)) + warnings.extend(self._check_content_warnings(content)) + + # Duplication check + if self.remove_duplicates: + content_hash = compute_content_hash(content) + if content_hash in self.seen_hashes: + errors.append(f"Duplicate content (hash: {content_hash[:12]})") + else: + self.seen_hashes.add(content_hash) + + # Metadata validation + errors.extend(self._validate_metadata(doc.metadata)) + + return ValidationResult( + valid=len(errors) == 0, + doc_id=doc_id, + errors=errors, + warnings=warnings + ) + + def clean_document(self, doc: Document) -> Document: + """Clean and normalize document content.""" + content = doc.page_content or "" + + # Unicode normalization + if self.normalize_unicode: + content = unicodedata.normalize('NFKC', content) + + # HTML cleaning + if self.clean_html: + content = self._clean_html(content) + + # Text normalization + content = self._normalize_text(content) + + # Update document + cleaned_doc = Document( + page_content=content, + metadata=doc.metadata.copy() + ) + + # Update character count in metadata + cleaned_doc.metadata["char_count"] = len(content) + cleaned_doc.metadata["token_estimate"] = self._estimate_tokens(content) + + return cleaned_doc + + def _estimate_tokens(self, text: str) -> int: + """Rough token estimation (1 token ≈ 4 characters for English).""" + return len(text) // 4 + + def _check_content_quality(self, content: str) -> List[str]: + """Check for content quality issues.""" + errors = [] + + # Check for suspicious patterns + for pattern in self.suspicious_patterns: + if re.search(pattern, content): + errors.append(f"Suspicious pattern detected: {pattern}") + + # Check character distribution + char_counts = Counter(content) + total_chars = len(content) + + # Check for excessive repetition + if total_chars > 0: + max_char_ratio = max(count / total_chars for count in char_counts.values()) + if max_char_ratio > 0.5: # More than 50% of content is single character + errors.append(f"Excessive character repetition: {max_char_ratio:.2%}") + + # Check for valid character set + invalid_chars = set(content) - self.allowed_chars + if invalid_chars: + warnings_chars = list(invalid_chars)[:5] # Show first 5 + errors.append(f"Invalid characters found: {warnings_chars}") + + return errors + + def _check_content_warnings(self, content: str) -> List[str]: + """Check for content quality warnings.""" + warnings = [] + + # Check for very long lines (might be malformed) + lines = content.split('\n') + long_lines = [i for i, line in enumerate(lines) if len(line) > 1000] + if long_lines: + warnings.append(f"Very long lines found at positions: {long_lines[:5]}") + + # Check for unusual whitespace patterns + if re.search(r'\s{10,}', content): + warnings.append("Excessive whitespace found") + + # Check language (basic heuristic) + if self.allowed_languages and "en" in self.allowed_languages: + english_ratio = self._estimate_english_ratio(content) + if english_ratio < 0.7: + warnings.append(f"Low English content ratio: {english_ratio:.2%}") + + return warnings + + def _validate_metadata(self, metadata: Dict[str, Any]) -> List[str]: + """Validate required metadata fields.""" + errors = [] + + required_fields = ["external_id", "source"] + for field in required_fields: + if field not in metadata or not metadata[field]: + errors.append(f"Missing required metadata field: {field}") + + return errors + + def _clean_html(self, content: str) -> str: + """Basic HTML cleaning.""" + # Remove HTML tags + content = re.sub(r'<[^>]+>', ' ', content) + + # Decode common HTML entities + html_entities = { + '&': '&', + '<': '<', + '>': '>', + '"': '"', + ''': "'", + ' ': ' ' + } + + for entity, replacement in html_entities.items(): + content = content.replace(entity, replacement) + + return content + + def _normalize_text(self, content: str) -> str: + """Normalize text content.""" + # Fix multiple whitespaces + content = re.sub(r'\s+', ' ', content) + + # Fix multiple newlines (preserve paragraph structure) + content = re.sub(r'\n\s*\n\s*\n+', '\n\n', content) + + # Strip leading/trailing whitespace + content = content.strip() + + return content + + def _estimate_english_ratio(self, content: str) -> float: + """Rough estimation of English content ratio.""" + # Simple heuristic: check for common English words + english_words = { + 'the', 'and', 'to', 'of', 'a', 'in', 'is', 'it', 'you', 'that', + 'he', 'was', 'for', 'on', 'are', 'as', 'with', 'his', 'they', 'i' + } + + words = re.findall(r'\b\w+\b', content.lower()) + if not words: + return 0.0 + + english_count = sum(1 for word in words if word in english_words) + return english_count / len(words) diff --git a/playground/test_db_controller.py b/playground/test_db_controller.py deleted file mode 100644 index 48ca6db..0000000 --- a/playground/test_db_controller.py +++ /dev/null @@ -1,5 +0,0 @@ -from database import QdrantVectorDB - -if __name__ == "__main__": - db = QdrantVectorDB() - db.init_collection(vector_size=384) diff --git a/playground/test_dense_retriever.py b/playground/test_dense_retriever.py deleted file mode 100644 index ddddc66..0000000 --- a/playground/test_dense_retriever.py +++ /dev/null @@ -1,39 +0,0 @@ -from retrievers.dense_retriever import QdrantDenseRetriever -from database.qdrant_controller import QdrantVectorDB -from embedding.factory import get_embedder -from langchain_core.documents import Document -import os -import dotenv - -# Load environment variables -dotenv.load_dotenv(override=True) - -if __name__ == "__main__": - # Load embedder from env - embedder = get_embedder(name=os.getenv("DENSE_EMBEDDER")) - print(f"Using embedder: {embedder}") - # Initialize DB and LangChain-compatible vectorstore - db = QdrantVectorDB() - vectorstore = db.as_langchain_vectorstore(dense_embedding=embedder) - - # Create retriever - retriever = QdrantDenseRetriever( - embedding=embedder, - vectorstore=vectorstore, - top_k=5 - ) - - # Run query - query = "Advanced RAG Models with Graph Structures: Optimizing Complex Knowledge Reasoning and Text Generation" - results = retriever.get_relevant_documents(query) - - if not results: - print("No results found.") - else: - for i, (doc, score) in enumerate(results): - print(f"[{i+1}] Score: {score:.4f}") - print( - f"Doc ID: {doc.metadata.get('doc_id', 'N/A')} | Chunk ID: {doc.metadata.get('chunk_id', 'N/A')}") - print("Text:") - print(doc.page_content[:200]) - print("-" * 50) diff --git a/playground/test_embedding_pipeline.py b/playground/test_embedding_pipeline.py deleted file mode 100644 index cb46b2d..0000000 --- a/playground/test_embedding_pipeline.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -import uuid -import logging -from typing import List -import dotenv - -from langchain.schema import Document -from processors.dispatcher import ProcessorDispatcher -from embedding.factory import get_embedder -from embedding.recursive_splitter import RecursiveSplitter -from database.qdrant_controller import QdrantVectorDB - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -def prepare_documents(texts: List[str], original_docs: List[Document]) -> List[Document]: - enriched = [] - for i, text in enumerate(texts): - src = original_docs[i % len(original_docs)] - enriched.append( - Document( - page_content=text, - metadata={ - "source": src.metadata.get("source", "unknown"), - "doc_id": src.metadata.get("doc_id", str(uuid.uuid4())), - "chunk_id": i - } - ) - ) - return enriched - - -def run_embedding_and_insert(): - dotenv.load_dotenv(override=True) - - # 1. Load + chunk - processor = ProcessorDispatcher(chunk_size=300, chunk_overlap=30) - raw_docs = processor.process_directory("sandbox") - splitter = RecursiveSplitter(chunk_size=300, chunk_overlap=30) - chunks = splitter.split(raw_docs) - documents = prepare_documents( - texts=[c.page_content for c in chunks], - original_docs=raw_docs - ) - - # 2. embedders - dense_embedder = get_embedder(os.getenv("DENSE_EMBEDDER", "hf")) - sparse_embedder = get_embedder(os.getenv("SPARSE_EMBEDDER", "bm25")) - - # 3. init Qdrant - db = QdrantVectorDB() - - # compute your dense dimension once: - dq = dense_embedder.embed_query("test") - if hasattr(dq, "shape"): - dense_dim = dq.shape[-1] - else: - dense_dim = len(dq) - - db.init_collection(dense_vector_size=dense_dim) - - # 4. insert - db.insert_documents( - documents=documents, - dense_embedder=dense_embedder, - sparse_embedder=sparse_embedder - ) - print(f"Inserted {len(documents)} documents into Qdrant.") - - -if __name__ == "__main__": - run_embedding_and_insert() diff --git a/playground/test_hybrid_retriever.py b/playground/test_hybrid_retriever.py deleted file mode 100644 index 56c585e..0000000 --- a/playground/test_hybrid_retriever.py +++ /dev/null @@ -1,53 +0,0 @@ -# playground/test_hybrid_retriever.py - -import os -import dotenv - -from database.qdrant_controller import QdrantVectorDB -from retrievers.hybrid_retriever import QdrantHybridRetriever -from embedding.factory import get_embedder - -if __name__ == "__main__": - # 1. Load .env and initialize embedders - dotenv.load_dotenv(override=True) - dense_embedder_name = os.getenv("DENSE_EMBEDDER", "hf") - sparse_embedder_name = os.getenv("SPARSE_EMBEDDER", "bm25") - - dense_embedder = get_embedder(name=dense_embedder_name) - sparse_embedder = get_embedder(name=sparse_embedder_name) - print(f"Using dense embedder: {dense_embedder}") - print(f"Using sparse embedder: {sparse_embedder}") - - # 2. Initialize Qdrant database - qdrant_db = QdrantVectorDB() - qdrant_client = qdrant_db.client - collection_name = qdrant_db.get_collection_name() - - # 3. Build the hybrid retriever - hybrid_retriever = QdrantHybridRetriever( - client=qdrant_client, - collection_name=collection_name, - dense_embedding=dense_embedder, - sparse_embedding=sparse_embedder, - top_k=5 - ) - - # 4. Run a test query - test_query = ( - "Advanced RAG Models with Graph Structures:" - " Optimizing Complex Knowledge Reasoning and Text Generation" - ) - results = hybrid_retriever.retrieve(test_query) - - if not results: - print("No results found.") - else: - for i, (doc, score) in enumerate(results, start=1): - metadata = doc.metadata or {} - print(f"[{i}] score={score:.4f}") - print(f" source : {metadata.get('source', 'N/A')}") - print(f" doc_id : {metadata.get('doc_id', 'N/A')}") - print(f" chunk_id: {metadata.get('chunk_id', 'N/A')}") - print(" excerpt :", - doc.page_content[:200].replace("\n", " "), "…") - print("-" * 80) diff --git a/processors/__init__.py b/processors/__init__.py deleted file mode 100644 index b8b2b34..0000000 --- a/processors/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .pdf_processor import PDFProcessor -from processors.base import BaseProcessor -from .dispatcher import ProcessorDispatcher diff --git a/processors/base.py b/processors/base.py deleted file mode 100644 index ffb2de3..0000000 --- a/processors/base.py +++ /dev/null @@ -1,36 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List -from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter - - -class BaseProcessor(ABC): - def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - self.splitter = RecursiveCharacterTextSplitter( - chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, - ) - self.paths: List[str] = [] - - def add_file(self, path: str): - self.paths.append(path) - - @abstractmethod - def load_single(self, path: str) -> List[Document]: - """Load one file (PDF, CSV, etc).""" - pass - - def load(self) -> List[Document]: - """Load all files added so far.""" - docs = [] - for path in self.paths: - docs.extend(self.load_single(path)) - return docs - - def chunk(self, docs: List[Document]) -> List[Document]: - return self.splitter.split_documents(docs) - - def process(self) -> List[Document]: - return self.chunk(self.load()) diff --git a/processors/dispatcher.py b/processors/dispatcher.py deleted file mode 100644 index 8c9a7a7..0000000 --- a/processors/dispatcher.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -from typing import List, Dict -from langchain.schema import Document -from processors.pdf_processor import PDFProcessor -from processors.base import BaseProcessor - - -class ProcessorDispatcher: - def __init__(self, chunk_size=500, chunk_overlap=50): - self.registry: Dict[str, BaseProcessor] = { - ".pdf": PDFProcessor(chunk_size=chunk_size, chunk_overlap=chunk_overlap), - # ".csv": CSVProcessor(...), - # ".txt": TextProcessor(...), - } - - def process_directory(self, root: str) -> List[Document]: - for dirpath, _, filenames in os.walk(root): - for file in filenames: - ext = os.path.splitext(file)[1].lower() - processor = self.registry.get(ext) - if processor: - full_path = os.path.join(dirpath, file) - processor.add_file(full_path) - - all_docs = [] - for processor in self.registry.values(): - all_docs.extend(processor.process()) - return all_docs diff --git a/processors/pdf_processor.py b/processors/pdf_processor.py deleted file mode 100644 index 16dacfd..0000000 --- a/processors/pdf_processor.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import List -from langchain_community.document_loaders import PyPDFLoader -from langchain.schema import Document -from processors.base import BaseProcessor - - -class PDFProcessor(BaseProcessor): - def load_single(self, path: str) -> List[Document]: - loader = PyPDFLoader(path) - return loader.load() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d2c15ad --- /dev/null +++ b/pytest.ini @@ -0,0 +1,34 @@ +[pytest] +markers = + integration: marks tests as integration tests (requires external services) + slow: marks tests as slow running + requires_db: marks tests that require database connectivity + requires_api: marks tests that require API keys/external services + end_to_end: marks tests that test complete pipeline functionality + +# Test discovery +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Output options +addopts = + -v + --strict-markers + --tb=short + --maxfail=5 + +# Logging +log_cli = false +log_cli_level = INFO +log_cli_format = %(asctime)s [%(levelname)8s] %(name)s: %(message)s +log_cli_date_format = %Y-%m-%d %H:%M:%S + +# Coverage options (when using --cov) +norecursedirs = .git .venv venv env build dist *.egg + +# Filtering +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning diff --git a/requirements.txt b/requirements.txt index 8d1e49f..3a3b775 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/retrievers/__init__.py b/retrievers/__init__.py index d1d995d..db5a7d4 100644 --- a/retrievers/__init__.py +++ b/retrievers/__init__.py @@ -1,2 +1,21 @@ -from .dense_retriever import QdrantDenseRetriever -from .hybrid_retriever import QdrantHybridRetriever +""" +Modern retriever implementations that integrate with the retrieval pipeline. +These retrievers implement the BaseRetriever interface from components.retrieval_pipeline. +""" + +from .base_retriever import ModernBaseRetriever +from .dense_retriever import QdrantDenseRetriever, DenseRetriever +from .sparse_retriever import QdrantSparseRetriever, SparseRetriever +from .hybrid_retriever import QdrantHybridRetriever, HybridRetriever +from .semantic_retriever import SemanticRetriever + +__all__ = [ + "ModernBaseRetriever", + "QdrantDenseRetriever", + "DenseRetriever", + "QdrantSparseRetriever", + "SparseRetriever", + "QdrantHybridRetriever", + "HybridRetriever", + "SemanticRetriever" +] diff --git a/retrievers/base.py b/retrievers/base.py deleted file mode 100644 index 5ec2cff..0000000 --- a/retrievers/base.py +++ /dev/null @@ -1,35 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List -from langchain_core.documents import Document - - -class BaseRetriever(ABC): - """ - Abstract base class for all retrievers. - """ - - @abstractmethod - def retrieve(self, query: str, k: int = 5) -> List[Document]: - """ - Retrieve relevant documents for a given query. - - Args: - query (str): The input query string. - k (int, optional): Number of documents to retrieve. Defaults to 5. - - Returns: - List[Document]: List of retrieved documents. - """ - pass - - def get_relevant_documents(self, query: str) -> List[Document]: - """ - Compatibility method for LangChain retriever interface. - - Args: - query (str): The input query string. - - Returns: - List[Document]: List of retrieved documents. - """ - return self.retrieve(query=query) diff --git a/retrievers/base_retriever.py b/retrievers/base_retriever.py new file mode 100644 index 0000000..ce69838 --- /dev/null +++ b/retrievers/base_retriever.py @@ -0,0 +1,170 @@ +""" +Modern base retriever that integrates with the retrieval pipeline architecture. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from components.retrieval_pipeline import BaseRetriever, RetrievalResult +from langchain_core.documents import Document +import logging + +logger = logging.getLogger(__name__) + + +class ModernBaseRetriever(BaseRetriever): + """ + Modern base retriever implementing the retrieval pipeline interface. + All modern retrievers should inherit from this class. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize the retriever with configuration. + + Args: + config: Configuration dictionary containing retrieval parameters + """ + self.config = config + self.top_k = config.get("top_k", 5) + self.score_threshold = config.get("score_threshold", 0.0) + self._initialized = False # Track initialization state + + # Initialize any common components here + self._initialize_components() + + def _initialize_components(self): + """Initialize common components. Override in subclasses.""" + # Enable lazy initialization by default + self._lazy_init = self.config.get('lazy_initialization', True) + + # If lazy initialization is disabled, subclasses should override this + if not self._lazy_init: + # Actual initialization happens in subclasses + pass + + @abstractmethod + def _perform_search(self, query: str, k: int) -> List[RetrievalResult]: + """ + Perform the actual search operation. + + Args: + query: Search query + k: Number of results to retrieve + + Returns: + List of RetrievalResult objects + """ + pass + + def retrieve(self, query: str, k: int = None) -> List[RetrievalResult]: + """ + Retrieve documents for the given query. + + Args: + query: Search query + k: Number of results to retrieve (defaults to configured top_k) + + Returns: + List of RetrievalResult objects + """ + if k is None: + k = self.top_k + + logger.debug(f"Retrieving {k} results for query: {query[:50]}...") + + try: + # Perform the search + results = self._perform_search(query, k) + + # Apply score threshold filtering + if self.score_threshold > 0: + results = [r for r in results if r.score >= + self.score_threshold] + logger.debug( + f"Filtered to {len(results)} results above threshold {self.score_threshold}") + + # Ensure we don't return more than requested + results = results[:k] + + logger.debug(f"Retrieved {len(results)} results") + return results + + except Exception as e: + logger.error(f"Error during retrieval: {e}") + return [] + + def _create_retrieval_result( + self, + document: Document, + score: float, + additional_metadata: Dict[str, Any] = None + ) -> RetrievalResult: + """ + Create a RetrievalResult object with proper metadata. + + Args: + document: The retrieved document + score: Relevance score + additional_metadata: Additional metadata to include + + Returns: + RetrievalResult object + """ + metadata = { + "retriever_config": self.config, + "retrieval_timestamp": None, # Can add timestamp if needed + } + + if additional_metadata: + metadata.update(additional_metadata) + + return RetrievalResult( + document=document, + score=score, + retrieval_method=self.component_name, + metadata=metadata + ) + + def _normalize_scores(self, results: List[RetrievalResult]) -> List[RetrievalResult]: + """ + Normalize scores to [0, 1] range. + + Args: + results: List of retrieval results + + Returns: + Results with normalized scores + """ + if not results: + return results + + scores = [r.score for r in results] + min_score = min(scores) + max_score = max(scores) + + if max_score == min_score: + # All scores are the same + for result in results: + result.score = 1.0 + else: + # Normalize to [0, 1] + for result in results: + result.score = (result.score - min_score) / \ + (max_score - min_score) + + return results + + def _validate_config(self, required_keys: List[str]): + """ + Validate that required configuration keys are present. + + Args: + required_keys: List of required configuration keys + + Raises: + ValueError: If required keys are missing + """ + missing_keys = [key for key in required_keys if key not in self.config] + if missing_keys: + raise ValueError( + f"Missing required configuration keys: {missing_keys}") diff --git a/retrievers/dense_retriever.py b/retrievers/dense_retriever.py index ee2b5f9..cc1da03 100644 --- a/retrievers/dense_retriever.py +++ b/retrievers/dense_retriever.py @@ -1,24 +1,169 @@ -from qdrant_client import QdrantClient +""" +Modern dense retriever that integrates with the retrieval pipeline architecture. +""" + +from typing import List, Dict, Any, Optional from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain_qdrant import QdrantVectorStore, RetrievalMode -from typing import List, Tuple -from .base import BaseRetriever - - -class QdrantDenseRetriever(BaseRetriever): - def __init__( - self, - embedding: Embeddings, - vectorstore: QdrantVectorStore, - top_k: int = 5, - ): - self.embedding = embedding - self.vectorstore = vectorstore - self.top_k = top_k - - def retrieve(self, query: str, k: int = None) -> List[Tuple[Document, float]]: - return self.vectorstore.similarity_search_with_score(query, k=k or self.top_k) - - def get_relevant_documents(self, query: str) -> List[Tuple[Document, float]]: - return self.retrieve(query) +from langchain_core.documents import Document +from langchain_qdrant import QdrantVectorStore +from components.retrieval_pipeline import RetrievalResult +from .base_retriever import ModernBaseRetriever +import logging + +logger = logging.getLogger(__name__) + + +class QdrantDenseRetriever(ModernBaseRetriever): + """ + Dense vector retriever using Qdrant and LangChain. + Performs semantic similarity search using dense embeddings. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize the dense retriever. + + Args: + config: Configuration dictionary containing: + - embedding: Dense embedding configuration + - qdrant: Qdrant database configuration + - top_k: Number of results to retrieve (default: 5) + - score_threshold: Minimum score threshold (default: 0.0) + """ + super().__init__(config) + + # Validate required configuration + if 'embedding' not in config: + logger.warning("No embedding config found, using default") + if 'qdrant' not in config: + logger.warning("No qdrant config found, using default") + + # Initialize components lazily + self.embedding = None + self.vectorstore = None + self._initialized = False + + def _initialize_components(self): + """Initialize embedding and vector store components.""" + if self._initialized: + return + + try: + # Initialize embedding - get the dense embedding config + from embedding.factory import get_embedder + + # Extract dense embedding config from the main config + embedding_section = self.config.get('embedding', {}) + if 'dense' in embedding_section: + embedding_config = embedding_section['dense'] + else: + # Fallback to legacy config or default + embedding_config = embedding_section or { + 'type': 'sentence_transformers', + 'model': 'sentence-transformers/all-MiniLM-L6-v2' + } + + self.embedding = get_embedder(embedding_config) + + # Initialize Qdrant vector store + from database.qdrant_controller import QdrantVectorDB + qdrant_db = QdrantVectorDB(config=self.config) + self.vectorstore = qdrant_db.as_langchain_vectorstore( + dense_embedding=self.embedding + ) + + self._initialized = True + logger.info( + f"Dense retriever initialized with embedding: {type(self.embedding).__name__}") + + except Exception as e: + logger.error( + f"Failed to initialize dense retriever components: {e}") + # Don't raise, just mark as failed to initialize + self._initialized = False + + @property + def component_name(self) -> str: + return "dense_retriever" + + def _perform_search(self, query: str, k: int) -> List[RetrievalResult]: + """ + Perform dense similarity search using direct Qdrant API to preserve external_id. + + Args: + query: Search query + k: Number of results to retrieve + + Returns: + List of RetrievalResult objects + """ + if not self._initialized: + self._initialize_components() + + if not self._initialized: + logger.warning( + "Dense retriever not properly initialized, returning empty results") + return [] + + try: + # Get query embedding + query_vector = self.embedding.embed_query(query) + + # Get Qdrant database instance + from database.qdrant_controller import QdrantVectorDB + qdrant_db = QdrantVectorDB(config=self.config) + + # Direct Qdrant search to preserve external_id + from qdrant_client.models import NamedVector + + search_result = qdrant_db.client.search( + collection_name=qdrant_db.collection_name, + query_vector=NamedVector( + name=qdrant_db.dense_vector_name, + vector=query_vector + ), + limit=k, + with_payload=True # Include all payload data including external_id + ) + + # Convert to RetrievalResult objects + retrieval_results = [] + for result in search_result: + payload = result.payload or {} + + # Create document with preserved external_id + document = Document( + page_content=payload.get('page_content', ''), + metadata={ + **payload.get('metadata', {}), + # Ensure external_id is in metadata + 'external_id': payload.get('external_id'), + # Also store the Qdrant UUID for reference + 'qdrant_id': str(result.id) + } + ) + + retrieval_result = self._create_retrieval_result( + document=document, + score=result.score, + additional_metadata={ + 'search_type': 'dense_similarity', + 'embedding_model': type(self.embedding).__name__, + # Also add to retrieval metadata + 'external_id': payload.get('external_id') + } + ) + retrieval_results.append(retrieval_result) + + # Normalize scores for consistency + retrieval_results = self._normalize_scores(retrieval_results) + + return retrieval_results + + except Exception as e: + logger.error(f"Error during dense search: {e}") + return [] + + +# Backward compatibility alias +DenseRetriever = QdrantDenseRetriever diff --git a/retrievers/hybrid_retriever.py b/retrievers/hybrid_retriever.py index ef1271f..b05b34e 100644 --- a/retrievers/hybrid_retriever.py +++ b/retrievers/hybrid_retriever.py @@ -1,62 +1,401 @@ -from typing import List, Tuple, Optional -from qdrant_client import QdrantClient +""" +Modern hybrid retriever that integrates with the retrieval pipeline architecture. +""" + +from typing import List, Dict, Any, Optional from langchain.embeddings.base import Embeddings -from langchain.schema import Document +from langchain_core.documents import Document from langchain_qdrant import QdrantVectorStore, RetrievalMode -from .base import BaseRetriever +from components.retrieval_pipeline import RetrievalResult +from .base_retriever import ModernBaseRetriever +import logging + +logger = logging.getLogger(__name__) -class QdrantHybridRetriever(BaseRetriever): +class QdrantHybridRetriever(ModernBaseRetriever): """ - A retriever that uses Qdrant's hybrid (dense + sparse) search. + Hybrid retriever combining dense and sparse vector search using Qdrant. + Leverages both semantic similarity (dense) and keyword matching (sparse). """ - def __init__( - self, - client: QdrantClient, - collection_name: str, - dense_embedding: Embeddings, - sparse_embedding: Embeddings, - *, - top_k: int = 5, - dense_vector_name: str = "dense", - sparse_vector_name: str = "sparse", - ): + def __init__(self, config: Dict[str, Any]): """ + Initialize the hybrid retriever. + Args: - client: an initialized QdrantClient - collection_name: the Qdrant collection to query - dense_embedding: the Embeddings instance for dense vectors - sparse_embedding: the Embeddings instance for sparse vectors - top_k: number of hits to return by default - dense_vector_name: name of the dense vector field in Qdrant - sparse_vector_name: name of the sparse vector field in Qdrant - """ - self.top_k = top_k - self.vs = QdrantVectorStore( - client=client, - collection_name=collection_name, - embedding=dense_embedding, - vector_name=dense_vector_name, - sparse_embedding=sparse_embedding, - sparse_vector_name=sparse_vector_name, - retrieval_mode=RetrievalMode.HYBRID, - ) - - def retrieve( - self, query: str, *, k: Optional[int] = None - ) -> List[Tuple[Document, float]]: + config: Configuration dictionary containing: + - embedding: Configuration with both dense and sparse embeddings + - qdrant: Qdrant database configuration + - top_k: Number of results to retrieve (default: 5) + - score_threshold: Minimum score threshold (default: 0.0) + - fusion_method: How to combine dense/sparse scores (default: 'rrf') """ - Returns top-k (Document, score) pairs according to hybrid search. - """ - return self.vs.similarity_search_with_score( - query, - k=k or self.top_k - ) + super().__init__(config) + + # Validate required configuration + if 'embedding' not in config: + logger.warning("No embedding config found, using defaults") + if 'qdrant' not in config: + logger.warning("No qdrant config found, using defaults") + + # Initialize components + self.dense_embedding = None + self.sparse_embedding = None + self.vectorstore = None + self.fusion_method = config.get( + 'fusion_method', 'rrf') # Reciprocal Rank Fusion + + # Load fusion parameters from config + fusion_config = config.get('fusion', {}) + self.rrf_k = fusion_config.get('rrf_k', 60) # Standard RRF constant + self.dense_weight = fusion_config.get('dense_weight', 0.5) + self.sparse_weight = fusion_config.get('sparse_weight', 0.5) + self._initialized = False + + def _initialize_components(self): + """Initialize embeddings and vector store components.""" + if self._initialized: + return + + try: + # Initialize embeddings + from embedding.factory import get_embedder + + embedding_section = self.config.get('embedding', {}) + + # Extract dense and sparse embedding configs + if 'dense' in embedding_section: + dense_config = embedding_section['dense'] + else: + # Default dense embedding config + dense_config = { + 'provider': 'google', + 'model': 'models/embedding-001', + 'dimensions': 768, + 'api_key_env': 'GOOGLE_API_KEY' + } + + if 'sparse' in embedding_section: + sparse_config = embedding_section['sparse'] + else: + # Default sparse embedding config + sparse_config = { + 'provider': 'sparse', + 'model': 'Qdrant/bm25', + 'vector_name': 'sparse' + } + + self.dense_embedding = get_embedder(dense_config) + self.sparse_embedding = get_embedder(sparse_config) + + # Initialize Qdrant database + from database.qdrant_controller import QdrantVectorDB + qdrant_db = QdrantVectorDB(config=self.config) + + # Store qdrant_db for direct API access + self.qdrant_db = qdrant_db - def get_relevant_documents(self, query: str) -> List[Document]: + self._initialized = True + logger.info(f"Hybrid retriever initialized with dense: {type(self.dense_embedding).__name__}, " + f"sparse: {type(self.sparse_embedding).__name__}") + + except Exception as e: + logger.error( + f"Failed to initialize hybrid retriever components: {e}") + import traceback + traceback.print_exc() + self._initialized = False + + @property + def component_name(self) -> str: + return "hybrid_retriever" + + def _perform_search(self, query: str, k: int) -> List[RetrievalResult]: """ - Returns just the Documents (no scores). + Perform hybrid search combining dense and sparse retrieval. + + Args: + query: Search query + k: Number of results to retrieve + + Returns: + List of RetrievalResult objects """ - hits = self.retrieve(query) - return [doc for doc, _ in hits] + if not self._initialized: + self._initialize_components() + + if not self._initialized: + logger.warning( + "Hybrid retriever not properly initialized, returning empty results") + return [] + + try: + # Perform separate dense and sparse searches then combine + dense_results = self._perform_dense_search(query, k) + sparse_results = self._perform_sparse_search(query, k) + + # Combine results using Reciprocal Rank Fusion (RRF) + combined_results = self._fuse_results( + dense_results, sparse_results, k) + + return combined_results + + except Exception as e: + logger.error(f"Error during hybrid search: {e}") + import traceback + traceback.print_exc() + return [] + + def _perform_dense_search(self, query: str, k: int) -> List[RetrievalResult]: + """Perform dense search using direct Qdrant API.""" + try: + # Get dense query vector + query_vector = self.dense_embedding.embed_query(query) + + # Direct Qdrant search for dense vectors + from qdrant_client.models import NamedVector + + search_result = self.qdrant_db.client.search( + collection_name=self.qdrant_db.collection_name, + query_vector=NamedVector( + name=self.qdrant_db.dense_vector_name, + vector=query_vector + ), + limit=k, + with_payload=True + ) + + # Convert to RetrievalResult objects + results = [] + for result in search_result: + payload = result.payload or {} + + document = Document( + page_content=payload.get('page_content', ''), + metadata={ + **payload.get('metadata', {}), + 'external_id': payload.get('external_id'), + 'qdrant_id': str(result.id) + } + ) + + retrieval_result = self._create_retrieval_result( + document=document, + score=result.score, + additional_metadata={ + 'search_type': 'dense_component', + 'embedding_model': type(self.dense_embedding).__name__, + 'external_id': payload.get('external_id') + } + ) + results.append(retrieval_result) + + return results + + except Exception as e: + logger.error(f"Dense search component failed: {e}") + return [] + + def _perform_sparse_search(self, query: str, k: int) -> List[RetrievalResult]: + """Perform sparse search using direct Qdrant API.""" + try: + # Get sparse query vector + if hasattr(self.sparse_embedding, 'embed_query'): + query_vector = self.sparse_embedding.embed_query(query) + else: + query_vector = self.sparse_embedding.embed_documents([query])[ + 0] + + # Convert sparse dict to Qdrant sparse vector format for named sparse vectors + if isinstance(query_vector, dict): + from qdrant_client.models import NamedSparseVector + + search_result = self.qdrant_db.client.search( + collection_name=self.qdrant_db.collection_name, + query_vector=NamedSparseVector( + name=self.qdrant_db.sparse_vector_name, + vector={"indices": list(query_vector.keys()), "values": list( + query_vector.values())} + ), + limit=k, + with_payload=True + ) + else: + # Dense vector format (list) - fallback + from qdrant_client.models import NamedVector + + search_result = self.qdrant_db.client.search( + collection_name=self.qdrant_db.collection_name, + query_vector=NamedVector( + name=self.qdrant_db.sparse_vector_name, + vector=query_vector + ), + limit=k, + with_payload=True + ) + + # Convert to RetrievalResult objects + results = [] + for result in search_result: + payload = result.payload or {} + + document = Document( + page_content=payload.get('page_content', ''), + metadata={ + **payload.get('metadata', {}), + 'external_id': payload.get('external_id'), + 'qdrant_id': str(result.id) + } + ) + + retrieval_result = self._create_retrieval_result( + document=document, + score=result.score, + additional_metadata={ + 'search_type': 'sparse_component', + 'embedding_model': type(self.sparse_embedding).__name__, + 'external_id': payload.get('external_id') + } + ) + results.append(retrieval_result) + + return results + + except Exception as e: + logger.error(f"Sparse search component failed: {e}") + return [] + + def _fuse_results(self, dense_results: List[RetrievalResult], sparse_results: List[RetrievalResult], k: int) -> List[RetrievalResult]: + """Combine dense and sparse results using standard fusion methods.""" + try: + if self.fusion_method == 'rrf': + return self._fuse_with_rrf(dense_results, sparse_results, k) + elif self.fusion_method == 'weighted_sum': + return self._fuse_with_weighted_sum(dense_results, sparse_results, k) + else: + logger.warning( + f"Unknown fusion method: {self.fusion_method}, falling back to RRF") + return self._fuse_with_rrf(dense_results, sparse_results, k) + except Exception as e: + logger.error(f"Result fusion failed: {e}") + # Fallback to dense results + return dense_results[:k] + + def _fuse_with_rrf(self, dense_results: List[RetrievalResult], sparse_results: List[RetrievalResult], k: int) -> List[RetrievalResult]: + """Standard Reciprocal Rank Fusion (Cormack et al. 2009).""" + doc_scores = {} + rrf_k = self.rrf_k # Use configurable RRF constant + + # Add dense results with standard RRF scoring + for rank, result in enumerate(dense_results, 1): + doc_id = result.document.metadata.get('external_id') + if doc_id: + rrf_score = 1.0 / (rrf_k + rank) # Standard RRF formula + if doc_id not in doc_scores: + doc_scores[doc_id] = { + 'result': result, 'dense_score': rrf_score, 'sparse_score': 0} + else: + doc_scores[doc_id]['dense_score'] = rrf_score + + # Add sparse results with standard RRF scoring + for rank, result in enumerate(sparse_results, 1): + doc_id = result.document.metadata.get('external_id') + if doc_id: + rrf_score = 1.0 / (rrf_k + rank) # Standard RRF formula + if doc_id not in doc_scores: + doc_scores[doc_id] = { + 'result': result, 'dense_score': 0, 'sparse_score': rrf_score} + else: + doc_scores[doc_id]['sparse_score'] = rrf_score + + # Combine scores and sort + combined_results = [] + for doc_id, scores in doc_scores.items(): + combined_score = scores['dense_score'] + scores['sparse_score'] + result = scores['result'] + result.score = combined_score + result.metadata.update({ + 'search_type': 'hybrid_rrf', + 'dense_rrf_score': scores['dense_score'], + 'sparse_rrf_score': scores['sparse_score'], + 'fusion_method': 'rrf', + 'rrf_k': rrf_k + }) + combined_results.append(result) + + # Sort by combined score and return top k + combined_results.sort(key=lambda x: x.score, reverse=True) + return combined_results[:k] + + def _fuse_with_weighted_sum(self, dense_results: List[RetrievalResult], sparse_results: List[RetrievalResult], k: int) -> List[RetrievalResult]: + """Weighted sum fusion with score normalization.""" + # Get weights from config + dense_weight = self.dense_weight + sparse_weight = self.sparse_weight + + # Normalize weights to sum to 1 + total_weight = dense_weight + sparse_weight + if total_weight > 0: + dense_weight /= total_weight + sparse_weight /= total_weight + else: + dense_weight = sparse_weight = 0.5 + + # Normalize scores using min-max normalization + def normalize_scores(results): + if not results: + return {} + scores = [r.score for r in results] + min_score, max_score = min(scores), max(scores) + score_range = max_score - min_score + + normalized = {} + for result in results: + doc_id = result.document.metadata.get('external_id') + if doc_id and score_range > 0: + normalized[doc_id] = { + 'result': result, + 'score': (result.score - min_score) / score_range + } + elif doc_id: + normalized[doc_id] = {'result': result, 'score': 1.0} + return normalized + + dense_normalized = normalize_scores(dense_results) + sparse_normalized = normalize_scores(sparse_results) + + # Combine normalized scores + doc_scores = {} + all_doc_ids = set(dense_normalized.keys()) | set( + sparse_normalized.keys()) + + for doc_id in all_doc_ids: + dense_score = dense_normalized.get(doc_id, {}).get('score', 0.0) + sparse_score = sparse_normalized.get(doc_id, {}).get('score', 0.0) + + combined_score = dense_weight * dense_score + sparse_weight * sparse_score + + # Use the result from whichever retriever found this document + result = (dense_normalized.get(doc_id) + or sparse_normalized.get(doc_id))['result'] + result.score = combined_score + result.metadata.update({ + 'search_type': 'hybrid_weighted', + 'dense_weight': dense_weight, + 'sparse_weight': sparse_weight, + 'dense_norm_score': dense_score, + 'sparse_norm_score': sparse_score, + 'fusion_method': 'weighted_sum' + }) + + doc_scores[doc_id] = result + + # Sort by combined score and return top k + combined_results = list(doc_scores.values()) + combined_results.sort(key=lambda x: x.score, reverse=True) + return combined_results[:k] + + +# Backward compatibility alias +HybridRetriever = QdrantHybridRetriever diff --git a/retrievers/semantic_retriever.py b/retrievers/semantic_retriever.py new file mode 100644 index 0000000..296df83 --- /dev/null +++ b/retrievers/semantic_retriever.py @@ -0,0 +1,277 @@ +""" +Modern semantic retriever that integrates with the retrieval pipeline architecture. +""" + +from typing import List, Dict, Any, Optional +from langchain_core.documents import Document +from components.retrieval_pipeline import RetrievalResult +from .base_retriever import ModernBaseRetriever +import logging + +logger = logging.getLogger(__name__) + + +class SemanticRetriever(ModernBaseRetriever): + """ + Advanced semantic retriever that can use multiple retrieval strategies + and combine them intelligently based on query analysis. + + This retriever can: + - Analyze query intent and complexity + - Route to appropriate retrieval strategies + - Combine multiple retrieval methods + - Apply semantic post-processing + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize the semantic retriever. + + Args: + config: Configuration dictionary containing: + - strategies: List of retrieval strategies to use + - query_analyzer: Configuration for query analysis + - routing_rules: Rules for strategy selection + - top_k: Number of results to retrieve (default: 5) + - score_threshold: Minimum score threshold (default: 0.0) + """ + super().__init__(config) + + # Initialize retrieval strategies + self.strategies = {} + self.strategy_weights = {} + self.query_analyzer = None + + # Configuration + self.routing_rules = config.get('routing_rules', {}) + self.default_strategy = config.get('default_strategy', 'hybrid') + + def _initialize_components(self): + """Initialize retrieval strategies and query analyzer.""" + try: + strategies_config = self.config.get('strategies', {}) + + # Initialize available strategies + for strategy_name, strategy_config in strategies_config.items(): + if strategy_config.get('enabled', True): + strategy = self._create_strategy( + strategy_name, strategy_config) + if strategy: + self.strategies[strategy_name] = strategy + self.strategy_weights[strategy_name] = strategy_config.get( + 'weight', 1.0) + + # Initialize query analyzer if configured + analyzer_config = self.config.get('query_analyzer', {}) + if analyzer_config.get('enabled', False): + self.query_analyzer = self._create_query_analyzer( + analyzer_config) + + logger.info( + f"Semantic retriever initialized with strategies: {list(self.strategies.keys())}") + + except Exception as e: + logger.error( + f"Failed to initialize semantic retriever components: {e}") + # Initialize empty strategies to prevent further errors + if not hasattr(self, 'strategies'): + self.strategies = {} + if not hasattr(self, 'strategy_weights'): + self.strategy_weights = {} + + def _create_strategy(self, strategy_name: str, strategy_config: Dict[str, Any]) -> Optional[ModernBaseRetriever]: + """Create a retrieval strategy instance.""" + try: + # Merge global config with strategy-specific config + merged_config = self.config.copy() + merged_config.update(strategy_config) + + if strategy_name == 'dense': + from .dense_retriever import QdrantDenseRetriever + return QdrantDenseRetriever(merged_config) + elif strategy_name == 'sparse': + from .sparse_retriever import QdrantSparseRetriever + return QdrantSparseRetriever(merged_config) + elif strategy_name == 'hybrid': + from .hybrid_retriever import QdrantHybridRetriever + return QdrantHybridRetriever(merged_config) + else: + logger.warning(f"Unknown strategy: {strategy_name}") + return None + + except Exception as e: + logger.warning(f"Failed to create strategy {strategy_name}: {e}") + return None + + def _create_query_analyzer(self, analyzer_config: Dict[str, Any]): + """Create query analyzer for intelligent routing.""" + # Placeholder for advanced query analysis + # Could integrate with LLMs, query classification models, etc. + return None + + @property + def component_name(self) -> str: + return "semantic_retriever" + + def _perform_search(self, query: str, k: int) -> List[RetrievalResult]: + """ + Perform semantic search using intelligent strategy selection. + + Args: + query: Search query + k: Number of results to retrieve + + Returns: + List of RetrievalResult objects + """ + if not self.strategies: + self._initialize_components() + + try: + # Analyze query to determine best strategies + selected_strategies = self._select_strategies(query) + + if not selected_strategies: + logger.warning("No strategies selected, using default") + selected_strategies = [self.default_strategy] if self.default_strategy in self.strategies else list( + self.strategies.keys())[:1] + + # If still no strategies, return empty + if not selected_strategies: + logger.warning("No strategies available") + return [] + + # Perform retrieval with selected strategies + if len(selected_strategies) == 1: + # Single strategy + strategy_name = selected_strategies[0] + if strategy_name not in self.strategies: + logger.warning(f"Strategy {strategy_name} not available") + return [] + + strategy = self.strategies[strategy_name] + results = strategy._perform_search(query, k) + + # Update retrieval method in metadata + for result in results: + result.retrieval_method = f"semantic_{strategy_name}" + + return results + else: + # Multiple strategies - combine results + return self._combine_strategy_results(query, selected_strategies, k) + + except Exception as e: + logger.error(f"Error during semantic search: {e}") + return [] + + def _select_strategies(self, query: str) -> List[str]: + """ + Select appropriate retrieval strategies based on query analysis. + + Args: + query: Search query + + Returns: + List of strategy names to use + """ + # Simple rule-based selection for now + # Could be enhanced with ML-based query classification + + query_lower = query.lower() + query_length = len(query.split()) + + # Default to hybrid for most queries + selected = [] + + # Rule-based strategy selection + if query_length <= 3: + # Short queries - prefer sparse (keyword matching) + if 'sparse' in self.strategies: + selected.append('sparse') + elif any(keyword in query_lower for keyword in ['how to', 'what is', 'explain', 'describe']): + # Conceptual queries - prefer dense (semantic similarity) + if 'dense' in self.strategies: + selected.append('dense') + else: + # Default to hybrid for balanced approach + if 'hybrid' in self.strategies: + selected.append('hybrid') + elif 'dense' in self.strategies and 'sparse' in self.strategies: + # Fallback to combining dense and sparse + selected.extend(['dense', 'sparse']) + + # Apply routing rules from config + for rule_query, rule_strategies in self.routing_rules.items(): + if rule_query.lower() in query_lower: + selected = rule_strategies + break + + # Ensure selected strategies exist + selected = [s for s in selected if s in self.strategies] + + logger.debug( + f"Selected strategies for query '{query[:50]}...': {selected}") + return selected + + def _combine_strategy_results(self, query: str, strategies: List[str], k: int) -> List[RetrievalResult]: + """ + Combine results from multiple strategies using fusion techniques. + + Args: + query: Search query + strategies: List of strategy names + k: Number of final results + + Returns: + Combined and ranked results + """ + all_results = {} # document_id -> RetrievalResult + strategy_results = {} + + # Collect results from each strategy + for strategy_name in strategies: + if strategy_name not in self.strategies: + continue + + strategy = self.strategies[strategy_name] + results = strategy._perform_search( + query, k * 2) # Get more results for fusion + strategy_results[strategy_name] = results + + for result in results: + doc_id = self._get_document_id(result.document) + if doc_id not in all_results: + all_results[doc_id] = result + # Update metadata to reflect semantic fusion + result.retrieval_method = f"semantic_fusion" + if 'fusion_strategies' not in result.metadata: + result.metadata['fusion_strategies'] = [] + result.metadata['fusion_strategies'].append(strategy_name) + else: + # Combine scores using weighted average + existing = all_results[doc_id] + weight1 = self.strategy_weights.get( + existing.metadata['fusion_strategies'][-1], 1.0) + weight2 = self.strategy_weights.get(strategy_name, 1.0) + + combined_score = ( + existing.score * weight1 + result.score * weight2) / (weight1 + weight2) + existing.score = combined_score + existing.metadata['fusion_strategies'].append( + strategy_name) + + # Rank and return top k results + final_results = list(all_results.values()) + final_results.sort(key=lambda x: x.score, reverse=True) + + return final_results[:k] + + def _get_document_id(self, document: Document) -> str: + """Get a unique identifier for a document.""" + # Use external_id if available, otherwise use content hash + if hasattr(document, 'metadata') and 'external_id' in document.metadata: + return document.metadata['external_id'] + else: + import hashlib + return hashlib.md5(document.page_content.encode()).hexdigest() diff --git a/retrievers/sparse_retriever.py b/retrievers/sparse_retriever.py new file mode 100644 index 0000000..496f7e3 --- /dev/null +++ b/retrievers/sparse_retriever.py @@ -0,0 +1,188 @@ +""" +Modern sparse retriever that integrates with the retrieval pipeline architecture. +""" + +from typing import List, Dict, Any, Optional +from langchain.embeddings.base import Embeddings +from langchain_core.documents import Document +from langchain_qdrant import QdrantVectorStore, RetrievalMode +from components.retrieval_pipeline import RetrievalResult +from .base_retriever import ModernBaseRetriever +import logging + +logger = logging.getLogger(__name__) + + +class QdrantSparseRetriever(ModernBaseRetriever): + """ + Sparse vector retriever using Qdrant. + Performs keyword-based search using sparse embeddings (e.g., SPLADE, BGE-M3). + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize the sparse retriever. + + Args: + config: Configuration dictionary containing: + - embedding: Sparse embedding configuration + - qdrant: Qdrant database configuration + - top_k: Number of results to retrieve (default: 5) + - score_threshold: Minimum score threshold (default: 0.0) + """ + super().__init__(config) + + # Validate required configuration + if 'embedding' not in config: + logger.warning("No embedding config found, using default") + if 'qdrant' not in config: + logger.warning("No qdrant config found, using default") + + # Initialize components lazily + self.embedding = None + self.vectorstore = None + self._initialized = False + + def _initialize_components(self): + """Initialize embedding and vector store components.""" + if self._initialized: + return + + try: + # Initialize sparse embedding - get the sparse embedding config + from embedding.factory import get_embedder + + # Extract sparse embedding config from the main config + embedding_section = self.config.get('embedding', {}) + if 'sparse' in embedding_section: + embedding_config = embedding_section['sparse'] + else: + # Default sparse embedding config + embedding_config = { + 'provider': 'sparse', + 'model': 'Qdrant/bm25', + 'vector_name': 'sparse' + } + + self.embedding = get_embedder(embedding_config) + + # Initialize Qdrant database + from database.qdrant_controller import QdrantVectorDB + qdrant_db = QdrantVectorDB(config=self.config) + + # Store qdrant_db for direct API access + self.qdrant_db = qdrant_db + + self._initialized = True + logger.info( + f"Sparse retriever initialized with embedding: {type(self.embedding).__name__}") + + except Exception as e: + logger.error( + f"Failed to initialize sparse retriever components: {e}") + import traceback + traceback.print_exc() + self._initialized = False + + @property + def component_name(self) -> str: + return "sparse_retriever" + + def _perform_search(self, query: str, k: int) -> List[RetrievalResult]: + """ + Perform sparse similarity search using direct Qdrant API to preserve external_id. + + Args: + query: Search query + k: Number of results to retrieve + + Returns: + List of RetrievalResult objects + """ + if not self._initialized: + self._initialize_components() + + if not self._initialized: + logger.warning( + "Sparse retriever not properly initialized, returning empty results") + return [] + + try: + # Get sparse query vector + if hasattr(self.embedding, 'embed_query'): + query_vector = self.embedding.embed_query(query) + else: + # For BM25/sparse embeddings that might not have embed_query + query_vector = self.embedding.embed_documents([query])[0] + + # Convert sparse dict to Qdrant sparse vector format for named sparse vectors + if isinstance(query_vector, dict): + from qdrant_client.models import NamedSparseVector + + search_result = self.qdrant_db.client.search( + collection_name=self.qdrant_db.collection_name, + query_vector=NamedSparseVector( + name=self.qdrant_db.sparse_vector_name, + vector={"indices": list(query_vector.keys()), "values": list( + query_vector.values())} + ), + limit=k, + with_payload=True + ) + else: + # Dense vector format (list) - fallback + from qdrant_client.models import NamedVector + + search_result = self.qdrant_db.client.search( + collection_name=self.qdrant_db.collection_name, + query_vector=NamedVector( + name=self.qdrant_db.sparse_vector_name, + vector=query_vector + ), + limit=k, + with_payload=True + ) + + # Convert to RetrievalResult objects + retrieval_results = [] + for result in search_result: + payload = result.payload or {} + + # Create document with preserved external_id + document = Document( + page_content=payload.get('page_content', ''), + metadata={ + **payload.get('metadata', {}), + # Ensure external_id is in metadata + 'external_id': payload.get('external_id'), + # Also store the Qdrant UUID for reference + 'qdrant_id': str(result.id) + } + ) + + retrieval_result = self._create_retrieval_result( + document=document, + score=result.score, + additional_metadata={ + 'search_type': 'sparse_similarity', + 'embedding_model': type(self.embedding).__name__, + # Also add to retrieval metadata + 'external_id': payload.get('external_id') + } + ) + retrieval_results.append(retrieval_result) + + # Normalize scores for consistency + retrieval_results = self._normalize_scores(retrieval_results) + + return retrieval_results + + except Exception as e: + logger.error(f"Error during sparse search: {e}") + import traceback + traceback.print_exc() + return [] + + +# Backward compatibility alias +SparseRetriever = QdrantSparseRetriever diff --git a/scripts/setup_sosum.sh b/scripts/setup_sosum.sh new file mode 100755 index 0000000..cb7520e --- /dev/null +++ b/scripts/setup_sosum.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# SOSum Dataset Quick Setup Script +# This script downloads and sets up the SOSum dataset for ingestion + +set -e + +echo "🔧 SOSum Dataset Setup" +echo "======================" + +# Configuration +SOSUM_DIR="datasets/sosum" +GITHUB_URL="https://github.com/BonanKou/SOSum-A-Dataset-of-Extractive-Summaries-of-Stack-Overflow-Posts-and-labeling-tools.git" + +# Create datasets directory +echo "Creating datasets directory..." +mkdir -p datasets/ + +# Download SOSum dataset +if [ ! -d "$SOSUM_DIR" ]; then + echo "Downloading SOSum dataset..." + git clone "$GITHUB_URL" "$SOSUM_DIR" +else + echo "SOSum dataset already exists at $SOSUM_DIR" +fi + +# Verify files exist +echo "Verifying dataset files..." +if [ -f "$SOSUM_DIR/data/question.csv" ] && [ -f "$SOSUM_DIR/data/answer.csv" ]; then + echo "Found question.csv and answer.csv" + + # Show file stats + echo "Dataset statistics:" + echo " Questions: $(tail -n +2 $SOSUM_DIR/data/question.csv | wc -l) rows" + echo " Answers: $(tail -n +2 $SOSUM_DIR/data/answer.csv | wc -l) rows" +else + echo "Required CSV files not found in $SOSUM_DIR/data/" + exit 1 +fi + +echo "" +echo " Ready for ingestion! Run:" +echo " # Test the adapter" +echo " python examples/ingest_sosum_example.py" +echo "" +echo " # Dry run ingestion" +echo " python bin/ingest.py ingest stackoverflow $SOSUM_DIR --dry-run --max-docs 10" +echo "" +echo " # Canary ingestion (safe test)" +echo " python bin/ingest.py ingest stackoverflow $SOSUM_DIR --canary --max-docs 100" +echo "" +echo " # Full ingestion" +echo " python bin/ingest.py ingest stackoverflow $SOSUM_DIR" +echo "" +echo " # Evaluate retrieval" +echo " python bin/ingest.py evaluate stackoverflow $SOSUM_DIR --output-dir results/sosum/" diff --git a/scripts/standalone_sosum_processor.py b/scripts/standalone_sosum_processor.py new file mode 100644 index 0000000..d1b1547 --- /dev/null +++ b/scripts/standalone_sosum_processor.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +""" +Standalone SOSum ingestion script. +This script works without the full pipeline dependencies. +""" +import csv +import json +import hashlib +from pathlib import Path +from typing import List, Dict, Any +from datetime import datetime + +def compute_hash(text: str) -> str: + """Compute SHA256 hash of text.""" + return hashlib.sha256(text.encode('utf-8')).hexdigest()[:12] + +def parse_list_field(field_value: str) -> List[str]: + """Parse a field that might be a string representation of a list.""" + if not field_value: + return [] + + if field_value.startswith('[') and field_value.endswith(']'): + try: + import ast + result = ast.literal_eval(field_value) + if isinstance(result, list): + return [str(item).strip() for item in result if str(item).strip()] + except: + # Fallback: strip brackets and split by comma + return [item.strip().strip("'\"") for item in field_value.strip('[]').split(',') if item.strip()] + + return [field_value.strip()] + +def process_sosum_dataset(dataset_path: str, output_dir: str = "output/sosum_processed"): + """Process SOSum dataset and export as JSON.""" + dataset_path = Path(dataset_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Find CSV files + question_file = dataset_path / "question.csv" + answer_file = dataset_path / "answer.csv" + + if not question_file.exists(): + data_dir = dataset_path / "data" + if data_dir.exists(): + question_file = data_dir / "question.csv" + answer_file = data_dir / "answer.csv" + + if not question_file.exists() or not answer_file.exists(): + raise FileNotFoundError(f"CSV files not found in {dataset_path}") + + print(f"📂 Processing SOSum dataset from: {dataset_path}") + print(f"📄 Questions: {question_file}") + print(f"📄 Answers: {answer_file}") + + documents = [] + evaluation_queries = [] + + # Process questions + print("\n🔍 Processing questions...") + question_count = 0 + with open(question_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader): + try: + question_id = row.get("Question Id", f"q{row_num}") + title = row.get("Question Title", "") + body_raw = row.get("Question Body", "") + tags_raw = row.get("Tags", "") + question_type = row.get("Question Type", "") + + # Parse body (might be list of sentences) + body_list = parse_list_field(body_raw) + body = " ".join(body_list) if body_list else "" + + # Parse tags + tags = parse_list_field(tags_raw) if tags_raw else [] + + # Create document content + content = f"Title: {title}\n\nQuestion: {body}" if title and body else (title or body) + + if content.strip(): + doc_id = f"q_{question_id}" + content_hash = compute_hash(content) + + document = { + "id": doc_id, + "content": content, + "content_hash": content_hash, + "metadata": { + "external_id": doc_id, + "source": "stackoverflow_sosum", + "post_type": "question", + "title": title, + "tags": tags, + "question_type": int(question_type) if question_type.isdigit() else None, + "char_count": len(content), + "processed_at": datetime.now().isoformat() + } + } + documents.append(document) + question_count += 1 + + # Create evaluation query from title + if title and len(title) > 10: + evaluation_queries.append({ + "query_id": f"eval_q_{question_id}", + "query": title, + "expected_docs": [doc_id], + "query_type": "question_title" + }) + + except Exception as e: + print(f"Error processing question {row_num}: {e}") + continue + + print(f"✅ Processed {question_count} questions") + + # Process answers + print("\n🔍 Processing answers...") + answer_count = 0 + with open(answer_file, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader): + try: + answer_id = row.get("Answer Id", f"a{row_num}") + body_raw = row.get("Answer Body", "") + summary_raw = row.get("Summary", "") + + # Parse body (might be list of sentences) + body_list = parse_list_field(body_raw) + body = " ".join(body_list) if body_list else "" + + # Parse summary (might be list of sentences) + summary_list = parse_list_field(summary_raw) if summary_raw else [] + summary = " ".join(summary_list) if summary_list else "" + + # Create document content + content = body + if summary: + content = f"Answer: {body}\n\nSummary: {summary}" + + if content.strip(): + doc_id = f"a_{answer_id}" + content_hash = compute_hash(content) + + document = { + "id": doc_id, + "content": content, + "content_hash": content_hash, + "metadata": { + "external_id": doc_id, + "source": "stackoverflow_sosum", + "post_type": "answer", + "has_summary": bool(summary), + "summary": summary if summary else None, + "char_count": len(content), + "processed_at": datetime.now().isoformat() + } + } + documents.append(document) + answer_count += 1 + + # Create evaluation query from summary + if summary and len(summary) > 20: + # Use first sentence of summary as query + summary_sentences = summary.split('.') + query = summary_sentences[0].strip() + "." if summary_sentences else summary[:50] + + if len(query) > 10: + evaluation_queries.append({ + "query_id": f"eval_a_{answer_id}", + "query": query, + "expected_docs": [doc_id], + "query_type": "answer_summary" + }) + + except Exception as e: + print(f"Error processing answer {row_num}: {e}") + continue + + print(f"✅ Processed {answer_count} answers") + + # Save results + documents_file = output_dir / "documents.json" + queries_file = output_dir / "evaluation_queries.json" + stats_file = output_dir / "stats.json" + + with open(documents_file, 'w', encoding='utf-8') as f: + json.dump(documents, f, indent=2, ensure_ascii=False) + + with open(queries_file, 'w', encoding='utf-8') as f: + json.dump(evaluation_queries, f, indent=2, ensure_ascii=False) + + # Generate statistics + stats = { + "dataset_name": "SOSum Stack Overflow", + "processed_at": datetime.now().isoformat(), + "total_documents": len(documents), + "questions": question_count, + "answers": answer_count, + "evaluation_queries": len(evaluation_queries), + "files": { + "documents": str(documents_file), + "evaluation_queries": str(queries_file), + }, + "content_stats": { + "avg_char_count": sum(d["metadata"]["char_count"] for d in documents) / len(documents) if documents else 0, + "min_char_count": min(d["metadata"]["char_count"] for d in documents) if documents else 0, + "max_char_count": max(d["metadata"]["char_count"] for d in documents) if documents else 0, + } + } + + with open(stats_file, 'w', encoding='utf-8') as f: + json.dump(stats, f, indent=2, ensure_ascii=False) + + print(f"\n📊 Processing complete!") + print(f" Total documents: {len(documents)}") + print(f" Questions: {question_count}") + print(f" Answers: {answer_count}") + print(f" Evaluation queries: {len(evaluation_queries)}") + print(f" Avg content length: {stats['content_stats']['avg_char_count']:.0f} chars") + print(f"\n💾 Files saved to: {output_dir}") + print(f" Documents: {documents_file}") + print(f" Queries: {queries_file}") + print(f" Stats: {stats_file}") + + return documents, evaluation_queries, stats + +def main(): + """Main function.""" + import sys + + if len(sys.argv) < 2: + print("Usage: python standalone_sosum_processor.py [output_dir]") + print("Example: python standalone_sosum_processor.py ../datasets/sosum") + return 1 + + dataset_path = sys.argv[1] + output_dir = sys.argv[2] if len(sys.argv) > 2 else "output/sosum_processed" + + try: + documents, queries, stats = process_sosum_dataset(dataset_path, output_dir) + + print(f"\n🎯 Ready for ingestion!") + print(f" You now have {len(documents)} processed documents") + print(f" Each document has a unique ID and content hash for deduplication") + print(f" {len(queries)} evaluation queries are available for testing") + + return 0 + + except Exception as e: + print(f"❌ Error: {e}") + return 1 + +if __name__ == "__main__": + exit(main()) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..afd2e22 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package for the RAG pipeline diff --git a/tests/pipeline/__init__.py b/tests/pipeline/__init__.py new file mode 100644 index 0000000..0bdbfff --- /dev/null +++ b/tests/pipeline/__init__.py @@ -0,0 +1 @@ +# Minimal pipeline tests package diff --git a/tests/pipeline/run_tests.py b/tests/pipeline/run_tests.py new file mode 100644 index 0000000..b41e670 --- /dev/null +++ b/tests/pipeline/run_tests.py @@ -0,0 +1,229 @@ +""" +Comprehensive Test Runner for Minimal Pipeline Tests + +Runs all minimal pipeline tests with proper error handling and reporting. +Designed to be CI-friendly and avoid local model dependencies. +""" + +import pytest +import sys +import os +from pathlib import Path +import importlib.util + +# Add project root to path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + + +def run_all_tests(): + """ + Run all minimal pipeline tests with comprehensive reporting. + """ + print("🧪 Comprehensive Pipeline Test Suite") + print("=" * 60) + print("🎯 Testing core pipeline functionality without local models") + + # Check API key availability + has_api_key = bool(os.getenv("GOOGLE_API_KEY")) + has_qdrant = check_qdrant_availability() + + print(f"🔑 Google API Key: {'✅ Available' if has_api_key else '❌ Not set'}") + print(f"🗄️ Qdrant Service: {'✅ Running' if has_qdrant else '❌ Not available'}") + print("=" * 60) + + # Test files to run + test_files = [ + ("test_minimal_pipeline.py", "Minimal Pipeline Tests", False, False), + ("test_components.py", "Component Integration Tests", False, False), + ("test_qdrant_connectivity.py", "Qdrant Connectivity Tests", True, False) + ] + + # Add end-to-end tests if API key is available + if has_api_key and has_qdrant: + test_files.append(("test_end_to_end.py", "End-to-End Pipeline Tests", True, True)) + print("� Full test suite - including end-to-end tests") + elif has_api_key: + print("⚠️ API key available but Qdrant not running - skipping end-to-end tests") + elif has_qdrant: + print("⚠️ Qdrant running but no API key - skipping end-to-end tests") + else: + print("⚠️ Running minimal test suite only") + + results = {} + + for test_file, description, requires_qdrant, requires_api in test_files: + print(f"\n🚀 Running {description}") + print("-" * 40) + + # Skip if requirements not met + if requires_qdrant and not has_qdrant: + print(f"⏭️ Skipping {test_file} - Qdrant not available") + results[test_file] = "SKIPPED_NO_QDRANT" + continue + + if requires_api and not has_api_key: + print(f"⏭️ Skipping {test_file} - API key not available") + results[test_file] = "SKIPPED_NO_API" + continue + + test_path = Path(__file__).parent / test_file + + if not test_path.exists(): + print(f"❌ Test file not found: {test_file}") + results[test_file] = "FILE_NOT_FOUND" + continue + + try: + # Build pytest command + pytest_args = [ + str(test_path), + "-v", + "--tb=short" + ] + + # Add markers for end-to-end tests + if requires_api: + pytest_args.extend(["-m", "requires_api"]) + + # Run pytest on specific file + exit_code = pytest.main(pytest_args) + + if exit_code == 0: + print(f"✅ {description} passed") + results[test_file] = "PASSED" + else: + print(f"❌ {description} failed") + results[test_file] = "FAILED" + + except Exception as e: + print(f"💥 Error running {test_file}: {e}") + results[test_file] = f"ERROR: {e}" + + # Summary + print("\n" + "=" * 60) + print("📊 COMPREHENSIVE TEST SUMMARY") + print("=" * 60) + + passed = sum(1 for result in results.values() if result == "PASSED") + skipped = sum(1 for result in results.values() if result.startswith("SKIPPED")) + failed = sum(1 for result in results.values() if result == "FAILED" or result.startswith("ERROR")) + total = len(results) + + for test_file, result in results.items(): + if result == "PASSED": + status_emoji = "✅" + elif result.startswith("SKIPPED"): + status_emoji = "⏭️ " + else: + status_emoji = "❌" + print(f"{status_emoji} {test_file}: {result}") + + print(f"\n🎯 Results: {passed} passed, {skipped} skipped, {failed} failed ({total} total)") + + if failed == 0: + if passed > 0: + print("🎉 ALL AVAILABLE TESTS PASSED!") + if has_api_key and has_qdrant: + print("✅ Complete pipeline validation successful") + else: + print("✅ Available components validated successfully") + if not has_api_key: + print("💡 Set GOOGLE_API_KEY for end-to-end tests") + if not has_qdrant: + print("💡 Start Qdrant for database tests") + else: + print("⚠️ No tests were executed") + return True + else: + print("⚠️ Some tests failed. Check the output above for details.") + return False + + +def check_qdrant_availability(): + """Check if Qdrant is available.""" + try: + import requests + response = requests.get("http://localhost:6333/collections", timeout=3) + return response.status_code == 200 + except: + return False + + +def check_dependencies(): + """ + Check if required dependencies are available. + """ + print("🔍 Checking dependencies...") + + required_modules = [ + "pytest", + "yaml", + "requests", + "langchain_core" + ] + + missing = [] + + for module in required_modules: + try: + spec = importlib.util.find_spec(module) + if spec is None: + missing.append(module) + except ImportError: + missing.append(module) + + if missing: + print(f"❌ Missing dependencies: {', '.join(missing)}") + print("💡 Install with: pip install " + " ".join(missing)) + return False + else: + print("✅ All dependencies available") + return True + + +def check_environment(): + """ + Check environment setup for tests. + """ + print("🌍 Checking environment...") + + # Check if we're in the right directory + if not Path("config.yml").exists(): + print("❌ config.yml not found. Make sure you're in the project root directory.") + return False + + # Check CI Google config exists + ci_config = Path("pipelines/configs/retrieval/ci_google_gemini.yml") + if not ci_config.exists(): + print(f"❌ CI Google config not found: {ci_config}") + return False + + print("✅ Environment ready") + return True + + +def main(): + """ + Main test runner entry point. + """ + print("🎬 Starting Minimal Pipeline Test Suite") + print(f"📁 Working directory: {os.getcwd()}") + print(f"🐍 Python: {sys.version}") + + # Pre-flight checks + if not check_dependencies(): + sys.exit(1) + + if not check_environment(): + sys.exit(1) + + # Run tests + success = run_all_tests() + + # Exit with appropriate code + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/pipeline/test_components.py b/tests/pipeline/test_components.py new file mode 100644 index 0000000..b636659 --- /dev/null +++ b/tests/pipeline/test_components.py @@ -0,0 +1,181 @@ +""" +Component Integration Tests + +Tests for individual pipeline components without requiring local models. +Tests component initialization and configuration validation. +""" + +import pytest +import sys +from pathlib import Path + +# Add project root to path for imports +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + + +class TestComponentIntegration: + """Test pipeline component integration.""" + + def test_retrieval_component_base_import(self): + """Test that base retrieval components can be imported.""" + from components.retrieval_pipeline import RetrievalComponent, BaseRetriever, RetrievalResult + + assert RetrievalComponent is not None + assert BaseRetriever is not None + assert RetrievalResult is not None + + def test_retrieval_result_dataclass(self): + """Test RetrievalResult dataclass functionality.""" + from components.retrieval_pipeline import RetrievalResult + from langchain_core.documents import Document + + doc = Document(page_content="test content", metadata={"test": "value"}) + result = RetrievalResult( + document=doc, + score=0.95, + retrieval_method="dense", + metadata={"extra": "info"} + ) + + assert result.document.page_content == "test content" + assert result.score == 0.95 + assert result.retrieval_method == "dense" + assert result.metadata["extra"] == "info" + + def test_pipeline_factory_import(self): + """Test that pipeline factory can be imported.""" + from components.retrieval_pipeline import RetrievalPipelineFactory + + assert RetrievalPipelineFactory is not None + + # Test that it has required methods + assert hasattr(RetrievalPipelineFactory, 'create_from_config') + + def test_filters_import(self): + """Test that filter components can be imported.""" + from components.filters import ScoreFilter, ResultLimiter + + assert ScoreFilter is not None + assert ResultLimiter is not None + + def test_score_filter_functionality(self): + """Test score filter without actual retrieval.""" + from components.filters import ScoreFilter + from components.retrieval_pipeline import RetrievalResult + from langchain_core.documents import Document + + # Create filter + filter_component = ScoreFilter(min_score=0.5) + + # Create test results + high_score_doc = Document(page_content="high relevance") + low_score_doc = Document(page_content="low relevance") + + results = [ + RetrievalResult(high_score_doc, 0.8, "dense"), + RetrievalResult(low_score_doc, 0.3, "dense") + ] + + # Filter results + filtered = filter_component.filter("test query", results) + + # Should only keep high score result + assert len(filtered) == 1 + assert filtered[0].score == 0.8 + + def test_limit_filter_functionality(self): + """Test result limiter without actual retrieval.""" + from components.filters import ResultLimiter + from components.retrieval_pipeline import RetrievalResult + from langchain_core.documents import Document + + # Create limiter + limiter = ResultLimiter(max_results=2) + + # Create test results + results = [] + for i in range(5): + doc = Document(page_content=f"content {i}") + results.append(RetrievalResult(doc, 0.9 - i*0.1, "dense")) + + # Limit results + limited = limiter.post_process("test query", results) + + # Should limit to 2 results + assert len(limited) == 2 + assert limited[0].score == 0.9 + assert limited[1].score == 0.8 + + def test_database_controller_import(self): + """Test that database controllers can be imported.""" + from database.qdrant_controller import QdrantVectorDB + + assert QdrantVectorDB is not None + + # Test that it has required methods + db = QdrantVectorDB() + assert hasattr(db, 'init_collection') + assert hasattr(db, 'get_client') + assert hasattr(db, 'insert_documents') + + def test_embedding_factory_import(self): + """Test that embedding factory can be imported.""" + from embedding.factory import get_embedder + + assert get_embedder is not None + + def test_agent_nodes_import(self): + """Test that agent nodes can be imported.""" + from agent.nodes.retriever import make_configurable_retriever + from agent.nodes.generator import make_generator + from agent.nodes.query_interpreter import make_query_interpreter + + assert make_configurable_retriever is not None + assert make_generator is not None + assert make_query_interpreter is not None + + +class TestConfigurationValidation: + """Test configuration validation without actual initialization.""" + + def test_config_loader_imports(self): + """Test config loading utilities.""" + from config.config_loader import load_config, get_retriever_config + + assert load_config is not None + assert get_retriever_config is not None + + def test_retrieval_config_structure(self): + """Test retrieval configuration structure.""" + import yaml + + config_path = "pipelines/configs/retrieval/ci_google_gemini.yml" + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + + # Test required structure + assert "retrieval_pipeline" in config + pipeline_config = config["retrieval_pipeline"] + + assert "retriever" in pipeline_config + retriever_config = pipeline_config["retriever"] + + # Test retriever config + assert "type" in retriever_config + assert "embedding" in retriever_config + assert "qdrant" in retriever_config + + # Test embedding config + embedding_config = retriever_config["embedding"] + assert "dense" in embedding_config + + dense_config = embedding_config["dense"] + assert dense_config["provider"] == "google" + assert "model" in dense_config + assert "dimensions" in dense_config + + +if __name__ == "__main__": + # Run tests directly + pytest.main([__file__, "-v"]) diff --git a/tests/pipeline/test_config.py b/tests/pipeline/test_config.py new file mode 100644 index 0000000..47f0fd4 --- /dev/null +++ b/tests/pipeline/test_config.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +""" +Configuration validation tests for pipeline. +Tests YAML structure and required fields without loading models. +""" + +import yaml +import sys +from pathlib import Path +from typing import Dict, Any, List + + +def test_yaml_validity() -> bool: + """Test that all YAML configuration files are valid.""" + print("🔍 Testing YAML file validity...") + + config_dirs = [ + 'pipelines/configs/retrieval', + 'pipelines/configs/datasets', + 'pipelines/configs/examples', + 'pipelines/configs/legacy' + ] + + total_files = 0 + errors = [] + + for config_dir in config_dirs: + config_path = Path(config_dir) + if not config_path.exists(): + continue + + for yaml_file in config_path.glob('*.yml'): + total_files += 1 + try: + with open(yaml_file, 'r') as f: + yaml.safe_load(f) + except Exception as e: + errors.append(f"{yaml_file}: {e}") + + if errors: + print(f"❌ YAML validation failed:") + for error in errors: + print(f" {error}") + return False + else: + print(f"✅ All {total_files} YAML files are valid") + return True + + +def test_retrieval_config_structure() -> bool: + """Test retrieval configuration structure.""" + print("🔍 Testing retrieval configuration structure...") + + retrieval_dir = Path('pipelines/configs/retrieval') + if not retrieval_dir.exists(): + print("❌ Retrieval configs directory not found") + return False + + required_fields = ['retrieval_pipeline'] + required_retriever_fields = ['type', 'top_k'] + + config_count = 0 + for config_file in retrieval_dir.glob('*.yml'): + config_count += 1 + + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + # Check top-level fields + for field in required_fields: + if field not in config: + print(f"❌ {config_file.name}: Missing {field}") + return False + + # Check retriever fields + retriever = config.get('retrieval_pipeline', {}).get('retriever', {}) + for field in required_retriever_fields: + if field not in retriever: + print(f"❌ {config_file.name}: Missing retriever.{field}") + return False + + # Check that retriever type is valid + valid_types = ['dense', 'sparse', 'hybrid'] + if retriever.get('type') not in valid_types: + print(f"❌ {config_file.name}: Invalid retriever type") + return False + + if config_count == 0: + print("❌ No retrieval configs found") + return False + + print(f"✅ All {config_count} retrieval configs have valid structure") + return True + + +def test_google_embeddings_in_configs() -> bool: + """Test that Google embeddings are properly configured.""" + print("🔍 Testing Google embeddings configuration...") + + retrieval_dir = Path('pipelines/configs/retrieval') + if not retrieval_dir.exists(): + print("❌ Retrieval configs directory not found") + return False + + google_configs = 0 + for config_file in retrieval_dir.glob('*.yml'): + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + # Look for Google embeddings configuration + retriever = config.get('retrieval_pipeline', {}).get('retriever', {}) + embedding = retriever.get('embedding', {}) + + if embedding: + dense_config = embedding.get('dense', {}) + if dense_config.get('provider') == 'google': + google_configs += 1 + + # Validate Google-specific fields + required_google_fields = ['model', 'dimensions', 'api_key_env'] + for field in required_google_fields: + if field not in dense_config: + print(f"❌ {config_file.name}: Missing Google embedding field: {field}") + return False + + # Check model format + model = dense_config.get('model', '') + if not model.startswith('models/'): + print(f"❌ {config_file.name}: Invalid Google model format: {model}") + return False + + print(f"✅ Found {google_configs} configs with valid Google embeddings") + return True + + +def test_main_config_structure() -> bool: + """Test main configuration file structure.""" + print("🔍 Testing main configuration structure...") + + main_config_path = Path('config.yml') + if not main_config_path.exists(): + print("❌ Main config file not found") + return False + + try: + with open(main_config_path, 'r') as f: + config = yaml.safe_load(f) + except Exception as e: + print(f"❌ Failed to load main config: {e}") + return False + + required_sections = ['llm', 'qdrant', 'agent_retrieval'] + for section in required_sections: + if section not in config: + print(f"❌ Missing main config section: {section}") + return False + + # Check agent_retrieval section + agent_retrieval = config.get('agent_retrieval', {}) + if 'config_path' not in agent_retrieval: + print("❌ Missing agent_retrieval.config_path") + return False + + # Check that the referenced config file exists + config_path = agent_retrieval.get('config_path') + if not Path(config_path).exists(): + print(f"❌ Referenced config file not found: {config_path}") + return False + + print("✅ Main configuration structure is valid") + return True + + +def run_config_validation_tests() -> bool: + """Run all configuration validation tests.""" + print("🔧 Configuration Validation Tests") + print("=" * 40) + + tests = [ + ("YAML Validity", test_yaml_validity), + ("Retrieval Config Structure", test_retrieval_config_structure), + ("Google Embeddings Config", test_google_embeddings_in_configs), + ("Main Config Structure", test_main_config_structure), + ] + + passed = 0 + failed_tests = [] + + for test_name, test_func in tests: + print(f"\n📋 {test_name}") + print("-" * 25) + + if test_func(): + passed += 1 + else: + failed_tests.append(test_name) + + total = len(tests) + print("\n" + "=" * 40) + print("📊 CONFIGURATION VALIDATION RESULTS") + print("=" * 40) + + if passed == total: + print("🎉 ALL CONFIGURATION TESTS PASSED!") + return True + else: + print(f"❌ {total - passed} of {total} tests failed") + print("Failed tests:") + for test in failed_tests: + print(f" • {test}") + return False + + +def main(): + """Main function.""" + success = run_config_validation_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pipeline/test_end_to_end.py b/tests/pipeline/test_end_to_end.py new file mode 100644 index 0000000..ffe5a9c --- /dev/null +++ b/tests/pipeline/test_end_to_end.py @@ -0,0 +1,365 @@ +""" +End-to-End Pipeline Tests + +These tests actually run the complete pipeline with real data. +Requires GOOGLE_API_KEY and Qdrant with test data. +Designed for GitHub Actions with API key secrets. +""" + +import pytest +import os +import sys +import requests +import time +from pathlib import Path +from typing import List, Dict, Any + +# Add project root to path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + + +# Sample documents for testing +SAMPLE_DOCUMENTS = [ + { + "content": "Python exception handling is done using try-except blocks. You can catch specific exceptions like ValueError or use a general except clause. Always handle exceptions gracefully to prevent crashes.", + "title": "Python Exception Handling Basics", + "tags": ["python", "error-handling", "exceptions"], + "external_id": "doc_001" + }, + { + "content": "Binary search is an efficient algorithm for finding an item from a sorted list of items. It works by repeatedly dividing the search interval in half. Time complexity is O(log n).", + "title": "Binary Search Algorithm Explained", + "tags": ["algorithms", "search", "binary-search", "complexity"], + "external_id": "doc_002" + }, + { + "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed. Common types include supervised, unsupervised, and reinforcement learning.", + "title": "Introduction to Machine Learning", + "tags": ["machine-learning", "ai", "supervised-learning"], + "external_id": "doc_003" + }, + { + "content": "REST APIs are architectural style for designing networked applications. They use HTTP methods like GET, POST, PUT, DELETE to perform operations. REST APIs are stateless and use JSON for data exchange.", + "title": "Understanding REST API Design", + "tags": ["api", "rest", "web-development", "http"], + "external_id": "doc_004" + }, + { + "content": "Docker containers provide a lightweight way to package applications with their dependencies. Containers are isolated, portable, and can run consistently across different environments.", + "title": "Docker Container Fundamentals", + "tags": ["docker", "containers", "devops", "deployment"], + "external_id": "doc_005" + } +] + + +@pytest.mark.requires_api +class TestEndToEndPipeline: + """Test complete pipeline execution with real API and data.""" + + @pytest.fixture(scope="class") + def qdrant_config(self): + """Qdrant configuration for tests.""" + host = os.getenv("QDRANT_HOST", "127.0.0.1") + port = int(os.getenv("QDRANT_PORT", "6333")) + return { + "host": host, + "port": port, + "url": f"http://{host}:{port}", + "collection_name": "test_e2e_collection" + } + + @pytest.fixture(scope="class") + def setup_test_collection(self, qdrant_config): + """Set up Qdrant collection with sample documents for testing.""" + if not os.getenv("GOOGLE_API_KEY"): + pytest.skip("GOOGLE_API_KEY not set - skipping end-to-end test") + + collection_name = qdrant_config["collection_name"] + base_url = qdrant_config["url"] + + # Clean up any existing collection + try: + requests.delete( + f"{base_url}/collections/{collection_name}", timeout=10) + except requests.RequestException: + pass + time.sleep(0.5) + + # Create collection with Google embedding dimensions + create_payload = { + "vectors": { + "size": 768, # Google embeddings size + "distance": "Cosine" + } + } + + response = requests.put( + f"{base_url}/collections/{collection_name}", + json=create_payload, + timeout=15 + ) + assert response.status_code in ( + 200, 201), f"Failed to create collection: {response.text}" + + # Insert sample documents with embeddings + self._insert_sample_documents(qdrant_config) + + yield qdrant_config + + # Cleanup after tests + try: + requests.delete( + f"{base_url}/collections/{collection_name}", timeout=10) + except requests.RequestException: + pass + + def _insert_sample_documents(self, qdrant_config): + """Insert sample documents into Qdrant collection.""" + from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings + from qdrant_client import QdrantClient + from qdrant_client.models import PointStruct + + # Initialize Google embeddings (force REST in CI if provided) + embeddings = GoogleGenerativeAIEmbeddings( + model="models/embedding-001", + google_api_key=os.getenv("GOOGLE_API_KEY"), + transport=os.getenv("GENAI_TRANSPORT", "rest") + ) + + # Initialize Qdrant client + client = QdrantClient( + host=qdrant_config["host"], + port=qdrant_config["port"] + ) + + # Generate embeddings and create points + points = [] + for i, doc in enumerate(SAMPLE_DOCUMENTS): + vector = embeddings.embed_query(doc["content"]) + point = PointStruct( + id=i + 1, + vector=vector, + payload={ + "content": doc["content"], + "labels": { + "title": doc["title"], + "tags": doc["tags"], + "external_id": doc["external_id"] + } + } + ) + points.append(point) + + # Insert points + client.upsert( + collection_name=qdrant_config["collection_name"], + points=points + ) + + # Wait for indexing + time.sleep(1.5) + print( + f"✅ Inserted {len(points)} sample documents into {qdrant_config['collection_name']}") + + def _update_config_for_test(self, collection_name: str): + """Create a test config file with the test collection name and Gemini embeddings.""" + import yaml + + # Load base CI config + with open("pipelines/configs/retrieval/ci_google_gemini.yml", 'r') as f: + config = yaml.safe_load(f) + + # Ensure the dict structure exists + config.setdefault("qdrant", {}) + config.setdefault("retrieval_pipeline", {}) + config["retrieval_pipeline"].setdefault("retriever", {}) + config["retrieval_pipeline"]["retriever"].setdefault("qdrant", {}) + + # Point to the test collection + config["qdrant"]["collection"] = collection_name + config["retrieval_pipeline"]["retriever"]["qdrant"]["collection_name"] = collection_name + config["retrieval_pipeline"]["retriever"]["qdrant"]["force_recreate"] = False + + # Force Gemini embeddings for query-time embedding + # Adjust path if your agent expects a different key + config["retrieval_pipeline"]["retriever"]["embedding"] = { + "dense": { + "provider": "google", + "model": "models/embedding-001", + "api_key_env": "GOOGLE_API_KEY", + "transport": "rest" + } + } + + # Save test config + test_config_path = "pipelines/configs/retrieval/ci_google_gemini_test.yml" + with open(test_config_path, 'w') as f: + yaml.dump(config, f, default_flow_style=False) + + print(f"✅ Created test config with collection: {collection_name}") + + @pytest.mark.integration + @pytest.mark.requires_api + def test_full_retrieval_pipeline(self, setup_test_collection): + """Test complete retrieval pipeline with real query and data.""" + from bin.agent_retriever import ConfigurableRetrieverAgent + + qdrant_config = setup_test_collection + + # Update CI config to use test collection and Gemini embeddings + self._update_config_for_test(qdrant_config["collection_name"]) + + # Create agent with test config + agent = ConfigurableRetrieverAgent( + "pipelines/configs/retrieval/ci_google_gemini_test.yml") + + # Test queries that should match our sample documents + test_cases = [ + { + "query": "How to handle errors in Python?", + "expected_keywords": ["exception", "python", "try", "except"], + "min_score": 0.3 + }, + { + "query": "What is binary search algorithm?", + "expected_keywords": ["binary", "search", "algorithm", "sorted"], + "min_score": 0.3 + }, + { + "query": "Machine learning introduction", + "expected_keywords": ["machine learning", "artificial intelligence", "supervised"], + "min_score": 0.3 + } + ] + + for test_case in test_cases: + results = agent.retrieve(test_case["query"], top_k=3) + + # Validate results structure + assert isinstance(results, list), "Results should be a list" + assert len( + results) > 0, f"Should return results for query: {test_case['query']}" + assert len( + results) <= 3, "Should not return more than requested top_k" + + # Validate result structure + result = results[0] + assert "score" in result + assert "content" in result + assert "retrieval_method" in result + assert "question_title" in result + assert "tags" in result + + # Validate score and content quality + assert isinstance(result["score"], (int, float)) + assert result["score"] >= test_case[ + "min_score"], f"Score too low: {result['score']}" + assert len(result["content"]) > 20 + + # Check for at least one expected keyword + content_lower = result["content"].lower() + title_lower = result["question_title"].lower() + assert any(k.lower() in content_lower or k.lower() + in title_lower for k in test_case["expected_keywords"]) + + print( + f"✅ Query: '{test_case['query']}' -> Score: {result['score']:.3f}, Title: '{result['question_title']}'") + + @pytest.mark.integration + @pytest.mark.requires_api + def test_retrieval_ranking(self, setup_test_collection): + """Test that retrieval returns results in correct ranking order.""" + from bin.agent_retriever import ConfigurableRetrieverAgent + + qdrant_config = setup_test_collection + self._update_config_for_test(qdrant_config["collection_name"]) + + agent = ConfigurableRetrieverAgent( + "pipelines/configs/retrieval/ci_google_gemini_test.yml") + + # Query that should strongly match the Python exceptions document + results = agent.retrieve( + "Python try except error handling ValueError", top_k=5) + + assert len(results) > 1, "Should return multiple results" + + # Scores should be in descending order + scores = [result["score"] for result in results] + assert scores == sorted( + scores, reverse=True), "Results should be sorted by score (descending)" + + # Top result should be highly relevant + top_result = results[0] + assert top_result["score"] > 0.5, f"Top result score too low: {top_result['score']}" + + # Check that top result is about Python exceptions + content_and_title = ( + top_result["content"] + " " + top_result["question_title"]).lower() + assert "python" in content_and_title, "Top result should mention Python" + assert ("exception" in content_and_title or "error" in content_and_title), "Top result should be about error handling" + + @pytest.mark.integration + @pytest.mark.requires_api + def test_config_switching_with_data(self, setup_test_collection): + """Test configuration switching works with real data.""" + from bin.agent_retriever import ConfigurableRetrieverAgent + + qdrant_config = setup_test_collection + self._update_config_for_test(qdrant_config["collection_name"]) + + agent = ConfigurableRetrieverAgent( + "pipelines/configs/retrieval/ci_google_gemini_test.yml") + + # Test with initial config + query = "machine learning basics" + results1 = agent.retrieve(query, top_k=2) + + # Get config info + config_info1 = agent.get_config_info() + + # Verify initial results + assert len(results1) > 0, "Should return results with initial config" + assert config_info1["retriever_type"] == "dense", "Should be using dense retriever" + assert config_info1["collection"] == qdrant_config["collection_name"], "Should use test collection" + + print( + f"✅ Config switching test completed - Retrieved {len(results1)} results") + + @pytest.mark.integration + def test_pipeline_error_handling_with_real_setup(self, setup_test_collection): + """Test pipeline error handling with real Qdrant setup.""" + from bin.agent_retriever import ConfigurableRetrieverAgent + + # Test with non-existent config + with pytest.raises(FileNotFoundError): + ConfigurableRetrieverAgent("nonexistent_config.yml") + + # Test with valid agent + qdrant_config = setup_test_collection + self._update_config_for_test(qdrant_config["collection_name"]) + + agent = ConfigurableRetrieverAgent( + "pipelines/configs/retrieval/ci_google_gemini_test.yml") + + # Test empty query (should handle gracefully) + try: + results = agent.retrieve("", top_k=1) + assert isinstance( + results, list), "Should return list even for empty query" + except Exception as e: + assert len(str(e)) > 0, "Error message should be informative" + + # Test very long query (should not crash) + long_query = "test " * 100 # 500 character query + try: + results = agent.retrieve(long_query, top_k=1) + assert isinstance(results, list), "Should handle long queries" + except Exception as e: + print(f"Long query failed (acceptable): {e}") + + +if __name__ == "__main__": + # Run with specific markers + pytest.main([__file__, "-v", "-m", "integration"]) diff --git a/tests/pipeline/test_minimal.py b/tests/pipeline/test_minimal.py new file mode 100644 index 0000000..3905eb8 --- /dev/null +++ b/tests/pipeline/test_minimal.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +""" +Minimal pipeline tests using only Google Gemini embeddings. +No local models like sentence transformers. +""" + +import os +import sys +import yaml +import tempfile +from pathlib import Path +from typing import Dict, Any + +# Add project root to path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +# Set environment variables +os.environ.setdefault('OPENAI_API_KEY', 'test-key-for-ci') +os.environ.setdefault('GOOGLE_API_KEY', 'test-key-for-ci') + + +def create_google_only_config() -> Dict[str, Any]: + """Create a configuration that only uses Google Gemini embeddings.""" + return { + "description": "Google Gemini only configuration for CI", + "retrieval_pipeline": { + "retriever": { + "type": "dense", + "top_k": 5, + "score_threshold": 0.1, + "embedding": { + "strategy": "dense", + "dense": { + "provider": "google", + "model": "models/embedding-001", + "dimensions": 768, + "api_key_env": "GOOGLE_API_KEY", + "batch_size": 16, + "vector_name": "dense" + } + }, + "qdrant": { + "collection_name": "test_ci_google_only", # Unique collection name + "dense_vector_name": "dense", + "host": "localhost", + "port": 6333, + "force_recreate": True # Force recreation to avoid conflicts + }, + "performance": { + "lazy_initialization": True, + "enable_caching": False + } + }, + "stages": [ + { + "type": "retriever", + "name": "google_dense_retriever" + } + ] + } + } + + +def test_config_loading() -> bool: + """Test that main configuration loads without errors.""" + print("🔍 Testing configuration loading...") + + try: + from config.config_loader import load_config + config = load_config() + + required_keys = ['llm', 'qdrant', 'agent_retrieval'] + for key in required_keys: + if key not in config: + print(f"❌ Missing config key: {key}") + return False + + print("✅ Configuration loads successfully") + return True + except Exception as e: + print(f"❌ Configuration loading failed: {e}") + return False + + +def test_agent_schema() -> bool: + """Test agent schema works correctly.""" + print("🔍 Testing agent schema...") + + try: + from agent.schema import AgentState + + state = AgentState( + question="Test question", + reference_date="2024-01-01", + chat_history=[] + ) + + # Check required fields + if 'question' not in state: + print("❌ Missing question field") + return False + + if 'reference_date' not in state: + print("❌ Missing reference_date field") + return False + + # Ensure SQL field was removed + if 'sql' in state: + print("❌ SQL field should not exist (was removed)") + return False + + print("✅ Agent schema works correctly") + return True + except Exception as e: + print(f"❌ Agent schema test failed: {e}") + return False + + +def test_google_embeddings_config() -> bool: + """Test Google embeddings configuration without actually calling the API.""" + print("🔍 Testing Google embeddings configuration...") + + try: + from embedding.factory import get_embedder + + # Test config structure + embedding_config = { + "provider": "google", + "model": "models/embedding-001", + "dimensions": 768, + "api_key_env": "GOOGLE_API_KEY" + } + + # This should not fail even without real API key + # We're just testing the configuration structure + try: + embedder = get_embedder(embedding_config) + # If we get here, the config structure is correct + print("✅ Google embeddings configuration is valid") + return True + except Exception as e: + # Expected if no real API key - check if it's an auth error + error_str = str(e).lower() + if any(keyword in error_str for keyword in ['api', 'key', 'auth', 'credential']): + print("✅ Google embeddings configuration is valid (auth expected in CI)") + return True + else: + print(f"❌ Unexpected embeddings error: {e}") + return False + except Exception as e: + print(f"❌ Google embeddings test failed: {e}") + return False + + +def test_agent_retriever_with_google() -> bool: + """Test agent retriever with Google-only configuration.""" + print("🔍 Testing agent retriever with Google embeddings...") + + try: + from bin.agent_retriever import ConfigurableRetrieverAgent + + # Create temporary config file with Google-only setup + config = create_google_only_config() + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + yaml.dump(config, f) + temp_config_path = f.name + + try: + # Initialize agent + agent = ConfigurableRetrieverAgent(temp_config_path) + + # Get config info + config_info = agent.get_config_info() + + if not isinstance(config_info, dict): + print("❌ Config info is not a dict") + return False + + if 'retriever_type' not in config_info: + print("❌ Missing retriever_type in config info") + return False + + print("✅ Agent retriever with Google embeddings works") + return True + + finally: + os.unlink(temp_config_path) + except Exception as e: + print(f"❌ Agent retriever test failed: {e}") + return False + + +def test_pipeline_factory_google_only() -> bool: + """Test pipeline factory with Google-only configuration.""" + print("🔍 Testing pipeline factory with Google embeddings...") + + try: + from components.retrieval_pipeline import RetrievalPipelineFactory + + config = create_google_only_config() + + # Try to create pipeline - this may fail due to missing API key + # but should not fail due to local model loading + try: + pipeline = RetrievalPipelineFactory.create_from_config(config) + print("✅ Pipeline factory works with Google embeddings") + return True + except Exception as e: + error_str = str(e).lower() + # These are acceptable errors in CI + acceptable_errors = [ + 'api', 'key', 'auth', 'credential', 'quota', 'permission', + 'qdrant', 'collection', 'database', 'connection' + ] + + if any(keyword in error_str for keyword in acceptable_errors): + print("✅ Pipeline factory handles missing services correctly") + return True + else: + print(f"❌ Unexpected pipeline factory error: {e}") + return False + except Exception as e: + print(f"❌ Pipeline factory test failed: {e}") + return False + + +def test_config_switching() -> bool: + """Test configuration switching mechanism.""" + print("🔍 Testing configuration switching...") + + try: + from bin.switch_agent_config import list_available_configs + + configs = list_available_configs() + + if len(configs) == 0: + print("❌ No configurations found") + return False + + # Verify each config file exists and has valid structure + for config_name, description, path in configs: + if not Path(path).exists(): + print(f"❌ Config file missing: {path}") + return False + + with open(path, 'r') as f: + config = yaml.safe_load(f) + + if 'retrieval_pipeline' not in config: + print(f"❌ Invalid config structure: {config_name}") + return False + + print(f"✅ Configuration switching works with {len(configs)} configs") + return True + except Exception as e: + print(f"❌ Configuration switching test failed: {e}") + return False + + +def run_minimal_pipeline_tests() -> bool: + """Run minimal pipeline tests without local models.""" + print("🧪 Minimal Pipeline Tests (Google Gemini Only)") + print("=" * 50) + + tests = [ + ("Configuration Loading", test_config_loading), + ("Agent Schema", test_agent_schema), + ("Google Embeddings Config", test_google_embeddings_config), + ("Agent Retriever (Google)", test_agent_retriever_with_google), + ("Pipeline Factory (Google)", test_pipeline_factory_google_only), + ("Configuration Switching", test_config_switching), + ] + + passed = 0 + failed_tests = [] + + for test_name, test_func in tests: + print(f"\n📋 {test_name}") + print("-" * 30) + + if test_func(): + passed += 1 + else: + failed_tests.append(test_name) + + total = len(tests) + print("\n" + "=" * 50) + print("📊 MINIMAL PIPELINE TEST RESULTS") + print("=" * 50) + + if passed == total: + print("🎉 ALL TESTS PASSED!") + print("✅ Pipeline works with Google Gemini embeddings") + return True + else: + print(f"❌ {total - passed} of {total} tests failed") + print("Failed tests:") + for test in failed_tests: + print(f" • {test}") + return False + + +def main(): + """Main function.""" + success = run_minimal_pipeline_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pipeline/test_minimal_pipeline.py b/tests/pipeline/test_minimal_pipeline.py new file mode 100644 index 0000000..343a135 --- /dev/null +++ b/tests/pipeline/test_minimal_pipeline.py @@ -0,0 +1,189 @@ +""" +Minimal Pipeline Tests + +This module contains minimal tests for the RAG pipeline that: +1. Don't require local model downloads (no sentence transformers) +2. Only use Google Gemini embeddings for CI compatibility +3. Test core pipeline functionality and configuration +4. Validate Qdrant connectivity (optional) + +Tests are designed to be runnable in CI/CD environments. +""" + +import pytest +import os +import yaml +import sys +from pathlib import Path + +# Add project root to path for imports +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +class TestMinimalPipeline: + """Test core pipeline functionality without local models.""" + + def test_config_loading(self): + """Test that main config.yml loads correctly.""" + from config.config_loader import load_config + + config = load_config("config.yml") + assert config is not None + assert isinstance(config, dict) + + # Check required sections exist + assert "embedding" in config + assert "retrievers" in config + assert "llm" in config + + def test_agent_schema_import(self): + """Test that agent schema imports correctly.""" + from agent.schema import AgentState + + # Test that schema doesn't require SQL fields + state_annotations = AgentState.__annotations__ + assert "question" in state_annotations + assert "answer" in state_annotations + assert "chat_history" in state_annotations + + def test_google_embedding_config(self): + """Test Google embeddings configuration without initialization.""" + from config.config_loader import load_config + + config = load_config("config.yml") + + # Check Google embeddings are configured in main config + embedding_config = config.get("embedding", {}) + dense_config = embedding_config.get("dense", {}) + + assert dense_config.get("provider") == "google" + assert dense_config.get("model") == "models/embedding-001" + assert dense_config.get("dimensions") == 768 + assert "api_key_env" in dense_config + + def test_ci_google_config_loads(self): + """Test that CI Google Gemini config loads correctly.""" + ci_config_path = "pipelines/configs/retrieval/ci_google_gemini.yml" + + with open(ci_config_path, 'r') as f: + config = yaml.safe_load(f) + + assert config is not None + retriever_config = config["retrieval_pipeline"]["retriever"] + + # Verify Google embeddings + embedding_config = retriever_config["embedding"]["dense"] + assert embedding_config["provider"] == "google" + assert embedding_config["model"] == "models/embedding-001" + assert embedding_config["dimensions"] == 768 + + def test_agent_retriever_config_load(self): + """Test agent retriever can load CI Google config.""" + from bin.agent_retriever import ConfigurableRetrieverAgent + + # Use CI Google config + ci_config_path = "pipelines/configs/retrieval/ci_google_gemini.yml" + + # Should not raise exception when loading config + agent = ConfigurableRetrieverAgent(ci_config_path, cache_pipeline=False) + config_info = agent.get_config_info() + + assert config_info["retriever_type"] == "dense" + assert "num_stages" in config_info + + def test_pipeline_factory_google_config(self): + """Test pipeline factory with Google config (no retrieval).""" + from components.retrieval_pipeline import RetrievalPipelineFactory + + # Load CI Google config + ci_config_path = "pipelines/configs/retrieval/ci_google_gemini.yml" + with open(ci_config_path, 'r') as f: + config = yaml.safe_load(f) + + # Should be able to parse config without errors + # Don't actually create pipeline to avoid requiring API keys + pipeline_config = config["retrieval_pipeline"] + assert pipeline_config["retriever"]["type"] == "dense" + assert pipeline_config["retriever"]["embedding"]["dense"]["provider"] == "google" + + def test_config_switching(self): + """Test that configuration switching mechanism works.""" + from config.config_loader import load_config, get_retriever_config + + main_config = load_config("config.yml") + + # Test extracting different retriever configs + dense_config = get_retriever_config(main_config, "dense") + hybrid_config = get_retriever_config(main_config, "hybrid") + + assert dense_config["type"] == "dense" + assert hybrid_config["type"] == "hybrid" + + # Both should have Google embeddings + assert dense_config["embedding"]["provider"] == "google" + assert hybrid_config["embedding"]["dense"]["provider"] == "google" + + +class TestConfigValidation: + """Test configuration file validation.""" + + def test_yaml_files_valid(self): + """Test that all YAML files in the project are valid.""" + yaml_files = [] + + # Find all YAML files + for root, dirs, files in os.walk("."): + for file in files: + if file.endswith(('.yml', '.yaml')): + yaml_files.append(os.path.join(root, file)) + + assert len(yaml_files) > 0, "No YAML files found" + + for yaml_file in yaml_files: + try: + with open(yaml_file, 'r') as f: + yaml.safe_load(f) + except yaml.YAMLError as e: + pytest.fail(f"Invalid YAML in {yaml_file}: {e}") + + def test_google_embeddings_config_complete(self): + """Test that Google embeddings configurations are complete.""" + config_files = [ + "config.yml", + "pipelines/configs/retrieval/ci_google_gemini.yml" + ] + + for config_file in config_files: + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + # Find Google embedding configs + google_configs = self._find_google_configs(config) + assert len(google_configs) > 0, f"No Google configs found in {config_file}" + + for google_config in google_configs: + assert google_config.get("provider") == "google" + assert "model" in google_config + assert "dimensions" in google_config + assert google_config.get("dimensions") == 768 + + def _find_google_configs(self, config, path=[]): + """Recursively find Google embedding configurations.""" + google_configs = [] + + if isinstance(config, dict): + for key, value in config.items(): + if isinstance(value, dict) and value.get("provider") == "google": + google_configs.append(value) + else: + google_configs.extend(self._find_google_configs(value, path + [key])) + elif isinstance(config, list): + for i, item in enumerate(config): + google_configs.extend(self._find_google_configs(item, path + [i])) + + return google_configs + + +if __name__ == "__main__": + # Run tests directly + pytest.main([__file__, "-v"]) diff --git a/tests/pipeline/test_qdrant.py b/tests/pipeline/test_qdrant.py new file mode 100644 index 0000000..733beaa --- /dev/null +++ b/tests/pipeline/test_qdrant.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +Simple Qdrant connectivity test for CI environments. +No embedding models, just basic database operations. +""" + +import requests +import json +import sys +import time +from typing import Dict, Any + + +def wait_for_qdrant(max_attempts: int = 30) -> bool: + """Wait for Qdrant to be ready.""" + print(f"⏳ Waiting for Qdrant to be ready...") + + for attempt in range(1, max_attempts + 1): + try: + # Try collections endpoint instead of health (more reliable) + response = requests.get( + 'http://localhost:6333/collections', timeout=2) + if response.status_code == 200: + print(f"✅ Qdrant is ready! (attempt {attempt})") + return True + except: + pass + + if attempt < max_attempts: + print(f"⏳ Attempt {attempt}/{max_attempts}...") + time.sleep(2) + + print("❌ Qdrant failed to start") + return False + + +def test_qdrant_connectivity() -> bool: + """Test basic Qdrant connectivity using multiple endpoints.""" + print("🔍 Testing Qdrant connectivity...") + + # Try multiple endpoints to verify connectivity + endpoints_to_test = [ + ("/collections", "Collections endpoint"), + ("/health", "Health endpoint (optional)"), + ] + + working_endpoints = 0 + + for endpoint, description in endpoints_to_test: + try: + response = requests.get( + f'http://localhost:6333{endpoint}', timeout=5) + if response.status_code == 200: + print(f"✅ {description} working") + working_endpoints += 1 + else: + print( + f"⚠️ {description} returned {response.status_code} (may be normal)") + except Exception as e: + print(f"⚠️ {description} error: {e}") + + # We only need at least one endpoint to work (collections is essential) + if working_endpoints > 0: + print("✅ Qdrant connectivity confirmed") + return True + else: + print("❌ No Qdrant endpoints accessible") + return False + + +def test_qdrant_collections_endpoint() -> bool: + """Test Qdrant collections endpoint.""" + print("🔍 Testing Qdrant collections endpoint...") + + try: + response = requests.get('http://localhost:6333/collections', timeout=5) + if response.status_code == 200: + print("✅ Collections endpoint accessible") + return True + else: + print(f"❌ Collections endpoint failed: {response.status_code}") + return False + except Exception as e: + print(f"❌ Collections endpoint error: {e}") + return False + + +def test_create_delete_collection() -> bool: + """Test creating and deleting a simple test collection.""" + print("🔍 Testing collection creation/deletion...") + + collection_name = "test_ci_minimal" + collection_config = { + 'vectors': { + 'size': 384, # Small vector size for testing + 'distance': 'Cosine' + } + } + + try: + # Clean up if collection exists + requests.delete( + f'http://localhost:6333/collections/{collection_name}', timeout=5) + + # Create collection + create_response = requests.put( + f'http://localhost:6333/collections/{collection_name}', + json=collection_config, + timeout=10 + ) + + if create_response.status_code not in [200, 201]: + print( + f"❌ Failed to create collection: {create_response.status_code}") + return False + + # Verify collection exists + info_response = requests.get( + f'http://localhost:6333/collections/{collection_name}', + timeout=5 + ) + + if info_response.status_code != 200: + print( + f"❌ Failed to get collection info: {info_response.status_code}") + return False + + # Delete collection + delete_response = requests.delete( + f'http://localhost:6333/collections/{collection_name}', + timeout=5 + ) + + if delete_response.status_code not in [200, 404]: + print( + f"❌ Failed to delete collection: {delete_response.status_code}") + return False + + print("✅ Collection operations successful") + return True + + except Exception as e: + print(f"❌ Collection operations error: {e}") + return False + + +def run_qdrant_tests() -> bool: + """Run basic Qdrant connectivity tests.""" + print("🗄️ Qdrant Connectivity Tests") + print("=" * 35) + + tests = [ + ("Basic Connectivity", test_qdrant_connectivity), + ("Collections Endpoint", test_qdrant_collections_endpoint), + ("Collection Operations", test_create_delete_collection), + ] + + passed = 0 + failed_tests = [] + + for test_name, test_func in tests: + print(f"\n📋 {test_name}") + print("-" * 20) + + if test_func(): + passed += 1 + else: + failed_tests.append(test_name) + + total = len(tests) + print("\n" + "=" * 35) + print("📊 QDRANT TEST RESULTS") + print("=" * 35) + + if passed == total: + print("🎉 ALL QDRANT TESTS PASSED!") + return True + else: + print(f"❌ {total - passed} of {total} tests failed") + print("Failed tests:") + for test in failed_tests: + print(f" • {test}") + return False + + +def main(): + """Main function.""" + import argparse + + parser = argparse.ArgumentParser(description="Test Qdrant connectivity") + parser.add_argument("--wait", action="store_true", + help="Wait for Qdrant first") + parser.add_argument("--max-wait", type=int, default=30, + help="Max wait attempts") + + args = parser.parse_args() + + if args.wait: + if not wait_for_qdrant(args.max_wait): + return 1 + + success = run_qdrant_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pipeline/test_qdrant_connectivity.py b/tests/pipeline/test_qdrant_connectivity.py new file mode 100644 index 0000000..08a9fda --- /dev/null +++ b/tests/pipeline/test_qdrant_connectivity.py @@ -0,0 +1,122 @@ +""" +Qdrant Database Connectivity Tests + +Tests basic Qdrant connectivity and operations without requiring embeddings. +These tests are optional and can be skipped if Qdrant is not available. +""" + +import pytest +import requests +import os +from typing import Dict, Any + + +class TestQdrantConnectivity: + """Test Qdrant database connectivity (optional).""" + + @pytest.fixture + def qdrant_config(self) -> Dict[str, Any]: + """Get Qdrant configuration.""" + return { + "host": os.getenv("QDRANT_HOST", "localhost"), + "port": int(os.getenv("QDRANT_PORT", "6333")), + "url": f"http://{os.getenv('QDRANT_HOST', 'localhost')}:{os.getenv('QDRANT_PORT', '6333')}" + } + + @pytest.mark.integration + def test_qdrant_health_endpoint(self, qdrant_config): + """Test Qdrant health endpoint is accessible.""" + try: + health_url = f"{qdrant_config['url']}/health" + response = requests.get(health_url, timeout=5) + + # Qdrant returns 200 with health status, but sometimes health endpoint is different + # If 404, try the collections endpoint instead as a health check + if response.status_code == 404: + collections_url = f"{qdrant_config['url']}/collections" + response = requests.get(collections_url, timeout=5) + assert response.status_code == 200 + else: + assert response.status_code == 200 + + except requests.ConnectionError: + pytest.skip("Qdrant not available - skipping connectivity tests") + except requests.Timeout: + pytest.skip("Qdrant connection timeout - skipping connectivity tests") + + @pytest.mark.integration + def test_qdrant_collections_endpoint(self, qdrant_config): + """Test Qdrant collections endpoint is accessible.""" + try: + collections_url = f"{qdrant_config['url']}/collections" + response = requests.get(collections_url, timeout=5) + + # Should return 200 with collections list (might be empty) + assert response.status_code == 200 + + data = response.json() + assert "result" in data + assert "collections" in data["result"] + + except requests.ConnectionError: + pytest.skip("Qdrant not available - skipping connectivity tests") + except requests.Timeout: + pytest.skip("Qdrant connection timeout - skipping connectivity tests") + + @pytest.mark.integration + def test_qdrant_collection_creation_deletion(self, qdrant_config): + """Test basic collection creation and deletion (no embeddings).""" + try: + base_url = qdrant_config['url'] + test_collection = "test_minimal_collection" + + # Create collection with minimal config + create_url = f"{base_url}/collections/{test_collection}" + create_payload = { + "vectors": { + "size": 768, # Google embeddings size + "distance": "Cosine" + } + } + + # Clean up first if exists + requests.delete(create_url, timeout=5) + + # Create collection + response = requests.put(create_url, json=create_payload, timeout=5) + assert response.status_code in [200, 201] + + # Verify collection exists + list_response = requests.get(f"{base_url}/collections", timeout=5) + assert list_response.status_code == 200 + + collections = list_response.json()["result"]["collections"] + collection_names = [c["name"] for c in collections] + assert test_collection in collection_names + + # Clean up + delete_response = requests.delete(create_url, timeout=5) + assert delete_response.status_code == 200 + + except requests.ConnectionError: + pytest.skip("Qdrant not available - skipping connectivity tests") + except requests.Timeout: + pytest.skip("Qdrant connection timeout - skipping connectivity tests") + + @pytest.mark.integration + def test_qdrant_client_import(self): + """Test that Qdrant client can be imported.""" + try: + from qdrant_client import QdrantClient + + # Should be able to create client instance (doesn't connect yet) + client = QdrantClient(host="localhost", port=6333) + assert client is not None + + except ImportError: + pytest.skip("qdrant-client not installed - skipping client tests") + + +if __name__ == "__main__": + # Run tests directly + pytest.main([__file__, "-v"]) diff --git a/tests/pipeline/test_runner.py b/tests/pipeline/test_runner.py new file mode 100644 index 0000000..598843f --- /dev/null +++ b/tests/pipeline/test_runner.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Runner for all minimal pipeline tests. +Combines configuration, pipeline, and database tests. +""" + +import sys +import os +from pathlib import Path + +# Add project root to path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +# Import test functions directly +def import_test_functions(): + """Import test functions dynamically.""" + import importlib.util + + # Import config tests + config_spec = importlib.util.spec_from_file_location( + "test_config", Path(__file__).parent / "test_config.py" + ) + config_module = importlib.util.module_from_spec(config_spec) + config_spec.loader.exec_module(config_module) + + # Import minimal tests + minimal_spec = importlib.util.spec_from_file_location( + "test_minimal", Path(__file__).parent / "test_minimal.py" + ) + minimal_module = importlib.util.module_from_spec(minimal_spec) + minimal_spec.loader.exec_module(minimal_module) + + # Import qdrant tests + qdrant_spec = importlib.util.spec_from_file_location( + "test_qdrant", Path(__file__).parent / "test_qdrant.py" + ) + qdrant_module = importlib.util.module_from_spec(qdrant_spec) + qdrant_spec.loader.exec_module(qdrant_module) + + return ( + config_module.run_config_validation_tests, + minimal_module.run_minimal_pipeline_tests, + qdrant_module.run_qdrant_tests, + qdrant_module.wait_for_qdrant + ) + + +def run_all_pipeline_tests() -> bool: + """Run all minimal pipeline tests.""" + print("🧪 Complete Minimal Pipeline Test Suite") + print("=" * 50) + print("🎯 Using only Google Gemini embeddings (no local models)") + print("=" * 50) + + # Import test functions + (run_config_validation_tests, + run_minimal_pipeline_tests, + run_qdrant_tests, + wait_for_qdrant) = import_test_functions() + + # Test categories + test_suites = [ + ("Configuration Validation", run_config_validation_tests), + ("Minimal Pipeline Tests", run_minimal_pipeline_tests), + ] + + # Add Qdrant tests if enabled + if os.getenv('CI_RUN_DB_TESTS'): + print("🗄️ Database tests enabled") + if wait_for_qdrant(60): + test_suites.append(("Qdrant Connectivity", run_qdrant_tests)) + else: + print("⚠️ Qdrant not available, skipping database tests") + else: + print("⚠️ Database tests disabled (CI_RUN_DB_TESTS not set)") + + # Run test suites + passed_suites = 0 + failed_suites = [] + + for suite_name, test_func in test_suites: + print(f"\n🚀 Running {suite_name}") + print("=" * 50) + + if test_func(): + passed_suites += 1 + print(f"✅ {suite_name} PASSED") + else: + failed_suites.append(suite_name) + print(f"❌ {suite_name} FAILED") + + # Final summary + total_suites = len(test_suites) + print("\n" + "=" * 50) + print("📊 FINAL PIPELINE TEST RESULTS") + print("=" * 50) + + if passed_suites == total_suites: + print("🎉 ALL PIPELINE TESTS PASSED!") + print("✅ Pipeline is ready for production") + print("✅ Google Gemini embeddings properly configured") + print("✅ No local models required") + return True + else: + print(f"❌ {total_suites - passed_suites} of {total_suites} test suites failed") + print("Failed test suites:") + for suite in failed_suites: + print(f" • {suite}") + print("\n🔧 Please fix the issues above") + return False + + +def main(): + """Main function.""" + import argparse + + parser = argparse.ArgumentParser(description="Run minimal pipeline tests") + parser.add_argument("--with-db", action="store_true", + help="Enable database tests") + parser.add_argument("--wait", action="store_true", + help="Wait for Qdrant if database tests enabled") + + args = parser.parse_args() + + if args.with_db: + os.environ['CI_RUN_DB_TESTS'] = '1' + + if args.wait and os.getenv('CI_RUN_DB_TESTS'): + print("⏳ Waiting for Qdrant...") + _, _, _, wait_for_qdrant = import_test_functions() + wait_for_qdrant(60) + + success = run_all_pipeline_tests() + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/requirements-minimal.txt b/tests/requirements-minimal.txt new file mode 100644 index 0000000..d6df22c --- /dev/null +++ b/tests/requirements-minimal.txt @@ -0,0 +1,100 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +annotated-types==0.7.0 +anyio==4.10.0 +attrs==25.3.0 +boto3==1.40.21 +botocore==1.40.21 +cachetools==5.5.2 +certifi==2025.8.3 +charset-normalizer==3.4.3 +coloredlogs==15.0.1 +dataclasses-json==0.6.7 +distro==1.9.0 +dotenv==0.9.9 +fastembed==0.7.3 +filelock==3.19.1 +filetype==1.2.0 +flatbuffers==25.2.10 +frozenlist==1.7.0 +fsspec==2025.7.0 +google-ai-generativelanguage==0.6.18 +google-api-core==2.25.1 +google-auth==2.40.3 +googleapis-common-protos==1.70.0 +greenlet==3.2.4 +grpcio==1.74.0 +grpcio-status==1.74.0 +h11==0.16.0 +h2==4.3.0 +hf-xet==1.1.9 +hpack==4.1.0 +httpcore==1.0.9 +httpx==0.28.1 +httpx-sse==0.4.1 +huggingface-hub==0.34.4 +humanfriendly==10.0 +hyperframe==6.1.0 +idna==3.10 +iniconfig==2.1.0 +jiter==0.10.0 +jmespath==1.0.1 +jsonpatch==1.33 +jsonpointer==3.0.0 +langchain==0.3.27 +langchain-community==0.3.29 +langchain-core==0.3.75 +langchain-google-genai==2.1.10 +langchain-openai==0.3.32 +langchain-qdrant==0.2.0 +langchain-text-splitters==0.3.10 +langsmith==0.4.21 +loguru==0.7.3 +marshmallow==3.26.1 +mmh3==5.2.0 +mpmath==1.3.0 +multidict==6.6.4 +mypy_extensions==1.1.0 +numpy==2.3.2 +onnxruntime==1.22.1 +openai==1.102.0 +orjson==3.11.3 +packaging==25.0 +pillow==11.3.0 +pluggy==1.6.0 +portalocker==3.2.0 +propcache==0.3.2 +proto-plus==1.26.1 +protobuf==6.32.0 +py_rust_stemmers==0.1.5 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pydantic==2.11.7 +pydantic-settings==2.10.1 +pydantic_core==2.33.2 +Pygments==2.19.2 +pytest==8.4.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +PyYAML==6.0.2 +qdrant-client==1.15.1 +regex==2025.8.29 +requests==2.32.5 +requests-toolbelt==1.0.0 +rsa==4.9.1 +s3transfer==0.13.1 +six==1.17.0 +sniffio==1.3.1 +SQLAlchemy==2.0.43 +sympy==1.14.0 +tenacity==9.1.2 +tiktoken==0.11.0 +tokenizers==0.22.0 +tqdm==4.67.1 +typing-inspect==0.9.0 +typing-inspection==0.4.1 +typing_extensions==4.15.0 +urllib3==2.5.0 +yarl==1.20.1 +zstandard==0.24.0