Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions astrbot/api/event/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent,
register_on_star_activated as on_star_activated,
register_on_star_deactivated as on_star_deactivated,
)

from astrbot.core.star.filter.event_message_type import (
Expand Down Expand Up @@ -46,4 +48,6 @@
"on_decorating_result",
"after_message_sent",
"on_llm_response",
"on_star_activated",
"on_star_deactivated",
]
4 changes: 4 additions & 0 deletions astrbot/core/star/register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
register_agent,
register_on_decorating_result,
register_after_message_sent,
register_on_star_activated,
register_on_star_deactivated,
)

__all__ = [
Expand All @@ -32,4 +34,6 @@
"register_agent",
"register_on_decorating_result",
"register_after_message_sent",
"register_on_star_activated",
"register_on_star_deactivated",
]
28 changes: 28 additions & 0 deletions astrbot/core/star/register/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,31 @@ def decorator(awaitable):
return awaitable

return decorator


def register_on_star_activated(star_name: str = None, **kwargs):
"""当指定插件被激活时"""

def decorator(awaitable):
handler_md = get_handler_or_create(
awaitable, EventType.OnStarActivatedEvent, **kwargs
)
if star_name:
handler_md.extras_configs["target_star_name"] = star_name
return awaitable

return decorator


def register_on_star_deactivated(star_name: str = None, **kwargs):
"""当指定插件被停用时"""

def decorator(awaitable):
handler_md = get_handler_or_create(
awaitable, EventType.OnStarDeactivatedEvent, **kwargs
)
if star_name:
handler_md.extras_configs["target_star_name"] = star_name
return awaitable

return decorator
2 changes: 2 additions & 0 deletions astrbot/core/star/star.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class StarMetadata:
"""插件版本"""
repo: str | None = None
"""插件仓库地址"""
dependencies: list[str] = field(default_factory=list)
"""插件依赖列表"""

star_cls_type: type[Star] | None = None
"""插件的类对象的类型"""
Expand Down
2 changes: 2 additions & 0 deletions astrbot/core/star/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class EventType(enum.Enum):
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
OnAfterMessageSentEvent = enum.auto() # 发送消息后

OnStarActivatedEvent = enum.auto() # 插件启用
OnStarDeactivatedEvent = enum.auto() # 插件禁用

@dataclass
class StarHandlerMetadata:
Expand Down
201 changes: 165 additions & 36 deletions astrbot/core/star/star_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from .star import star_map, star_registry
from .star_handler import star_handlers_registry
from .updator import PluginUpdator
from .star_handler import EventType, StarHandlerMetadata
import networkx as nx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果不用 networkx 的话可以实现嘛?可能目前基于图的长期记忆还暂时不会引入。

tips: 抱歉这么久没处理;这个 PR 打算放到 releases/4.0.0 分支中;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上可行, 即在 2340c5c 基础上修改,之前稍微尝试了一下感觉比较冗长(时间有点久了,只记得至少有一处细节手动处理起来比较复杂,而交由networkx管理相对清晰)并且不易扩展(比如目前想到的是依赖项后跟版本号进行更细致的区分)


try:
from watchfiles import PythonFilter, awatch
Expand Down Expand Up @@ -144,13 +146,11 @@ def _get_modules(path):
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
os.path.join(path, d, d + ".py")
):
modules.append(
{
"pname": d,
"module": module_str,
"module_path": os.path.join(path, d, module_str),
}
)
modules.append({
"pname": d,
"module": module_str,
"module_path": os.path.join(path, d, module_str),
})
return modules

def _get_plugin_modules(self) -> list[dict]:
Expand Down Expand Up @@ -226,6 +226,7 @@ def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | N
desc=metadata["desc"],
version=metadata["version"],
repo=metadata["repo"] if "repo" in metadata else None,
dependencies=metadata.get("dependencies", []),
)

return metadata
Expand Down Expand Up @@ -321,25 +322,17 @@ async def reload(self, specified_plugin_name=None):
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
plugin_modules = await self._get_load_order()
result = await self.load(plugin_modules=plugin_modules)
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
if smd.name:
await self._unbind_plugin(smd.name, specified_module_path)

