Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions astrbot/cli/commands/cmd_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from pathlib import Path

import anyio
import click
from filelock import FileLock, Timeout

Expand Down Expand Up @@ -48,7 +48,7 @@ def init() -> None:

try:
with lock.acquire():
asyncio.run(initialize_astrbot(astrbot_root))
anyio.run(initialize_astrbot, astrbot_root)
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")

Expand Down
6 changes: 3 additions & 3 deletions astrbot/cli/commands/cmd_run.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import asyncio
import os
import sys
import traceback
from pathlib import Path

import anyio
import click
from filelock import FileLock, Timeout

from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root


async def run_astrbot(astrbot_root: Path):
async def run_astrbot(astrbot_root: Path) -> None:
"""运行 AstrBot"""
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
Expand Down Expand Up @@ -53,7 +53,7 @@ def run(reload: bool, port: str) -> None:
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
with lock.acquire():
asyncio.run(run_astrbot(astrbot_root))
anyio.run(run_astrbot, astrbot_root)
except KeyboardInterrupt:
click.echo("AstrBot 已关闭...")
except Timeout:
Expand Down
13 changes: 9 additions & 4 deletions astrbot/core/core_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import threading
import time
import traceback
from asyncio import Queue

import anyio

from astrbot.core import LogBroker, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
Expand Down Expand Up @@ -104,7 +105,9 @@ async def initialize(self) -> None:
logger.error(traceback.format_exc())

# 初始化事件队列
self.event_queue = Queue()
self._event_queue_send, self.event_queue = anyio.create_memory_object_stream[
object
](0)

# 初始化人格管理器
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
Expand All @@ -118,7 +121,9 @@ async def initialize(self) -> None:
)

# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.platform_manager = PlatformManager(
self.astrbot_config, self._event_queue_send
)

# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
Expand All @@ -131,7 +136,7 @@ async def initialize(self) -> None:

# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
self._event_queue_send,
self.astrbot_config,
self.db,
self.provider_manager,
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.user_id) == user_id
col(ConversationV2.user_id) == user_id,
),
)

Expand Down
24 changes: 13 additions & 11 deletions astrbot/core/event_bus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""事件总线, 用于处理事件的分发和处理
"""事件总线, 用于处理事件的分发和处理.

事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑

Expand All @@ -10,8 +11,8 @@
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
"""

import asyncio
from asyncio import Queue
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream

from astrbot.core import logger
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
Expand All @@ -25,28 +26,29 @@ class EventBus:

def __init__(
self,
event_queue: Queue,
event_queue: MemoryObjectReceiveStream[AstrMessageEvent],
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
):
astrbot_config_mgr: AstrBotConfigManager | None = None,
) -> None:
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr

async def dispatch(self):
async def dispatch(self) -> None:
while True:
event: AstrMessageEvent = await self.event_queue.get()
event: AstrMessageEvent = await self.event_queue.receive()
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
asyncio.create_task(scheduler.execute(event))
anyio.create_task(scheduler.execute(event))

def _print_event(self, event: AstrMessageEvent, conf_name: str):
def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None:
"""用于记录事件信息

Args:
event (AstrMessageEvent): 事件对象
event: 事件对象
conf_name: 配置名称

"""
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
Expand Down
9 changes: 5 additions & 4 deletions astrbot/core/file_token_service.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import asyncio
import os
import platform
import time
import uuid
from urllib.parse import unquote, urlparse

import anyio


class FileTokenService:
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""

def __init__(self, default_timeout: float = 300):
self.lock = asyncio.Lock()
self.staged_files = {} # token: (file_path, expire_time)
def __init__(self, default_timeout: float = 300) -> None:
self.lock = anyio.Lock()
self.staged_files: dict = {} # token: (file_path, expire_time)
self.default_timeout = default_timeout

async def _cleanup_expired_tokens(self):
Expand Down
9 changes: 5 additions & 4 deletions astrbot/core/pipeline/rate_limit_check/stage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from collections import defaultdict, deque
from collections.abc import AsyncGenerator
from datetime import datetime, timedelta

import anyio

