diff --git a/.changes/next-release/feature-sso-81096.json b/.changes/next-release/feature-sso-81096.json new file mode 100644 index 000000000000..e4419d6bbd19 --- /dev/null +++ b/.changes/next-release/feature-sso-81096.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "sso", + "description": "Add support and default to the OAuth 2.0 Authorization Code Flow with PKCE for aws sso login." +} diff --git a/awscli/botocore/exceptions.py b/awscli/botocore/exceptions.py index 115d2442a572..b92fa6bac0b2 100644 --- a/awscli/botocore/exceptions.py +++ b/awscli/botocore/exceptions.py @@ -682,6 +682,10 @@ class SSOTokenLoadError(SSOError): fmt = "Error loading SSO Token: {error_msg}" +class AuthorizationCodeLoadError(SSOError): + fmt = "Error loading authorization code: {error_msg}" + + class UnauthorizedSSOTokenError(SSOError): fmt = ( "The SSO session associated with this profile has expired or is " @@ -690,6 +694,14 @@ class UnauthorizedSSOTokenError(SSOError): ) +class AuthCodeFetcherError(SSOError): + fmt = ( + "Unable to initialize the OAuth 2.0 authorization callback handler: " + "{error_msg} \n You may use --use-device-code to fall back to the " + "device code flow which does not require the callback handler." + ) + + class CapacityNotAvailableError(BotoCoreError): fmt = ( 'Insufficient request capacity available.' diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index 5a7614211a4d..cf7c90630a37 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -23,7 +23,10 @@ import random import re import socket +import string import time +import urllib +import uuid import warnings import weakref from datetime import datetime as _DatetimeClass @@ -48,6 +51,7 @@ zip_longest, ) from botocore.exceptions import ( + AuthorizationCodeLoadError, ClientError, ConfigNotFound, ConnectionClosedError, @@ -3057,20 +3061,16 @@ def __call__(self): return token_file.read() -class SSOTokenFetcher(object): - # The device flow RFC defines the slow down delay to be an additional - # 5 seconds: - # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 - _SLOW_DOWN_DELAY = 5 - # The default interval of 5 is also defined in the RFC (see above link) - _DEFAULT_INTERVAL = 5 +class BaseSSOTokenFetcher(object): + """Base class for SSO token fetchers, for functionality + shared between the device and authorization code grant flows. + """ _EXPIRY_WINDOW = 15 * 60 _CLIENT_REGISTRATION_TYPE = 'public' - _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' def __init__( self, sso_region, client_creator, cache=None, - on_pending_authorization=None, time_fetcher=None, sleep=None, + on_pending_authorization=None, time_fetcher=None ): self._sso_region = sso_region self._client_creator = client_creator @@ -3080,10 +3080,6 @@ def __init__( time_fetcher = self._utc_now self._time_fetcher = time_fetcher - if sleep is None: - sleep = time.sleep - self._sleep = sleep - if cache is None: cache = {} self._cache = cache @@ -3101,6 +3097,15 @@ def _is_expired(self, response): seconds = total_seconds(end_time - self._time_fetcher()) return seconds < self._EXPIRY_WINDOW + def _is_registration_for_auth_code(self, registration): + if ('grantTypes' in registration and + 'authorization_code' in registration['grantTypes']): + return True + + # Else assume that it's device flow, + # since the CLI didn't cache grantTypes previously + return False + @CachedProperty def _client(self): config = botocore.config.Config( @@ -3109,13 +3114,61 @@ def _client(self): ) return self._client_creator('sso-oidc', config=config) - def _register_client(self, session_name, scopes): + def _generate_client_name(self, session_name): if session_name is None: # Use a timestamp for the session name for legacy configuration timestamp = datetime2timestamp(self._time_fetcher()) session_name = int(timestamp) + return f'botocore-client-{str(session_name)}' + + def _registration_cache_key(self, start_url, session_name, scopes): + # Registration is unique based on the following properties to ensure + # modifications to the registration do not affect the permissions of + # tokens derived for other start URLs. + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self._sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + + def _token_cache_key(self, start_url, session_name): + input_str = start_url + if session_name is not None: + input_str = session_name + return hashlib.sha1(input_str.encode('utf-8')).hexdigest() + + +class SSOTokenFetcher(BaseSSOTokenFetcher): + """Performs the device grant OAuth2.0 flow""" + # The device flow RFC defines the slow-down delay to be an additional + # 5 seconds: + # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 + _SLOW_DOWN_DELAY = 5 + # The default interval of 5 is also defined in the RFC (see above link) + _DEFAULT_INTERVAL = 5 + _EXPIRY_WINDOW = 15 * 60 + _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' + + def __init__( + self, sso_region, client_creator, cache=None, + on_pending_authorization=None, time_fetcher=None, sleep=None + ): + super().__init__( + sso_region, client_creator, cache, on_pending_authorization, + time_fetcher + ) + + if sleep is None: + sleep = time.sleep + self._sleep = sleep + + def _register_client(self, session_name, scopes): register_kwargs = { - 'clientName': f'botocore-client-{session_name}', + 'clientName': self._generate_client_name(session_name), 'clientType': self._CLIENT_REGISTRATION_TYPE, } if scopes: @@ -3132,20 +3185,6 @@ def _register_client(self, session_name, scopes): registration['scopes'] = scopes return registration - def _registration_cache_key(self, start_url, session_name, scopes): - # Registration is unique based on the following properties to ensure - # modifications to the registration do not affect the permissions of - # tokens derived for other start URLs. - args = { - 'tool': 'botocore', - 'startUrl': start_url, - 'region': self._sso_region, - 'scopes': scopes, - 'session_name': session_name, - } - cache_args = json.dumps(args, sort_keys=True).encode('utf-8') - return hashlib.sha1(cache_args).hexdigest() - def _registration( self, start_url, @@ -3160,7 +3199,8 @@ def _registration( ) if not force_refresh and cache_key in self._cache: registration = self._cache[cache_key] - if not self._is_expired(registration): + if (not self._is_expired(registration) and + not self._is_registration_for_auth_code(registration)): return registration registration = self._register_client( @@ -3172,7 +3212,7 @@ def _registration( def _authorize_client(self, start_url, registration): # NOTE: The authorization response is not cached. These responses are - # short lived (currently only 10 minutes) and can only be exchanged for + # short-lived (currently only 10 minutes) and can only be exchanged for # a token once. Having multiple clients share this is problematic. response = self._client.start_device_authorization( clientId=registration['clientId'], @@ -3261,12 +3301,6 @@ def _create_token_attempt( raise PendingAuthorizationExpiredError() return interval, None - def _token_cache_key(self, start_url, session_name): - input_str = start_url - if session_name is not None: - input_str = session_name - return hashlib.sha1(input_str.encode('utf-8')).hexdigest() - def _token( self, start_url, @@ -3305,6 +3339,231 @@ def fetch_token( ) +class SSOTokenFetcherAuth(BaseSSOTokenFetcher): + """Performs the authorization code grant with PKCE OAuth2.0 flow""" + _AUTH_GRANT_TYPES = ['authorization_code', 'refresh_token'] + _AUTH_GRANT_DEFAULT_SCOPE = 'sso:account:access' + + def __init__( + self, sso_region, client_creator, auth_code_fetcher, cache=None, + on_pending_authorization=None, time_fetcher=None + ): + super().__init__(sso_region, client_creator, cache, + on_pending_authorization, time_fetcher + ) + + self._auth_code_fetcher = auth_code_fetcher + + # Generate the PKCE pair + self.code_verifier = ''.join( + random.SystemRandom().choice( + string.ascii_letters + string.digits + '-._~' + ) + for _ in range(64) + ) + self.code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(self.code_verifier.encode()).digest()).decode() + + def _register_client(self, session_name, scopes, redirect_uri, issuer_url): + register_kwargs = { + 'clientName': self._generate_client_name(session_name), + 'clientType': self._CLIENT_REGISTRATION_TYPE, + 'grantTypes': self._AUTH_GRANT_TYPES, + 'redirectUris': [redirect_uri], + 'issuerUrl': issuer_url + } + + if scopes: + register_kwargs['scopes'] = scopes + else: + register_kwargs['scopes'] = [self._AUTH_GRANT_DEFAULT_SCOPE] + + response = self._client.register_client(**register_kwargs) + + expires_at = response['clientSecretExpiresAt'] + expires_at = datetime.datetime.fromtimestamp(expires_at, tzutc()) + registration = { + 'clientId': response['clientId'], + 'clientSecret': response['clientSecret'], + 'expiresAt': expires_at, + 'scopes': register_kwargs['scopes'], + 'grantTypes': register_kwargs['grantTypes'] + } + + return registration + + def _registration( + self, + start_url, + session_name, + scopes, + force_refresh=False, + ): + cache_key = self._registration_cache_key( + start_url, + session_name, + scopes, + ) + if not force_refresh and cache_key in self._cache: + registration = self._cache[cache_key] + if (not self._is_expired(registration) and + self._is_registration_for_auth_code(registration)): + return registration + + registration = self._register_client( + session_name, + scopes, + self._auth_code_fetcher.redirect_uri_without_port(), + start_url + ) + self._cache[cache_key] = registration + return registration + + def _extract_resolved_endpoint(self, params, **kwargs): + """Event handler for before-call that will extract the resolved endpoint + for a given request without actually running it + """ + # This will contain any path and query params specific to + # the operation/input, so extract just the scheme and hostname + if params['url']: + parsed = urlparse(params['url']) + self._base_endpoint = f'{parsed.scheme}://{parsed.netloc}' + + # Return a tuple containing the "response" to short-circuit the request + return botocore.awsrequest.AWSResponse(None, 200, {}, None), {} + + def _get_base_authorization_uri(self): + """Simulates an SSO-OIDC request so that we can extract the "base" + endpoint for the current client to use for the un-modeled Authorize + operation + """ + self._client.meta.events.register('before-call', + self._extract_resolved_endpoint) + self._client.register_client( + clientName='temp', + clientType='public' + ) + self._client.meta.events.unregister('before-call', + self._extract_resolved_endpoint) + + return self._base_endpoint + + def _get_authorization_uri( + self, + client_id, + registration_scopes, + expected_state): + + query_params = { + 'response_type': 'code', + 'client_id': client_id, + 'redirect_uri': self._auth_code_fetcher.redirect_uri_with_port(), + 'state': expected_state, + 'code_challenge_method': 'S256' + # Don't want to encode code_challenge again, so we append below + } + + # For the query param, scopes must be space separated before encoding + if registration_scopes: + query_params['scope'] = " ".join(registration_scopes) + + return ( + f'{self._get_base_authorization_uri()}/authorize?' + f'{urllib.parse.urlencode(query_params)}' + f'&code_challenge={self.code_challenge[:-1]}' # trim final '=' + ) + + def _get_new_token(self, start_url, session_name, registration_scopes): + registration = self._registration( + start_url, + session_name, + registration_scopes, + ) + + expected_state = uuid.uuid4() + + authorization_uri = self._get_authorization_uri( + registration['clientId'], + registration_scopes, + expected_state) + + # Even though there's just one uri, this matches the inputs + # for the device code flow so that we can reuse the browser handlers + authorization_args = { + 'verificationUri': authorization_uri, + 'verificationUriComplete': authorization_uri, + 'userCode': None + } + + # Open/display the link, then block until the redirect uri is hit and + # the auth code is retrieved + self._on_pending_authorization(**authorization_args) + auth_code, state = self._auth_code_fetcher.get_auth_code_and_state() + + if auth_code is None: + raise AuthorizationCodeLoadError( + error_msg='Failed to retrieve an authorization code.' + ) + + if state != expected_state: + raise AuthorizationCodeLoadError( + error_msg='State parameter does not match expected value.' + ) + + return self._create_token_(start_url, registration, auth_code) + + def _create_token_(self, start_url, registration, auth_code): + try: + response = self._client.create_token( + grantType='authorization_code', + clientId=registration['clientId'], + clientSecret=registration['clientSecret'], + redirectUri=self._auth_code_fetcher.redirect_uri_with_port(), + codeVerifier=self.code_verifier, + code=auth_code + ) + expires_in = datetime.timedelta(seconds=response['expiresIn']) + token = { + 'startUrl': start_url, + 'region': self._sso_region, + 'accessToken': response['accessToken'], + 'expiresAt': self._time_fetcher() + expires_in, + # Cache the registration alongside the token + 'clientId': registration['clientId'], + 'clientSecret': registration['clientSecret'], + 'registrationExpiresAt': registration['expiresAt'], + } + if 'refreshToken' in response: + token['refreshToken'] = response['refreshToken'] + return token + except self._client.exceptions.InvalidGrantException as error: + print(error) + except self._client.exceptions.ExpiredTokenException: + raise PendingAuthorizationExpiredError() + + def fetch_token( + self, + start_url, + force_refresh=False, + registration_scopes=None, + session_name=None, + ): + cache_key = self._token_cache_key(start_url, session_name) + # Only obey the token cache if we are not forcing a refresh. + if not force_refresh and cache_key in self._cache: + token = self._cache[cache_key] + if not self._is_expired(token): + return token + + token = self._get_new_token( + start_url, + session_name, + registration_scopes + ) + self._cache[cache_key] = token + return token + + class SSOTokenLoader(object): def __init__(self, cache=None): if cache is None: diff --git a/awscli/customizations/sso/index.html b/awscli/customizations/sso/index.html new file mode 100644 index 000000000000..f1572f4b04f9 --- /dev/null +++ b/awscli/customizations/sso/index.html @@ -0,0 +1,176 @@ + + + + + AWS Authentication + + + + + +
+
+ + + + + +
+
+ +
+
+ +
+

Request approved

+

+
+
+

+
+ + + +
+
+ + + diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index 4d7790b62c76..acd58e287dd4 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -52,6 +52,7 @@ def _run_main(self, parsed_args, parsed_globals): force_refresh=True, session_name=sso_config.get('session_name'), registration_scopes=sso_config.get('registration_scopes'), + fallback_to_device_flow=parsed_args.use_device_code ) success_msg = 'Successfully logged into Start URL: %s\n' uni_print(success_msg % sso_config['sso_start_url']) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index ae9a83e9e8b8..d7d813bf11cc 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -11,19 +11,22 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import datetime -import os -import logging import json +import logging +import os import webbrowser +from functools import partial +from http.server import HTTPServer, BaseHTTPRequestHandler +from urllib.parse import urlparse, parse_qs -from botocore.utils import SSOTokenFetcher -from botocore.utils import original_ld_library_path from botocore.credentials import JSONFileCache +from botocore.exceptions import AuthCodeFetcherError +from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth +from botocore.utils import original_ld_library_path from awscli.customizations.commands import BasicCommand -from awscli.customizations.utils import uni_print -from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR from awscli.customizations.exceptions import ConfigurationError +from awscli.customizations.utils import uni_print LOG = logging.getLogger(__name__) @@ -37,9 +40,18 @@ 'action': 'store_true', 'default': False, 'help_text': ( - 'Disables automatically opening the verfication URL in the ' + 'Disables automatically opening the verification URL in the ' 'default browser.' ) + }, + { + 'name': 'use-device-code', + 'action': 'store_true', + 'default': False, + 'help_text': ( + 'Uses the Device Code authorization grant and login flow ' + 'instead of the Authorization Code flow.' + ) } ] @@ -56,24 +68,38 @@ def _sso_json_dumps(obj): def do_sso_login(session, sso_region, start_url, token_cache=None, on_pending_authorization=None, force_refresh=False, - registration_scopes=None, session_name=None): + registration_scopes=None, session_name=None, + fallback_to_device_flow=False): if token_cache is None: token_cache = JSONFileCache(SSO_TOKEN_DIR, dumps_func=_sso_json_dumps) if on_pending_authorization is None: on_pending_authorization = OpenBrowserHandler( open_browser=open_browser_with_original_ld_path ) - token_fetcher = SSOTokenFetcher( - sso_region=sso_region, - client_creator=session.create_client, - cache=token_cache, - on_pending_authorization=on_pending_authorization - ) + + # For the auth flow, we need a non-legacy sso-session and check that the + # user hasn't opted into falling back to the device code flow + if session_name and not fallback_to_device_flow: + token_fetcher = SSOTokenFetcherAuth( + sso_region=sso_region, + client_creator=session.create_client, + auth_code_fetcher=AuthCodeFetcher(), + cache=token_cache, + on_pending_authorization=on_pending_authorization + ) + else: + token_fetcher = SSOTokenFetcher( + sso_region=sso_region, + client_creator=session.create_client, + cache=token_cache, + on_pending_authorization=on_pending_authorization + ) + return token_fetcher.fetch_token( start_url=start_url, session_name=session_name, force_refresh=force_refresh, - registration_scopes=registration_scopes, + registration_scopes=registration_scopes ) @@ -110,13 +136,19 @@ def __call__( f'Browser will not be automatically opened.\n' f'Please visit the following URL:\n' f'\n{verificationUri}\n' + + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' f'\nAlternatively, you may visit the following URL which will ' f'autofill the code upon loading:' - f'\n{verificationUriComplete}\n' - ) + f'\n{verificationUriComplete}\n') + uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) class OpenBrowserHandler(BaseAuthorizationhandler): @@ -135,10 +167,16 @@ def __call__( f'to use a different device to authorize this request, open the ' f'following URL:\n' f'\n{verificationUri}\n' + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' ) uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) + if self._open_browser: try: return self._open_browser(verificationUriComplete) @@ -146,6 +184,73 @@ def __call__( LOG.debug('Failed to open browser:', exc_info=True) +class AuthCodeFetcher: + """Manages the local web server that will be used + to retrieve the authorization code from the OAuth callback + """ + def __init__(self): + self._auth_code = None + self._state = None + self._is_done = False + + # We do this so that the request handler can have a reference to this + # AuthCodeFetcher so that it can pass back the state and auth code + try: + handler = partial(OAuthCallbackHandler, self) + self.http_server = HTTPServer(('', 0), handler) + except Exception as e: + raise AuthCodeFetcherError(error_msg=e) + + def redirect_uri_without_port(self): + return 'http://127.0.0.1/oauth/callback' + + def redirect_uri_with_port(self): + return f'http://127.0.0.1:{self.http_server.server_port}/oauth/callback' + + def get_auth_code_and_state(self): + """Blocks until the expected redirect request with either the + authorization code/state or and error is handled + """ + while not self._is_done: + self.http_server.handle_request() + self.http_server.server_close() + + return self._auth_code, self._state + + +class OAuthCallbackHandler(BaseHTTPRequestHandler): + """HTTP handler to handle OAuth callback requests, extracting + the auth code and state parameters, and displaying a page directing + the user to return to the CLI. + """ + def __init__(self, auth_code_fetcher, *args, **kwargs): + self._auth_code_fetcher = auth_code_fetcher + super().__init__(*args, **kwargs) + + def log_message(self, format, *args): + # Suppress built-in logging, otherwise it prints + # each request to console + pass + + def do_GET(self): + self.send_response(200) + self.end_headers() + with open( + os.path.join(os.path.dirname(__file__), 'index.html'), + 'rb') as file: + self.wfile.write(file.read()) + + query_params = parse_qs(urlparse(self.path).query) + + if 'error' in query_params: + self._auth_code_fetcher._is_done = True + return + elif 'code' in query_params and 'state' in query_params: + self._auth_code_fetcher._is_done = True + self._auth_code_fetcher._auth_code = query_params['code'][0] + self._auth_code_fetcher._state = query_params['state'][0] + + class InvalidSSOConfigError(ConfigurationError): pass diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index 98036f9173b8..ace3e5e63b03 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -13,7 +13,7 @@ import time from awscli.clidriver import AWSCLIEntryPoint -from awscli.customizations.sso.utils import OpenBrowserHandler +from awscli.customizations.sso.utils import OpenBrowserHandler, AuthCodeFetcher from awscli.testutils import create_clidriver from awscli.testutils import FileCreator from awscli.testutils import BaseAWSCommandParamsTest @@ -45,6 +45,29 @@ def setUp(self): self.open_browser_mock, ) self.open_browser_patch.start() + + self.fetcher_mock = mock.Mock(spec=AuthCodeFetcher) + self.fetcher_mock.return_value.redirect_uri_without_port.return_value = ( + 'http://127.0.0.1/oauth/callback' + ) + self.fetcher_mock.return_value.redirect_uri_with_port.return_value = ( + 'http://127.0.0.1:55555/oauth/callback' + ) + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", "00000000-0000-0000-0000-000000000000" + ) + self.auth_code_fetcher_patch = mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher', + self.fetcher_mock, + ) + self.auth_code_fetcher_patch.start() + + self.uuid_mock = mock.Mock( + return_value="00000000-0000-0000-0000-000000000000" + ) + self.uuid_patch = mock.patch('uuid.uuid4', self.uuid_mock) + self.uuid_patch.start() + self.expires_in = 28800 self.expiration_time = time.time() + 1000 @@ -52,6 +75,8 @@ def tearDown(self): super(BaseSSOTest, self).tearDown() self.files.remove_all() self.open_browser_patch.stop() + self.auth_code_fetcher_patch.stop() + self.uuid_patch.stop() self.token_cache_dir_patch.stop() def assert_used_expected_sso_region(self, expected_region): diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index e24c47591716..f8684908db21 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -23,8 +23,8 @@ class TestLoginCommand(BaseSSOTest): r'\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ' ) - def add_oidc_workflow_responses(self, access_token, - include_register_response=True): + def add_oidc_device_responses(self, access_token, + include_register_response=True): responses = [ # StartDeviceAuthorization response { @@ -53,12 +53,50 @@ def add_oidc_workflow_responses(self, access_token, 0, { 'clientSecretExpiresAt': self.expiration_time, - 'clientId': 'foo-client-id', - 'clientSecret': 'foo-client-secret', + 'clientId': 'device-client-id', + 'clientSecret': 'device-client-secret', } ) self.parsed_responses = responses + def add_oidc_auth_code_responses(self, access_token, + include_register_response=True): + responses = [ + # CreateToken responses + { + 'expiresIn': self.expires_in, + 'tokenType': 'Bearer', + 'accessToken': access_token, + } + ] + if include_register_response: + responses.insert( + 0, + { + 'clientSecretExpiresAt': self.expiration_time, + 'clientId': 'auth-client-id', + 'clientSecret': 'auth-client-secret', + } + ) + self.parsed_responses = responses + + def assert_cache_contains_registration( + self, + start_url, + session_name, + scopes, + expected_client_id): + cached_files = os.listdir(self.token_cache_dir) + + cached_registration_filename = self._get_cached_registration_filename( + start_url, session_name, scopes) + + self.assertIn(cached_registration_filename, cached_files) + self.assertEqual( + self._get_cached_response(cached_registration_filename)['clientId'], + expected_client_id + ) + def assert_cache_contains_token( self, start_url, @@ -78,6 +116,17 @@ def assert_cache_contains_token( expected_token ) + def _get_cached_registration_filename(self, start_url, session_name, scopes): + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self.sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + '.json' + def _get_cached_token_filename(self, start_url, session_name): to_hash = start_url if session_name: @@ -101,8 +150,19 @@ def assert_cache_token_expiration_time_format_is_correct(self): ) ) - def test_login(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_explicit_device(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token + ) + + def test_login_implicit_device(self): + # This is a legacy profile via setUp, so we expect + # it to fall back to device flow automatically + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( @@ -110,9 +170,9 @@ def test_login(self): expected_token=self.access_token ) - def test_login_no_browser(self): - self.add_oidc_workflow_responses(self.access_token) - stdout, _, _ = self.run_cmd('sso login --no-browser') + def test_login_device_no_browser(self): + self.add_oidc_device_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --use-device-code --no-browser') self.assertIn('Browser will not be automatically opened.', stdout) self.open_browser_mock.assert_not_called() self.assert_used_expected_sso_region(expected_region=self.sso_region) @@ -121,20 +181,86 @@ def test_login_no_browser(self): expected_token=self.access_token ) - def test_login_forces_refresh(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_auth_no_browser(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --no-browser') + self.assertIn('Browser will not be automatically opened.', stdout) + self.open_browser_mock.assert_not_called() + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token, + session_name='test-session' + ) + + def test_login_device_forces_refresh(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_device_responses( + 'new.token', include_register_response=False) + self.run_cmd('sso login --use-device-code') + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + ) + + def test_login_auth_forces_refresh(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') # The register response from the first login should have been # cached. - self.add_oidc_workflow_responses( + self.add_oidc_auth_code_responses( 'new.token', include_register_response=False) self.run_cmd('sso login') self.assert_cache_contains_token( start_url=self.start_url, - expected_token='new.token' + expected_token='new.token', + session_name='test-session' + ) + + def test_login_auth_after_device_forces_refresh(self): + self.set_config_file_content( + content=self.get_sso_session_config('test-session')) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_auth_code_responses('new.token') + self.run_cmd('sso login') + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + session_name='test-session' + ) + + def test_login_device_no_sso_configuration(self): + self.set_config_file_content(content='') + _, stderr, _ = self.run_cmd('sso login --use-device-code', + expected_rc=253) + self.assertIn( + 'Missing the following required SSO configuration', + stderr ) - def test_login_no_sso_configuration(self): + def test_login_auth_no_sso_configuration(self): self.set_config_file_content(content='') _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn( @@ -142,14 +268,14 @@ def test_login_no_sso_configuration(self): stderr ) - def test_login_minimal_sso_configuration(self): + def test_login_device_minimal_sso_configuration(self): content = ( '[default]\n' 'sso_start_url={start_url}\n' 'sso_region={sso_region}\n' ).format(start_url=self.start_url, sso_region=self.sso_region) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( @@ -157,13 +283,15 @@ def test_login_minimal_sso_configuration(self): expected_token=self.access_token ) - def test_login_partially_missing_sso_configuration(self): + def test_login_device_partially_missing_sso_configuration(self): content = ( '[default]\n' 'sso_start_url=%s\n' % self.start_url ) self.set_config_file_content(content=content) - _, stderr, _ = self.run_cmd('sso login', expected_rc=253) + _, stderr, _ = self.run_cmd( + 'sso login --use-device-code', expected_rc=253 + ) self.assertIn( 'Missing the following required SSO configuration', stderr @@ -174,8 +302,8 @@ def test_login_partially_missing_sso_configuration(self): self.assertNotIn('sso_role_name', stderr) def test_token_cache_datetime_format(self): - self.add_oidc_workflow_responses(self.access_token) - self.run_cmd('sso login') + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( start_url=self.start_url, @@ -183,38 +311,115 @@ def test_token_cache_datetime_format(self): ) self.assert_cache_token_expiration_time_format_is_correct() - def test_login_sso_session(self): + def test_login_device_sso_session(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_auth_sso_session(self): content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_device_sso_with_explicit_sso_session_arg(self): + content = self.get_sso_session_config( + 'test-session', include_profile=False) + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --sso-session test-session --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) - def test_login_sso_with_explicit_sso_session_arg(self): + def test_login_auth_sso_with_explicit_sso_session_arg(self): content = self.get_sso_session_config( 'test-session', include_profile=False) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login --sso-session test-session') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_device_sso_session_with_scopes(self): + self.registration_scopes = ['sso:foo', 'sso:bar'] + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) + operation, params = self.operations_called[0] + self.assertEqual(operation.name, 'RegisterClient') + self.assertEqual(params.get('scopes'), self.registration_scopes) - def test_login_sso_session_with_scopes(self): + def test_login_auth_sso_session_with_scopes(self): self.registration_scopes = ['sso:foo', 'sso:bar'] content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', @@ -245,3 +450,36 @@ def test_login_sso_session_missing(self): self.set_config_file_content(content=content) _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn('sso-session does not exist: "test"', stderr) + + def test_login_auth_sso_no_authorization_code_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + None, None + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'Failed to retrieve an authorization code.', + stderr + ) + + def test_login_auth_sso_state_mismatch_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", '00000000-0000-0000-0000-000000000001' + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'State parameter does not match expected value.', + stderr + ) + diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 47f5762b3d8c..7a24acee3baf 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -11,22 +11,23 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import os +import threading import webbrowser import pytest - -from awscli.testutils import mock -from awscli.testutils import unittest - +import urllib3 from botocore.session import Session -from botocore.exceptions import ClientError from awscli.compat import StringIO -from awscli.customizations.sso.utils import parse_sso_registration_scopes -from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler +from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import open_browser_with_original_ld_path +from awscli.customizations.sso.utils import ( + parse_sso_registration_scopes, AuthCodeFetcher +) +from awscli.testutils import mock +from awscli.testutils import unittest @pytest.mark.parametrize( @@ -205,3 +206,45 @@ def test_can_patch_env(self): os.environ) open_browser_with_original_ld_path('http://example.com') self.assertIsNone(captured_env.get('LD_LIBRARY_PATH')) + + +class TestAuthCodeFetcher(unittest.TestCase): + """Tests for the AuthCodeFetcher class, which is the local + web server we use to handle the OAuth 2.0 callback + """ + + def setUp(self): + self.fetcher = AuthCodeFetcher() + self.url = f'http://127.0.0.1:{self.fetcher.http_server.server_address[1]}/' + + # Start the server on a background thread so that + # the test thread can make the request + self.server_thread = threading.Thread( + target=self.fetcher.get_auth_code_and_state + ) + self.server_thread.setDaemon(True) + self.server_thread.start() + + def test_expected_auth_code(self): + expected_code = '1234' + expected_state = '4567' + url = self.url + f'?code={expected_code}&state={expected_state}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + self.assertEqual(response.status, 200) + self.assertEqual(self.fetcher._auth_code, expected_code) + self.assertEqual(self.fetcher._state, expected_state) + + def test_error(self): + expected_code = 'Failed' + url = self.url + f'?error={expected_code}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + self.assertEqual(response.status, 200) + self.assertEqual(self.fetcher._auth_code, None) + self.assertEqual(self.fetcher._state, None) +