Skip to content

Commit

Permalink
refactor: refine chat flow and retrievers (#589)
Browse files Browse the repository at this point in the history
close #565

- [x] Add VectorSearchRetrieverr (support single KB pure retrieve)
   - [x] metadata filters
   - [ ] similarity_threshold
- [x] Add VectorSearchFusionRetriever
   - [ ] support multiple KB route
   - [x] support query decompose
- [x] Add KnowledgeGraphRetriever (support single KB pure retrieve)
   - [x] Add `KnowledgeGraphNode` (llamaindex node)
   - [x] metadata filters
   - [ ] similarity_threshold
- [x] Add KnowledgeGraphFusionRetriever 
   - [ ] support multiple KB route
   - [x] support query decompose (old IntentAnalyzer)
- [x] Add retrieve APIs
  - [x]  `/api/v1/retrieve/chunks`
  - [x]  `/api/v1/retrieve/knowledge_graph`
  - [x]  `/api/v1/admin/knowledge_bases/{kb_id}/chunks/retrieve`
  - [x]  `/api/v1/admin/knowledge_bases/{kb_id}/graph/retrieve` 
- [x] Integration with ChatService
- [ ] tests and fix bug

Feature PRs:
- merge `TiDBGraphEditor` to `TiDBGraphStore`, TiDBGraphStore will only
retain storage layer operactions, so it shouldn't depend on DsPy or LLM.
- `TiDBKnowledgeGraphIndex` will handle the merge entities parts.
- `KnowledgeGraphRetriever` will handle query compose parts.
Mini256 authored Jan 24, 2025
1 parent 97cf600 commit 5562ed7
Showing 75 changed files with 2,512 additions and 1,585 deletions.
Empty file added backend/app/api/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions backend/app/api/admin_routes/embedding_model/routes.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
EmbeddingProviderOption,
embedding_provider_options,
)
from app.rag.embeddings.resolver import get_embed_model
from app.rag.embeddings.resolver import resolve_embed_model

