Skip to content

Commit 0aba2aa

Browse files
committed
feat: support multiple providers
Signed-off-by: Lukas Schaefer <[email protected]>
1 parent d5e606e commit 0aba2aa

File tree

10 files changed

+1088
-142
lines changed

10 files changed

+1088
-142
lines changed

lib/Controller/ConfigController.php

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,11 @@ public function setSensitiveUserConfig(array $values): DataResponse {
7272
* @return DataResponse
7373
*/
7474
public function setAdminConfig(array $values): DataResponse {
75-
if (isset($values['api_key']) || isset($values['basic_password']) || isset($values['basic_user']) || isset($values['url'])) {
76-
return new DataResponse('', Http::STATUS_BAD_REQUEST);
75+
$prefixes = ['', 'image_', 'tts_', 'stt_'];
76+
foreach ($prefixes as $prefix) {
77+
if (isset($values[$prefix . 'api_key']) || isset($values[$prefix . 'basic_password']) || isset($values[$prefix . 'basic_user']) || isset($values[$prefix . 'url'])) {
78+
return new DataResponse('', Http::STATUS_BAD_REQUEST);
79+
}
7780
}
7881
try {
7982
$this->openAiSettingsService->setAdminConfig($values);

lib/Controller/OpenAiAPIController.php

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ public function __construct(
2626
}
2727

2828
/**
29+
* @param string|null $serviceType
2930
* @return DataResponse
3031
*/
3132
#[NoAdminRequired]
32-
public function getModels(): DataResponse {
33+
public function getModels(?string $serviceType = null): DataResponse {
3334
try {
34-
$response = $this->openAiAPIService->getModels($this->userId, true);
35+
$response = $this->openAiAPIService->getModels($this->userId, true, $serviceType);
3536
return new DataResponse($response);
3637
} catch (Exception $e) {
3738
$code = $e->getCode() === 0 ? Http::STATUS_BAD_REQUEST : intval($e->getCode());

lib/Service/OpenAiAPIService.php

Lines changed: 95 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,47 @@ public function createQuotaUsage(string $userId, int $type, int $usage) {
6666
}
6767

6868
/**
69+
* @param ?string $serviceType
6970
* @return bool
7071
*/
71-
public function isUsingOpenAi(): bool {
72-
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
72+
public function isUsingOpenAi(?string $serviceType = null): bool {
73+
$serviceUrl = '';
74+
if ($serviceType === 'image') {
75+
$serviceUrl = $this->openAiSettingsService->getImageUrl();
76+
} elseif ($serviceType === 'stt') {
77+
$serviceUrl = $this->openAiSettingsService->getSttUrl();
78+
} elseif ($serviceType === 'tts') {
79+
$serviceUrl = $this->openAiSettingsService->getTtsUrl();
80+
}
81+
if ($serviceUrl === '') {
82+
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
83+
}
7384
return $serviceUrl === '' || $serviceUrl === Application::OPENAI_API_BASE_URL;
7485
}
7586

7687
/**
88+
* @param ?string $serviceType
89+
*
7790
* @return string
7891
*/
79-
public function getServiceName(): string {
80-
if ($this->isUsingOpenAi()) {
92+
public function getServiceName(?string $serviceType = null): string {
93+
if ($this->isUsingOpenAi($serviceType)) {
94+
if ($serviceType === 'image') {
95+
return $this->l->t('OpenAI\'s DALL-E 2');
96+
}
97+
if ($serviceType === 'tts') {
98+
$this->l->t('OpenAI\'s Text to Speech');
99+
}
81100
return 'OpenAI';
82101
} else {
83102
$serviceName = $this->openAiSettingsService->getServiceName();
103+
if ($serviceType === 'image' && $this->openAiSettingsService->imageOverrideEnabled()) {
104+
$serviceName = $this->openAiSettingsService->getImageServiceName();
105+
} elseif ($serviceType === 'stt' && $this->openAiSettingsService->sttOverrideEnabled()) {
106+
$serviceName = $this->openAiSettingsService->getSttServiceName();
107+
} elseif ($serviceType === 'tts' && $this->openAiSettingsService->ttsOverrideEnabled()) {
108+
$serviceName = $this->openAiSettingsService->getTtsServiceName();
109+
}
84110
if ($serviceName === '') {
85111
return 'LocalAI';
86112
}
@@ -111,13 +137,15 @@ private function isModelListValid($models): bool {
111137
/**
112138
* @param ?string $userId
113139
* @param bool $refresh
140+
* @param ?string $serviceType
114141
* @return array|string[]
115142
* @throws Exception
116143
*/
117-
public function getModels(?string $userId, bool $refresh = false): array {
144+
public function getModels(?string $userId, bool $refresh = false, ?string $serviceType = null): array {
118145
$cache = $this->cacheFactory->createDistributed(Application::APP_ID);
119-
$userCacheKey = Application::MODELS_CACHE_KEY . '_' . ($userId ?? '');
120-
$adminCacheKey = Application::MODELS_CACHE_KEY . '-main';
146+
$userCacheKey = Application::MODELS_CACHE_KEY . '_' . ($userId ?? '') . '_' . ($serviceType ?? 'main');
147+
$adminCacheKey = Application::MODELS_CACHE_KEY . '-main' . '_' . ($serviceType ?? 'main');
148+
$dbCacheKey = $serviceType ? 'models' . '_' . $serviceType : 'models';
121149

122150
if (!$refresh) {
123151
if ($this->modelsMemoryCache !== null) {
@@ -155,7 +183,7 @@ public function getModels(?string $userId, bool $refresh = false): array {
155183
}
156184

157185
// if we don't need to refresh to model list and it's not been found in the cache, it is obtained from the DB
158-
$modelsObjectString = $this->appConfig->getValueString(Application::APP_ID, 'models', '{"data":[],"object":"list"}');
186+
$modelsObjectString = $this->appConfig->getValueString(Application::APP_ID, $dbCacheKey, '{"data":[],"object":"list"}');
159187
$fallbackModels = [
160188
'data' => [],
161189
'object' => 'list',
@@ -177,7 +205,7 @@ public function getModels(?string $userId, bool $refresh = false): array {
177205

178206
try {
179207
$this->logger->debug('Actually getting OpenAI models with a network request');
180-
$modelsResponse = $this->request($userId, 'models');
208+
$modelsResponse = $this->request($userId, 'models', serviceType: $serviceType);
181209
} catch (Exception $e) {
182210
$this->logger->warning('Error retrieving models (exc): ' . $e->getMessage());
183211
throw $e;
@@ -200,7 +228,7 @@ public function getModels(?string $userId, bool $refresh = false): array {
200228
$this->modelsMemoryCache = $modelsResponse;
201229
// we always store the model list after getting it
202230
$modelsObjectString = json_encode($modelsResponse);
203-
$this->appConfig->setValueString(Application::APP_ID, 'models', $modelsObjectString);
231+
$this->appConfig->setValueString(Application::APP_ID, $dbCacheKey, $modelsObjectString);
204232
return $modelsResponse;
205233
}
206234

@@ -223,9 +251,9 @@ private function hasOwnOpenAiApiKey(string $userId): bool {
223251
* @param string|null $userId
224252
* @return array
225253
*/
226-
public function getModelEnumValues(?string $userId): array {
254+
public function getModelEnumValues(?string $userId, ?string $serviceType = null): array {
227255
try {
228-
$modelResponse = $this->getModels($userId);
256+
$modelResponse = $this->getModels($userId, false, $serviceType);
229257
$modelEnumValues = array_map(function (array $model) {
230258
return new ShapeEnumValue($model['id'], $model['id']);
231259
}, $modelResponse['data'] ?? []);
@@ -779,7 +807,7 @@ public function transcribe(
779807
$endpoint = $translate ? 'audio/translations' : 'audio/transcriptions';
780808
$contentType = 'multipart/form-data';
781809

782-
$response = $this->request($userId, $endpoint, $params, 'POST', $contentType);
810+
$response = $this->request($userId, $endpoint, $params, 'POST', $contentType, serviceType: 'stt');
783811

784812
if (!isset($response['text'])) {
785813
$this->logger->warning('Audio transcription error: ' . json_encode($response));
@@ -822,7 +850,7 @@ public function requestImageCreation(
822850
'model' => $model === Application::DEFAULT_MODEL_ID ? Application::DEFAULT_IMAGE_MODEL_ID : $model,
823851
];
824852

825-
$apiResponse = $this->request($userId, 'images/generations', $params, 'POST');
853+
$apiResponse = $this->request($userId, 'images/generations', $params, 'POST', serviceType: 'image');
826854

827855
if (!isset($apiResponse['data']) || !is_array($apiResponse['data'])) {
828856
$this->logger->warning('OpenAI image generation error', ['api_response' => $apiResponse]);
@@ -891,7 +919,7 @@ public function requestSpeechCreation(
891919
'speed' => $speed,
892920
];
893921

894-
$apiResponse = $this->request($userId, 'audio/speech', $params, 'POST');
922+
$apiResponse = $this->request($userId, 'audio/speech', $params, 'POST', serviceType: 'tts');
895923

896924
try {
897925
$charCount = mb_strlen($prompt);
@@ -930,7 +958,7 @@ public function updateExpTextProcessingTime(int $runtime): void {
930958
* @return int
931959
*/
932960
public function getExpImgProcessingTime(): int {
933-
return $this->isUsingOpenAi()
961+
return $this->isUsingOpenAi('image')
934962
? intval($this->appConfig->getValueString(Application::APP_ID, 'openai_image_generation_time', strval(Application::DEFAULT_OPENAI_IMAGE_GENERATION_TIME), lazy: true))
935963
: intval($this->appConfig->getValueString(Application::APP_ID, 'localai_image_generation_time', strval(Application::DEFAULT_LOCALAI_IMAGE_GENERATION_TIME), lazy: true));
936964
}
@@ -943,7 +971,7 @@ public function updateExpImgProcessingTime(int $runtime): void {
943971
$oldTime = floatval($this->getExpImgProcessingTime());
944972
$newTime = (1.0 - Application::EXPECTED_RUNTIME_LOWPASS_FACTOR) * $oldTime + Application::EXPECTED_RUNTIME_LOWPASS_FACTOR * floatval($runtime);
945973

946-
if ($this->isUsingOpenAi()) {
974+
if ($this->isUsingOpenAi('image')) {
947975
$this->appConfig->setValueString(Application::APP_ID, 'openai_image_generation_time', strval(intval($newTime)), lazy: true);
948976
} else {
949977
$this->appConfig->setValueString(Application::APP_ID, 'localai_image_generation_time', strval(intval($newTime)), lazy: true);
@@ -958,17 +986,54 @@ public function updateExpImgProcessingTime(int $runtime): void {
958986
* @param string $method HTTP query method
959987
* @param string|null $contentType
960988
* @param bool $logErrors if set to false error logs will be suppressed
989+
* @param string|null $serviceType
961990
* @return array decoded request result or error
962991
* @throws Exception
963992
*/
964-
public function request(?string $userId, string $endPoint, array $params = [], string $method = 'GET', ?string $contentType = null, bool $logErrors = true): array {
993+
public function request(?string $userId, string $endPoint, array $params = [], string $method = 'GET', ?string $contentType = null, bool $logErrors = true, ?string $serviceType = null): array {
965994
try {
966-
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
967-
if ($serviceUrl === '') {
968-
$serviceUrl = Application::OPENAI_API_BASE_URL;
995+
$serviceUrl = '';
996+
$apiKey = '';
997+
$basicUser = '';
998+
$basicPassword = '';
999+
$useBasicAuth = false;
1000+
$timeout = 0;
1001+
1002+
if ($serviceType === 'image') {
1003+
$serviceUrl = $this->openAiSettingsService->getImageUrl();
1004+
$apiKey = $this->openAiSettingsService->getAdminImageApiKey();
1005+
$basicUser = $this->openAiSettingsService->getAdminImageBasicUser();
1006+
$basicPassword = $this->openAiSettingsService->getAdminImageBasicPassword();
1007+
$useBasicAuth = $this->openAiSettingsService->getAdminImageUseBasicAuth();
1008+
$timeout = $this->openAiSettingsService->getImageRequestTimeout();
1009+
} elseif ($serviceType === 'stt') {
1010+
$serviceUrl = $this->openAiSettingsService->getSttUrl();
1011+
$apiKey = $this->openAiSettingsService->getAdminSttApiKey();
1012+
$basicUser = $this->openAiSettingsService->getAdminSttBasicUser();
1013+
$basicPassword = $this->openAiSettingsService->getAdminSttBasicPassword();
1014+
$useBasicAuth = $this->openAiSettingsService->getAdminSttUseBasicAuth();
1015+
$timeout = $this->openAiSettingsService->getSttRequestTimeout();
1016+
} elseif ($serviceType === 'tts') {
1017+
$serviceUrl = $this->openAiSettingsService->getTtsUrl();
1018+
$apiKey = $this->openAiSettingsService->getAdminTtsApiKey();
1019+
$basicUser = $this->openAiSettingsService->getAdminTtsBasicUser();
1020+
$basicPassword = $this->openAiSettingsService->getAdminTtsBasicPassword();
1021+
$useBasicAuth = $this->openAiSettingsService->getAdminTtsUseBasicAuth();
1022+
$timeout = $this->openAiSettingsService->getTtsRequestTimeout();
9691023
}
9701024

971-
$timeout = $this->openAiSettingsService->getRequestTimeout();
1025+
// Currently only supporting user api keys for the default service
1026+
if (empty($serviceUrl)) {
1027+
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
1028+
if ($serviceUrl === '') {
1029+
$serviceUrl = Application::OPENAI_API_BASE_URL;
1030+
}
1031+
$apiKey = $this->openAiSettingsService->getUserApiKey($userId, true);
1032+
$basicUser = $this->openAiSettingsService->getUserBasicUser($userId, true);
1033+
$basicPassword = $this->openAiSettingsService->getUserBasicPassword($userId, true);
1034+
$useBasicAuth = $this->openAiSettingsService->getUseBasicAuth();
1035+
$timeout = $this->openAiSettingsService->getRequestTimeout();
1036+
}
9721037

9731038
$url = rtrim($serviceUrl, '/') . '/' . $endPoint;
9741039
$options = [
@@ -978,20 +1043,11 @@ public function request(?string $userId, string $endPoint, array $params = [], s
9781043
],
9791044
];
9801045

981-
// an API key is mandatory when using OpenAI
982-
$apiKey = $this->openAiSettingsService->getUserApiKey($userId, true);
983-
984-
// We can also use basic authentication
985-
$basicUser = $this->openAiSettingsService->getUserBasicUser($userId, true);
986-
$basicPassword = $this->openAiSettingsService->getUserBasicPassword($userId, true);
987-
9881046
if ($serviceUrl === Application::OPENAI_API_BASE_URL && $apiKey === '') {
9891047
return ['error' => 'An API key is required for api.openai.com'];
9901048
}
9911049

992-
$useBasicAuth = $this->openAiSettingsService->getUseBasicAuth();
993-
994-
if ($this->isUsingOpenAi() || !$useBasicAuth) {
1050+
if ($this->isUsingOpenAi($serviceType) || !$useBasicAuth) {
9951051
if ($apiKey !== '') {
9961052
$options['headers']['Authorization'] = 'Bearer ' . $apiKey;
9971053
}
@@ -1001,7 +1057,7 @@ public function request(?string $userId, string $endPoint, array $params = [], s
10011057
}
10021058
}
10031059

1004-
if (!$this->isUsingOpenAi()) {
1060+
if (!$this->isUsingOpenAi($serviceType)) {
10051061
$options['nextcloud']['allow_local_address'] = true;
10061062
}
10071063

@@ -1095,15 +1151,15 @@ public function request(?string $userId, string $endPoint, array $params = [], s
10951151
* @return bool whether the T2I provider is available
10961152
*/
10971153
public function isT2IAvailable(): bool {
1098-
if ($this->isUsingOpenAi()) {
1154+
if ($this->isUsingOpenAi() || $this->openAiSettingsService->imageOverrideEnabled()) {
10991155
return true;
11001156
}
11011157
try {
11021158
$params = [
11031159
'prompt' => 'a',
11041160
'model' => 'invalid-model',
11051161
];
1106-
$this->request(null, 'images/generations', $params, 'POST', logErrors: false);
1162+
$this->request(null, 'images/generations', $params, 'POST', logErrors: false, serviceType: 'image');
11071163
} catch (Exception $e) {
11081164
return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED;
11091165
}
@@ -1116,15 +1172,15 @@ public function isT2IAvailable(): bool {
11161172
* @return bool whether the STT provider is available
11171173
*/
11181174
public function isSTTAvailable(): bool {
1119-
if ($this->isUsingOpenAi()) {
1175+
if ($this->isUsingOpenAi() || $this->openAiSettingsService->sttOverrideEnabled()) {
11201176
return true;
11211177
}
11221178
try {
11231179
$params = [
11241180
'model' => 'invalid-model',
11251181
'file' => 'a',
11261182
];
1127-
$this->request(null, 'audio/translations', $params, 'POST', 'multipart/form-data', logErrors: false);
1183+
$this->request(null, 'audio/translations', $params, 'POST', 'multipart/form-data', logErrors: false, serviceType: 'stt');
11281184
} catch (Exception $e) {
11291185
return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED;
11301186
}
@@ -1137,7 +1193,7 @@ public function isSTTAvailable(): bool {
11371193
* @return bool whether the TTS provider is available
11381194
*/
11391195
public function isTTSAvailable(): bool {
1140-
if ($this->isUsingOpenAi()) {
1196+
if ($this->isUsingOpenAi() || $this->openAiSettingsService->ttsOverrideEnabled()) {
11411197
return true;
11421198
}
11431199
try {
@@ -1148,7 +1204,7 @@ public function isTTSAvailable(): bool {
11481204
'response_format' => 'mp3',
11491205
];
11501206

1151-
$this->request(null, 'audio/speech', $params, 'POST', logErrors: false);
1207+
$this->request(null, 'audio/speech', $params, 'POST', logErrors: false, serviceType: 'tts');
11521208
} catch (Exception $e) {
11531209
return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED;
11541210
}

0 commit comments

Comments
 (0)