11import os
2- from typing import Optional
2+ from typing import Optional , Sequence
33
4- from fastapi import Depends
4+ from fastapi import BackgroundTasks , Depends
55from sqlalchemy import select
66from sqlalchemy .ext .asyncio import AsyncSession
77
8- from app .db .models .knowledge_gen import RagKnowledgeBase
8+ from app .core .logging import get_logger
9+ from app .db .models .dataset_management import DatasetFiles
10+ from app .db .models .knowledge_gen import RagFile , RagKnowledgeBase
911from app .db .models .model_config import ModelConfig
10- from app .db .session import AsyncSessionLocal
12+ from app .db .session import get_db
13+ from app .module .shared .common .document_loaders import load_documents
1114from .graph_rag import (
1215 DEFAULT_WORKING_DIR ,
1316 build_embedding_func ,
1417 build_llm_model_func ,
1518 initialize_rag ,
1619)
1720
21+ logger = get_logger (__name__ )
22+
1823
1924class RAGService :
2025 def __init__ (
2126 self ,
22- db : AsyncSession = Depends (AsyncSessionLocal ),
27+ db : AsyncSession = Depends (get_db ),
28+ background_tasks : BackgroundTasks | None = None ,
2329 ):
2430 self .db = db
31+ self .background_tasks = background_tasks
2532 self .rag = None
2633
27-
28- async def get_unprocessed_files (self , knowledge_base_id : str ) -> list [str ]:
29- pass
34+ async def get_unprocessed_files (self , knowledge_base_id : str ) -> Sequence [RagFile ]:
35+ result = await self .db .execute (
36+ select (RagFile ).where (
37+ RagFile .knowledge_base_id == knowledge_base_id ,
38+ RagFile .status != "PROCESSED" ,
39+ )
40+ )
41+ return result .scalars ().all ()
3042
3143 async def init_graph_rag (self , knowledge_base_id : str ):
3244 kb = await self ._get_knowledge_base (knowledge_base_id )
@@ -45,8 +57,45 @@ async def init_graph_rag(self, knowledge_base_id: str):
4557
4658 kb_working_dir = os .path .join (DEFAULT_WORKING_DIR , kb .name )
4759 self .rag = await initialize_rag (llm_callable , embedding_callable , kb_working_dir )
60+
61+ if self .background_tasks is not None :
62+ self .background_tasks .add_task (self ._process_pending_files , knowledge_base_id )
63+ else :
64+ await self ._process_pending_files (knowledge_base_id )
65+
4866 return {"status" : "initialized" , "knowledge_base_id" : knowledge_base_id }
4967
68+ async def _process_pending_files (self , knowledge_base_id : str ):
69+ rag_files = await self .get_unprocessed_files (knowledge_base_id )
70+ if not rag_files :
71+ logger .info (f"No pending files to process for knowledge base { knowledge_base_id } " )
72+ return
73+
74+ for rag_file in rag_files :
75+ await self ._process_single_file (rag_file )
76+
77+ async def _process_single_file (self , rag_file : RagFile ):
78+ dataset_file = await self ._get_dataset_file (rag_file .file_id )
79+ documents = load_documents (dataset_file .file_path )
80+ for doc in documents :
81+ await self .rag .ainsert (text = doc .page_content )
82+ await self ._mark_file_processed (rag_file )
83+
84+ async def _get_dataset_file (self , file_id : str ) -> DatasetFiles :
85+ result = await self .db .execute (
86+ select (DatasetFiles ).where (DatasetFiles .id == file_id )
87+ )
88+ dataset_file = result .scalars ().first ()
89+ if not dataset_file :
90+ raise ValueError (f"Dataset file with ID { file_id } not found." )
91+ return dataset_file
92+
93+ async def _mark_file_processed (self , rag_file : RagFile ):
94+ rag_file .status = "PROCESSED"
95+ self .db .add (rag_file )
96+ await self .db .commit ()
97+ await self .db .refresh (rag_file )
98+
5099 async def _get_knowledge_base (self , knowledge_base_id : str ):
51100 result = await self .db .execute (
52101 select (RagKnowledgeBase ).where (RagKnowledgeBase .id == knowledge_base_id )
0 commit comments