Skip to content

Latest commit

 

History

History
695 lines (536 loc) · 20.4 KB

File metadata and controls

695 lines (536 loc) · 20.4 KB

开发者指南

Course Materials RAG System 的详细技术指南,专注于顺序工具调用架构的开发和扩展。

🏗️ 架构深度解析

顺序工具调用系统 (Sequential Tool Calling)

这是系统最核心的特性,允许AI智能地进行多轮工具调用来处理复杂查询。

ToolCallSession 状态管理

@dataclass
class ToolCallSession:
    def __init__(self, max_rounds: int = 2, timeout_per_round: float = 10.0):
        self.current_round = 0
        self.max_rounds = max_rounds
        self.timeout_per_round = timeout_per_round
        self.start_time = time.time()
        self.tool_results = []
        self.total_timeout = max_rounds * timeout_per_round

核心设计原则

  • 有界处理:最多2轮调用,总计10秒超时
  • 状态追踪:记录每轮结果,支持上下文传递
  • 优雅降级:超时或失败时返回已有结果

AI Generator 的顺序调用逻辑

def _handle_sequential_tool_execution(self, messages, tools):
    """处理顺序工具调用的核心逻辑"""
    session = ToolCallSession()
    
    while session.current_round < session.max_rounds:
        # 轮次时间检查
        if time.time() - session.start_time > session.total_timeout:
            break
            
        # AI决策:是否需要工具调用
        response = self._call_ai_api(messages, tools)
        
        if response.tool_calls:
            # 执行工具并更新上下文
            tool_result = self._execute_tool_call(response.tool_calls[0])
            session.tool_results.append(tool_result)
            messages.append({"role": "assistant", "content": None, "tool_calls": response.tool_calls})
            messages.append({"role": "tool", "content": tool_result})
            session.current_round += 1
        else:
            # AI决定不需要更多工具调用
            return response.content, session.tool_results
    
    # 最终整合响应
    final_response = self._call_ai_api(messages, tools=[])
    return final_response.content, session.tool_results

关键设计决策

  • AI主导:让AI决定是否需要更多工具调用
  • 上下文保持:每轮结果都添加到消息历史中
  • 时间界限:严格的超时控制防止无限循环

工具架构设计

双工具系统

系统当前支持两个互补的工具:

# backend/search_tools.py

class CourseOutlineTool(Tool):
    """课程结构和大纲查询"""
    def __init__(self, vector_store):
        super().__init__()
        self.name = "course_outline"
        self.description = "获取完整的课程大纲和课程结构信息"
        self.store = vector_store
        
    def execute(self, course_name: str) -> str:
        course_title = self.store._resolve_course_name(course_name)
        # 返回格式化的课程结构
        
class CourseSearchTool(Tool):
    """语义内容搜索"""  
    def __init__(self, vector_store):
        super().__init__()
        self.name = "search_course_content" 
        self.description = "在课程内容中进行语义搜索"
        self.store = vector_store
        
    def execute(self, query: str, course_name: str = None) -> str:
        # 执行语义相似度搜索

工具协调策略

  • 优先级:CourseOutlineTool 注册在前,适合结构化查询
  • 互补性:大纲工具提供结构,搜索工具提供内容
  • 智能选择:AI根据查询意图自动选择合适工具

向量存储架构 (ChromaDB)

双集合设计

# backend/vector_store.py

class VectorStore:
    def __init__(self):
        self.client = chromadb.PersistentClient(path="./chroma_db")
        
        # 课程目录集合:存储课程元数据
        self.course_catalog = self.client.get_or_create_collection(
            name="course_catalog",
            metadata={"hnsw:space": "cosine"}
        )
        
        # 课程内容集合:存储文档块
        self.course_content = self.client.get_or_create_collection(
            name="course_content", 
            metadata={"hnsw:space": "cosine"}
        )

数据分层

  • 目录层:课程标题、链接、讲师信息
  • 内容层:按课程分块的实际内容(800字符块,100重叠)
  • 元数据:课程名、课程编号、块ID用于精确检索

智能课程名解析

def _resolve_course_name(self, user_input: str) -> str:
    """模糊匹配课程名称"""
    catalog_results = self.course_catalog.query(
        query_texts=[user_input], 
        n_results=1
    )
    
    if catalog_results['ids'] and len(catalog_results['ids'][0]) > 0:
        metadata = catalog_results['metadatas'][0][0]
        return metadata.get('course_title', user_input)
    
    return user_input  # 回退到原始输入

🔧 开发工作流

1. 环境设置

# 克隆项目
git clone <repository-url>
cd course-materials-rag

# 设置Python环境(推荐使用uv)
uv sync

