From 3cafdecf47fac879eb36bb2e2d8eb1537434ebe4 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Thu, 26 Sep 2024 10:08:10 -0400 Subject: [PATCH] Add the OAuth Authorization Code Flow with PKCE This adds support for the OAuth2.0 authorization code flow with PKCE to the aws sso login command. It is the new default behavior, but users can fall back to the device code flow using the new --use-device-code option. --- .changes/next-release/feature-sso-81096.json | 5 + awscli/botocore/exceptions.py | 12 + awscli/botocore/utils.py | 333 ++++++++++++++++--- awscli/customizations/sso/index.html | 176 ++++++++++ awscli/customizations/sso/login.py | 1 + awscli/customizations/sso/utils.py | 139 +++++++- tests/functional/sso/__init__.py | 27 +- tests/functional/sso/test_login.py | 290 ++++++++++++++-- tests/unit/customizations/sso/test_utils.py | 57 +++- 9 files changed, 952 insertions(+), 88 deletions(-) create mode 100644 .changes/next-release/feature-sso-81096.json create mode 100644 awscli/customizations/sso/index.html diff --git a/.changes/next-release/feature-sso-81096.json b/.changes/next-release/feature-sso-81096.json new file mode 100644 index 000000000000..e4419d6bbd19 --- /dev/null +++ b/.changes/next-release/feature-sso-81096.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "sso", + "description": "Add support and default to the OAuth 2.0 Authorization Code Flow with PKCE for aws sso login." +} diff --git a/awscli/botocore/exceptions.py b/awscli/botocore/exceptions.py index 115d2442a572..b92fa6bac0b2 100644 --- a/awscli/botocore/exceptions.py +++ b/awscli/botocore/exceptions.py @@ -682,6 +682,10 @@ class SSOTokenLoadError(SSOError): fmt = "Error loading SSO Token: {error_msg}" +class AuthorizationCodeLoadError(SSOError): + fmt = "Error loading authorization code: {error_msg}" + + class UnauthorizedSSOTokenError(SSOError): fmt = ( "The SSO session associated with this profile has expired or is " @@ -690,6 +694,14 @@ class UnauthorizedSSOTokenError(SSOError): ) +class AuthCodeFetcherError(SSOError): + fmt = ( + "Unable to initialize the OAuth 2.0 authorization callback handler: " + "{error_msg} \n You may use --use-device-code to fall back to the " + "device code flow which does not require the callback handler." + ) + + class CapacityNotAvailableError(BotoCoreError): fmt = ( 'Insufficient request capacity available.' diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index 5a7614211a4d..cf7c90630a37 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -23,7 +23,10 @@ import random import re import socket +import string import time +import urllib +import uuid import warnings import weakref from datetime import datetime as _DatetimeClass @@ -48,6 +51,7 @@ zip_longest, ) from botocore.exceptions import ( + AuthorizationCodeLoadError, ClientError, ConfigNotFound, ConnectionClosedError, @@ -3057,20 +3061,16 @@ def __call__(self): return token_file.read() -class SSOTokenFetcher(object): - # The device flow RFC defines the slow down delay to be an additional - # 5 seconds: - # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 - _SLOW_DOWN_DELAY = 5 - # The default interval of 5 is also defined in the RFC (see above link) - _DEFAULT_INTERVAL = 5 +class BaseSSOTokenFetcher(object): + """Base class for SSO token fetchers, for functionality + shared between the device and authorization code grant flows. + """ _EXPIRY_WINDOW = 15 * 60 _CLIENT_REGISTRATION_TYPE = 'public' - _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' def __init__( self, sso_region, client_creator, cache=None, - on_pending_authorization=None, time_fetcher=None, sleep=None, + on_pending_authorization=None, time_fetcher=None ): self._sso_region = sso_region self._client_creator = client_creator @@ -3080,10 +3080,6 @@ def __init__( time_fetcher = self._utc_now self._time_fetcher = time_fetcher - if sleep is None: - sleep = time.sleep - self._sleep = sleep - if cache is None: cache = {} self._cache = cache @@ -3101,6 +3097,15 @@ def _is_expired(self, response): seconds = total_seconds(end_time - self._time_fetcher()) return seconds < self._EXPIRY_WINDOW + def _is_registration_for_auth_code(self, registration): + if ('grantTypes' in registration and + 'authorization_code' in registration['grantTypes']): + return True + + # Else assume that it's device flow, + # since the CLI didn't cache grantTypes previously + return False + @CachedProperty def _client(self): config = botocore.config.Config( @@ -3109,13 +3114,61 @@ def _client(self): ) return self._client_creator('sso-oidc', config=config) - def _register_client(self, session_name, scopes): + def _generate_client_name(self, session_name): if session_name is None: # Use a timestamp for the session name for legacy configuration timestamp = datetime2timestamp(self._time_fetcher()) session_name = int(timestamp) + return f'botocore-client-{str(session_name)}' + + def _registration_cache_key(self, start_url, session_name, scopes): + # Registration is unique based on the following properties to ensure + # modifications to the registration do not affect the permissions of + # tokens derived for other start URLs. + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self._sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + + def _token_cache_key(self, start_url, session_name): + input_str = start_url + if session_name is not None: + input_str = session_name + return hashlib.sha1(input_str.encode('utf-8')).hexdigest() + + +class SSOTokenFetcher(BaseSSOTokenFetcher): + """Performs the device grant OAuth2.0 flow""" + # The device flow RFC defines the slow-down delay to be an additional + # 5 seconds: + # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 + _SLOW_DOWN_DELAY = 5 + # The default interval of 5 is also defined in the RFC (see above link) + _DEFAULT_INTERVAL = 5 + _EXPIRY_WINDOW = 15 * 60 + _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' + + def __init__( + self, sso_region, client_creator, cache=None, + on_pending_authorization=None, time_fetcher=None, sleep=None + ): + super().__init__( + sso_region, client_creator, cache, on_pending_authorization, + time_fetcher + ) + + if sleep is None: + sleep = time.sleep + self._sleep = sleep + + def _register_client(self, session_name, scopes): register_kwargs = { - 'clientName': f'botocore-client-{session_name}', + 'clientName': self._generate_client_name(session_name), 'clientType': self._CLIENT_REGISTRATION_TYPE, } if scopes: @@ -3132,20 +3185,6 @@ def _register_client(self, session_name, scopes): registration['scopes'] = scopes return registration - def _registration_cache_key(self, start_url, session_name, scopes): - # Registration is unique based on the following properties to ensure - # modifications to the registration do not affect the permissions of - # tokens derived for other start URLs. - args = { - 'tool': 'botocore', - 'startUrl': start_url, - 'region': self._sso_region, - 'scopes': scopes, - 'session_name': session_name, - } - cache_args = json.dumps(args, sort_keys=True).encode('utf-8') - return hashlib.sha1(cache_args).hexdigest() - def _registration( self, start_url, @@ -3160,7 +3199,8 @@ def _registration( ) if not force_refresh and cache_key in self._cache: registration = self._cache[cache_key] - if not self._is_expired(registration): + if (not self._is_expired(registration) and + not self._is_registration_for_auth_code(registration)): return registration registration = self._register_client( @@ -3172,7 +3212,7 @@ def _registration( def _authorize_client(self, start_url, registration): # NOTE: The authorization response is not cached. These responses are - # short lived (currently only 10 minutes) and can only be exchanged for + # short-lived (currently only 10 minutes) and can only be exchanged for # a token once. Having multiple clients share this is problematic. response = self._client.start_device_authorization( clientId=registration['clientId'], @@ -3261,12 +3301,6 @@ def _create_token_attempt( raise PendingAuthorizationExpiredError() return interval, None - def _token_cache_key(self, start_url, session_name): - input_str = start_url - if session_name is not None: - input_str = session_name - return hashlib.sha1(input_str.encode('utf-8')).hexdigest() - def _token( self, start_url, @@ -3305,6 +3339,231 @@ def fetch_token( ) +class SSOTokenFetcherAuth(BaseSSOTokenFetcher): + """Performs the authorization code grant with PKCE OAuth2.0 flow""" + _AUTH_GRANT_TYPES = ['authorization_code', 'refresh_token'] + _AUTH_GRANT_DEFAULT_SCOPE = 'sso:account:access' + + def __init__( + self, sso_region, client_creator, auth_code_fetcher, cache=None, + on_pending_authorization=None, time_fetcher=None + ): + super().__init__(sso_region, client_creator, cache, + on_pending_authorization, time_fetcher + ) + + self._auth_code_fetcher = auth_code_fetcher + + # Generate the PKCE pair + self.code_verifier = ''.join( + random.SystemRandom().choice( + string.ascii_letters + string.digits + '-._~' + ) + for _ in range(64) + ) + self.code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(self.code_verifier.encode()).digest()).decode() + + def _register_client(self, session_name, scopes, redirect_uri, issuer_url): + register_kwargs = { + 'clientName': self._generate_client_name(session_name), + 'clientType': self._CLIENT_REGISTRATION_TYPE, + 'grantTypes': self._AUTH_GRANT_TYPES, + 'redirectUris': [redirect_uri], + 'issuerUrl': issuer_url + } + + if scopes: + register_kwargs['scopes'] = scopes + else: + register_kwargs['scopes'] = [self._AUTH_GRANT_DEFAULT_SCOPE] + + response = self._client.register_client(**register_kwargs) + + expires_at = response['clientSecretExpiresAt'] + expires_at = datetime.datetime.fromtimestamp(expires_at, tzutc()) + registration = { + 'clientId': response['clientId'], + 'clientSecret': response['clientSecret'], + 'expiresAt': expires_at, + 'scopes': register_kwargs['scopes'], + 'grantTypes': register_kwargs['grantTypes'] + } + + return registration + + def _registration( + self, + start_url, + session_name, + scopes, + force_refresh=False, + ): + cache_key = self._registration_cache_key( + start_url, + session_name, + scopes, + ) + if not force_refresh and cache_key in self._cache: + registration = self._cache[cache_key] + if (not self._is_expired(registration) and + self._is_registration_for_auth_code(registration)): + return registration + + registration = self._register_client( + session_name, + scopes, + self._auth_code_fetcher.redirect_uri_without_port(), + start_url + ) + self._cache[cache_key] = registration + return registration + + def _extract_resolved_endpoint(self, params, **kwargs): + """Event handler for before-call that will extract the resolved endpoint + for a given request without actually running it + """ + # This will contain any path and query params specific to + # the operation/input, so extract just the scheme and hostname + if params['url']: + parsed = urlparse(params['url']) + self._base_endpoint = f'{parsed.scheme}://{parsed.netloc}' + + # Return a tuple containing the "response" to short-circuit the request + return botocore.awsrequest.AWSResponse(None, 200, {}, None), {} + + def _get_base_authorization_uri(self): + """Simulates an SSO-OIDC request so that we can extract the "base" + endpoint for the current client to use for the un-modeled Authorize + operation + """ + self._client.meta.events.register('before-call', + self._extract_resolved_endpoint) + self._client.register_client( + clientName='temp', + clientType='public' + ) + self._client.meta.events.unregister('before-call', + self._extract_resolved_endpoint) + + return self._base_endpoint + + def _get_authorization_uri( + self, + client_id, + registration_scopes, + expected_state): + + query_params = { + 'response_type': 'code', + 'client_id': client_id, + 'redirect_uri': self._auth_code_fetcher.redirect_uri_with_port(), + 'state': expected_state, + 'code_challenge_method': 'S256' + # Don't want to encode code_challenge again, so we append below + } + + # For the query param, scopes must be space separated before encoding + if registration_scopes: + query_params['scope'] = " ".join(registration_scopes) + + return ( + f'{self._get_base_authorization_uri()}/authorize?' + f'{urllib.parse.urlencode(query_params)}' + f'&code_challenge={self.code_challenge[:-1]}' # trim final '=' + ) + + def _get_new_token(self, start_url, session_name, registration_scopes): + registration = self._registration( + start_url, + session_name, + registration_scopes, + ) + + expected_state = uuid.uuid4() + + authorization_uri = self._get_authorization_uri( + registration['clientId'], + registration_scopes, + expected_state) + + # Even though there's just one uri, this matches the inputs + # for the device code flow so that we can reuse the browser handlers + authorization_args = { + 'verificationUri': authorization_uri, + 'verificationUriComplete': authorization_uri, + 'userCode': None + } + + # Open/display the link, then block until the redirect uri is hit and + # the auth code is retrieved + self._on_pending_authorization(**authorization_args) + auth_code, state = self._auth_code_fetcher.get_auth_code_and_state() + + if auth_code is None: + raise AuthorizationCodeLoadError( + error_msg='Failed to retrieve an authorization code.' + ) + + if state != expected_state: + raise AuthorizationCodeLoadError( + error_msg='State parameter does not match expected value.' + ) + + return self._create_token_(start_url, registration, auth_code) + + def _create_token_(self, start_url, registration, auth_code): + try: + response = self._client.create_token( + grantType='authorization_code', + clientId=registration['clientId'], + clientSecret=registration['clientSecret'], + redirectUri=self._auth_code_fetcher.redirect_uri_with_port(), + codeVerifier=self.code_verifier, + code=auth_code + ) + expires_in = datetime.timedelta(seconds=response['expiresIn']) + token = { + 'startUrl': start_url, + 'region': self._sso_region, + 'accessToken': response['accessToken'], + 'expiresAt': self._time_fetcher() + expires_in, + # Cache the registration alongside the token + 'clientId': registration['clientId'], + 'clientSecret': registration['clientSecret'], + 'registrationExpiresAt': registration['expiresAt'], + } + if 'refreshToken' in response: + token['refreshToken'] = response['refreshToken'] + return token + except self._client.exceptions.InvalidGrantException as error: + print(error) + except self._client.exceptions.ExpiredTokenException: + raise PendingAuthorizationExpiredError() + + def fetch_token( + self, + start_url, + force_refresh=False, + registration_scopes=None, + session_name=None, + ): + cache_key = self._token_cache_key(start_url, session_name) + # Only obey the token cache if we are not forcing a refresh. + if not force_refresh and cache_key in self._cache: + token = self._cache[cache_key] + if not self._is_expired(token): + return token + + token = self._get_new_token( + start_url, + session_name, + registration_scopes + ) + self._cache[cache_key] = token + return token + + class SSOTokenLoader(object): def __init__(self, cache=None): if cache is None: diff --git a/awscli/customizations/sso/index.html b/awscli/customizations/sso/index.html new file mode 100644 index 000000000000..f1572f4b04f9 --- /dev/null +++ b/awscli/customizations/sso/index.html @@ -0,0 +1,176 @@ + + + + + AWS Authentication + + + + + +
+
+ + + + + +
+
+ +
+
+ +
+

Request approved

+

+
+
+

+
+ + + +
+
+ + + diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index 4d7790b62c76..acd58e287dd4 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -52,6 +52,7 @@ def _run_main(self, parsed_args, parsed_globals): force_refresh=True, session_name=sso_config.get('session_name'), registration_scopes=sso_config.get('registration_scopes'), + fallback_to_device_flow=parsed_args.use_device_code ) success_msg = 'Successfully logged into Start URL: %s\n' uni_print(success_msg % sso_config['sso_start_url']) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index ae9a83e9e8b8..d7d813bf11cc 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -11,19 +11,22 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import datetime -import os -import logging import json +import logging +import os import webbrowser +from functools import partial +from http.server import HTTPServer, BaseHTTPRequestHandler +from urllib.parse import urlparse, parse_qs -from botocore.utils import SSOTokenFetcher -from botocore.utils import original_ld_library_path from botocore.credentials import JSONFileCache +from botocore.exceptions import AuthCodeFetcherError +from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth +from botocore.utils import original_ld_library_path from awscli.customizations.commands import BasicCommand -from awscli.customizations.utils import uni_print -from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR from awscli.customizations.exceptions import ConfigurationError +from awscli.customizations.utils import uni_print LOG = logging.getLogger(__name__) @@ -37,9 +40,18 @@ 'action': 'store_true', 'default': False, 'help_text': ( - 'Disables automatically opening the verfication URL in the ' + 'Disables automatically opening the verification URL in the ' 'default browser.' ) + }, + { + 'name': 'use-device-code', + 'action': 'store_true', + 'default': False, + 'help_text': ( + 'Uses the Device Code authorization grant and login flow ' + 'instead of the Authorization Code flow.' + ) } ] @@ -56,24 +68,38 @@ def _sso_json_dumps(obj): def do_sso_login(session, sso_region, start_url, token_cache=None, on_pending_authorization=None, force_refresh=False, - registration_scopes=None, session_name=None): + registration_scopes=None, session_name=None, + fallback_to_device_flow=False): if token_cache is None: token_cache = JSONFileCache(SSO_TOKEN_DIR, dumps_func=_sso_json_dumps) if on_pending_authorization is None: on_pending_authorization = OpenBrowserHandler( open_browser=open_browser_with_original_ld_path ) - token_fetcher = SSOTokenFetcher( - sso_region=sso_region, - client_creator=session.create_client, - cache=token_cache, - on_pending_authorization=on_pending_authorization - ) + + # For the auth flow, we need a non-legacy sso-session and check that the + # user hasn't opted into falling back to the device code flow + if session_name and not fallback_to_device_flow: + token_fetcher = SSOTokenFetcherAuth( + sso_region=sso_region, + client_creator=session.create_client, + auth_code_fetcher=AuthCodeFetcher(), + cache=token_cache, + on_pending_authorization=on_pending_authorization + ) + else: + token_fetcher = SSOTokenFetcher( + sso_region=sso_region, + client_creator=session.create_client, + cache=token_cache, + on_pending_authorization=on_pending_authorization + ) + return token_fetcher.fetch_token( start_url=start_url, session_name=session_name, force_refresh=force_refresh, - registration_scopes=registration_scopes, + registration_scopes=registration_scopes ) @@ -110,13 +136,19 @@ def __call__( f'Browser will not be automatically opened.\n' f'Please visit the following URL:\n' f'\n{verificationUri}\n' + + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' f'\nAlternatively, you may visit the following URL which will ' f'autofill the code upon loading:' - f'\n{verificationUriComplete}\n' - ) + f'\n{verificationUriComplete}\n') + uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) class OpenBrowserHandler(BaseAuthorizationhandler): @@ -135,10 +167,16 @@ def __call__( f'to use a different device to authorize this request, open the ' f'following URL:\n' f'\n{verificationUri}\n' + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' ) uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) + if self._open_browser: try: return self._open_browser(verificationUriComplete) @@ -146,6 +184,73 @@ def __call__( LOG.debug('Failed to open browser:', exc_info=True) +class AuthCodeFetcher: + """Manages the local web server that will be used + to retrieve the authorization code from the OAuth callback + """ + def __init__(self): + self._auth_code = None + self._state = None + self._is_done = False + + # We do this so that the request handler can have a reference to this + # AuthCodeFetcher so that it can pass back the state and auth code + try: + handler = partial(OAuthCallbackHandler, self) + self.http_server = HTTPServer(('', 0), handler) + except Exception as e: + raise AuthCodeFetcherError(error_msg=e) + + def redirect_uri_without_port(self): + return 'http://127.0.0.1/oauth/callback' + + def redirect_uri_with_port(self): + return f'http://127.0.0.1:{self.http_server.server_port}/oauth/callback' + + def get_auth_code_and_state(self): + """Blocks until the expected redirect request with either the + authorization code/state or and error is handled + """ + while not self._is_done: + self.http_server.handle_request() + self.http_server.server_close() + + return self._auth_code, self._state + + +class OAuthCallbackHandler(BaseHTTPRequestHandler): + """HTTP handler to handle OAuth callback requests, extracting + the auth code and state parameters, and displaying a page directing + the user to return to the CLI. + """ + def __init__(self, auth_code_fetcher, *args, **kwargs): + self._auth_code_fetcher = auth_code_fetcher + super().__init__(*args, **kwargs) + + def log_message(self, format, *args): + # Suppress built-in logging, otherwise it prints + # each request to console + pass + + def do_GET(self): + self.send_response(200) + self.end_headers() + with open( + os.path.join(os.path.dirname(__file__), 'index.html'), + 'rb') as file: + self.wfile.write(file.read()) + + query_params = parse_qs(urlparse(self.path).query) + + if 'error' in query_params: + self._auth_code_fetcher._is_done = True + return + elif 'code' in query_params and 'state' in query_params: + self._auth_code_fetcher._is_done = True + self._auth_code_fetcher._auth_code = query_params['code'][0] + self._auth_code_fetcher._state = query_params['state'][0] + + class InvalidSSOConfigError(ConfigurationError): pass diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index 98036f9173b8..ace3e5e63b03 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -13,7 +13,7 @@ import time from awscli.clidriver import AWSCLIEntryPoint -from awscli.customizations.sso.utils import OpenBrowserHandler +from awscli.customizations.sso.utils import OpenBrowserHandler, AuthCodeFetcher from awscli.testutils import create_clidriver from awscli.testutils import FileCreator from awscli.testutils import BaseAWSCommandParamsTest @@ -45,6 +45,29 @@ def setUp(self): self.open_browser_mock, ) self.open_browser_patch.start() + + self.fetcher_mock = mock.Mock(spec=AuthCodeFetcher) + self.fetcher_mock.return_value.redirect_uri_without_port.return_value = ( + 'http://127.0.0.1/oauth/callback' + ) + self.fetcher_mock.return_value.redirect_uri_with_port.return_value = ( + 'http://127.0.0.1:55555/oauth/callback' + ) + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", "00000000-0000-0000-0000-000000000000" + ) + self.auth_code_fetcher_patch = mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher', + self.fetcher_mock, + ) + self.auth_code_fetcher_patch.start() + + self.uuid_mock = mock.Mock( + return_value="00000000-0000-0000-0000-000000000000" + ) + self.uuid_patch = mock.patch('uuid.uuid4', self.uuid_mock) + self.uuid_patch.start() + self.expires_in = 28800 self.expiration_time = time.time() + 1000 @@ -52,6 +75,8 @@ def tearDown(self): super(BaseSSOTest, self).tearDown() self.files.remove_all() self.open_browser_patch.stop() + self.auth_code_fetcher_patch.stop() + self.uuid_patch.stop() self.token_cache_dir_patch.stop() def assert_used_expected_sso_region(self, expected_region): diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index e24c47591716..f8684908db21 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -23,8 +23,8 @@ class TestLoginCommand(BaseSSOTest): r'\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ' ) - def add_oidc_workflow_responses(self, access_token, - include_register_response=True): + def add_oidc_device_responses(self, access_token, + include_register_response=True): responses = [ # StartDeviceAuthorization response { @@ -53,12 +53,50 @@ def add_oidc_workflow_responses(self, access_token, 0, { 'clientSecretExpiresAt': self.expiration_time, - 'clientId': 'foo-client-id', - 'clientSecret': 'foo-client-secret', + 'clientId': 'device-client-id', + 'clientSecret': 'device-client-secret', } ) self.parsed_responses = responses + def add_oidc_auth_code_responses(self, access_token, + include_register_response=True): + responses = [ + # CreateToken responses + { + 'expiresIn': self.expires_in, + 'tokenType': 'Bearer', + 'accessToken': access_token, + } + ] + if include_register_response: + responses.insert( + 0, + { + 'clientSecretExpiresAt': self.expiration_time, + 'clientId': 'auth-client-id', + 'clientSecret': 'auth-client-secret', + } + ) + self.parsed_responses = responses + + def assert_cache_contains_registration( + self, + start_url, + session_name, + scopes, + expected_client_id): + cached_files = os.listdir(self.token_cache_dir) + + cached_registration_filename = self._get_cached_registration_filename( + start_url, session_name, scopes) + + self.assertIn(cached_registration_filename, cached_files) + self.assertEqual( + self._get_cached_response(cached_registration_filename)['clientId'], + expected_client_id + ) + def assert_cache_contains_token( self, start_url, @@ -78,6 +116,17 @@ def assert_cache_contains_token( expected_token ) + def _get_cached_registration_filename(self, start_url, session_name, scopes): + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self.sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + '.json' + def _get_cached_token_filename(self, start_url, session_name): to_hash = start_url if session_name: @@ -101,8 +150,19 @@ def assert_cache_token_expiration_time_format_is_correct(self): ) ) - def test_login(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_explicit_device(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token + ) + + def test_login_implicit_device(self): + # This is a legacy profile via setUp, so we expect + # it to fall back to device flow automatically + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( @@ -110,9 +170,9 @@ def test_login(self): expected_token=self.access_token ) - def test_login_no_browser(self): - self.add_oidc_workflow_responses(self.access_token) - stdout, _, _ = self.run_cmd('sso login --no-browser') + def test_login_device_no_browser(self): + self.add_oidc_device_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --use-device-code --no-browser') self.assertIn('Browser will not be automatically opened.', stdout) self.open_browser_mock.assert_not_called() self.assert_used_expected_sso_region(expected_region=self.sso_region) @@ -121,20 +181,86 @@ def test_login_no_browser(self): expected_token=self.access_token ) - def test_login_forces_refresh(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_auth_no_browser(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --no-browser') + self.assertIn('Browser will not be automatically opened.', stdout) + self.open_browser_mock.assert_not_called() + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token, + session_name='test-session' + ) + + def test_login_device_forces_refresh(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_device_responses( + 'new.token', include_register_response=False) + self.run_cmd('sso login --use-device-code') + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + ) + + def test_login_auth_forces_refresh(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') # The register response from the first login should have been # cached. - self.add_oidc_workflow_responses( + self.add_oidc_auth_code_responses( 'new.token', include_register_response=False) self.run_cmd('sso login') self.assert_cache_contains_token( start_url=self.start_url, - expected_token='new.token' + expected_token='new.token', + session_name='test-session' + ) + + def test_login_auth_after_device_forces_refresh(self): + self.set_config_file_content( + content=self.get_sso_session_config('test-session')) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_auth_code_responses('new.token') + self.run_cmd('sso login') + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + session_name='test-session' + ) + + def test_login_device_no_sso_configuration(self): + self.set_config_file_content(content='') + _, stderr, _ = self.run_cmd('sso login --use-device-code', + expected_rc=253) + self.assertIn( + 'Missing the following required SSO configuration', + stderr ) - def test_login_no_sso_configuration(self): + def test_login_auth_no_sso_configuration(self): self.set_config_file_content(content='') _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn( @@ -142,14 +268,14 @@ def test_login_no_sso_configuration(self): stderr ) - def test_login_minimal_sso_configuration(self): + def test_login_device_minimal_sso_configuration(self): content = ( '[default]\n' 'sso_start_url={start_url}\n' 'sso_region={sso_region}\n' ).format(start_url=self.start_url, sso_region=self.sso_region) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( @@ -157,13 +283,15 @@ def test_login_minimal_sso_configuration(self): expected_token=self.access_token ) - def test_login_partially_missing_sso_configuration(self): + def test_login_device_partially_missing_sso_configuration(self): content = ( '[default]\n' 'sso_start_url=%s\n' % self.start_url ) self.set_config_file_content(content=content) - _, stderr, _ = self.run_cmd('sso login', expected_rc=253) + _, stderr, _ = self.run_cmd( + 'sso login --use-device-code', expected_rc=253 + ) self.assertIn( 'Missing the following required SSO configuration', stderr @@ -174,8 +302,8 @@ def test_login_partially_missing_sso_configuration(self): self.assertNotIn('sso_role_name', stderr) def test_token_cache_datetime_format(self): - self.add_oidc_workflow_responses(self.access_token) - self.run_cmd('sso login') + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( start_url=self.start_url, @@ -183,38 +311,115 @@ def test_token_cache_datetime_format(self): ) self.assert_cache_token_expiration_time_format_is_correct() - def test_login_sso_session(self): + def test_login_device_sso_session(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_auth_sso_session(self): content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_device_sso_with_explicit_sso_session_arg(self): + content = self.get_sso_session_config( + 'test-session', include_profile=False) + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --sso-session test-session --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) - def test_login_sso_with_explicit_sso_session_arg(self): + def test_login_auth_sso_with_explicit_sso_session_arg(self): content = self.get_sso_session_config( 'test-session', include_profile=False) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login --sso-session test-session') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_device_sso_session_with_scopes(self): + self.registration_scopes = ['sso:foo', 'sso:bar'] + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) + operation, params = self.operations_called[0] + self.assertEqual(operation.name, 'RegisterClient') + self.assertEqual(params.get('scopes'), self.registration_scopes) - def test_login_sso_session_with_scopes(self): + def test_login_auth_sso_session_with_scopes(self): self.registration_scopes = ['sso:foo', 'sso:bar'] content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', @@ -245,3 +450,36 @@ def test_login_sso_session_missing(self): self.set_config_file_content(content=content) _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn('sso-session does not exist: "test"', stderr) + + def test_login_auth_sso_no_authorization_code_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + None, None + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'Failed to retrieve an authorization code.', + stderr + ) + + def test_login_auth_sso_state_mismatch_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", '00000000-0000-0000-0000-000000000001' + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'State parameter does not match expected value.', + stderr + ) + diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 47f5762b3d8c..7a24acee3baf 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -11,22 +11,23 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import os +import threading import webbrowser import pytest - -from awscli.testutils import mock -from awscli.testutils import unittest - +import urllib3 from botocore.session import Session -from botocore.exceptions import ClientError from awscli.compat import StringIO -from awscli.customizations.sso.utils import parse_sso_registration_scopes -from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler +from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import open_browser_with_original_ld_path +from awscli.customizations.sso.utils import ( + parse_sso_registration_scopes, AuthCodeFetcher +) +from awscli.testutils import mock +from awscli.testutils import unittest @pytest.mark.parametrize( @@ -205,3 +206,45 @@ def test_can_patch_env(self): os.environ) open_browser_with_original_ld_path('http://example.com') self.assertIsNone(captured_env.get('LD_LIBRARY_PATH')) + + +class TestAuthCodeFetcher(unittest.TestCase): + """Tests for the AuthCodeFetcher class, which is the local + web server we use to handle the OAuth 2.0 callback + """ + + def setUp(self): + self.fetcher = AuthCodeFetcher() + self.url = f'http://127.0.0.1:{self.fetcher.http_server.server_address[1]}/' + + # Start the server on a background thread so that + # the test thread can make the request + self.server_thread = threading.Thread( + target=self.fetcher.get_auth_code_and_state + ) + self.server_thread.setDaemon(True) + self.server_thread.start() + + def test_expected_auth_code(self): + expected_code = '1234' + expected_state = '4567' + url = self.url + f'?code={expected_code}&state={expected_state}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + self.assertEqual(response.status, 200) + self.assertEqual(self.fetcher._auth_code, expected_code) + self.assertEqual(self.fetcher._state, expected_state) + + def test_error(self): + expected_code = 'Failed' + url = self.url + f'?error={expected_code}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + self.assertEqual(response.status, 200) + self.assertEqual(self.fetcher._auth_code, None) + self.assertEqual(self.fetcher._state, None) +