router = APIRouter()
logger = logging.getLogger(__name__)
@@ -57,7 +57,7 @@ def test_embedding_model(
create: EmbeddingModelCreate,
) -> EmbeddingModelTestResult:
try:
embed_model = get_embed_model(
embed_model = resolve_embed_model(
provider=create.provider,
model=create.model,
config=create.config,
Empty file.
13 changes: 13 additions & 0 deletions backend/app/api/admin_routes/knowledge_base/chunk/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel

from app.rag.retrievers.chunk.schema import VectorSearchRetrieverConfig


class KBChunkRetrievalConfig(BaseModel):
vector_search: VectorSearchRetrieverConfig
# TODO: add fulltext and knowledge graph search config


class KBRetrieveChunksRequest(BaseModel):
query: str
retrieval_config: KBChunkRetrievalConfig
38 changes: 38 additions & 0 deletions backend/app/api/admin_routes/knowledge_base/chunk/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

from fastapi import APIRouter
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.retrievers.chunk.simple_retriever import (
ChunkSimpleRetriever,
)
from app.rag.retrievers.chunk.schema import ChunksRetrievalResult

from app.exceptions import InternalServerError, KBNotFound
from .models import KBRetrieveChunksRequest

router = APIRouter()
logger = logging.getLogger(__name__)


@router.post("/admin/knowledge_base/{kb_id}/chunks/retrieve")
def retrieve_chunks(
db_session: SessionDep,
user: CurrentSuperuserDep,
kb_id: int,
request: KBRetrieveChunksRequest,
) -> ChunksRetrievalResult:
try:
vector_search_config = request.retrieval_config.vector_search
retriever = ChunkSimpleRetriever(
db_session=db_session,
knowledge_base_id=kb_id,
config=vector_search_config,
)
return retriever.retrieve_chunks(
request.query,
)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from fastapi import HTTPException
from starlette import status

from app.api.admin_routes.knowledge_base.graph.models import (
KnowledgeRequest,
KnowledgeNeighborRequest,
KnowledgeChunkRequest,
)
from app.api.admin_routes.knowledge_base.graph.routes import router, logger
from app.api.deps import SessionDep
from app.exceptions import KBNotFound, InternalServerError
from app.rag.knowledge_base.index_store import get_kb_tidb_graph_store
from app.repositories import knowledge_base_repo


# Experimental interface


@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge")
def retrieve_knowledge(session: SessionDep, kb_id: int, request: KnowledgeRequest):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.retrieve_graph_data(
request.query,
request.top_k,
request.similarity_threshold,
)
return {
"entities": data["entities"],
"relationships": data["relationships"],
}
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()


@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/neighbors")
def retrieve_knowledge_neighbors(
session: SessionDep, kb_id: int, request: KnowledgeNeighborRequest
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.retrieve_neighbors(
request.entities_ids,
request.query,
request.max_depth,
request.max_neighbors,
request.similarity_threshold,
)
return data
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()


@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/chunks")
def retrieve_knowledge_chunks(
session: SessionDep, kb_id: int, request: KnowledgeChunkRequest
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.get_chunks_by_relationships(request.relationships_ids)
if not data:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No chunks found for the given relationships",
)
return data
except KBNotFound as e:
raise e
except HTTPException as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
20 changes: 20 additions & 0 deletions backend/app/api/admin_routes/knowledge_base/graph/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List, Optional
from pydantic import BaseModel, model_validator

from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrieverConfig,
)


class SynopsisEntityCreate(BaseModel):
name: str
@@ -36,6 +40,22 @@ class GraphSearchRequest(BaseModel):
relationship_meta_filters: dict = {}


# Knowledge Graph Retrieval


class KBKnowledgeGraphRetrievalConfig(BaseModel):
knowledge_graph: KnowledgeGraphRetrieverConfig


class KBRetrieveKnowledgeGraphRequest(BaseModel):
query: str
llm_id: int
retrival_config: KBKnowledgeGraphRetrievalConfig


### Experimental


class KnowledgeRequest(BaseModel):
query: str
similarity_threshold: float = 0.55
103 changes: 31 additions & 72 deletions backend/app/api/admin_routes/knowledge_base/graph/routes.py
Original file line number Diff line number Diff line change
@@ -7,21 +7,25 @@
SynopsisEntityCreate,
EntityUpdate,
RelationshipUpdate,
KBRetrieveKnowledgeGraphRequest,
GraphSearchRequest,
KnowledgeRequest,
KnowledgeNeighborRequest,
KnowledgeChunkRequest,
)
from app.api.deps import SessionDep
from app.exceptions import KBNotFound, InternalServerError
from app.models import (
EntityPublic,
RelationshipPublic,
)
from app.rag.retrievers.knowledge_graph.schema import (
KnowledgeGraphRetrievalResult,
)
from app.rag.knowledge_base.index_store import (
get_kb_tidb_graph_editor,
get_kb_tidb_graph_store,
)
from app.rag.retrievers.knowledge_graph.simple_retriever import (
KnowledgeGraphSimpleRetriever,
)
from app.repositories import knowledge_base_repo

router = APIRouter()
@@ -195,92 +199,47 @@ def update_relationship(
raise e


@router.post("/admin/knowledge_bases/{kb_id}/graph/search")
def search_graph(session: SessionDep, kb_id: int, request: GraphSearchRequest):
@router.post("/admin/knowledge_bases/{kb_id}/graph/retrieve")
def retrieve_kb_knowledge_graph(
db_session: SessionDep, kb_id: int, request: KBRetrieveKnowledgeGraphRequest
) -> KnowledgeGraphRetrievalResult:
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
entities, relations, _ = graph_store.retrieve_with_weight(
request.query,
[],
request.depth,
request.include_meta,
request.with_degree,
False,
request.relationship_meta_filters,
retriever = KnowledgeGraphSimpleRetriever(
db_session=db_session,
knowledge_base_id=kb_id,
config=request.retrival_config.knowledge_graph,
)
knowledge_graph = retriever.retrieve_knowledge_graph(request.query)
return KnowledgeGraphRetrievalResult(
entities=knowledge_graph.entities,
relationships=knowledge_graph.relationships,
)
return {
"entities": entities,
"relationships": relations,
}
except KBNotFound as e:
raise e
except Exception as e:
# TODO: throw InternalServerError
raise e


@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge")
def retrieve_knowledge(session: SessionDep, kb_id: int, request: KnowledgeRequest):
@router.post("/admin/knowledge_bases/{kb_id}/graph/search", deprecated=True)
def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchRequest):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.retrieve_graph_data(
entities, relations = graph_store.retrieve_with_weight(
request.query,
request.top_k,
request.similarity_threshold,
[],
request.depth,
request.include_meta,
request.with_degree,
request.relationship_meta_filters,
)
return {
"entities": data["entities"],
"relationships": data["relationships"],
"entities": entities,
"relationships": relations,
}
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()


@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/neighbors")
def retrieve_knowledge_neighbors(
session: SessionDep, kb_id: int, request: KnowledgeNeighborRequest
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.retrieve_neighbors(
request.entities_ids,
request.query,
request.max_depth,
request.max_neighbors,
request.similarity_threshold,
)
return data
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()


@router.post("/admin/knowledge_bases/{kb_id}/graph/knowledge/chunks")
def retrieve_knowledge_chunks(
session: SessionDep, kb_id: int, request: KnowledgeChunkRequest
):
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
data = graph_store.get_chunks_by_relationships(request.relationships_ids)
if not data:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No chunks found for the given relationships",
)
return data
except KBNotFound as e:
raise e
except HTTPException as e:
# TODO: throw InternalServerError
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
4 changes: 2 additions & 2 deletions backend/app/api/admin_routes/llm/routes.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
LLMProviderOption,
llm_provider_options,
)
from app.rag.llms.resolver import get_llm
from app.rag.llms.resolver import resolve_llm

router = APIRouter()
logger = logging.getLogger(__name__)
@@ -59,7 +59,7 @@ def test_llm(
user: CurrentSuperuserDep,
) -> LLMTestResult:
try:
llm = get_llm(
llm = resolve_llm(
provider=db_llm.provider,
model=db_llm.model,
config=db_llm.config,
4 changes: 2 additions & 2 deletions backend/app/api/admin_routes/reranker_model/routes.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from app.models import RerankerModel, AdminRerankerModel, ChatEngine
from app.repositories.reranker_model import reranker_model_repo
from app.rag.rerankers.provider import RerankerProviderOption, reranker_provider_options
from app.rag.rerankers.resolver import get_reranker_model
from app.rag.rerankers.resolver import resolve_reranker

router = APIRouter()
logger = logging.getLogger(__name__)
@@ -38,7 +38,7 @@ def test_reranker_model(
db_reranker_model: RerankerModel, user: CurrentSuperuserDep
) -> LLMTestResult:
try:
reranker = get_reranker_model(
reranker = resolve_reranker(
provider=db_reranker_model.provider,
model=db_reranker_model.model,
# for testing purpose, we only rerank 2 nodes
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)


@router.get("/admin/retrieve/documents")
@router.get("/admin/retrieve/documents", deprecated=True)
def retrieve_documents(
session: SessionDep,
user: CurrentSuperuserDep,
@@ -42,7 +42,7 @@ def retrieve_documents(
raise InternalServerError()


@router.get("/admin/embedding_retrieve")
@router.get("/admin/embedding_retrieve", deprecated=True)
def embedding_retrieve(
session: SessionDep,
user: CurrentSuperuserDep,
@@ -71,7 +71,7 @@ def embedding_retrieve(
raise InternalServerError()


@router.post("/admin/embedding_retrieve")
@router.post("/admin/embedding_retrieve", deprecated=True)
def embedding_search(
session: SessionDep,
user: CurrentSuperuserDep,
Loading

0 comments on commit 5562ed7

Please sign in to comment.