Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable caching capabilities on UserRepository #264

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config-templates/module_oidc.php
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@
// 60 * 60 * 6, // Default lifetime in seconds (used when particular cache item doesn't define its own lifetime)
//],

// Cache duration for user entities (authenticated users data). If not set, cache duration will be the same as
// session duration. This is used to avoid fetching user data from database on every authentication event.
// This is only relevant if protocol cache adapter is set up. For duration format info, check
// https://www.php.net/manual/en/dateinterval.construct.php.
// ModuleConfig::OPTION_PROTOCOL_USER_ENTITY_CACHE_DURATION => 'PT1H', // 1 hour
ModuleConfig::OPTION_PROTOCOL_USER_ENTITY_CACHE_DURATION => null, // fallback to session duration

/**
* Cron related options.
*/
Expand Down
5 changes: 5 additions & 0 deletions docker/ssp/module_oidc.php
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,9 @@
ModuleConfig::OPTION_AUTH_FORCED_ACR_VALUE_FOR_COOKIE_AUTHENTICATION => null,

ModuleConfig::OPTION_CRON_TAG => 'hourly',

ModuleConfig::OPTION_PROTOCOL_CACHE_ADAPTER => \Symfony\Component\Cache\Adapter\FilesystemAdapter::class,
ModuleConfig::OPTION_PROTOCOL_CACHE_ADAPTER_ARGUMENTS => [
// Use defaults
],
];
6 changes: 5 additions & 1 deletion routing/services/services.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,8 @@ services:
SimpleSAML\OpenID\Federation:
factory: [ '@SimpleSAML\Module\oidc\Factories\FederationFactory', 'build' ]
SimpleSAML\OpenID\Jwks:
factory: [ '@SimpleSAML\Module\oidc\Factories\JwksFactory', 'build' ]
factory: [ '@SimpleSAML\Module\oidc\Factories\JwksFactory', 'build' ]

# SSP
SimpleSAML\Database:
factory: [ 'SimpleSAML\Database', 'getInstance' ]
2 changes: 2 additions & 0 deletions src/Controller/Federation/Test.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

namespace SimpleSAML\Module\oidc\Controller\Federation;

