Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions app/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from pydantic import BaseModel, Field, model_validator

import torch
from sentence_transformers import SentenceTransformer, util

from app.llm import LLM
from app.logger import logger
from app.sandbox.client import SANDBOX_CLIENT
from app.schema import ROLE_TYPE, AgentState, Memory, Message


class BaseAgent(BaseModel, ABC):
class BaseAgent(BaseModel, ABC): # ABC:用于规范子类行为,例如step()方法,如果子类不实现就无法初始化
"""Abstract base class for managing agent state and execution.

Provides foundational functionality for state transitions, memory management,
Expand All @@ -30,23 +33,30 @@ class BaseAgent(BaseModel, ABC):
)

# Dependencies
llm: LLM = Field(default_factory=LLM, description="Language model instance")
memory: Memory = Field(default_factory=Memory, description="Agent's memory store")
state: AgentState = Field(
default=AgentState.IDLE, description="Current agent state"
llm: LLM = Field(default_factory=LLM, description="Language model instance") # default_factory和default是有区别的
memory: Memory = Field(default_factory=Memory, description="Agent's memory store") # default写法使得所有BaseAgent实例
state: AgentState = Field( # 共享一样Memory实例; 而default_factory写法在每次实例化BaseAgent时都会调用一次Memory进行实例化
default=AgentState.IDLE, description="Current agent state" # default写法对于一些可变对象是危险的(dict、list、自定义类等)
)

# Execution control
max_steps: int = Field(default=10, description="Maximum steps before termination")
current_step: int = Field(default=0, description="Current step in execution")

duplicate_threshold: int = 2
duplicate_count: int = 1
embedding_model: SentenceTransformer = Field(
default_factory=lambda: SentenceTransformer("all-MiniLM-L6-v2")
)
# store the embedding of the pre AI msg, avoiding calculate it again:
last_msg_emb: dict = Field(default=dict())

class Config:
arbitrary_types_allowed = True
extra = "allow" # Allow extra fields for flexibility in subclasses
class Config: # Pydantic模型中的一个内部配置类,不同于python中的Config类
arbitrary_types_allowed = True # 除了标准类型(如int、str)和已注册的Pydantic模型,可以使用自定义类作为字段类型,如前面 llm: LLM=...
extra = "allow" # Allow extra fields for flexibility in subclasses(如果设置为forbid,传入了模型中未定义的字段时会抛出错误,
# 而现在这些字段会被保留在 .model_extra 中)

@model_validator(mode="after")
@model_validator(mode="after") # 在模型初始化完成后执行额外逻辑。
def initialize_agent(self) -> "BaseAgent":
"""Initialize agent with default settings if not provided."""
if self.llm is None or not isinstance(self.llm, LLM):
Expand All @@ -55,7 +65,7 @@ def initialize_agent(self) -> "BaseAgent":
self.memory = Memory()
return self

@asynccontextmanager
@asynccontextmanager # 装饰器,用于定义异步上下文管理器,可以用async with来使用它
async def state_context(self, new_state: AgentState):
"""Context manager for safe agent state transitions.

Expand All @@ -74,11 +84,13 @@ async def state_context(self, new_state: AgentState):
previous_state = self.state
self.state = new_state
try:
yield
except Exception as e:
yield # 上下文管理器的核心:
# 把控制权暂时交给 async with 块中的代码(即下方的 async with self.state_context(AgentState.RUNNING) )
# 在代码块执行期间,agent的状态为 `new_state`
except Exception as e: # 如果async with块中的代码发生错误,则修改状态并抛出异常
self.state = AgentState.ERROR # Transition to ERROR on failure
raise e
finally:
finally: # 无论最后是否发生异常,都将状态恢复成原来的状态
self.state = previous_state # Revert to previous state

def update_memory(
Expand Down Expand Up @@ -141,19 +153,24 @@ async def run(self, request: Optional[str] = None) -> str:
step_result = await self.step()

# Check for stuck state
if self.is_stuck():
# if self.is_stuck():
# self.handle_stuck_state()
self.update_duplicate_count()
if self.duplicate_count == 2:
self.handle_stuck_state()
elif self.duplicate_count > 2:
self.state = AgentState.FINISHED

results.append(f"Step {self.current_step}: {step_result}")

self.state = AgentState.IDLE
if self.current_step >= self.max_steps:
self.current_step = 0
self.state = AgentState.IDLE
results.append(f"Terminated: Reached max steps ({self.max_steps})")
await SANDBOX_CLIENT.cleanup()
return "\n".join(results) if results else "No steps executed"

@abstractmethod
@abstractmethod # 抽象方法,当前(父类)只定义接口,不提供逻辑,且要求子类必须实现
async def step(self) -> str:
"""Execute a single step in the agent's workflow.

Expand All @@ -167,6 +184,34 @@ def handle_stuck_state(self):
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}"
logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}")

def update_duplicate_count(self):
"""Update the duplicate count of the agent's memory."""
if len(self.memory.messages) < 2:
return

last_message = self.memory.messages[-1]
if not last_message.content:
return

pre_content = None
if last_message.role not in self.last_msg_emb:
for msg in reversed(self.memory.messages[:-1]):
if msg.role == last_message.role:
pre_content = msg.content
break
if not pre_content:
return
self.last_msg_emb[last_message.role] = self.embedding_model.encode(pre_content, convert_to_tensor=True)

# calculate the semantic similarity of the last two AI message:
latest_emb = self.embedding_model.encode(last_message.content, convert_to_tensor=True)
similarity = util.cos_sim(self.last_msg_emb[last_message.role], latest_emb).item()
if similarity > 0.8:
self.duplicate_count += 1
elif self.duplicate_count != 1: # reset it if not detecting duplicate content
self.duplicate_count = 1
self.last_msg_emb[last_message.role] = latest_emb

def is_stuck(self) -> bool:
"""Check if the agent is stuck in a loop by detecting duplicate content"""
if len(self.memory.messages) < 2:
Expand Down
4 changes: 3 additions & 1 deletion app/agent/toolcall.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ async def act(self) -> str:
)
self.memory.add_message(tool_msg)
results.append(result)
if self.state == AgentState.FINISHED: # if the tool used now is Terminate, stop to execute other tools.
break