result = await self.load(specified_module_path)
result = await self.batch_reload(
specified_module_path=specified_module_path
)

return result

async def load(self, specified_module_path=None, specified_dir_name=None):
async def load(self, plugin_modules=None):
"""载入插件。
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。

Expand All @@ -356,10 +349,11 @@ async def load(self, specified_module_path=None, specified_dir_name=None):
inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", [])
alter_cmd = await sp.global_get("alter_cmd", {})

plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"

logger.info(
f"正在按顺序加载插件: {[plugin_module['pname'] for plugin_module in plugin_modules]}"
)
fail_rec = ""

# 导入插件模块,并尝试实例化插件类
Expand All @@ -375,12 +369,6 @@ async def load(self, specified_module_path=None, specified_dir_name=None):
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str

# 检查是否需要载入指定的插件
if specified_module_path and path != specified_module_path:
continue
if specified_dir_name and root_dir_name != specified_dir_name:
continue

logger.info(f"正在载入插件 {root_dir_name} ...")

# 尝试导入模块
Expand Down Expand Up @@ -451,6 +439,9 @@ async def load(self, specified_module_path=None, specified_dir_name=None):
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
await self._trigger_star_lifecycle_event(
EventType.OnStarActivatedEvent, metadata
)
else:
logger.info(f"插件 {metadata.name} 已被禁用。")

Expand Down Expand Up @@ -622,7 +613,8 @@ async def install_plugin(self, repo_url: str, proxy=""):
plugin_path = await self.updator.install(repo_url, proxy)
# reload the plugin
dir_name = os.path.basename(plugin_path)
await self.load(specified_dir_name=dir_name)
plugin_modules = await self._get_load_order(specified_dir_name=dir_name)
await self.batch_reload(plugin_modules=plugin_modules)

# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
Expand Down Expand Up @@ -778,8 +770,7 @@ async def turn_off_plugin(self, plugin_name: str):

plugin.activated = False

@staticmethod
async def _terminate_plugin(star_metadata: StarMetadata):
async def _terminate_plugin(self, star_metadata: StarMetadata):
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
logger.info(f"正在终止插件 {star_metadata.name} ...")

Expand All @@ -788,14 +779,18 @@ async def _terminate_plugin(star_metadata: StarMetadata):
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
return

await self._trigger_star_lifecycle_event(
EventType.OnStarDeactivatedEvent, star_metadata
)

if star_metadata.star_cls is None:
return

if '__del__' in star_metadata.star_cls_type.__dict__:
if "__del__" in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
elif 'terminate' in star_metadata.star_cls_type.__dict__:
elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()

async def turn_on_plugin(self, plugin_name: str):
Expand Down Expand Up @@ -832,7 +827,8 @@ async def install_plugin_from_file(self, zip_file_path: str):
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
# await self.reload()
await self.load(specified_dir_name=dir_name)
plugin_modules = await self._get_load_order(specified_dir_name=dir_name)
await self.batch_reload(plugin_modules=plugin_modules)

# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
Expand Down Expand Up @@ -865,3 +861,136 @@ async def install_plugin_from_file(self, zip_file_path: str):
}

return plugin_info

async def _trigger_star_lifecycle_event(
self, event_type: EventType, star_metadata: StarMetadata
):
"""
内部辅助函数,用于触发插件(Star)相关的生命周期事件。
Args:
event_type: 要触发的事件类型 (EventType.OnStarActivatedEvent 或 EventType.OnStarDeactivatedEvent)。
star_metadata: 触发事件的插件的 StarMetadata 对象。
"""
handlers_to_run: list[StarHandlerMetadata] = []
# 获取所有监听该事件类型的 handlers
handlers = star_handlers_registry.get_handlers_by_event_type(event_type)

for handler in handlers:
# 检查这个 handler 是否监听了特定的插件名
target_star_name = handler.extras_configs.get("target_star_name")
if target_star_name and target_star_name == star_metadata.name:
# 如果指定了目标插件名,则只在匹配时添加
handlers_to_run.append(handler)

for handler in handlers_to_run:
try:
# 调用插件的钩子函数,并传入 StarMetadata 对象
logger.info(
f"hook({event_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} (目标插件: {star_metadata.name})"
)
await handler.handler(star_metadata) # 传递参数
except Exception:
logger.error(
f"执行插件 {handler.handler_name} 的 {event_type.name} 钩子时出错: {traceback.format_exc()}"
)

def _get_plugin_dir_path(self, root_dir_name: str, is_reserved: bool) -> str:
"""根据插件的根目录名和是否为保留插件,返回插件的完整文件路径。"""
return (
os.path.join(self.plugin_store_path, root_dir_name)
if not is_reserved
else os.path.join(self.reserved_plugin_path, root_dir_name)
)

def _build_module_path(self, plugin_module_info: dict) -> str:
"""根据插件模块信息构建完整的模块路径。"""
reserved = plugin_module_info.get("reserved", False)
path_prefix = "packages." if reserved else "data.plugins."
return (
f"{path_prefix}{plugin_module_info['pname']}.{plugin_module_info['module']}"
)

async def _get_load_order(
self, specified_dir_name: str = None, specified_module_path: str = None
):
star_graph = self._build_star_graph()
if star_graph is None:
return None
try:
if specified_dir_name:
for node in star_graph:
if (
star_graph.nodes[node]["data"].get("pname")
== specified_dir_name
):
dependent_nodes = nx.descendants(star_graph, node)
sub_graph = star_graph.subgraph(dependent_nodes.union({node}))
load_order = list(nx.topological_sort(sub_graph))
return [star_graph.nodes[node]["data"] for node in load_order]
elif specified_module_path:
for node in star_graph:
if specified_module_path == self._build_module_path(
star_graph.nodes[node].get("data")
):
dependent_nodes = nx.descendants(star_graph, node)
sub_graph = star_graph.subgraph(dependent_nodes.union({node}))
load_order = list(nx.topological_sort(sub_graph))
return [star_graph.nodes[node]["data"] for node in load_order]
else:
return [
star_graph.nodes[node]["data"]
for node in list(nx.topological_sort(star_graph))
]
except nx.NetworkXUnfeasible:
logger.error("出现循环依赖,无法确定加载顺序,按自然顺序加载")
return [star_graph.nodes[node]["data"] for node in star_graph]

def _build_star_graph(self):
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return None
G = nx.DiGraph()
for plugin_module in plugin_modules:
root_dir_name = plugin_module["pname"]
is_reserved = plugin_module.get("reserved", False)
plugin_dir_path = self._get_plugin_dir_path(root_dir_name, is_reserved)
G.add_node(root_dir_name, data=plugin_module)
try:
metadata = self._load_plugin_metadata(plugin_dir_path)
if metadata:
for dep_name in metadata.dependencies:
G.add_edge(root_dir_name, dep_name)
except Exception:
pass
# 过滤不存在的依赖(出边没有data, 就删除指向的节点)
nodes_to_remove = []
for node_name in list(G.nodes()):
for neighbor in list(G.neighbors(node_name)):
if G.nodes[neighbor].get("data") is None:
nodes_to_remove.append(neighbor)
logger.warning(
f"插件 {node_name} 声明依赖 {neighbor}, 但该插件未被发现,跳过加载。"
)
for node in nodes_to_remove:
G.remove_node(node)
return G

async def batch_reload(self, specified_module_path=None, plugin_modules=None):
if not plugin_modules:
plugin_modules = await self._get_load_order(
specified_module_path=specified_module_path
)
for plugin_module in plugin_modules:
specified_module_path = self._build_module_path(plugin_module)
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
await self._unbind_plugin(smd.name, specified_module_path)

return await self.load(plugin_modules=plugin_modules)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
"watchfiles>=1.0.5",
"websockets>=15.0.1",
"wechatpy>=1.8.18",
"networkx>=3.4.2",
]

[project.scripts]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ faiss-cpu
aiosqlite
py-cord>=2.6.1
slack-sdk
pydub
pydub
networkx