diff --git a/.changes/next-release/feature-sso-81096.json b/.changes/next-release/feature-sso-81096.json
new file mode 100644
index 000000000000..e4419d6bbd19
--- /dev/null
+++ b/.changes/next-release/feature-sso-81096.json
@@ -0,0 +1,5 @@
+{
+ "type": "feature",
+ "category": "sso",
+ "description": "Add support and default to the OAuth 2.0 Authorization Code Flow with PKCE for aws sso login."
+}
diff --git a/awscli/botocore/exceptions.py b/awscli/botocore/exceptions.py
index 115d2442a572..b92fa6bac0b2 100644
--- a/awscli/botocore/exceptions.py
+++ b/awscli/botocore/exceptions.py
@@ -682,6 +682,10 @@ class SSOTokenLoadError(SSOError):
fmt = "Error loading SSO Token: {error_msg}"
+class AuthorizationCodeLoadError(SSOError):
+ fmt = "Error loading authorization code: {error_msg}"
+
+
class UnauthorizedSSOTokenError(SSOError):
fmt = (
"The SSO session associated with this profile has expired or is "
@@ -690,6 +694,14 @@ class UnauthorizedSSOTokenError(SSOError):
)
+class AuthCodeFetcherError(SSOError):
+ fmt = (
+ "Unable to initialize the OAuth 2.0 authorization callback handler: "
+ "{error_msg} \n You may use --use-device-code to fall back to the "
+ "device code flow which does not require the callback handler."
+ )
+
+
class CapacityNotAvailableError(BotoCoreError):
fmt = (
'Insufficient request capacity available.'
diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py
index 5a7614211a4d..cf7c90630a37 100644
--- a/awscli/botocore/utils.py
+++ b/awscli/botocore/utils.py
@@ -23,7 +23,10 @@
import random
import re
import socket
+import string
import time
+import urllib
+import uuid
import warnings
import weakref
from datetime import datetime as _DatetimeClass
@@ -48,6 +51,7 @@
zip_longest,
)
from botocore.exceptions import (
+ AuthorizationCodeLoadError,
ClientError,
ConfigNotFound,
ConnectionClosedError,
@@ -3057,20 +3061,16 @@ def __call__(self):
return token_file.read()
-class SSOTokenFetcher(object):
- # The device flow RFC defines the slow down delay to be an additional
- # 5 seconds:
- # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5
- _SLOW_DOWN_DELAY = 5
- # The default interval of 5 is also defined in the RFC (see above link)
- _DEFAULT_INTERVAL = 5
+class BaseSSOTokenFetcher(object):
+ """Base class for SSO token fetchers, for functionality
+ shared between the device and authorization code grant flows.
+ """
_EXPIRY_WINDOW = 15 * 60
_CLIENT_REGISTRATION_TYPE = 'public'
- _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code'
def __init__(
self, sso_region, client_creator, cache=None,
- on_pending_authorization=None, time_fetcher=None, sleep=None,
+ on_pending_authorization=None, time_fetcher=None
):
self._sso_region = sso_region
self._client_creator = client_creator
@@ -3080,10 +3080,6 @@ def __init__(
time_fetcher = self._utc_now
self._time_fetcher = time_fetcher
- if sleep is None:
- sleep = time.sleep
- self._sleep = sleep
-
if cache is None:
cache = {}
self._cache = cache
@@ -3101,6 +3097,15 @@ def _is_expired(self, response):
seconds = total_seconds(end_time - self._time_fetcher())
return seconds < self._EXPIRY_WINDOW
+ def _is_registration_for_auth_code(self, registration):
+ if ('grantTypes' in registration and
+ 'authorization_code' in registration['grantTypes']):
+ return True
+
+ # Else assume that it's device flow,
+ # since the CLI didn't cache grantTypes previously
+ return False
+
@CachedProperty
def _client(self):
config = botocore.config.Config(
@@ -3109,13 +3114,61 @@ def _client(self):
)
return self._client_creator('sso-oidc', config=config)
- def _register_client(self, session_name, scopes):
+ def _generate_client_name(self, session_name):
if session_name is None:
# Use a timestamp for the session name for legacy configuration
timestamp = datetime2timestamp(self._time_fetcher())
session_name = int(timestamp)
+ return f'botocore-client-{str(session_name)}'
+
+ def _registration_cache_key(self, start_url, session_name, scopes):
+ # Registration is unique based on the following properties to ensure
+ # modifications to the registration do not affect the permissions of
+ # tokens derived for other start URLs.
+ args = {
+ 'tool': 'botocore',
+ 'startUrl': start_url,
+ 'region': self._sso_region,
+ 'scopes': scopes,
+ 'session_name': session_name,
+ }
+ cache_args = json.dumps(args, sort_keys=True).encode('utf-8')
+ return hashlib.sha1(cache_args).hexdigest()
+
+ def _token_cache_key(self, start_url, session_name):
+ input_str = start_url
+ if session_name is not None:
+ input_str = session_name
+ return hashlib.sha1(input_str.encode('utf-8')).hexdigest()
+
+
+class SSOTokenFetcher(BaseSSOTokenFetcher):
+ """Performs the device grant OAuth2.0 flow"""
+ # The device flow RFC defines the slow-down delay to be an additional
+ # 5 seconds:
+ # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5
+ _SLOW_DOWN_DELAY = 5
+ # The default interval of 5 is also defined in the RFC (see above link)
+ _DEFAULT_INTERVAL = 5
+ _EXPIRY_WINDOW = 15 * 60
+ _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code'
+
+ def __init__(
+ self, sso_region, client_creator, cache=None,
+ on_pending_authorization=None, time_fetcher=None, sleep=None
+ ):
+ super().__init__(
+ sso_region, client_creator, cache, on_pending_authorization,
+ time_fetcher
+ )
+
+ if sleep is None:
+ sleep = time.sleep
+ self._sleep = sleep
+
+ def _register_client(self, session_name, scopes):
register_kwargs = {
- 'clientName': f'botocore-client-{session_name}',
+ 'clientName': self._generate_client_name(session_name),
'clientType': self._CLIENT_REGISTRATION_TYPE,
}
if scopes:
@@ -3132,20 +3185,6 @@ def _register_client(self, session_name, scopes):
registration['scopes'] = scopes
return registration
- def _registration_cache_key(self, start_url, session_name, scopes):
- # Registration is unique based on the following properties to ensure
- # modifications to the registration do not affect the permissions of
- # tokens derived for other start URLs.
- args = {
- 'tool': 'botocore',
- 'startUrl': start_url,
- 'region': self._sso_region,
- 'scopes': scopes,
- 'session_name': session_name,
- }
- cache_args = json.dumps(args, sort_keys=True).encode('utf-8')
- return hashlib.sha1(cache_args).hexdigest()
-
def _registration(
self,
start_url,
@@ -3160,7 +3199,8 @@ def _registration(
)
if not force_refresh and cache_key in self._cache:
registration = self._cache[cache_key]
- if not self._is_expired(registration):
+ if (not self._is_expired(registration) and
+ not self._is_registration_for_auth_code(registration)):
return registration
registration = self._register_client(
@@ -3172,7 +3212,7 @@ def _registration(
def _authorize_client(self, start_url, registration):
# NOTE: The authorization response is not cached. These responses are
- # short lived (currently only 10 minutes) and can only be exchanged for
+ # short-lived (currently only 10 minutes) and can only be exchanged for
# a token once. Having multiple clients share this is problematic.
response = self._client.start_device_authorization(
clientId=registration['clientId'],
@@ -3261,12 +3301,6 @@ def _create_token_attempt(
raise PendingAuthorizationExpiredError()
return interval, None
- def _token_cache_key(self, start_url, session_name):
- input_str = start_url
- if session_name is not None:
- input_str = session_name
- return hashlib.sha1(input_str.encode('utf-8')).hexdigest()
-
def _token(
self,
start_url,
@@ -3305,6 +3339,231 @@ def fetch_token(
)
+class SSOTokenFetcherAuth(BaseSSOTokenFetcher):
+ """Performs the authorization code grant with PKCE OAuth2.0 flow"""
+ _AUTH_GRANT_TYPES = ['authorization_code', 'refresh_token']
+ _AUTH_GRANT_DEFAULT_SCOPE = 'sso:account:access'
+
+ def __init__(
+ self, sso_region, client_creator, auth_code_fetcher, cache=None,
+ on_pending_authorization=None, time_fetcher=None
+ ):
+ super().__init__(sso_region, client_creator, cache,
+ on_pending_authorization, time_fetcher
+ )
+
+ self._auth_code_fetcher = auth_code_fetcher
+
+ # Generate the PKCE pair
+ self.code_verifier = ''.join(
+ random.SystemRandom().choice(
+ string.ascii_letters + string.digits + '-._~'
+ )
+ for _ in range(64)
+ )
+ self.code_challenge = base64.urlsafe_b64encode(
+ hashlib.sha256(self.code_verifier.encode()).digest()).decode()
+
+ def _register_client(self, session_name, scopes, redirect_uri, issuer_url):
+ register_kwargs = {
+ 'clientName': self._generate_client_name(session_name),
+ 'clientType': self._CLIENT_REGISTRATION_TYPE,
+ 'grantTypes': self._AUTH_GRANT_TYPES,
+ 'redirectUris': [redirect_uri],
+ 'issuerUrl': issuer_url
+ }
+
+ if scopes:
+ register_kwargs['scopes'] = scopes
+ else:
+ register_kwargs['scopes'] = [self._AUTH_GRANT_DEFAULT_SCOPE]
+
+ response = self._client.register_client(**register_kwargs)
+
+ expires_at = response['clientSecretExpiresAt']
+ expires_at = datetime.datetime.fromtimestamp(expires_at, tzutc())
+ registration = {
+ 'clientId': response['clientId'],
+ 'clientSecret': response['clientSecret'],
+ 'expiresAt': expires_at,
+ 'scopes': register_kwargs['scopes'],
+ 'grantTypes': register_kwargs['grantTypes']
+ }
+
+ return registration
+
+ def _registration(
+ self,
+ start_url,
+ session_name,
+ scopes,
+ force_refresh=False,
+ ):
+ cache_key = self._registration_cache_key(
+ start_url,
+ session_name,
+ scopes,
+ )
+ if not force_refresh and cache_key in self._cache:
+ registration = self._cache[cache_key]
+ if (not self._is_expired(registration) and
+ self._is_registration_for_auth_code(registration)):
+ return registration
+
+ registration = self._register_client(
+ session_name,
+ scopes,
+ self._auth_code_fetcher.redirect_uri_without_port(),
+ start_url
+ )
+ self._cache[cache_key] = registration
+ return registration
+
+ def _extract_resolved_endpoint(self, params, **kwargs):
+ """Event handler for before-call that will extract the resolved endpoint
+ for a given request without actually running it
+ """
+ # This will contain any path and query params specific to
+ # the operation/input, so extract just the scheme and hostname
+ if params['url']:
+ parsed = urlparse(params['url'])
+ self._base_endpoint = f'{parsed.scheme}://{parsed.netloc}'
+
+ # Return a tuple containing the "response" to short-circuit the request
+ return botocore.awsrequest.AWSResponse(None, 200, {}, None), {}
+
+ def _get_base_authorization_uri(self):
+ """Simulates an SSO-OIDC request so that we can extract the "base"
+ endpoint for the current client to use for the un-modeled Authorize
+ operation
+ """
+ self._client.meta.events.register('before-call',
+ self._extract_resolved_endpoint)
+ self._client.register_client(
+ clientName='temp',
+ clientType='public'
+ )
+ self._client.meta.events.unregister('before-call',
+ self._extract_resolved_endpoint)
+
+ return self._base_endpoint
+
+ def _get_authorization_uri(
+ self,
+ client_id,
+ registration_scopes,
+ expected_state):
+
+ query_params = {
+ 'response_type': 'code',
+ 'client_id': client_id,
+ 'redirect_uri': self._auth_code_fetcher.redirect_uri_with_port(),
+ 'state': expected_state,
+ 'code_challenge_method': 'S256'
+ # Don't want to encode code_challenge again, so we append below
+ }
+
+ # For the query param, scopes must be space separated before encoding
+ if registration_scopes:
+ query_params['scope'] = " ".join(registration_scopes)
+
+ return (
+ f'{self._get_base_authorization_uri()}/authorize?'
+ f'{urllib.parse.urlencode(query_params)}'
+ f'&code_challenge={self.code_challenge[:-1]}' # trim final '='
+ )
+
+ def _get_new_token(self, start_url, session_name, registration_scopes):
+ registration = self._registration(
+ start_url,
+ session_name,
+ registration_scopes,
+ )
+
+ expected_state = uuid.uuid4()
+
+ authorization_uri = self._get_authorization_uri(
+ registration['clientId'],
+ registration_scopes,
+ expected_state)
+
+ # Even though there's just one uri, this matches the inputs
+ # for the device code flow so that we can reuse the browser handlers
+ authorization_args = {
+ 'verificationUri': authorization_uri,
+ 'verificationUriComplete': authorization_uri,
+ 'userCode': None
+ }
+
+ # Open/display the link, then block until the redirect uri is hit and
+ # the auth code is retrieved
+ self._on_pending_authorization(**authorization_args)
+ auth_code, state = self._auth_code_fetcher.get_auth_code_and_state()
+
+ if auth_code is None:
+ raise AuthorizationCodeLoadError(
+ error_msg='Failed to retrieve an authorization code.'
+ )
+
+ if state != expected_state:
+ raise AuthorizationCodeLoadError(
+ error_msg='State parameter does not match expected value.'
+ )
+
+ return self._create_token_(start_url, registration, auth_code)
+
+ def _create_token_(self, start_url, registration, auth_code):
+ try:
+ response = self._client.create_token(
+ grantType='authorization_code',
+ clientId=registration['clientId'],
+ clientSecret=registration['clientSecret'],
+ redirectUri=self._auth_code_fetcher.redirect_uri_with_port(),
+ codeVerifier=self.code_verifier,
+ code=auth_code
+ )
+ expires_in = datetime.timedelta(seconds=response['expiresIn'])
+ token = {
+ 'startUrl': start_url,
+ 'region': self._sso_region,
+ 'accessToken': response['accessToken'],
+ 'expiresAt': self._time_fetcher() + expires_in,
+ # Cache the registration alongside the token
+ 'clientId': registration['clientId'],
+ 'clientSecret': registration['clientSecret'],
+ 'registrationExpiresAt': registration['expiresAt'],
+ }
+ if 'refreshToken' in response:
+ token['refreshToken'] = response['refreshToken']
+ return token
+ except self._client.exceptions.InvalidGrantException as error:
+ print(error)
+ except self._client.exceptions.ExpiredTokenException:
+ raise PendingAuthorizationExpiredError()
+
+ def fetch_token(
+ self,
+ start_url,
+ force_refresh=False,
+ registration_scopes=None,
+ session_name=None,
+ ):
+ cache_key = self._token_cache_key(start_url, session_name)
+ # Only obey the token cache if we are not forcing a refresh.
+ if not force_refresh and cache_key in self._cache:
+ token = self._cache[cache_key]
+ if not self._is_expired(token):
+ return token
+
+ token = self._get_new_token(
+ start_url,
+ session_name,
+ registration_scopes
+ )
+ self._cache[cache_key] = token
+ return token
+
+
class SSOTokenLoader(object):
def __init__(self, cache=None):
if cache is None:
diff --git a/awscli/customizations/sso/index.html b/awscli/customizations/sso/index.html
new file mode 100644
index 000000000000..f1572f4b04f9
--- /dev/null
+++ b/awscli/customizations/sso/index.html
@@ -0,0 +1,176 @@
+
+
+
+
+ AWS Authentication
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
You can close this window and re-start the authorization flow
+
+
+
+
+
+
diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py
index 4d7790b62c76..acd58e287dd4 100644
--- a/awscli/customizations/sso/login.py
+++ b/awscli/customizations/sso/login.py
@@ -52,6 +52,7 @@ def _run_main(self, parsed_args, parsed_globals):
force_refresh=True,
session_name=sso_config.get('session_name'),
registration_scopes=sso_config.get('registration_scopes'),
+ fallback_to_device_flow=parsed_args.use_device_code
)
success_msg = 'Successfully logged into Start URL: %s\n'
uni_print(success_msg % sso_config['sso_start_url'])
diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py
index ae9a83e9e8b8..d7d813bf11cc 100644
--- a/awscli/customizations/sso/utils.py
+++ b/awscli/customizations/sso/utils.py
@@ -11,19 +11,22 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import datetime
-import os
-import logging
import json
+import logging
+import os
import webbrowser
+from functools import partial
+from http.server import HTTPServer, BaseHTTPRequestHandler
+from urllib.parse import urlparse, parse_qs
-from botocore.utils import SSOTokenFetcher
-from botocore.utils import original_ld_library_path
from botocore.credentials import JSONFileCache
+from botocore.exceptions import AuthCodeFetcherError
+from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth
+from botocore.utils import original_ld_library_path
from awscli.customizations.commands import BasicCommand
-from awscli.customizations.utils import uni_print
-from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR
from awscli.customizations.exceptions import ConfigurationError
+from awscli.customizations.utils import uni_print
LOG = logging.getLogger(__name__)
@@ -37,9 +40,18 @@
'action': 'store_true',
'default': False,
'help_text': (
- 'Disables automatically opening the verfication URL in the '
+ 'Disables automatically opening the verification URL in the '
'default browser.'
)
+ },
+ {
+ 'name': 'use-device-code',
+ 'action': 'store_true',
+ 'default': False,
+ 'help_text': (
+ 'Uses the Device Code authorization grant and login flow '
+ 'instead of the Authorization Code flow.'
+ )
}
]
@@ -56,24 +68,38 @@ def _sso_json_dumps(obj):
def do_sso_login(session, sso_region, start_url, token_cache=None,
on_pending_authorization=None, force_refresh=False,
- registration_scopes=None, session_name=None):
+ registration_scopes=None, session_name=None,
+ fallback_to_device_flow=False):
if token_cache is None:
token_cache = JSONFileCache(SSO_TOKEN_DIR, dumps_func=_sso_json_dumps)
if on_pending_authorization is None:
on_pending_authorization = OpenBrowserHandler(
open_browser=open_browser_with_original_ld_path
)
- token_fetcher = SSOTokenFetcher(
- sso_region=sso_region,
- client_creator=session.create_client,
- cache=token_cache,
- on_pending_authorization=on_pending_authorization
- )
+
+ # For the auth flow, we need a non-legacy sso-session and check that the
+ # user hasn't opted into falling back to the device code flow
+ if session_name and not fallback_to_device_flow:
+ token_fetcher = SSOTokenFetcherAuth(
+ sso_region=sso_region,
+ client_creator=session.create_client,
+ auth_code_fetcher=AuthCodeFetcher(),
+ cache=token_cache,
+ on_pending_authorization=on_pending_authorization
+ )
+ else:
+ token_fetcher = SSOTokenFetcher(
+ sso_region=sso_region,
+ client_creator=session.create_client,
+ cache=token_cache,
+ on_pending_authorization=on_pending_authorization
+ )
+
return token_fetcher.fetch_token(
start_url=start_url,
session_name=session_name,
force_refresh=force_refresh,
- registration_scopes=registration_scopes,
+ registration_scopes=registration_scopes
)
@@ -110,13 +136,19 @@ def __call__(
f'Browser will not be automatically opened.\n'
f'Please visit the following URL:\n'
f'\n{verificationUri}\n'
+
+ )
+
+ user_code_msg = (
f'\nThen enter the code:\n'
f'\n{userCode}\n'
f'\nAlternatively, you may visit the following URL which will '
f'autofill the code upon loading:'
- f'\n{verificationUriComplete}\n'
- )
+ f'\n{verificationUriComplete}\n')
+
uni_print(opening_msg, self._outfile)
+ if userCode:
+ uni_print(user_code_msg, self._outfile)
class OpenBrowserHandler(BaseAuthorizationhandler):
@@ -135,10 +167,16 @@ def __call__(
f'to use a different device to authorize this request, open the '
f'following URL:\n'
f'\n{verificationUri}\n'
+ )
+
+ user_code_msg = (
f'\nThen enter the code:\n'
f'\n{userCode}\n'
)
uni_print(opening_msg, self._outfile)
+ if userCode:
+ uni_print(user_code_msg, self._outfile)
+
if self._open_browser:
try:
return self._open_browser(verificationUriComplete)
@@ -146,6 +184,73 @@ def __call__(
LOG.debug('Failed to open browser:', exc_info=True)
+class AuthCodeFetcher:
+ """Manages the local web server that will be used
+ to retrieve the authorization code from the OAuth callback
+ """
+ def __init__(self):
+ self._auth_code = None
+ self._state = None
+ self._is_done = False
+
+ # We do this so that the request handler can have a reference to this
+ # AuthCodeFetcher so that it can pass back the state and auth code
+ try:
+ handler = partial(OAuthCallbackHandler, self)
+ self.http_server = HTTPServer(('', 0), handler)
+ except Exception as e:
+ raise AuthCodeFetcherError(error_msg=e)
+
+ def redirect_uri_without_port(self):
+ return 'http://127.0.0.1/oauth/callback'
+
+ def redirect_uri_with_port(self):
+ return f'http://127.0.0.1:{self.http_server.server_port}/oauth/callback'
+
+ def get_auth_code_and_state(self):
+ """Blocks until the expected redirect request with either the
+ authorization code/state or and error is handled
+ """
+ while not self._is_done:
+ self.http_server.handle_request()
+ self.http_server.server_close()
+
+ return self._auth_code, self._state
+
+
+class OAuthCallbackHandler(BaseHTTPRequestHandler):
+ """HTTP handler to handle OAuth callback requests, extracting
+ the auth code and state parameters, and displaying a page directing
+ the user to return to the CLI.
+ """
+ def __init__(self, auth_code_fetcher, *args, **kwargs):
+ self._auth_code_fetcher = auth_code_fetcher
+ super().__init__(*args, **kwargs)
+
+ def log_message(self, format, *args):
+ # Suppress built-in logging, otherwise it prints
+ # each request to console
+ pass
+
+ def do_GET(self):
+ self.send_response(200)
+ self.end_headers()
+ with open(
+ os.path.join(os.path.dirname(__file__), 'index.html'),
+ 'rb') as file:
+ self.wfile.write(file.read())
+
+ query_params = parse_qs(urlparse(self.path).query)
+
+ if 'error' in query_params:
+ self._auth_code_fetcher._is_done = True
+ return
+ elif 'code' in query_params and 'state' in query_params:
+ self._auth_code_fetcher._is_done = True
+ self._auth_code_fetcher._auth_code = query_params['code'][0]
+ self._auth_code_fetcher._state = query_params['state'][0]
+
+
class InvalidSSOConfigError(ConfigurationError):
pass
diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py
index 98036f9173b8..ace3e5e63b03 100644
--- a/tests/functional/sso/__init__.py
+++ b/tests/functional/sso/__init__.py
@@ -13,7 +13,7 @@
import time
from awscli.clidriver import AWSCLIEntryPoint
-from awscli.customizations.sso.utils import OpenBrowserHandler
+from awscli.customizations.sso.utils import OpenBrowserHandler, AuthCodeFetcher
from awscli.testutils import create_clidriver
from awscli.testutils import FileCreator
from awscli.testutils import BaseAWSCommandParamsTest
@@ -45,6 +45,29 @@ def setUp(self):
self.open_browser_mock,
)
self.open_browser_patch.start()
+
+ self.fetcher_mock = mock.Mock(spec=AuthCodeFetcher)
+ self.fetcher_mock.return_value.redirect_uri_without_port.return_value = (
+ 'http://127.0.0.1/oauth/callback'
+ )
+ self.fetcher_mock.return_value.redirect_uri_with_port.return_value = (
+ 'http://127.0.0.1:55555/oauth/callback'
+ )
+ self.fetcher_mock.return_value.get_auth_code_and_state.return_value = (
+ "abc", "00000000-0000-0000-0000-000000000000"
+ )
+ self.auth_code_fetcher_patch = mock.patch(
+ 'awscli.customizations.sso.utils.AuthCodeFetcher',
+ self.fetcher_mock,
+ )
+ self.auth_code_fetcher_patch.start()
+
+ self.uuid_mock = mock.Mock(
+ return_value="00000000-0000-0000-0000-000000000000"
+ )
+ self.uuid_patch = mock.patch('uuid.uuid4', self.uuid_mock)
+ self.uuid_patch.start()
+
self.expires_in = 28800
self.expiration_time = time.time() + 1000
@@ -52,6 +75,8 @@ def tearDown(self):
super(BaseSSOTest, self).tearDown()
self.files.remove_all()
self.open_browser_patch.stop()
+ self.auth_code_fetcher_patch.stop()
+ self.uuid_patch.stop()
self.token_cache_dir_patch.stop()
def assert_used_expected_sso_region(self, expected_region):
diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py
index e24c47591716..f8684908db21 100644
--- a/tests/functional/sso/test_login.py
+++ b/tests/functional/sso/test_login.py
@@ -23,8 +23,8 @@ class TestLoginCommand(BaseSSOTest):
r'\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ'
)
- def add_oidc_workflow_responses(self, access_token,
- include_register_response=True):
+ def add_oidc_device_responses(self, access_token,
+ include_register_response=True):
responses = [
# StartDeviceAuthorization response
{
@@ -53,12 +53,50 @@ def add_oidc_workflow_responses(self, access_token,
0,
{
'clientSecretExpiresAt': self.expiration_time,
- 'clientId': 'foo-client-id',
- 'clientSecret': 'foo-client-secret',
+ 'clientId': 'device-client-id',
+ 'clientSecret': 'device-client-secret',
}
)
self.parsed_responses = responses
+ def add_oidc_auth_code_responses(self, access_token,
+ include_register_response=True):
+ responses = [
+ # CreateToken responses
+ {
+ 'expiresIn': self.expires_in,
+ 'tokenType': 'Bearer',
+ 'accessToken': access_token,
+ }
+ ]
+ if include_register_response:
+ responses.insert(
+ 0,
+ {
+ 'clientSecretExpiresAt': self.expiration_time,
+ 'clientId': 'auth-client-id',
+ 'clientSecret': 'auth-client-secret',
+ }
+ )
+ self.parsed_responses = responses
+
+ def assert_cache_contains_registration(
+ self,
+ start_url,
+ session_name,
+ scopes,
+ expected_client_id):
+ cached_files = os.listdir(self.token_cache_dir)
+
+ cached_registration_filename = self._get_cached_registration_filename(
+ start_url, session_name, scopes)
+
+ self.assertIn(cached_registration_filename, cached_files)
+ self.assertEqual(
+ self._get_cached_response(cached_registration_filename)['clientId'],
+ expected_client_id
+ )
+
def assert_cache_contains_token(
self,
start_url,
@@ -78,6 +116,17 @@ def assert_cache_contains_token(
expected_token
)
+ def _get_cached_registration_filename(self, start_url, session_name, scopes):
+ args = {
+ 'tool': 'botocore',
+ 'startUrl': start_url,
+ 'region': self.sso_region,
+ 'scopes': scopes,
+ 'session_name': session_name,
+ }
+ cache_args = json.dumps(args, sort_keys=True).encode('utf-8')
+ return hashlib.sha1(cache_args).hexdigest() + '.json'
+
def _get_cached_token_filename(self, start_url, session_name):
to_hash = start_url
if session_name:
@@ -101,8 +150,19 @@ def assert_cache_token_expiration_time_format_is_correct(self):
)
)
- def test_login(self):
- self.add_oidc_workflow_responses(self.access_token)
+ def test_login_explicit_device(self):
+ self.add_oidc_device_responses(self.access_token)
+ self.run_cmd('sso login --use-device-code')
+ self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_token(
+ start_url=self.start_url,
+ expected_token=self.access_token
+ )
+
+ def test_login_implicit_device(self):
+ # This is a legacy profile via setUp, so we expect
+ # it to fall back to device flow automatically
+ self.add_oidc_device_responses(self.access_token)
self.run_cmd('sso login')
self.assert_used_expected_sso_region(expected_region=self.sso_region)
self.assert_cache_contains_token(
@@ -110,9 +170,9 @@ def test_login(self):
expected_token=self.access_token
)
- def test_login_no_browser(self):
- self.add_oidc_workflow_responses(self.access_token)
- stdout, _, _ = self.run_cmd('sso login --no-browser')
+ def test_login_device_no_browser(self):
+ self.add_oidc_device_responses(self.access_token)
+ stdout, _, _ = self.run_cmd('sso login --use-device-code --no-browser')
self.assertIn('Browser will not be automatically opened.', stdout)
self.open_browser_mock.assert_not_called()
self.assert_used_expected_sso_region(expected_region=self.sso_region)
@@ -121,20 +181,86 @@ def test_login_no_browser(self):
expected_token=self.access_token
)
- def test_login_forces_refresh(self):
- self.add_oidc_workflow_responses(self.access_token)
+ def test_login_auth_no_browser(self):
+ content = self.get_sso_session_config('test-session')
+ self.set_config_file_content(content=content)
+ self.add_oidc_auth_code_responses(self.access_token)
+ stdout, _, _ = self.run_cmd('sso login --no-browser')
+ self.assertIn('Browser will not be automatically opened.', stdout)
+ self.open_browser_mock.assert_not_called()
+ self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='auth-client-id'
+ )
+ self.assert_cache_contains_token(
+ start_url=self.start_url,
+ expected_token=self.access_token,
+ session_name='test-session'
+ )
+
+ def test_login_device_forces_refresh(self):
+ self.add_oidc_device_responses(self.access_token)
+ self.run_cmd('sso login --use-device-code')
+ # The register response from the first login should have been
+ # cached.
+ self.add_oidc_device_responses(
+ 'new.token', include_register_response=False)
+ self.run_cmd('sso login --use-device-code')
+ self.assert_cache_contains_token(
+ start_url=self.start_url,
+ expected_token='new.token',
+ )
+
+ def test_login_auth_forces_refresh(self):
+ content = self.get_sso_session_config('test-session')
+ self.set_config_file_content(content=content)
+ self.add_oidc_auth_code_responses(self.access_token)
self.run_cmd('sso login')
# The register response from the first login should have been
# cached.
- self.add_oidc_workflow_responses(
+ self.add_oidc_auth_code_responses(
'new.token', include_register_response=False)
self.run_cmd('sso login')
self.assert_cache_contains_token(
start_url=self.start_url,
- expected_token='new.token'
+ expected_token='new.token',
+ session_name='test-session'
+ )
+
+ def test_login_auth_after_device_forces_refresh(self):
+ self.set_config_file_content(
+ content=self.get_sso_session_config('test-session'))
+ self.add_oidc_device_responses(self.access_token)
+ self.run_cmd('sso login --use-device-code')
+ # The register response from the first login should have been
+ # cached.
+ self.add_oidc_auth_code_responses('new.token')
+ self.run_cmd('sso login')
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='auth-client-id'
+ )
+ self.assert_cache_contains_token(
+ start_url=self.start_url,
+ expected_token='new.token',
+ session_name='test-session'
+ )
+
+ def test_login_device_no_sso_configuration(self):
+ self.set_config_file_content(content='')
+ _, stderr, _ = self.run_cmd('sso login --use-device-code',
+ expected_rc=253)
+ self.assertIn(
+ 'Missing the following required SSO configuration',
+ stderr
)
- def test_login_no_sso_configuration(self):
+ def test_login_auth_no_sso_configuration(self):
self.set_config_file_content(content='')
_, stderr, _ = self.run_cmd('sso login', expected_rc=253)
self.assertIn(
@@ -142,14 +268,14 @@ def test_login_no_sso_configuration(self):
stderr
)
- def test_login_minimal_sso_configuration(self):
+ def test_login_device_minimal_sso_configuration(self):
content = (
'[default]\n'
'sso_start_url={start_url}\n'
'sso_region={sso_region}\n'
).format(start_url=self.start_url, sso_region=self.sso_region)
self.set_config_file_content(content=content)
- self.add_oidc_workflow_responses(self.access_token)
+ self.add_oidc_device_responses(self.access_token)
self.run_cmd('sso login')
self.assert_used_expected_sso_region(expected_region=self.sso_region)
self.assert_cache_contains_token(
@@ -157,13 +283,15 @@ def test_login_minimal_sso_configuration(self):
expected_token=self.access_token
)
- def test_login_partially_missing_sso_configuration(self):
+ def test_login_device_partially_missing_sso_configuration(self):
content = (
'[default]\n'
'sso_start_url=%s\n' % self.start_url
)
self.set_config_file_content(content=content)
- _, stderr, _ = self.run_cmd('sso login', expected_rc=253)
+ _, stderr, _ = self.run_cmd(
+ 'sso login --use-device-code', expected_rc=253
+ )
self.assertIn(
'Missing the following required SSO configuration',
stderr
@@ -174,8 +302,8 @@ def test_login_partially_missing_sso_configuration(self):
self.assertNotIn('sso_role_name', stderr)
def test_token_cache_datetime_format(self):
- self.add_oidc_workflow_responses(self.access_token)
- self.run_cmd('sso login')
+ self.add_oidc_device_responses(self.access_token)
+ self.run_cmd('sso login --use-device-code')
self.assert_used_expected_sso_region(expected_region=self.sso_region)
self.assert_cache_contains_token(
start_url=self.start_url,
@@ -183,38 +311,115 @@ def test_token_cache_datetime_format(self):
)
self.assert_cache_token_expiration_time_format_is_correct()
- def test_login_sso_session(self):
+ def test_login_device_sso_session(self):
+ content = self.get_sso_session_config('test-session')
+ self.set_config_file_content(content=content)
+ self.add_oidc_device_responses(self.access_token)
+ self.run_cmd('sso login --use-device-code')
+ self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='device-client-id'
+ )
+ self.assert_cache_contains_token(
+ start_url=self.start_url,
+ session_name='test-session',
+ expected_token=self.access_token,
+ )
+
+ def test_login_auth_sso_session(self):
content = self.get_sso_session_config('test-session')
self.set_config_file_content(content=content)
- self.add_oidc_workflow_responses(self.access_token)
+ self.add_oidc_auth_code_responses(self.access_token)
self.run_cmd('sso login')
self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='auth-client-id'
+ )
+ self.assert_cache_contains_token(
+ start_url=self.start_url,
+ session_name='test-session',
+ expected_token=self.access_token,
+ )
+
+ def test_login_device_sso_with_explicit_sso_session_arg(self):
+ content = self.get_sso_session_config(
+ 'test-session', include_profile=False)
+ self.set_config_file_content(content=content)
+ self.add_oidc_device_responses(self.access_token)
+ self.run_cmd('sso login --sso-session test-session --use-device-code')
+ self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='device-client-id'
+ )
self.assert_cache_contains_token(
start_url=self.start_url,
session_name='test-session',
expected_token=self.access_token,
)
- def test_login_sso_with_explicit_sso_session_arg(self):
+ def test_login_auth_sso_with_explicit_sso_session_arg(self):
content = self.get_sso_session_config(
'test-session', include_profile=False)
self.set_config_file_content(content=content)
- self.add_oidc_workflow_responses(self.access_token)
+ self.add_oidc_auth_code_responses(self.access_token)
self.run_cmd('sso login --sso-session test-session')
self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='auth-client-id'
+ )
+ self.assert_cache_contains_token(
+ start_url=self.start_url,
+ session_name='test-session',
+ expected_token=self.access_token,
+ )
+
+ def test_login_device_sso_session_with_scopes(self):
+ self.registration_scopes = ['sso:foo', 'sso:bar']
+ content = self.get_sso_session_config('test-session')
+ self.set_config_file_content(content=content)
+ self.add_oidc_device_responses(self.access_token)
+ self.run_cmd('sso login --use-device-code')
+ self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='device-client-id'
+ )
self.assert_cache_contains_token(
start_url=self.start_url,
session_name='test-session',
expected_token=self.access_token,
)
+ operation, params = self.operations_called[0]
+ self.assertEqual(operation.name, 'RegisterClient')
+ self.assertEqual(params.get('scopes'), self.registration_scopes)
- def test_login_sso_session_with_scopes(self):
+ def test_login_auth_sso_session_with_scopes(self):
self.registration_scopes = ['sso:foo', 'sso:bar']
content = self.get_sso_session_config('test-session')
self.set_config_file_content(content=content)
- self.add_oidc_workflow_responses(self.access_token)
+ self.add_oidc_auth_code_responses(self.access_token)
self.run_cmd('sso login')
self.assert_used_expected_sso_region(expected_region=self.sso_region)
+ self.assert_cache_contains_registration(
+ start_url=self.start_url,
+ session_name='test-session',
+ scopes=self.registration_scopes,
+ expected_client_id='auth-client-id'
+ )
self.assert_cache_contains_token(
start_url=self.start_url,
session_name='test-session',
@@ -245,3 +450,36 @@ def test_login_sso_session_missing(self):
self.set_config_file_content(content=content)
_, stderr, _ = self.run_cmd('sso login', expected_rc=253)
self.assertIn('sso-session does not exist: "test"', stderr)
+
+ def test_login_auth_sso_no_authorization_code_throws_error(self):
+ self.fetcher_mock.return_value.get_auth_code_and_state.return_value = (
+ None, None
+ )
+ content = self.get_sso_session_config('test-session')
+ self.set_config_file_content(content=content)
+ self.add_oidc_auth_code_responses(self.access_token)
+
+ _, stderr, _ = self.run_cmd(
+ 'sso login', expected_rc=255
+ )
+ self.assertIn(
+ 'Failed to retrieve an authorization code.',
+ stderr
+ )
+
+ def test_login_auth_sso_state_mismatch_throws_error(self):
+ self.fetcher_mock.return_value.get_auth_code_and_state.return_value = (
+ "abc", '00000000-0000-0000-0000-000000000001'
+ )
+ content = self.get_sso_session_config('test-session')
+ self.set_config_file_content(content=content)
+ self.add_oidc_auth_code_responses(self.access_token)
+
+ _, stderr, _ = self.run_cmd(
+ 'sso login', expected_rc=255
+ )
+ self.assertIn(
+ 'State parameter does not match expected value.',
+ stderr
+ )
+
diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py
index 47f5762b3d8c..7a24acee3baf 100644
--- a/tests/unit/customizations/sso/test_utils.py
+++ b/tests/unit/customizations/sso/test_utils.py
@@ -11,22 +11,23 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import os
+import threading
import webbrowser
import pytest
-
-from awscli.testutils import mock
-from awscli.testutils import unittest
-
+import urllib3
from botocore.session import Session
-from botocore.exceptions import ClientError
from awscli.compat import StringIO
-from awscli.customizations.sso.utils import parse_sso_registration_scopes
-from awscli.customizations.sso.utils import do_sso_login
from awscli.customizations.sso.utils import OpenBrowserHandler
from awscli.customizations.sso.utils import PrintOnlyHandler
+from awscli.customizations.sso.utils import do_sso_login
from awscli.customizations.sso.utils import open_browser_with_original_ld_path
+from awscli.customizations.sso.utils import (
+ parse_sso_registration_scopes, AuthCodeFetcher
+)
+from awscli.testutils import mock
+from awscli.testutils import unittest
@pytest.mark.parametrize(
@@ -205,3 +206,45 @@ def test_can_patch_env(self):
os.environ)
open_browser_with_original_ld_path('http://example.com')
self.assertIsNone(captured_env.get('LD_LIBRARY_PATH'))
+
+
+class TestAuthCodeFetcher(unittest.TestCase):
+ """Tests for the AuthCodeFetcher class, which is the local
+ web server we use to handle the OAuth 2.0 callback
+ """
+
+ def setUp(self):
+ self.fetcher = AuthCodeFetcher()
+ self.url = f'http://127.0.0.1:{self.fetcher.http_server.server_address[1]}/'
+
+ # Start the server on a background thread so that
+ # the test thread can make the request
+ self.server_thread = threading.Thread(
+ target=self.fetcher.get_auth_code_and_state
+ )
+ self.server_thread.setDaemon(True)
+ self.server_thread.start()
+
+ def test_expected_auth_code(self):
+ expected_code = '1234'
+ expected_state = '4567'
+ url = self.url + f'?code={expected_code}&state={expected_state}'
+
+ http = urllib3.PoolManager()
+ response = http.request("GET", url)
+
+ self.assertEqual(response.status, 200)
+ self.assertEqual(self.fetcher._auth_code, expected_code)
+ self.assertEqual(self.fetcher._state, expected_state)
+
+ def test_error(self):
+ expected_code = 'Failed'
+ url = self.url + f'?error={expected_code}'
+
+ http = urllib3.PoolManager()
+ response = http.request("GET", url)
+
+ self.assertEqual(response.status, 200)
+ self.assertEqual(self.fetcher._auth_code, None)
+ self.assertEqual(self.fetcher._state, None)
+