From df43fd100015aead44af03a7cf95330934b04ce4 Mon Sep 17 00:00:00 2001 From: Mini256 Date: Mon, 13 Jan 2025 10:26:26 +0800 Subject: [PATCH] init --- .../app/api/admin_routes/retrieve/models.py | 28 +++++++++ .../app/api/admin_routes/retrieve/routes.py | 35 ++++++++++++ .../app/rag/retrievers/KBHybirdRetriever.py | 0 .../retrievers/KBKnowledgeGraphRetriever.py | 0 .../rag/retrievers/KBVectorSearchRetriever.py | 57 +++++++++++++++++++ 5 files changed, 120 insertions(+) create mode 100644 backend/app/api/admin_routes/retrieve/models.py create mode 100644 backend/app/api/admin_routes/retrieve/routes.py create mode 100644 backend/app/rag/retrievers/KBHybirdRetriever.py create mode 100644 backend/app/rag/retrievers/KBKnowledgeGraphRetriever.py create mode 100644 backend/app/rag/retrievers/KBVectorSearchRetriever.py diff --git a/backend/app/api/admin_routes/retrieve/models.py b/backend/app/api/admin_routes/retrieve/models.py new file mode 100644 index 000000000..3df65d0a9 --- /dev/null +++ b/backend/app/api/admin_routes/retrieve/models.py @@ -0,0 +1,28 @@ +from app.models import Chunk +from typing import List +from pydantic import BaseModel + + +class AppConfig(BaseModel): + pass + + +class RetrievalConfig(BaseModel): + top_k: int = 10 + similarity_top_k: int = None + metadata_filters: dict = {} + oversampling_factor: int = 5 + + +class RetrieveRequest(BaseModel): + query: str + retrieval_config: RetrievalConfig = RetrievalConfig() + + +class RetrievedChunk(BaseModel): + chunk: Chunk + score: float + + +class RetrieveResponse(BaseModel): + chunks: List[RetrievedChunk] diff --git a/backend/app/api/admin_routes/retrieve/routes.py b/backend/app/api/admin_routes/retrieve/routes.py new file mode 100644 index 000000000..1d01f4b12 --- /dev/null +++ b/backend/app/api/admin_routes/retrieve/routes.py @@ -0,0 +1,35 @@ +import logging +from typing import List + +from fastapi import APIRouter +from app.api.admin_routes.models import ChatEngineBasedRetrieveRequest +from app.api.deps import SessionDep, CurrentSuperuserDep +from llama_index.core.schema import NodeWithScore +from app.rag.retrieve import retrieve_service + +from app.exceptions import InternalServerError, KBNotFound + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("/admin/retrieve/chunks") +def retrieve_chunks( + session: SessionDep, + user: CurrentSuperuserDep, + request: ChatEngineBasedRetrieveRequest, +) -> List[NodeWithScore]: + try: + return retrieve_service.chat_engine_retrieve_chunks( + session, + request.query, + top_k=request.top_k, + similarity_top_k=request.similarity_top_k, + oversampling_factor=request.oversampling_factor, + enable_kg_enhance_query_refine=request.enable_kg_enhance_query_refine, + ) + except KBNotFound as e: + raise e + except Exception as e: + logger.exception(e) + raise InternalServerError() diff --git a/backend/app/rag/retrievers/KBHybirdRetriever.py b/backend/app/rag/retrievers/KBHybirdRetriever.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/rag/retrievers/KBKnowledgeGraphRetriever.py b/backend/app/rag/retrievers/KBKnowledgeGraphRetriever.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/rag/retrievers/KBVectorSearchRetriever.py b/backend/app/rag/retrievers/KBVectorSearchRetriever.py new file mode 100644 index 000000000..b5f695ed3 --- /dev/null +++ b/backend/app/rag/retrievers/KBVectorSearchRetriever.py @@ -0,0 +1,57 @@ +import logging +from llama_index.core.retrievers import BaseRetriever +from llama_index.core.schema import NodeWithScore, QueryBundle +from pydantic import BaseModel, List + +from app.rag.postprocessors.metadata_post_filter import MetadataFilters + + +logger = logging.getLogger(__name__) + + +class VectorSearchRerankerConfig(BaseModel): + enable: bool = True + + +class VectorSearchMetadataFilterConfig(BaseModel): + enable: bool = True + filters: MetadataFilters = None + + +class VectorSearchRetrieverConfig(BaseModel): + enable: bool = True + top_k: int = 10 + similarity_top_k: int = None + oversampling_factor: int = 5 + reranker: VectorSearchRerankerConfig = None + metadata_filter: VectorSearchMetadataFilterConfig = None + + +class KnowledgeGraphRetrieverConfig(BaseModel): + enable: bool = False + + +class KnowledgeBaseConfig(BaseModel): + linked_knowledge_base: LinkedKnowledgeBaseConfig + + +class RetrieverConfig(BaseModel): + knowledge_base: KnowledgeBaseConfig + vector_search: VectorSearchRetrieverConfig + knowledge_graph: KnowledgeGraphRetrieverConfig + + +class VectorSearchRetriever(BaseRetriever): + def __init__(self, config: VectorSearchRetrieverConfig): + pass + + +class AppRetriever(BaseRetriever): + def __init__( + self, + config: RetrieverConfig, + ): + pass + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + pass