Skip to content

Commit eed622f

Browse files
committed
chore(core.provider): 🚨 修正实现错误Lint
1 parent dd09108 commit eed622f

17 files changed

+95
-52
lines changed

astrbot/core/provider/sources/azure_tts_source.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import cast
12
import uuid
23
import time
34
import json
@@ -40,7 +41,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
4041

4142
async def _sync_time(self):
4243
try:
43-
response = await self.client.get(self.auth_time_url)
44+
response = await cast(AsyncClient, self.client).get(self.auth_time_url)
4445
response.raise_for_status()
4546
server_time = int(response.json()["timestamp"])
4647
local_time = int(time.time())
@@ -62,7 +63,7 @@ async def get_audio(self, text: str, voice_params: dict) -> str:
6263
signature = await self._generate_signature()
6364
for attempt in range(self.retry_count):
6465
try:
65-
response = await self.client.post(
66+
response = await cast(AsyncClient, self.client).post(
6667
f"{self.api_url}?sign={signature}",
6768
data={
6869
"text": text,
@@ -87,6 +88,7 @@ async def get_audio(self, text: str, voice_params: dict) -> str:
8788
if attempt == self.retry_count - 1:
8889
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
8990
await asyncio.sleep(0.5 * (attempt + 1))
91+
raise RuntimeError("OTTS未返回音频文件")
9092

9193

9294
class AzureNativeProvider(TTSProvider):
@@ -130,7 +132,7 @@ async def _refresh_token(self):
130132
token_url = (
131133
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
132134
)
133-
response = await self.client.post(
135+
response = await cast(AsyncClient, self.client).post(
134136
token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
135137
)
136138
response.raise_for_status()
@@ -153,7 +155,7 @@ async def get_audio(self, text: str) -> str:
153155
</mstts:express-as>
154156
</voice>
155157
</speak>"""
156-
response = await self.client.post(
158+
response = await cast(AsyncClient, self.client).post(
157159
self.endpoint,
158160
content=ssml,
159161
headers={
@@ -176,8 +178,11 @@ def __init__(self, provider_config: dict, provider_settings: dict):
176178
key_value = provider_config.get("azure_tts_subscription_key", "")
177179
self.provider = self._parse_provider(key_value, provider_config)
178180

179-
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
181+
def _parse_provider(
182+
self, key_value: str, config: dict
183+
) -> OTTSProvider | AzureNativeProvider:
180184
if key_value.lower().startswith("other["):
185+
json_str = ""
181186
try:
182187
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
183188
if not match:

astrbot/core/provider/sources/coze_source.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
provider_settings,
2727
default_persona,
2828
)
29-
self.api_key = provider_config.get("coze_api_key", "")
29+
self.api_key: str = provider_config.get("coze_api_key", "")
3030
if not self.api_key:
3131
raise Exception("Coze API Key 不能为空。")
3232
self.bot_id = provider_config.get("bot_id", "")
@@ -64,8 +64,8 @@ def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
6464

6565
try:
6666
if is_base64 and data.startswith("data:image/"):
67+
header, encoded = data.split(",", 1)
6768
try:
68-
header, encoded = data.split(",", 1)
6969
image_bytes = base64.b64decode(encoded)
7070
cache_key = hashlib.md5(image_bytes).hexdigest()
7171
return cache_key
@@ -567,11 +567,11 @@ async def forget(self, session_id: str):
567567
logger.error(f"清空 Coze 会话失败: {str(e)}")
568568
return False
569569

570-
async def get_current_key(self):
570+
def get_current_key(self):
571571
"""获取当前API Key"""
572572
return self.api_key
573573

574-
async def set_key(self, key: str):
574+
def set_key(self, key: str):
575575
"""设置新的API Key"""
576576
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
577577

@@ -583,12 +583,12 @@ def get_model(self):
583583
"""获取当前模型"""
584584
return f"bot_{self.bot_id}"
585585

586-
def set_model(self, model: str):
586+
def set_model(self, model_name: str):
587587
"""设置模型(在Coze中是Bot ID)"""
588-
if model.startswith("bot_"):
589-
self.bot_id = model[4:]
588+
if model_name.startswith("bot_"):
589+
self.bot_id = model_name[4:]
590590
else:
591-
self.bot_id = model
591+
self.bot_id = model_name
592592

593593
async def get_human_readable_context(
594594
self, session_id: str, page: int = 1, page_size: int = 10

astrbot/core/provider/sources/dashscope_source.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
provider_settings,
2626
default_persona,
2727
)
28-
self.api_key = provider_config.get("dashscope_api_key", "")
28+
self.api_key: str = provider_config.get("dashscope_api_key", "")
2929
if not self.api_key:
3030
raise Exception("阿里云百炼 API Key 不能为空。")
3131
self.app_id = provider_config.get("dashscope_app_id", "")
@@ -66,6 +66,7 @@ async def text_chat(
6666
func_tool=None,
6767
contexts=None,
6868
system_prompt=None,
69+
tool_calls_result=None,
6970
model=None,
7071
**kwargs,
7172
) -> LLMResponse:
@@ -75,6 +76,8 @@ async def text_chat(
7576
payload_vars = self.variables.copy()
7677
# 动态变量
7778
session_var = await sp.session_get(session_id, "session_variables", default={})
79+
if not isinstance(session_var, dict):
80+
session_var = {}
7881
payload_vars.update(session_var)
7982

8083
if (
@@ -159,9 +162,9 @@ async def text_chat_stream(
159162
self,
160163
prompt,
161164
session_id=None,
162-
image_urls=...,
165+
image_urls=None,
163166
func_tool=None,
164-
contexts=...,
167+
contexts=None,
165168
system_prompt=None,
166169
tool_calls_result=None,
167170
model=None,
@@ -186,10 +189,10 @@ async def text_chat_stream(
186189
async def forget(self, session_id):
187190
return True
188191

189-
async def get_current_key(self):
192+
def get_current_key(self):
190193
return self.api_key
191194

192-
async def set_key(self, key):
195+
def set_key(self, key):
193196
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
194197

195198
async def get_models(self):

astrbot/core/provider/sources/dashscope_tts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
super().__init__(provider_config, provider_settings)
3333
self.chosen_api_key: str = provider_config.get("api_key", "")
3434
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
35-
self.set_model(provider_config.get("model", None))
35+
self.set_model(provider_config["model"])
3636
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
3737
dashscope.api_key = self.chosen_api_key
3838

@@ -67,9 +67,10 @@ def _call_qwen_tts(self, model: str, text: str):
6767

6868
kwargs = {
6969
"model": model,
70-
"text": text,
70+
"messages": None,
7171
"api_key": self.chosen_api_key,
7272
"voice": self.voice or "Cherry",
73+
"text": text,
7374
}
7475
if not self.voice:
7576
logging.warning(

astrbot/core/provider/sources/dify_source.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ async def text_chat(
9696
payload_vars = self.variables.copy()
9797
# 动态变量
9898
session_var = await sp.session_get(session_id, "session_variables", default={})
99+
if not isinstance(session_var, dict):
100+
session_var = {}
99101
payload_vars.update(session_var)
100102
payload_vars["system_prompt"] = system_prompt
101103

@@ -266,10 +268,10 @@ async def forget(self, session_id):
266268
self.conversation_ids[session_id] = ""
267269
return True
268270

269-
async def get_current_key(self):
271+
def get_current_key(self):
270272
return self.api_key
271273

272-
async def set_key(self, key):
274+
def set_key(self, key):
273275
raise Exception("Dify 适配器不支持设置 API Key。")
274276

275277
async def get_models(self):

astrbot/core/provider/sources/edge_tts_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def get_audio(self, text: str) -> str:
6262
from pyffmpeg import FFmpeg
6363

6464
ff = FFmpeg()
65-
ff.convert(input=mp3_path, output=wav_path)
65+
ff.convert(input_file=mp3_path, output_file=wav_path)
6666
except Exception as e:
6767
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
6868
# use ffmpeg command line

astrbot/core/provider/sources/fishaudio_tts_api_source.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def __init__(
5353
self.headers = {
5454
"Authorization": f"Bearer {self.chosen_api_key}",
5555
}
56-
self.set_model(provider_config.get("model", None))
56+
self.set_model(provider_config["model"])
5757

58-
async def _get_reference_id_by_character(self, character: str) -> str:
58+
async def _get_reference_id_by_character(self, character: str) -> str | None:
5959
"""
6060
获取角色的reference_id
6161
@@ -120,7 +120,7 @@ async def _generate_request(self, text: str) -> dict:
120120
text=text,
121121
format="wav",
122122
reference_id=reference_id,
123-
)
123+
).model_dump()
124124

125125
async def get_audio(self, text: str) -> str:
126126
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -138,5 +138,6 @@ async def get_audio(self, text: str) -> str:
138138
async for chunk in response.aiter_bytes():
139139
f.write(chunk)
140140
return path
141-
text = await response.aread()
141+
body = await response.aread()
142+
text = body.decode("utf-8", errors="replace")
142143
raise Exception(f"Fish Audio API请求失败: {text}")

astrbot/core/provider/sources/gemini_embedding_source.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..provider import EmbeddingProvider
55
from ..register import register_provider_adapter
66
from ..entities import ProviderType
7+
from typing import cast
78

89

910
@register_provider_adapter(
@@ -17,8 +18,8 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
1718
self.provider_config = provider_config
1819
self.provider_settings = provider_settings
1920

20-
api_key: str = provider_config.get("embedding_api_key")
21-
api_base: str = provider_config.get("embedding_api_base", None)
21+
api_key: str = provider_config["embedding_api_key"]
22+
api_base: str = provider_config["embedding_api_base"]
2223
timeout: int = int(provider_config.get("timeout", 20))
2324

2425
http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,19 +42,27 @@ async def get_embedding(self, text: str) -> list[float]:
4142
result = await self.client.models.embed_content(
4243
model=self.model, contents=text
4344
)
45+
assert result.embeddings is not None
46+
assert result.embeddings[0].values is not None
4447
return result.embeddings[0].values
4548
except APIError as e:
4649
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
4750

48-
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
51+
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
4952
"""
5053
批量获取文本的嵌入
5154
"""
5255
try:
5356
result = await self.client.models.embed_content(
54-
model=self.model, contents=texts
57+
model=self.model, contents=cast(types.ContentListUnion, text)
5558
)
56-
return [embedding.values for embedding in result.embeddings]
59+
assert result.embeddings is not None
60+
61+
embeddings: list[list[float]] = []
62+
for embedding in result.embeddings:
63+
assert embedding.values is not None
64+
embeddings.append(embedding.values)
65+
return embeddings
5766
except APIError as e:
5867
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
5968

astrbot/core/provider/sources/gemini_source.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import random
66
from collections.abc import AsyncGenerator
7+
from typing import cast
78

89
from google import genai
910
from google.genai import types
@@ -138,7 +139,7 @@ async def _prepare_query_config(
138139
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
139140
modalities = ["Text"]
140141

141-
tool_list = []
142+
tool_list: list[types.Tool] | None = []
142143
model_name = self.get_model()
143144
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
144145
native_search = self.provider_config.get("gm_native_search", False)
@@ -215,7 +216,7 @@ async def _prepare_query_config(
215216
logprobs=payloads.get("logprobs"),
216217
seed=payloads.get("seed"),
217218
response_modalities=modalities,
218-
tools=tool_list,
219+
tools=cast(types.ToolListUnion | None, tool_list),
219220
safety_settings=self.safety_settings if self.safety_settings else None,
220221
thinking_config=(
221222
types.ThinkingConfig(
@@ -258,6 +259,7 @@ def append_or_extend(
258259
content_cls: type[types.Content],
259260
) -> None:
260261
if contents and isinstance(contents[-1], content_cls):
262+
assert contents[-1].parts is not None
261263
contents[-1].parts.extend(part)
262264
else:
263265
contents.append(content_cls(parts=part))
@@ -413,7 +415,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
413415
)
414416
result = await self.client.models.generate_content(
415417
model=self.get_model(),
416-
contents=conversation,
418+
contents=cast(types.ContentListUnion, conversation),
417419
config=config,
418420
)
419421

@@ -483,7 +485,7 @@ async def _query_stream(
483485
)
484486
result = await self.client.models.generate_content_stream(
485487
model=self.get_model(),
486-
contents=conversation,
488+
contents=cast(types.ContentListUnion, conversation),
487489
config=config,
488490
)
489491
break

astrbot/core/provider/sources/minimax_tts_api_source.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _build_tts_stream_body(self, text: str):
8080

8181
return json.dumps(dict_body)
8282

83-
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
83+
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
8484
"""进行流式请求"""
8585
try:
8686
async with aiohttp.ClientSession() as session:
@@ -108,7 +108,9 @@ async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
108108
data = json.loads(message[6:])
109109
if "extra_info" in data:
110110
continue
111-
audio = data.get("data", {}).get("audio")
111+
audio: str | None = data.get("data", {}).get(
112+
"audio"
113+
)
112114
if audio is not None:
113115
yield audio
114116
except json.JSONDecodeError:

0 commit comments

Comments
 (0)