from astrbot.core import logger
from astrbot.core.config.astrbot_config import RateLimitStrategy
from astrbot.core.platform.astr_message_event import AstrMessageEvent
Expand All @@ -19,11 +20,11 @@ class RateLimitStage(Stage):
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
"""

def __init__(self):
def __init__(self) -> None:
# 存储每个会话的请求时间队列
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
# 为每个会话设置一个锁,避免并发冲突
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self.locks: defaultdict[str, anyio.Lock] = defaultdict(anyio.Lock)
# 限流参数
self.rate_limit_count: int = 0
self.rate_limit_time: timedelta = timedelta(0)
Expand Down Expand Up @@ -74,7 +75,7 @@ async def process(
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。",
)
await asyncio.sleep(stall_duration)
await anyio.sleep(stall_duration)
now = datetime.now()
case RateLimitStrategy.DISCARD.value:
logger.info(
Expand Down
5 changes: 3 additions & 2 deletions astrbot/core/platform/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import traceback
from asyncio import Queue

from anyio.streams.memory import MemoryObjectSendStream

from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
Expand All @@ -12,7 +13,7 @@


class PlatformManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue):
def __init__(self, config: AstrBotConfig, event_queue: MemoryObjectSendStream):
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""

Expand Down
7 changes: 4 additions & 3 deletions astrbot/core/platform/platform.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import abc
import uuid
from asyncio import Queue
from collections.abc import Awaitable
from typing import Any

from anyio.streams.memory import MemoryObjectSendStream

from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.metrics import Metric

Expand All @@ -13,7 +14,7 @@


class Platform(abc.ABC):
def __init__(self, event_queue: Queue):
def __init__(self, event_queue: MemoryObjectSendStream):
super().__init__()
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
self._event_queue = event_queue
Expand Down Expand Up @@ -45,7 +46,7 @@ async def send_by_session(

def commit_event(self, event: AstrMessageEvent):
"""提交一个事件到事件队列。"""
self._event_queue.put_nowait(event)
self._event_queue.send_nowait(event)

def get_client(self):
"""获取平台的客户端对象。"""
2 changes: 1 addition & 1 deletion astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def handle_msg(self, abm: AstrBotMessage):
client=self.client,
)

self._event_queue.put_nowait(event)
self._event_queue.send_nowait(event)

async def run(self):
# await self.client_.start()
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/platform/sources/lark/lark_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def handle_msg(self, abm: AstrBotMessage):
bot=self.lark_api,
)

self._event_queue.put_nowait(event)
self._event_queue.send_nowait(event)

async def run(self):
# self.client.start()
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/platform/sources/slack/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import hashlib
import hmac
import json
import logging
from collections.abc import Callable

import anyio
from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
logging.getLogger("quart.app").setLevel(logging.WARNING)
logging.getLogger("quart.serving").setLevel(logging.WARNING)

self.shutdown_event = asyncio.Event()
self.shutdown_event = anyio.Event()

def _setup_routes(self):
"""设置路由"""
Expand Down
3 changes: 2 additions & 1 deletion astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""企业微信智能机器人 API 客户端
"""企业微信智能机器人 API 客户端.

处理消息加密解密、API 调用等
"""

Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
处理企业微信智能机器人的 HTTP 回调请求
"""

import asyncio
from collections.abc import Callable
from typing import Any

import anyio
import quart

from astrbot.api import logger
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
self.app = quart.Quart(__name__)
self._setup_routes()

self.shutdown_event = asyncio.Event()
self.shutdown_event = anyio.Event()

def _setup_routes(self):
"""设置 Quart 路由"""
Expand Down
13 changes: 7 additions & 6 deletions astrbot/core/provider/func_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import aiohttp
import anyio

from astrbot import logger
from astrbot.core import sp
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self) -> None:
self.func_list: list[FuncTool] = []
self.mcp_client_dict: dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_client_event: dict[str, asyncio.Event] = {}
self.mcp_client_event: dict[str, anyio.Event] = {}

def empty(self) -> bool:
return len(self.func_list) == 0
Expand Down Expand Up @@ -206,7 +207,7 @@ async def init_mcp_clients(self) -> None:
for name in mcp_server_json_obj:
cfg = mcp_server_json_obj[name]
if cfg.get("active", True):
event = asyncio.Event()
event = anyio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event),
)
Expand All @@ -216,7 +217,7 @@ async def _init_mcp_client_task_wrapper(
self,
name: str,
cfg: dict,
event: asyncio.Event,
event: anyio.Event,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
Expand Down Expand Up @@ -307,7 +308,7 @@ async def enable_mcp_server(
self,
name: str,
config: dict,
event: asyncio.Event | None = None,
event: anyio.Event | None = None,
ready_future: asyncio.Future | None = None,
timeout: int = 30,
) -> None:
Expand All @@ -316,7 +317,7 @@ async def enable_mcp_server(
Args:
name (str): The name of the MCP server.
config (dict): Configuration for the MCP server.
event (asyncio.Event): Event to signal when the MCP client is ready.
event (anyio.Event): Event to signal when the MCP client is ready.
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
timeout (int): Timeout for the initialization.

Expand All @@ -326,7 +327,7 @@ async def enable_mcp_server(

"""
if not event:
event = asyncio.Event()
event = anyio.Event()
if not ready_future:
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
Expand Down
Loading
Loading