Course Materials RAG System 的详细技术指南,专注于顺序工具调用架构的开发和扩展。
这是系统最核心的特性,允许AI智能地进行多轮工具调用来处理复杂查询。
@dataclass
class ToolCallSession:
def __init__(self, max_rounds: int = 2, timeout_per_round: float = 10.0):
self.current_round = 0
self.max_rounds = max_rounds
self.timeout_per_round = timeout_per_round
self.start_time = time.time()
self.tool_results = []
self.total_timeout = max_rounds * timeout_per_round核心设计原则:
- 有界处理:最多2轮调用,总计10秒超时
- 状态追踪:记录每轮结果,支持上下文传递
- 优雅降级:超时或失败时返回已有结果
def _handle_sequential_tool_execution(self, messages, tools):
"""处理顺序工具调用的核心逻辑"""
session = ToolCallSession()
while session.current_round < session.max_rounds:
# 轮次时间检查
if time.time() - session.start_time > session.total_timeout:
break
# AI决策:是否需要工具调用
response = self._call_ai_api(messages, tools)
if response.tool_calls:
# 执行工具并更新上下文
tool_result = self._execute_tool_call(response.tool_calls[0])
session.tool_results.append(tool_result)
messages.append({"role": "assistant", "content": None, "tool_calls": response.tool_calls})
messages.append({"role": "tool", "content": tool_result})
session.current_round += 1
else:
# AI决定不需要更多工具调用
return response.content, session.tool_results
# 最终整合响应
final_response = self._call_ai_api(messages, tools=[])
return final_response.content, session.tool_results关键设计决策:
- AI主导:让AI决定是否需要更多工具调用
- 上下文保持:每轮结果都添加到消息历史中
- 时间界限:严格的超时控制防止无限循环
系统当前支持两个互补的工具:
# backend/search_tools.py
class CourseOutlineTool(Tool):
"""课程结构和大纲查询"""
def __init__(self, vector_store):
super().__init__()
self.name = "course_outline"
self.description = "获取完整的课程大纲和课程结构信息"
self.store = vector_store
def execute(self, course_name: str) -> str:
course_title = self.store._resolve_course_name(course_name)
# 返回格式化的课程结构
class CourseSearchTool(Tool):
"""语义内容搜索"""
def __init__(self, vector_store):
super().__init__()
self.name = "search_course_content"
self.description = "在课程内容中进行语义搜索"
self.store = vector_store
def execute(self, query: str, course_name: str = None) -> str:
# 执行语义相似度搜索工具协调策略:
- 优先级:CourseOutlineTool 注册在前,适合结构化查询
- 互补性:大纲工具提供结构,搜索工具提供内容
- 智能选择:AI根据查询意图自动选择合适工具
# backend/vector_store.py
class VectorStore:
def __init__(self):
self.client = chromadb.PersistentClient(path="./chroma_db")
# 课程目录集合:存储课程元数据
self.course_catalog = self.client.get_or_create_collection(
name="course_catalog",
metadata={"hnsw:space": "cosine"}
)
# 课程内容集合:存储文档块
self.course_content = self.client.get_or_create_collection(
name="course_content",
metadata={"hnsw:space": "cosine"}
)数据分层:
- 目录层:课程标题、链接、讲师信息
- 内容层:按课程分块的实际内容(800字符块,100重叠)
- 元数据:课程名、课程编号、块ID用于精确检索
def _resolve_course_name(self, user_input: str) -> str:
"""模糊匹配课程名称"""
catalog_results = self.course_catalog.query(
query_texts=[user_input],
n_results=1
)
if catalog_results['ids'] and len(catalog_results['ids'][0]) > 0:
metadata = catalog_results['metadatas'][0][0]
return metadata.get('course_title', user_input)
return user_input # 回退到原始输入# 克隆项目
git clone <repository-url>
cd course-materials-rag
# 设置Python环境(推荐使用uv)
uv sync
# 配置环境变量
cp .env.example .env
# 编辑.env添加API密钥
# 验证安装
cd backend
python -m uvicorn app:app --reload --port 8000创建新工具时遵循以下模式:
# backend/search_tools.py
class NewCustomTool(Tool):
"""新工具的说明"""
def __init__(self, vector_store):
super().__init__()
self.name = "new_tool_name" # 唯一标识符
self.description = "AI能理解的工具描述,说明何时使用"
self.store = vector_store
def execute(self, param1: str, param2: str = None) -> str:
"""
执行工具逻辑
Args:
param1: 必需参数说明
param2: 可选参数说明
Returns:
格式化的字符串结果,AI将基于此生成回答
"""
try:
# 工具执行逻辑
result = self._process_logic(param1, param2)
return self._format_result(result)
except Exception as e:
return f"工具执行失败:{str(e)}"
def _process_logic(self, param1, param2):
"""私有方法:核心处理逻辑"""
pass
def _format_result(self, result):
"""私有方法:结果格式化"""
pass注册新工具:
# backend/rag_system.py
def __init__(self, config: Config):
# ... 现有初始化代码
# 注册工具(顺序很重要!)
self.tool_manager.register_tool(CourseOutlineTool(self.vector_store))
self.tool_manager.register_tool(CourseSearchTool(self.vector_store))
self.tool_manager.register_tool(NewCustomTool(self.vector_store)) # 添加新工具系统支持DeepSeek和Anthropic Claude,添加新提供商:
# backend/ai_generator.py
class NewAIProviderGenerator(AIGenerator):
def __init__(self, config: Config):
super().__init__(config)
self.client = initialize_new_provider_client(config)
def _call_ai_api(self, messages, tools=None):
"""实现特定提供商的API调用"""
try:
response = self.client.chat.completions.create(
model=self.config.new_provider_model,
messages=messages,
tools=self._convert_tools_to_provider_format(tools) if tools else None,
# 其他提供商特定参数
)
return self._parse_response(response)
except Exception as e:
raise Exception(f"New AI Provider API 调用失败:{str(e)}")# backend/tests/test_new_feature.py
import pytest
from unittest.mock import Mock, patch
from rag_system import RAGSystem
from config import Config
class TestNewFeature:
@pytest.fixture
def mock_rag_system(self):
config = Config()
config.ai_provider = "deepseek" # 测试环境使用mock
return RAGSystem(config)
def test_simple_tool_call(self, mock_rag_system):
"""测试简单的单工具调用"""
query = "简单测试查询"
with patch.object(mock_rag_system.ai_generator, 'generate_response') as mock_generate:
mock_generate.return_value = ("测试回答", [{"text": "测试来源"}])
response, sources = mock_rag_system.query(query)
assert response == "测试回答"
assert len(sources) == 1
mock_generate.assert_called_once()
def test_sequential_tool_calls(self, mock_rag_system):
"""测试顺序工具调用"""
complex_query = "复杂的多部分查询"
# 模拟复杂查询应该触发多次工具调用
with patch.object(mock_rag_system.ai_generator, 'generate_response') as mock_generate:
mock_generate.return_value = ("综合回答", [
{"text": "来源1"},
{"text": "来源2"}
])
response, sources = mock_rag_system.query(complex_query)
assert len(sources) >= 2 # 多工具调用应该产生多个来源
assert "综合" in response # 多工具调用通常产生更综合的回答# backend/tests/test_integration.py
class TestFullSystemIntegration:
def test_end_to_end_query(self):
"""完整的端到端查询测试"""
config = Config()
rag = RAGSystem(config)
# 测试真实查询(需要有效的API密钥)
query = "What is the MCP course about?"
response, sources = rag.query(query)
assert len(response) > 50 # 回答应该有足够内容
assert len(sources) >= 1 # 至少一个来源
assert any("MCP" in source.get('text', '') for source in sources)
def test_timeout_protection(self):
"""测试超时保护机制"""
config = Config()
rag = RAGSystem(config)
start_time = time.time()
# 执行可能耗时的查询
response, sources = rag.query("复杂的多轮分析查询")
end_time = time.time()
duration = end_time - start_time
# 应该在合理时间内完成(15秒以内)
assert duration < 15.0
assert response # 即使超时也应该有回答# backend/tests/test_performance.py
import asyncio
import concurrent.futures
def test_concurrent_queries():
"""测试并发查询处理"""
config = Config()
rag = RAGSystem(config)
queries = [
"What is MCP?",
"Advanced Retrieval course outline?",
"Compare different courses",
"Find information about neural networks"
]
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(rag.query, q) for q in queries]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
# 验证所有查询都成功完成
assert len(results) == len(queries)
assert all(len(r[0]) > 0 for r in results) # 所有回答都非空# backend/rag_system.py
import logging
import time
# 设置详细的日志记录
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
class RAGSystem:
def __init__(self, config):
self.logger = logging.getLogger(__name__)
# ... 初始化代码
def query(self, query_text: str, session_id: str = None):
self.logger.info(f"收到查询:{query_text[:100]}...")
start_time = time.time()
try:
response, sources = self.ai_generator.generate_response(
query_text, history, self.tool_manager.get_tools()
)
end_time = time.time()
self.logger.info(f"查询完成,用时:{end_time - start_time:.2f}秒")
self.logger.debug(f"回答长度:{len(response)},来源数量:{len(sources)}")
return response, sources
except Exception as e:
self.logger.error(f"查询失败:{str(e)}", exc_info=True)
raise症状:期望多轮调用但只看到单次调用
诊断步骤:
# 添加调试代码到 ai_generator.py
def _handle_sequential_tool_execution(self, messages, tools):
print(f"🔍 开始顺序工具调用,工具数量:{len(tools)}")
session = ToolCallSession()
while session.current_round < session.max_rounds:
print(f"🔄 第 {session.current_round + 1} 轮工具调用")
response = self._call_ai_api(messages, tools)
print(f"🤖 AI响应:工具调用={bool(response.tool_calls)}")
if response.tool_calls:
print(f"🛠️ 执行工具:{response.tool_calls[0].function.name}")
# ... 继续处理可能原因及解决:
- AI提供商不支持工具调用:检查API配置
- 工具描述不清晰:优化工具描述使AI更容易选择
- 查询不够复杂:尝试更复杂的多部分查询
症状:查询等待时间超过预期
诊断:
# 添加详细的性能监控
import time
class PerformanceMonitor:
def __init__(self):
self.times = {}
def start(self, operation: str):
self.times[operation] = time.time()
def end(self, operation: str):
if operation in self.times:
duration = time.time() - self.times[operation]
print(f"⏱️ {operation}: {duration:.2f}秒")
return duration
return 0
# 在关键点使用
monitor = PerformanceMonitor()
monitor.start("AI_API_CALL")
response = self._call_ai_api(messages, tools)
monitor.end("AI_API_CALL")症状:工具调用出错或返回空结果
调试方法:
def execute_tool_call(self, tool_call):
try:
print(f"🛠️ 执行工具:{tool_call.function.name}")
print(f"📝 参数:{tool_call.function.arguments}")
tool = self.get_tool(tool_call.function.name)
result = tool.execute_with_args(tool_call.function.arguments)
print(f"✅ 工具执行结果长度:{len(result)}")
print(f"📄 结果预览:{result[:200]}...")
return result
except Exception as e:
print(f"❌ 工具执行失败:{str(e)}")
import traceback
traceback.print_exc()
return f"工具执行出错:{str(e)}"# 优化ChromaDB查询性能
def optimize_vector_search(self):
"""优化向量搜索设置"""
# 调整集合参数以获得更好性能
self.course_content = self.client.get_or_create_collection(
name="course_content",
metadata={
"hnsw:space": "cosine",
"hnsw:construction_ef": 200, # 更高的构建ef值
"hnsw:M": 16, # 更多连接
"hnsw:search_ef": 100 # 搜索ef值
}
)from functools import lru_cache
class VectorStore:
@lru_cache(maxsize=100)
def search_content_cached(self, query: str, course_name: str = None):
"""缓存常见搜索结果"""
return self.search_content(query, course_name)
def clear_cache(self):
"""清除缓存(在数据更新后调用)"""
self.search_content_cached.cache_clear()- 创建新的Generator类:
# backend/ai_generator.py
class NewProviderGenerator(AIGenerator):
def __init__(self, config: Config):
super().__init__(config)
self.client = new_provider_client.Client(api_key=config.new_provider_api_key)- 更新Config类:
# backend/config.py
class Config:
def __init__(self):
if self.ai_provider == "new_provider":
self.new_provider_api_key = os.getenv("NEW_PROVIDER_API_KEY")
self.new_provider_model = os.getenv("NEW_PROVIDER_MODEL", "default-model")- 更新Factory:
# backend/rag_system.py
def _create_ai_generator(self):
if self.config.ai_provider == "new_provider":
return NewProviderGenerator(self.config)- 扩展DocumentProcessor:
# backend/document_processor.py
class ExtendedDocumentProcessor(DocumentProcessor):
def process_new_format(self, file_path: str):
"""处理新的文档格式"""
if file_path.endswith('.pdf'):
return self._process_pdf(file_path)
elif file_path.endswith('.docx'):
return self._process_docx(file_path)
# ... 其他格式- 更新文档加载逻辑:
def load_documents(self):
"""加载所有支持格式的文档"""
supported_extensions = ['.txt', '.pdf', '.docx', '.md']
for ext in supported_extensions:
files = glob.glob(f"docs/*{ext}")
for file_path in files:
processed = self.document_processor.process_file(file_path)
self.vector_store.add_documents(processed)# backend/rag_system.py
import watchdog.observers
import watchdog.events
class DocumentUpdateHandler(watchdog.events.FileSystemEventHandler):
def __init__(self, rag_system):
self.rag_system = rag_system
def on_modified(self, event):
if event.src_path.endswith('.txt'):
print(f"📄 检测到文档更新:{event.src_path}")
self.rag_system.reload_document(event.src_path)
class RAGSystem:
def start_file_watcher(self):
"""启动文档监控"""
event_handler = DocumentUpdateHandler(self)
observer = watchdog.observers.Observer()
observer.schedule(event_handler, "docs/", recursive=True)
observer.start()
def reload_document(self, file_path: str):
"""重新加载特定文档"""
# 从向量存储中删除旧版本
# 处理新版本并添加
pass# backend/metrics.py
import psutil
import time
from dataclasses import dataclass
from typing import Dict, List
@dataclass
class QueryMetrics:
query_time: float
tool_calls_count: int
response_length: int
sources_count: int
ai_provider: str
success: bool
class MetricsCollector:
def __init__(self):
self.query_history: List[QueryMetrics] = []
def record_query(self, metrics: QueryMetrics):
self.query_history.append(metrics)
def get_performance_stats(self) -> Dict:
"""获取性能统计"""
recent_queries = self.query_history[-100:] # 最近100次查询
return {
"average_response_time": sum(q.query_time for q in recent_queries) / len(recent_queries),
"success_rate": sum(1 for q in recent_queries if q.success) / len(recent_queries),
"average_tool_calls": sum(q.tool_calls_count for q in recent_queries) / len(recent_queries),
"memory_usage": psutil.Process().memory_info().rss / 1024 / 1024, # MB
"cpu_usage": psutil.cpu_percent()
}# backend/app.py
@app.get("/health")
async def health_check():
"""系统健康检查"""
try:
# 测试数据库连接
test_query = rag_system.vector_store.course_catalog.peek(limit=1)
# 测试AI提供商连接
test_response = rag_system.ai_generator._call_ai_api([
{"role": "user", "content": "test"}
])
metrics = metrics_collector.get_performance_stats()
return {
"status": "healthy",
"timestamp": time.time(),
"database": "connected",
"ai_provider": "connected",
"performance": metrics
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
"timestamp": time.time()
}这个开发者指南涵盖了系统的核心架构、开发工作流、测试策略、调试方法和扩展指南。对于想要深入了解或扩展这个顺序工具调用RAG系统的开发者来说,这是一个全面的技术参考。