Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(LLM config): Restructure LLM configuration and management #276

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
# * Settings marked with * are advanced settings that won't appear in the Streamlit page and can only be modified manually in config.py
version: "2.0.0"
## ======================== Basic Settings ======================== ##
# API settings
api:
key: 'YOUR_API_KEY'
base_url: 'https://api.302.ai'
model: 'claude-3-5-sonnet-20240620'
# LLM 模型配置部分
llm_models:
default_model: # 默认模型配置
key: 'YOUR_API_KEY'
base_url: 'https://api.302.ai'
model: 'claude-3-5-sonnet-20240620'
fast_model: # 快速模型配置
key: 'YOUR_API_KEY'
base_url: 'https://api.302.ai'
model: 'Yi-Lightning'

# 各阶段使用的模型配置
llm_stages:
align: 'fast_model' # 字幕对齐
split: 'fast_model' # 字幕分割
summarize: 'default_model' # 字幕总结
translate_faithfulness: 'default_model' # 字幕翻译(精确)
translate_expressiveness: 'default_model' # 字幕翻译(信达雅)
reduce: 'fast_model' # 字幕缩减

# Language settings, written into the prompt, can be described in natural language
target_language: '简体中文'
Expand Down Expand Up @@ -34,7 +48,7 @@ subtitle:
target_multiplier: 1.2

# *Number of LLM multi-threaded accesses
max_workers: 5
max_workers: 8
# *Maximum number of words for the first rough cut, below 18 will cut too finely affecting translation, above 22 is too long and will make subsequent subtitle splitting difficult to align
max_split_length: 20

Expand Down Expand Up @@ -129,6 +143,7 @@ llm_support_json:
- 'gemini-1.5-flash-latest'
- 'gemini-1.5-pro-latest'
- 'gemini-1.5-pro-002'
- 'Yi-Lightning'

# have problems
# - 'Qwen/Qwen2.5-72B-Instruct'
Expand Down
25 changes: 20 additions & 5 deletions core/ask_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def save_log(model, prompt, response, log_title = 'default', message = None):
with open(log_file, 'w', encoding='utf-8') as f:
json.dump(logs, f, ensure_ascii=False, indent=4)

def check_ask_gpt_history(prompt, model, log_title):
def check_ask_gpt_history(prompt, log_title):
# check if the prompt has been asked before
if not os.path.exists(LOG_FOLDER):
return False
Expand All @@ -43,11 +43,26 @@ def check_ask_gpt_history(prompt, model, log_title):
return item["response"]
return False

def ask_gpt(prompt, response_json=True, valid_def=None, log_title='default'):
api_set = load_key("api")
def ask_gpt(prompt, response_json=True, valid_def=None, log_title='default', check_api=False, api_set=None):
llm_support_json = load_key("llm_support_json")
if check_api:
api_set = api_set
else:
if log_title == 'sentence_splitbymeaning':
api_set = load_key(f"llm_models.{load_key('llm_stages')['split']}")
elif log_title == 'align_subs':
api_set = load_key(f"llm_models.{load_key('llm_stages')['align']}")
elif log_title == 'summary':
api_set = load_key(f"llm_models.{load_key('llm_stages')['summarize']}")
elif log_title == 'translate_faithfulness':
api_set = load_key(f"llm_models.{load_key('llm_stages')['translate_faithfulness']}")
elif log_title == 'translate_expressiveness':
api_set = load_key(f"llm_models.{load_key('llm_stages')['translate_expressiveness']}")
elif log_title == 'subtitle_trim':
api_set = load_key(f"llm_models.{load_key('llm_stages')['reduce']}")

with LOCK:
history_response = check_ask_gpt_history(prompt, api_set["model"], log_title)
history_response = check_ask_gpt_history(prompt, log_title)
if history_response:
return history_response

Expand Down Expand Up @@ -109,4 +124,4 @@ def ask_gpt(prompt, response_json=True, valid_def=None, log_title='default'):


if __name__ == '__main__':
print(ask_gpt('hi there hey response in json format, just return 200.' , response_json=True, log_title=None))
print(ask_gpt('hi there hey response in json format, just return 200.' , response_json=True, log_title='None', api_set=load_key("llm_models.default_model"), check_api=True))
36 changes: 24 additions & 12 deletions core/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,37 @@ def load_key(key: str) -> Any:
return value

def update_key(key: str, new_value: Any) -> bool:
"""Update a key in the config file. If the key doesn't exist, it will be created.

Args:
key: Dot-separated key path (e.g. "llm_models.default_model.key")
new_value: Value to set

Returns:
bool: True if successful
"""
with config_lock:
with open(CONFIG_PATH, 'r', encoding='utf-8') as file:
data = yaml.load(file)

keys = key.split('.')
current = data

# 遍历除最后一个key外的所有key
for k in keys[:-1]:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return False

if isinstance(current, dict) and keys[-1] in current:
current[keys[-1]] = new_value
with open(CONFIG_PATH, 'w', encoding='utf-8') as file:
yaml.dump(data, file)
return True
else:
raise KeyError(f"Key '{keys[-1]}' not found in configuration")
# 如果key不存在,创建一个新的字典
if k not in current:
current[k] = {}
current = current[k]

# 设置最终的值
current[keys[-1]] = new_value

# 保存更新后的配置
with open(CONFIG_PATH, 'w', encoding='utf-8') as file:
yaml.dump(data, file)
return True


