Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 131 additions & 49 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(
"""加载的 Embedding Provider 的实例"""
self.inst_map: dict[str, Provider | STTProvider | TTSProvider] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self._inst_map_lock = asyncio.Lock()
"""inst_map 操作的互斥锁"""
self.llm_tools = llm_tools

self.curr_provider_inst: Provider | None = None
Expand Down Expand Up @@ -77,15 +79,21 @@ async def set_provider(

Version 4.0.0: 这个版本下已经默认隔离提供商
"""
if provider_id not in self.inst_map:
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
async with self._inst_map_lock:
if provider_id not in self.inst_map:
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
if not umo:
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]

if umo:
await sp.session_put(
umo,
f"provider_perf_{provider_type.value}",
provider_id,
)
return

# 不启用提供商会话隔离模式的情况

prov = self.inst_map[provider_id]
Expand All @@ -107,7 +115,8 @@ async def set_provider(

async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
async with self._inst_map_lock:
return self.inst_map.get(provider_id)

def get_using_provider(
self, provider_type: ProviderType, umo=None
Expand Down Expand Up @@ -185,17 +194,18 @@ async def initialize(self):
scope="global",
scope_id="global",
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
async with self._inst_map_lock:
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]

self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]

self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]

# 初始化 MCP Client 连接
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
Expand Down Expand Up @@ -386,25 +396,94 @@ async def load_provider(self, provider_config: dict):
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)

self.inst_map[provider_config["id"]] = inst
async with self._inst_map_lock:
self.inst_map[provider_config["id"]] = inst
except Exception as e:
logger.error(traceback.format_exc())
logger.error(
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
)

async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config["id"])
if provider_config["enable"]:
await self.load_provider(provider_config)
provider_id = provider_config["id"]

# 只有禁用的直接终止
if not provider_config["enable"]:
async with self._inst_map_lock:
if provider_id in self.inst_map:
await self.terminate_provider(provider_id)
return

# 备份一份旧实例
old_instance = None
async with self._inst_map_lock:
old_instance = self.inst_map.get(provider_id)

try:
# 使用临时 ID 加载新实例
temp_config = provider_config.copy()
temp_id = f"{provider_id}_reload_temp"
temp_config["id"] = temp_id

await self.load_provider(temp_config)

# 新实例加载成功,替换旧实例
async with self._inst_map_lock:
if temp_id in self.inst_map:
new_instance = self.inst_map[temp_id]
del self.inst_map[temp_id]

new_instance.provider_config["id"] = provider_id
self.inst_map[provider_id] = new_instance

if old_instance:
if old_instance in self.provider_insts:
idx = self.provider_insts.index(old_instance)
self.provider_insts[idx] = new_instance
if old_instance in self.stt_provider_insts:
idx = self.stt_provider_insts.index(old_instance)
self.stt_provider_insts[idx] = new_instance
if old_instance in self.tts_provider_insts:
idx = self.tts_provider_insts.index(old_instance)
self.tts_provider_insts[idx] = new_instance
if old_instance in self.embedding_provider_insts:
idx = self.embedding_provider_insts.index(old_instance)
self.embedding_provider_insts[idx] = new_instance

# 更新当前实例引用
if self.curr_provider_inst == old_instance:
self.curr_provider_inst = new_instance
if self.curr_stt_provider_inst == old_instance:
self.curr_stt_provider_inst = new_instance
if self.curr_tts_provider_inst == old_instance:
self.curr_tts_provider_inst = new_instance

# 锁外清理旧实例
if old_instance and hasattr(old_instance, "terminate"):
try:
await old_instance.terminate()
except Exception as e:
logger.warning(f"清理旧 Provider 实例时出错: {e}")

# 和配置文件保持同步
except Exception as e:
# 清理临时实例
async with self._inst_map_lock:
if temp_id in self.inst_map:
temp_instance = self.inst_map[temp_id]
del self.inst_map[temp_id]
if hasattr(temp_instance, "terminate"):
try:
await temp_instance.terminate()
except Exception:
pass
raise e

# 清理已移除的提供商
config_ids = [provider["id"] for provider in self.providers_config]
logger.debug(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
async with self._inst_map_lock:
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)

if len(self.provider_insts) == 0:
self.curr_provider_inst = None
Expand Down Expand Up @@ -434,38 +513,41 @@ def get_insts(self):
return self.provider_insts

async def terminate_provider(self, provider_id: str):
if provider_id in self.inst_map:
logger.info(
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..."
)

if self.inst_map[provider_id] in self.provider_insts:
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, Provider):
self.provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.stt_provider_insts:
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, STTProvider):
self.stt_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.tts_provider_insts:
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, TTSProvider):
self.tts_provider_insts.remove(prov_inst)

if self.inst_map[provider_id] == self.curr_provider_inst:
instance = None
logger.info(
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..."
)
# 锁内移除引用
async with self._inst_map_lock:
instance = self.inst_map.get(provider_id)
if not instance:
return
del self.inst_map[provider_id]
if instance in self.provider_insts:
self.provider_insts.remove(instance)
if instance in self.stt_provider_insts:
self.stt_provider_insts.remove(instance)
if instance in self.tts_provider_insts:
self.tts_provider_insts.remove(instance)
if instance in self.embedding_provider_insts:
self.embedding_provider_insts.remove(instance)
if self.curr_provider_inst == instance:
self.curr_provider_inst = None
if self.inst_map[provider_id] == self.curr_stt_provider_inst:
if self.curr_stt_provider_inst == instance:
self.curr_stt_provider_inst = None
if self.inst_map[provider_id] == self.curr_tts_provider_inst:
if self.curr_tts_provider_inst == instance:
self.curr_tts_provider_inst = None

if getattr(self.inst_map[provider_id], "terminate", None):
await self.inst_map[provider_id].terminate() # type: ignore
# 锁外终止实例
if hasattr(instance, "terminate"):
try:
await instance.terminate()
except Exception as e:
logger.warning(f"终止提供商实例时出错: {e}")

logger.info(
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
)
del self.inst_map[provider_id]
logger.info(
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
)

async def terminate(self):
for provider_inst in self.provider_insts:
Expand Down
58 changes: 55 additions & 3 deletions astrbot/dashboard/routes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,18 +479,70 @@ async def check_one_provider_status(self):
logger.info(f"API call: /config/provider/check_one id={provider_id}")
try:
prov_mgr = self.core_lifecycle.provider_manager
target = prov_mgr.inst_map.get(provider_id)

# 配置里没有
provider_config = None
for config in self.config["provider"]:
if config.get("id") == provider_id:
provider_config = config
break

if not provider_config:
logger.warning(
f"Provider config with id '{provider_id}' not found in configuration."
)
return (
Response()
.error(
f"Provider with id '{provider_id}' not found in configuration"
)
.__dict__
)

# 没启用
if not provider_config.get("enable", False):
logger.info(f"Provider with id '{provider_id}' is disabled.")
return (
Response()
.ok(
{
"id": provider_id,
"model": provider_config.get("model", "Unknown Model"),
"type": provider_config.get(
"provider_type", "Unknown Type"
),
"name": provider_config.get("name", provider_id),
"status": "disabled",
"error": "Provider is disabled",
}
)
.__dict__
)

# 先等待加载
target = await prov_mgr.get_provider_by_id(provider_id)
if not target:
logger.warning(
f"Provider with id '{provider_id}' not found in provider_manager."
f"Provider with id '{provider_id}' is enabled but not loaded in provider_manager."
)
return (
Response()
.error(f"Provider with id '{provider_id}' not found")
.ok(
{
"id": provider_id,
"model": provider_config.get("model", "Unknown Model"),
"type": provider_config.get(
"provider_type", "Unknown Type"
),
"name": provider_config.get("name", provider_id),
"status": "not_loaded",
"error": "Provider is enabled but failed to load. Check logs for details.",
}
)
.__dict__
)

# 已加载
result = await self._test_single_provider(target)
return Response().ok(result).__dict__

Expand Down
Loading