return "\n\n".join(results)

Expand Down Expand Up @@ -245,6 +247,6 @@ async def cleanup(self):
async def run(self, request: Optional[str] = None) -> str:
"""Run the agent with cleanup when done."""
try:
return await super().run(request)
return await super().run(request) # 执行完后先不return,先执行finally,最后再return
finally:
await self.cleanup()
13 changes: 7 additions & 6 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ class SandboxSettings(BaseModel):


class DaytonaSettings(BaseModel):
# BaseModel: Pydantic提供的核心类,可以根据传入的字典的字段定义自动构造实例,如果传入字典没有某个字段,按照默认值设置
daytona_api_key: str
daytona_server_url: Optional[str] = Field(
daytona_server_url: Optional[str] = Field( # Field: Pydantic提供的函数,可用于设置默认值、元数据(如描述)以及验证规则
"https://app.daytona.io/api", description=""
)
daytona_target: Optional[str] = Field("us", description="enum ['eu', 'us']")
Expand Down Expand Up @@ -196,17 +197,17 @@ class Config:

class Config:
_instance = None
_lock = threading.Lock()
_lock = threading.Lock() # 线程锁,用于确保在多线程环境下,只有一个线程能进入临界区,避免并发创建多个实例
_initialized = False

def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
with cls._lock: # 进入锁保护区域,其它线程必须等待锁释放后才能进入
if cls._instance is None: # 双重检查
cls._instance = super().__new__(cls) # 创建新实例
return cls._instance

def __init__(self):
def __init__(self): # __new__后自动调用
if not self._initialized:
with self._lock:
if not self._initialized:
Expand Down
21 changes: 11 additions & 10 deletions app/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,21 @@ def count_message_tokens(self, messages: List[dict]) -> int:


class LLM:
_instances: Dict[str, "LLM"] = {}
_instances: Dict[str, "LLM"] = {} # 一个类级别的缓存字典,键是config_name,值是对应的LLM实例,以实现“每个配置只创建一次LLM实例”

def __new__(
def __new__( # 在__init__之前调用
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
):
if config_name not in cls._instances:
instance = super().__new__(cls)
instance.__init__(config_name, llm_config)
cls._instances[config_name] = instance
return cls._instances[config_name]
if config_name not in cls._instances: # 如果这个配置名没有被创建过,那么
instance = super().__new__(cls) # 创建一个新实例
instance.__init__(config_name, llm_config) # 由于__new__会绕过自动调用__init__,因此必须显示调用它,否则对象不会被初始化
cls._instances[config_name] = instance # 缓存起来
return cls._instances[config_name] # 返回缓存对象

def __init__(
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
):
if not hasattr(self, "client"): # Only initialize if not already initialized
if not hasattr(self, "client"): # Only initialize if not already initialized(因为__new__可能会多次返回同一个实例)
llm_config = llm_config or config.llm
llm_config = llm_config.get(config_name, llm_config["default"])
self.model = llm_config.model
Expand Down Expand Up @@ -635,9 +635,10 @@ async def ask_with_images(
raise

@retry(
wait=wait_random_exponential(min=1, max=60),
wait=wait_random_exponential(min=1, max=60), # 指数退避+随即抖动: 每次等待时间呈指数级增长,但限制在[1s,60s]间,
# 并增加随机抖动以避免客户端同时失败后在同一时间点同时重试,即等待时间不是严格的指数数值
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
retry=retry_if_exception_type( # 只在发生以下error时重试
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
Expand Down
6 changes: 3 additions & 3 deletions app/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel, Field


class Role(str, Enum):
class Role(str, Enum): # 表示定义一个枚举类,该类的每个成员都是str型
"""Message role options"""

SYSTEM = "system"
Expand Down Expand Up @@ -96,12 +96,12 @@ def to_dict(self) -> dict:
message["base64_image"] = self.base64_image
return message

@classmethod
@classmethod # 一种工厂方法,第一个参数必须是cls表示当前类。该方法可以通过类名或实例调用,但它不会访问实例的属性,智能访问类级别的东西
def user_message(
cls, content: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a user message"""
return cls(role=Role.USER, content=content, base64_image=base64_image)
return cls(role=Role.USER, content=content, base64_image=base64_image) # cls(...)表示用当前类构造一个实例

@classmethod
def system_message(cls, content: str) -> "Message":
Expand Down
2 changes: 1 addition & 1 deletion app/tool/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
class ToolResult(BaseModel):
"""Represents the result of a tool execution."""

output: Any = Field(default=None)
output: Any = Field(default=None) # 这里使用的是default方法,因此可以跨实例共享参数,实现工具使用结果的累积
error: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
system: Optional[str] = Field(default=None)
Expand Down
20 changes: 12 additions & 8 deletions app/tool/browser_use_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,14 +383,18 @@ async def execute(

content = markdownify.markdownify(await page.content())

prompt = f"""\
Your task is to extract the content of the page. You will be given a page and a goal, and you should extract all relevant information around this goal from the page. If the goal is vague, summarize the page. Respond in json format.
Extraction goal: {goal}

Page content:
{content[:max_content_length]}
"""
messages = [{"role": "system", "content": prompt}]
system_prompt = """Your task is to extract the content of the page. You will be given a page and \
a goal, and you should extract all relevant information around this goal from the page. If the goal\
is vague, summarize the page. Respond in json format."""
user_prompt = f"""Extraction goal: {goal}

Page content:
{content[:max_content_length]}"""

messages = [ # if the llm model the Manus use is gemini, the messages must contain user message.
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]

# Define extraction function schema
extraction_function = {
Expand Down
22 changes: 16 additions & 6 deletions app/tool/create_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class CreateChatCompletion(BaseTool):
name: str = "create_chat_completion"
name: str = "create_chat_completion" # 标识工具名称与用途,供智能体注册和调用
description: str = (
"Creates a structured completion with specified output formatting."
)
Expand All @@ -21,7 +21,10 @@ class CreateChatCompletion(BaseTool):
list: "array",
}
response_type: Optional[Type] = None
required: List[str] = Field(default_factory=lambda: ["response"])
required: List[str] = Field(default_factory=lambda: ["response"]) # default_factory接受一个可调用对象,
# 因此需要借用lambda初始化,而不能直接赋值为["response"],如果要直接赋值,应该使用default,但是我们在agent/base.py提到,default方式会使得
# 跨实例共享参数,因此这里采用default_factory + lambda结合的方式
# 这里required是指需要模型输出的东西

def __init__(self, response_type: Optional[Type] = str):
"""Initialize with a specific response type."""
Expand All @@ -43,9 +46,9 @@ def _build_parameters(self) -> dict:
"required": self.required,
}

if isinstance(self.response_type, type) and issubclass(
if isinstance(self.response_type, type) and issubclass( # (所有类对象都是type的实例)
self.response_type, BaseModel
):
): # 如果self.response_type是一个Pydantic模型类,就用它的json schema来构建参数定义
schema = self.response_type.model_json_schema()
return {
"type": "object",
Expand All @@ -57,7 +60,7 @@ def _build_parameters(self) -> dict:

def _create_type_schema(self, type_hint: Type) -> dict:
"""Create a JSON schema for the given type."""
origin = get_origin(type_hint)
origin = get_origin(type_hint) # python泛型,用于获取容器的外层和内部参数类型
args = get_args(type_hint)

# Handle primitive types
Expand Down Expand Up @@ -105,7 +108,14 @@ def _create_type_schema(self, type_hint: Type) -> dict:
if origin is Union:
return self._create_union_schema(args)

return self._build_parameters()
# return self._build_parameters() # 如果self.response_type不是str、None、list、dict、Union,将发生死循环
return {
"type": "object",
"properties": {
"response": {"type": "string", "description": "Fallback schema"}
},
"required": self.required,
}

def _get_type_info(self, type_hint: Type) -> dict:
"""Get type information for a single type."""
Expand Down
1 change: 1 addition & 0 deletions config/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# prevent the local config file from being uploaded to the remote repository
config.toml
mcp.json
6 changes: 6 additions & 0 deletions config/config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ temperature = 0.0 # Controls randomness for vision mode
#timeout = 300
#network_enabled = true

[daytona]
daytona_api_key = "YOUR API KEY"
daytona_server_url = "https://app.daytona.io/api"
daytona_target = "us"


# MCP (Model Context Protocol) configuration
[mcp]
server_reference = "app.mcp.server" # default server module reference
Expand Down
10 changes: 10 additions & 0 deletions config/mcp.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
"server1": {
"type": "sse",
"url": "http://localhost:8000/sse"
},

"filesystem": {
"type": "stdio",
"command": "YOUR PATH TO npx.cmd",
"args": [
"-y",
"@modelcontextprotocol/server-filesystem",
"YOUR PATH TO THE workspace FOLDER"
]
}
}
}
Loading