use SimpleSAML\Database;
use SimpleSAML\Module\oidc\Codebooks\RegistrationTypeEnum;
use SimpleSAML\Module\oidc\Factories\CoreFactory;
use SimpleSAML\Module\oidc\Factories\Entities\ClientEntityFactory;
Expand Down Expand Up @@ -33,6 +34,7 @@ public function __construct(
protected ?FederationCache $federationCache,
protected LoggerService $loggerService,
protected Jwks $jwks,
protected Database $database,
protected ClientEntityFactory $clientEntityFactory,
protected CoreFactory $coreFactory,
protected \DateInterval $maxCacheDuration = new \DateInterval('PT30S'),
Expand Down
17 changes: 17 additions & 0 deletions src/ModuleConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
final public const OPTION_FEDERATION_ENTITY_STATEMENT_CACHE_DURATION = 'federation_entity_statement_cache_duration';
final public const OPTION_PROTOCOL_CACHE_ADAPTER = 'protocol_cache_adapter';
final public const OPTION_PROTOCOL_CACHE_ADAPTER_ARGUMENTS = 'protocol_cache_adapter_arguments';
final public const OPTION_PROTOCOL_USER_ENTITY_CACHE_DURATION = 'protocol_user_entity_cache_duration';

protected static array $standardScopes = [
ScopesEnum::OpenId->value => [
Expand Down Expand Up @@ -630,4 +631,20 @@
{
return $this->config()->getOptionalArray(self::OPTION_PROTOCOL_CACHE_ADAPTER_ARGUMENTS, []);
}

/**
* Get cache duration for user entities (user data). If not set in configuration, it will fall back to SSP session
* duration.
*
* @throws \Exception
*/
public function getProtocolUserEntityCacheDuration(): DateInterval

Check warning on line 641 in src/ModuleConfig.php

View check run for this annotation

Codecov / codecov/patch

src/ModuleConfig.php#L641

Added line #L641 was not covered by tests
{
return new DateInterval(
$this->config()->getOptionalString(
self::OPTION_PROTOCOL_USER_ENTITY_CACHE_DURATION,
null,
) ?? "PT{$this->sspConfig()->getInteger('session.duration')}S",
);

Check warning on line 648 in src/ModuleConfig.php

View check run for this annotation

Codecov / codecov/patch

src/ModuleConfig.php#L643-L648

Added lines #L643 - L648 were not covered by tests
}
}
15 changes: 6 additions & 9 deletions src/Repositories/AbstractDatabaseRepository.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,21 @@
*/
namespace SimpleSAML\Module\oidc\Repositories;

use SimpleSAML\Configuration;
use SimpleSAML\Database;
use SimpleSAML\Module\oidc\ModuleConfig;
use SimpleSAML\Module\oidc\Utils\ProtocolCache;

abstract class AbstractDatabaseRepository
{
protected Configuration $config;

protected Database $database;

/**
* ClientRepository constructor.
* @throws \Exception
*/
public function __construct(protected ModuleConfig $moduleConfig)
{
$this->config = $this->moduleConfig->config();
$this->database = Database::getInstance();
public function __construct(

Check warning on line 28 in src/Repositories/AbstractDatabaseRepository.php

View check run for this annotation

Codecov / codecov/patch

src/Repositories/AbstractDatabaseRepository.php#L28

Added line #L28 was not covered by tests
protected readonly ModuleConfig $moduleConfig,
protected readonly Database $database,
protected readonly ?ProtocolCache $protocolCache,
) {
}

abstract public function getTableName(): ?string;
Expand Down
6 changes: 5 additions & 1 deletion src/Repositories/AccessTokenRepository.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use League\OAuth2\Server\Entities\AccessTokenEntityInterface as OAuth2AccessTokenEntityInterface;
use League\OAuth2\Server\Entities\ClientEntityInterface as OAuth2ClientEntityInterface;
use RuntimeException;
use SimpleSAML\Database;
use SimpleSAML\Error\Error;
use SimpleSAML\Module\oidc\Codebooks\DateFormatsEnum;
use SimpleSAML\Module\oidc\Entities\AccessTokenEntity;
Expand All @@ -30,6 +31,7 @@
use SimpleSAML\Module\oidc\Repositories\Interfaces\AccessTokenRepositoryInterface;
use SimpleSAML\Module\oidc\Repositories\Traits\RevokeTokenByAuthCodeIdTrait;
use SimpleSAML\Module\oidc\Server\Exceptions\OidcServerException;
use SimpleSAML\Module\oidc\Utils\ProtocolCache;

class AccessTokenRepository extends AbstractDatabaseRepository implements AccessTokenRepositoryInterface
{
Expand All @@ -39,11 +41,13 @@ class AccessTokenRepository extends AbstractDatabaseRepository implements Access

public function __construct(
ModuleConfig $moduleConfig,
Database $database,
?ProtocolCache $protocolCache,
protected readonly ClientRepository $clientRepository,
protected readonly AccessTokenEntityFactory $accessTokenEntityFactory,
protected readonly Helpers $helpers,
) {
parent::__construct($moduleConfig);
parent::__construct($moduleConfig, $database, $protocolCache);
}

public function getTableName(): string
Expand Down
6 changes: 5 additions & 1 deletion src/Repositories/AuthCodeRepository.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

use League\OAuth2\Server\Entities\AuthCodeEntityInterface as OAuth2AuthCodeEntityInterface;
use RuntimeException;
use SimpleSAML\Database;
use SimpleSAML\Error\Error;
use SimpleSAML\Module\oidc\Codebooks\DateFormatsEnum;
use SimpleSAML\Module\oidc\Entities\AuthCodeEntity;
Expand All @@ -26,16 +27,19 @@
use SimpleSAML\Module\oidc\Helpers;
use SimpleSAML\Module\oidc\ModuleConfig;
use SimpleSAML\Module\oidc\Repositories\Interfaces\AuthCodeRepositoryInterface;
use SimpleSAML\Module\oidc\Utils\ProtocolCache;

class AuthCodeRepository extends AbstractDatabaseRepository implements AuthCodeRepositoryInterface
{
public function __construct(
ModuleConfig $moduleConfig,
Database $database,
?ProtocolCache $protocolCache,
protected readonly ClientRepository $clientRepository,
protected readonly AuthCodeEntityFactory $authCodeEntityFactory,
protected readonly Helpers $helpers,
) {
parent::__construct($moduleConfig);
parent::__construct($moduleConfig, $database, $protocolCache);
}

final public const TABLE_NAME = 'oidc_auth_code';
Expand Down
8 changes: 6 additions & 2 deletions src/Repositories/ClientRepository.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

use League\OAuth2\Server\Repositories\ClientRepositoryInterface;
use PDO;
use SimpleSAML\Database;
use SimpleSAML\Module\oidc\Entities\ClientEntity;
use SimpleSAML\Module\oidc\Entities\Interfaces\ClientEntityInterface;
use SimpleSAML\Module\oidc\Factories\Entities\ClientEntityFactory;
use SimpleSAML\Module\oidc\ModuleConfig;
use SimpleSAML\Module\oidc\Utils\ProtocolCache;

class ClientRepository extends AbstractDatabaseRepository implements ClientRepositoryInterface
{
public function __construct(
ModuleConfig $moduleConfig,
Database $database,
?ProtocolCache $protocolCache,
protected readonly ClientEntityFactory $clientEntityFactory,
) {
parent::__construct($moduleConfig);
parent::__construct($moduleConfig, $database, $protocolCache);
}

final public const TABLE_NAME = 'oidc_client';
Expand Down Expand Up @@ -389,7 +393,7 @@ private function count(string $query, ?string $owner): int
*/
private function getItemsPerPage(): int
{
return $this->config
return $this->moduleConfig->config()
->getOptionalIntegerRange(ModuleConfig::OPTION_ADMIN_UI_PAGINATION_ITEMS_PER_PAGE, 1, 100, 20);
}

Expand Down
6 changes: 5 additions & 1 deletion src/Repositories/RefreshTokenRepository.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use League\OAuth2\Server\Entities\RefreshTokenEntityInterface as OAuth2RefreshTokenEntityInterface;
use League\OAuth2\Server\Exception\OAuthServerException;
use RuntimeException;
use SimpleSAML\Database;
use SimpleSAML\Module\oidc\Codebooks\DateFormatsEnum;
use SimpleSAML\Module\oidc\Entities\Interfaces\RefreshTokenEntityInterface;
use SimpleSAML\Module\oidc\Entities\RefreshTokenEntity;
Expand All @@ -27,6 +28,7 @@
use SimpleSAML\Module\oidc\ModuleConfig;
use SimpleSAML\Module\oidc\Repositories\Interfaces\RefreshTokenRepositoryInterface;
use SimpleSAML\Module\oidc\Repositories\Traits\RevokeTokenByAuthCodeIdTrait;
use SimpleSAML\Module\oidc\Utils\ProtocolCache;

class RefreshTokenRepository extends AbstractDatabaseRepository implements RefreshTokenRepositoryInterface
{
Expand All @@ -36,11 +38,13 @@ class RefreshTokenRepository extends AbstractDatabaseRepository implements Refre

public function __construct(
ModuleConfig $moduleConfig,
Database $database,
?ProtocolCache $protocolCache,
protected readonly AccessTokenRepository $accessTokenRepository,
protected readonly RefreshTokenEntityFactory $refreshTokenEntityFactory,
protected readonly Helpers $helpers,
) {
parent::__construct($moduleConfig);
parent::__construct($moduleConfig, $database, $protocolCache);
}

/**
Expand Down
10 changes: 2 additions & 8 deletions src/Repositories/ScopeRepository.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,12 @@
use function array_key_exists;
use function in_array;

class ScopeRepository extends AbstractDatabaseRepository implements ScopeRepositoryInterface
class ScopeRepository implements ScopeRepositoryInterface
{
public function __construct(
ModuleConfig $moduleConfig,
protected readonly ModuleConfig $moduleConfig,
protected readonly ScopeEntityFactory $scopeEntityFactory,
) {
parent::__construct($moduleConfig);
}

public function getTableName(): ?string
{
return null;
}

/**
Expand Down
52 changes: 45 additions & 7 deletions src/Repositories/UserRepository.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,38 @@
use League\OAuth2\Server\Entities\ClientEntityInterface as OAuth2ClientEntityInterface;
use League\OAuth2\Server\Entities\UserEntityInterface;
use League\OAuth2\Server\Repositories\UserRepositoryInterface;
use SimpleSAML\Database;
use SimpleSAML\Module\oidc\Entities\UserEntity;
use SimpleSAML\Module\oidc\Factories\Entities\UserEntityFactory;
use SimpleSAML\Module\oidc\Helpers;
use SimpleSAML\Module\oidc\ModuleConfig;
use SimpleSAML\Module\oidc\Repositories\Interfaces\IdentityProviderInterface;
use SimpleSAML\Module\oidc\Utils\ProtocolCache;

class UserRepository extends AbstractDatabaseRepository implements UserRepositoryInterface, IdentityProviderInterface
{
final public const TABLE_NAME = 'oidc_user';

public function __construct(
ModuleConfig $moduleConfig,
Database $database,
?ProtocolCache $protocolCache,
protected readonly Helpers $helpers,
protected readonly UserEntityFactory $userEntityFactory,
) {
parent::__construct($moduleConfig);
parent::__construct($moduleConfig, $database, $protocolCache);
}

public function getTableName(): string
{
return $this->database->applyPrefix(self::TABLE_NAME);
}

public function getCacheKey(string $identifier): string
{
return $this->getTableName() . '_' . $identifier;
}

/**
* @param string $identifier
*
Expand All @@ -52,6 +61,13 @@ public function getTableName(): string
*/
public function getUserEntityByIdentifier(string $identifier): ?UserEntity
{
/** @var ?array $cachedState */
$cachedState = $this->protocolCache?->get(null, $this->getCacheKey($identifier));

if (is_array($cachedState)) {
return $this->userEntityFactory->fromState($cachedState);
}

$stmt = $this->database->read(
"SELECT * FROM {$this->getTableName()} WHERE id = :id",
[
Expand All @@ -69,7 +85,15 @@ public function getUserEntityByIdentifier(string $identifier): ?UserEntity
return null;
}

return $this->userEntityFactory->fromState($row);
$userEntity = $this->userEntityFactory->fromState($row);

$this->protocolCache?->set(
$userEntity->getState(),
$this->moduleConfig->getProtocolUserEntityCacheDuration(),
$this->getCacheKey($userEntity->getIdentifier()),
);

return $userEntity;
}

/**
Expand All @@ -95,21 +119,29 @@ public function add(UserEntity $userEntity): void
$stmt,
$userEntity->getState(),
);

$this->protocolCache?->set(
$userEntity->getState(),
$this->moduleConfig->getProtocolUserEntityCacheDuration(),
$this->getCacheKey($userEntity->getIdentifier()),
);
}

public function delete(UserEntity $user): void
public function delete(UserEntity $userEntity): void
{
$this->database->write(
"DELETE FROM {$this->getTableName()} WHERE id = :id",
[
'id' => $user->getIdentifier(),
'id' => $userEntity->getIdentifier(),
],
);

$this->protocolCache?->delete($this->getCacheKey($userEntity->getIdentifier()));
}

public function update(UserEntity $user, ?DateTimeImmutable $updatedAt = null): void
public function update(UserEntity $userEntity, ?DateTimeImmutable $updatedAt = null): void
{
$user->setUpdatedAt($updatedAt ?? $this->helpers->dateTime()->getUtc());
$userEntity->setUpdatedAt($updatedAt ?? $this->helpers->dateTime()->getUtc());

$stmt = sprintf(
"UPDATE %s SET claims = :claims, updated_at = :updated_at, created_at = :created_at WHERE id = :id",
Expand All @@ -118,7 +150,13 @@ public function update(UserEntity $user, ?DateTimeImmutable $updatedAt = null):

$this->database->write(
$stmt,
$user->getState(),
$userEntity->getState(),
);

$this->protocolCache?->set(
$userEntity->getState(),
$this->moduleConfig->getProtocolUserEntityCacheDuration(),
$this->getCacheKey($userEntity->getIdentifier()),
);
}
}
Loading