Skip to content

Commit e6659fe

Browse files
committed
feat: enhance RAG service with background processing for unprocessed files and improve file handling logic
1 parent 39e50ce commit e6659fe

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

runtime/datamate-python/app/module/rag/service/graph_rag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightrag.utils import setup_logger, EmbeddingFunc
99

1010
setup_logger("lightrag", level="DEBUG")
11-
DEFAULT_WORKING_DIR = "/rag_storage"
11+
DEFAULT_WORKING_DIR = os.path.join(os.getcwd(), "rag_storage")
1212

1313

1414
async def build_llm_model_func(model_name: str, base_url: str, api_key: str) -> Callable[..., Awaitable[str]]:
@@ -33,7 +33,7 @@ async def build_embedding_func(
3333
model_name: str, base_url: str, api_key: str, embedding_dim: int
3434
) -> EmbeddingFunc:
3535
async def _embedding_func(texts: list[str]) -> np.ndarray:
36-
return await openai_embed.func(
36+
return await openai_embed(
3737
texts,
3838
model=model_name,
3939
api_key=api_key,

runtime/datamate-python/app/module/rag/service/rag_service.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,44 @@
11
import os
2-
from typing import Optional
2+
from typing import Optional, Sequence
33

4-
from fastapi import Depends
4+
from fastapi import BackgroundTasks, Depends
55
from sqlalchemy import select
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8-
from app.db.models.knowledge_gen import RagKnowledgeBase
8+
from app.core.logging import get_logger
9+
from app.db.models.dataset_management import DatasetFiles
10+
from app.db.models.knowledge_gen import RagFile, RagKnowledgeBase
911
from app.db.models.model_config import ModelConfig
10-
from app.db.session import AsyncSessionLocal
12+
from app.db.session import get_db
13+
from app.module.shared.common.document_loaders import load_documents
1114
from .graph_rag import (
1215
DEFAULT_WORKING_DIR,
1316
build_embedding_func,
1417
build_llm_model_func,
1518
initialize_rag,
1619
)
1720

21+
logger = get_logger(__name__)
22+
1823

1924
class RAGService:
2025
def __init__(
2126
self,
22-
db: AsyncSession = Depends(AsyncSessionLocal),
27+
db: AsyncSession = Depends(get_db),
28+
background_tasks: BackgroundTasks | None = None,
2329
):
2430
self.db = db
31+
self.background_tasks = background_tasks
2532
self.rag = None
2633

27-
28-
async def get_unprocessed_files(self, knowledge_base_id: str) -> list[str]:
29-
pass
34+
async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFile]:
35+
result = await self.db.execute(
36+
select(RagFile).where(
37+
RagFile.knowledge_base_id == knowledge_base_id,
38+
RagFile.status != "PROCESSED",
39+
)
40+
)
41+
return result.scalars().all()
3042

3143
async def init_graph_rag(self, knowledge_base_id: str):
3244
kb = await self._get_knowledge_base(knowledge_base_id)
@@ -45,8 +57,45 @@ async def init_graph_rag(self, knowledge_base_id: str):
4557

4658
kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name)
4759
self.rag = await initialize_rag(llm_callable, embedding_callable, kb_working_dir)
60+
61+
if self.background_tasks is not None:
62+
self.background_tasks.add_task(self._process_pending_files, knowledge_base_id)
63+
else:
64+
await self._process_pending_files(knowledge_base_id)
65+
4866
return {"status": "initialized", "knowledge_base_id": knowledge_base_id}
4967

68+
async def _process_pending_files(self, knowledge_base_id: str):
69+
rag_files = await self.get_unprocessed_files(knowledge_base_id)
70+
if not rag_files:
71+
logger.info(f"No pending files to process for knowledge base {knowledge_base_id}")
72+
return
73+
74+
for rag_file in rag_files:
75+
await self._process_single_file(rag_file)
76+
77+
async def _process_single_file(self, rag_file: RagFile):
78+
dataset_file = await self._get_dataset_file(rag_file.file_id)
79+
documents = load_documents(dataset_file.file_path)
80+
for doc in documents:
81+
await self.rag.ainsert(text=doc.page_content)
82+
await self._mark_file_processed(rag_file)
83+
84+
async def _get_dataset_file(self, file_id: str) -> DatasetFiles:
85+
result = await self.db.execute(
86+
select(DatasetFiles).where(DatasetFiles.id == file_id)
87+
)
88+
dataset_file = result.scalars().first()
89+
if not dataset_file:
90+
raise ValueError(f"Dataset file with ID {file_id} not found.")
91+
return dataset_file
92+
93+
async def _mark_file_processed(self, rag_file: RagFile):
94+
rag_file.status = "PROCESSED"
95+
self.db.add(rag_file)
96+
await self.db.commit()
97+
await self.db.refresh(rag_file)
98+
5099
async def _get_knowledge_base(self, knowledge_base_id: str):
51100
result = await self.db.execute(
52101
select(RagKnowledgeBase).where(RagKnowledgeBase.id == knowledge_base_id)

0 commit comments

Comments
 (0)