From 9ff8570385ffc92c5360eb6072accba748e46a4a Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Fri, 19 Sep 2025 16:48:10 +0800 Subject: [PATCH 1/4] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dinst=5Fmap?= =?UTF-8?q?=E6=9C=AA=E4=BA=92=E6=96=A5=E8=AE=BF=E9=97=AE=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E7=9A=84=E9=87=8D=E8=BD=BD=E7=8A=B6=E6=80=81=E4=B8=8D=E4=B8=80?= =?UTF-8?q?=E8=87=B4,=20=E5=A2=9E=E5=8A=A0=E5=89=8D=E7=AB=AF=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E5=90=AF=E7=94=A8=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/manager.py | 175 ++++++++++++++++++++------- astrbot/dashboard/routes/config.py | 58 ++++++++- dashboard/src/views/ProviderPage.vue | 21 ++-- 3 files changed, 197 insertions(+), 57 deletions(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 3b50e4976..5685f157e 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -40,6 +40,8 @@ def __init__( """加载的 Embedding Provider 的实例""" self.inst_map: dict[str, Provider] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" + self._inst_map_lock = asyncio.Lock() + """inst_map 操作的互斥锁""" self.llm_tools = llm_tools self.curr_provider_inst: Provider | None = None @@ -77,8 +79,13 @@ async def set_provider( Version 4.0.0: 这个版本下已经默认隔离提供商 """ - if provider_id not in self.inst_map: - raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") + async with self._inst_map_lock: + if provider_id not in self.inst_map: + raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") + if not umo: + # 不启用提供商会话隔离模式的情况 + self.curr_provider_inst = self.inst_map[provider_id] + if umo: await sp.session_put( umo, @@ -86,8 +93,6 @@ async def set_provider( provider_id, ) return - # 不启用提供商会话隔离模式的情况 - self.curr_provider_inst = self.inst_map[provider_id] if provider_type == ProviderType.TEXT_TO_SPEECH: sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global") elif provider_type == ProviderType.SPEECH_TO_TEXT: @@ -97,7 +102,8 @@ async def set_provider( async def get_provider_by_id(self, provider_id: str) -> Provider | None: """根据提供商 ID 获取提供商实例""" - return self.inst_map.get(provider_id) + async with self._inst_map_lock: + return self.inst_map.get(provider_id) def get_using_provider(self, provider_type: ProviderType, umo=None): """获取正在使用的提供商实例。 @@ -173,17 +179,18 @@ async def initialize(self): scope="global", scope_id="global", ) - self.curr_provider_inst = self.inst_map.get(selected_provider_id) - if not self.curr_provider_inst and self.provider_insts: - self.curr_provider_inst = self.provider_insts[0] + async with self._inst_map_lock: + self.curr_provider_inst = self.inst_map.get(selected_provider_id) + if not self.curr_provider_inst and self.provider_insts: + self.curr_provider_inst = self.provider_insts[0] - self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) - if not self.curr_stt_provider_inst and self.stt_provider_insts: - self.curr_stt_provider_inst = self.stt_provider_insts[0] + self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) + if not self.curr_stt_provider_inst and self.stt_provider_insts: + self.curr_stt_provider_inst = self.stt_provider_insts[0] - self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) - if not self.curr_tts_provider_inst and self.tts_provider_insts: - self.curr_tts_provider_inst = self.tts_provider_insts[0] + self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) + if not self.curr_tts_provider_inst and self.tts_provider_insts: + self.curr_tts_provider_inst = self.tts_provider_insts[0] # 初始化 MCP Client 连接 asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") @@ -376,8 +383,8 @@ async def load_provider(self, provider_config: dict): if getattr(inst, "initialize", None): await inst.initialize() self.embedding_provider_insts.append(inst) - - self.inst_map[provider_config["id"]] = inst + async with self._inst_map_lock: + self.inst_map[provider_config["id"]] = inst except Exception as e: logger.error(traceback.format_exc()) logger.error( @@ -385,16 +392,85 @@ async def load_provider(self, provider_config: dict): ) async def reload(self, provider_config: dict): - await self.terminate_provider(provider_config["id"]) - if provider_config["enable"]: - await self.load_provider(provider_config) + provider_id = provider_config["id"] - # 和配置文件保持同步 + # 只有禁用的直接终止 + if not provider_config["enable"]: + async with self._inst_map_lock: + if provider_id in self.inst_map: + await self.terminate_provider(provider_id) + return + + # 备份一份旧实例 + old_instance = None + async with self._inst_map_lock: + old_instance = self.inst_map.get(provider_id) + + try: + # 使用临时 ID 加载新实例 + temp_config = provider_config.copy() + temp_id = f"{provider_id}_reload_temp" + temp_config["id"] = temp_id + + await self.load_provider(temp_config) + + # 新实例加载成功,替换旧实例 + async with self._inst_map_lock: + if temp_id in self.inst_map: + new_instance = self.inst_map[temp_id] + del self.inst_map[temp_id] + + new_instance.provider_config["id"] = provider_id + self.inst_map[provider_id] = new_instance + + if old_instance: + if old_instance in self.provider_insts: + idx = self.provider_insts.index(old_instance) + self.provider_insts[idx] = new_instance + if old_instance in self.stt_provider_insts: + idx = self.stt_provider_insts.index(old_instance) + self.stt_provider_insts[idx] = new_instance + if old_instance in self.tts_provider_insts: + idx = self.tts_provider_insts.index(old_instance) + self.tts_provider_insts[idx] = new_instance + if old_instance in self.embedding_provider_insts: + idx = self.embedding_provider_insts.index(old_instance) + self.embedding_provider_insts[idx] = new_instance + + # 更新当前实例引用 + if self.curr_provider_inst == old_instance: + self.curr_provider_inst = new_instance + if self.curr_stt_provider_inst == old_instance: + self.curr_stt_provider_inst = new_instance + if self.curr_tts_provider_inst == old_instance: + self.curr_tts_provider_inst = new_instance + + # 锁外清理旧实例 + if old_instance and hasattr(old_instance, "terminate"): + try: + await old_instance.terminate() + except Exception as e: + logger.warning(f"清理旧 Provider 实例时出错: {e}") + + except Exception as e: + # 清理临时实例 + async with self._inst_map_lock: + if temp_id in self.inst_map: + temp_instance = self.inst_map[temp_id] + del self.inst_map[temp_id] + if hasattr(temp_instance, "terminate"): + try: + await temp_instance.terminate() + except Exception: + pass + raise e + + # 清理已移除的提供商 config_ids = [provider["id"] for provider in self.providers_config] - logger.debug(f"providers in user's config: {config_ids}") - for key in list(self.inst_map.keys()): - if key not in config_ids: - await self.terminate_provider(key) + async with self._inst_map_lock: + for key in list(self.inst_map.keys()): + if key not in config_ids: + await self.terminate_provider(key) if len(self.provider_insts) == 0: self.curr_provider_inst = None @@ -424,32 +500,41 @@ def get_insts(self): return self.provider_insts async def terminate_provider(self, provider_id: str): - if provider_id in self.inst_map: - logger.info( - f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..." - ) - - if self.inst_map[provider_id] in self.provider_insts: - self.provider_insts.remove(self.inst_map[provider_id]) - if self.inst_map[provider_id] in self.stt_provider_insts: - self.stt_provider_insts.remove(self.inst_map[provider_id]) - if self.inst_map[provider_id] in self.tts_provider_insts: - self.tts_provider_insts.remove(self.inst_map[provider_id]) - - if self.inst_map[provider_id] == self.curr_provider_inst: + instance = None + logger.info( + f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..." + ) + # 锁内移除引用 + async with self._inst_map_lock: + instance = self.inst_map.get(provider_id) + if not instance: + return + del self.inst_map[provider_id] + if instance in self.provider_insts: + self.provider_insts.remove(instance) + if instance in self.stt_provider_insts: + self.stt_provider_insts.remove(instance) + if instance in self.tts_provider_insts: + self.tts_provider_insts.remove(instance) + if instance in self.embedding_provider_insts: + self.embedding_provider_insts.remove(instance) + if self.curr_provider_inst == instance: self.curr_provider_inst = None - if self.inst_map[provider_id] == self.curr_stt_provider_inst: + if self.curr_stt_provider_inst == instance: self.curr_stt_provider_inst = None - if self.inst_map[provider_id] == self.curr_tts_provider_inst: + if self.curr_tts_provider_inst == instance: self.curr_tts_provider_inst = None - if getattr(self.inst_map[provider_id], "terminate", None): - await self.inst_map[provider_id].terminate() # type: ignore + # 锁外终止实例 + if hasattr(instance, "terminate"): + try: + await instance.terminate() + except Exception as e: + logger.warning(f"终止提供商实例时出错: {e}") - logger.info( - f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" - ) - del self.inst_map[provider_id] + logger.info( + f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" + ) async def terminate(self): for provider_inst in self.provider_insts: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 0983cf8d5..896dfdb79 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -479,18 +479,70 @@ async def check_one_provider_status(self): logger.info(f"API call: /config/provider/check_one id={provider_id}") try: prov_mgr = self.core_lifecycle.provider_manager - target = prov_mgr.inst_map.get(provider_id) + # 配置里没有 + provider_config = None + for config in self.config["provider"]: + if config.get("id") == provider_id: + provider_config = config + break + + if not provider_config: + logger.warning( + f"Provider config with id '{provider_id}' not found in configuration." + ) + return ( + Response() + .error( + f"Provider with id '{provider_id}' not found in configuration" + ) + .__dict__ + ) + + # 没启用 + if not provider_config.get("enable", False): + logger.info(f"Provider with id '{provider_id}' is disabled.") + return ( + Response() + .ok( + { + "id": provider_id, + "model": provider_config.get("model", "Unknown Model"), + "type": provider_config.get( + "provider_type", "Unknown Type" + ), + "name": provider_config.get("name", provider_id), + "status": "disabled", + "error": "Provider is disabled", + } + ) + .__dict__ + ) + + # 先等待加载 + target = await prov_mgr.get_provider_by_id(provider_id) if not target: logger.warning( - f"Provider with id '{provider_id}' not found in provider_manager." + f"Provider with id '{provider_id}' is enabled but not loaded in provider_manager." ) return ( Response() - .error(f"Provider with id '{provider_id}' not found") + .ok( + { + "id": provider_id, + "model": provider_config.get("model", "Unknown Model"), + "type": provider_config.get( + "provider_type", "Unknown Type" + ), + "name": provider_config.get("name", provider_id), + "status": "not_loaded", + "error": "Provider is enabled but failed to load. Check logs for details.", + } + ) .__dict__ ) + # 已加载 result = await self._test_single_provider(target) return Response().ok(result).__dict__ diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index b36c0db21..01e60925c 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -830,19 +830,22 @@ export default { }, getStatusColor(status) { switch (status) { - case 'available': - return 'success'; - case 'unavailable': - return 'error'; - case 'pending': - return 'grey'; - default: - return 'default'; + case 'available': return 'success' + case 'unavailable': return 'error' + case 'disabled': return 'warning' + case 'not_loaded': return 'warning' + default: return 'info' } }, getStatusText(status) { - return this.messages.status[status] || status; + switch(status) { + case 'available': return '可用' + case 'unavailable': return '不可用' + case 'disabled': return '已禁用' + case 'not_loaded': return '加载失败' + default: return '未知' + } }, } } From 27f7efd2fb5939926a90a2c44fa50866b2d9d61d Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Fri, 19 Sep 2025 16:58:53 +0800 Subject: [PATCH 2/4] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E4=B8=80=E5=A4=84?= =?UTF-8?q?=E5=AD=97=E6=AE=B5=E5=90=8D=E7=A7=B0,=20=E6=9C=AA=E5=90=AF?= =?UTF-8?q?=E7=94=A8=E4=B8=BA=E4=BB=80=E4=B9=88=E5=8F=ABunavailable,=20?= =?UTF-8?q?=E8=BF=99=E4=B8=8D=E5=AF=B9=E7=9A=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dashboard/src/views/ProviderPage.vue | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 01e60925c..abae63f3d 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -753,7 +753,7 @@ export default { if (index !== -1) { const disabledStatus = { ...this.providerStatuses[index], - status: 'unavailable', + status: 'disabled', error: '该提供商未被用户启用' }; this.providerStatuses.splice(index, 1, disabledStatus); From 861642f08abdaffbd8fa17829c7e7621fb53f456 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Fri, 19 Sep 2025 17:38:15 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E6=96=B0?= =?UTF-8?q?=E7=8A=B6=E6=80=81=E6=A0=B7=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dashboard/src/views/ProviderPage.vue | 44 +++++++++++++++------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index abae63f3d..e09a4780f 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -102,27 +102,29 @@ - - mdi-check-circle - mdi-alert-circle - - - {{ status.id }} - - - {{ getStatusText(status.status) }} - - - - {{ tm('availability.errorMessage') }}: {{ status.error }} - + + mdi-check-circle + mdi-alert-circle + mdi-minus-circle + mdi-alert-circle-outline + + + {{ status.id }} + + + {{ getStatusText(status.status) }} + + + + {{ tm('availability.errorMessage') }}: {{ status.error }} + From 44345069d99e8b11eccc73b317846a57416b5d95 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Sun, 21 Sep 2025 18:36:43 +0800 Subject: [PATCH 4/4] style: format code --- astrbot/core/provider/manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index bc7c3d3c1..5ee86c885 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -532,7 +532,6 @@ async def terminate_provider(self, provider_id: str): if instance in self.embedding_provider_insts: self.embedding_provider_insts.remove(instance) if self.curr_provider_inst == instance: - self.curr_provider_inst = None if self.curr_stt_provider_inst == instance: self.curr_stt_provider_inst = None