Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/basic_usage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Basic usage example for insta_rag library."""

import os
from pathlib import Path

from dotenv import load_dotenv
Expand Down Expand Up @@ -93,7 +92,7 @@ def main():
metadata={"project": "insta_rag_demo"},
)

print(f"\n✓ Documents processed successfully!")
print("\n✓ Documents processed successfully!")
print(f" - Documents processed: {response.documents_processed}")
print(f" - Total chunks created: {response.total_chunks}")
print(f" - Total tokens: {response.processing_stats.total_tokens}")
Expand All @@ -103,12 +102,12 @@ def main():
print(f" - Total time: {response.processing_stats.total_time_ms:.2f}ms")

if response.errors:
print(f"\n⚠ Errors encountered:")
print("\n⚠ Errors encountered:")
for error in response.errors:
print(f" - {error}")

# Display chunk information
print(f"\nChunk Details:")
print("\nChunk Details:")
for i, chunk in enumerate(response.chunks[:3]): # Show first 3 chunks
print(f"\nChunk {i + 1}:")
print(f" - ID: {chunk.chunk_id}")
Expand Down
2 changes: 1 addition & 1 deletion src/insta_rag/chunking/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from ..utils.exceptions import ChunkingError
from insta_rag.utils.exceptions import ChunkingError
from ..models.chunk import Chunk, ChunkMetadata
from .base import BaseChunker
from .utils import (
Expand Down
8 changes: 5 additions & 3 deletions src/insta_rag/chunking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def split_into_sentences(text: str) -> List[str]:
"""
# Use regex to split on sentence boundaries
# This handles common cases but may not be perfect for all texts
sentence_endings = re.compile(r'(?<=[.!?])\s+(?=[A-Z])')
sentence_endings = re.compile(r"(?<=[.!?])\s+(?=[A-Z])")
sentences = sentence_endings.split(text)

# Clean up sentences
Expand All @@ -113,7 +113,7 @@ def split_into_paragraphs(text: str) -> List[str]:
List of paragraphs
"""
# Split on double newlines
paragraphs = re.split(r'\n\s*\n', text)
paragraphs = re.split(r"\n\s*\n", text)

# Clean up paragraphs
paragraphs = [p.strip() for p in paragraphs if p.strip()]
Expand Down Expand Up @@ -156,7 +156,9 @@ def validate_chunk_quality(chunk: str) -> bool:
return True


def add_overlap_to_chunks(chunks: List[str], overlap_percentage: float = 0.2) -> List[str]:
def add_overlap_to_chunks(
chunks: List[str], overlap_percentage: float = 0.2
) -> List[str]:
"""Add overlap between consecutive chunks.

Args:
Expand Down
9 changes: 6 additions & 3 deletions src/insta_rag/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from ..chunking.semantic import SemanticChunker
from ..embedding.openai import OpenAIEmbedder
from ..utils.exceptions import ValidationError, VectorDBError
from insta_rag.utils.exceptions import ValidationError, VectorDBError
from ..models.document import DocumentInput, SourceType
from ..models.response import (
AddDocumentsResponse,
ProcessingStats,
UpdateDocumentsResponse,
)
from ..utils.pdf_processing import extract_text_from_pdf
from insta_rag.utils.pdf_processing import extract_text_from_pdf
from ..vectordb.qdrant import QdrantVectorDB
from .config import RAGConfig