# 配置环境变量
cp .env.example .env
# 编辑.env添加API密钥

# 验证安装
cd backend
python -m uvicorn app:app --reload --port 8000

2. 添加新工具

创建新工具时遵循以下模式:

# backend/search_tools.py

class NewCustomTool(Tool):
    """新工具的说明"""
    
    def __init__(self, vector_store):
        super().__init__()
        self.name = "new_tool_name"  # 唯一标识符
        self.description = "AI能理解的工具描述,说明何时使用"
        self.store = vector_store
        
    def execute(self, param1: str, param2: str = None) -> str:
        """
        执行工具逻辑
        
        Args:
            param1: 必需参数说明
            param2: 可选参数说明
            
        Returns:
            格式化的字符串结果,AI将基于此生成回答
        """
        try:
            # 工具执行逻辑
            result = self._process_logic(param1, param2)
            return self._format_result(result)
        except Exception as e:
            return f"工具执行失败:{str(e)}"
            
    def _process_logic(self, param1, param2):
        """私有方法:核心处理逻辑"""
        pass
        
    def _format_result(self, result):
        """私有方法:结果格式化"""
        pass

注册新工具

# backend/rag_system.py

def __init__(self, config: Config):
    # ... 现有初始化代码
    
    # 注册工具(顺序很重要!)
    self.tool_manager.register_tool(CourseOutlineTool(self.vector_store))
    self.tool_manager.register_tool(CourseSearchTool(self.vector_store)) 
    self.tool_manager.register_tool(NewCustomTool(self.vector_store))  # 添加新工具

3. 多AI提供商支持

系统支持DeepSeek和Anthropic Claude,添加新提供商:

# backend/ai_generator.py

class NewAIProviderGenerator(AIGenerator):
    def __init__(self, config: Config):
        super().__init__(config)
        self.client = initialize_new_provider_client(config)
        
    def _call_ai_api(self, messages, tools=None):
        """实现特定提供商的API调用"""
        try:
            response = self.client.chat.completions.create(
                model=self.config.new_provider_model,
                messages=messages,
                tools=self._convert_tools_to_provider_format(tools) if tools else None,
                # 其他提供商特定参数
            )
            return self._parse_response(response)
        except Exception as e:
            raise Exception(f"New AI Provider API 调用失败:{str(e)}")

🧪 测试策略

单元测试

# backend/tests/test_new_feature.py

import pytest
from unittest.mock import Mock, patch
from rag_system import RAGSystem
from config import Config

class TestNewFeature:
    @pytest.fixture
    def mock_rag_system(self):
        config = Config()
        config.ai_provider = "deepseek"  # 测试环境使用mock
        return RAGSystem(config)
    
    def test_simple_tool_call(self, mock_rag_system):
        """测试简单的单工具调用"""
        query = "简单测试查询"
        
        with patch.object(mock_rag_system.ai_generator, 'generate_response') as mock_generate:
            mock_generate.return_value = ("测试回答", [{"text": "测试来源"}])
            
            response, sources = mock_rag_system.query(query)
            
            assert response == "测试回答"
            assert len(sources) == 1
            mock_generate.assert_called_once()
    
    def test_sequential_tool_calls(self, mock_rag_system):
        """测试顺序工具调用"""
        complex_query = "复杂的多部分查询"
        
        # 模拟复杂查询应该触发多次工具调用
        with patch.object(mock_rag_system.ai_generator, 'generate_response') as mock_generate:
            mock_generate.return_value = ("综合回答", [
                {"text": "来源1"}, 
                {"text": "来源2"}
            ])
            
            response, sources = mock_rag_system.query(complex_query)
            
            assert len(sources) >= 2  # 多工具调用应该产生多个来源
            assert "综合" in response  # 多工具调用通常产生更综合的回答

集成测试

# backend/tests/test_integration.py

class TestFullSystemIntegration:
    def test_end_to_end_query(self):
        """完整的端到端查询测试"""
        config = Config()
        rag = RAGSystem(config)
        
        # 测试真实查询(需要有效的API密钥)
        query = "What is the MCP course about?"
        response, sources = rag.query(query)
        
        assert len(response) > 50  # 回答应该有足够内容
        assert len(sources) >= 1   # 至少一个来源
        assert any("MCP" in source.get('text', '') for source in sources)
        
    def test_timeout_protection(self):
        """测试超时保护机制"""
        config = Config()
        rag = RAGSystem(config)
        
        start_time = time.time()
        
        # 执行可能耗时的查询
        response, sources = rag.query("复杂的多轮分析查询")
        
        end_time = time.time()
        duration = end_time - start_time
        
        # 应该在合理时间内完成(15秒以内)
        assert duration < 15.0
        assert response  # 即使超时也应该有回答