# basic utils
def get_joiner(language):
Expand Down
29 changes: 23 additions & 6 deletions i18n/中文/config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
# * 标有 * 的设置是高级设置,不会出现在 Streamlit 页面中,只能在 config.py 中手动修改
version: "2.0.0"
## ======================== 基本设置 ======================== ##
# API 设置
api:
key: 'YOUR_API_KEY'
base_url: 'https://api.siliconflow.cn'
model: 'Qwen/Qwen2.5-72B-Instruct'
# LLM 模型配置部分
llm_models:
default_model: # 默认模型配置
key: 'YOUR_API_KEY'
base_url: 'https://api.302.ai'
model: 'claude-3-5-sonnet-20240620'
fast_model: # 快速模型配置
key: 'YOUR_API_KEY'
base_url: 'https://api.302.ai'
model: 'Yi-Lightning'


# 各阶段使用的模型配置
llm_stages:
align: 'fast_model' # 字幕对齐
split: 'fast_model' # 字幕分割
summarize: 'default_model' # 字幕总结
translate_faithfulness: 'default_model' # 字幕翻译(精确)
translate_expressiveness: 'default_model' # 字幕翻译(信达雅)
reduce: 'fast_model' # 字幕缩减


# 语言设置,写入提示词,可以用自然语言描述
target_language: '简体中文'
Expand Down Expand Up @@ -129,6 +145,7 @@ llm_support_json:
- 'gemini-1.5-flash-latest'
- 'gemini-1.5-pro-latest'
- 'gemini-1.5-pro-002'
- 'Yi-Lightning'

# 存在问题
# - 'Qwen/Qwen2.5-72B-Instruct'
Expand Down Expand Up @@ -158,4 +175,4 @@ language_split_with_space:
# 不使用空格作为分隔符的语言
language_split_without_space:
- 'zh'
- 'ja'
- 'ja'
99 changes: 86 additions & 13 deletions i18n/中文/st_components/sidebar_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,91 @@

def config_input(label, key, help=None):
"""Generic config input handler"""
val = st.text_input(label, value=load_key(key), help=help)
val = st.text_input(
label,
value=load_key(key),
help=help,
key=f"config_input_{key}"
)
if val != load_key(key):
update_key(key, val)
return val

def page_setting():
with st.expander("LLM 配置", expanded=True):
config_input("API_KEY", "api.key")
config_input("BASE_URL", "api.base_url", help="API请求的基础URL")
# LLM配置部分
st.subheader("LLM配置")

# 模型配置管理
st.subheader("模型配置")

# 获取现有模型列表
models = load_key("llm_models")

# 添加新模型按钮
if st.button("➕ 添加新模型"):
new_model_name = f"model_{len(models)}"
update_key(f"llm_models.{new_model_name}", {
"key": "",
"base_url": "",
"model": ""
})
st.rerun()

# 显示现有模型配置
for model_name in models:
st.markdown(f"### 📑 {model_name}")
# 模型名称编辑(可选)
new_name = st.text_input("模型名称", value=model_name, key=f"name_{model_name}")

c1, c2 = st.columns([4, 1])
with c1:
config_input("模型", "api.model")
# 基础配置
config_input(f"API密钥", f"llm_models.{model_name}.key")
config_input(f"基础URL", f"llm_models.{model_name}.base_url")
config_input(f"模型", f"llm_models.{model_name}.model")

# 测试和删除按钮
c1, c2, c3 = st.columns([3, 1, 1])
with c2:
if st.button("📡", key="api"):
st.toast("API密钥有效" if check_api() else "API密钥无效",
icon="✅" if check_api() else "❌")
if st.button("🔍 测试", key=f"test_{model_name}"):
api_set = load_key(f"llm_models.{model_name}")
st.toast(
"API密钥有效" if check_api(api_set) else "API密钥无效",
icon="✅" if check_api(api_set) else "❌"
)
with c3:
if len(models) > 1 and st.button("🗑️ 删除", key=f"delete_{model_name}"):
models_dict = load_key("llm_models")
del models_dict[model_name]
update_key("llm_models", models_dict)
st.rerun()
st.divider()

# 阶段配置
st.subheader("阶段配置")
stages = {
"align": "字幕对齐",
"split": "字幕分割",
"summarize": "总结",
"translate_faithfulness": "翻译(精确)",
"translate_expressiveness": "翻译(优雅)",
"reduce": "字幕缩减"
}

for stage_key, stage_name in stages.items():
col1, col2 = st.columns([3, 2])
with col1:
st.write(stage_name)
with col2:
current_model = load_key(f"llm_stages.{stage_key}")
selected_model = st.selectbox(
"模型",
options=list(models.keys()),
index=list(models.keys()).index(current_model),
key=f"stage_{stage_key}",
label_visibility="collapsed"
)
if selected_model != current_model:
update_key(f"llm_stages.{stage_key}", selected_model)

with st.expander("转写和字幕设置", expanded=True):
c1, c2 = st.columns(2)
with c1:
Expand Down Expand Up @@ -131,10 +198,16 @@ def page_setting():
if selected_refer_mode != load_key("gpt_sovits.refer_mode"):
update_key("gpt_sovits.refer_mode", selected_refer_mode)

def check_api():
def check_api(api_set):
"""检查API配置是否有效"""
try:
resp = ask_gpt("This is a test, response 'message':'success' in json format.",
response_json=True, log_title='None')
resp = ask_gpt(
"This is a test, response 'message':'success' in json format.",
response_json=True,
log_title='None',
check_api=True,
api_set=api_set
)
return resp.get('message') == 'success'
except Exception:
return False
Loading