Skip to content

Commit

Permalink
feat(auth): add EntraId integration tests
Browse files Browse the repository at this point in the history
- Add integration tests for token renewal and re-authentication flows
- Update credentials provider to use uniqueId as username instead of account username
- Add test utilities for loading Redis endpoint configurations
- Split TypeScript configs into separate files for samples and integration tests
  • Loading branch information
bobymicroby committed Jan 21, 2025
1 parent ac972bd commit df0ca4b
Show file tree
Hide file tree
Showing 15 changed files with 563 additions and 118 deletions.
16 changes: 8 additions & 8 deletions packages/authx/lib/token-manager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ describe('TokenManager', () => {
assert.equal(listener.receivedTokens.length, 1, 'Should not receive new token after failure');
assert.equal(listener.errors.length, 1, 'Should receive error');
assert.equal(listener.errors[0].message, 'Fatal error', 'Should have correct error message');
assert.equal(listener.errors[0].isFatal, true, 'Should be a fatal error');
assert.equal(listener.errors[0].isRetryable, false, 'Should be a fatal error');

// verify that the token manager is stopped and no more requests are made after the error and expected refresh time
await delay(80);
Expand All @@ -352,7 +352,7 @@ describe('TokenManager', () => {
initialDelayMs: 100,
maxDelayMs: 1000,
backoffMultiplier: 2,
shouldRetry: (error: unknown) => error instanceof Error && error.message === 'Temporary failure'
isRetryable: (error: unknown) => error instanceof Error && error.message === 'Temporary failure'
}
};

Expand Down Expand Up @@ -389,7 +389,7 @@ describe('TokenManager', () => {
// Should have first error but not stop due to retry config
assert.equal(listener.errors.length, 1, 'Should have first error');
assert.ok(listener.errors[0].message.includes('attempt 1'), 'Error should indicate first attempt');
assert.equal(listener.errors[0].isFatal, false, 'Should not be a fatal error');
assert.equal(listener.errors[0].isRetryable, true, 'Should not be a fatal error');
assert.equal(manager.isRunning(), true, 'Should continue running during retries');

// Advance past first retry (delay: 100ms due to backoff)
Expand All @@ -401,7 +401,7 @@ describe('TokenManager', () => {

assert.equal(listener.errors.length, 2, 'Should have second error');
assert.ok(listener.errors[1].message.includes('attempt 2'), 'Error should indicate second attempt');
assert.equal(listener.errors[0].isFatal, false, 'Should not be a fatal error');
assert.equal(listener.errors[0].isRetryable, true, 'Should not be a fatal error');
assert.equal(manager.isRunning(), true, 'Should continue running during retries');

// Advance past second retry (delay: 200ms due to backoff)
Expand Down Expand Up @@ -435,7 +435,7 @@ describe('TokenManager', () => {
maxDelayMs: 1000,
backoffMultiplier: 2,
jitterPercentage: 0,
shouldRetry: (error: unknown) => error instanceof Error && error.message === 'Temporary failure'
isRetryable: (error: unknown) => error instanceof Error && error.message === 'Temporary failure'
}
};

Expand Down Expand Up @@ -470,7 +470,7 @@ describe('TokenManager', () => {
// First error
assert.equal(listener.errors.length, 1, 'Should have first error');
assert.equal(manager.isRunning(), true, 'Should continue running after first error');
assert.equal(listener.errors[0].isFatal, false, 'Should not be a fatal error');
assert.equal(listener.errors[0].isRetryable, true, 'Should not be a fatal error');

// Advance past first retry
await delay(100);
Expand All @@ -483,7 +483,7 @@ describe('TokenManager', () => {
// Second error
assert.equal(listener.errors.length, 2, 'Should have second error');
assert.equal(manager.isRunning(), true, 'Should continue running after second error');
assert.equal(listener.errors[1].isFatal, false, 'Should not be a fatal error');
assert.equal(listener.errors[1].isRetryable, true, 'Should not be a fatal error');

// Advance past second retry
await delay(200);
Expand All @@ -495,7 +495,7 @@ describe('TokenManager', () => {

// Should stop after max retries
assert.equal(listener.errors.length, 3, 'Should have final error');
assert.equal(listener.errors[2].isFatal, true, 'Should not be a fatal error');
assert.equal(listener.errors[2].isRetryable, false, 'Should be a fatal error');
assert.equal(manager.isRunning(), false, 'Should stop after max retries exceeded');
assert.equal(identityProvider.getRequestCount(), 4, 'Should have made exactly 4 requests');

Expand Down
90 changes: 70 additions & 20 deletions packages/authx/lib/token-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,70 @@ import { Token } from './token';
* The configuration for retrying token refreshes.
*/
export interface RetryPolicy {
// The maximum number of attempts to retry token refreshes.
/**
* The maximum number of attempts to retry token refreshes.
*/
maxAttempts: number;
// The initial delay in milliseconds before the first retry.

/**
* The initial delay in milliseconds before the first retry.
*/
initialDelayMs: number;
// The maximum delay in milliseconds between retries (the calculated delay will be capped at this value).

/**
* The maximum delay in milliseconds between retries.
* The calculated delay will be capped at this value.
*/
maxDelayMs: number;
// The multiplier for exponential backoff between retries. e.g. 2 will double the delay each time.

/**
* The multiplier for exponential backoff between retries.
* @example
* A value of 2 will double the delay each time:
* - 1st retry: initialDelayMs
* - 2nd retry: initialDelayMs * 2
* - 3rd retry: initialDelayMs * 4
*/
backoffMultiplier: number;
// The percentage of jitter to apply to the delay. e.g. 0.1 will add or subtract up to 10% of the delay.

/**
* The percentage of jitter to apply to the delay.
* @example
* A value of 0.1 will add or subtract up to 10% of the delay.
*/
jitterPercentage?: number;
// A custom function to determine if a retry should be attempted based on the error and attempt number.
shouldRetry?: (error: unknown, attempt: number) => boolean;

/**
* Function to classify errors from the identity provider as retryable or non-retryable.
* Used to determine if a token refresh failure should be retried based on the type of error.
*
* The default behavior is to retry all types of errors if no function is provided.
*
* Common use cases:
* - Network errors that may be transient (should retry)
* - Invalid credentials (should not retry)
* - Rate limiting responses (should retry)
*
* @param error - The error from the identity provider3
* @param attempt - Current retry attempt (0-based)
* @returns `true` if the error is considered transient and the operation should be retried
*
* @example
* ```typescript
* const retryPolicy: RetryPolicy = {
* maxAttempts: 3,
* initialDelayMs: 1000,
* maxDelayMs: 5000,
* backoffMultiplier: 2,
* isRetryable: (error) => {
* // Retry on network errors or rate limiting
* return error instanceof NetworkError ||
* error instanceof RateLimitError;
* }
* };
* ```
*/
isRetryable?: (error: unknown, attempt: number) => boolean;
}

/**
Expand All @@ -36,14 +88,13 @@ export interface TokenManagerConfig {
}

/**
* IDPError is an error that occurs while calling the underlying IdentityProvider.
* IDPError indicates a failure from the identity provider.
*
* It can be transient and if retry policy is configured, the token manager will attempt to obtain a token again.
* This means that receiving non-fatal error is not a stream termination event.
* The stream will be terminated only if the error is fatal.
* The `isRetryable` flag is determined by the RetryPolicy's error classification function - if an error is
* classified as retryable, it will be marked as transient and the token manager will attempt to recover.
*/
export class IDPError extends Error {
constructor(public readonly message: string, public readonly isFatal: boolean) {
constructor(public readonly message: string, public readonly isRetryable: boolean) {
super(message);
this.name = 'IDPError';
}
Expand Down Expand Up @@ -105,7 +156,6 @@ export class TokenManager<T> {
*/
public start(listener: TokenStreamListener<T>, initialDelayMs: number = 0): Disposable {
if (this.listener) {
console.log('TokenManager is already running, stopping the previous instance');
this.stop();
}

Expand Down Expand Up @@ -142,14 +192,14 @@ export class TokenManager<T> {
private shouldRetry(error: unknown): boolean {
if (!this.config.retry) return false;

const { maxAttempts, shouldRetry } = this.config.retry;
const { maxAttempts, isRetryable } = this.config.retry;

if (this.retryAttempt >= maxAttempts) {
return false;
}

if (shouldRetry) {
return shouldRetry(error, this.retryAttempt);
if (isRetryable) {
return isRetryable(error, this.retryAttempt);
}

return false;
Expand All @@ -172,10 +222,10 @@ export class TokenManager<T> {
if (this.shouldRetry(error)) {
this.retryAttempt++;
const retryDelay = this.calculateRetryDelay();
this.notifyError(`Token refresh failed (attempt ${this.retryAttempt}), retrying in ${retryDelay}ms: ${error}`, false)
this.notifyError(`Token refresh failed (attempt ${this.retryAttempt}), retrying in ${retryDelay}ms: ${error}`, true)
this.scheduleNextRefresh(retryDelay);
} else {
this.notifyError(error, true);
this.notifyError(error, false);
this.stop();
}
}
Expand Down Expand Up @@ -255,13 +305,13 @@ export class TokenManager<T> {
return this.currentToken;
}

private notifyError = (error: unknown, isFatal: boolean): void => {
private notifyError(error: unknown, isRetryable: boolean): void {
const errorMessage = error instanceof Error ? error.message : String(error);

if (!this.listener) {
throw new Error(`TokenManager is not running but received an error: ${errorMessage}`);
}

this.listener.onError(new IDPError(errorMessage, isFatal));
this.listener.onError(new IDPError(errorMessage, isRetryable));
}
}
30 changes: 14 additions & 16 deletions packages/client/lib/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ export default class RedisClient<
#epoch: number;
#watchEpoch?: number;

private credentialsSubscription: Disposable | null = null;
#credentialsSubscription: Disposable | null = null;

get options(): RedisClientOptions<M, F, S, RESP> | undefined {
return this._self.#options;
Expand Down Expand Up @@ -394,19 +394,17 @@ export default class RedisClient<
}
}

private subscribeForStreamingCredentials(cp: StreamingCredentialsProvider): Promise<[BasicAuth, Disposable]> {
#subscribeForStreamingCredentials(cp: StreamingCredentialsProvider): Promise<[BasicAuth, Disposable]> {
return cp.subscribe({
onNext: credentials => {
this.reAuthenticate(credentials).catch(error => {
const errorMessage = error instanceof Error ? error.message : String(error);
console.error('Error during re-authentication', errorMessage);
cp.onReAuthenticationError(new CredentialsError(errorMessage));
});

},
onError: (e: Error) => {
const errorMessage = `Error from streaming credentials provider: ${e.message}`;
console.error(errorMessage);
cp.onReAuthenticationError(new UnableToObtainNewCredentialsError(errorMessage));
}
});
Expand All @@ -431,8 +429,8 @@ export default class RedisClient<

if (cp && cp.type === 'streaming-credentials-provider') {

const [credentials, disposable] = await this.subscribeForStreamingCredentials(cp)
this.credentialsSubscription = disposable;
const [credentials, disposable] = await this.#subscribeForStreamingCredentials(cp)
this.#credentialsSubscription = disposable;

if (credentials.password) {
hello.AUTH = {
Expand Down Expand Up @@ -467,8 +465,8 @@ export default class RedisClient<

if (cp && cp.type === 'streaming-credentials-provider') {

const [credentials, disposable] = await this.subscribeForStreamingCredentials(cp)
this.credentialsSubscription = disposable;
const [credentials, disposable] = await this.#subscribeForStreamingCredentials(cp)
this.#credentialsSubscription = disposable;

if (credentials.username || credentials.password) {
commands.push(
Expand Down Expand Up @@ -1105,8 +1103,8 @@ export default class RedisClient<
const chainId = Symbol('Reset Chain'),
promises = [this._self.#queue.reset(chainId)],
selectedDB = this._self.#options?.database ?? 0;
this.credentialsSubscription?.[Symbol.dispose]();
this.credentialsSubscription = null;
this._self.#credentialsSubscription?.[Symbol.dispose]();
this._self.#credentialsSubscription = null;
for (const command of (await this._self.#handshake(selectedDB))) {
promises.push(
this._self.#queue.addCommand(command, {
Expand Down Expand Up @@ -1158,8 +1156,8 @@ export default class RedisClient<
* @deprecated use .close instead
*/
QUIT(): Promise<string> {
this.credentialsSubscription?.[Symbol.dispose]();
this.credentialsSubscription = null;
this._self.#credentialsSubscription?.[Symbol.dispose]();
this._self.#credentialsSubscription = null;
return this._self.#socket.quit(async () => {
clearTimeout(this._self.#pingTimer);
const quitPromise = this._self.#queue.addCommand<string>(['QUIT']);
Expand Down Expand Up @@ -1198,8 +1196,8 @@ export default class RedisClient<
resolve();
};
this._self.#socket.on('data', maybeClose);
this.credentialsSubscription?.[Symbol.dispose]();
this.credentialsSubscription = null;
this._self.#credentialsSubscription?.[Symbol.dispose]();
this._self.#credentialsSubscription = null;
});
}

Expand All @@ -1210,8 +1208,8 @@ export default class RedisClient<
clearTimeout(this._self.#pingTimer);
this._self.#queue.flushAll(new DisconnectsClientError());
this._self.#socket.destroy();
this.credentialsSubscription?.[Symbol.dispose]();
this.credentialsSubscription = null;
this._self.#credentialsSubscription?.[Symbol.dispose]();
this._self.#credentialsSubscription = null;
}

ref() {
Expand Down
Loading

0 comments on commit df0ca4b

Please sign in to comment.