From 38f74fb913194ea52839b439b6ca4648989c7428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=B0=91=E6=9D=B0?= Date: Tue, 17 Mar 2026 15:57:34 +0800 Subject: [PATCH 1/2] fix: add configurable models for tts and asr --- app/api/generate/tts/route.ts | 6 +- app/api/transcription/route.ts | 2 + app/generation-preview/page.tsx | 1 + components/audio/tts-config-popover.tsx | 4 +- components/generation/media-popover.tsx | 4 +- components/settings/asr-settings.tsx | 216 +++++++++++++++++++++++- components/settings/tts-settings.tsx | 216 +++++++++++++++++++++++- lib/audio/asr-providers.ts | 8 +- lib/audio/constants.ts | 24 +++ lib/audio/tts-providers.ts | 10 +- lib/audio/types.ts | 4 + lib/hooks/use-audio-recorder.ts | 4 +- lib/hooks/use-scene-generator.ts | 1 + lib/store/settings.ts | 45 ++++- 14 files changed, 525 insertions(+), 20 deletions(-) diff --git a/app/api/generate/tts/route.ts b/app/api/generate/tts/route.ts index 542f105b..b5489552 100644 --- a/app/api/generate/tts/route.ts +++ b/app/api/generate/tts/route.ts @@ -22,10 +22,11 @@ export const maxDuration = 30; export async function POST(req: NextRequest) { try { const body = await req.json(); - const { text, audioId, ttsProviderId, ttsVoice, ttsSpeed, ttsApiKey, ttsBaseUrl } = body as { + const { text, audioId, ttsProviderId, ttsModelId, ttsVoice, ttsSpeed, ttsApiKey, ttsBaseUrl } = body as { text: string; audioId: string; ttsProviderId: TTSProviderId; + ttsModelId?: string; ttsVoice: string; ttsSpeed?: number; ttsApiKey?: string; @@ -64,6 +65,7 @@ export async function POST(req: NextRequest) { // Build TTS config const config = { providerId: ttsProviderId, + modelId: ttsModelId, voice: ttsVoice, speed: ttsSpeed ?? 1.0, apiKey, @@ -71,7 +73,7 @@ export async function POST(req: NextRequest) { }; log.info( - `Generating TTS: provider=${ttsProviderId}, voice=${ttsVoice}, audioId=${audioId}, textLen=${text.length}`, + `Generating TTS: provider=${ttsProviderId}, model=${ttsModelId || 'default'}, voice=${ttsVoice}, audioId=${audioId}, textLen=${text.length}`, ); // Generate audio diff --git a/app/api/transcription/route.ts b/app/api/transcription/route.ts index c3bf5a27..8b0cbd8a 100644 --- a/app/api/transcription/route.ts +++ b/app/api/transcription/route.ts @@ -14,6 +14,7 @@ export async function POST(req: NextRequest) { const formData = await req.formData(); const audioFile = formData.get('audio') as File; const providerId = formData.get('providerId') as ASRProviderId | null; + const modelId = formData.get('modelId') as string | null; const language = formData.get('language') as string | null; const apiKey = formData.get('apiKey') as string | null; const baseUrl = formData.get('baseUrl') as string | null; @@ -35,6 +36,7 @@ export async function POST(req: NextRequest) { const config = { providerId: effectiveProviderId, + modelId: modelId || undefined, language: language || 'auto', apiKey: clientBaseUrl ? apiKey || '' diff --git a/app/generation-preview/page.tsx b/app/generation-preview/page.tsx index 213a5140..ec703809 100644 --- a/app/generation-preview/page.tsx +++ b/app/generation-preview/page.tsx @@ -666,6 +666,7 @@ function GenerationPreviewContent() { text: action.text, audioId, ttsProviderId: settings.ttsProviderId, + ttsModelId: settings.ttsModelId, ttsVoice: settings.ttsVoice, ttsSpeed: settings.ttsSpeed, ttsApiKey: ttsProviderConfig?.apiKey || undefined, diff --git a/components/audio/tts-config-popover.tsx b/components/audio/tts-config-popover.tsx index bda09055..3cb09d68 100644 --- a/components/audio/tts-config-popover.tsx +++ b/components/audio/tts-config-popover.tsx @@ -35,6 +35,7 @@ export function TtsConfigPopover() { const ttsEnabled = useSettingsStore((s) => s.ttsEnabled); const setTTSEnabled = useSettingsStore((s) => s.setTTSEnabled); const ttsProviderId = useSettingsStore((s) => s.ttsProviderId); + const ttsModelId = useSettingsStore((s) => s.ttsModelId); const ttsVoice = useSettingsStore((s) => s.ttsVoice); const ttsProvidersConfig = useSettingsStore((s) => s.ttsProvidersConfig); const setTTSVoice = useSettingsStore((s) => s.setTTSVoice); @@ -70,6 +71,7 @@ export function TtsConfigPopover() { text: '你好,欢迎来到AI课堂!让我们一起学习吧。', audioId: 'preview', ttsProviderId: ttsProviderId, + ttsModelId: ttsModelId, ttsVoice: ttsVoice, ttsApiKey: providerConfig?.apiKey, ttsBaseUrl: providerConfig?.baseUrl, @@ -95,7 +97,7 @@ export function TtsConfigPopover() { } catch { setPreviewing(false); } - }, [ttsProviderId, ttsVoice, ttsProvidersConfig, previewing]); + }, [ttsProviderId, ttsModelId, ttsVoice, ttsProvidersConfig, previewing]); return ( diff --git a/components/generation/media-popover.tsx b/components/generation/media-popover.tsx index 0a46230f..460cc327 100644 --- a/components/generation/media-popover.tsx +++ b/components/generation/media-popover.tsx @@ -101,6 +101,7 @@ export function MediaPopover({ onSettingsOpen }: MediaPopoverProps) { const setVideoModelId = useSettingsStore((s) => s.setVideoModelId); const ttsProviderId = useSettingsStore((s) => s.ttsProviderId); + const ttsModelId = useSettingsStore((s) => s.ttsModelId); const ttsVoice = useSettingsStore((s) => s.ttsVoice); const ttsSpeed = useSettingsStore((s) => s.ttsSpeed); const ttsProvidersConfig = useSettingsStore((s) => s.ttsProvidersConfig); @@ -198,6 +199,7 @@ export function MediaPopover({ onSettingsOpen }: MediaPopoverProps) { text: '你好,欢迎来到AI课堂!让我们一起学习吧。', audioId: 'preview', ttsProviderId, + ttsModelId, ttsVoice, ttsApiKey: providerConfig?.apiKey, ttsBaseUrl: providerConfig?.baseUrl, @@ -221,7 +223,7 @@ export function MediaPopover({ onSettingsOpen }: MediaPopoverProps) { } catch { setPreviewing(false); } - }, [ttsProviderId, ttsVoice, ttsProvidersConfig, previewing]); + }, [ttsProviderId, ttsModelId, ttsVoice, ttsProvidersConfig, previewing]); // ASR: only available providers const asrGroups = useMemo( diff --git a/components/settings/asr-settings.tsx b/components/settings/asr-settings.tsx index 2281e7cf..99f301ff 100644 --- a/components/settings/asr-settings.tsx +++ b/components/settings/asr-settings.tsx @@ -1,14 +1,27 @@ 'use client'; -import { useState, useRef } from 'react'; +import { useState, useRef, useEffect, useCallback, useMemo } from 'react'; import { Label } from '@/components/ui/label'; import { Input } from '@/components/ui/input'; import { Button } from '@/components/ui/button'; +import { Dialog, DialogContent, DialogDescription, DialogTitle } from '@/components/ui/dialog'; import { useI18n } from '@/lib/hooks/use-i18n'; import { useSettingsStore } from '@/lib/store/settings'; import { ASR_PROVIDERS } from '@/lib/audio/constants'; import type { ASRProviderId } from '@/lib/audio/types'; -import { Mic, MicOff, CheckCircle2, XCircle, Eye, EyeOff } from 'lucide-react'; +import { + Mic, + MicOff, + CheckCircle2, + XCircle, + Eye, + EyeOff, + Plus, + Settings2, + Trash2, + Circle, + CircleDot, +} from 'lucide-react'; import { cn } from '@/lib/utils'; import { createLogger } from '@/lib/logger'; @@ -21,11 +34,18 @@ interface ASRSettingsProps { export function ASRSettings({ selectedProviderId }: ASRSettingsProps) { const { t } = useI18n(); + const asrModelId = useSettingsStore((state) => state.asrModelId); const asrLanguage = useSettingsStore((state) => state.asrLanguage); const asrProvidersConfig = useSettingsStore((state) => state.asrProvidersConfig); const setASRProviderConfig = useSettingsStore((state) => state.setASRProviderConfig); + const setASRModelId = useSettingsStore((state) => state.setASRModelId); const asrProvider = ASR_PROVIDERS[selectedProviderId] ?? ASR_PROVIDERS['openai-whisper']; + const builtInModels = useMemo(() => asrProvider.models || [], [asrProvider.models]); + const customModels = useMemo( + () => asrProvidersConfig[selectedProviderId]?.customModels || [], + [selectedProviderId, asrProvidersConfig], + ); const isServerConfigured = !!asrProvidersConfig[selectedProviderId]?.isServerConfigured; const [showApiKey, setShowApiKey] = useState(false); @@ -33,6 +53,9 @@ export function ASRSettings({ selectedProviderId }: ASRSettingsProps) { const [asrResult, setASRResult] = useState(''); const [testStatus, setTestStatus] = useState<'idle' | 'testing' | 'success' | 'error'>('idle'); const [testMessage, setTestMessage] = useState(''); + const [showModelDialog, setShowModelDialog] = useState(false); + const [editingModelIndex, setEditingModelIndex] = useState(null); + const [modelForm, setModelForm] = useState({ id: '', name: '' }); const mediaRecorderRef = useRef(null); // Reset state when provider changes (derived state pattern) @@ -45,6 +68,63 @@ export function ASRSettings({ selectedProviderId }: ASRSettingsProps) { setASRResult(''); } + useEffect(() => { + const availableModelIds = new Set([ + ...builtInModels.map((model) => model.id), + ...customModels.map((model) => model.id), + ]); + if (availableModelIds.size > 0 && !availableModelIds.has(asrModelId)) { + const nextModelId = builtInModels[0]?.id || customModels[0]?.id || ''; + if (nextModelId) setASRModelId(nextModelId); + } + }, [asrModelId, builtInModels, customModels, setASRModelId]); + + const handleOpenAddModel = () => { + setEditingModelIndex(null); + setModelForm({ id: '', name: '' }); + setShowModelDialog(true); + }; + + const handleOpenEditModel = (index: number) => { + setEditingModelIndex(index); + setModelForm({ ...customModels[index] }); + setShowModelDialog(true); + }; + + const handleSaveModel = useCallback(() => { + if (!modelForm.id.trim()) return; + const nextCustomModels = [...customModels]; + const normalizedModel = { + id: modelForm.id.trim(), + name: modelForm.name.trim() || modelForm.id.trim(), + }; + if (editingModelIndex !== null) { + nextCustomModels[editingModelIndex] = normalizedModel; + } else { + nextCustomModels.push(normalizedModel); + } + setASRProviderConfig(selectedProviderId, { customModels: nextCustomModels }); + setASRModelId(normalizedModel.id); + setShowModelDialog(false); + }, [ + customModels, + editingModelIndex, + modelForm, + selectedProviderId, + setASRModelId, + setASRProviderConfig, + ]); + + const handleDeleteModel = (index: number) => { + const targetModel = customModels[index]; + const nextCustomModels = customModels.filter((_, i) => i !== index); + setASRProviderConfig(selectedProviderId, { customModels: nextCustomModels }); + if (asrModelId === targetModel?.id) { + const nextModelId = builtInModels[0]?.id || nextCustomModels[0]?.id || ''; + setASRModelId(nextModelId); + } + }; + const handleToggleASRRecording = async () => { if (isRecording) { if (mediaRecorderRef.current && mediaRecorderRef.current.state === 'recording') { @@ -104,6 +184,7 @@ export function ASRSettings({ selectedProviderId }: ASRSettingsProps) { const formData = new FormData(); formData.append('audio', audioBlob, 'recording.webm'); formData.append('providerId', selectedProviderId); + formData.append('modelId', asrModelId); formData.append('language', asrLanguage); const apiKeyValue = asrProvidersConfig[selectedProviderId]?.apiKey; if (apiKeyValue?.trim()) formData.append('apiKey', apiKeyValue); @@ -271,6 +352,137 @@ export function ASRSettings({ selectedProviderId }: ASRSettingsProps) { )} + +
+
+ + +
+ +
+ {builtInModels.map((model) => { + const selected = asrModelId === model.id; + return ( + + ); + })} + + {customModels.map((model, index) => { + const selected = asrModelId === model.id; + return ( +
+ +
+ + +
+
+ ); + })} +
+
+ + + + + {editingModelIndex !== null ? t('settings.editModel') : t('settings.addNewModel')} + + + {editingModelIndex !== null ? t('settings.editModel') : t('settings.addNewModel')} + +
+
+ + setModelForm((prev) => ({ ...prev, id: e.target.value }))} + placeholder="e.g. my-custom-asr-model" + className="h-8 font-mono text-sm" + /> +
+
+ + setModelForm((prev) => ({ ...prev, name: e.target.value }))} + placeholder="e.g. My Custom ASR Model" + className="h-8 text-sm" + /> +
+
+ + +
+
+
+
); } diff --git a/components/settings/tts-settings.tsx b/components/settings/tts-settings.tsx index 45a03f51..a77d226d 100644 --- a/components/settings/tts-settings.tsx +++ b/components/settings/tts-settings.tsx @@ -1,14 +1,27 @@ 'use client'; -import { useState, useRef, useEffect } from 'react'; +import { useState, useRef, useEffect, useCallback, useMemo } from 'react'; import { Label } from '@/components/ui/label'; import { Input } from '@/components/ui/input'; import { Button } from '@/components/ui/button'; +import { Dialog, DialogContent, DialogDescription, DialogTitle } from '@/components/ui/dialog'; import { useI18n } from '@/lib/hooks/use-i18n'; import { useSettingsStore } from '@/lib/store/settings'; import { TTS_PROVIDERS, DEFAULT_TTS_VOICES } from '@/lib/audio/constants'; import type { TTSProviderId } from '@/lib/audio/types'; -import { Volume2, Loader2, CheckCircle2, XCircle, Eye, EyeOff } from 'lucide-react'; +import { + Volume2, + Loader2, + CheckCircle2, + XCircle, + Eye, + EyeOff, + Plus, + Settings2, + Trash2, + Circle, + CircleDot, +} from 'lucide-react'; import { cn } from '@/lib/utils'; import { createLogger } from '@/lib/logger'; @@ -21,10 +34,12 @@ interface TTSSettingsProps { export function TTSSettings({ selectedProviderId }: TTSSettingsProps) { const { t } = useI18n(); + const ttsModelId = useSettingsStore((state) => state.ttsModelId); const ttsVoice = useSettingsStore((state) => state.ttsVoice); const ttsSpeed = useSettingsStore((state) => state.ttsSpeed); const ttsProvidersConfig = useSettingsStore((state) => state.ttsProvidersConfig); const setTTSProviderConfig = useSettingsStore((state) => state.setTTSProviderConfig); + const setTTSModelId = useSettingsStore((state) => state.setTTSModelId); const activeProviderId = useSettingsStore((state) => state.ttsProviderId); // When testing a non-active provider, use that provider's default voice @@ -35,6 +50,11 @@ export function TTSSettings({ selectedProviderId }: TTSSettingsProps) { : DEFAULT_TTS_VOICES[selectedProviderId] || 'default'; const ttsProvider = TTS_PROVIDERS[selectedProviderId] ?? TTS_PROVIDERS['openai-tts']; + const builtInModels = useMemo(() => ttsProvider.models || [], [ttsProvider.models]); + const customModels = useMemo( + () => ttsProvidersConfig[selectedProviderId]?.customModels || [], + [selectedProviderId, ttsProvidersConfig], + ); const isServerConfigured = !!ttsProvidersConfig[selectedProviderId]?.isServerConfigured; const [showApiKey, setShowApiKey] = useState(false); @@ -42,6 +62,9 @@ export function TTSSettings({ selectedProviderId }: TTSSettingsProps) { const [testText, setTestText] = useState(t('settings.ttsTestTextDefault')); const [testStatus, setTestStatus] = useState<'idle' | 'testing' | 'success' | 'error'>('idle'); const [testMessage, setTestMessage] = useState(''); + const [showModelDialog, setShowModelDialog] = useState(false); + const [editingModelIndex, setEditingModelIndex] = useState(null); + const [modelForm, setModelForm] = useState({ id: '', name: '' }); const audioRef = useRef(null); // Update test text when language changes @@ -56,6 +79,63 @@ export function TTSSettings({ selectedProviderId }: TTSSettingsProps) { setTestMessage(''); }, [selectedProviderId]); + useEffect(() => { + const availableModelIds = new Set([ + ...builtInModels.map((model) => model.id), + ...customModels.map((model) => model.id), + ]); + if (availableModelIds.size > 0 && !availableModelIds.has(ttsModelId)) { + const nextModelId = builtInModels[0]?.id || customModels[0]?.id || ''; + if (nextModelId) setTTSModelId(nextModelId); + } + }, [builtInModels, customModels, ttsModelId, setTTSModelId]); + + const handleOpenAddModel = () => { + setEditingModelIndex(null); + setModelForm({ id: '', name: '' }); + setShowModelDialog(true); + }; + + const handleOpenEditModel = (index: number) => { + setEditingModelIndex(index); + setModelForm({ ...customModels[index] }); + setShowModelDialog(true); + }; + + const handleSaveModel = useCallback(() => { + if (!modelForm.id.trim()) return; + const nextCustomModels = [...customModels]; + const normalizedModel = { + id: modelForm.id.trim(), + name: modelForm.name.trim() || modelForm.id.trim(), + }; + if (editingModelIndex !== null) { + nextCustomModels[editingModelIndex] = normalizedModel; + } else { + nextCustomModels.push(normalizedModel); + } + setTTSProviderConfig(selectedProviderId, { customModels: nextCustomModels }); + setTTSModelId(normalizedModel.id); + setShowModelDialog(false); + }, [ + customModels, + editingModelIndex, + modelForm, + selectedProviderId, + setTTSModelId, + setTTSProviderConfig, + ]); + + const handleDeleteModel = (index: number) => { + const targetModel = customModels[index]; + const nextCustomModels = customModels.filter((_, i) => i !== index); + setTTSProviderConfig(selectedProviderId, { customModels: nextCustomModels }); + if (ttsModelId === targetModel?.id) { + const nextModelId = builtInModels[0]?.id || nextCustomModels[0]?.id || ''; + setTTSModelId(nextModelId); + } + }; + const handleTestTTS = async () => { if (!testText.trim()) return; setTestingTTS(true); @@ -93,6 +173,7 @@ export function TTSSettings({ selectedProviderId }: TTSSettingsProps) { text: testText, audioId: 'tts-test', ttsProviderId: selectedProviderId, + ttsModelId, ttsVoice: effectiveVoice, ttsSpeed: ttsSpeed, }; @@ -268,6 +349,137 @@ export function TTSSettings({ selectedProviderId }: TTSSettingsProps) { )} +
+
+ + +
+ +
+ {builtInModels.map((model) => { + const selected = ttsModelId === model.id; + return ( + + ); + })} + + {customModels.map((model, index) => { + const selected = ttsModelId === model.id; + return ( +
+ +
+ + +
+
+ ); + })} +
+
+ + + + + {editingModelIndex !== null ? t('settings.editModel') : t('settings.addNewModel')} + + + {editingModelIndex !== null ? t('settings.editModel') : t('settings.addNewModel')} + +
+
+ + setModelForm((prev) => ({ ...prev, id: e.target.value }))} + placeholder="e.g. my-custom-tts-model" + className="h-8 font-mono text-sm" + /> +
+
+ + setModelForm((prev) => ({ ...prev, name: e.target.value }))} + placeholder="e.g. My Custom TTS Model" + className="h-8 text-sm" + /> +
+
+ + +
+
+
+
+