负载测试

# backend/tests/test_performance.py

import asyncio
import concurrent.futures

def test_concurrent_queries():
    """测试并发查询处理"""
    config = Config()
    rag = RAGSystem(config)
    
    queries = [
        "What is MCP?",
        "Advanced Retrieval course outline?", 
        "Compare different courses",
        "Find information about neural networks"
    ]
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        futures = [executor.submit(rag.query, q) for q in queries]
        results = [f.result() for f in concurrent.futures.as_completed(futures)]
    
    # 验证所有查询都成功完成
    assert len(results) == len(queries)
    assert all(len(r[0]) > 0 for r in results)  # 所有回答都非空

🔍 调试和故障排除

日志记录策略

# backend/rag_system.py

import logging
import time

# 设置详细的日志记录
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

class RAGSystem:
    def __init__(self, config):
        self.logger = logging.getLogger(__name__)
        # ... 初始化代码
        
    def query(self, query_text: str, session_id: str = None):
        self.logger.info(f"收到查询:{query_text[:100]}...")
        start_time = time.time()
        
        try:
            response, sources = self.ai_generator.generate_response(
                query_text, history, self.tool_manager.get_tools()
            )
            
            end_time = time.time()
            self.logger.info(f"查询完成,用时:{end_time - start_time:.2f}秒")
            self.logger.debug(f"回答长度:{len(response)},来源数量:{len(sources)}")
            
            return response, sources
            
        except Exception as e:
            self.logger.error(f"查询失败:{str(e)}", exc_info=True)
            raise

常见问题诊断

1. 顺序工具调用不生效

症状:期望多轮调用但只看到单次调用

诊断步骤

# 添加调试代码到 ai_generator.py
def _handle_sequential_tool_execution(self, messages, tools):
    print(f"🔍 开始顺序工具调用,工具数量:{len(tools)}")
    session = ToolCallSession()
    
    while session.current_round < session.max_rounds:
        print(f"🔄 第 {session.current_round + 1} 轮工具调用")
        
        response = self._call_ai_api(messages, tools)
        print(f"🤖 AI响应:工具调用={bool(response.tool_calls)}")
        
        if response.tool_calls:
            print(f"🛠️ 执行工具:{response.tool_calls[0].function.name}")
            # ... 继续处理

可能原因及解决

  • AI提供商不支持工具调用:检查API配置
  • 工具描述不清晰:优化工具描述使AI更容易选择
  • 查询不够复杂:尝试更复杂的多部分查询

2. 响应时间过长

症状:查询等待时间超过预期

诊断

# 添加详细的性能监控
import time

class PerformanceMonitor:
    def __init__(self):
        self.times = {}
        
    def start(self, operation: str):
        self.times[operation] = time.time()
        
    def end(self, operation: str):
        if operation in self.times:
            duration = time.time() - self.times[operation]
            print(f"⏱️ {operation}: {duration:.2f}秒")
            return duration
        return 0

# 在关键点使用
monitor = PerformanceMonitor()
monitor.start("AI_API_CALL")
response = self._call_ai_api(messages, tools)
monitor.end("AI_API_CALL")

3. 工具执行失败

症状:工具调用出错或返回空结果

调试方法

def execute_tool_call(self, tool_call):
    try:
        print(f"🛠️ 执行工具:{tool_call.function.name}")
        print(f"📝 参数:{tool_call.function.arguments}")
        
        tool = self.get_tool(tool_call.function.name)
        result = tool.execute_with_args(tool_call.function.arguments)
        
        print(f"✅ 工具执行结果长度:{len(result)}")
        print(f"📄 结果预览:{result[:200]}...")
        
        return result
        
    except Exception as e:
        print(f"❌ 工具执行失败:{str(e)}")
        import traceback
        traceback.print_exc()
        return f"工具执行出错:{str(e)}"

性能优化建议

1. 向量搜索优化

# 优化ChromaDB查询性能
def optimize_vector_search(self):
    """优化向量搜索设置"""
    # 调整集合参数以获得更好性能
    self.course_content = self.client.get_or_create_collection(
        name="course_content",
        metadata={
            "hnsw:space": "cosine",
            "hnsw:construction_ef": 200,    # 更高的构建ef值
            "hnsw:M": 16,                   # 更多连接
            "hnsw:search_ef": 100           # 搜索ef值
        }
    )

2. 缓存策略

from functools import lru_cache

