diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index fea654f20..090b28ec6 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -2,7 +2,9 @@ import click import hashlib import zoneinfo -from typing import Any, Callable +from typing import Any + +from collections.abc import Callable from ..utils import get_astrbot_root, check_astrbot_root @@ -100,7 +102,7 @@ def _save_config(config: dict[str, Any]) -> None: ) -def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: +def _set_nested_item(obj: dict[str, Any], path: str, value: object) -> None: """设置嵌套字典中的值""" parts = path.split(".") for part in parts[:-1]: @@ -114,7 +116,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: obj[parts[-1]] = value -def _get_nested_item(obj: dict[str, Any], path: str) -> Any: +def _get_nested_item(obj: dict[str, Any], path: str) -> object: """获取嵌套字典中的值""" parts = path.split(".") for part in parts: @@ -123,7 +125,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any: @click.group(name="conf") -def conf(): +def conf() -> None: """配置管理命令 支持的配置项: @@ -146,7 +148,7 @@ def conf(): @conf.command(name="set") @click.argument("key") @click.argument("value") -def set_config(key: str, value: str): +def set_config(key: str, value: str) -> None: """设置配置项的值""" if key not in CONFIG_VALIDATORS.keys(): raise click.ClickException(f"不支持的配置项: {key}") @@ -175,7 +177,7 @@ def set_config(key: str, value: str): @conf.command(name="get") @click.argument("key", required=False) -def get_config(key: str = None): +def get_config(key: str | None = None) -> None: """获取配置项的值,不提供key则显示所有可配置项""" config = _load_config() diff --git a/astrbot/cli/commands/cmd_init.py b/astrbot/cli/commands/cmd_init.py index d9a42f822..f9724c55a 100644 --- a/astrbot/cli/commands/cmd_init.py +++ b/astrbot/cli/commands/cmd_init.py @@ -1,3 +1,5 @@ +from pathlib import Path + import asyncio import click @@ -6,7 +8,7 @@ from ..utils import check_dashboard, get_astrbot_root -async def initialize_astrbot(astrbot_root) -> None: +async def initialize_astrbot(astrbot_root: Path) -> None: """执行 AstrBot 初始化逻辑""" dot_astrbot = astrbot_root / ".astrbot" diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index b250ede4b..d554fff59 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -16,7 +16,7 @@ @click.group() -def plug(): +def plug() -> None: """插件管理""" pass @@ -30,7 +30,11 @@ def _get_data_path() -> Path: return (base / "data").resolve() -def display_plugins(plugins, title=None, color=None): +def display_plugins( + plugins: list[dict], + title: str | None = None, + color: int | tuple[int, int, int] | str | None = None, +) -> None: if title: click.echo(click.style(title, fg=color, bold=True)) @@ -47,7 +51,7 @@ def display_plugins(plugins, title=None, color=None): @plug.command() @click.argument("name") -def new(name: str): +def new(name: str) -> None: """创建新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" / name @@ -86,7 +90,7 @@ def new(name: str): f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n") # 重写 main.py - with open(plug_path / "main.py", "r", encoding="utf-8") as f: + with open(plug_path / "main.py", encoding="utf-8") as f: content = f.read() new_content = content.replace( @@ -102,7 +106,7 @@ def new(name: str): @plug.command() @click.option("--all", "-a", is_flag=True, help="列出未安装的插件") -def list(all: bool): +def list(all: bool) -> None: """列出插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -143,7 +147,7 @@ def list(all: bool): @plug.command() @click.argument("name") @click.option("--proxy", help="代理服务器地址") -def install(name: str, proxy: str | None): +def install(name: str, proxy: str | None) -> None: """安装插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -166,7 +170,7 @@ def install(name: str, proxy: str | None): @plug.command() @click.argument("name") -def remove(name: str): +def remove(name: str) -> None: """卸载插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -189,7 +193,7 @@ def remove(name: str): @plug.command() @click.argument("name", required=False) @click.option("--proxy", help="Github代理地址") -def update(name: str, proxy: str | None): +def update(name: str, proxy: str | None) -> None: """更新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -227,7 +231,7 @@ def update(name: str, proxy: str | None): @plug.command() @click.argument("query") -def search(query: str): +def search(query: str) -> None: """搜索插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 38113744f..d707e401c 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -11,7 +11,7 @@ from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root -async def run_astrbot(astrbot_root: Path): +async def run_astrbot(astrbot_root: Path) -> None: """运行 AstrBot""" from astrbot.core import logger, LogManager, LogBroker, db_helper from astrbot.core.initial_loader import InitialLoader diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index cd1fcd97b..12d0af03c 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -19,7 +19,7 @@ class PluginStatus(str, Enum): NOT_PUBLISHED = "未发布" -def get_git_repo(url: str, target_path: Path, proxy: str | None = None): +def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None: """从 Git 仓库下载代码并解压到指定路径""" temp_dir = Path(tempfile.mkdtemp()) try: diff --git a/astrbot/cli/utils/version_comparator.py b/astrbot/cli/utils/version_comparator.py index fecab885e..505d01e0d 100644 --- a/astrbot/cli/utils/version_comparator.py +++ b/astrbot/cli/utils/version_comparator.py @@ -17,7 +17,7 @@ def compare_version(v1: str, v2: str) -> int: v1 = v1.lower().replace("v", "") v2 = v2.lower().replace("v", "") - def split_version(version): + def split_version(version: str) -> tuple[list[int], list[int | str] | None]: match = re.match( r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$", version, @@ -79,7 +79,7 @@ def split_version(version): return 0 # 数字部分和预发布标签都相同 @staticmethod - def _split_prerelease(prerelease): + def _split_prerelease(prerelease: str) -> list[int | str] | None: if not prerelease: return None parts = prerelease.split(".") diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index d26463147..d218bd503 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,4 +1,4 @@ -from typing import Generic +from typing import Any, Generic from .tool import FunctionTool from .agent import Agent from .run_context import TContext @@ -8,8 +8,11 @@ class HandoffTool(FunctionTool, Generic[TContext]): """Handoff tool for delegating tasks to another agent.""" def __init__( - self, agent: Agent[TContext], parameters: dict | None = None, **kwargs - ): + self, + agent: Agent[TContext], + parameters: dict | None = None, + **kwargs: Any, # noqa: ANN401 + ) -> None: self.agent = agent super().__init__( name=f"transfer_to_{agent.name}", diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 884fe6bd4..f26baa45e 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -8,20 +8,20 @@ @dataclass class BaseAgentRunHooks(Generic[TContext]): - async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... + async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, - ): ... + ) -> None: ... async def on_tool_end( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, tool_result: mcp.types.CallToolResult | None, - ): ... + ) -> None: ... async def on_agent_done( self, run_context: ContextWrapper[TContext], llm_response: LLMResponse - ): ... + ) -> None: ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 8db9d6f26..7bbb67235 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -1,7 +1,6 @@ import asyncio import logging from datetime import timedelta -from typing import Optional from contextlib import AsyncExitStack from astrbot import logger from astrbot.core.utils.log_pipe import LogPipe @@ -94,9 +93,9 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: - def __init__(self): + def __init__(self) -> None: # Initialize session and client objects - self.session: Optional[mcp.ClientSession] = None + self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() self.name: str | None = None @@ -105,7 +104,7 @@ def __init__(self): self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() - async def connect_to_server(self, mcp_server_config: dict, name: str): + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """连接到 MCP 服务器 如果 `url` 参数存在: @@ -118,7 +117,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str): """ cfg = _prepare_config(mcp_server_config.copy()) - def logging_callback(msg: str): + def logging_callback(msg: str) -> None: # 处理 MCP 服务的错误日志 print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -188,7 +187,7 @@ def logging_callback(msg: str): **cfg, ) - def callback(msg: str): + def callback(msg: str) -> None: # 处理 MCP 服务的错误日志 self.server_errlogs.append(msg) @@ -218,7 +217,7 @@ async def list_tools_and_save(self) -> mcp.ListToolsResult: self.tools = response.tools return response - async def cleanup(self): + async def cleanup(self) -> None: """Clean up resources""" await self.exit_stack.aclose() self.running_event.set() # Set the running event to indicate cleanup is done diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 83821ae29..182e498d5 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -26,7 +26,7 @@ async def reset( run_context: ContextWrapper[TContext], tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], - **kwargs: T.Any, + **kwargs: T.Any, # noqa: ANN401 ) -> None: """ Reset the agent to its initial state. diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 33298e895..dc0230e29 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -69,7 +69,7 @@ async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: yield await self.provider.text_chat(**self.req.__dict__) @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ Process a single step of the agent. This method should return the result of the step. diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index ae0ab761c..cdaeba028 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,6 +1,10 @@ from dataclasses import dataclass from deprecated import deprecated -from typing import Awaitable, Callable, Literal, Any, Optional +from typing import Literal, Any + +from collections.abc import Iterator + +from collections.abc import Awaitable, Callable from .mcp_client import MCPClient @@ -30,10 +34,10 @@ class FunctionTool: mcp_client: MCPClient | None = None """MCP 客户端,当 origin 为 mcp 时有效""" - def __repr__(self): + def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" - def __dict__(self) -> dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """将 FunctionTool 转换为字典格式""" return { "name": self.name, @@ -51,14 +55,14 @@ class ToolSet: This class provides methods to add, remove, and retrieve tools, as well as convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).""" - def __init__(self, tools: list[FunctionTool] | None = None): + def __init__(self, tools: list[FunctionTool] | None = None) -> None: self.tools: list[FunctionTool] = tools or [] def empty(self) -> bool: """Check if the tool set is empty.""" return len(self.tools) == 0 - def add_tool(self, tool: FunctionTool): + def add_tool(self, tool: FunctionTool) -> None: """Add a tool to the set.""" # 检查是否已存在同名工具 for i, existing_tool in enumerate(self.tools): @@ -67,11 +71,11 @@ def add_tool(self, tool: FunctionTool): return self.tools.append(tool) - def remove_tool(self, name: str): + def remove_tool(self, name: str) -> None: """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] - def get_tool(self, name: str) -> Optional[FunctionTool]: + def get_tool(self, name: str) -> FunctionTool | None: """Get a tool by its name.""" for tool in self.tools: if tool.name == name: @@ -85,7 +89,7 @@ def add_func( func_args: list, desc: str, handler: Callable[..., Awaitable[Any]], - ): + ) -> None: """Add a function tool to the set.""" params = { "type": "object", # hard-coded here @@ -105,7 +109,7 @@ def add_func( self.add_tool(_func) @deprecated(reason="Use remove_tool() instead", version="4.0.0") - def remove_func(self, name: str): + def remove_func(self, name: str) -> None: """Remove a function tool by its name.""" self.remove_tool(name) @@ -236,32 +240,34 @@ def convert_schema(schema: dict) -> dict: return declarations @deprecated(reason="Use openai_schema() instead", version="4.0.0") - def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False): + def get_func_desc_openai_style( + self, omit_empty_parameter_field: bool = False + ) -> list[dict]: return self.openai_schema(omit_empty_parameter_field) @deprecated(reason="Use anthropic_schema() instead", version="4.0.0") - def get_func_desc_anthropic_style(self): + def get_func_desc_anthropic_style(self) -> list[dict]: return self.anthropic_schema() @deprecated(reason="Use google_schema() instead", version="4.0.0") - def get_func_desc_google_genai_style(self): + def get_func_desc_google_genai_style(self) -> dict: return self.google_schema() def names(self) -> list[str]: """获取所有工具的名称列表""" return [tool.name for tool in self.tools] - def __len__(self): + def __len__(self) -> int: return len(self.tools) - def __bool__(self): + def __bool__(self) -> bool: return len(self.tools) > 0 - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.tools) - def __repr__(self): + def __repr__(self) -> str: return f"ToolSet(tools={self.tools})" - def __str__(self): + def __str__(self) -> str: return f"ToolSet(tools={self.tools})" diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 34a2f5e77..39578c4ed 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -1,5 +1,7 @@ import mcp -from typing import Any, Generic, AsyncGenerator +from typing import Any, Generic + +from collections.abc import AsyncGenerator from .run_context import TContext, ContextWrapper from .tool import FunctionTool @@ -7,5 +9,8 @@ class BaseFunctionToolExecutor(Generic[TContext]): @classmethod async def execute( - cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args + cls, + tool: FunctionTool, + run_context: ContextWrapper[TContext], + **tool_args: Any, # noqa: ANN401 ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 0ee3f4fe6..c99a67296 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -35,7 +35,7 @@ def __init__( default_config: AstrBotConfig, ucr: UmopConfigRouter, sp: SharedPreferences, - ): + ) -> None: self.sp = sp self.ucr = ucr self.confs: dict[str, AstrBotConfig] = {} @@ -52,7 +52,7 @@ def _get_abconf_data(self) -> dict: ) return self.abconf_data - def _load_all_configs(self): + def _load_all_configs(self) -> None: """Load all configurations from the shared preferences.""" abconf_data = self._get_abconf_data() self.abconf_data = abconf_data diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 5d1f6fbe7..cd8a89196 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -3,7 +3,6 @@ import logging import enum from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP -from typing import Dict from astrbot.core.utils.astrbot_path import get_astrbot_data_path ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") @@ -23,14 +22,19 @@ class AstrBotConfig(dict): - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 """ + # Class-level 属性注解,帮助类型检查器正确推断实例属性类型 + config_path: str + default_config: dict + schema: dict | None + first_deploy: bool | None + def __init__( self, config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, - schema: dict = None, - ): + schema: dict | None = None, + ) -> None: super().__init__() - # 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件 object.__setattr__(self, "config_path", config_path) object.__setattr__(self, "default_config", default_config) @@ -45,7 +49,7 @@ def __init__( json.dump(default_config, f, indent=4, ensure_ascii=False) object.__setattr__(self, "first_deploy", True) # 标记第一次部署 - with open(config_path, "r", encoding="utf-8-sig") as f: + with open(config_path, encoding="utf-8-sig") as f: conf_str = f.read() conf = json.loads(conf_str) @@ -61,7 +65,7 @@ def _config_schema_to_default_config(self, schema: dict) -> dict: """将 Schema 转换成 Config""" conf = {} - def _parse_schema(schema: dict, conf: dict): + def _parse_schema(schema: dict, conf: dict) -> None: for k, v in schema.items(): if v["type"] not in DEFAULT_VALUE_MAP: raise TypeError( @@ -82,7 +86,9 @@ def _parse_schema(schema: dict, conf: dict): return conf - def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""): + def check_config_integrity( + self, refer_conf: dict, conf: dict, path: str = "" + ) -> bool: """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False @@ -140,7 +146,7 @@ def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""): return has_new - def save_config(self, replace_config: Dict = None): + def save_config(self, replace_config: dict | None = None) -> None: """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config @@ -150,20 +156,20 @@ def save_config(self, replace_config: Dict = None): with open(self.config_path, "w", encoding="utf-8-sig") as f: json.dump(self, f, indent=2, ensure_ascii=False) - def __getattr__(self, item): + def __getattr__(self, item: str) -> object: try: return self[item] except KeyError: return None - def __delattr__(self, key): + def __delattr__(self, key: str) -> None: try: del self[key] self.save_config() except KeyError: raise AttributeError(f"没有找到 Key: '{key}'") - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: object) -> None: self[key] = value def check_exist(self) -> bool: diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 8f8e2e0e9..a4262c998 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -6,8 +6,10 @@ """ import json +from typing import Any from astrbot.core import sp -from typing import Dict, List, Callable, Awaitable + +from collections.abc import Callable, Awaitable from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 @@ -15,13 +17,13 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" - def __init__(self, db_helper: BaseDatabase): - self.session_conversations: Dict[str, str] = {} + def __init__(self, db_helper: BaseDatabase) -> None: + self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 # 会话删除回调函数列表(用于级联清理,如知识库配置) - self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = [] + self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = [] def register_on_session_deleted( self, callback: Callable[[str], Awaitable[None]] @@ -100,7 +102,9 @@ async def new_conversation( await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) return conv.conversation_id - async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): + async def switch_conversation( + self, unified_msg_origin: str, conversation_id: str + ) -> None: """切换会话的对话 Args: @@ -112,7 +116,7 @@ async def switch_conversation(self, unified_msg_origin: str, conversation_id: st async def delete_conversation( self, unified_msg_origin: str, conversation_id: str | None = None - ): + ) -> None: """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: @@ -128,7 +132,7 @@ async def delete_conversation( self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def delete_conversations_by_user_id(self, unified_msg_origin: str): + async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 Args: @@ -182,7 +186,7 @@ async def get_conversation( async def get_conversations( self, unified_msg_origin: str | None = None, platform_id: str | None = None - ) -> List[Conversation]: + ) -> list[Conversation]: """获取对话列表 Args: @@ -206,7 +210,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: Any, # noqa: ANN401 ) -> tuple[list[Conversation], int]: """获取过滤后的对话列表 @@ -238,7 +242,7 @@ async def update_conversation( history: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, - ): + ) -> None: """更新会话的对话 Args: @@ -259,7 +263,7 @@ async def update_conversation( async def update_conversation_title( self, unified_msg_origin: str, title: str, conversation_id: str | None = None - ): + ) -> None: """更新会话的对话标题 Args: @@ -280,7 +284,7 @@ async def update_conversation_persona_id( unified_msg_origin: str, persona_id: str, conversation_id: str | None = None, - ): + ) -> None: """更新会话的对话 Persona ID Args: @@ -297,8 +301,12 @@ async def update_conversation_persona_id( ) async def get_human_readable_context( - self, unified_msg_origin, conversation_id, page=1, page_size=10 - ): + self, + unified_msg_origin: str, + conversation_id: str, + page: int = 1, + page_size: int = 10, + ) -> tuple[list[str], int]: """获取人类可读的上下文 Args: diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 0abd3ad49..cb2269235 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -35,7 +35,7 @@ def __init__(self) -> None: self.engine, class_=AsyncSession, expire_on_commit=False ) - async def initialize(self): + async def initialize(self) -> None: """初始化数据库连接""" pass @@ -100,7 +100,7 @@ async def get_conversations( ... @abc.abstractmethod - async def get_conversation_by_id(self, cid: str) -> ConversationV2: + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: """Get a specific conversation by its ID.""" ... @@ -118,7 +118,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: T.Any, # noqa: ANN401 ) -> tuple[list[ConversationV2], int]: """Get conversations filtered by platform IDs and search query.""" ... @@ -145,7 +145,7 @@ async def update_conversation( title: str | None = None, persona_id: str | None = None, content: list[dict] | None = None, - ) -> None: + ) -> ConversationV2 | None: """Update a conversation's history.""" ... @@ -167,7 +167,7 @@ async def insert_platform_message_history( content: dict, sender_id: str | None = None, sender_name: str | None = None, - ) -> None: + ) -> PlatformMessageHistory | None: """Insert a new platform message history record.""" ... @@ -195,12 +195,12 @@ async def insert_attachment( path: str, type: str, mime_type: str, - ): + ) -> Attachment: """Insert a new attachment record.""" ... @abc.abstractmethod - async def get_attachment_by_id(self, attachment_id: str) -> Attachment: + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" ... @@ -216,7 +216,7 @@ async def insert_persona( ... @abc.abstractmethod - async def get_persona_by_id(self, persona_id: str) -> Persona: + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" ... @@ -249,7 +249,9 @@ async def insert_preference_or_update( ... @abc.abstractmethod - async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference: + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by scope ID and key.""" ... diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index 901cdc4ed..257d69ab6 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -32,7 +32,7 @@ async def do_migration_v4( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], astrbot_config: AstrBotConfig, -): +) -> None: """ 执行数据库迁移 迁移旧的 webchat_conversation 表到新的 conversation 表。 diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 4aa5082db..21874aa12 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -37,7 +37,7 @@ def get_platform_type( async def migration_conversation_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db") ) @@ -91,7 +91,7 @@ async def migration_conversation_table( async def migration_platform_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db") ) @@ -166,7 +166,7 @@ async def migration_platform_table( async def migration_webchat_data( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db") @@ -219,7 +219,7 @@ async def migration_webchat_data( async def migration_persona_data( db_helper: BaseDatabase, astrbot_config: AstrBotConfig -): +) -> None: """ 迁移 Persona 数据到新的表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 @@ -261,7 +261,7 @@ async def migration_persona_data( async def migration_preferences( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: # 1. global scope migration keys = [ "inactivated_llm_tools", diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index 8a1dc5de7..c01bdfc1d 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -3,7 +3,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter -async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter): +async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None: abconf_data = acm.abconf_data if not isinstance(abconf_data, dict): diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 6a661bd3d..e137078ee 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -7,39 +7,39 @@ class SharedPreferences: - def __init__(self, path=None): + def __init__(self, path: str | None = None) -> None: if path is None: path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") self.path = path self._data = self._load_preferences() - def _load_preferences(self): + def _load_preferences(self) -> dict: if os.path.exists(self.path): try: - with open(self.path, "r") as f: + with open(self.path) as f: return json.load(f) except json.JSONDecodeError: os.remove(self.path) return {} - def _save_preferences(self): + def _save_preferences(self) -> None: with open(self.path, "w") as f: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() - def get(self, key, default: _VT = None) -> _VT: + def get(self, key: str, default: _VT = None) -> _VT: return self._data.get(key, default) - def put(self, key, value): + def put(self, key: str, value: object) -> None: self._data[key] = value self._save_preferences() - def remove(self, key): + def remove(self, key: str) -> None: if key in self._data: del self._data[key] self._save_preferences() - def clear(self): + def clear(self) -> None: self._data.clear() self._save_preferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index ad86c51f3..4a258f4c9 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -1,7 +1,7 @@ import sqlite3 import time from astrbot.core.db.po import Platform, Stats -from typing import Tuple, List, Dict, Any +from typing import Any from dataclasses import dataclass @@ -126,7 +126,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: Tuple = None): + def _exec_sql(self, sql: str, params: tuple | None = None) -> None: conn = self.conn try: c = self.conn.cursor() @@ -143,7 +143,7 @@ def _exec_sql(self, sql: str, params: Tuple = None): conn.commit() - def insert_platform_metrics(self, metrics: dict): + def insert_platform_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -152,7 +152,7 @@ def insert_platform_metrics(self, metrics: dict): (k, v, int(time.time())), ) - def insert_llm_metrics(self, metrics: dict): + def insert_llm_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -225,7 +225,9 @@ def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: return Stats(platform, [], []) - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: + def get_conversation_by_user_id( + self, user_id: str, cid: str + ) -> Conversation | None: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -246,7 +248,7 @@ def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: return Conversation(*res) - def new_conversation(self, user_id: str, cid: str): + def new_conversation(self, user_id: str, cid: str) -> None: history = "[]" updated_at = int(time.time()) created_at = updated_at @@ -257,7 +259,7 @@ def new_conversation(self, user_id: str, cid: str): (user_id, cid, history, updated_at, created_at), ) - def get_conversations(self, user_id: str) -> Tuple: + def get_conversations(self, user_id: str) -> list[Conversation]: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -284,7 +286,7 @@ def get_conversations(self, user_id: str) -> Tuple: ) return conversations - def update_conversation(self, user_id: str, cid: str, history: str): + def update_conversation(self, user_id: str, cid: str, history: str) -> None: """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( @@ -294,7 +296,7 @@ def update_conversation(self, user_id: str, cid: str, history: str): (history, updated_at, user_id, cid), ) - def update_conversation_title(self, user_id: str, cid: str, title: str): + def update_conversation_title(self, user_id: str, cid: str, title: str) -> None: self._exec_sql( """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? @@ -302,7 +304,9 @@ def update_conversation_title(self, user_id: str, cid: str, title: str): (title, user_id, cid), ) - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + def update_conversation_persona_id( + self, user_id: str, cid: str, persona_id: str + ) -> None: self._exec_sql( """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? @@ -310,7 +314,7 @@ def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str (persona_id, user_id, cid), ) - def delete_conversation(self, user_id: str, cid: str): + def delete_conversation(self, user_id: str, cid: str) -> None: self._exec_sql( """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? @@ -320,7 +324,7 @@ def delete_conversation(self, user_id: str, cid: str): def get_all_conversations( self, page: int = 1, page_size: int = 20 - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> tuple[list[dict[str, Any]], int]: """获取所有对话,支持分页,按更新时间降序排序""" try: c = self.conn.cursor() @@ -381,12 +385,12 @@ def get_filtered_conversations( self, page: int = 1, page_size: int = 20, - platforms: List[str] = None, - message_types: List[str] = None, - search_query: str = None, - exclude_ids: List[str] = None, - exclude_platforms: List[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: + platforms: list[str] | None = None, + message_types: list[str] | None = None, + search_query: str | None = None, + exclude_ids: list[str] | None = None, + exclude_platforms: list[str] | None = None, + ) -> tuple[list[dict[str, Any]], int]: """获取筛选后的对话列表""" try: c = self.conn.cursor() diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 24a05f947..fa6fe8fee 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -9,7 +9,7 @@ UniqueConstraint, Field, ) -from typing import Optional, TypedDict +from typing import TypedDict class PlatformStat(SQLModel, table=True): @@ -50,14 +50,14 @@ class ConversationV2(SQLModel, table=True): ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) - content: Optional[list] = Field(default=None, sa_type=JSON) + content: list | None = Field(default=None, sa_type=JSON) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) - title: Optional[str] = Field(default=None, max_length=255) - persona_id: Optional[str] = Field(default=None) + title: str | None = Field(default=None, max_length=255) + persona_id: str | None = Field(default=None) __table_args__ = ( UniqueConstraint( @@ -80,9 +80,9 @@ class Persona(SQLModel, table=True): ) persona_id: str = Field(max_length=255, nullable=False) system_prompt: str = Field(sa_type=Text, nullable=False) - begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON) + begin_dialogs: list | None = Field(default=None, sa_type=JSON) """a list of strings, each representing a dialog to start with""" - tools: Optional[list] = Field(default=None, sa_type=JSON) + tools: list | None = Field(default=None, sa_type=JSON) """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names.""" created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( @@ -142,10 +142,8 @@ class PlatformMessageHistory(SQLModel, table=True): ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) # An id of group, user in platform - sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform - sender_name: Optional[str] = Field( - default=None - ) # Name of the sender in the platform + sender_id: str | None = Field(default=None) # ID of the sender in the platform + sender_name: str | None = Field(default=None) # Name of the sender in the platform content: dict = Field(sa_type=JSON, nullable=False) # a message chain list created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index f9faede19..2c469c92b 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -46,10 +46,10 @@ async def initialize(self) -> None: async def insert_platform_stats( self, - platform_id, - platform_type, - count=1, - timestamp=None, + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime | None = None, ) -> None: """Insert a new platform statistic record.""" async with self.get_db() as session: @@ -108,7 +108,9 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformSt # Conversation Management # ==== - async def get_conversations(self, user_id=None, platform_id=None): + async def get_conversations( + self, user_id: str | None = None, platform_id: str | None = None + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2) @@ -121,16 +123,18 @@ async def get_conversations(self, user_id=None, platform_id=None): query = query.order_by(desc(ConversationV2.created_at)) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type: ignore - async def get_conversation_by_id(self, cid): + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2).where(ConversationV2.conversation_id == cid) result = await session.execute(query) return result.scalar_one_or_none() - async def get_all_conversations(self, page=1, page_size=20): + async def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size @@ -140,16 +144,16 @@ async def get_all_conversations(self, page=1, page_size=20): .offset(offset) .limit(page_size) ) - return result.scalars().all() + return result.scalars().all() # type:ignore async def get_filtered_conversations( self, - page=1, - page_size=20, - platform_ids=None, - search_query="", - **kwargs, - ): + page: int = 1, + page_size: int = 20, + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs: T.Any, # noqa:ANN401 + ) -> tuple[list[ConversationV2], int]: async with self.get_db() as session: session: AsyncSession # Build the base query with filters @@ -194,19 +198,19 @@ async def get_filtered_conversations( result = await session.execute(result_query) conversations = result.scalars().all() - return conversations, total + return conversations, total # type:ignore async def create_conversation( self, - user_id, - platform_id, - content=None, - title=None, - persona_id=None, - cid=None, - created_at=None, - updated_at=None, - ): + user_id: str, + platform_id: str, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, + cid: str | None = None, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> ConversationV2: kwargs = {} if cid: kwargs["conversation_id"] = cid @@ -228,7 +232,13 @@ async def create_conversation( session.add(new_conversation) return new_conversation - async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async def update_conversation( + self, + cid: str, + title: str | None = None, + persona_id: str | None = None, + content: list[dict] | None = None, + ) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -248,7 +258,7 @@ async def update_conversation(self, cid, title=None, persona_id=None, content=No await session.execute(query) return await self.get_conversation_by_id(cid) - async def delete_conversation(self, cid): + async def delete_conversation(self, cid: str) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -268,10 +278,10 @@ async def delete_conversations_by_user_id(self, user_id: str) -> None: async def get_session_conversations( self, - page=1, - page_size=20, - search_query=None, - platform=None, + page: int = 1, + page_size: int = 20, + search_query: str | None = None, + platform: str | None = None, ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details.""" async with self.get_db() as session: @@ -375,12 +385,12 @@ async def get_session_conversations( async def insert_platform_message_history( self, - platform_id, - user_id, - content, - sender_id=None, - sender_name=None, - ): + platform_id: str, + user_id: str, + content: dict, + sender_id: str | None = None, + sender_name: str | None = None, + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" async with self.get_db() as session: session: AsyncSession @@ -396,8 +406,8 @@ async def insert_platform_message_history( return new_history async def delete_platform_message_offset( - self, platform_id, user_id, offset_sec=86400 - ): + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: """Delete platform message history records older than the specified offset.""" async with self.get_db() as session: session: AsyncSession @@ -413,8 +423,8 @@ async def delete_platform_message_offset( ) async def get_platform_message_history( - self, platform_id, user_id, page=1, page_size=20 - ): + self, platform_id: str, user_id: str, page: int = 1, page_size: int = 20 + ) -> list[PlatformMessageHistory]: """Get platform message history records.""" async with self.get_db() as session: session: AsyncSession @@ -428,9 +438,11 @@ async def get_platform_message_history( .order_by(desc(PlatformMessageHistory.created_at)) ) result = await session.execute(query.offset(offset).limit(page_size)) - return result.scalars().all() + return result.scalars().all() # type:ignore - async def insert_attachment(self, path, type, mime_type): + async def insert_attachment( + self, path: str, type: str, mime_type: str + ) -> Attachment: """Insert a new attachment record.""" async with self.get_db() as session: session: AsyncSession @@ -443,7 +455,7 @@ async def insert_attachment(self, path, type, mime_type): session.add(new_attachment) return new_attachment - async def get_attachment_by_id(self, attachment_id): + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -452,8 +464,12 @@ async def get_attachment_by_id(self, attachment_id): return result.scalar_one_or_none() async def insert_persona( - self, persona_id, system_prompt, begin_dialogs=None, tools=None - ): + self, + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona: """Insert a new persona record.""" async with self.get_db() as session: session: AsyncSession @@ -467,7 +483,7 @@ async def insert_persona( session.add(new_persona) return new_persona - async def get_persona_by_id(self, persona_id): + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -475,17 +491,21 @@ async def get_persona_by_id(self, persona_id): result = await session.execute(query) return result.scalar_one_or_none() - async def get_personas(self): + async def get_personas(self) -> list[Persona]: """Get all personas for a specific bot.""" async with self.get_db() as session: session: AsyncSession query = select(Persona) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore async def update_persona( - self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN - ): + self, + persona_id: str, + system_prompt: str | None = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona | None: """Update a persona's system prompt or begin dialogs.""" async with self.get_db() as session: session: AsyncSession @@ -504,7 +524,7 @@ async def update_persona( await session.execute(query) return await self.get_persona_by_id(persona_id) - async def delete_persona(self, persona_id): + async def delete_persona(self, persona_id: str) -> None: """Delete a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -513,7 +533,9 @@ async def delete_persona(self, persona_id): delete(Persona).where(col(Persona.persona_id) == persona_id) ) - async def insert_preference_or_update(self, scope, scope_id, key, value): + async def insert_preference_or_update( + self, scope: str, scope_id: str, key: str, value: dict + ) -> Preference: """Insert a new preference record or update if it exists.""" async with self.get_db() as session: session: AsyncSession @@ -534,7 +556,9 @@ async def insert_preference_or_update(self, scope, scope_id, key, value): session.add(new_preference) return existing_preference or new_preference - async def get_preference(self, scope, scope_id, key): + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by key.""" async with self.get_db() as session: session: AsyncSession @@ -546,7 +570,9 @@ async def get_preference(self, scope, scope_id, key): result = await session.execute(query) return result.scalar_one_or_none() - async def get_preferences(self, scope, scope_id=None, key=None): + async def get_preferences( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: """Get all preferences for a specific scope ID or key.""" async with self.get_db() as session: session: AsyncSession @@ -556,9 +582,9 @@ async def get_preferences(self, scope, scope_id=None, key=None): if key is not None: query = query.where(Preference.key == key) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore - async def remove_preference(self, scope, scope_id, key): + async def remove_preference(self, scope: str, scope_id: str, key: str) -> None: """Remove a preference by scope ID and key.""" async with self.get_db() as session: session: AsyncSession @@ -572,7 +598,7 @@ async def remove_preference(self, scope, scope_id, key): ) await session.commit() - async def clear_preferences(self, scope, scope_id): + async def clear_preferences(self, scope: str, scope_id: str) -> None: """Clear all preferences for a specific scope ID.""" async with self.get_db() as session: session: AsyncSession @@ -589,10 +615,10 @@ async def clear_preferences(self, scope, scope_id): # Deprecated Methods # ==== - def get_base_stats(self, offset_sec=86400): + def get_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: """Get base statistics within the specified offset in seconds.""" - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -614,19 +640,19 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_total_message_count(self): + def get_total_message_count(self) -> int: """Get the total message count from platform statistics.""" - async def _inner(): + async def _inner() -> int: async with self.get_db() as session: session: AsyncSession result = await session.execute( @@ -637,18 +663,18 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_grouped_base_stats(self, offset_sec=86400): + def get_grouped_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: # group by platform_id - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -672,11 +698,11 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index 27fc9f3fb..46242881d 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -1,5 +1,6 @@ import abc from dataclasses import dataclass +from types import FunctionType @dataclass @@ -9,7 +10,7 @@ class Result: class BaseVecDB: - async def initialize(self): + async def initialize(self) -> None: """ 初始化向量数据库 """ @@ -33,7 +34,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: FunctionType | None = None, ) -> int: """ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -74,4 +75,4 @@ async def delete(self, doc_id: str) -> bool: ... @abc.abstractmethod - async def close(self): ... + async def close(self) -> None: ... diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index 265c0cc43..629c895c9 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -2,6 +2,7 @@ import json from datetime import datetime from contextlib import asynccontextmanager +from collections.abc import AsyncGenerator from sqlalchemy import Text, Column from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine @@ -30,7 +31,7 @@ class Document(BaseDocModel, table=True): class DocumentStorage: - def __init__(self, db_path: str): + def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.engine: AsyncEngine | None = None @@ -39,7 +40,7 @@ def __init__(self, db_path: str): os.path.dirname(__file__), "sqlite_init.sql" ) - async def initialize(self): + async def initialize(self) -> None: """Initialize the SQLite database and create the documents table if it doesn't exist.""" await self.connect() async with self.engine.begin() as conn: # type: ignore @@ -76,7 +77,7 @@ async def initialize(self): await conn.commit() - async def connect(self): + async def connect(self) -> None: """Connect to the SQLite database.""" if self.engine is None: self.engine = create_async_engine( @@ -91,7 +92,7 @@ async def connect(self): ) # type: ignore @asynccontextmanager - async def get_session(self): + async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """Context manager for database sessions.""" async with self.async_session_maker() as session: # type: ignore yield session @@ -203,7 +204,7 @@ async def insert_documents_batch( await session.flush() # Flush to get all IDs return [doc.id for doc in documents] # type: ignore - async def delete_document_by_doc_id(self, doc_id: str): + async def delete_document_by_doc_id(self, doc_id: str) -> None: """Delete a document by its doc_id. Args: @@ -220,7 +221,7 @@ async def delete_document_by_doc_id(self, doc_id: str): if document: await session.delete(document) - async def get_document_by_doc_id(self, doc_id: str): + async def get_document_by_doc_id(self, doc_id: str) -> dict | None: """Retrieve a document by its doc_id. Args: @@ -240,7 +241,7 @@ async def get_document_by_doc_id(self, doc_id: str): return self._document_to_dict(document) return None - async def update_document_by_doc_id(self, doc_id: str, new_text: str): + async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None: """Update a document by its doc_id. Args: @@ -260,7 +261,7 @@ async def update_document_by_doc_id(self, doc_id: str, new_text: str): document.updated_at = datetime.now() session.add(document) - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """Delete documents by their metadata filters. Args: @@ -351,7 +352,7 @@ def _document_to_dict(self, document: Document) -> dict: else document.updated_at, } - async def tuple_to_dict(self, row): + async def tuple_to_dict(self, row: tuple) -> dict: """Convert a tuple to a dictionary. Args: @@ -371,7 +372,7 @@ async def tuple_to_dict(self, row): "updated_at": row[5], } - async def close(self): + async def close(self) -> None: """Close the connection to the SQLite database.""" if self.engine: await self.engine.dispose() diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 2c0cc8dfe..d2e7a0f06 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -9,7 +9,7 @@ class EmbeddingStorage: - def __init__(self, dimension: int, path: str | None = None): + def __init__(self, dimension: int, path: str | None = None) -> None: self.dimension = dimension self.path = path self.index = None @@ -19,7 +19,7 @@ def __init__(self, dimension: int, path: str | None = None): base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) - async def insert(self, vector: np.ndarray, id: int): + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 Args: @@ -36,7 +36,7 @@ async def insert(self, vector: np.ndarray, id: int): self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) await self.save_index() - async def insert_batch(self, vectors: np.ndarray, ids: list[int]): + async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: """批量插入向量 Args: @@ -67,7 +67,7 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: distances, indices = self.index.search(vector, k) return distances, indices - async def delete(self, ids: list[int]): + async def delete(self, ids: list[int]) -> None: """删除向量 Args: @@ -78,7 +78,7 @@ async def delete(self, ids: list[int]): self.index.remove_ids(id_array) await self.save_index() - async def save_index(self): + async def save_index(self) -> None: """保存索引 Args: diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 8a21538ec..3a08ed3f7 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -1,3 +1,4 @@ +from types import FunctionType import uuid import time import numpy as np @@ -20,7 +21,7 @@ def __init__( index_store_path: str, embedding_provider: EmbeddingProvider, rerank_provider: RerankProvider | None = None, - ): + ) -> None: self.doc_store_path = doc_store_path self.index_store_path = index_store_path self.embedding_provider = embedding_provider @@ -31,7 +32,7 @@ def __init__( self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider - async def initialize(self): + async def initialize(self) -> None: await self.document_storage.initialize() async def insert( @@ -61,7 +62,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: FunctionType | None = None, ) -> list[int]: """ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -158,7 +159,7 @@ async def retrieve( return top_k_results - async def delete(self, doc_id: str): + async def delete(self, doc_id: str) -> None: """ 删除一条文档块(chunk) """ @@ -172,7 +173,7 @@ async def delete(self, doc_id: str): await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) - async def close(self): + async def close(self) -> None: await self.document_storage.close() async def count_documents(self, metadata_filter: dict | None = None) -> int: @@ -187,7 +188,7 @@ async def count_documents(self, metadata_filter: dict | None = None) -> int: ) return count - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """ 根据元数据过滤器删除文档 """ diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py index 5aaf84ba1..3cf1796dc 100644 --- a/astrbot/core/knowledge_base/chunking/base.py +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -4,6 +4,7 @@ """ from abc import ABC, abstractmethod +from typing import Any class BaseChunker(ABC): @@ -13,7 +14,7 @@ class BaseChunker(ABC): """ @abstractmethod - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 """将文本分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index c9b35d7d8..f36fb86e4 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -3,6 +3,7 @@ 按照固定的字符数将文本分块,支持重叠区域。 """ +from typing import Any from .base import BaseChunker @@ -12,7 +13,7 @@ class FixedSizeChunker(BaseChunker): 按照固定的字符数分块,并支持块之间的重叠。 """ - def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: """初始化分块器 Args: @@ -22,7 +23,7 @@ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 """固定大小分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 21b76cba5..11b31b165 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Any from .base import BaseChunker @@ -10,7 +11,7 @@ def __init__( length_function: Callable[[str], int] = len, is_separator_regex: bool = False, separators: list[str] | None = None, - ): + ) -> None: """ 初始化递归字符文本分割器 @@ -38,7 +39,7 @@ def __init__( "", # 字符 ] - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 """ 递归地将文本分割成块 diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 827d621d3..cc8dd6d0c 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -1,5 +1,6 @@ from contextlib import asynccontextmanager from pathlib import Path +from collections.abc import AsyncGenerator from sqlmodel import col, desc from sqlalchemy import text, func, select, update, delete @@ -45,7 +46,7 @@ def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: ) @asynccontextmanager - async def get_db(self): + async def get_db(self) -> AsyncGenerator[AsyncSession, None]: """获取数据库会话 用法: @@ -249,7 +250,7 @@ async def get_document_with_metadata(self, doc_id: str) -> dict | None: "knowledge_base": row[1], } - async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB): + async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None: """删除单个文档及其相关数据""" # 在知识库表中删除 async with self.get_db() as session: diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 09b9c9fc8..793e41a8d 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -1,3 +1,4 @@ +from types import FunctionType import uuid import aiofiles import json @@ -24,7 +25,7 @@ def __init__( provider_manager: ProviderManager, kb_root_dir: str, chunker: BaseChunker, - ): + ) -> None: self.kb_db = kb_db self.kb = kb self.prov_mgr = provider_manager @@ -38,7 +39,7 @@ def __init__( self.kb_medias_dir.mkdir(parents=True, exist_ok=True) self.kb_files_dir.mkdir(parents=True, exist_ok=True) - async def initialize(self): + async def initialize(self) -> None: await self._ensure_vec_db() async def get_ep(self) -> EmbeddingProvider: @@ -82,7 +83,7 @@ async def _ensure_vec_db(self) -> FaissVecDB: self.vec_db = vec_db return vec_db - async def delete_vec_db(self): + async def delete_vec_db(self) -> None: """删除知识库的向量数据库和所有相关文件""" import shutil @@ -90,7 +91,7 @@ async def delete_vec_db(self): if self.kb_dir.exists(): shutil.rmtree(self.kb_dir) - async def terminate(self): + async def terminate(self) -> None: if self.vec_db: await self.vec_db.close() @@ -104,7 +105,7 @@ async def upload_document( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: FunctionType | None = None, ) -> KBDocument: """上传并处理文档(带原子性保证和失败清理) @@ -180,7 +181,7 @@ async def upload_document( await progress_callback("chunking", 100, 100) # 阶段3: 生成向量(带进度回调) - async def embedding_progress_callback(current, total): + async def embedding_progress_callback(current: int, total: int) -> None: if progress_callback: await progress_callback("embedding", current, total) @@ -245,7 +246,7 @@ async def get_document(self, doc_id: str) -> KBDocument | None: doc = await self.kb_db.get_document_by_id(doc_id) return doc - async def delete_document(self, doc_id: str): + async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" await self.kb_db.delete_document_by_id( doc_id=doc_id, @@ -257,7 +258,7 @@ async def delete_document(self, doc_id: str): ) await self.refresh_kb() - async def delete_chunk(self, chunk_id: str, doc_id: str): + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore await vec_db.delete(chunk_id) @@ -268,7 +269,7 @@ async def delete_chunk(self, chunk_id: str, doc_id: str): await self.refresh_kb() await self.refresh_document(doc_id) - async def refresh_kb(self): + async def refresh_kb(self) -> None: if self.kb: kb = await self.kb_db.get_kb_by_id(self.kb.kb_id) if kb: diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index c1c63d08a..1f0a8b68d 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -28,14 +28,14 @@ class KnowledgeBaseManager: def __init__( self, provider_manager: ProviderManager, - ): + ) -> None: Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True) self.provider_manager = provider_manager self._session_deleted_callback_registered = False self.kb_insts: dict[str, KBHelper] = {} - async def initialize(self): + async def initialize(self) -> None: """初始化知识库模块""" try: logger.info("正在初始化知识库模块...") @@ -60,13 +60,13 @@ async def initialize(self): logger.error(f"知识库模块初始化失败: {e}") logger.error(traceback.format_exc()) - async def _init_kb_database(self): + async def _init_kb_database(self) -> None: self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix()) await self.kb_db.initialize() await self.kb_db.migrate_to_v1() logger.info(f"KnowledgeBase database initialized: {DB_PATH}") - async def load_kbs(self): + async def load_kbs(self) -> None: """加载所有知识库实例""" kb_records = await self.kb_db.list_kbs() for record in kb_records: @@ -269,7 +269,7 @@ def _format_context(self, results: list[RetrievalResult]) -> str: return "\n".join(lines) - async def terminate(self): + async def terminate(self) -> None: """终止所有知识库实例,关闭数据库连接""" for kb_id, kb_helper in self.kb_insts.items(): try: diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 278e4da20..8d9447f42 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -6,7 +6,6 @@ import time from dataclasses import dataclass -from typing import List from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion @@ -45,7 +44,7 @@ def __init__( sparse_retriever: SparseRetriever, rank_fusion: RankFusion, kb_db: KBSQLiteDatabase, - ): + ) -> None: """初始化检索管理器 Args: @@ -61,11 +60,11 @@ def __init__( async def retrieve( self, query: str, - kb_ids: List[str], + kb_ids: list[str], kb_id_helper_map: dict[str, KBHelper], top_k_fusion: int = 20, top_m_final: int = 5, - ) -> List[RetrievalResult]: + ) -> list[RetrievalResult]: """混合检索 流程: @@ -188,9 +187,9 @@ async def retrieve( async def _dense_retrieve( self, query: str, - kb_ids: List[str], + kb_ids: list[str], kb_options: dict, - ): + ) -> list[Result]: """稠密检索 (向量相似度) 为每个知识库使用独立的向量数据库进行检索,然后合并结果。 @@ -233,10 +232,10 @@ async def _dense_retrieve( async def _rerank( self, query: str, - results: List[RetrievalResult], + results: list[RetrievalResult], top_k: int, rerank_provider: RerankProvider, - ) -> List[RetrievalResult]: + ) -> list[RetrievalResult]: """Rerank 重排序 Args: diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 3ceba4ff8..2e1b90d62 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -31,7 +31,7 @@ class RankFusion: - 使用 Reciprocal Rank Fusion (RRF) 算法 """ - def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60): + def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None: """初始化结果融合器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index 315930b3e..03735348a 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -32,7 +32,7 @@ class SparseRetriever: - 使用 BM25 算法计算相关度 """ - def __init__(self, kb_db: KBSQLiteDatabase): + def __init__(self, kb_db: KBSQLiteDatabase) -> None: """初始化稀疏检索器 Args: diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 3a1c50371..30699dc0b 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -28,7 +28,6 @@ import sys from collections import deque from asyncio import Queue -from typing import List # 日志缓存大小 CACHED_SIZE = 200 @@ -87,7 +86,7 @@ class LogBroker: def __init__(self): self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 - self.subscribers: List[Queue] = [] # 订阅者列表 + self.subscribers: list[Queue] = [] # 订阅者列表 def register(self) -> Queue: """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列 diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index d9ec4b41b..386e944a7 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -72,7 +72,7 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType - def toString(self): + def toString(self) -> str: output = f"[CQ:{self.type.lower()}" for k, v in self.__dict__.items(): if k == "type" or v is None: @@ -81,7 +81,7 @@ def toString(self): k = "type" if isinstance(v, bool): v = 1 if v else 0 - output += ",%s=%s" % ( + output += ",{}={}".format( k, str(v) .replace("&", "&") @@ -92,7 +92,7 @@ def toString(self): output += "]" return output - def toDict(self): + def toDict(self) -> dict: data = {} for k, v in self.__dict__.items(): if k == "type" or v is None: @@ -112,20 +112,20 @@ class Plain(BaseMessageComponent): text: str convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息 - def __init__(self, text: str, convert: bool = True, **_): + def __init__(self, text: str, convert: bool = True, **_: T.Any) -> None: # noqa: ANN401 super().__init__(text=text, convert=convert, **_) - def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本 + def toString(self) -> str: # 没有 [CQ:plain] 这种东西,所以直接导出纯文本 if not self.convert: return self.text return ( self.text.replace("&", "&").replace("[", "[").replace("]", "]") ) - def toDict(self): + def toDict(self) -> dict: return {"type": "text", "data": {"text": self.text.strip()}} - async def to_dict(self): + async def to_dict(self) -> dict: return {"type": "text", "data": {"text": self.text}} @@ -133,7 +133,7 @@ class Face(BaseMessageComponent): type = ComponentType.Face id: int - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -148,7 +148,7 @@ class Record(BaseMessageComponent): # 额外 path: T.Optional[str] - def __init__(self, file: T.Optional[str], **_): + def __init__(self, file: T.Optional[str], **_: T.Any) -> None: # noqa: ANN401 for k in _.keys(): if k == "url": pass @@ -156,17 +156,17 @@ def __init__(self, file: T.Optional[str], **_): super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "Record": # noqa: ANN401 return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: T.Any) -> "Record": # noqa: ANN401 if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromBase64(bs64_data: str, **_): + def fromBase64(bs64_data: str, **_: T.Any) -> "Record": # noqa: ANN401 return Record(file=f"base64://{bs64_data}", **_) async def convert_to_file_path(self) -> str: @@ -250,15 +250,15 @@ class Video(BaseMessageComponent): # 额外 path: T.Optional[str] = "" - def __init__(self, file: str, **_): + def __init__(self, file: str, **_: T.Any) -> None: # noqa: ANN401 super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "Video": # noqa: ANN401 return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: T.Any) -> "Video": # noqa: ANN401 if url.startswith("http://") or url.startswith("https://"): return Video(file=url, **_) raise Exception("not a valid url") @@ -285,7 +285,7 @@ async def convert_to_file_path(self) -> str: else: raise Exception(f"not a valid file: {url}") - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """ 将视频注册到文件服务。 @@ -308,7 +308,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = self.file if url_or_path.startswith("http"): @@ -333,10 +333,10 @@ class At(BaseMessageComponent): qq: T.Union[int, str] # 此处str为all时代表所有人 name: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) - def toDict(self): + def toDict(self) -> dict: return { "type": "at", "data": {"qq": str(self.qq)}, @@ -346,28 +346,28 @@ def toDict(self): class AtAll(At): qq: str = "all" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -375,7 +375,7 @@ class Anonymous(BaseMessageComponent): # TODO type = ComponentType.Anonymous ignore: T.Optional[bool] = False - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -386,7 +386,7 @@ class Share(BaseMessageComponent): content: T.Optional[str] = "" image: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -395,7 +395,7 @@ class Contact(BaseMessageComponent): # TODO _type: str # type 字段冲突 id: T.Optional[int] = 0 - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -406,7 +406,7 @@ class Location(BaseMessageComponent): # TODO title: T.Optional[str] = "" content: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -420,7 +420,7 @@ class Music(BaseMessageComponent): content: T.Optional[str] = "" image: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") @@ -440,29 +440,29 @@ class Image(BaseMessageComponent): path: T.Optional[str] = "" file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: T.Optional[str], **_): + def __init__(self, file: T.Optional[str], **_: T.Any) -> None: # noqa: ANN401 super().__init__(file=file, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: T.Any) -> "Image": # noqa: ANN401 if url.startswith("http://") or url.startswith("https://"): return Image(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "Image": # noqa: ANN401 return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromBase64(base64: str, **_): + def fromBase64(base64: str, **_: T.Any) -> "Image": # noqa: ANN401 return Image(f"base64://{base64}", **_) @staticmethod - def fromBytes(byte: bytes): + def fromBytes(byte: bytes) -> "Image": return Image.fromBase64(base64.b64encode(byte).decode()) @staticmethod - def fromIO(IO): + def fromIO(IO: T.BinaryIO) -> "Image": return Image.fromBytes(IO.read()) async def convert_to_file_path(self) -> str: @@ -562,7 +562,7 @@ class Reply(BaseMessageComponent): seq: T.Optional[int] = 0 """deprecated""" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -570,16 +570,16 @@ class RedBag(BaseMessageComponent): type = ComponentType.RedBag title: str - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class Poke(BaseMessageComponent): - type: str = ComponentType.Poke + type = ComponentType.Poke id: T.Optional[int] = 0 qq: T.Optional[int] = 0 - def __init__(self, type: str, **_): + def __init__(self, type: str, **_: T.Any) -> None: # noqa: ANN401 type = f"Poke:{type}" super().__init__(type=type, **_) @@ -588,7 +588,7 @@ class Forward(BaseMessageComponent): type = ComponentType.Forward id: str - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -603,13 +603,13 @@ class Node(BaseMessageComponent): seq: T.Optional[T.Union[str, list]] = "" # 忽略 time: T.Optional[int] = 0 # 忽略 - def __init__(self, content: list[BaseMessageComponent], **_): + def __init__(self, content: list[BaseMessageComponent], **_: T.Any) -> None: # noqa: ANN401 if isinstance(content, Node): # back content = [content] super().__init__(content=content, **_) - async def to_dict(self): + async def to_dict(self) -> dict: data_content = [] for comp in self.content: if isinstance(comp, (Image, Record)): @@ -650,10 +650,10 @@ class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: T.List[Node] - def __init__(self, nodes: T.List[Node], **_): + def __init__(self, nodes: T.List[Node], **_: T.Any) -> None: # noqa: ANN401 super().__init__(nodes=nodes, **_) - def toDict(self): + def toDict(self) -> dict: """Deprecated. Use to_dict instead""" ret = { "messages": [], @@ -663,7 +663,7 @@ def toDict(self): ret["messages"].append(d) return ret - async def to_dict(self): + async def to_dict(self) -> dict: """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" ret = {"messages": []} for node in self.nodes: @@ -677,7 +677,7 @@ class Xml(BaseMessageComponent): data: str resid: T.Optional[int] = 0 - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -686,7 +686,7 @@ class Json(BaseMessageComponent): data: T.Union[str, dict] resid: T.Optional[int] = 0 - def __init__(self, data, **_): + def __init__(self, data: T.Union[str, dict], **_: T.Any) -> None: # noqa: ANN401 if isinstance(data, dict): data = json.dumps(data) super().__init__(data=data, **_) @@ -703,11 +703,11 @@ class CardImage(BaseMessageComponent): source: T.Optional[str] = "" icon: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "CardImage": # noqa: ANN401 return CardImage(file=f"file:///{os.path.abspath(path)}", **_) @@ -715,7 +715,7 @@ class TTS(BaseMessageComponent): type = ComponentType.TTS text: str - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -723,7 +723,7 @@ class Unknown(BaseMessageComponent): type = ComponentType.Unknown text: str - def toString(self): + def toString(self) -> str: return "" @@ -737,7 +737,7 @@ class File(BaseMessageComponent): file_: T.Optional[str] = "" # 本地路径 url: T.Optional[str] = "" # url - def __init__(self, name: str, file: str = "", url: str = ""): + def __init__(self, name: str, file: str = "", url: str = "") -> None: # noqa: ANN401 """文件消息段。""" super().__init__(name=name, file_=file, url=url) @@ -757,11 +757,9 @@ def file(self) -> str: loop = asyncio.get_event_loop() if loop.is_running(): logger.warning( - ( - "不可以在异步上下文中同步等待下载! " - "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" - "请使用 await get_file() 代替直接获取 .file 字段" - ) + "不可以在异步上下文中同步等待下载! " + "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" + "请使用 await get_file() 代替直接获取 .file 字段" ) return "" else: @@ -776,7 +774,7 @@ def file(self) -> str: return "" @file.setter - def file(self, value: str): + def file(self, value: str) -> None: """ 向前兼容, 设置file属性, 传入的参数可能是文件路径或URL @@ -809,7 +807,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: return "" - async def _download_file(self): + async def _download_file(self) -> None: """下载文件""" download_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(download_dir, exist_ok=True) @@ -817,7 +815,7 @@ async def _download_file(self): await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """ 将文件注册到文件服务。 @@ -840,7 +838,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = await self.get_file(allow_return_url=True) if url_or_path.startswith("http"): @@ -867,7 +865,7 @@ class WechatEmoji(BaseMessageComponent): md5_len: T.Optional[int] = 0 cdnurl: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 7bfdd34c8..030118e90 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,6 +1,8 @@ import enum -from typing import List, Optional, Union, AsyncGenerator +from typing_extensions import Self + +from collections.abc import AsyncGenerator from dataclasses import dataclass, field from astrbot.core.message.components import ( BaseMessageComponent, @@ -22,12 +24,12 @@ class MessageChain: `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 """ - chain: List[BaseMessageComponent] = field(default_factory=list) - use_t2i_: Optional[bool] = None # None 为跟随用户设置 - type: Optional[str] = None + chain: list[BaseMessageComponent] = field(default_factory=list) + use_t2i_: bool | None = None # None 为跟随用户设置 + type: str | None = None """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" - def message(self, message: str): + def message(self, message: str) -> Self: """添加一条文本消息到消息链 `chain` 中。 Example: @@ -39,7 +41,7 @@ def message(self, message: str): self.chain.append(Plain(message)) return self - def at(self, name: str, qq: Union[str, int]): + def at(self, name: str, qq: str | int) -> Self: """添加一条 At 消息到消息链 `chain` 中。 Example: @@ -51,7 +53,7 @@ def at(self, name: str, qq: Union[str, int]): self.chain.append(At(name=name, qq=qq)) return self - def at_all(self): + def at_all(self) -> Self: """添加一条 AtAll 消息到消息链 `chain` 中。 Example: @@ -64,7 +66,7 @@ def at_all(self): return self @deprecated("请使用 message 方法代替。") - def error(self, message: str): + def error(self, message: str) -> Self: """添加一条错误消息到消息链 `chain` 中 Example: @@ -75,7 +77,7 @@ def error(self, message: str): self.chain.append(Plain(message)) return self - def url_image(self, url: str): + def url_image(self, url: str) -> Self: """添加一条图片消息(https 链接)到消息链 `chain` 中。 Note: @@ -89,7 +91,7 @@ def url_image(self, url: str): self.chain.append(Image.fromURL(url)) return self - def file_image(self, path: str): + def file_image(self, path: str) -> Self: """添加一条图片消息(本地文件路径)到消息链 `chain` 中。 Note: @@ -100,7 +102,7 @@ def file_image(self, path: str): self.chain.append(Image.fromFileSystem(path)) return self - def base64_image(self, base64_str: str): + def base64_image(self, base64_str: str) -> Self: """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。 Example: @@ -109,7 +111,7 @@ def base64_image(self, base64_str: str): self.chain.append(Image.fromBase64(base64_str)) return self - def use_t2i(self, use_t2i: bool): + def use_t2i(self, use_t2i: bool) -> Self: """设置是否使用文本转图片服务。 Args: @@ -122,10 +124,10 @@ def get_plain_text(self) -> str: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - def squash_plain(self): + def squash_plain(self) -> Self | None: """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" if not self.chain: - return + return None new_chain = [] first_plain = None @@ -183,23 +185,23 @@ class MessageEventResult(MessageChain): `result_type` (EventResultType): 事件处理的结果类型。 """ - result_type: Optional[EventResultType] = field( + result_type: EventResultType | None = field( default_factory=lambda: EventResultType.CONTINUE ) - result_content_type: Optional[ResultContentType] = field( + result_content_type: ResultContentType | None = field( default_factory=lambda: ResultContentType.GENERAL_RESULT ) - async_stream: Optional[AsyncGenerator] = None + async_stream: AsyncGenerator | None = None """异步流""" - def stop_event(self) -> "MessageEventResult": + def stop_event(self) -> Self: """终止事件传播。""" self.result_type = EventResultType.STOP return self - def continue_event(self) -> "MessageEventResult": + def continue_event(self) -> Self: """继续事件传播。""" self.result_type = EventResultType.CONTINUE return self @@ -210,12 +212,12 @@ def is_stopped(self) -> bool: """ return self.result_type == EventResultType.STOP - def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": + def set_async_stream(self, stream: AsyncGenerator) -> Self: """设置异步流。""" self.async_stream = stream return self - def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": + def set_result_content_type(self, typ: ResultContentType) -> Self: """设置事件处理的结果类型。 Args: diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index e6ecd995c..d94e97ea9 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -1,4 +1,4 @@ -from typing import Union, AsyncGenerator +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -14,13 +14,13 @@ class ContentSafetyCheckStage(Stage): 当前只会检查文本的。 """ - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: config = ctx.astrbot_config["content_safety"] self.strategy_selector = StrategySelector(config) async def process( self, event: AstrMessageEvent, check_text: str | None = None - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """检查内容安全""" text = check_text if check_text else event.get_message_str() ok, info = self.strategy_selector.check(text) diff --git a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py index 5701f0634..f0a34e73f 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py @@ -1,8 +1,7 @@ import abc -from typing import Tuple class ContentSafetyStrategy(abc.ABC): @abc.abstractmethod - def check(self, content: str) -> Tuple[bool, str]: + def check(self, content: str) -> tuple[bool, str]: raise NotImplementedError diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py index af960328f..27494d455 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py @@ -1,11 +1,10 @@ from . import ContentSafetyStrategy -from typing import List, Tuple from astrbot import logger class StrategySelector: def __init__(self, config: dict) -> None: - self.enabled_strategies: List[ContentSafetyStrategy] = [] + self.enabled_strategies: list[ContentSafetyStrategy] = [] if config["internal_keywords"]["enable"]: from .keywords import KeywordsStrategy @@ -26,7 +25,7 @@ def __init__(self, config: dict) -> None: ) ) - def check(self, content: str) -> Tuple[bool, str]: + def check(self, content: str) -> tuple[bool, str]: for strategy in self.enabled_strategies: ok, info = strategy.check(content) if not ok: diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index e7ac120b7..10582c69a 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -11,8 +11,8 @@ async def call_handler( event: AstrMessageEvent, handler: T.Callable[..., T.Awaitable[T.Any]], - *args, - **kwargs, + *args: T.Any, # noqa: ANN401 + **kwargs: T.Any, # noqa: ANN401 ) -> T.AsyncGenerator[T.Any, None]: """执行事件处理函数并处理其返回结果 @@ -73,8 +73,8 @@ async def call_handler( async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, - *args, - **kwargs, + *args: T.Any, # noqa: ANN401 + **kwargs: T.Any, # noqa: ANN401 ) -> bool: """调用事件钩子函数 diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 5c075687f..936e9f3e7 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -1,7 +1,8 @@ import traceback import asyncio import random -from typing import Union, AsyncGenerator + +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -21,7 +22,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """在处理事件之前的预处理""" # 平台特异配置:platform_specific..pre_ack_emoji supported = {"telegram", "lark"} diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 703b3681c..747b5562b 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -6,6 +6,7 @@ import copy import json import traceback +import typing as T from datetime import timedelta from collections.abc import AsyncGenerator from astrbot.core.conversation_mgr import Conversation @@ -50,7 +51,12 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @classmethod - async def execute(cls, tool, run_context, **tool_args): + async def execute( + cls, + tool: FunctionTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args: T.Any, # noqa:ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: """执行函数调用。 Args: @@ -82,8 +88,8 @@ async def _execute_handoff( cls, tool: HandoffTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: T.Any, # noqa: ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: input_ = tool_args.get("input", "agent") agent_runner = AgentRunner() @@ -172,8 +178,8 @@ async def _execute_local( cls, tool: FunctionTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: T.Any, # noqa: ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: if not run_context.event: raise ValueError("Event must be provided for local function tools.") @@ -220,8 +226,8 @@ async def _execute_mcp( cls, tool: FunctionTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: T.Any, # noqa: ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: if not tool.mcp_client: raise ValueError("MCP client is not available for MCP function tools.") @@ -241,7 +247,9 @@ async def _execute_mcp( class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): - async def on_agent_done(self, run_context, llm_response): + async def on_agent_done( + self, run_context: ContextWrapper[AstrAgentContext], llm_response: LLMResponse + ) -> None: # 执行事件钩子 await call_event_hook( run_context.event, EventType.OnLLMResponseEvent, llm_response @@ -338,7 +346,7 @@ async def initialize(self, ctx: PipelineContext) -> None: self.conv_manager = ctx.plugin_manager.context.conversation_manager - def _select_provider(self, event: AstrMessageEvent): + def _select_provider(self, event: AstrMessageEvent) -> Provider | None: """选择使用的 LLM 提供商""" sel_provider = event.get_extra("selected_provider") _ctx = self.ctx.plugin_manager.context @@ -565,7 +573,7 @@ async def process( async def _handle_webchat( self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider - ): + ) -> None: """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" if not req.conversation: return @@ -623,7 +631,7 @@ async def _save_to_history( event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse | None, - ): + ) -> None: if ( not req or not req.conversation diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 42990aae5..7a9d2045f 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -4,7 +4,9 @@ from ...context import PipelineContext, call_handler from ..stage import Stage -from typing import Dict, Any, List, AsyncGenerator, Union +from typing import Any + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core import logger @@ -22,11 +24,11 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: - activated_handlers: List[StarHandlerMetadata] = event.get_extra( + ) -> None | AsyncGenerator[None, None]: + activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers" ) - handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra( + handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra( "handlers_parsed_params" ) if not handlers_parsed_params: diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index f653a9fb9..87906d653 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -1,4 +1,4 @@ -from typing import List, Union, AsyncGenerator +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from .method.llm_request import LLMRequestSubStage @@ -23,9 +23,9 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """处理事件""" - activated_handlers: List[StarHandlerMetadata] = event.get_extra( + activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers" ) # 有插件 Handler 被激活 diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index b36a2fbd0..e16a1ca48 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -1,7 +1,9 @@ import asyncio from datetime import datetime, timedelta from collections import defaultdict, deque -from typing import DefaultDict, Deque, Union, AsyncGenerator +from typing import DefaultDict, Deque + +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -18,7 +20,7 @@ class RateLimitStage(Stage): 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 """ - def __init__(self): + def __init__(self) -> None: # 存储每个会话的请求时间队列 self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque) # 为每个会话设置一个锁,避免并发冲突 @@ -43,7 +45,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """ 检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index dc6a67e2f..98afecefe 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -2,7 +2,8 @@ import asyncio import math import astrbot.core.message.components as Comp -from typing import Union, AsyncGenerator + +from collections.abc import AsyncGenerator from ..stage import register_stage, Stage from ..context import PipelineContext, call_event_hook from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -34,7 +35,7 @@ class RespondStage(Stage): Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.platform_settings: dict = self.config.get("platform_settings", {}) @@ -92,7 +93,7 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 Args: @@ -134,7 +135,7 @@ def _extract_comp( raw_chain: list[BaseMessageComponent], extract_types: set[ComponentType], modify_raw_chain: bool = True, - ): + ) -> list[BaseMessageComponent]: extracted = [] if modify_raw_chain: remaining = [] @@ -151,7 +152,7 @@ def _extract_comp( async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: result = event.get_result() if result is None: return diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index c1f893baf..9dddd6d1a 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -1,7 +1,8 @@ import re import time import traceback -from typing import AsyncGenerator, Union + +from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply @@ -18,7 +19,7 @@ @register_stage class ResultDecorateStage(Stage): - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ @@ -72,7 +73,7 @@ async def initialize(self, ctx: PipelineContext): async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: result = event.get_result() if result is None or not result.chain: return diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 7a38ec03f..1b2f7dccc 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -1,7 +1,7 @@ from . import STAGES_ORDER from .stage import registered_stages from .context import PipelineContext -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from astrbot.core.platform import AstrMessageEvent from astrbot.core import logger @@ -9,21 +9,23 @@ class PipelineScheduler: """管道调度器,负责调度各个阶段的执行""" - def __init__(self, context: PipelineContext): + def __init__(self, context: PipelineContext) -> None: registered_stages.sort( key=lambda x: STAGES_ORDER.index(x.__name__) ) # 按照顺序排序 self.ctx = context # 上下文对象 self.stages = [] # 存储阶段实例 - async def initialize(self): + async def initialize(self) -> None: """初始化管道调度器时, 初始化所有阶段""" for stage_cls in registered_stages: stage_instance = stage_cls() # 创建实例 await stage_instance.initialize(self.ctx) self.stages.append(stage_instance) - async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + async def _process_stages( + self, event: AstrMessageEvent, from_stage: int = 0 + ) -> None: """依次执行各个阶段 Args: @@ -65,7 +67,7 @@ async def _process_stages(self, event: AstrMessageEvent, from_stage=0): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - async def execute(self, event: AstrMessageEvent): + async def execute(self, event: AstrMessageEvent) -> None: """执行 pipeline Args: diff --git a/astrbot/core/pipeline/session_status_check/stage.py b/astrbot/core/pipeline/session_status_check/stage.py index 3c451e26a..731441f3a 100644 --- a/astrbot/core/pipeline/session_status_check/stage.py +++ b/astrbot/core/pipeline/session_status_check/stage.py @@ -1,6 +1,7 @@ from ..stage import Stage, register_stage from ..context import PipelineContext -from typing import AsyncGenerator, Union + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core import logger @@ -16,7 +17,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: # 检查会话是否整体启用 if not SessionServiceManager.is_session_enabled(event.unified_msg_origin): logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。") diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index c4550495a..fb7dec9e9 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -1,13 +1,14 @@ from __future__ import annotations import abc -from typing import List, AsyncGenerator, Union, Type + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from .context import PipelineContext -registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 +registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 -def register_stage(cls): +def register_stage(cls: type[Stage]) -> type[Stage]: """一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类""" registered_stages.append(cls) return cls @@ -26,9 +27,7 @@ async def initialize(self, ctx: PipelineContext) -> None: raise NotImplementedError @abc.abstractmethod - async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + async def process(self, event: AstrMessageEvent) -> None | AsyncGenerator[None]: """处理事件 Args: diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index de6ad5e35..161b206a9 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Union +from collections.abc import AsyncGenerator from astrbot import logger from astrbot.core.message.components import At, AtAll, Reply @@ -49,7 +49,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: if ( self.ignore_bot_self_message and event.get_self_id() == event.get_sender_id() diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py index b140d23ba..ce3616b27 100644 --- a/astrbot/core/pipeline/whitelist_check/stage.py +++ b/astrbot/core/pipeline/whitelist_check/stage.py @@ -1,6 +1,7 @@ from ..stage import Stage, register_stage from ..context import PipelineContext -from typing import AsyncGenerator, Union + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType from astrbot.core import logger @@ -28,7 +29,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: if not self.enable_whitelist_check: # 白名单检查未启用 return diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 3a4b8c128..448a0733f 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -4,7 +4,9 @@ import hashlib import uuid -from typing import List, Union, Optional, AsyncGenerator, Any +from typing import Any + +from collections.abc import AsyncGenerator from astrbot import logger from astrbot.core.db.po import Conversation @@ -90,7 +92,7 @@ def get_message_str(self) -> str: """ return self.message_str - def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str: + def _outline_chain(self, chain: list[BaseMessageComponent] | None) -> str: outline = "" if not chain: return outline @@ -127,7 +129,7 @@ def get_message_outline(self) -> str: """ return self._outline_chain(self.message_obj.message) - def get_messages(self) -> List[BaseMessageComponent]: + def get_messages(self) -> list[BaseMessageComponent]: """ 获取消息链。 """ @@ -240,7 +242,7 @@ async def _pre_send(self): async def _post_send(self): """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" - def set_result(self, result: Union[MessageEventResult, str]): + def set_result(self, result: MessageEventResult | str): """设置消息事件的结果。 Note: @@ -344,7 +346,7 @@ def image_result(self, url_or_path: str) -> MessageEventResult: return MessageEventResult().url_image(url_or_path) return MessageEventResult().file_image(url_or_path) - def chain_result(self, chain: List[BaseMessageComponent]) -> MessageEventResult: + def chain_result(self, chain: list[BaseMessageComponent]) -> MessageEventResult: """ 创建一个空的消息事件结果,包含指定的消息链。 """ @@ -359,8 +361,8 @@ def request_llm( prompt: str, func_tool_manager=None, session_id: str = None, - image_urls: List[str] = [], - contexts: List = [], + image_urls: list[str] = [], + contexts: list = [], system_prompt: str = "", conversation: Conversation = None, ) -> ProviderRequest: @@ -427,7 +429,7 @@ async def react(self, emoji: str): """ await self.send(MessageChain([Plain(emoji)])) - async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]: + async def get_group(self, group_id: str = None, **kwargs) -> Group | None: """获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。 适配情况: diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 1808c2911..33ae120e0 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -1,5 +1,4 @@ import time -from typing import List from dataclasses import dataclass from astrbot.core.message.components import BaseMessageComponent from .message_type import MessageType @@ -28,9 +27,9 @@ class Group: """群头像""" group_owner: str = None """群主 id""" - group_admins: List[str] = None + group_admins: list[str] = None """群管理员 id""" - members: List[MessageMember] = None + members: list[MessageMember] = None """所有群成员""" def __str__(self): @@ -57,7 +56,7 @@ class AstrBotMessage: message_id: str # 消息id group: Group # 群组 sender: MessageMember # 发送者 - message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 + message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message_str: str # 最直观的纯文本消息字符串 raw_message: object timestamp: int # 消息时间戳 diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 7090c669c..436185296 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -2,7 +2,6 @@ import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .platform import Platform -from typing import List from asyncio import Queue from .register import platform_cls_map from astrbot.core import logger @@ -12,7 +11,7 @@ class PlatformManager: def __init__(self, config: AstrBotConfig, event_queue: Queue): - self.platform_insts: List[Platform] = [] + self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" self._inst_map = {} diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index c109f29b4..31afba005 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -1,6 +1,8 @@ import abc import uuid -from typing import Awaitable, Any +from typing import Any + +from collections.abc import Awaitable from asyncio import Queue from .platform_metadata import PlatformMetadata from .astr_message_event import AstrMessageEvent diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index 97c33a43e..b2a43cf15 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -1,10 +1,9 @@ -from typing import List, Dict, Type from .platform_metadata import PlatformMetadata from astrbot.core import logger -platform_registry: List[PlatformMetadata] = [] +platform_registry: list[PlatformMetadata] = [] """维护了通过装饰器注册的平台适配器""" -platform_cls_map: Dict[str, Type] = {} +platform_cls_map: dict[str, type] = {} """维护了平台适配器名称和适配器类的映射""" diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index b8bb723d5..50dc4e070 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,6 +1,7 @@ import asyncio import re -from typing import AsyncGenerator, Dict, List + +from collections.abc import AsyncGenerator from aiocqhttp import CQHttp, Event from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( @@ -198,7 +199,7 @@ async def get_group(self, group_id=None, **kwargs): group_id=group_id, ) - members: List[Dict] = await self.bot.call_action( + members: list[dict] = await self.bot.call_action( "get_group_member_list", group_id=group_id, ) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index d1992b6c3..2bbd885e9 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -3,7 +3,9 @@ import logging import uuid import itertools -from typing import Awaitable, Any +from typing import Any + +from collections.abc import Awaitable from aiocqhttp import CQHttp, Event from astrbot.api.platform import ( Platform, diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index 07e712161..79b89704a 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -1,5 +1,4 @@ import discord -from typing import List from astrbot.api.message_components import BaseMessageComponent @@ -18,7 +17,7 @@ def __init__( thumbnail: str = None, image: str = None, footer: str = None, - fields: List[dict] = None, + fields: list[dict] = None, ): self.title = title self.description = description @@ -96,7 +95,7 @@ class DiscordView(BaseMessageComponent): type: str = "discord_view" def __init__( - self, components: List[BaseMessageComponent] = None, timeout: float = None + self, components: list[BaseMessageComponent] = None, timeout: float = None ): self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 6764eda61..fa82d0da7 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -19,7 +19,7 @@ from .client import DiscordBotClient from .discord_platform_event import DiscordPlatformEvent -from typing import Any, Tuple +from typing import Any from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.star import star_map @@ -420,7 +420,7 @@ async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None): @staticmethod def _extract_command_info( event_filter: Any, handler_metadata: StarHandlerMetadata - ) -> Tuple[str, str, CommandFilter] | None: + ) -> tuple[str, str, CommandFilter] | None: """从事件过滤器中提取指令信息""" cmd_name = None # is_group = False diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 2c8d055fc..46a5df51b 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -3,7 +3,6 @@ import base64 from io import BytesIO from pathlib import Path -from typing import Optional import sys from astrbot.api.event import AstrMessageEvent, MessageChain @@ -41,7 +40,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: DiscordBotClient, - interaction_followup_webhook: Optional[discord.Webhook] = None, + interaction_followup_webhook: discord.Webhook | None = None, ): super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -98,7 +97,7 @@ async def send(self, message: MessageChain): await super().send(message) - async def _get_channel(self) -> Optional[discord.abc.Messageable]: + async def _get_channel(self) -> discord.abc.Messageable | None: """获取当前事件对应的频道对象""" try: channel_id = int(self.session_id) @@ -112,7 +111,7 @@ async def _get_channel(self) -> Optional[discord.abc.Messageable]: async def _parse_to_discord( self, message: MessageChain, - ) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]: + ) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: """将 MessageChain 解析为 Discord 发送所需的内容""" content = "" files = [] diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 2174c497c..d9e96566e 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -4,7 +4,6 @@ import base64 import lark_oapi as lark from io import BytesIO -from typing import List from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image as AstrBotImage, At from astrbot.core.utils.io import download_image_by_url @@ -21,7 +20,7 @@ def __init__( self.bot = bot @staticmethod - async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List: + async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> list: ret = [] _stage = [] for comp in message.chain: diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index 981d05c82..8d8cc18b8 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -1,6 +1,8 @@ import asyncio import random -from typing import Dict, Any, Optional, Awaitable, List +from typing import Any + +from collections.abc import Awaitable from astrbot.api import logger from astrbot.api.event import MessageChain @@ -87,7 +89,7 @@ def __init__( self.unique_session = platform_settings["unique_session"] - self.api: Optional[MisskeyAPI] = None + self.api: MisskeyAPI | None = None self._running = False self.client_self_id = "" self._bot_username = "" @@ -168,7 +170,7 @@ async def _send_text_only_message( from .misskey_utils import extract_user_id_from_session_id user_id = extract_user_id_from_session_id(session_id) - payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + payload: dict[str, Any] = {"toUserId": user_id, "text": text} await self.api.send_message(payload) elif session_id and is_valid_room_session_id(session_id): from .misskey_utils import extract_room_id_from_session_id @@ -180,7 +182,7 @@ async def _send_text_only_message( return await super().send_by_session(session, message_chain) def _process_poll_data( - self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str] + self, message: AstrBotMessage, poll: dict[str, Any], message_parts: list[str] ): """处理投票数据,将其添加到消息中""" try: @@ -196,7 +198,7 @@ def _process_poll_data( message.message.append(Comp.Plain(poll_text)) message_parts.append(poll_text) - def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]: + def _extract_additional_fields(self, session, message_chain) -> dict[str, Any]: """从会话和消息链中提取额外字段""" fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None} @@ -267,7 +269,7 @@ async def _start_websocket_connection(self): await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) - async def _handle_notification(self, data: Dict[str, Any]): + async def _handle_notification(self, data: dict[str, Any]): try: notification_type = data.get("type") logger.debug( @@ -291,7 +293,7 @@ async def _handle_notification(self, data: Dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理通知失败: {e}") - async def _handle_chat_message(self, data: Dict[str, Any]): + async def _handle_chat_message(self, data: dict[str, Any]): try: sender_id = str( data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "") @@ -326,13 +328,13 @@ async def _handle_chat_message(self, data: Dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理聊天消息失败: {e}") - async def _debug_handler(self, data: Dict[str, Any]): + async def _debug_handler(self, data: dict[str, Any]): event_type = data.get("type", "unknown") logger.debug( f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}" ) - def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool: + def _is_bot_mentioned(self, note: dict[str, Any]) -> bool: text = note.get("text", "") if not text: return False @@ -400,8 +402,8 @@ async def send_by_session( if len(text) > self.max_message_length: text = text[: self.max_message_length] + "..." - file_ids: List[str] = [] - fallback_urls: List[str] = [] + file_ids: list[str] = [] + fallback_urls: list[str] = [] if not self.enable_file_upload: return await self._send_text_only_message( @@ -417,7 +419,7 @@ async def send_by_session( upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY) sem = asyncio.Semaphore(upload_concurrency) - async def _upload_comp(comp) -> Optional[object]: + async def _upload_comp(comp) -> object | None: """组件上传函数:处理 URL(下载后上传)或本地文件(直接上传)""" from .misskey_utils import ( resolve_component_url_or_path, @@ -540,7 +542,7 @@ async def _upload_comp(comp) -> Optional[object]: if fallback_urls: appended = "\n" + "\n".join(fallback_urls) text = (text or "") + appended - payload: Dict[str, Any] = {"toRoomId": room_id, "text": text} + payload: dict[str, Any] = {"toRoomId": room_id, "text": text} if file_ids: payload["fileIds"] = file_ids await self.api.send_room_message(payload) @@ -555,7 +557,7 @@ async def _upload_comp(comp) -> Optional[object]: if fallback_urls: appended = "\n" + "\n".join(fallback_urls) text = (text or "") + appended - payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + payload: dict[str, Any] = {"toUserId": user_id, "text": text} if file_ids: # 聊天消息只支持单个文件,使用 fileId 而不是 fileIds payload["fileId"] = file_ids[0] @@ -610,7 +612,7 @@ async def _upload_comp(comp) -> Optional[object]: return await super().send_by_session(session, message_chain) - async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 贴文数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=False) message = create_base_message( @@ -652,7 +654,7 @@ async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: ) return message - async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_chat_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 聊天消息数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=True) message = create_base_message( @@ -676,7 +678,7 @@ async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage message.message_str = raw_text if raw_text else "" return message - async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 群聊消息数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=True) room_id = raw_data.get("toRoomId", "") diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 4b920508f..dae02b183 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -1,7 +1,9 @@ import json import random import asyncio -from typing import Any, Optional, Dict, List, Callable, Awaitable +from typing import Any + +from collections.abc import Callable, Awaitable import uuid try: @@ -54,11 +56,11 @@ class StreamingClient: def __init__(self, instance_url: str, access_token: str): self.instance_url = instance_url.rstrip("/") self.access_token = access_token - self.websocket: Optional[Any] = None + self.websocket: Any | None = None self.is_connected = False - self.message_handlers: Dict[str, Callable] = {} - self.channels: Dict[str, str] = {} - self.desired_channels: Dict[str, Optional[Dict]] = {} + self.message_handlers: dict[str, Callable] = {} + self.channels: dict[str, str] = {} + self.desired_channels: dict[str, dict | None] = {} self._running = False self._last_pong = None @@ -104,7 +106,7 @@ async def disconnect(self): logger.info("[Misskey WebSocket] 连接已断开") async def subscribe_channel( - self, channel_type: str, params: Optional[Dict] = None + self, channel_type: str, params: dict | None = None ) -> str: if not self.is_connected or not self.websocket: raise WebSocketError("WebSocket 未连接") @@ -136,7 +138,7 @@ async def unsubscribe_channel(self, channel_id: str): self.desired_channels.pop(channel_type, None) def add_message_handler( - self, event_type: str, handler: Callable[[Dict], Awaitable[None]] + self, event_type: str, handler: Callable[[dict], Awaitable[None]] ): self.message_handlers[event_type] = handler @@ -188,11 +190,11 @@ async def listen(self): except Exception: pass - async def _handle_message(self, data: Dict[str, Any]): + async def _handle_message(self, data: dict[str, Any]): message_type = data.get("type") body = data.get("body", {}) - def _build_channel_summary(message_type: Optional[str], body: Any) -> str: + def _build_channel_summary(message_type: str | None, body: Any) -> str: try: if not isinstance(body, dict): return f"[Misskey WebSocket] 收到消息类型: {message_type}" @@ -334,12 +336,12 @@ def __init__( allow_insecure_downloads: bool = False, download_timeout: int = 15, chunk_size: int = 64 * 1024, - max_download_bytes: Optional[int] = None, + max_download_bytes: int | None = None, ): self.instance_url = instance_url.rstrip("/") self.access_token = access_token - self._session: Optional[aiohttp.ClientSession] = None - self.streaming: Optional[StreamingClient] = None + self._session: aiohttp.ClientSession | None = None + self.streaming: StreamingClient | None = None # download options self.allow_insecure_downloads = allow_insecure_downloads self.download_timeout = download_timeout @@ -456,7 +458,7 @@ async def _process_response( retryable_exceptions=(APIConnectionError, APIRateLimitError), ) async def _make_request( - self, endpoint: str, data: Optional[Dict[str, Any]] = None + self, endpoint: str, data: dict[str, Any] | None = None ) -> Any: url = f"{self.instance_url}/api/{endpoint}" payload = {"i": self.access_token} @@ -472,24 +474,24 @@ async def _make_request( async def create_note( self, - text: Optional[str] = None, + text: str | None = None, visibility: str = "public", - reply_id: Optional[str] = None, - visible_user_ids: Optional[List[str]] = None, - file_ids: Optional[List[str]] = None, + reply_id: str | None = None, + visible_user_ids: list[str] | None = None, + file_ids: list[str] | None = None, local_only: bool = False, - cw: Optional[str] = None, - poll: Optional[Dict[str, Any]] = None, - renote_id: Optional[str] = None, - channel_id: Optional[str] = None, - reaction_acceptance: Optional[str] = None, - no_extract_mentions: Optional[bool] = None, - no_extract_hashtags: Optional[bool] = None, - no_extract_emojis: Optional[bool] = None, - media_ids: Optional[List[str]] = None, - ) -> Dict[str, Any]: + cw: str | None = None, + poll: dict[str, Any] | None = None, + renote_id: str | None = None, + channel_id: str | None = None, + reaction_acceptance: str | None = None, + no_extract_mentions: bool | None = None, + no_extract_hashtags: bool | None = None, + no_extract_emojis: bool | None = None, + media_ids: list[str] | None = None, + ) -> dict[str, Any]: """Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API.""" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} if text is not None: data["text"] = text @@ -537,9 +539,9 @@ async def create_note( async def upload_file( self, file_path: str, - name: Optional[str] = None, - folder_id: Optional[str] = None, - ) -> Dict[str, Any]: + name: str | None = None, + folder_id: str | None = None, + ) -> dict[str, Any]: """Upload a file to Misskey drive/files/create and return a dict containing id and raw result.""" if not file_path: raise APIError("No file path provided for upload") @@ -574,7 +576,7 @@ async def upload_file( logger.error(f"[Misskey API] 文件上传网络错误: {e}") raise APIConnectionError(f"Upload failed: {e}") from e - async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]: + async def find_files_by_hash(self, md5_hash: str) -> list[dict[str, Any]]: """Find files by MD5 hash""" if not md5_hash: raise APIError("No MD5 hash provided for find-by-hash") @@ -593,13 +595,13 @@ async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]: raise async def find_files_by_name( - self, name: str, folder_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, name: str, folder_id: str | None = None + ) -> list[dict[str, Any]]: """Find files by name""" if not name: raise APIError("No name provided for find") - data: Dict[str, Any] = {"name": name} + data: dict[str, Any] = {"name": name} if folder_id: data["folderId"] = folder_id @@ -617,11 +619,11 @@ async def find_files_by_name( async def find_files( self, limit: int = 10, - folder_id: Optional[str] = None, - type: Optional[str] = None, - ) -> List[Dict[str, Any]]: + folder_id: str | None = None, + type: str | None = None, + ) -> list[dict[str, Any]]: """List files with optional filters""" - data: Dict[str, Any] = {"limit": limit} + data: dict[str, Any] = {"limit": limit} if folder_id is not None: data["folderId"] = folder_id if type is not None: @@ -642,7 +644,7 @@ async def find_files( async def _download_with_existing_session( self, url: str, ssl_verify: bool = True - ) -> Optional[bytes]: + ) -> bytes | None: """使用现有会话下载文件""" if not (hasattr(self, "session") and self.session): raise APIConnectionError("No existing session available") @@ -656,7 +658,7 @@ async def _download_with_existing_session( async def _download_with_temp_session( self, url: str, ssl_verify: bool = True - ) -> Optional[bytes]: + ) -> bytes | None: """使用临时会话下载文件""" connector = aiohttp.TCPConnector(ssl=ssl_verify) async with aiohttp.ClientSession(connector=connector) as temp_session: @@ -670,11 +672,11 @@ async def _download_with_temp_session( async def upload_and_find_file( self, url: str, - name: Optional[str] = None, - folder_id: Optional[str] = None, + name: str | None = None, + folder_id: str | None = None, max_wait_time: float = 30.0, check_interval: float = 2.0, - ) -> Optional[Dict[str, Any]]: + ) -> dict[str, Any] | None: """ 简化的文件上传:尝试 URL 上传,失败则下载后本地上传 @@ -732,13 +734,13 @@ async def upload_and_find_file( return None - async def get_current_user(self) -> Dict[str, Any]: + async def get_current_user(self) -> dict[str, Any]: """获取当前用户信息""" return await self._make_request("i", {}) async def send_message( - self, user_id_or_payload: Any, text: Optional[str] = None - ) -> Dict[str, Any]: + self, user_id_or_payload: Any, text: str | None = None + ) -> dict[str, Any]: """发送聊天消息。 Accepts either (user_id: str, text: str) or a single dict payload prepared by caller. @@ -754,8 +756,8 @@ async def send_message( return result async def send_room_message( - self, room_id_or_payload: Any, text: Optional[str] = None - ) -> Dict[str, Any]: + self, room_id_or_payload: Any, text: str | None = None + ) -> dict[str, Any]: """发送房间消息。 Accepts either (room_id: str, text: str) or a single dict payload. @@ -771,10 +773,10 @@ async def send_room_message( return result async def get_messages( - self, user_id: str, limit: int = 10, since_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, user_id: str, limit: int = 10, since_id: str | None = None + ) -> list[dict[str, Any]]: """获取聊天消息历史""" - data: Dict[str, Any] = {"userId": user_id, "limit": limit} + data: dict[str, Any] = {"userId": user_id, "limit": limit} if since_id: data["sinceId"] = since_id @@ -785,10 +787,10 @@ async def get_messages( return [] async def get_mentions( - self, limit: int = 10, since_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, limit: int = 10, since_id: str | None = None + ) -> list[dict[str, Any]]: """获取提及通知""" - data: Dict[str, Any] = {"limit": limit} + data: dict[str, Any] = {"limit": limit} if since_id: data["sinceId"] = since_id data["includeTypes"] = ["mention", "reply", "quote"] @@ -806,11 +808,11 @@ async def send_message_with_media( self, message_type: str, target_id: str, - text: Optional[str] = None, - media_urls: Optional[List[str]] = None, - local_files: Optional[List[str]] = None, + text: str | None = None, + media_urls: list[str] | None = None, + local_files: list[str] | None = None, **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ 通用消息发送函数:统一处理文本+媒体发送 @@ -846,7 +848,7 @@ async def send_message_with_media( message_type, target_id, text, file_ids, **kwargs ) - async def _process_media_urls(self, urls: List[str]) -> List[str]: + async def _process_media_urls(self, urls: list[str]) -> list[str]: """处理远程媒体文件URL列表,返回文件ID列表""" file_ids = [] for url in urls: @@ -863,7 +865,7 @@ async def _process_media_urls(self, urls: List[str]) -> List[str]: continue return file_ids - async def _process_local_files(self, file_paths: List[str]) -> List[str]: + async def _process_local_files(self, file_paths: list[str]) -> list[str]: """处理本地文件路径列表,返回文件ID列表""" file_ids = [] for file_path in file_paths: @@ -883,10 +885,10 @@ async def _dispatch_message( self, message_type: str, target_id: str, - text: Optional[str], - file_ids: List[str], + text: str | None, + file_ids: list[str], **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """根据消息类型分发到对应的发送方法""" if message_type == "chat": # 聊天消息使用 fileId (单数) diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index cd737f78e..0b4a7b876 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -1,6 +1,6 @@ import asyncio import re -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import PlatformMetadata, AstrBotMessage diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index ebc95d8d7..135744167 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -1,6 +1,6 @@ """Misskey 平台适配器通用工具函数""" -from typing import Dict, Any, List, Tuple, Optional, Union +from typing import Any import astrbot.api.message_components as Comp from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType @@ -9,7 +9,7 @@ class FileIDExtractor: """从 API 响应中提取文件 ID 的帮助类(无状态)。""" @staticmethod - def extract_file_id(result: Any) -> Optional[str]: + def extract_file_id(result: Any) -> str | None: if not isinstance(result, dict): return None @@ -34,8 +34,8 @@ class MessagePayloadBuilder: @staticmethod def build_chat_payload( - user_id: str, text: Optional[str], file_id: Optional[str] = None - ) -> Dict[str, Any]: + user_id: str, text: str | None, file_id: str | None = None + ) -> dict[str, Any]: payload = {"toUserId": user_id} if text: payload["text"] = text @@ -45,8 +45,8 @@ def build_chat_payload( @staticmethod def build_room_payload( - room_id: str, text: Optional[str], file_id: Optional[str] = None - ) -> Dict[str, Any]: + room_id: str, text: str | None, file_id: str | None = None + ) -> dict[str, Any]: payload = {"toRoomId": room_id} if text: payload["text"] = text @@ -56,9 +56,9 @@ def build_room_payload( @staticmethod def build_note_payload( - text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs - ) -> Dict[str, Any]: - payload: Dict[str, Any] = {} + text: str | None, file_ids: list[str] | None = None, **kwargs + ) -> dict[str, Any]: + payload: dict[str, Any] = {} if text: payload["text"] = text if file_ids: @@ -67,7 +67,7 @@ def build_note_payload( return payload -def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]: +def serialize_message_chain(chain: list[Any]) -> tuple[str, bool]: """将消息链序列化为文本字符串""" text_parts = [] has_at = False @@ -113,12 +113,12 @@ def process_component(component): def resolve_message_visibility( - user_id: Optional[str] = None, - user_cache: Optional[Dict[str, Any]] = None, - self_id: Optional[str] = None, - raw_message: Optional[Dict[str, Any]] = None, + user_id: str | None = None, + user_cache: dict[str, Any] | None = None, + self_id: str | None = None, + raw_message: dict[str, Any] | None = None, default_visibility: str = "public", -) -> Tuple[str, Optional[List[str]]]: +) -> tuple[str, list[str] | None]: """解析 Misskey 消息的可见性设置 可以从 user_cache 或 raw_message 中解析,支持两种调用方式: @@ -169,13 +169,13 @@ def resolve_message_visibility( # 保留旧函数名作为向后兼容的别名 def resolve_visibility_from_raw_message( - raw_message: Dict[str, Any], self_id: Optional[str] = None -) -> Tuple[str, Optional[List[str]]]: + raw_message: dict[str, Any], self_id: str | None = None +) -> tuple[str, list[str] | None]: """从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)""" return resolve_message_visibility(raw_message=raw_message, self_id=self_id) -def is_valid_user_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_user_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -189,7 +189,7 @@ def is_valid_user_session_id(session_id: Union[str, Any]) -> bool: ) -def is_valid_room_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_room_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的房间 session_id (仅限room%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -203,7 +203,7 @@ def is_valid_room_session_id(session_id: Union[str, Any]) -> bool: ) -def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_chat_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -236,7 +236,7 @@ def extract_room_id_from_session_id(session_id: str) -> str: def add_at_mention_if_needed( - text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False + text: str, user_info: dict[str, Any] | None, has_at: bool = False ) -> str: """如果需要且没有@用户,则添加@用户 @@ -258,7 +258,7 @@ def add_at_mention_if_needed( return text -def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]: +def create_file_component(file_info: dict[str, Any]) -> tuple[Any, str]: """创建文件组件和描述文本""" file_url = file_info.get("url", "") file_name = file_info.get("name", "未知文件") @@ -287,7 +287,7 @@ def process_files( return file_parts -def format_poll(poll: Dict[str, Any]) -> str: +def format_poll(poll: dict[str, Any]) -> str: """将 Misskey 的 poll 对象格式化为可读字符串。""" if not poll or not isinstance(poll, dict): return "" @@ -304,8 +304,8 @@ def format_poll(poll: Dict[str, Any]) -> str: def extract_sender_info( - raw_data: Dict[str, Any], is_chat: bool = False -) -> Dict[str, Any]: + raw_data: dict[str, Any], is_chat: bool = False +) -> dict[str, Any]: """提取发送者信息""" if is_chat: sender = raw_data.get("fromUser", {}) @@ -323,11 +323,11 @@ def extract_sender_info( def create_base_message( - raw_data: Dict[str, Any], - sender_info: Dict[str, Any], + raw_data: dict[str, Any], + sender_info: dict[str, Any], client_self_id: str, is_chat: bool = False, - room_id: Optional[str] = None, + room_id: str | None = None, unique_session: bool = False, ) -> AstrBotMessage: """创建基础消息对象""" @@ -367,7 +367,7 @@ def create_base_message( def process_at_mention( message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str -) -> Tuple[List[str], str]: +) -> tuple[list[str], str]: """处理@提及逻辑,返回消息部分列表和处理后的文本""" message_parts = [] @@ -389,9 +389,9 @@ def process_at_mention( def cache_user_info( - user_cache: Dict[str, Any], - sender_info: Dict[str, Any], - raw_data: Dict[str, Any], + user_cache: dict[str, Any], + sender_info: dict[str, Any], + raw_data: dict[str, Any], client_self_id: str, is_chat: bool = False, ): @@ -417,7 +417,7 @@ def cache_user_info( def cache_room_info( - user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str + user_cache: dict[str, Any], raw_data: dict[str, Any], client_self_id: str ): """缓存房间信息""" room_data = raw_data.get("toRoom") @@ -437,7 +437,7 @@ def cache_room_info( async def resolve_component_url_or_path( comp: Any, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: """尝试从组件解析可上传的远程 URL 或本地路径。 返回 (url_candidate, local_path)。两者可能都为 None。 @@ -503,7 +503,7 @@ async def _get_str_value(coro_or_val): return url_candidate, local_path -def summarize_component_for_log(comp: Any) -> Dict[str, Any]: +def summarize_component_for_log(comp: Any) -> dict[str, Any]: """生成适合日志的组件属性字典(尽量不抛异常)。""" attrs = {} for a in ("file", "url", "path", "src", "source", "name"): @@ -519,9 +519,9 @@ def summarize_component_for_log(comp: Any) -> Dict[str, Any]: async def upload_local_with_retries( api: Any, local_path: str, - preferred_name: Optional[str], - folder_id: Optional[str], -) -> Optional[str]: + preferred_name: str | None, + folder_id: str | None, +) -> str | None: """尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。""" try: res = await api.upload_file(local_path, preferred_name, folder_id) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 2096237ce..8dc061812 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -16,7 +16,6 @@ from astrbot.api import logger from botpy.types.message import Media from botpy.types import message -from typing import Optional import random import uuid import os @@ -196,7 +195,7 @@ async def upload_group_and_c2c_image( async def upload_group_and_c2c_record( self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs - ) -> Optional[Media]: + ) -> Media | None: """ 上传媒体文件 """ diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index d5285f759..9acd822b3 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -19,7 +19,6 @@ ) from astrbot import logger from astrbot.api.event import MessageChain -from typing import Union, List from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion from .qqofficial_message_event import QQOfficialMessageEvent @@ -33,7 +32,7 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: "QQOfficialPlatformAdapter"): + def set_platform(self, platform: QQOfficialPlatformAdapter): self.platform = platform # 收到群消息 @@ -133,7 +132,7 @@ def meta(self) -> PlatformMetadata: @staticmethod def _parse_from_qqofficial( - message: Union[botpy.message.Message, botpy.message.GroupMessage], + message: botpy.message.Message | botpy.message.GroupMessage, message_type: MessageType, ): abm = AstrBotMessage() @@ -142,7 +141,7 @@ def _parse_from_qqofficial( abm.raw_message = message abm.message_id = message.id abm.tag = "qq_official" - msg: List[BaseMessageComponent] = [] + msg: list[BaseMessageComponent] = [] if isinstance(message, botpy.message.GroupMessage) or isinstance( message, botpy.message.C2CMessage diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index a3f4f53ec..06de8bd7f 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -3,7 +3,6 @@ import time import websockets from websockets.asyncio.client import connect -from typing import Optional from aiohttp import ClientSession, ClientTimeout from websockets.asyncio.client import ClientConnection from astrbot.api import logger @@ -57,12 +56,12 @@ def __init__( id=self.config["id"], ) - self.ws: Optional[ClientConnection] = None - self.session: Optional[ClientSession] = None + self.ws: ClientConnection | None = None + self.session: ClientSession | None = None self.sequence = 0 self.logins = [] self.running = False - self.heartbeat_task: Optional[asyncio.Task] = None + self.heartbeat_task: asyncio.Task | None = None self.ready_received = False async def send_by_session( @@ -295,10 +294,10 @@ async def convert_satori_message( message: dict, user: dict, channel: dict, - guild: Optional[dict], + guild: dict | None, login: dict, - timestamp: Optional[int] = None, - ) -> Optional[AstrBotMessage]: + timestamp: int | None = None, + ) -> AstrBotMessage | None: try: abm = AstrBotMessage() abm.message_id = message.get("id", "") @@ -438,7 +437,7 @@ def _extract_namespace_prefixes(self, content: str) -> set: return prefixes - async def _extract_quote_element(self, content: str) -> Optional[dict]: + async def _extract_quote_element(self, content: str) -> dict | None: """提取标签信息""" try: # 处理命名空间前缀问题 @@ -506,7 +505,7 @@ async def _extract_quote_element(self, content: str) -> Optional[dict]: logger.error(f"提取标签时发生错误: {e}") return None - async def _extract_quote_with_regex(self, content: str) -> Optional[dict]: + async def _extract_quote_with_regex(self, content: str) -> dict | None: """使用正则表达式提取quote标签信息""" import re @@ -529,7 +528,7 @@ async def _extract_quote_with_regex(self, content: str) -> Optional[dict]: "content_without_quote": content_without_quote, } - async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]: + async def _convert_quote_message(self, quote: dict) -> AstrBotMessage | None: """转换引用消息""" try: quote_abm = AstrBotMessage() diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index 7877e4f52..74b4570ff 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -3,7 +3,8 @@ import hashlib import asyncio import logging -from typing import Callable, Optional + +from collections.abc import Callable from quart import Quart, request, Response from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.socket_mode.aiohttp import SocketModeClient @@ -22,7 +23,7 @@ def __init__( host: str = "0.0.0.0", port: int = 3000, path: str = "/slack/events", - event_handler: Optional[Callable] = None, + event_handler: Callable | None = None, ): self.web_client = web_client self.signing_secret = signing_secret @@ -119,7 +120,7 @@ def __init__( self, web_client: AsyncWebClient, app_token: str, - event_handler: Optional[Callable] = None, + event_handler: Callable | None = None, ): self.web_client = web_client self.app_token = app_token diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 7e75f3c20..4ad8433cf 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -4,7 +4,9 @@ import aiohttp import re import base64 -from typing import Awaitable, Any +from typing import Any + +from collections.abc import Awaitable from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.socket_mode.request import SocketModeRequest from astrbot.api.platform import ( diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 86f9f9764..692c03be2 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -1,6 +1,6 @@ import asyncio import re -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from slack_sdk.web.async_client import AsyncWebClient from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index faec122ac..16de7edfa 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -2,7 +2,9 @@ import asyncio import uuid import os -from typing import Awaitable, Any, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.core.platform import ( Platform, AstrBotMessage, diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 6b835ecb5..be979a78e 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -4,7 +4,6 @@ import os import traceback import time -from typing import Optional import aiohttp import anyio @@ -137,7 +136,7 @@ def load_credentials(self): """ if os.path.exists(self.credentials_file): try: - with open(self.credentials_file, "r") as f: + with open(self.credentials_file) as f: credentials = json.load(f) logger.info("成功加载 WeChatPadPro 凭据。") return credentials @@ -540,7 +539,7 @@ async def _process_chat_type( async def _get_group_member_nickname( self, group_id: str, member_wxid: str - ) -> Optional[str]: + ) -> str | None: """ 通过接口获取群成员的昵称。 """ @@ -896,7 +895,7 @@ async def get_contact_list(self): async def get_contact_details_list( self, room_wx_id_list: list[str] = None, user_names: list[str] = None - ) -> Optional[dict]: + ) -> dict | None: """ 获取联系人详情列表。 """ diff --git a/astrbot/core/platform/sources/wecom/wecom_kf.py b/astrbot/core/platform/sources/wecom/wecom_kf.py index 118667975..8d838da97 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 5332942b9..4a0508051 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- encoding:utf-8 -*- """对企业微信发送给企业后台的消息加解密示例代码. @copyright: Copyright (c) 1998-2020 Tencent Inc. @@ -136,7 +135,7 @@ def decode(self, decrypted): return decrypted[:-pad] -class Prpcrypt(object): +class Prpcrypt: """提供接收和推送给企业微信消息的加解密接口""" def __init__(self, key): @@ -210,7 +209,7 @@ def get_random_str(self): return str(random.randint(1000000000000000, 9999999999999999)).encode() -class WXBizJsonMsgCrypt(object): +class WXBizJsonMsgCrypt: # 构造函数 def __init__(self, sToken, sEncodingAESKey, sReceiveId): try: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py index cc1bf221e..0df14a505 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- ######################################################################### # Author: jonyqin # Created Time: Thu 11 Sep 2014 01:53:58 PM CST diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 830d8de58..3198dc28f 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -9,7 +9,9 @@ import uuid import hashlib import base64 -from typing import Awaitable, Any, Dict, Optional, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.api.platform import ( @@ -151,8 +153,8 @@ async def _handle_queued_message(self, data: dict): logger.error(f"处理队列消息时发生异常: {e}") async def _process_message( - self, message_data: Dict[str, Any], callback_params: Dict[str, str] - ) -> Optional[str]: + self, message_data: dict[str, Any], callback_params: dict[str, str] + ) -> str | None: """处理接收到的消息 Args: @@ -278,15 +280,15 @@ async def _process_message( return None pass - def _extract_session_id(self, message_data: Dict[str, Any]) -> str: + def _extract_session_id(self, message_data: dict[str, Any]) -> str: """从消息数据中提取会话ID""" user_id = message_data.get("from", {}).get("userid", "default_user") return format_session_id("wecomai", user_id) async def _enqueue_message( self, - message_data: Dict[str, Any], - callback_params: Dict[str, str], + message_data: dict[str, Any], + callback_params: dict[str, str], stream_id: str, session_id: str, ): diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 540bf06b6..ea34dee32 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -6,7 +6,7 @@ import json import base64 import hashlib -from typing import Dict, Any, Optional, Tuple, Union +from typing import Any from Crypto.Cipher import AES import aiohttp @@ -31,7 +31,7 @@ def __init__(self, token: str, encoding_aes_key: str): async def decrypt_message( self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str - ) -> Tuple[int, Optional[Dict[str, Any]]]: + ) -> tuple[int, dict[str, Any] | None]: """解密企业微信消息 Args: @@ -71,7 +71,7 @@ async def decrypt_message( async def encrypt_message( self, plain_message: str, nonce: str, timestamp: str - ) -> Optional[str]: + ) -> str | None: """加密消息 Args: @@ -127,8 +127,8 @@ def verify_url( return "verify fail" async def process_encrypted_image( - self, image_url: str, aes_key_base64: Optional[str] = None - ) -> Tuple[bool, Union[bytes, str]]: + self, image_url: str, aes_key_base64: str | None = None + ) -> tuple[bool, bytes | str]: """下载并解密加密图片 Args: @@ -292,7 +292,7 @@ class WecomAIBotMessageParser: """企业微信智能机器人消息解析器""" @staticmethod - def parse_text_message(data: Dict[str, Any]) -> Optional[str]: + def parse_text_message(data: dict[str, Any]) -> str | None: """解析文本消息 Args: @@ -308,7 +308,7 @@ def parse_text_message(data: Dict[str, Any]) -> Optional[str]: return None @staticmethod - def parse_image_message(data: Dict[str, Any]) -> Optional[str]: + def parse_image_message(data: dict[str, Any]) -> str | None: """解析图片消息 Args: @@ -324,7 +324,7 @@ def parse_image_message(data: Dict[str, Any]) -> Optional[str]: return None @staticmethod - def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def parse_stream_message(data: dict[str, Any]) -> dict[str, Any] | None: """解析流消息 Args: @@ -346,7 +346,7 @@ def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: return None @staticmethod - def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]: + def parse_mixed_message(data: dict[str, Any]) -> list | None: """解析混合消息 Args: @@ -362,7 +362,7 @@ def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]: return None @staticmethod - def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def parse_event_message(data: dict[str, Any]) -> dict[str, Any] | None: """解析事件消息 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 1367301c9..3086e9513 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -5,7 +5,7 @@ """ import asyncio -from typing import Dict, Any, Optional +from typing import Any from astrbot.api import logger @@ -13,13 +13,13 @@ class WecomAIQueueMgr: """企业微信智能机器人队列管理器""" def __init__(self) -> None: - self.queues: Dict[str, asyncio.Queue] = {} + self.queues: dict[str, asyncio.Queue] = {} """StreamID 到输入队列的映射 - 用于接收用户消息""" - self.back_queues: Dict[str, asyncio.Queue] = {} + self.back_queues: dict[str, asyncio.Queue] = {} """StreamID 到输出队列的映射 - 用于发送机器人响应""" - self.pending_responses: Dict[str, Dict[str, Any]] = {} + self.pending_responses: dict[str, dict[str, Any]] = {} """待处理的响应缓存,用于流式响应""" def get_or_create_queue(self, session_id: str) -> asyncio.Queue: @@ -90,7 +90,7 @@ def has_back_queue(self, session_id: str) -> bool: """ return session_id in self.back_queues - def set_pending_response(self, session_id: str, callback_params: Dict[str, str]): + def set_pending_response(self, session_id: str, callback_params: dict[str, str]): """设置待处理的响应参数 Args: @@ -103,7 +103,7 @@ def set_pending_response(self, session_id: str, callback_params: Dict[str, str]) } logger.debug(f"[WecomAI] 设置待处理响应: {session_id}") - def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]: + def get_pending_response(self, session_id: str) -> dict[str, Any] | None: """获取待处理的响应参数 Args: @@ -131,7 +131,7 @@ def cleanup_expired_responses(self, max_age_seconds: int = 300): del self.pending_responses[session_id] logger.debug(f"[WecomAI] 清理过期响应: {session_id}") - def get_stats(self) -> Dict[str, int]: + def get_stats(self) -> dict[str, int]: """获取队列统计信息 Returns: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index bbb69d041..27b78dd78 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -4,7 +4,9 @@ """ import asyncio -from typing import Dict, Any, Optional, Callable +from typing import Any + +from collections.abc import Callable import quart from astrbot.api import logger @@ -21,9 +23,8 @@ def __init__( host: str, port: int, api_client: WecomAIBotAPIClient, - message_handler: Optional[ - Callable[[Dict[str, Any], Dict[str, str]], Any] - ] = None, + message_handler: None + | (Callable[[dict[str, Any], dict[str, str]], Any]) = None, ): """初始化服务器 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index dccb2e260..cc4361581 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -10,7 +10,7 @@ import aiohttp import asyncio from Crypto.Cipher import AES -from typing import Any, Tuple +from typing import Any from astrbot.api import logger @@ -91,7 +91,7 @@ def format_session_id(session_type: str, session_id: str) -> str: return f"wecom_ai_bot_{session_type}_{session_id}" -def parse_session_id(formatted_session_id: str) -> Tuple[str, str]: +def parse_session_id(formatted_session_id: str) -> tuple[str, str]: """解析格式化的会话 ID Args: @@ -145,7 +145,7 @@ def format_error_response(error_code: int, error_msg: str) -> str: async def process_encrypted_image( image_url: str, aes_key_base64: str -) -> Tuple[bool, str]: +) -> tuple[bool, str]: """下载并解密加密图片 Args: diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 85687c417..f59365283 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -4,7 +4,7 @@ from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field -from typing import List, Dict, Type, Any +from typing import Any from astrbot.core.agent.tool import ToolSet from openai.types.chat.chat_completion import ChatCompletion from google.genai.types import GenerateContentResponse @@ -32,7 +32,7 @@ class ProviderMetaData: desc: str = "" """提供商适配器描述.""" provider_type: ProviderType = ProviderType.CHAT_COMPLETION - cls_type: Type | None = None + cls_type: type | None = None default_config_tmpl: dict | None = None """平台的默认配置模板""" @@ -61,7 +61,7 @@ class AssistantMessageSegment: """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" content: str | None = None - tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list) + tool_calls: list[ChatCompletionMessageToolCall | dict] = field(default_factory=list) role: str = "assistant" def to_dict(self): @@ -84,10 +84,10 @@ class ToolCallsResult: tool_calls_info: AssistantMessageSegment """函数调用的信息""" - tool_calls_result: List[ToolCallMessageSegment] + tool_calls_result: list[ToolCallMessageSegment] """函数调用的结果""" - def to_openai_messages(self) -> List[Dict]: + def to_openai_messages(self) -> list[dict]: ret = [ self.tool_calls_info.to_dict(), *[item.to_dict() for item in self.tool_calls_result], @@ -175,7 +175,7 @@ def _print_friendly_context(self): return result_parts - async def assemble_context(self) -> Dict: + async def assemble_context(self) -> dict: """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。""" if self.image_urls: user_content = { @@ -219,15 +219,15 @@ class LLMResponse: """角色, assistant, tool, err""" result_chain: MessageChain | None = None """返回的消息链""" - tools_call_args: List[Dict[str, Any]] = field(default_factory=list) + tools_call_args: list[dict[str, Any]] = field(default_factory=list) """工具调用参数""" - tools_call_name: List[str] = field(default_factory=list) + tools_call_name: list[str] = field(default_factory=list) """工具调用名称""" - tools_call_ids: List[str] = field(default_factory=list) + tools_call_ids: list[str] = field(default_factory=list) """工具调用 ID""" raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None - _new_record: Dict[str, Any] | None = None + _new_record: dict[str, Any] | None = None _completion_text: str = "" @@ -239,11 +239,11 @@ def __init__( role: str, completion_text: str = "", result_chain: MessageChain | None = None, - tools_call_args: List[Dict[str, Any]] | None = None, - tools_call_name: List[str] | None = None, - tools_call_ids: List[str] | None = None, + tools_call_args: list[dict[str, Any]] | None = None, + tools_call_name: list[str] | None = None, + tools_call_ids: list[str] | None = None, raw_completion: ChatCompletion | None = None, - _new_record: Dict[str, Any] | None = None, + _new_record: dict[str, Any] | None = None, is_chunk: bool = False, ): """初始化 LLMResponse @@ -291,7 +291,7 @@ def completion_text(self, value): else: self._completion_text = value - def to_openai_tool_calls(self) -> List[Dict]: + def to_openai_tool_calls(self) -> list[dict]: """将工具调用信息转换为 OpenAI 格式""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 51cde0eb9..ceba07076 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,7 +4,9 @@ import asyncio import aiohttp -from typing import Dict, List, Awaitable, Callable, Any +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot import logger from astrbot.core import sp @@ -96,10 +98,10 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class FunctionToolManager: def __init__(self) -> None: - self.func_list: List[FuncTool] = [] - self.mcp_client_dict: Dict[str, MCPClient] = {} + self.func_list: list[FuncTool] = [] + self.mcp_client_dict: dict[str, MCPClient] = {} """MCP 服务列表""" - self.mcp_client_event: Dict[str, asyncio.Event] = {} + self.mcp_client_event: dict[str, asyncio.Event] = {} def empty(self) -> bool: return len(self.func_list) == 0 @@ -202,8 +204,8 @@ async def init_mcp_clients(self) -> None: logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return - mcp_server_json_obj: Dict[str, Dict] = json.load( - open(mcp_json_file, "r", encoding="utf-8") + mcp_server_json_obj: dict[str, dict] = json.load( + open(mcp_json_file, encoding="utf-8") )["mcpServers"] for name in mcp_server_json_obj.keys(): @@ -479,7 +481,7 @@ def load_mcp_config(self): return DEFAULT_MCP_CONFIG try: - with open(self.mcp_config_path, "r", encoding="utf-8") as f: + with open(self.mcp_config_path, encoding="utf-8") as f: return json.load(f) except Exception as e: logger.error(f"加载 MCP 配置失败: {e}") diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 9953e9f17..8e48a11a3 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,7 +1,6 @@ import abc import asyncio -from typing import List -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from astrbot.core.agent.tool import ToolSet from astrbot.core.provider.entities import ( LLMResponse, @@ -67,7 +66,7 @@ def __init__( def get_current_key(self) -> str: raise NotImplementedError() - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: """获得提供商 Key""" keys = self.provider_config.get("key", [""]) return keys or [""] @@ -77,7 +76,7 @@ def set_key(self, key: str): raise NotImplementedError() @abc.abstractmethod - async def get_models(self) -> List[str]: + async def get_models(self) -> list[str]: """获得支持的模型列表""" raise NotImplementedError() @@ -140,7 +139,7 @@ async def text_chat_stream( """ ... - async def pop_record(self, context: List): + async def pop_record(self, context: list): """ 弹出 context 第一条非系统提示词对话记录 """ diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 02d7934d1..31b82bb8a 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,11 +1,10 @@ -from typing import List, Dict from .entities import ProviderMetaData, ProviderType from astrbot.core import logger from .func_tool_manager import FuncCall -provider_registry: List[ProviderMetaData] = [] +provider_registry: list[ProviderMetaData] = [] """维护了通过装饰器注册的 Provider""" -provider_cls_map: Dict[str, ProviderMetaData] = {} +provider_cls_map: dict[str, ProviderMetaData] = {} """维护了 Provider 类型名称和 ProviderMetadata 的映射""" llm_tools = FuncCall() diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index cd4206ce7..8b3a8772a 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,7 +1,6 @@ import json import anthropic import base64 -from typing import List from mimetypes import guess_type from anthropic import AsyncAnthropic @@ -13,7 +12,7 @@ from astrbot.core.provider.func_tool_manager import ToolSet from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse -from typing import AsyncGenerator +from collections.abc import AsyncGenerator @register_provider_adapter( @@ -33,7 +32,7 @@ def __init__( ) self.chosen_api_key: str = "" - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else "" self.base_url = provider_config.get("api_base", "https://api.anthropic.com") self.timeout = provider_config.get("timeout", 120) @@ -326,7 +325,7 @@ async def text_chat_stream( async for llm_response in self._query_stream(payloads, func_tool): yield llm_response - async def assemble_context(self, text: str, image_urls: List[str] | None = None): + async def assemble_context(self, text: str, image_urls: list[str] | None = None): """组装上下文,支持文本和图片""" if not image_urls: return {"role": "user", "content": text} @@ -384,7 +383,7 @@ async def encode_image_bs64(self, image_url: str) -> str: def get_current_key(self) -> str: return self.chosen_api_key - async def get_models(self) -> List[str]: + async def get_models(self) -> list[str]: models_str = [] models = await self.client.models.list() models = sorted(models.data, key=lambda x: x.id) diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 6ddf452d4..5ef101b20 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -6,7 +6,6 @@ import random import asyncio from pathlib import Path -from typing import Dict from xml.sax.saxutils import escape from httpx import AsyncClient, Timeout @@ -21,7 +20,7 @@ class OTTSProvider: - def __init__(self, config: Dict): + def __init__(self, config: dict): self.skey = config["OTTS_SKEY"] self.api_url = config["OTTS_URL"] self.auth_time_url = config["OTTS_AUTH_TIME"] @@ -58,7 +57,7 @@ async def _generate_signature(self) -> str: path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/" return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}" - async def get_audio(self, text: str, voice_params: Dict) -> str: + async def get_audio(self, text: str, voice_params: dict) -> str: file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav" signature = await self._generate_signature() for attempt in range(self.retry_count): diff --git a/astrbot/core/provider/sources/coze_api_client.py b/astrbot/core/provider/sources/coze_api_client.py index a768979c6..3fe02d8e5 100644 --- a/astrbot/core/provider/sources/coze_api_client.py +++ b/astrbot/core/provider/sources/coze_api_client.py @@ -2,7 +2,9 @@ import asyncio import aiohttp import io -from typing import Dict, List, Any, AsyncGenerator +from typing import Any + +from collections.abc import AsyncGenerator from astrbot.core import logger @@ -117,12 +119,12 @@ async def chat_messages( self, bot_id: str, user_id: str, - additional_messages: List[Dict] | None = None, + additional_messages: list[dict] | None = None, conversation_id: str | None = None, auto_save_history: bool = True, stream: bool = True, timeout: float = 120, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[dict[str, Any], None]: """发送聊天消息并返回流式响应 Args: diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py index 639af0814..93887900c 100644 --- a/astrbot/core/provider/sources/coze_source.py +++ b/astrbot/core/provider/sources/coze_source.py @@ -2,7 +2,8 @@ import os import base64 import hashlib -from typing import AsyncGenerator, Dict + +from collections.abc import AsyncGenerator from astrbot.core.message.message_event_result import MessageChain import astrbot.core.message.components as Comp from astrbot.api.provider import Provider @@ -44,8 +45,8 @@ def __init__( if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.auto_save_history = provider_config.get("auto_save_history", True) - self.conversation_ids: Dict[str, str] = {} - self.file_id_cache: Dict[str, Dict[str, str]] = {} + self.conversation_ids: dict[str, str] = {} + self.file_id_cache: dict[str, dict[str, str]] = {} # 创建 API 客户端 self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index efda31ca9..54e043f6a 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,7 +3,6 @@ import logging import os import uuid -from typing import Optional, Tuple import aiohttp import dashscope from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer @@ -80,7 +79,7 @@ def _call_qwen_tts(self, model: str, text: str): async def _synthesize_with_qwen_tts( self, model: str, text: str - ) -> Tuple[Optional[bytes], str]: + ) -> tuple[bytes | None, str]: loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, self._call_qwen_tts, model, text) audio_bytes = await self._extract_audio_from_response(response) @@ -91,7 +90,7 @@ async def _synthesize_with_qwen_tts( ext = ".wav" return audio_bytes, ext - async def _extract_audio_from_response(self, response) -> Optional[bytes]: + async def _extract_audio_from_response(self, response) -> bytes | None: output = getattr(response, "output", None) audio_obj = getattr(output, "audio", None) if output is not None else None if not audio_obj: @@ -110,7 +109,7 @@ async def _extract_audio_from_response(self, response) -> Optional[bytes]: return await self._download_audio_from_url(url) return None - async def _download_audio_from_url(self, url: str) -> Optional[bytes]: + async def _download_audio_from_url(self, url: str) -> bytes | None: if not url: return None timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20 @@ -126,7 +125,7 @@ async def _download_audio_from_url(self, url: str) -> Optional[bytes]: async def _synthesize_with_cosyvoice( self, model: str, text: str - ) -> Tuple[Optional[bytes], str]: + ) -> tuple[bytes | None, str]: synthesizer = SpeechSynthesizer( model=model, voice=self.voice, diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index b14a9bdcb..1865741e0 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,7 +3,6 @@ import json import logging import random -from typing import Optional, List from collections.abc import AsyncGenerator from google import genai @@ -60,11 +59,11 @@ def __init__( provider_settings, default_persona, ) - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else "" self.timeout: int = int(provider_config.get("timeout", 180)) - self.api_base: Optional[str] = provider_config.get("api_base", None) + self.api_base: str | None = provider_config.get("api_base", None) if self.api_base and self.api_base.endswith("/"): self.api_base = self.api_base[:-1] @@ -122,9 +121,9 @@ async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool: async def _prepare_query_config( self, payloads: dict, - tools: Optional[ToolSet] = None, - system_instruction: Optional[str] = None, - modalities: Optional[list[str]] = None, + tools: ToolSet | None = None, + system_instruction: str | None = None, + modalities: list[str] | None = None, temperature: float = 0.7, ) -> types.GenerateContentConfig: """准备查询配置""" @@ -406,7 +405,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: conversation = self._prepare_conversation(payloads) temperature = payloads.get("temperature", 0.7) - result: Optional[types.GenerateContentResponse] = None + result: types.GenerateContentResponse | None = None while True: try: config = await self._prepare_query_config( diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 5b210835b..66901497f 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -2,7 +2,8 @@ import os import uuid import aiohttp -from typing import Dict, List, Union, AsyncIterator + +from collections.abc import AsyncIterator from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.api import logger from ..entities import ProviderType @@ -30,7 +31,7 @@ def __init__( self.is_timber_weight: bool = provider_config.get( "minimax-is-timber-weight", False ) - self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads( + self.timber_weight: list[dict[str, str | int]] = json.loads( provider_config.get( "minimax-timber-weight", '[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]', @@ -66,7 +67,7 @@ def __init__( def _build_tts_stream_body(self, text: str): """构建流式请求体""" - dict_body: Dict[str, object] = { + dict_body: dict[str, object] = { "model": self.model_name, "text": text, "stream": True, diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 09c284acb..6174ca9d9 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -17,7 +17,8 @@ from astrbot.api.provider import Provider from astrbot import logger from astrbot.core.provider.func_tool_manager import ToolSet -from typing import List, AsyncGenerator + +from collections.abc import AsyncGenerator from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse, ToolCallsResult @@ -38,7 +39,7 @@ def __init__( default_persona, ) self.chosen_api_key = None - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): @@ -280,7 +281,7 @@ async def _handle_api_error( context_query: list, func_tool: ToolSet, chosen_key: str, - available_api_keys: List[str], + available_api_keys: list[str], retry_cnt: int, max_retries: int, ) -> tuple: @@ -497,7 +498,7 @@ async def text_chat_stream( raise Exception("未知错误") raise last_exception - async def _remove_image_from_context(self, contexts: List): + async def _remove_image_from_context(self, contexts: list): """ 从上下文中删除所有带有 image 的记录 """ @@ -521,14 +522,14 @@ async def _remove_image_from_context(self, contexts: List): def get_current_key(self) -> str: return self.client.api_key - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key): self.client.api_key = key async def assemble_context( - self, text: str, image_urls: List[str] | None = None + self, text: str, image_urls: list[str] | None = None ) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index 23a522dc1..8d9acbcb2 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -2,13 +2,12 @@ 此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta """ -from typing import Union import os import json from astrbot.core.utils.astrbot_path import get_astrbot_data_path -def load_config(namespace: str) -> Union[dict, bool]: +def load_config(namespace: str) -> dict | bool: """ 从配置文件中加载配置。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 @@ -17,7 +16,7 @@ def load_config(namespace: str) -> Union[dict, bool]: path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): return False - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: ret = {} data = json.load(f) for k in data: @@ -51,7 +50,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): if not os.path.exists(path): with open(path, "w", encoding="utf-8-sig") as f: f.write("{}") - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: d = json.load(f) assert isinstance(d, dict) if key not in d: @@ -78,7 +77,7 @@ def update_config(namespace: str, key: str, value): path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。") - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: d = json.load(f) assert isinstance(d, dict) if key not in d: diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 0229f4dbb..ce795d767 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,5 +1,4 @@ from asyncio import Queue -from typing import List, Union from astrbot.core.provider.provider import ( Provider, @@ -25,7 +24,9 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter from .filter.regex import RegexFilter -from typing import Awaitable, Any, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, @@ -42,7 +43,7 @@ class Context: registered_web_apis: list = [] # back compatibility - _register_tasks: List[Awaitable] = [] + _register_tasks: list[Awaitable] = [] _star_manager = None def __init__( @@ -78,7 +79,7 @@ def get_registered_star(self, star_name: str) -> StarMetadata | None: if star.name == star_name: return star - def get_all_stars(self) -> List[StarMetadata]: + def get_all_stars(self) -> list[StarMetadata]: """获取当前载入的所有插件 Metadata 的列表""" return star_registry @@ -116,19 +117,19 @@ def get_provider_by_id( prov = self.provider_manager.inst_map.get(provider_id) return prov - def get_all_providers(self) -> List[Provider]: + def get_all_providers(self) -> list[Provider]: """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" return self.provider_manager.provider_insts - def get_all_tts_providers(self) -> List[TTSProvider]: + def get_all_tts_providers(self) -> list[TTSProvider]: """获取所有用于 TTS 任务的 Provider。""" return self.provider_manager.tts_provider_insts - def get_all_stt_providers(self) -> List[STTProvider]: + def get_all_stt_providers(self) -> list[STTProvider]: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts - def get_all_embedding_providers(self) -> List[EmbeddingProvider]: + def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts @@ -196,9 +197,7 @@ def get_event_queue(self) -> Queue: return self._event_queue @deprecated(version="4.0.0", reason="Use get_platform_inst instead") - def get_platform( - self, platform_type: Union[PlatformAdapterType, str] - ) -> Platform | None: + def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: """ 获取指定类型的平台适配器。 @@ -231,7 +230,7 @@ def get_platform_inst(self, platform_id: str) -> Platform | None: return platform async def send_message( - self, session: Union[str, MessageSesion], message_chain: MessageChain + self, session: str | MessageSesion, message_chain: MessageChain ) -> bool: """ 根据 session(unified_msg_origin) 主动发送消息。 diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 3d67cb750..9adc528ee 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -2,7 +2,7 @@ import inspect import types import typing -from typing import List, Any, Type, Dict +from typing import Any from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig @@ -37,14 +37,14 @@ def __init__( command_name: str, alias: set | None = None, handler_md: StarHandlerMetadata | None = None, - parent_command_names: List[str] = [""], + parent_command_names: list[str] = [""], ): self.command_name = command_name self.alias = alias if alias else set() self.parent_command_names = parent_command_names if handler_md: self.init_handler_md(handler_md) - self.custom_filter_list: List[CustomFilter] = [] + self.custom_filter_list: list[CustomFilter] = [] # Cache for complete command names list self._cmpl_cmd_names: list | None = None @@ -89,8 +89,8 @@ def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: return True def validate_and_convert_params( - self, params: List[Any], param_type: Dict[str, Type] - ) -> Dict[str, Any]: + self, params: list[Any], param_type: dict[str, type] + ) -> dict[str, Any]: """将参数列表 params 根据 param_type 转换为参数字典。""" result = {} param_items = list(param_type.items()) @@ -111,7 +111,7 @@ def validate_and_convert_params( # 没有 GreedyStr 的情况 if i >= len(params): if ( - isinstance(param_type_or_default_val, (Type, types.UnionType)) + isinstance(param_type_or_default_val, (type, types.UnionType)) or typing.get_origin(param_type_or_default_val) is typing.Union or param_type_or_default_val is inspect.Parameter.empty ): diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index e01fa2c58..2f11f193b 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -1,6 +1,5 @@ from __future__ import annotations -from typing import List, Union from . import HandlerFilter from .command import CommandFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -18,22 +17,22 @@ def __init__( ): self.group_name = group_name self.alias = alias if alias else set() - self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = [] - self.custom_filter_list: List[CustomFilter] = [] + self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = [] + self.custom_filter_list: list[CustomFilter] = [] self.parent_group = parent_group # Cache for complete command names list self._cmpl_cmd_names: list | None = None def add_sub_command_filter( - self, sub_command_filter: Union[CommandFilter, CommandGroupFilter] + self, sub_command_filter: CommandFilter | CommandGroupFilter ): self.sub_command_filters.append(sub_command_filter) def add_custom_filter(self, custom_filter: CustomFilter): self.custom_filter_list.append(custom_filter) - def get_complete_command_names(self) -> List[str]: + def get_complete_command_names(self) -> list[str]: """遍历父节点获取完整的指令名。 新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。""" @@ -59,7 +58,7 @@ def get_complete_command_names(self) -> List[str]: # 以树的形式打印出来 def print_cmd_tree( self, - sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], + sub_command_filters: list[CommandFilter | CommandGroupFilter], prefix: str = "", event: AstrMessageEvent | None = None, cfg: AstrBotConfig | None = None, diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index d1c5a6dce..c61e09a86 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -12,7 +12,9 @@ from ..filter.permission import PermissionTypeFilter, PermissionType from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr from ..filter.regex import RegexFilter -from typing import Awaitable, Any, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools from astrbot.core.agent.agent import Agent @@ -220,9 +222,7 @@ def decorator(obj): class RegisteringCommandable: """用于指令组级联注册""" - group: Callable[..., Callable[..., "RegisteringCommandable"]] = ( - register_command_group - ) + group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group command: Callable[..., Callable[..., None]] = register_command custom_filter: Callable[..., Callable[..., None]] = register_custom_filter diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index 94a0c8a4d..313ab7c8c 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -3,7 +3,6 @@ """ from astrbot.core import sp, logger -from typing import Dict, List from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -95,7 +94,7 @@ def set_plugin_status_for_session( ) @staticmethod - def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]: + def get_session_plugin_config(session_id: str) -> dict[str, list[str]]: """获取指定会话的插件配置 Args: @@ -112,7 +111,7 @@ def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]: ) @staticmethod - def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List: + def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list: """根据会话配置过滤处理器列表 Args: diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 80b5adb60..5dc076bca 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,7 +1,9 @@ from __future__ import annotations import enum from dataclasses import dataclass, field -from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic +from typing import Any, TypeVar, Generic + +from collections.abc import Callable, Awaitable from .filter import HandlerFilter from .star import star_map @@ -10,8 +12,8 @@ class StarHandlerRegistry(Generic[T]): def __init__(self): - self.star_handlers_map: Dict[str, StarHandlerMetadata] = {} - self._handlers: List[StarHandlerMetadata] = [] + self.star_handlers_map: dict[str, StarHandlerMetadata] = {} + self._handlers: list[StarHandlerMetadata] = [] def append(self, handler: StarHandlerMetadata): """添加一个 Handler,并保持按优先级有序""" @@ -31,7 +33,7 @@ def get_handlers_by_event_type( event_type: EventType, only_activated=True, plugins_name: list[str] | None = None, - ) -> List[StarHandlerMetadata]: + ) -> list[StarHandlerMetadata]: handlers = [] for handler in self._handlers: # 过滤事件类型 @@ -65,7 +67,7 @@ def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None def get_handlers_by_module_name( self, module_name: str - ) -> List[StarHandlerMetadata]: + ) -> list[StarHandlerMetadata]: return [ handler for handler in self._handlers @@ -126,7 +128,7 @@ class StarHandlerMetadata: handler: Callable[..., Awaitable[Any]] """Handler 的函数对象,应当是一个异步函数""" - event_filters: List[HandlerFilter] + event_filters: list[HandlerFilter] """一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件""" desc: str = "" diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 6f9dfe2fa..d37e79189 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -22,7 +22,9 @@ import os import uuid from pathlib import Path -from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar +from typing import Any, ClassVar + +from collections.abc import Awaitable, Callable from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType @@ -44,7 +46,7 @@ class StarTools: 这些方法封装了一些常用操作,使插件开发更加简单便捷! """ - _context: ClassVar[Optional[Context]] = None + _context: ClassVar[Context | None] = None @classmethod def initialize(cls, context: Context) -> None: @@ -58,7 +60,7 @@ def initialize(cls, context: Context) -> None: @classmethod async def send_message( - cls, session: Union[str, MessageSesion], message_chain: MessageChain + cls, session: str | MessageSesion, message_chain: MessageChain ) -> bool: """ 根据session(unified_msg_origin)主动发送消息 @@ -122,7 +124,7 @@ async def create_message( self_id: str, session_id: str, sender: MessageMember, - message: List[BaseMessageComponent], + message: list[BaseMessageComponent], message_str: str, message_id: str = "", raw_message: object = None, @@ -254,7 +256,7 @@ def unregister_llm_tool(cls, name: str) -> None: cls._context.unregister_llm_tool(name) @classmethod - def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: + def get_data_dir(cls, plugin_name: str | None = None) -> Path: """ 返回插件数据目录的绝对路径。 diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index 15a6b71fb..18be474c8 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -2,7 +2,9 @@ import json from astrbot.core import logger from aiohttp import ClientSession, ClientResponse -from typing import Dict, List, Any, AsyncGenerator +from typing import Any + +from collections.abc import AsyncGenerator async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: @@ -39,14 +41,14 @@ def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): async def chat_messages( self, - inputs: Dict, + inputs: dict, query: str, user: str, response_mode: str = "streaming", conversation_id: str = "", - files: List[Dict[str, Any]] = [], + files: list[dict[str, Any]] = [], timeout: float = 60, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[dict[str, Any], None]: url = f"{self.api_base}/chat-messages" payload = locals() payload.pop("self") @@ -65,10 +67,10 @@ async def chat_messages( async def workflow_run( self, - inputs: Dict, + inputs: dict, user: str, response_mode: str = "streaming", - files: List[Dict[str, Any]] = [], + files: list[dict[str, Any]] = [], timeout: float = 60, ): url = f"{self.api_base}/workflows/run" @@ -91,7 +93,7 @@ async def file_upload( self, file_path: str, user: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: url = f"{self.api_base}/files/upload" with open(file_path, "rb") as f: payload = { diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 7fe9bde05..19a50976e 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -21,7 +21,7 @@ def get_installation_id(): if os.path.exists(id_file): try: - with open(id_file, "r") as f: + with open(id_file) as f: Metric._iid_cache = f.read().strip() return Metric._iid_cache except Exception: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index c27a54113..ae7737a7a 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -8,11 +8,13 @@ import functools import copy import astrbot.core.message.components as Comp -from typing import Dict, Any, Callable, Awaitable, List +from typing import Any + +from collections.abc import Callable, Awaitable from astrbot.core.platform import AstrMessageEvent -USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 -FILTERS: List["SessionFilter"] = [] # 存储 SessionFilter 实例 +USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 +FILTERS: list["SessionFilter"] = [] # 存储 SessionFilter 实例 class SessionController: @@ -29,7 +31,7 @@ def __init__(self): self.timeout: float | int = None """上次保持(keep)开始时的超时时间""" - self.history_chains: List[List[Comp.BaseMessageComponent]] = [] + self.history_chains: list[list[Comp.BaseMessageComponent]] = [] def stop(self, error: Exception = None): """立即结束这个会话""" @@ -81,7 +83,7 @@ async def _holding(self, event: asyncio.Event, timeout: int): pass # 避免报错 # finally: - def get_history_chains(self) -> List[List[Comp.BaseMessageComponent]]: + def get_history_chains(self) -> list[list[Comp.BaseMessageComponent]]: """获取历史消息链""" return self.history_chains diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 19eab2efe..bc9f7dad9 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -66,7 +66,7 @@ class TextMeasurer: """测量文本尺寸的工具类""" @staticmethod - def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: + def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> tuple[int, int]: """获取文本的尺寸""" try: # PIL 9.0.0 以上版本 @@ -82,7 +82,7 @@ def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: @staticmethod def split_text_to_fit_width( text: str, font: ImageFont.FreeTypeFont, max_width: int - ) -> List[str]: + ) -> list[str]: """将文本拆分为多行,确保每行不超过指定宽度""" lines = [] if not text: @@ -532,7 +532,7 @@ def render( class CodeBlockElement(MarkdownElement): """代码块元素""" - def __init__(self, content: List[str]): + def __init__(self, content: list[str]): super().__init__("\n".join(content)) def calculate_height(self, image_width: int, font_size: int) -> int: @@ -705,7 +705,7 @@ class MarkdownParser: """Markdown解析器,将文本解析为元素""" @staticmethod - async def parse(text: str) -> List[MarkdownElement]: + async def parse(text: str) -> list[MarkdownElement]: elements = [] lines = text.split("\n") @@ -847,7 +847,7 @@ def __init__( self, font_size: int = 26, width: int = 800, - bg_color: Tuple[int, int, int] = (255, 255, 255), + bg_color: tuple[int, int, int] = (255, 255, 255), ): self.font_size = font_size self.width = width diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index b441a908e..9ae422947 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -43,7 +43,7 @@ def _get_user_template_path(self, name: str) -> str: def _read_file(self, path: str) -> str: """读取文件内容。""" - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return f.read() def list_templates(self) -> list[dict]: diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 8fd89919a..38c3defd2 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -300,7 +300,7 @@ async def get_tool_list(self): """获取所有注册的工具列表""" try: tools = self.tool_mgr.func_list - tools_dict = [tool.__dict__() for tool in tools] + tools_dict = [tool.to_dict() for tool in tools] return Response().ok(data=tools_dict).__dict__ except Exception as e: logger.error(traceback.format_exc())