Expand Down Expand Up @@ -325,7 +325,10 @@ def update_documents(
NoDocumentsFoundError: No documents match criteria (for delete/replace)
VectorDBError: Qdrant operation failures
"""
from ..utils.exceptions import CollectionNotFoundError, NoDocumentsFoundError
from insta_rag.utils.exceptions import (
CollectionNotFoundError,
NoDocumentsFoundError,
)

start_time = time.time()
errors = []
Expand Down
2 changes: 1 addition & 1 deletion src/insta_rag/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

from ..utils.exceptions import ConfigurationError
from insta_rag.utils.exceptions import ConfigurationError


@dataclass
Expand Down
9 changes: 6 additions & 3 deletions src/insta_rag/core/retrieval_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import time
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional

from ..models.response import (
RetrievalResponse,
Expand Down Expand Up @@ -138,7 +138,7 @@ def retrieve(
... collection_name="knowledge_base",
... filters={"user_id": "user_123"},
... top_k=10,
... enable_reranking=True
... enable_reranking=True,
... )
>>> for chunk in response.chunks:
... print(f"Score: {chunk.relevance_score:.4f}")
Expand Down Expand Up @@ -230,7 +230,10 @@ def retrieve(
chunk_dict = {}
for chunk in all_chunks:
chunk_id = chunk.chunk_id
if chunk_id not in chunk_dict or chunk.score > chunk_dict[chunk_id].score:
if (
chunk_id not in chunk_dict
or chunk.score > chunk_dict[chunk_id].score
):
chunk_dict[chunk_id] = chunk
unique_chunks = list(chunk_dict.values())
else:
Expand Down
2 changes: 1 addition & 1 deletion src/insta_rag/embedding/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import List, Optional

from ..utils.exceptions import EmbeddingError
from insta_rag.utils.exceptions import EmbeddingError
from .base import BaseEmbedder


Expand Down
4 changes: 3 additions & 1 deletion src/insta_rag/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def from_binary(
def get_source_path(self) -> Optional[Path]:
"""Get source as Path if it's a file, None otherwise."""
if self.source_type == SourceType.FILE:
return Path(self.source) if not isinstance(self.source, Path) else self.source
return (
Path(self.source) if not isinstance(self.source, Path) else self.source
)
return None

def get_source_text(self) -> Optional[str]:
Expand Down
4 changes: 3 additions & 1 deletion src/insta_rag/models/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ class UpdateDocumentsResponse:
chunks_added: int = 0
chunks_updated: int = 0
updated_document_ids: List[str] = field(default_factory=list)
chunks: List[Chunk] = field(default_factory=list) # NEW: For external storage (e.g., MongoDB)
chunks: List[Chunk] = field(
default_factory=list
) # NEW: For external storage (e.g., MongoDB)
errors: List[str] = field(default_factory=list)

def to_dict(self) -> Dict[str, Any]:
Expand Down
22 changes: 12 additions & 10 deletions src/insta_rag/retrieval/keyword_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,18 @@ def _build_corpus(self):
mongodb_id = point.payload.get("mongodb_id")
if mongodb_id and self.rag_client.mongodb:
try:
mongo_doc = self.rag_client.mongodb.get_chunk_content_by_mongo_id(
str(mongodb_id)
mongo_doc = (
self.rag_client.mongodb.get_chunk_content_by_mongo_id(
str(mongodb_id)
)
)
if mongo_doc:
content = mongo_doc.get("content", "")
mongodb_fetch_count += 1
except Exception as e:
print(f" Warning: Failed to fetch content from MongoDB for chunk {point.payload.get('chunk_id')}: {e}")
print(
f" Warning: Failed to fetch content from MongoDB for chunk {point.payload.get('chunk_id')}: {e}"
)
skipped_count += 1
continue

Expand All @@ -112,7 +116,9 @@ def _build_corpus(self):
)

if mongodb_fetch_count > 0:
print(f" Fetched content for {mongodb_fetch_count} chunks from MongoDB")
print(
f" Fetched content for {mongodb_fetch_count} chunks from MongoDB"
)
if skipped_count > 0:
print(f" Skipped {skipped_count} chunks without content")

Expand All @@ -122,7 +128,7 @@ def _build_corpus(self):
print(f" ✓ BM25 corpus built: {len(self.corpus)} documents indexed")
else:
self.bm25 = None
print(f" ⚠️ BM25 corpus is empty - no documents indexed")
print(" ⚠️ BM25 corpus is empty - no documents indexed")

except ImportError:
print(
Expand Down Expand Up @@ -171,11 +177,7 @@ def search(
if filters:
match = True
for key, value in filters.items():
if (
value is not None
and value != ""
and value != {}
):
if value is not None and value != "" and value != {}:
if chunk_data["metadata"].get(key) != value:
match = False
break
Expand Down
2 changes: 1 addition & 1 deletion src/insta_rag/retrieval/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from typing import Dict

from ..utils.exceptions import QueryGenerationError
from insta_rag.utils.exceptions import QueryGenerationError


class HyDEQueryGenerator:
Expand Down
62 changes: 22 additions & 40 deletions src/insta_rag/retrieval/reranker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Reranking implementations for improving retrieval results."""

import time
import json
from typing import Any, Dict, List, Tuple
import requests
Expand Down Expand Up @@ -30,7 +29,7 @@ def __init__(
api_key: str,
api_url: str = "http://118.67.212.45:8000/rerank",
normalize: bool = False,
timeout: int = 30
timeout: int = 30,
):
"""Initialize BGE reranker.

Expand All @@ -46,10 +45,7 @@ def __init__(
self.timeout = timeout

def rerank(
self,
query: str,
chunks: List[Tuple[str, Dict[str, Any]]],
top_k: int
self, query: str, chunks: List[Tuple[str, Dict[str, Any]]], top_k: int
) -> List[Tuple[int, float]]:
"""Rerank chunks based on relevance to query using BGE reranker.

Expand All @@ -75,22 +71,19 @@ def rerank(
"query": query,
"documents": documents,
"top_k": min(top_k, len(documents)), # Don't request more than available
"normalize": self.normalize
"normalize": self.normalize,
}

headers = {
"accept": "application/json",
"X-API-Key": self.api_key,
"Content-Type": "application/json"
"Content-Type": "application/json",
}

try:
# Make API request
response = requests.post(
self.api_url,
json=request_data,
headers=headers,
timeout=self.timeout
self.api_url, json=request_data, headers=headers, timeout=self.timeout
)
response.raise_for_status()

Expand Down Expand Up @@ -130,10 +123,7 @@ def __init__(self, api_key: str, model: str = "rerank-english-v3.0"):
self.model = model

def rerank(
self,
query: str,
chunks: List[Tuple[str, Dict[str, Any]]],
top_k: int
self, query: str, chunks: List[Tuple[str, Dict[str, Any]]], top_k: int
) -> List[Tuple[int, float]]:
"""Rerank chunks using Cohere API.

Expand Down Expand Up @@ -167,7 +157,7 @@ def __init__(
api_key: str,
base_url: str,
model: str = "gpt-oss-120b",
timeout: int = 60
timeout: int = 60,
):
"""Initialize LLM reranker.

Expand All @@ -181,17 +171,10 @@ def __init__(
self.base_url = base_url
self.model = model
self.timeout = timeout
self.client = OpenAI(
base_url=base_url,
api_key=api_key,
timeout=timeout
)
self.client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout)

def rerank(
self,
query: str,
chunks: List[Tuple[str, Dict[str, Any]]],
top_k: int
self, query: str, chunks: List[Tuple[str, Dict[str, Any]]], top_k: int
) -> List[Tuple[int, float]]:
"""Rerank chunks based on relevance to query using LLM.

Expand Down Expand Up @@ -253,31 +236,28 @@ def rerank(
messages=[
{
"role": "system",
"content": "You are a relevance scoring system that returns only valid JSON arrays. Never include explanations, only return the JSON array."
"content": "You are a relevance scoring system that returns only valid JSON arrays. Never include explanations, only return the JSON array.",
},
{
"role": "user",
"content": prompt
}
{"role": "user", "content": prompt},
],
temperature=0.0,
max_tokens=2000
max_tokens=2000,
)

# Extract the response content
response_text = response.choices[0].message.content.strip()

# Print the OSS model response for logging
print(f"\n{'='*80}")
print(f"OSS MODEL RESPONSE (gpt-oss-120b):")
print(f"{'='*80}")
print(f"\n{'=' * 80}")
print("OSS MODEL RESPONSE (gpt-oss-120b):")
print(f"{'=' * 80}")
print(response_text)
print(f"{'='*80}\n")
print(f"{'=' * 80}\n")

# Try to find JSON array in the response
# Sometimes LLM might add extra text, so we look for the JSON array
start_idx = response_text.find('[')
end_idx = response_text.rfind(']') + 1
start_idx = response_text.find("[")
end_idx = response_text.rfind("]") + 1

if start_idx == -1 or end_idx == 0:
raise ValueError("No JSON array found in LLM response")
Expand Down Expand Up @@ -312,6 +292,8 @@ def rerank(
return reranked_results

except json.JSONDecodeError as e:
raise Exception(f"LLM reranker failed to parse JSON response: {str(e)}. Response: {response_text[:200]}")
raise Exception(
f"LLM reranker failed to parse JSON response: {str(e)}. Response: {response_text[:200]}"
)
except Exception as e:
raise Exception(f"LLM reranker API request failed: {str(e)}")
raise Exception(f"LLM reranker API request failed: {str(e)}")
3 changes: 1 addition & 2 deletions src/insta_rag/vectordb/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import uuid
from typing import Any, Dict, List, Optional

from ..utils.exceptions import CollectionNotFoundError, VectorDBError
from insta_rag.utils.exceptions import CollectionNotFoundError, VectorDBError
from .base import BaseVectorDB, VectorSearchResult


Expand Down