diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 4b9204a68..5ee86c885 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -40,6 +40,8 @@ def __init__( """加载的 Embedding Provider 的实例""" self.inst_map: dict[str, Provider | STTProvider | TTSProvider] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" + self._inst_map_lock = asyncio.Lock() + """inst_map 操作的互斥锁""" self.llm_tools = llm_tools self.curr_provider_inst: Provider | None = None @@ -77,8 +79,13 @@ async def set_provider( Version 4.0.0: 这个版本下已经默认隔离提供商 """ - if provider_id not in self.inst_map: - raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") + async with self._inst_map_lock: + if provider_id not in self.inst_map: + raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") + if not umo: + # 不启用提供商会话隔离模式的情况 + self.curr_provider_inst = self.inst_map[provider_id] + if umo: await sp.session_put( umo, @@ -86,6 +93,7 @@ async def set_provider( provider_id, ) return + # 不启用提供商会话隔离模式的情况 prov = self.inst_map[provider_id] @@ -107,7 +115,8 @@ async def set_provider( async def get_provider_by_id(self, provider_id: str) -> Provider | None: """根据提供商 ID 获取提供商实例""" - return self.inst_map.get(provider_id) + async with self._inst_map_lock: + return self.inst_map.get(provider_id) def get_using_provider( self, provider_type: ProviderType, umo=None @@ -185,17 +194,18 @@ async def initialize(self): scope="global", scope_id="global", ) - self.curr_provider_inst = self.inst_map.get(selected_provider_id) - if not self.curr_provider_inst and self.provider_insts: - self.curr_provider_inst = self.provider_insts[0] + async with self._inst_map_lock: + self.curr_provider_inst = self.inst_map.get(selected_provider_id) + if not self.curr_provider_inst and self.provider_insts: + self.curr_provider_inst = self.provider_insts[0] - self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) - if not self.curr_stt_provider_inst and self.stt_provider_insts: - self.curr_stt_provider_inst = self.stt_provider_insts[0] + self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) + if not self.curr_stt_provider_inst and self.stt_provider_insts: + self.curr_stt_provider_inst = self.stt_provider_insts[0] - self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) - if not self.curr_tts_provider_inst and self.tts_provider_insts: - self.curr_tts_provider_inst = self.tts_provider_insts[0] + self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) + if not self.curr_tts_provider_inst and self.tts_provider_insts: + self.curr_tts_provider_inst = self.tts_provider_insts[0] # 初始化 MCP Client 连接 asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") @@ -386,8 +396,8 @@ async def load_provider(self, provider_config: dict): if getattr(inst, "initialize", None): await inst.initialize() self.embedding_provider_insts.append(inst) - - self.inst_map[provider_config["id"]] = inst + async with self._inst_map_lock: + self.inst_map[provider_config["id"]] = inst except Exception as e: logger.error(traceback.format_exc()) logger.error( @@ -395,16 +405,85 @@ async def load_provider(self, provider_config: dict): ) async def reload(self, provider_config: dict): - await self.terminate_provider(provider_config["id"]) - if provider_config["enable"]: - await self.load_provider(provider_config) + provider_id = provider_config["id"] + + # 只有禁用的直接终止 + if not provider_config["enable"]: + async with self._inst_map_lock: + if provider_id in self.inst_map: + await self.terminate_provider(provider_id) + return + + # 备份一份旧实例 + old_instance = None + async with self._inst_map_lock: + old_instance = self.inst_map.get(provider_id) + + try: + # 使用临时 ID 加载新实例 + temp_config = provider_config.copy() + temp_id = f"{provider_id}_reload_temp" + temp_config["id"] = temp_id + + await self.load_provider(temp_config) + + # 新实例加载成功,替换旧实例 + async with self._inst_map_lock: + if temp_id in self.inst_map: + new_instance = self.inst_map[temp_id] + del self.inst_map[temp_id] + + new_instance.provider_config["id"] = provider_id + self.inst_map[provider_id] = new_instance + + if old_instance: + if old_instance in self.provider_insts: + idx = self.provider_insts.index(old_instance) + self.provider_insts[idx] = new_instance + if old_instance in self.stt_provider_insts: + idx = self.stt_provider_insts.index(old_instance) + self.stt_provider_insts[idx] = new_instance + if old_instance in self.tts_provider_insts: + idx = self.tts_provider_insts.index(old_instance) + self.tts_provider_insts[idx] = new_instance + if old_instance in self.embedding_provider_insts: + idx = self.embedding_provider_insts.index(old_instance) + self.embedding_provider_insts[idx] = new_instance + + # 更新当前实例引用 + if self.curr_provider_inst == old_instance: + self.curr_provider_inst = new_instance + if self.curr_stt_provider_inst == old_instance: + self.curr_stt_provider_inst = new_instance + if self.curr_tts_provider_inst == old_instance: + self.curr_tts_provider_inst = new_instance + + # 锁外清理旧实例 + if old_instance and hasattr(old_instance, "terminate"): + try: + await old_instance.terminate() + except Exception as e: + logger.warning(f"清理旧 Provider 实例时出错: {e}") - # 和配置文件保持同步 + except Exception as e: + # 清理临时实例 + async with self._inst_map_lock: + if temp_id in self.inst_map: + temp_instance = self.inst_map[temp_id] + del self.inst_map[temp_id] + if hasattr(temp_instance, "terminate"): + try: + await temp_instance.terminate() + except Exception: + pass + raise e + + # 清理已移除的提供商 config_ids = [provider["id"] for provider in self.providers_config] - logger.debug(f"providers in user's config: {config_ids}") - for key in list(self.inst_map.keys()): - if key not in config_ids: - await self.terminate_provider(key) + async with self._inst_map_lock: + for key in list(self.inst_map.keys()): + if key not in config_ids: + await self.terminate_provider(key) if len(self.provider_insts) == 0: self.curr_provider_inst = None @@ -434,38 +513,41 @@ def get_insts(self): return self.provider_insts async def terminate_provider(self, provider_id: str): - if provider_id in self.inst_map: - logger.info( - f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..." - ) - - if self.inst_map[provider_id] in self.provider_insts: - prov_inst = self.inst_map[provider_id] - if isinstance(prov_inst, Provider): - self.provider_insts.remove(prov_inst) - if self.inst_map[provider_id] in self.stt_provider_insts: - prov_inst = self.inst_map[provider_id] - if isinstance(prov_inst, STTProvider): - self.stt_provider_insts.remove(prov_inst) - if self.inst_map[provider_id] in self.tts_provider_insts: - prov_inst = self.inst_map[provider_id] - if isinstance(prov_inst, TTSProvider): - self.tts_provider_insts.remove(prov_inst) - - if self.inst_map[provider_id] == self.curr_provider_inst: + instance = None + logger.info( + f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..." + ) + # 锁内移除引用 + async with self._inst_map_lock: + instance = self.inst_map.get(provider_id) + if not instance: + return + del self.inst_map[provider_id] + if instance in self.provider_insts: + self.provider_insts.remove(instance) + if instance in self.stt_provider_insts: + self.stt_provider_insts.remove(instance) + if instance in self.tts_provider_insts: + self.tts_provider_insts.remove(instance) + if instance in self.embedding_provider_insts: + self.embedding_provider_insts.remove(instance) + if self.curr_provider_inst == instance: self.curr_provider_inst = None - if self.inst_map[provider_id] == self.curr_stt_provider_inst: + if self.curr_stt_provider_inst == instance: self.curr_stt_provider_inst = None - if self.inst_map[provider_id] == self.curr_tts_provider_inst: + if self.curr_tts_provider_inst == instance: self.curr_tts_provider_inst = None - if getattr(self.inst_map[provider_id], "terminate", None): - await self.inst_map[provider_id].terminate() # type: ignore + # 锁外终止实例 + if hasattr(instance, "terminate"): + try: + await instance.terminate() + except Exception as e: + logger.warning(f"终止提供商实例时出错: {e}") - logger.info( - f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" - ) - del self.inst_map[provider_id] + logger.info( + f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" + ) async def terminate(self): for provider_inst in self.provider_insts: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 0983cf8d5..896dfdb79 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -479,18 +479,70 @@ async def check_one_provider_status(self): logger.info(f"API call: /config/provider/check_one id={provider_id}") try: prov_mgr = self.core_lifecycle.provider_manager - target = prov_mgr.inst_map.get(provider_id) + # 配置里没有 + provider_config = None + for config in self.config["provider"]: + if config.get("id") == provider_id: + provider_config = config + break + + if not provider_config: + logger.warning( + f"Provider config with id '{provider_id}' not found in configuration." + ) + return ( + Response() + .error( + f"Provider with id '{provider_id}' not found in configuration" + ) + .__dict__ + ) + + # 没启用 + if not provider_config.get("enable", False): + logger.info(f"Provider with id '{provider_id}' is disabled.") + return ( + Response() + .ok( + { + "id": provider_id, + "model": provider_config.get("model", "Unknown Model"), + "type": provider_config.get( + "provider_type", "Unknown Type" + ), + "name": provider_config.get("name", provider_id), + "status": "disabled", + "error": "Provider is disabled", + } + ) + .__dict__ + ) + + # 先等待加载 + target = await prov_mgr.get_provider_by_id(provider_id) if not target: logger.warning( - f"Provider with id '{provider_id}' not found in provider_manager." + f"Provider with id '{provider_id}' is enabled but not loaded in provider_manager." ) return ( Response() - .error(f"Provider with id '{provider_id}' not found") + .ok( + { + "id": provider_id, + "model": provider_config.get("model", "Unknown Model"), + "type": provider_config.get( + "provider_type", "Unknown Type" + ), + "name": provider_config.get("name", provider_id), + "status": "not_loaded", + "error": "Provider is enabled but failed to load. Check logs for details.", + } + ) .__dict__ ) + # 已加载 result = await self._test_single_provider(target) return Response().ok(result).__dict__ diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index b36c0db21..e09a4780f 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -102,27 +102,29 @@ - - mdi-check-circle - mdi-alert-circle - - - {{ status.id }} - - - {{ getStatusText(status.status) }} - - - - {{ tm('availability.errorMessage') }}: {{ status.error }} - + + mdi-check-circle + mdi-alert-circle + mdi-minus-circle + mdi-alert-circle-outline + + + {{ status.id }} + + + {{ getStatusText(status.status) }} + + + + {{ tm('availability.errorMessage') }}: {{ status.error }} + @@ -753,7 +755,7 @@ export default { if (index !== -1) { const disabledStatus = { ...this.providerStatuses[index], - status: 'unavailable', + status: 'disabled', error: '该提供商未被用户启用' }; this.providerStatuses.splice(index, 1, disabledStatus); @@ -830,19 +832,22 @@ export default { }, getStatusColor(status) { switch (status) { - case 'available': - return 'success'; - case 'unavailable': - return 'error'; - case 'pending': - return 'grey'; - default: - return 'default'; + case 'available': return 'success' + case 'unavailable': return 'error' + case 'disabled': return 'warning' + case 'not_loaded': return 'warning' + default: return 'info' } }, getStatusText(status) { - return this.messages.status[status] || status; + switch(status) { + case 'available': return '可用' + case 'unavailable': return '不可用' + case 'disabled': return '已禁用' + case 'not_loaded': return '加载失败' + default: return '未知' + } }, } }