Skip to content

Commit

Permalink
feat(NODE-6161): allow custom aws sdk config
Browse files Browse the repository at this point in the history
  • Loading branch information
durran committed Jan 31, 2025
1 parent f82aa57 commit 42f446e
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 18 deletions.
9 changes: 8 additions & 1 deletion src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
import * as net from 'net';

import { deserialize, type Document, serialize } from '../bson';
import { type AWSCredentialProvider } from '../cmap/auth/aws_temporary_credentials';
import { type CommandOptions, type ProxyOptions } from '../cmap/connection';
import { kDecorateResult } from '../constants';
import { getMongoDBClientEncryption } from '../deps';
Expand Down Expand Up @@ -153,6 +154,7 @@ export class AutoEncrypter {
_kmsProviders: KMSProviders;
_bypassMongocryptdAndCryptShared: boolean;
_contextCounter: number;
_awsCredentialProvider?: AWSCredentialProvider;

_mongocryptdManager?: MongocryptdManager;
_mongocryptdClient?: MongoClient;
Expand Down Expand Up @@ -327,6 +329,11 @@ export class AutoEncrypter {
* This function is a no-op when bypassSpawn is set or the crypt shared library is used.
*/
async init(): Promise<MongoClient | void> {
// This is handled during init() as the auto encrypter is instantiated during the client's
// parseOptions() call, so the client doesn't have its options set at that point.
this._awsCredentialProvider =
this._client.options.credentials?.mechanismProperties.AWS_CREDENTIAL_PROVIDER;

if (this._bypassMongocryptdAndCryptShared || this.cryptSharedLibVersionInfo) {
return;
}
Expand Down Expand Up @@ -438,7 +445,7 @@ export class AutoEncrypter {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return await refreshKMSCredentials(this._kmsProviders);
return await refreshKMSCredentials(this._kmsProviders, this._awsCredentialProvider);
}

/**
Expand Down
8 changes: 7 additions & 1 deletion src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
type UUID
} from '../bson';
import { type AnyBulkWriteOperation, type BulkWriteResult } from '../bulk/common';
import { type AWSCredentialProvider } from '../cmap/auth/aws_temporary_credentials';
import { type ProxyOptions } from '../cmap/connection';
import { type Collection } from '../collection';
import { type FindCursor } from '../cursor/find_cursor';
Expand Down Expand Up @@ -81,6 +82,9 @@ export class ClientEncryption {
/** @internal */
_mongoCrypt: MongoCrypt;

/** @internal */
_awsCredentialProvider?: AWSCredentialProvider;

/** @internal */
static getMongoCrypt(): MongoCryptConstructor {
const encryption = getMongoDBClientEncryption();
Expand Down Expand Up @@ -125,6 +129,8 @@ export class ClientEncryption {
this._kmsProviders = options.kmsProviders || {};
const { timeoutMS } = resolveTimeoutOptions(client, options);
this._timeoutMS = timeoutMS;
this._awsCredentialProvider =
client.options.credentials?.mechanismProperties.AWS_CREDENTIAL_PROVIDER;

if (options.keyVaultNamespace == null) {
throw new MongoCryptInvalidArgumentError('Missing required option `keyVaultNamespace`');
Expand Down Expand Up @@ -712,7 +718,7 @@ export class ClientEncryption {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return await refreshKMSCredentials(this._kmsProviders);
return await refreshKMSCredentials(this._kmsProviders, this._awsCredentialProvider);
}

static get libmongocryptVersion() {
Expand Down
12 changes: 9 additions & 3 deletions src/client-side-encryption/providers/aws.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import {
type AWSCredentialProvider,
AWSSDKCredentialProvider
} from '../../cmap/auth/aws_temporary_credentials';
import { type KMSProviders } from '.';

/**
* @internal
*/
export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
const credentialProvider = new AWSSDKCredentialProvider();
export async function loadAWSCredentials(
kmsProviders: KMSProviders,
provider?: AWSCredentialProvider
): Promise<KMSProviders> {
const credentialProvider = new AWSSDKCredentialProvider(provider);

// We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey`
// or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings
Expand Down
8 changes: 6 additions & 2 deletions src/client-side-encryption/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { Binary } from '../../bson';
import { type AWSCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import { loadAWSCredentials } from './aws';
import { loadAzureCredentials } from './azure';
import { loadGCPCredentials } from './gcp';
Expand Down Expand Up @@ -176,11 +177,14 @@ export function isEmptyCredentials(
*
* @internal
*/
export async function refreshKMSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
export async function refreshKMSCredentials(
kmsProviders: KMSProviders,
awsProvider?: AWSCredentialProvider
): Promise<KMSProviders> {
let finalKMSProviders = kmsProviders;

if (isEmptyCredentials('aws', kmsProviders)) {
finalKMSProviders = await loadAWSCredentials(finalKMSProviders);
finalKMSProviders = await loadAWSCredentials(finalKMSProviders, awsProvider);
}

if (isEmptyCredentials('gcp', kmsProviders)) {
Expand Down
18 changes: 17 additions & 1 deletion src/cmap/auth/aws_temporary_credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ export interface AWSTempCredentials {
Expiration?: Date;
}

/** @public **/
export type AWSCredentialProvider = () => Promise<AWSCredentials>;

/**
* @internal
*
Expand All @@ -41,7 +44,20 @@ export abstract class AWSTemporaryCredentialProvider {

/** @internal */
export class AWSSDKCredentialProvider extends AWSTemporaryCredentialProvider {
private _provider?: () => Promise<AWSCredentials>;
private _provider?: AWSCredentialProvider;

/**
* Create the SDK credentials provider.
* @param credentialsProvider - The credentials provider.
*/
constructor(credentialsProvider?: AWSCredentialProvider) {
super();

if (credentialsProvider) {
this._provider = credentialsProvider;
}
}

/**
* The AWS SDK caches credentials automatically and handles refresh when the credentials have expired.
* To ensure this occurs, we need to cache the `provider` returned by the AWS sdk and re-use it when fetching credentials.
Expand Down
3 changes: 3 additions & 0 deletions src/cmap/auth/mongo_credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
MongoInvalidArgumentError,
MongoMissingCredentialsError
} from '../../error';
import type { AWSCredentialProvider } from './aws_temporary_credentials';
import { GSSAPICanonicalizationValue } from './gssapi';
import type { OIDCCallbackFunction } from './mongodb_oidc';
import { AUTH_MECHS_AUTH_SRC_EXTERNAL, AuthMechanism } from './providers';
Expand Down Expand Up @@ -68,6 +69,8 @@ export interface AuthMechanismProperties extends Document {
ALLOWED_HOSTS?: string[];
/** The resource token for OIDC auth in Azure and GCP. */
TOKEN_RESOURCE?: string;
/** A custom AWS credential provider to use. */
AWS_CREDENTIAL_PROVIDER?: AWSCredentialProvider;
}

/** @public */
Expand Down
5 changes: 3 additions & 2 deletions src/cmap/auth/mongodb_aws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
import { ByteUtils, maxWireVersion, ns, randomBytes } from '../../utils';
import { type AuthContext, AuthProvider } from './auth_provider';
import {
type AWSCredentialProvider,
AWSSDKCredentialProvider,
type AWSTempCredentials,
AWSTemporaryCredentialProvider,
Expand All @@ -34,11 +35,11 @@ interface AWSSaslContinuePayload {

export class MongoDBAWS extends AuthProvider {
private credentialFetcher: AWSTemporaryCredentialProvider;
constructor() {
constructor(credentialProvider?: AWSCredentialProvider) {
super();

this.credentialFetcher = AWSTemporaryCredentialProvider.isAWSSDKInstalled
? new AWSSDKCredentialProvider()
? new AWSSDKCredentialProvider(credentialProvider)
: new LegacyAWSTemporaryCredentialProvider();
}

Expand Down
2 changes: 1 addition & 1 deletion src/deps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export function getZstdLibrary(): ZStandardLib | { kModuleError: MongoMissingDep
}

/**
* @internal
* @public
* Copy of the AwsCredentialIdentityProvider interface from [`smithy/types`](https://socket.dev/npm/package/\@smithy/types/files/1.1.1/dist-types/identity/awsCredentialIdentity.d.ts),
* the return type of the aws-sdk's `fromNodeProviderChain().provider()`.
*/
Expand Down
3 changes: 2 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ export { ReadPreferenceMode } from './read_preference';
export { ServerType, TopologyType } from './sdam/common';

// Helper classes
export type { AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials';
export type { AWSCredentials } from './deps';
export { ReadConcern } from './read_concern';
export { ReadPreference } from './read_preference';
export { WriteConcern } from './write_concern';

// events
export {
CommandFailedEvent,
Expand Down
10 changes: 8 additions & 2 deletions src/mongo_client_auth_providers.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { type AuthProvider } from './cmap/auth/auth_provider';
import { type AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials';
import { GSSAPI } from './cmap/auth/gssapi';
import { type AuthMechanismProperties } from './cmap/auth/mongo_credentials';
import { MongoDBAWS } from './cmap/auth/mongodb_aws';
Expand All @@ -13,8 +14,11 @@ import { X509 } from './cmap/auth/x509';
import { MongoInvalidArgumentError } from './error';

/** @internal */
const AUTH_PROVIDERS = new Map<AuthMechanism | string, (workflow?: Workflow) => AuthProvider>([
[AuthMechanism.MONGODB_AWS, () => new MongoDBAWS()],
const AUTH_PROVIDERS = new Map<AuthMechanism | string, (param?: any) => AuthProvider>([
[
AuthMechanism.MONGODB_AWS,
(credentialProvider?: AWSCredentialProvider) => new MongoDBAWS(credentialProvider)
],
[
AuthMechanism.MONGODB_CR,
() => {
Expand Down Expand Up @@ -65,6 +69,8 @@ export class MongoClientAuthProviders {
let provider;
if (name === AuthMechanism.MONGODB_OIDC) {
provider = providerFunction(this.getWorkflow(authMechanismProperties));
} else if (name === AuthMechanism.MONGODB_AWS) {
provider = providerFunction(authMechanismProperties.AWS_CREDENTIAL_PROVIDER);
} else {
provider = providerFunction();
}
Expand Down
58 changes: 54 additions & 4 deletions test/integration/auth/mongodb_aws.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ describe('MONGODB-AWS', function () {
expect(result).to.be.a('number');
});

context('when user supplies a credentials provider', function () {
beforeEach(function () {
if (!awsSdkPresent) {
this.skipReason = 'only relevant to AssumeRoleWithWebIdentity with SDK installed';
return this.skip();
}
});

it('authenticates with a user provided credentials provider', async function () {
// @ts-expect-error We intentionally access a protected variable.
const credentialProvider = AWSTemporaryCredentialProvider.awsSDK;
client = this.configuration.newClient(process.env.MONGODB_URI, {
authMechanismProperties: {
AWS_CREDENTIAL_PROVIDER: credentialProvider.fromNodeProviderChain()
}
});

const result = await client
.db('aws')
.collection('aws_test')
.estimatedDocumentCount()
.catch(error => error);

expect(result).to.not.be.instanceOf(MongoServerError);
expect(result).to.be.a('number');
});
});

it('should allow empty string in authMechanismProperties.AWS_SESSION_TOKEN to override AWS_SESSION_TOKEN environment variable', function () {
client = this.configuration.newClient(this.configuration.url(), {
authMechanismProperties: { AWS_SESSION_TOKEN: '' }
Expand Down Expand Up @@ -351,11 +379,33 @@ describe('AWS KMS Credential Fetching', function () {
: undefined;
this.currentTest?.skipReason && this.skip();
});
it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
context('when a credential provider is not providered', function () {
it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
});
});

context('when a credential provider is provided', function () {
let credentialProvider;

beforeEach(function () {
// @ts-expect-error We intentionally access a protected variable.
credentialProvider = AWSTemporaryCredentialProvider.awsSDK;
});

it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials(
{ aws: {} },
credentialProvider.fromNodeProviderChain()
);

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
});
});

it('does not return any extra keys for the `aws` credential provider', async function () {
Expand Down

0 comments on commit 42f446e

Please sign in to comment.