Skip to content

Commit d308ceb

Browse files
committed
Add account ID to the environment variable credential provider (#3332)
1 parent ab902a3 commit d308ceb

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

botocore/credentials.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ class Credentials:
312312
:param str token: The security token, valid only for session credentials.
313313
:param str method: A string which identifies where the credentials
314314
were found.
315+
:param str account_id: (optional) An account ID associated with the credentials.
315316
"""
316317

317318
def __init__(
@@ -1118,6 +1119,7 @@ class EnvProvider(CredentialProvider):
11181119
# AWS_SESSION_TOKEN is what other AWS SDKs have standardized on.
11191120
TOKENS = ['AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN']
11201121
EXPIRY_TIME = 'AWS_CREDENTIAL_EXPIRATION'
1122+
ACCOUNT_ID = 'AWS_ACCOUNT_ID'
11211123

11221124
def __init__(self, environ=None, mapping=None):
11231125
"""
@@ -1127,8 +1129,12 @@ def __init__(self, environ=None, mapping=None):
11271129
:param mapping: An optional mapping of variable names to
11281130
environment variable names. Use this if you want to
11291131
change the mapping of access_key->AWS_ACCESS_KEY_ID, etc.
1130-
The dict can have up to 3 keys: ``access_key``, ``secret_key``,
1131-
``session_token``.
1132+
The dict can have up to 5 keys:
1133+
* ``access_key``
1134+
* ``secret_key``
1135+
* ``token``
1136+
* ``expiry_time``
1137+
* ``account_id``
11321138
"""
11331139
if environ is None:
11341140
environ = os.environ
@@ -1144,6 +1150,7 @@ def _build_mapping(self, mapping):
11441150
var_mapping['secret_key'] = self.SECRET_KEY
11451151
var_mapping['token'] = self.TOKENS
11461152
var_mapping['expiry_time'] = self.EXPIRY_TIME
1153+
var_mapping['account_id'] = self.ACCOUNT_ID
11471154
else:
11481155
var_mapping['access_key'] = mapping.get(
11491156
'access_key', self.ACCESS_KEY
@@ -1157,6 +1164,9 @@ def _build_mapping(self, mapping):
11571164
var_mapping['expiry_time'] = mapping.get(
11581165
'expiry_time', self.EXPIRY_TIME
11591166
)
1167+
var_mapping['account_id'] = mapping.get(
1168+
'account_id', self.ACCOUNT_ID
1169+
)
11601170
return var_mapping
11611171

11621172
def load(self):
@@ -1181,13 +1191,15 @@ def load(self):
11811191
expiry_time,
11821192
refresh_using=fetcher,
11831193
method=self.METHOD,
1194+
account_id=credentials['account_id'],
11841195
)
11851196

11861197
return Credentials(
11871198
credentials['access_key'],
11881199
credentials['secret_key'],
11891200
credentials['token'],
11901201
method=self.METHOD,
1202+
account_id=credentials['account_id'],
11911203
)
11921204
else:
11931205
return None
@@ -1230,6 +1242,11 @@ def fetch_credentials(require_expiry=True):
12301242
provider=method, cred_var=mapping['expiry_time']
12311243
)
12321244

1245+
credentials['account_id'] = None
1246+
account_id = environ.get(mapping['account_id'], '')
1247+
if account_id:
1248+
credentials['account_id'] = account_id
1249+
12331250
return credentials
12341251

12351252
return fetch_credentials

tests/unit/test_credentials.py

+30
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,20 @@ def test_envvars_found_with_session_token(self):
10201020
self.assertEqual(creds.token, 'baz')
10211021
self.assertEqual(creds.method, 'env')
10221022

1023+
def test_envvars_found_with_account_id(self):
1024+
environ = {
1025+
'AWS_ACCESS_KEY_ID': 'foo',
1026+
'AWS_SECRET_ACCESS_KEY': 'bar',
1027+
'AWS_ACCOUNT_ID': 'baz',
1028+
}
1029+
provider = credentials.EnvProvider(environ)
1030+
creds = provider.load()
1031+
self.assertIsNotNone(creds)
1032+
self.assertEqual(creds.access_key, 'foo')
1033+
self.assertEqual(creds.secret_key, 'bar')
1034+
self.assertEqual(creds.account_id, 'baz')
1035+
self.assertEqual(creds.method, 'env')
1036+
10231037
def test_envvars_not_found(self):
10241038
provider = credentials.EnvProvider(environ={})
10251039
creds = provider.load()
@@ -1127,6 +1141,22 @@ def test_can_override_expiry_env_var_mapping(self):
11271141
with self.assertRaisesRegex(RuntimeError, error_message):
11281142
creds.get_frozen_credentials()
11291143

1144+
def test_can_override_account_id_env_var_mapping(self):
1145+
environ = {
1146+
'AWS_ACCESS_KEY_ID': 'foo',
1147+
'AWS_SECRET_ACCESS_KEY': 'bar',
1148+
'AWS_SESSION_TOKEN': 'baz',
1149+
'FOO_ACCOUNT_ID': 'bin',
1150+
}
1151+
provider = credentials.EnvProvider(
1152+
environ, {'account_id': 'FOO_ACCOUNT_ID'}
1153+
)
1154+
creds = provider.load()
1155+
self.assertEqual(creds.access_key, 'foo')
1156+
self.assertEqual(creds.secret_key, 'bar')
1157+
self.assertEqual(creds.token, 'baz')
1158+
self.assertEqual(creds.account_id, 'bin')
1159+
11301160
def test_partial_creds_is_an_error(self):
11311161
# If the user provides an access key, they must also
11321162
# provide a secret key. Not doing so will generate an

0 commit comments

Comments
 (0)