1414 */
1515#include < aws/cryptosdk/private/kms_keyring.h>
1616
17+ #include < aws/core/utils/ARN.h>
1718#include < aws/core/utils/Outcome.h>
1819#include < aws/core/utils/logging/LogMacros.h>
1920#include < aws/core/utils/memory/MemorySystemInterface.h>
@@ -78,16 +79,26 @@ static int OnDecrypt(
7879 const Aws::String key_arn = Private::aws_string_from_c_aws_byte_buf (&edk->provider_info );
7980
8081 /* If there are no key IDs in the list, keyring is in "discovery" mode and will attempt KMS calls with
81- * every ARN it comes across in the message. If there are key IDs in the list, it will cross check the
82- * ARN it reads with that list before attempting KMS calls. Note that if caller provided key IDs in
83- * anything other than a CMK ARN format, the SDK will not attempt to decrypt those data keys, because
84- * the EDK data format always specifies the CMK with the full (non-alias) ARN.
82+ * every key ARN it comes across in the message, so long as the key ARN is authorized by the
83+ * DiscoveryFilter (matches the partition and an account ID).
84+ *
85+ * If there are key IDs in the list, it will cross check the ARN it reads with that list
86+ * before attempting KMS calls. Note that if caller provided key IDs in anything other than
87+ * a CMK ARN format, the SDK will not attempt to decrypt those data keys, because the EDK
88+ * data format always specifies the CMK with the full (non-alias) ARN.
8589 */
8690 if (self->key_ids .size () &&
8791 std::find (self->key_ids .begin (), self->key_ids .end (), key_arn) == self->key_ids .end ()) {
8892 // This keyring does not have access to the CMK used to encrypt this data key. Skip.
8993 continue ;
9094 }
95+ // self->discovery_filter is non-null only if self was constructed via BuildDiscovery, which
96+ // in turn implies discovery mode
97+ if (self->discovery_filter && !self->discovery_filter ->IsAuthorized (key_arn)) {
98+ // The DiscoveryFilter blocks the CMK used to encrypt this data key. Skip.
99+ continue ;
100+ }
101+
91102 Aws::String kms_region = Private::parse_region_from_kms_key_arn (key_arn);
92103 if (kms_region.empty ()) {
93104 error_buf << " Error: Malformed ciphertext. Provider ID field of KMS EDK is invalid KMS CMK ARN: " << key_arn
@@ -104,6 +115,7 @@ static int OnDecrypt(
104115
105116 Aws::KMS::Model::DecryptRequest kms_request;
106117 kms_request.WithGrantTokens (self->grant_tokens )
118+ .WithKeyId (key_arn)
107119 .WithCiphertextBlob (aws_utils_byte_buffer_from_c_aws_byte_buf (&edk->ciphertext ))
108120 .WithEncryptionContext (enc_ctx_cpp);
109121
@@ -131,6 +143,10 @@ static int OnDecrypt(
131143 AWS_CRYPTOSDK_WRAPPING_KEY_DECRYPTED_DATA_KEY | AWS_CRYPTOSDK_WRAPPING_KEY_VERIFIED_ENC_CTX);
132144 }
133145 return ret;
146+ } else {
147+ // Since we specified the key ARN explicitly in the request,
148+ // KMS had better use that key to decrypt
149+ return aws_raise_error (AWS_ERROR_INVALID_STATE);
134150 }
135151 }
136152
@@ -419,6 +435,20 @@ aws_cryptosdk_keyring *KmsKeyring::Builder::BuildDiscovery() const {
419435 BuildClientSupplier (empty_key_ids_list, kms_client, client_supplier));
420436}
421437
438+ aws_cryptosdk_keyring *KmsKeyring::Builder::BuildDiscovery (std::shared_ptr<DiscoveryFilter> discovery_filter) const {
439+ if (!discovery_filter) {
440+ return nullptr ;
441+ }
442+
443+ Aws::Vector<Aws::String> empty_key_ids_list;
444+ return Aws::New<Private::KmsKeyringImpl>(
445+ AWS_CRYPTO_SDK_KMS_CLASS_TAG,
446+ empty_key_ids_list,
447+ grant_tokens,
448+ BuildClientSupplier (empty_key_ids_list, kms_client, client_supplier),
449+ discovery_filter);
450+ }
451+
422452KmsKeyring::Builder &KmsKeyring::Builder::WithGrantTokens (const Aws::Vector<Aws::String> &grant_tokens) {
423453 this ->grant_tokens .insert (this ->grant_tokens .end (), grant_tokens.begin (), grant_tokens.end ());
424454 return *this ;
@@ -440,5 +470,59 @@ KmsKeyring::Builder &KmsKeyring::Builder::WithKmsClient(const std::shared_ptr<KM
440470 return *this ;
441471}
442472
473+ bool KmsKeyring::DiscoveryFilter::IsAuthorized (const Aws::String &key_arn) const {
474+ Utils::ARN arn (key_arn);
475+ if (!arn) {
476+ return false ;
477+ }
478+
479+ bool matching_partition = arn.GetPartition () == partition;
480+ bool matching_account = account_ids.find (arn.GetAccountId ()) != account_ids.end ();
481+ return matching_partition && matching_account;
482+ }
483+
484+ KmsKeyring::DiscoveryFilterBuilder KmsKeyring::DiscoveryFilter::Builder (Aws::String partition) {
485+ KmsKeyring::DiscoveryFilterBuilder builder (partition);
486+ return builder;
487+ }
488+
489+ KmsKeyring::DiscoveryFilterBuilder &KmsKeyring::DiscoveryFilterBuilder::AddAccount (const Aws::String &account_id) {
490+ this ->account_ids .insert (account_id);
491+ return *this ;
492+ }
493+
494+ KmsKeyring::DiscoveryFilterBuilder &KmsKeyring::DiscoveryFilterBuilder::AddAccounts (
495+ const Aws::Vector<Aws::String> &account_ids) {
496+ this ->account_ids .insert (account_ids.begin (), account_ids.end ());
497+ return *this ;
498+ }
499+
500+ KmsKeyring::DiscoveryFilterBuilder &KmsKeyring::DiscoveryFilterBuilder::WithAccounts (
501+ const Aws::Vector<Aws::String> &account_ids) {
502+ this ->account_ids .clear ();
503+ return this ->AddAccounts (account_ids);
504+ }
505+
506+ std::shared_ptr<KmsKeyring::DiscoveryFilter> KmsKeyring::DiscoveryFilterBuilder::Build () const {
507+ // Must have at least one account ID, and partition and account IDs cannot be the empty string
508+ if (account_ids.empty ()) {
509+ AWS_LOGSTREAM_ERROR (
510+ AWS_CRYPTO_SDK_KMS_CLASS_TAG, " Invalid DiscoveryFilterBuilder: account IDs cannot be empty" );
511+ return nullptr ;
512+ }
513+ if (partition.empty ()) {
514+ AWS_LOGSTREAM_ERROR (AWS_CRYPTO_SDK_KMS_CLASS_TAG, " Invalid DiscoveryFilterBuilder: partition cannot be blank" );
515+ return nullptr ;
516+ }
517+ if (account_ids.find (" " ) != account_ids.end ()) {
518+ AWS_LOGSTREAM_ERROR (
519+ AWS_CRYPTO_SDK_KMS_CLASS_TAG, " Invalid DiscoveryFilterBuilder: account IDs cannot be blank" );
520+ return nullptr ;
521+ }
522+
523+ return Aws::MakeShared<Private::DiscoveryFilterImpl>(
524+ KmsKeyring::AWS_CRYPTO_SDK_DISCOVERY_FILTER_CLASS_TAG, partition, account_ids);
525+ }
526+
443527} // namespace Cryptosdk
444528} // namespace Aws
0 commit comments