class VectorStore:
    @lru_cache(maxsize=100)
    def search_content_cached(self, query: str, course_name: str = None):
        """缓存常见搜索结果"""
        return self.search_content(query, course_name)
        
    def clear_cache(self):
        """清除缓存(在数据更新后调用)"""
        self.search_content_cached.cache_clear()

🚀 系统扩展指南

添加新的AI提供商

  1. 创建新的Generator类
# backend/ai_generator.py
class NewProviderGenerator(AIGenerator):
    def __init__(self, config: Config):
        super().__init__(config)
        self.client = new_provider_client.Client(api_key=config.new_provider_api_key)
  1. 更新Config类
# backend/config.py  
class Config:
    def __init__(self):
        if self.ai_provider == "new_provider":
            self.new_provider_api_key = os.getenv("NEW_PROVIDER_API_KEY")
            self.new_provider_model = os.getenv("NEW_PROVIDER_MODEL", "default-model")
  1. 更新Factory
# backend/rag_system.py
def _create_ai_generator(self):
    if self.config.ai_provider == "new_provider":
        return NewProviderGenerator(self.config)

添加新的数据源

  1. 扩展DocumentProcessor
# backend/document_processor.py
class ExtendedDocumentProcessor(DocumentProcessor):
    def process_new_format(self, file_path: str):
        """处理新的文档格式"""
        if file_path.endswith('.pdf'):
            return self._process_pdf(file_path)
        elif file_path.endswith('.docx'):
            return self._process_docx(file_path)
        # ... 其他格式
  1. 更新文档加载逻辑
def load_documents(self):
    """加载所有支持格式的文档"""
    supported_extensions = ['.txt', '.pdf', '.docx', '.md']
    
    for ext in supported_extensions:
        files = glob.glob(f"docs/*{ext}")
        for file_path in files:
            processed = self.document_processor.process_file(file_path)
            self.vector_store.add_documents(processed)

实时数据更新

# backend/rag_system.py
import watchdog.observers
import watchdog.events

class DocumentUpdateHandler(watchdog.events.FileSystemEventHandler):
    def __init__(self, rag_system):
        self.rag_system = rag_system
        
    def on_modified(self, event):
        if event.src_path.endswith('.txt'):
            print(f"📄 检测到文档更新:{event.src_path}")
            self.rag_system.reload_document(event.src_path)

class RAGSystem:
    def start_file_watcher(self):
        """启动文档监控"""
        event_handler = DocumentUpdateHandler(self)
        observer = watchdog.observers.Observer()
        observer.schedule(event_handler, "docs/", recursive=True)
        observer.start()
        
    def reload_document(self, file_path: str):
        """重新加载特定文档"""
        # 从向量存储中删除旧版本
        # 处理新版本并添加
        pass

📊 监控和指标

系统监控

# backend/metrics.py
import psutil
import time
from dataclasses import dataclass
from typing import Dict, List

@dataclass  
class QueryMetrics:
    query_time: float
    tool_calls_count: int
    response_length: int
    sources_count: int
    ai_provider: str
    success: bool

class MetricsCollector:
    def __init__(self):
        self.query_history: List[QueryMetrics] = []
        
    def record_query(self, metrics: QueryMetrics):
        self.query_history.append(metrics)
        
    def get_performance_stats(self) -> Dict:
        """获取性能统计"""
        recent_queries = self.query_history[-100:]  # 最近100次查询
        
        return {
            "average_response_time": sum(q.query_time for q in recent_queries) / len(recent_queries),
            "success_rate": sum(1 for q in recent_queries if q.success) / len(recent_queries),
            "average_tool_calls": sum(q.tool_calls_count for q in recent_queries) / len(recent_queries),
            "memory_usage": psutil.Process().memory_info().rss / 1024 / 1024,  # MB
            "cpu_usage": psutil.cpu_percent()
        }

健康检查端点

# backend/app.py
@app.get("/health")
async def health_check():
    """系统健康检查"""
    try:
        # 测试数据库连接
        test_query = rag_system.vector_store.course_catalog.peek(limit=1)
        
        # 测试AI提供商连接
        test_response = rag_system.ai_generator._call_ai_api([
            {"role": "user", "content": "test"}
        ])
        
        metrics = metrics_collector.get_performance_stats()
        
        return {
            "status": "healthy",
            "timestamp": time.time(),
            "database": "connected",
            "ai_provider": "connected", 
            "performance": metrics
        }
        
    except Exception as e:
        return {
            "status": "unhealthy",
            "error": str(e),
            "timestamp": time.time()
        }

这个开发者指南涵盖了系统的核心架构、开发工作流、测试策略、调试方法和扩展指南。对于想要深入了解或扩展这个顺序工具调用RAG系统的开发者来说,这是一个全面的技术参考。