Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/Controller/ConfigController.php
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ public function setSensitiveUserConfig(array $values): DataResponse {
* @return DataResponse
*/
public function setAdminConfig(array $values): DataResponse {
if (isset($values['api_key']) || isset($values['basic_password']) || isset($values['basic_user']) || isset($values['url'])) {
return new DataResponse('', Http::STATUS_BAD_REQUEST);
$prefixes = ['', 'image_', 'tts_', 'stt_'];
foreach ($prefixes as $prefix) {
if (isset($values[$prefix . 'api_key']) || isset($values[$prefix . 'basic_password']) || isset($values[$prefix . 'basic_user']) || isset($values[$prefix . 'url'])) {
return new DataResponse('', Http::STATUS_BAD_REQUEST);
}
}
try {
$this->openAiSettingsService->setAdminConfig($values);
Expand Down
5 changes: 3 additions & 2 deletions lib/Controller/OpenAiAPIController.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ public function __construct(
}

/**
* @param string|null $serviceType
* @return DataResponse
*/
#[NoAdminRequired]
public function getModels(): DataResponse {
public function getModels(?string $serviceType = null): DataResponse {
try {
$response = $this->openAiAPIService->getModels($this->userId, true);
$response = $this->openAiAPIService->getModels($this->userId, true, $serviceType);
return new DataResponse($response);
} catch (Exception $e) {
$code = $e->getCode() === 0 ? Http::STATUS_BAD_REQUEST : intval($e->getCode());
Expand Down
134 changes: 95 additions & 39 deletions lib/Service/OpenAiAPIService.php
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,47 @@ public function createQuotaUsage(string $userId, int $type, int $usage) {
}

/**
* @param ?string $serviceType
* @return bool
*/
public function isUsingOpenAi(): bool {
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
public function isUsingOpenAi(?string $serviceType = null): bool {
$serviceUrl = '';
if ($serviceType === 'image') {
$serviceUrl = $this->openAiSettingsService->getImageUrl();
} elseif ($serviceType === 'stt') {
$serviceUrl = $this->openAiSettingsService->getSttUrl();
} elseif ($serviceType === 'tts') {
$serviceUrl = $this->openAiSettingsService->getTtsUrl();
}
if ($serviceUrl === '') {
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
}
return $serviceUrl === '' || $serviceUrl === Application::OPENAI_API_BASE_URL;
}

/**
* @param ?string $serviceType
*
* @return string
*/
public function getServiceName(): string {
if ($this->isUsingOpenAi()) {
public function getServiceName(?string $serviceType = null): string {
if ($this->isUsingOpenAi($serviceType)) {
if ($serviceType === 'image') {
return $this->l10n->t('OpenAI\'s DALL-E 2');
}
if ($serviceType === 'tts') {
$this->l10n->t('OpenAI\'s Text to Speech');
}
return 'OpenAI';
} else {
$serviceName = $this->openAiSettingsService->getServiceName();
if ($serviceType === 'image' && $this->openAiSettingsService->imageOverrideEnabled()) {
$serviceName = $this->openAiSettingsService->getImageServiceName();
} elseif ($serviceType === 'stt' && $this->openAiSettingsService->sttOverrideEnabled()) {
$serviceName = $this->openAiSettingsService->getSttServiceName();
} elseif ($serviceType === 'tts' && $this->openAiSettingsService->ttsOverrideEnabled()) {
$serviceName = $this->openAiSettingsService->getTtsServiceName();
}
if ($serviceName === '') {
return 'LocalAI';
}
Expand Down Expand Up @@ -111,13 +137,15 @@ private function isModelListValid($models): bool {
/**
* @param ?string $userId
* @param bool $refresh
* @param ?string $serviceType
* @return array|string[]
* @throws Exception
*/
public function getModels(?string $userId, bool $refresh = false): array {
public function getModels(?string $userId, bool $refresh = false, ?string $serviceType = null): array {
$cache = $this->cacheFactory->createDistributed(Application::APP_ID);
$userCacheKey = Application::MODELS_CACHE_KEY . '_' . ($userId ?? '');
$adminCacheKey = Application::MODELS_CACHE_KEY . '-main';
$userCacheKey = Application::MODELS_CACHE_KEY . '_' . ($userId ?? '') . '_' . ($serviceType ?? 'main');
$adminCacheKey = Application::MODELS_CACHE_KEY . '-main' . '_' . ($serviceType ?? 'main');
$dbCacheKey = $serviceType ? 'models' . '_' . $serviceType : 'models';

if (!$refresh) {
if ($this->modelsMemoryCache !== null) {
Expand Down Expand Up @@ -155,7 +183,7 @@ public function getModels(?string $userId, bool $refresh = false): array {
}

// if we don't need to refresh to model list and it's not been found in the cache, it is obtained from the DB
$modelsObjectString = $this->appConfig->getValueString(Application::APP_ID, 'models', '{"data":[],"object":"list"}');
$modelsObjectString = $this->appConfig->getValueString(Application::APP_ID, $dbCacheKey, '{"data":[],"object":"list"}');
$fallbackModels = [
'data' => [],
'object' => 'list',
Expand All @@ -177,7 +205,7 @@ public function getModels(?string $userId, bool $refresh = false): array {

try {
$this->logger->debug('Actually getting OpenAI models with a network request');
$modelsResponse = $this->request($userId, 'models');
$modelsResponse = $this->request($userId, 'models', serviceType: $serviceType);
} catch (Exception $e) {
$this->logger->warning('Error retrieving models (exc): ' . $e->getMessage());
throw $e;
Expand All @@ -200,7 +228,7 @@ public function getModels(?string $userId, bool $refresh = false): array {
$this->modelsMemoryCache = $modelsResponse;
// we always store the model list after getting it
$modelsObjectString = json_encode($modelsResponse);
$this->appConfig->setValueString(Application::APP_ID, 'models', $modelsObjectString);
$this->appConfig->setValueString(Application::APP_ID, $dbCacheKey, $modelsObjectString);
return $modelsResponse;
}

Expand All @@ -223,9 +251,9 @@ private function hasOwnOpenAiApiKey(string $userId): bool {
* @param string|null $userId
* @return array
*/
public function getModelEnumValues(?string $userId): array {
public function getModelEnumValues(?string $userId, ?string $serviceType = null): array {
try {
$modelResponse = $this->getModels($userId);
$modelResponse = $this->getModels($userId, false, $serviceType);
$modelEnumValues = array_map(function (array $model) {
return new ShapeEnumValue($model['id'], $model['id']);
}, $modelResponse['data'] ?? []);
Expand Down Expand Up @@ -779,7 +807,7 @@ public function transcribe(
$endpoint = $translate ? 'audio/translations' : 'audio/transcriptions';
$contentType = 'multipart/form-data';

$response = $this->request($userId, $endpoint, $params, 'POST', $contentType);
$response = $this->request($userId, $endpoint, $params, 'POST', $contentType, serviceType: 'stt');

if (!isset($response['text'])) {
$this->logger->warning('Audio transcription error: ' . json_encode($response));
Expand Down Expand Up @@ -822,7 +850,7 @@ public function requestImageCreation(
'model' => $model === Application::DEFAULT_MODEL_ID ? Application::DEFAULT_IMAGE_MODEL_ID : $model,
];

$apiResponse = $this->request($userId, 'images/generations', $params, 'POST');
$apiResponse = $this->request($userId, 'images/generations', $params, 'POST', serviceType: 'image');

if (!isset($apiResponse['data']) || !is_array($apiResponse['data'])) {
$this->logger->warning('OpenAI image generation error', ['api_response' => $apiResponse]);
Expand Down Expand Up @@ -891,7 +919,7 @@ public function requestSpeechCreation(
'speed' => $speed,
];

$apiResponse = $this->request($userId, 'audio/speech', $params, 'POST');
$apiResponse = $this->request($userId, 'audio/speech', $params, 'POST', serviceType: 'tts');

try {
$charCount = mb_strlen($prompt);
Expand Down Expand Up @@ -930,7 +958,7 @@ public function updateExpTextProcessingTime(int $runtime): void {
* @return int
*/
public function getExpImgProcessingTime(): int {
return $this->isUsingOpenAi()
return $this->isUsingOpenAi('image')
? intval($this->appConfig->getValueString(Application::APP_ID, 'openai_image_generation_time', strval(Application::DEFAULT_OPENAI_IMAGE_GENERATION_TIME), lazy: true))
: intval($this->appConfig->getValueString(Application::APP_ID, 'localai_image_generation_time', strval(Application::DEFAULT_LOCALAI_IMAGE_GENERATION_TIME), lazy: true));
}
Expand All @@ -943,7 +971,7 @@ public function updateExpImgProcessingTime(int $runtime): void {
$oldTime = floatval($this->getExpImgProcessingTime());
$newTime = (1.0 - Application::EXPECTED_RUNTIME_LOWPASS_FACTOR) * $oldTime + Application::EXPECTED_RUNTIME_LOWPASS_FACTOR * floatval($runtime);

if ($this->isUsingOpenAi()) {
if ($this->isUsingOpenAi('image')) {
$this->appConfig->setValueString(Application::APP_ID, 'openai_image_generation_time', strval(intval($newTime)), lazy: true);
} else {
$this->appConfig->setValueString(Application::APP_ID, 'localai_image_generation_time', strval(intval($newTime)), lazy: true);
Expand All @@ -958,17 +986,54 @@ public function updateExpImgProcessingTime(int $runtime): void {
* @param string $method HTTP query method
* @param string|null $contentType
* @param bool $logErrors if set to false error logs will be suppressed
* @param string|null $serviceType
* @return array decoded request result or error
* @throws Exception
*/
public function request(?string $userId, string $endPoint, array $params = [], string $method = 'GET', ?string $contentType = null, bool $logErrors = true): array {
public function request(?string $userId, string $endPoint, array $params = [], string $method = 'GET', ?string $contentType = null, bool $logErrors = true, ?string $serviceType = null): array {
try {
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
if ($serviceUrl === '') {
$serviceUrl = Application::OPENAI_API_BASE_URL;
$serviceUrl = '';
$apiKey = '';
$basicUser = '';
$basicPassword = '';
$useBasicAuth = false;
$timeout = 0;

if ($serviceType === 'image') {
$serviceUrl = $this->openAiSettingsService->getImageUrl();
$apiKey = $this->openAiSettingsService->getAdminImageApiKey();
$basicUser = $this->openAiSettingsService->getAdminImageBasicUser();
$basicPassword = $this->openAiSettingsService->getAdminImageBasicPassword();
$useBasicAuth = $this->openAiSettingsService->getAdminImageUseBasicAuth();
$timeout = $this->openAiSettingsService->getImageRequestTimeout();
} elseif ($serviceType === 'stt') {
$serviceUrl = $this->openAiSettingsService->getSttUrl();
$apiKey = $this->openAiSettingsService->getAdminSttApiKey();
$basicUser = $this->openAiSettingsService->getAdminSttBasicUser();
$basicPassword = $this->openAiSettingsService->getAdminSttBasicPassword();
$useBasicAuth = $this->openAiSettingsService->getAdminSttUseBasicAuth();
$timeout = $this->openAiSettingsService->getSttRequestTimeout();
} elseif ($serviceType === 'tts') {
$serviceUrl = $this->openAiSettingsService->getTtsUrl();
$apiKey = $this->openAiSettingsService->getAdminTtsApiKey();
$basicUser = $this->openAiSettingsService->getAdminTtsBasicUser();
$basicPassword = $this->openAiSettingsService->getAdminTtsBasicPassword();
$useBasicAuth = $this->openAiSettingsService->getAdminTtsUseBasicAuth();
$timeout = $this->openAiSettingsService->getTtsRequestTimeout();
}

$timeout = $this->openAiSettingsService->getRequestTimeout();
// Currently only supporting user api keys for the default service
if (empty($serviceUrl)) {
$serviceUrl = $this->openAiSettingsService->getServiceUrl();
if ($serviceUrl === '') {
$serviceUrl = Application::OPENAI_API_BASE_URL;
}
$apiKey = $this->openAiSettingsService->getUserApiKey($userId, true);
$basicUser = $this->openAiSettingsService->getUserBasicUser($userId, true);
$basicPassword = $this->openAiSettingsService->getUserBasicPassword($userId, true);
$useBasicAuth = $this->openAiSettingsService->getUseBasicAuth();
$timeout = $this->openAiSettingsService->getRequestTimeout();
}

$url = rtrim($serviceUrl, '/') . '/' . $endPoint;
$options = [
Expand All @@ -978,20 +1043,11 @@ public function request(?string $userId, string $endPoint, array $params = [], s
],
];

// an API key is mandatory when using OpenAI
$apiKey = $this->openAiSettingsService->getUserApiKey($userId, true);

// We can also use basic authentication
$basicUser = $this->openAiSettingsService->getUserBasicUser($userId, true);
$basicPassword = $this->openAiSettingsService->getUserBasicPassword($userId, true);

if ($serviceUrl === Application::OPENAI_API_BASE_URL && $apiKey === '') {
return ['error' => 'An API key is required for api.openai.com'];
}

$useBasicAuth = $this->openAiSettingsService->getUseBasicAuth();

if ($this->isUsingOpenAi() || !$useBasicAuth) {
if ($this->isUsingOpenAi($serviceType) || !$useBasicAuth) {
if ($apiKey !== '') {
$options['headers']['Authorization'] = 'Bearer ' . $apiKey;
}
Expand All @@ -1001,7 +1057,7 @@ public function request(?string $userId, string $endPoint, array $params = [], s
}
}

if (!$this->isUsingOpenAi()) {
if (!$this->isUsingOpenAi($serviceType)) {
$options['nextcloud']['allow_local_address'] = true;
}

Expand Down Expand Up @@ -1095,15 +1151,15 @@ public function request(?string $userId, string $endPoint, array $params = [], s
* @return bool whether the T2I provider is available
*/
public function isT2IAvailable(): bool {
if ($this->isUsingOpenAi()) {
if ($this->isUsingOpenAi() || $this->openAiSettingsService->imageOverrideEnabled()) {
return true;
}
try {
$params = [
'prompt' => 'a',
'model' => 'invalid-model',
];
$this->request(null, 'images/generations', $params, 'POST', logErrors: false);
$this->request(null, 'images/generations', $params, 'POST', logErrors: false, serviceType: 'image');
} catch (Exception $e) {
return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED;
}
Expand All @@ -1116,15 +1172,15 @@ public function isT2IAvailable(): bool {
* @return bool whether the STT provider is available
*/
public function isSTTAvailable(): bool {
if ($this->isUsingOpenAi()) {
if ($this->isUsingOpenAi() || $this->openAiSettingsService->sttOverrideEnabled()) {
return true;
}
try {
$params = [
'model' => 'invalid-model',
'file' => 'a',
];
$this->request(null, 'audio/translations', $params, 'POST', 'multipart/form-data', logErrors: false);
$this->request(null, 'audio/translations', $params, 'POST', 'multipart/form-data', logErrors: false, serviceType: 'stt');
} catch (Exception $e) {
return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED;
}
Expand All @@ -1137,7 +1193,7 @@ public function isSTTAvailable(): bool {
* @return bool whether the TTS provider is available
*/
public function isTTSAvailable(): bool {
if ($this->isUsingOpenAi()) {
if ($this->isUsingOpenAi() || $this->openAiSettingsService->ttsOverrideEnabled()) {
return true;
}
try {
Expand All @@ -1148,7 +1204,7 @@ public function isTTSAvailable(): bool {
'response_format' => 'mp3',
];

$this->request(null, 'audio/speech', $params, 'POST', logErrors: false);
$this->request(null, 'audio/speech', $params, 'POST', logErrors: false, serviceType: 'tts');
} catch (Exception $e) {
return $e->getCode() !== Http::STATUS_NOT_FOUND && $e->getCode() !== Http::STATUS_UNAUTHORIZED;
}
Expand Down
Loading
Loading