Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions lib/private/TaskProcessing/Manager.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
use OCP\Files\NotPermittedException;
use OCP\Files\SimpleFS\ISimpleFile;
use OCP\Http\Client\IClientService;
use OCP\ICache;
use OCP\ICacheFactory;
use OCP\IConfig;
use OCP\IL10N;
use OCP\IServerContainer;
Expand Down Expand Up @@ -77,6 +79,11 @@ class Manager implements IManager {
private ?array $availableTaskTypes = null;

private IAppData $appData;
private ?array $preferences = null;
private ?array $providersById = null;
private ICache $cache;
private ICache $distributedCache;

public function __construct(
private IConfig $config,
private Coordinator $coordinator,
Expand All @@ -91,8 +98,11 @@ public function __construct(
private IUserMountCache $userMountCache,
private IClientService $clientService,
private IAppManager $appManager,
ICacheFactory $cacheFactory,
) {
$this->appData = $appDataFactory->get('core');
$this->cache = $cacheFactory->createLocal('task_processing::');
$this->distributedCache = $cacheFactory->createDistributed('task_processing::');
}


Expand Down Expand Up @@ -698,12 +708,23 @@ public function getProviders(): array {

public function getPreferredProvider(string $taskTypeId) {
try {
$preferences = json_decode($this->config->getAppValue('core', 'ai.taskprocessing_provider_preferences', 'null'), associative: true, flags: JSON_THROW_ON_ERROR);
if ($this->preferences === null) {
$this->preferences = $this->distributedCache->get('ai.taskprocessing_provider_preferences');
if ($this->preferences === null) {
$this->preferences = json_decode($this->config->getAppValue('core', 'ai.taskprocessing_provider_preferences', 'null'), associative: true, flags: JSON_THROW_ON_ERROR);
$this->distributedCache->set('ai.taskprocessing_provider_preferences', $this->preferences, 60 * 3);
}
}

$providers = $this->getProviders();
if (isset($preferences[$taskTypeId])) {
$provider = current(array_values(array_filter($providers, fn ($provider) => $provider->getId() === $preferences[$taskTypeId])));
if ($provider !== false) {
return $provider;
if (isset($this->preferences[$taskTypeId])) {
$providersById = $this->providersById ?? array_reduce($providers, static function (array $carry, IProvider $provider) {
$carry[$provider->getId()] = $provider;
return $carry;
}, []);
$this->providersById = $providersById;
if (isset($providersById[$this->preferences[$taskTypeId]])) {
return $providersById[$this->preferences[$taskTypeId]];
}
}
// By default, use the first available provider
Expand All @@ -719,6 +740,10 @@ public function getPreferredProvider(string $taskTypeId) {
}

public function getAvailableTaskTypes(): array {
if ($this->availableTaskTypes === null) {
// We use local cache only because distributed cache uses JSOn stringify which would botch our ShapeDescriptor objects
$this->availableTaskTypes = $this->cache->get('available_task_types');
}
if ($this->availableTaskTypes === null) {
$taskTypes = $this->_getTaskTypes();

Expand Down Expand Up @@ -750,6 +775,7 @@ public function getAvailableTaskTypes(): array {
}

$this->availableTaskTypes = $availableTaskTypes;
$this->cache->set('available_task_types', $this->availableTaskTypes, 60);
}

return $this->availableTaskTypes;
Expand Down
2 changes: 2 additions & 0 deletions tests/lib/TaskProcessing/TaskProcessingTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use OCP\Files\Config\IUserMountCache;
use OCP\Files\IRootFolder;
use OCP\Http\Client\IClientService;
use OCP\ICacheFactory;
use OCP\IConfig;
use OCP\IDBConnection;
use OCP\IServerContainer;
Expand Down Expand Up @@ -475,6 +476,7 @@ protected function setUp(): void {
$this->userMountCache,
\OC::$server->get(IClientService::class),
\OC::$server->get(IAppManager::class),
\OC::$server->get(ICacheFactory::class),
);
}

Expand Down
Loading