diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index 2097212e..7ddd8eec 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -9,6 +9,7 @@ use OCA\OpenAi\Capabilities; use OCA\OpenAi\OldProcessing\Translation\TranslationProvider as OldTranslationProvider; +use OCA\OpenAi\TaskProcessing\AudioToAudioChatProvider; use OCA\OpenAi\TaskProcessing\AudioToTextProvider; use OCA\OpenAi\TaskProcessing\ChangeToneProvider; use OCA\OpenAi\TaskProcessing\ChangeToneTaskType; @@ -131,6 +132,22 @@ public function register(IRegistrationContext $context): void { $context->registerTaskProcessingProvider(TextToImageProvider::class); } + // only register audio chat stuff if we're using OpenAI or stt+llm+tts are enabled + $serviceUrl = $this->appConfig->getValueString(Application::APP_ID, 'url'); + $isUsingOpenAI = $serviceUrl === '' || $serviceUrl === Application::OPENAI_API_BASE_URL; + if ( + $isUsingOpenAI + || ( + $this->appConfig->getValueString(Application::APP_ID, 'stt_provider_enabled', '1') === '1' + && $this->appConfig->getValueString(Application::APP_ID, 'llm_provider_enabled', '1') === '1' + && $this->appConfig->getValueString(Application::APP_ID, 'tts_provider_enabled', '1') === '1' + ) + ) { + if (class_exists('OCP\\TaskProcessing\\TaskTypes\\AudioToAudioChat')) { + $context->registerTaskProcessingProvider(AudioToAudioChatProvider::class); + } + } + $context->registerCapability(Capabilities::class); } diff --git a/lib/Service/OpenAiAPIService.php b/lib/Service/OpenAiAPIService.php index 0ebe187f..1c4bede7 100644 --- a/lib/Service/OpenAiAPIService.php +++ b/lib/Service/OpenAiAPIService.php @@ -437,7 +437,8 @@ public function createCompletion( * @param array|null $extraParams * @param string|null $toolMessage JSON string with role, content, tool_call_id * @param array|null $tools - * @return array{messages: array, tool_calls: array} + * @param string|null $userAudioPromptBase64 + * @return array{messages: array, tool_calls: array, audio_messages: list>} * @throws Exception */ public function createChatCompletion( @@ -451,6 +452,7 @@ public function createChatCompletion( ?array $extraParams = null, ?string $toolMessage = null, ?array $tools = null, + ?string $userAudioPromptBase64 = null, ): array { if ($this->isQuotaExceeded($userId, Application::QUOTA_TYPE_TEXT)) { throw new Exception($this->l10n->t('Text generation quota exceeded'), Http::STATUS_TOO_MANY_REQUESTS); @@ -494,8 +496,24 @@ public function createChatCompletion( $messages[] = $message; } } - if ($userPrompt !== null) { - $messages[] = ['role' => 'user', 'content' => $userPrompt]; + if ($userPrompt !== null || $userAudioPromptBase64 !== null) { + $message = ['role' => 'user', 'content' => []]; + if ($userPrompt !== null) { + $message['content'][] = [ + 'type' => 'text', + 'text' => $userPrompt, + ]; + } + if ($userAudioPromptBase64 !== null) { + $message['content'][] = [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => $userAudioPromptBase64, + 'format' => 'mp3', + ], + ]; + } + $messages[] = $message; } if ($toolMessage !== null) { $msgs = json_decode($toolMessage, true); @@ -555,6 +573,7 @@ public function createChatCompletion( $completions = [ 'messages' => [], 'tool_calls' => [], + 'audio_messages' => [], ]; foreach ($response['choices'] as $choice) { @@ -583,6 +602,9 @@ public function createChatCompletion( if (isset($choice['message']['content']) && is_string($choice['message']['content'])) { $completions['messages'][] = $choice['message']['content']; } + if (isset($choice['message']['audio'], $choice['message']['audio']['data']) && is_string($choice['message']['audio']['data'])) { + $completions['audio_messages'][] = $choice['message']; + } } return $completions; diff --git a/lib/TaskProcessing/AudioToAudioChatProvider.php b/lib/TaskProcessing/AudioToAudioChatProvider.php new file mode 100644 index 00000000..637843f5 --- /dev/null +++ b/lib/TaskProcessing/AudioToAudioChatProvider.php @@ -0,0 +1,310 @@ +openAiAPIService->getServiceName(); + } + + public function getTaskTypeId(): string { + return AudioToAudioChat::ID; + } + + public function getExpectedRuntime(): int { + return $this->openAiAPIService->getExpTextProcessingTime(); + } + + public function getInputShapeEnumValues(): array { + return []; + } + + public function getInputShapeDefaults(): array { + return []; + } + + + public function getOptionalInputShape(): array { + $isUsingOpenAi = $this->openAiAPIService->isUsingOpenAi(); + $ois = [ + 'llm_model' => new ShapeDescriptor( + $this->l->t('Completion model'), + $this->l->t('The model used to generate the completion'), + EShapeType::Enum + ), + 'voice' => new ShapeDescriptor( + $this->l->t('Output voice'), + $this->l->t('The voice used to generate speech'), + EShapeType::Enum + ), + ]; + if (!$isUsingOpenAi) { + $ois['tts_model'] = new ShapeDescriptor( + $this->l->t('Text-to-speech model'), + $this->l->t('The model used to generate the speech'), + EShapeType::Enum + ); + $ois['speed'] = new ShapeDescriptor( + $this->l->t('Speed'), + $this->openAiAPIService->isUsingOpenAi() + ? $this->l->t('Speech speed modifier (Valid values: 0.25-4)') + : $this->l->t('Speech speed modifier'), + EShapeType::Number + ); + } + return $ois; + } + + public function getOptionalInputShapeEnumValues(): array { + $isUsingOpenAi = $this->openAiAPIService->isUsingOpenAi(); + $voices = json_decode($this->appConfig->getValueString(Application::APP_ID, 'tts_voices')) ?: Application::DEFAULT_SPEECH_VOICES; + $models = $this->openAiAPIService->getModelEnumValues($this->userId); + $enumValues = [ + 'voice' => array_map(function ($v) { return new ShapeEnumValue($v, $v); }, $voices), + 'llm_model' => $models, + ]; + if (!$isUsingOpenAi) { + $enumValues['tts_model'] = $models; + } + return $enumValues; + } + + public function getOptionalInputShapeDefaults(): array { + $isUsingOpenAi = $this->openAiAPIService->isUsingOpenAi(); + $adminVoice = $this->appConfig->getValueString(Application::APP_ID, 'default_speech_voice') ?: Application::DEFAULT_SPEECH_VOICE; + $adminLlmModel = $isUsingOpenAi + ? 'gpt-4o-audio-preview' + : $this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id'); + $defaults = [ + 'voice' => $adminVoice, + 'llm_model' => $adminLlmModel, + ]; + if (!$isUsingOpenAi) { + $adminTtsModel = $this->appConfig->getValueString(Application::APP_ID, 'default_speech_model_id') ?: Application::DEFAULT_SPEECH_MODEL_ID; + $defaults['tts_model'] = $adminTtsModel; + $defaults['speed'] = 1; + } + return $defaults; + } + + public function getOutputShapeEnumValues(): array { + return []; + } + + public function getOptionalOutputShape(): array { + return [ + 'audio_id' => new ShapeDescriptor( + $this->l->t('Remote audio ID'), + $this->l->t('The ID of the audio response returned by the remote service'), + EShapeType::Text + ), + 'audio_expires_at' => new ShapeDescriptor( + $this->l->t('Remote audio expiration date'), + $this->l->t('The remote audio response stays available in the service until this date'), + EShapeType::Number + ), + ]; + } + + public function getOptionalOutputShapeEnumValues(): array { + return []; + } + + public function process(?string $userId, array $input, callable $reportProgress): array { + if (!isset($input['input']) || !$input['input'] instanceof File || !$input['input']->isReadable()) { + throw new RuntimeException('Invalid input audio file in the "input" field. A readable file is expected.'); + } + $inputFile = $input['input']; + + if (!isset($input['system_prompt']) || !is_string($input['system_prompt'])) { + throw new RuntimeException('Invalid system_prompt'); + } + $systemPrompt = $input['system_prompt']; + + if (!isset($input['history']) || !is_array($input['history'])) { + throw new RuntimeException('Invalid chat history, array expected'); + } + $history = $input['history']; + + if (isset($input['tts_model']) && is_string($input['tts_model'])) { + $ttsModel = $input['tts_model']; + } else { + $ttsModel = $this->appConfig->getValueString(Application::APP_ID, 'default_speech_model_id', Application::DEFAULT_SPEECH_MODEL_ID) ?: Application::DEFAULT_SPEECH_MODEL_ID; + } + + if (isset($input['llm_model']) && is_string($input['llm_model'])) { + $llmModel = $input['llm_model']; + } else { + $isUsingOpenAi = $this->openAiAPIService->isUsingOpenAi(); + $llmModel = $isUsingOpenAi + ? 'gpt-4o-audio-preview' + : ($this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_MODEL_ID) ?: Application::DEFAULT_MODEL_ID); + } + + + if (isset($input['voice']) && is_string($input['voice'])) { + $outputVoice = $input['voice']; + } else { + $outputVoice = $this->appConfig->getValueString(Application::APP_ID, 'default_speech_voice', Application::DEFAULT_SPEECH_VOICE) ?: Application::DEFAULT_SPEECH_VOICE; + } + + $speed = 1; + if (isset($input['speed']) && is_numeric($input['speed'])) { + $speed = $input['speed']; + if ($this->openAiAPIService->isUsingOpenAi()) { + if ($speed > 4) { + $speed = 4; + } elseif ($speed < 0.25) { + $speed = 0.25; + } + } + } + + $sttModel = $this->appConfig->getValueString(Application::APP_ID, 'default_stt_model_id', Application::DEFAULT_MODEL_ID) ?: Application::DEFAULT_MODEL_ID; + $serviceName = $this->appConfig->getValueString(Application::APP_ID, 'service_name') ?: Application::APP_ID; + + // Using the chat API if connected to OpenAI + // there is an issue if the history mostly contains text, the model will answer text even if we add the audio modality + if ($this->openAiAPIService->isUsingOpenAi()) { + return $this->oneStep($userId, $systemPrompt, $inputFile, $history, $outputVoice, $sttModel, $llmModel, $ttsModel, $speed, $serviceName); + } + + // 3 steps: STT -> LLM -> TTS + return $this->threeSteps($userId, $systemPrompt, $inputFile, $history, $outputVoice, $sttModel, $llmModel, $ttsModel, $speed, $serviceName); + } + + private function oneStep( + ?string $userId, string $systemPrompt, File $inputFile, array $history, string $outputVoice, + string $sttModel, string $llmModel, string $ttsModel, float $speed, string $serviceName, + ): array { + $result = []; + $b64Audio = base64_encode($inputFile->getContent()); + $extraParams = [ + 'modalities' => ['text', 'audio'], + 'audio' => ['voice' => $outputVoice, 'format' => 'mp3'], + ]; + $systemPrompt .= ' Producing text responses will break the user interface. Important: You have multimodal voice capability, and you use voice exclusively to respond.'; + $completion = $this->openAiAPIService->createChatCompletion( + $userId, $llmModel, null, $systemPrompt, $history, 1, 1000, + $extraParams, null, null, $b64Audio, + ); + $message = array_pop($completion['audio_messages']); + // TODO find a way to force the model to answer with audio when there is only text in the history + // https://community.openai.com/t/gpt-4o-audio-preview-responds-in-text-not-audio/1006486/5 + if ($message === null) { + // no audio, TTS the text message + try { + $textResponse = array_pop($completion['messages']); + $apiResponse = $this->openAiAPIService->requestSpeechCreation($userId, $textResponse, $ttsModel, $outputVoice, $speed); + if (!isset($apiResponse['body'])) { + $this->logger->warning($serviceName . ' text to speech generation failed: no speech returned'); + throw new RuntimeException($serviceName . ' text to speech generation failed: no speech returned'); + } + $output = $apiResponse['body']; + } catch (\Exception $e) { + $this->logger->warning($serviceName . ' text to speech generation failed with: ' . $e->getMessage(), ['exception' => $e]); + throw new RuntimeException($serviceName . ' text to speech generation failed with: ' . $e->getMessage()); + } + } else { + $output = base64_decode($message['audio']['data']); + $textResponse = $message['audio']['transcript']; + if (isset($message['audio']['id'])) { + $result['audio_id'] = $message['audio']['id']; + } + if (isset($message['audio']['expires_at'])) { + $result['audio_expires_at'] = $message['audio']['expires_at']; + } + } + $result['output'] = $output; + $result['output_transcript'] = $textResponse; + + // we still want the input transcription + try { + $inputTranscription = $this->openAiAPIService->transcribeFile($userId, $inputFile, false, $sttModel); + $result['input_transcript'] = $inputTranscription; + } catch (Exception $e) { + $this->logger->warning($serviceName . ' audio input transcription failed with: ' . $e->getMessage(), ['exception' => $e]); + throw new RuntimeException($serviceName . ' audio input transcription failed with: ' . $e->getMessage()); + } + + return $result; + } + + private function threeSteps( + ?string $userId, string $systemPrompt, File $inputFile, array $history, string $outputVoice, + string $sttModel, string $llmModel, string $ttsModel, float $speed, string $serviceName, + ): array { + // speech to text + try { + $inputTranscription = $this->openAiAPIService->transcribeFile($userId, $inputFile, false, $sttModel); + } catch (Exception $e) { + $this->logger->warning($serviceName . ' transcription failed with: ' . $e->getMessage(), ['exception' => $e]); + throw new RuntimeException($serviceName . ' transcription failed with: ' . $e->getMessage()); + } + + // free prompt + try { + $completion = $this->openAiAPIService->createChatCompletion($userId, $llmModel, $inputTranscription, $systemPrompt, $history, 1, 1000); + $completion = $completion['messages']; + } catch (Exception $e) { + throw new RuntimeException($serviceName . ' chat completion request failed: ' . $e->getMessage()); + } + if (count($completion) === 0) { + throw new RuntimeException('No completion in ' . $serviceName . ' response.'); + } + $llmResult = array_pop($completion); + + // text to speech + try { + $apiResponse = $this->openAiAPIService->requestSpeechCreation($userId, $llmResult, $ttsModel, $outputVoice, $speed); + + if (!isset($apiResponse['body'])) { + $this->logger->warning($serviceName . ' text to speech generation failed: no speech returned'); + throw new RuntimeException($serviceName . ' text to speech generation failed: no speech returned'); + } + return [ + 'output' => $apiResponse['body'], + 'output_transcript' => $llmResult, + 'input_transcript' => $inputTranscription, + ]; + } catch (\Exception $e) { + $this->logger->warning($serviceName . ' text to speech generation failed with: ' . $e->getMessage(), ['exception' => $e]); + throw new RuntimeException($serviceName . ' text to speech generation failed with: ' . $e->getMessage()); + } + } +} diff --git a/psalm.xml b/psalm.xml index 8ca4dcbc..128fcc86 100644 --- a/psalm.xml +++ b/psalm.xml @@ -39,6 +39,7 @@ + diff --git a/tests/unit/Providers/OpenAiProviderTest.php b/tests/unit/Providers/OpenAiProviderTest.php index ee9923b2..fe0d2bd5 100644 --- a/tests/unit/Providers/OpenAiProviderTest.php +++ b/tests/unit/Providers/OpenAiProviderTest.php @@ -140,7 +140,7 @@ public function testFreePromptProvider(): void { $options = ['timeout' => Application::OPENAI_DEFAULT_REQUEST_TIMEOUT, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => self::AUTHORIZATION_HEADER, 'Content-Type' => 'application/json']]; $options['body'] = json_encode([ 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, - 'messages' => [['role' => 'user', 'content' => $prompt]], + 'messages' => [['role' => 'user', 'content' => [['type' => 'text', 'text' => $prompt]]]], 'n' => $n, 'max_completion_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'user' => self::TEST_USER1, @@ -204,7 +204,7 @@ public function testEmojiProvider(): void { $message = 'Give me an emoji for the following text. Output only the emoji without any other characters.' . "\n\n" . $prompt; $options['body'] = json_encode([ 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, - 'messages' => [['role' => 'user', 'content' => $message]], + 'messages' => [['role' => 'user', 'content' => [['type' => 'text', 'text' => $message]]]], 'n' => $n, 'max_completion_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'user' => self::TEST_USER1, @@ -269,7 +269,7 @@ public function testHeadlineProvider(): void { $message = 'Give me the headline of the following text in its original language. Do not output the language. Output only the headline without any quotes or additional punctuation.' . "\n\n" . $prompt; $options['body'] = json_encode([ 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, - 'messages' => [['role' => 'user', 'content' => $message]], + 'messages' => [['role' => 'user', 'content' => [['type' => 'text', 'text' => $message]]]], 'n' => $n, 'max_completion_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'user' => self::TEST_USER1, @@ -334,7 +334,7 @@ public function testChangeToneProvider(): void { $message = "Reformulate the following text in a $toneInput tone in its original language. Output only the reformulation. Here is the text:" . "\n\n" . $textInput . "\n\n" . 'Do not mention the used language in your reformulation. Here is your reformulation in the same language:'; $options['body'] = json_encode([ 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, - 'messages' => [['role' => 'user', 'content' => $message]], + 'messages' => [['role' => 'user', 'content' => [['type' => 'text', 'text' => $message]]]], 'n' => $n, 'max_completion_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'user' => self::TEST_USER1, @@ -400,8 +400,10 @@ public function testSummaryProvider(): void { . 'You should only return the summary without any additional information.'; $options['body'] = json_encode([ 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, - 'messages' => [['role' => 'system', 'content' => $systemPrompt], - ['role' => 'user', 'content' => $prompt]], + 'messages' => [ + ['role' => 'system', 'content' => $systemPrompt], + ['role' => 'user', 'content' => [['type' => 'text', 'text' => $prompt]]], + ], 'n' => $n, 'max_completion_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'user' => self::TEST_USER1, @@ -465,7 +467,10 @@ public function testProofreadProvider(): void { $systemPrompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list.'; $options['body'] = json_encode([ 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, - 'messages' => [['role' => 'system', 'content' => $systemPrompt],['role' => 'user', 'content' => $prompt]], + 'messages' => [ + ['role' => 'system', 'content' => $systemPrompt], + ['role' => 'user', 'content' => [['type' => 'text', 'text' => $prompt]]], + ], 'n' => $n, 'max_completion_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'user' => self::TEST_USER1, @@ -533,7 +538,9 @@ public function testTranslationProvider(): void { $options = ['timeout' => Application::OPENAI_DEFAULT_REQUEST_TIMEOUT, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => self::AUTHORIZATION_HEADER, 'Content-Type' => 'application/json']]; $options['body'] = json_encode([ 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, - 'messages' => [['role' => 'user', 'content' => 'Translate from ' . $fromLang . ' to English (US): ' . $inputText]], + 'messages' => [ + ['role' => 'user', 'content' => [['type' => 'text', 'text' => 'Translate from ' . $fromLang . ' to English (US): ' . $inputText]]], + ], 'n' => $n, 'max_completion_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'user' => self::TEST_USER1,