From a107921bc92f14fd27ae334943a221861205b37a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 27 Nov 2025 15:30:53 +0800 Subject: [PATCH] perf: enhance provider management with reload locking and logging - Introduced a reload lock to prevent concurrent reloads of providers. - Added logging to indicate when a provider is disabled and when providers are being synchronized with the configuration. - Refactored the reload method to improve clarity and maintainability. --- astrbot/core/provider/manager.py | 74 ++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 115a48463..3e477255a 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,7 +1,7 @@ import asyncio import traceback -from astrbot.core import logger, sp +from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase @@ -24,6 +24,7 @@ def __init__( db_helper: BaseDatabase, persona_mgr: PersonaManager, ): + self.reload_lock = asyncio.Lock() self.persona_mgr = persona_mgr self.acm = acm config = acm.confs["default"] @@ -226,6 +227,7 @@ async def initialize(self): async def load_provider(self, provider_config: dict): if not provider_config["enable"]: + logger.info(f"Provider {provider_config['id']} is disabled, skipping") return if provider_config.get("provider_type", "") == "agent_runner": return @@ -434,40 +436,46 @@ async def load_provider(self, provider_config: dict): ) async def reload(self, provider_config: dict): - await self.terminate_provider(provider_config["id"]) - if provider_config["enable"]: - await self.load_provider(provider_config) - - # 和配置文件保持同步 - config_ids = [provider["id"] for provider in self.providers_config] - logger.debug(f"providers in user's config: {config_ids}") - for key in list(self.inst_map.keys()): - if key not in config_ids: - await self.terminate_provider(key) - - if len(self.provider_insts) == 0: - self.curr_provider_inst = None - elif self.curr_provider_inst is None and len(self.provider_insts) > 0: - self.curr_provider_inst = self.provider_insts[0] - logger.info( - f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。", - ) + async with self.reload_lock: + await self.terminate_provider(provider_config["id"]) + if provider_config["enable"]: + await self.load_provider(provider_config) - if len(self.stt_provider_insts) == 0: - self.curr_stt_provider_inst = None - elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0: - self.curr_stt_provider_inst = self.stt_provider_insts[0] - logger.info( - f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。", - ) + # 和配置文件保持同步 + self.providers_config = astrbot_config["provider"] + config_ids = [provider["id"] for provider in self.providers_config] + logger.info(f"providers in user's config: {config_ids}") + for key in list(self.inst_map.keys()): + if key not in config_ids: + await self.terminate_provider(key) - if len(self.tts_provider_insts) == 0: - self.curr_tts_provider_inst = None - elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0: - self.curr_tts_provider_inst = self.tts_provider_insts[0] - logger.info( - f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。", - ) + if len(self.provider_insts) == 0: + self.curr_provider_inst = None + elif self.curr_provider_inst is None and len(self.provider_insts) > 0: + self.curr_provider_inst = self.provider_insts[0] + logger.info( + f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。", + ) + + if len(self.stt_provider_insts) == 0: + self.curr_stt_provider_inst = None + elif ( + self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0 + ): + self.curr_stt_provider_inst = self.stt_provider_insts[0] + logger.info( + f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。", + ) + + if len(self.tts_provider_insts) == 0: + self.curr_tts_provider_inst = None + elif ( + self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0 + ): + self.curr_tts_provider_inst = self.tts_provider_insts[0] + logger.info( + f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。", + ) def get_insts(self): return self.provider_insts