From e0ee9ebe032482c371b96e3b830fe23dfe950dc6 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Thu, 26 Sep 2024 10:08:10 -0400 Subject: [PATCH 1/2] Add the OAuth Authorization Code Flow with PKCE This adds support for the OAuth2.0 authorization code flow with PKCE to the aws sso login command. It is the new default behavior, but users can fall back to the device code flow using the new --use-device-code option. --- .changes/next-release/feature-sso-81096.json | 5 + awscli/botocore/exceptions.py | 14 +- awscli/botocore/utils.py | 368 +++++++++++++++++-- awscli/customizations/sso/index.html | 171 +++++++++ awscli/customizations/sso/login.py | 1 + awscli/customizations/sso/utils.py | 164 ++++++++- tests/functional/sso/__init__.py | 61 ++- tests/functional/sso/test_login.py | 323 ++++++++++++++-- tests/unit/customizations/sso/test_utils.py | 130 ++++++- 9 files changed, 1150 insertions(+), 87 deletions(-) create mode 100644 .changes/next-release/feature-sso-81096.json create mode 100644 awscli/customizations/sso/index.html diff --git a/.changes/next-release/feature-sso-81096.json b/.changes/next-release/feature-sso-81096.json new file mode 100644 index 000000000000..53acbc72cc32 --- /dev/null +++ b/.changes/next-release/feature-sso-81096.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "``sso``", + "description": "Add support and default to the OAuth 2.0 Authorization Code Flow with PKCE for ``aws sso login``." +} diff --git a/awscli/botocore/exceptions.py b/awscli/botocore/exceptions.py index f1f7de43cfd5..9adbf32070ef 100644 --- a/awscli/botocore/exceptions.py +++ b/awscli/botocore/exceptions.py @@ -674,7 +674,7 @@ class SSOError(BotoCoreError): class PendingAuthorizationExpiredError(SSOError): fmt = ( "The pending authorization to retrieve an SSO token has expired. The " - "device authorization flow to retrieve an SSO token must be restarted." + "login flow to retrieve an SSO token must be restarted." ) @@ -682,6 +682,10 @@ class SSOTokenLoadError(SSOError): fmt = "Error loading SSO Token: {error_msg}" +class AuthorizationCodeLoadError(SSOError): + fmt = "Error loading authorization code: {error_msg}" + + class UnauthorizedSSOTokenError(SSOError): fmt = ( "The SSO session associated with this profile has expired or is " @@ -690,6 +694,14 @@ class UnauthorizedSSOTokenError(SSOError): ) +class AuthCodeFetcherError(SSOError): + fmt = ( + "Unable to initialize the OAuth 2.0 authorization callback handler: " + "{error_msg} \n You may use --use-device-code to fall back to the " + "device code flow which does not require the callback handler." + ) + + class CapacityNotAvailableError(BotoCoreError): fmt = ( 'Insufficient request capacity available.' diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index 5a7614211a4d..affcf155baa6 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -22,8 +22,11 @@ import os import random import re +import secrets import socket +import string import time +import uuid import warnings import weakref from datetime import datetime as _DatetimeClass @@ -48,6 +51,7 @@ zip_longest, ) from botocore.exceptions import ( + AuthorizationCodeLoadError, ClientError, ConfigNotFound, ConnectionClosedError, @@ -3057,20 +3061,16 @@ def __call__(self): return token_file.read() -class SSOTokenFetcher(object): - # The device flow RFC defines the slow down delay to be an additional - # 5 seconds: - # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 - _SLOW_DOWN_DELAY = 5 - # The default interval of 5 is also defined in the RFC (see above link) - _DEFAULT_INTERVAL = 5 +class BaseSSOTokenFetcher(object): + """Base class for SSO token fetchers, for functionality + shared between the device and authorization code grant flows. + """ _EXPIRY_WINDOW = 15 * 60 _CLIENT_REGISTRATION_TYPE = 'public' - _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' def __init__( self, sso_region, client_creator, cache=None, - on_pending_authorization=None, time_fetcher=None, sleep=None, + on_pending_authorization=None, time_fetcher=None ): self._sso_region = sso_region self._client_creator = client_creator @@ -3080,14 +3080,19 @@ def __init__( time_fetcher = self._utc_now self._time_fetcher = time_fetcher - if sleep is None: - sleep = time.sleep - self._sleep = sleep - if cache is None: cache = {} self._cache = cache + def fetch_token( + self, + start_url, + force_refresh, + registration_scopes, + session_name, + ): + raise NotImplementedError('Must implement fetch_token()') + def _utc_now(self): return datetime.datetime.now(tzutc()) @@ -3101,6 +3106,17 @@ def _is_expired(self, response): seconds = total_seconds(end_time - self._time_fetcher()) return seconds < self._EXPIRY_WINDOW + def _is_registration_for_auth_code(self, registration): + if ( + 'grantTypes' in registration + and 'authorization_code' in registration['grantTypes'] + ): + return True + + # Else assume that it's device flow, + # since the CLI didn't cache grantTypes previously + return False + @CachedProperty def _client(self): config = botocore.config.Config( @@ -3109,13 +3125,69 @@ def _client(self): ) return self._client_creator('sso-oidc', config=config) - def _register_client(self, session_name, scopes): + def _generate_client_name(self, session_name): if session_name is None: # Use a timestamp for the session name for legacy configuration timestamp = datetime2timestamp(self._time_fetcher()) session_name = int(timestamp) + return f'botocore-client-{str(session_name)}' + + def _registration_cache_key(self, start_url, session_name, scopes): + # Registration is unique based on the following properties to ensure + # modifications to the registration do not affect the permissions of + # tokens derived for other start URLs. + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self._sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + + def _token_cache_key(self, start_url, session_name): + input_str = start_url + if session_name is not None: + input_str = session_name + return hashlib.sha1(input_str.encode('utf-8')).hexdigest() + + +class SSOTokenFetcher(BaseSSOTokenFetcher): + """Performs the device grant OAuth2.0 flow""" + # The device flow RFC defines the slow-down delay to be an additional + # 5 seconds: + # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 + _SLOW_DOWN_DELAY = 5 + # The default interval of 5 is also defined in the RFC (see above link) + _DEFAULT_INTERVAL = 5 + _EXPIRY_WINDOW = 15 * 60 + _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' + + def __init__( + self, + sso_region, + client_creator, + cache=None, + on_pending_authorization=None, + time_fetcher=None, + sleep=None, + ): + super().__init__( + sso_region, + client_creator, + cache, + on_pending_authorization, + time_fetcher, + ) + + if sleep is None: + sleep = time.sleep + self._sleep = sleep + + def _register_client(self, session_name, scopes): register_kwargs = { - 'clientName': f'botocore-client-{session_name}', + 'clientName': self._generate_client_name(session_name), 'clientType': self._CLIENT_REGISTRATION_TYPE, } if scopes: @@ -3132,20 +3204,6 @@ def _register_client(self, session_name, scopes): registration['scopes'] = scopes return registration - def _registration_cache_key(self, start_url, session_name, scopes): - # Registration is unique based on the following properties to ensure - # modifications to the registration do not affect the permissions of - # tokens derived for other start URLs. - args = { - 'tool': 'botocore', - 'startUrl': start_url, - 'region': self._sso_region, - 'scopes': scopes, - 'session_name': session_name, - } - cache_args = json.dumps(args, sort_keys=True).encode('utf-8') - return hashlib.sha1(cache_args).hexdigest() - def _registration( self, start_url, @@ -3160,7 +3218,10 @@ def _registration( ) if not force_refresh and cache_key in self._cache: registration = self._cache[cache_key] - if not self._is_expired(registration): + if ( + not self._is_expired(registration) + and not self._is_registration_for_auth_code(registration) + ): return registration registration = self._register_client( @@ -3172,7 +3233,7 @@ def _registration( def _authorize_client(self, start_url, registration): # NOTE: The authorization response is not cached. These responses are - # short lived (currently only 10 minutes) and can only be exchanged for + # short-lived (currently only 10 minutes) and can only be exchanged for # a token once. Having multiple clients share this is problematic. response = self._client.start_device_authorization( clientId=registration['clientId'], @@ -3261,12 +3322,6 @@ def _create_token_attempt( raise PendingAuthorizationExpiredError() return interval, None - def _token_cache_key(self, start_url, session_name): - input_str = start_url - if session_name is not None: - input_str = session_name - return hashlib.sha1(input_str.encode('utf-8')).hexdigest() - def _token( self, start_url, @@ -3305,6 +3360,245 @@ def fetch_token( ) +class SSOTokenFetcherAuth(BaseSSOTokenFetcher): + """Performs the authorization code grant with PKCE OAuth2.0 flow""" + _AUTH_GRANT_TYPES = ('authorization_code', 'refresh_token') + _AUTH_GRANT_DEFAULT_SCOPE = 'sso:account:access' + + def __init__( + self, + sso_region, + client_creator, + auth_code_fetcher, + cache=None, + on_pending_authorization=None, + time_fetcher=None, + ): + super().__init__( + sso_region, + client_creator, + cache, + on_pending_authorization, + time_fetcher, + ) + + self._auth_code_fetcher = auth_code_fetcher + + # Generate the PKCE pair + self.code_verifier = ''.join( + secrets.choice( + string.ascii_letters + string.digits + '-._~' + ) + for _ in range(64) + ) + self.code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(self.code_verifier.encode()).digest() + ).decode() + + def _register_client(self, session_name, scopes, redirect_uri, issuer_url): + register_kwargs = { + 'clientName': self._generate_client_name(session_name), + 'clientType': self._CLIENT_REGISTRATION_TYPE, + 'grantTypes': self._AUTH_GRANT_TYPES, + 'redirectUris': [redirect_uri], + 'issuerUrl': issuer_url, + 'scopes': scopes or [self._AUTH_GRANT_DEFAULT_SCOPE], + } + + response = self._client.register_client(**register_kwargs) + + expires_at = response['clientSecretExpiresAt'] + expires_at = datetime.datetime.fromtimestamp(expires_at, tzutc()) + registration = { + 'clientId': response['clientId'], + 'clientSecret': response['clientSecret'], + 'expiresAt': expires_at, + 'scopes': register_kwargs['scopes'], + 'grantTypes': register_kwargs['grantTypes'] + } + + return registration + + def _registration( + self, + start_url, + session_name, + scopes, + force_refresh=False, + ): + cache_key = self._registration_cache_key( + start_url, + session_name, + scopes, + ) + if not force_refresh and cache_key in self._cache: + registration = self._cache[cache_key] + if ( + not self._is_expired(registration) + and self._is_registration_for_auth_code(registration) + ): + return registration + + registration = self._register_client( + session_name, + scopes, + self._auth_code_fetcher.redirect_uri_without_port(), + start_url + ) + self._cache[cache_key] = registration + return registration + + def _extract_resolved_endpoint(self, params, **kwargs): + """Event handler for before-call that will extract the resolved endpoint + for a given request without actually running it + """ + # This will contain any path and query params specific to + # the operation/input, so extract just the scheme and hostname + if params['url']: + parsed = urlparse(params['url']) + self._base_endpoint = f'{parsed.scheme}://{parsed.netloc}' + + # Return a tuple containing the "response" to short-circuit the request + return botocore.awsrequest.AWSResponse(None, 200, {}, None), {} + + def _get_base_authorization_uri(self): + """Simulates an SSO-OIDC request so that we can extract the "base" + endpoint for the current client to use for the un-modeled Authorize + operation + """ + self._client.meta.events.register( + 'before-call', self._extract_resolved_endpoint + ) + self._client.register_client( + clientName='temp', + clientType='public' + ) + self._client.meta.events.unregister( + 'before-call', self._extract_resolved_endpoint + ) + + return self._base_endpoint + + def _get_authorization_uri( + self, + client_id, + registration_scopes, + expected_state + ): + + query_params = { + 'response_type': 'code', + 'client_id': client_id, + 'redirect_uri': self._auth_code_fetcher.redirect_uri_with_port(), + 'state': expected_state, + 'code_challenge_method': 'S256' + # Don't want to encode code_challenge again, so we append below + } + + # For the query param, scopes must be space separated before encoding + if registration_scopes: + query_params['scopes'] = " ".join(registration_scopes) + else: + query_params['scopes'] = self._AUTH_GRANT_DEFAULT_SCOPE + + return ( + f'{self._get_base_authorization_uri()}/authorize?' + f'{percent_encode_sequence(query_params)}' + f'&code_challenge={self.code_challenge[:-1]}' # trim final '=' + ) + + def _get_new_token(self, start_url, session_name, registration_scopes): + registration = self._registration( + start_url, + session_name, + registration_scopes, + ) + + expected_state = uuid.uuid4() + + authorization_uri = self._get_authorization_uri( + registration['clientId'], + registration_scopes, + expected_state + ) + + # Even though there's just one URI, this matches the inputs + # for the device code flow so that we can reuse the browser handlers + authorization_args = { + 'verificationUri': authorization_uri, + 'verificationUriComplete': authorization_uri, + 'userCode': None + } + + # Open/display the link, then block until the redirect uri is hit and + # the auth code is retrieved + self._on_pending_authorization(**authorization_args) + auth_code, state = self._auth_code_fetcher.get_auth_code_and_state() + + if auth_code is None: + raise AuthorizationCodeLoadError( + error_msg='Failed to retrieve an authorization code.' + ) + + # The state we get back from the redirect is just a string, so + # cast our original UUID before comparing + if state != str(expected_state): + raise AuthorizationCodeLoadError( + error_msg='State parameter does not match expected value.' + ) + + return self._create_token_(start_url, registration, auth_code) + + def _create_token_(self, start_url, registration, auth_code): + try: + response = self._client.create_token( + grantType='authorization_code', + clientId=registration['clientId'], + clientSecret=registration['clientSecret'], + redirectUri=self._auth_code_fetcher.redirect_uri_with_port(), + codeVerifier=self.code_verifier, + code=auth_code + ) + expires_in = datetime.timedelta(seconds=response['expiresIn']) + token = { + 'startUrl': start_url, + 'region': self._sso_region, + 'accessToken': response['accessToken'], + 'expiresAt': self._time_fetcher() + expires_in, + # Cache the registration alongside the token + 'clientId': registration['clientId'], + 'clientSecret': registration['clientSecret'], + 'registrationExpiresAt': registration['expiresAt'], + } + if 'refreshToken' in response: + token['refreshToken'] = response['refreshToken'] + return token + except self._client.exceptions.ExpiredTokenException: + raise PendingAuthorizationExpiredError() + + def fetch_token( + self, + start_url, + force_refresh=False, + registration_scopes=None, + session_name=None, + ): + cache_key = self._token_cache_key(start_url, session_name) + # Only obey the token cache if we are not forcing a refresh. + if not force_refresh and cache_key in self._cache: + token = self._cache[cache_key] + if not self._is_expired(token): + return token + + token = self._get_new_token( + start_url, + session_name, + registration_scopes + ) + self._cache[cache_key] = token + return token + + class SSOTokenLoader(object): def __init__(self, cache=None): if cache is None: diff --git a/awscli/customizations/sso/index.html b/awscli/customizations/sso/index.html new file mode 100644 index 000000000000..214a32fe1fa6 --- /dev/null +++ b/awscli/customizations/sso/index.html @@ -0,0 +1,171 @@ + + + + + AWS Authentication + + + + + +
+
+ + + + + +
+
+ +
+
+ +
+

Request approved

+

+
+
+

+
+ + + +
+
+ + + diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index 4d7790b62c76..2c7917b82660 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -52,6 +52,7 @@ def _run_main(self, parsed_args, parsed_globals): force_refresh=True, session_name=sso_config.get('session_name'), registration_scopes=sso_config.get('registration_scopes'), + use_device_code=parsed_args.use_device_code, ) success_msg = 'Successfully logged into Start URL: %s\n' uni_print(success_msg % sso_config['sso_start_url']) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index ae9a83e9e8b8..1e9aef74d120 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -11,19 +11,29 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import datetime -import os -import logging import json +import logging +import os +import socket +import time import webbrowser +from functools import partial +from http.server import HTTPServer, BaseHTTPRequestHandler -from botocore.utils import SSOTokenFetcher -from botocore.utils import original_ld_library_path +from botocore.compat import urlparse, parse_qs from botocore.credentials import JSONFileCache +from botocore.exceptions import ( + AuthCodeFetcherError, + PendingAuthorizationExpiredError, +) +from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth +from botocore.utils import original_ld_library_path -from awscli.customizations.commands import BasicCommand -from awscli.customizations.utils import uni_print +from awscli import __version__ as awscli_version from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR +from awscli.customizations.commands import BasicCommand from awscli.customizations.exceptions import ConfigurationError +from awscli.customizations.utils import uni_print LOG = logging.getLogger(__name__) @@ -37,9 +47,18 @@ 'action': 'store_true', 'default': False, 'help_text': ( - 'Disables automatically opening the verfication URL in the ' + 'Disables automatically opening the verification URL in the ' 'default browser.' ) + }, + { + 'name': 'use-device-code', + 'action': 'store_true', + 'default': False, + 'help_text': ( + 'Uses the Device Code authorization grant and login flow ' + 'instead of the Authorization Code flow.' + ) } ] @@ -56,19 +75,33 @@ def _sso_json_dumps(obj): def do_sso_login(session, sso_region, start_url, token_cache=None, on_pending_authorization=None, force_refresh=False, - registration_scopes=None, session_name=None): + registration_scopes=None, session_name=None, + use_device_code=False): if token_cache is None: token_cache = JSONFileCache(SSO_TOKEN_DIR, dumps_func=_sso_json_dumps) if on_pending_authorization is None: on_pending_authorization = OpenBrowserHandler( open_browser=open_browser_with_original_ld_path ) - token_fetcher = SSOTokenFetcher( - sso_region=sso_region, - client_creator=session.create_client, - cache=token_cache, - on_pending_authorization=on_pending_authorization - ) + + # For the auth flow, we need a non-legacy sso-session and check that the + # user hasn't opted into falling back to the device code flow + if session_name and not use_device_code: + token_fetcher = SSOTokenFetcherAuth( + sso_region=sso_region, + client_creator=session.create_client, + auth_code_fetcher=AuthCodeFetcher(), + cache=token_cache, + on_pending_authorization=on_pending_authorization, + ) + else: + token_fetcher = SSOTokenFetcher( + sso_region=sso_region, + client_creator=session.create_client, + cache=token_cache, + on_pending_authorization=on_pending_authorization, + ) + return token_fetcher.fetch_token( start_url=start_url, session_name=session_name, @@ -110,13 +143,20 @@ def __call__( f'Browser will not be automatically opened.\n' f'Please visit the following URL:\n' f'\n{verificationUri}\n' + + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' f'\nAlternatively, you may visit the following URL which will ' f'autofill the code upon loading:' f'\n{verificationUriComplete}\n' ) + uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) class OpenBrowserHandler(BaseAuthorizationhandler): @@ -135,10 +175,16 @@ def __call__( f'to use a different device to authorize this request, open the ' f'following URL:\n' f'\n{verificationUri}\n' + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' ) uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) + if self._open_browser: try: return self._open_browser(verificationUriComplete) @@ -146,6 +192,96 @@ def __call__( LOG.debug('Failed to open browser:', exc_info=True) +class AuthCodeFetcher: + """Manages the local web server that will be used + to retrieve the authorization code from the OAuth callback + """ + # How many seconds handle_request should wait for an incoming request + _REQUEST_TIMEOUT = 10 + # How long we wait overall for the callback + _OVERALL_TIMEOUT = 60 * 10 + + def __init__(self): + self._auth_code = None + self._state = None + self._is_done = False + + # We do this so that the request handler can have a reference to this + # AuthCodeFetcher so that it can pass back the state and auth code + try: + handler = partial(OAuthCallbackHandler, self) + self.http_server = HTTPServer(('', 0), handler) + self.http_server.timeout = self._REQUEST_TIMEOUT + except socket.error as e: + raise AuthCodeFetcherError(error_msg=e) + + def redirect_uri_without_port(self): + return 'http://127.0.0.1/oauth/callback' + + def redirect_uri_with_port(self): + return f'http://127.0.0.1:{self.http_server.server_port}/oauth/callback' + + def get_auth_code_and_state(self): + """Blocks until the expected redirect request with either the + authorization code/state or and error is handled + """ + start = time.time() + while not self._is_done and time.time() < start + self._OVERALL_TIMEOUT: + self.http_server.handle_request() + self.http_server.server_close() + + if not self._is_done: + raise PendingAuthorizationExpiredError + + return self._auth_code, self._state + + def set_auth_code_and_state(self, auth_code, state): + self._auth_code = auth_code + self._state = state + self._is_done = True + + +class OAuthCallbackHandler(BaseHTTPRequestHandler): + """HTTP handler to handle OAuth callback requests, extracting + the auth code and state parameters, and displaying a page directing + the user to return to the CLI. + """ + def __init__(self, auth_code_fetcher, *args, **kwargs): + self._auth_code_fetcher = auth_code_fetcher + super().__init__(*args, **kwargs) + + def log_message(self, format, *args): + # Suppress built-in logging, otherwise it prints + # each request to console + pass + + def version_string(self): + # Override the Host header in case helpful for debugging + return f'AWS CLI/{awscli_version}' + + def do_GET(self): + self.send_response(200) + self.end_headers() + with open( + os.path.join(os.path.dirname(__file__), 'index.html'), + 'rb', + ) as file: + self.wfile.write(file.read()) + + query_params = parse_qs(urlparse(self.path).query) + + if 'error' in query_params: + self._auth_code_fetcher.set_auth_code_and_state( + None, + None, + ) + elif 'code' in query_params and 'state' in query_params: + self._auth_code_fetcher.set_auth_code_and_state( + query_params['code'][0], + query_params['state'][0], + ) + + class InvalidSSOConfigError(ConfigurationError): pass diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index 98036f9173b8..712cf5003f2c 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -11,9 +11,10 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import time +import uuid from awscli.clidriver import AWSCLIEntryPoint -from awscli.customizations.sso.utils import OpenBrowserHandler +from awscli.customizations.sso.utils import OpenBrowserHandler, AuthCodeFetcher from awscli.testutils import create_clidriver from awscli.testutils import FileCreator from awscli.testutils import BaseAWSCommandParamsTest @@ -45,6 +46,29 @@ def setUp(self): self.open_browser_mock, ) self.open_browser_patch.start() + + self.fetcher_mock = mock.Mock(spec=AuthCodeFetcher) + self.fetcher_mock.return_value.redirect_uri_without_port.return_value = ( + 'http://127.0.0.1/oauth/callback' + ) + self.fetcher_mock.return_value.redirect_uri_with_port.return_value = ( + 'http://127.0.0.1:55555/oauth/callback' + ) + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", "00000000-0000-0000-0000-000000000000" + ) + self.auth_code_fetcher_patch = mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher', + self.fetcher_mock, + ) + self.auth_code_fetcher_patch.start() + + self.uuid_mock = mock.Mock( + return_value=uuid.UUID("00000000-0000-0000-0000-000000000000") + ) + self.uuid_patch = mock.patch('uuid.uuid4', self.uuid_mock) + self.uuid_patch.start() + self.expires_in = 28800 self.expiration_time = time.time() + 1000 @@ -52,11 +76,46 @@ def tearDown(self): super(BaseSSOTest, self).tearDown() self.files.remove_all() self.open_browser_patch.stop() + self.auth_code_fetcher_patch.stop() + self.uuid_patch.stop() self.token_cache_dir_patch.stop() def assert_used_expected_sso_region(self, expected_region): self.assertIn(expected_region, self.last_request_dict['url']) + def assert_device_browser_handler_called_with( + self, + userCode, + verificationUri, + verificationUriComplete, + ): + # assert_called_with is matching the __init__ parameters instead of + # __call__, so verify the arguments we're interested in this way + self.open_browser_mock.assert_called_once() + _, kwargs = self.open_browser_mock.return_value.call_args + self.assertEqual(userCode, kwargs['userCode']) + self.assertEqual(verificationUri, kwargs['verificationUri']) + self.assertEqual(verificationUriComplete, kwargs['verificationUriComplete']) + + def assert_auth_browser_handler_called_with(self, expected_scopes): + # The endpoint is subject to the endpoint rules, and the + # code_challenge is not fixed so assert against the rest of the url + expected_url = ( + 'authorize?' + 'response_type=code' + '&client_id=auth-client-id' + '&redirect_uri=http%3A%2F%2F127.0.0.1%3A55555%2Foauth%2Fcallback' + '&state=00000000-0000-0000-0000-000000000000' + '&code_challenge_method=S256' + '&scopes=' + expected_scopes + ) + + self.open_browser_mock.assert_called_once() + _, kwargs = self.open_browser_mock.return_value.call_args + self.assertEqual(None, kwargs['userCode']) + self.assertIn(expected_url, kwargs['verificationUri']) + self.assertIn(expected_url, kwargs['verificationUriComplete']) + def get_legacy_config(self): content = ( f'[default]\n' diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index e24c47591716..a2d6cc025079 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -23,8 +23,8 @@ class TestLoginCommand(BaseSSOTest): r'\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ' ) - def add_oidc_workflow_responses(self, access_token, - include_register_response=True): + def add_oidc_device_responses(self, access_token, + include_register_response=True): responses = [ # StartDeviceAuthorization response { @@ -53,12 +53,50 @@ def add_oidc_workflow_responses(self, access_token, 0, { 'clientSecretExpiresAt': self.expiration_time, - 'clientId': 'foo-client-id', - 'clientSecret': 'foo-client-secret', + 'clientId': 'device-client-id', + 'clientSecret': 'device-client-secret', } ) self.parsed_responses = responses + def add_oidc_auth_code_responses(self, access_token, + include_register_response=True): + responses = [ + # CreateToken responses + { + 'expiresIn': self.expires_in, + 'tokenType': 'Bearer', + 'accessToken': access_token, + } + ] + if include_register_response: + responses.insert( + 0, + { + 'clientSecretExpiresAt': self.expiration_time, + 'clientId': 'auth-client-id', + 'clientSecret': 'auth-client-secret', + } + ) + self.parsed_responses = responses + + def assert_cache_contains_registration( + self, + start_url, + session_name, + scopes, + expected_client_id): + cached_files = os.listdir(self.token_cache_dir) + + cached_registration_filename = self._get_cached_registration_filename( + start_url, session_name, scopes) + + self.assertIn(cached_registration_filename, cached_files) + self.assertEqual( + self._get_cached_response(cached_registration_filename)['clientId'], + expected_client_id + ) + def assert_cache_contains_token( self, start_url, @@ -78,6 +116,17 @@ def assert_cache_contains_token( expected_token ) + def _get_cached_registration_filename(self, start_url, session_name, scopes): + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self.sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + '.json' + def _get_cached_token_filename(self, start_url, session_name): to_hash = start_url if session_name: @@ -101,18 +150,39 @@ def assert_cache_token_expiration_time_format_is_correct(self): ) ) - def test_login(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_explicit_device(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token + ) + + def test_login_implicit_device(self): + # This is a legacy profile via setUp, so we expect + # it to fall back to device flow automatically + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_token( start_url=self.start_url, expected_token=self.access_token ) - def test_login_no_browser(self): - self.add_oidc_workflow_responses(self.access_token) - stdout, _, _ = self.run_cmd('sso login --no-browser') + def test_login_device_no_browser(self): + self.add_oidc_device_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --use-device-code --no-browser') self.assertIn('Browser will not be automatically opened.', stdout) self.open_browser_mock.assert_not_called() self.assert_used_expected_sso_region(expected_region=self.sso_region) @@ -121,20 +191,86 @@ def test_login_no_browser(self): expected_token=self.access_token ) - def test_login_forces_refresh(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_auth_no_browser(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --no-browser') + self.assertIn('Browser will not be automatically opened.', stdout) + self.open_browser_mock.assert_not_called() + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token, + session_name='test-session' + ) + + def test_login_device_forces_refresh(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_device_responses( + 'new.token', include_register_response=False) + self.run_cmd('sso login --use-device-code') + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + ) + + def test_login_auth_forces_refresh(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') # The register response from the first login should have been # cached. - self.add_oidc_workflow_responses( + self.add_oidc_auth_code_responses( 'new.token', include_register_response=False) self.run_cmd('sso login') self.assert_cache_contains_token( start_url=self.start_url, - expected_token='new.token' + expected_token='new.token', + session_name='test-session' + ) + + def test_login_auth_after_device_forces_refresh(self): + self.set_config_file_content( + content=self.get_sso_session_config('test-session')) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_auth_code_responses('new.token') + self.run_cmd('sso login') + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + session_name='test-session' ) - def test_login_no_sso_configuration(self): + def test_login_device_no_sso_configuration(self): + self.set_config_file_content(content='') + _, stderr, _ = self.run_cmd('sso login --use-device-code', + expected_rc=253) + self.assertIn( + 'Missing the following required SSO configuration', + stderr + ) + + def test_login_auth_no_sso_configuration(self): self.set_config_file_content(content='') _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn( @@ -142,28 +278,35 @@ def test_login_no_sso_configuration(self): stderr ) - def test_login_minimal_sso_configuration(self): + def test_login_device_minimal_sso_configuration(self): content = ( '[default]\n' 'sso_start_url={start_url}\n' 'sso_region={sso_region}\n' ).format(start_url=self.start_url, sso_region=self.sso_region) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_token( start_url=self.start_url, expected_token=self.access_token ) - def test_login_partially_missing_sso_configuration(self): + def test_login_device_partially_missing_sso_configuration(self): content = ( '[default]\n' 'sso_start_url=%s\n' % self.start_url ) self.set_config_file_content(content=content) - _, stderr, _ = self.run_cmd('sso login', expected_rc=253) + _, stderr, _ = self.run_cmd( + 'sso login --use-device-code', expected_rc=253 + ) self.assertIn( 'Missing the following required SSO configuration', stderr @@ -174,8 +317,8 @@ def test_login_partially_missing_sso_configuration(self): self.assertNotIn('sso_role_name', stderr) def test_token_cache_datetime_format(self): - self.add_oidc_workflow_responses(self.access_token) - self.run_cmd('sso login') + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( start_url=self.start_url, @@ -183,38 +326,133 @@ def test_token_cache_datetime_format(self): ) self.assert_cache_token_expiration_time_format_is_correct() - def test_login_sso_session(self): + def test_login_device_sso_session(self): content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_auth_sso_session(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_auth_browser_handler_called_with('sso%3Aaccount%3Aaccess') + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_device_sso_with_explicit_sso_session_arg(self): + content = self.get_sso_session_config( + 'test-session', include_profile=False) + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --sso-session test-session --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) - def test_login_sso_with_explicit_sso_session_arg(self): + def test_login_auth_sso_with_explicit_sso_session_arg(self): content = self.get_sso_session_config( 'test-session', include_profile=False) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login --sso-session test-session') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_auth_browser_handler_called_with('sso%3Aaccount%3Aaccess') + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) - def test_login_sso_session_with_scopes(self): + def test_login_device_sso_session_with_scopes(self): self.registration_scopes = ['sso:foo', 'sso:bar'] content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + operation, params = self.operations_called[0] + self.assertEqual(operation.name, 'RegisterClient') + self.assertEqual(params.get('scopes'), self.registration_scopes) + + def test_login_auth_sso_session_with_scopes(self): + self.registration_scopes = ['sso:foo', 'sso:bar'] + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_auth_browser_handler_called_with('sso%3Afoo%20sso%3Abar') + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', @@ -245,3 +483,36 @@ def test_login_sso_session_missing(self): self.set_config_file_content(content=content) _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn('sso-session does not exist: "test"', stderr) + + def test_login_auth_sso_no_authorization_code_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + None, None + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'Failed to retrieve an authorization code.', + stderr + ) + + def test_login_auth_sso_state_mismatch_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", '00000000-0000-0000-0000-000000000001' + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'State parameter does not match expected value.', + stderr + ) + diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 47f5762b3d8c..3a583dc29b39 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -11,22 +11,24 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import os +import threading import webbrowser - import pytest +import urllib3 -from awscli.testutils import mock -from awscli.testutils import unittest - +from botocore.exceptions import PendingAuthorizationExpiredError from botocore.session import Session -from botocore.exceptions import ClientError -from awscli.compat import StringIO -from awscli.customizations.sso.utils import parse_sso_registration_scopes -from awscli.customizations.sso.utils import do_sso_login +from awscli.compat import BytesIO, StringIO from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler +from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import open_browser_with_original_ld_path +from awscli.customizations.sso.utils import ( + parse_sso_registration_scopes, AuthCodeFetcher, OAuthCallbackHandler +) +from awscli.testutils import mock +from awscli.testutils import unittest @pytest.mark.parametrize( @@ -205,3 +207,115 @@ def test_can_patch_env(self): os.environ) open_browser_with_original_ld_path('http://example.com') self.assertIsNone(captured_env.get('LD_LIBRARY_PATH')) + + +class MockRequest(object): + def __init__(self, request): + self._request = request + + def makefile(self, *args, **kwargs): + return BytesIO(self._request) + + def sendall(self, data): + pass + + +class TestOAuthCallbackHandler: + """Tests for OAuthCallbackHandler, which handles + individual requests that we receive at the callback uri + """ + def test_expected_query_params(self): + fetcher = mock.Mock(AuthCodeFetcher) + + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /?state=123&code=456'), + mock.MagicMock(), + mock.MagicMock(), + ) + fetcher.set_auth_code_and_state.assert_called_once_with('456', '123') + + def test_error(self): + fetcher = mock.Mock(AuthCodeFetcher) + + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /?error=Error%20message'), + mock.MagicMock(), + mock.MagicMock(), + ) + + fetcher.set_auth_code_and_state.assert_called_once_with(None, None) + + def test_missing_expected_query_params(self): + fetcher = mock.Mock(AuthCodeFetcher) + + # We generally don't expect to be missing the expected query params, + # but if we do we expect the server to keep waiting for a valid callback + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /'), + mock.MagicMock(), + mock.MagicMock(), + ) + + fetcher.set_auth_code_and_state.assert_not_called() + + +class TestAuthCodeFetcher: + """Tests for the AuthCodeFetcher class, which is the local + web server we use to handle the OAuth 2.0 callback + """ + + def setup_method(self): + self.fetcher = AuthCodeFetcher() + self.url = f'http://127.0.0.1:{self.fetcher.http_server.server_address[1]}/' + + # Start the server on a background thread so that + # the test thread can make the request + self.server_thread = threading.Thread( + target=self.fetcher.get_auth_code_and_state + ) + self.server_thread.daemon = True + self.server_thread.start() + + def test_expected_auth_code(self): + expected_code = '1234' + expected_state = '4567' + url = self.url + f'?code={expected_code}&state={expected_state}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + actual_code, actual_state = self.fetcher.get_auth_code_and_state() + assert response.status == 200 + assert actual_code == expected_code + assert actual_state == expected_state + + def test_error(self): + expected_code = 'Failed' + url = self.url + f'?error={expected_code}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + actual_code, actual_state = self.fetcher.get_auth_code_and_state() + assert response.status == 200 + assert actual_code is None + assert actual_state is None + + +@mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher._REQUEST_TIMEOUT', + 0.1 +) +@mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher._OVERALL_TIMEOUT', + 0.1 +) +def test_get_auth_code_and_state_timeout(): + """Tests the timeout case separately of TestAuthCodeFetcher, + since we need to override the constants + """ + with pytest.raises(PendingAuthorizationExpiredError): + AuthCodeFetcher().get_auth_code_and_state() From e78d5329794f9cea2a9b4f769c18623257afde4e Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Fri, 15 Nov 2024 12:52:48 -0600 Subject: [PATCH 2/2] Add additional user agent metadata for the SSO auth code flow --- awscli/botocore/utils.py | 3 +++ tests/functional/sso/test_login.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index affcf155baa6..2f4e1ef4fd5e 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -3067,6 +3067,7 @@ class BaseSSOTokenFetcher(object): """ _EXPIRY_WINDOW = 15 * 60 _CLIENT_REGISTRATION_TYPE = 'public' + _USER_AGENT_EXTRA = None def __init__( self, sso_region, client_creator, cache=None, @@ -3122,6 +3123,7 @@ def _client(self): config = botocore.config.Config( region_name=self._sso_region, signature_version=botocore.UNSIGNED, + user_agent_extra=self._USER_AGENT_EXTRA, ) return self._client_creator('sso-oidc', config=config) @@ -3364,6 +3366,7 @@ class SSOTokenFetcherAuth(BaseSSOTokenFetcher): """Performs the authorization code grant with PKCE OAuth2.0 flow""" _AUTH_GRANT_TYPES = ('authorization_code', 'refresh_token') _AUTH_GRANT_DEFAULT_SCOPE = 'sso:account:access' + _USER_AGENT_EXTRA = 'md/sso#auth' def __init__( self, diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index a2d6cc025079..a7cc724423ed 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -516,3 +516,17 @@ def test_login_auth_sso_state_mismatch_throws_error(self): stderr ) + def test_login_device_no_extra_user_agent(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assertNotIn('md/sso#auth', + self.last_request_dict['headers']['User-Agent']) + + def test_login_auth_includes_extra_user_agent(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + self.run_cmd('sso login') + self.assertIn('md/sso#auth', + self.last_request_dict['headers']['User-Agent']) +