diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 3efd3cb3a573..3a583dc29b39 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -19,13 +19,13 @@ from botocore.exceptions import PendingAuthorizationExpiredError from botocore.session import Session -from awscli.compat import StringIO +from awscli.compat import BytesIO, StringIO from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import open_browser_with_original_ld_path from awscli.customizations.sso.utils import ( - parse_sso_registration_scopes, AuthCodeFetcher + parse_sso_registration_scopes, AuthCodeFetcher, OAuthCallbackHandler ) from awscli.testutils import mock from awscli.testutils import unittest @@ -209,6 +209,59 @@ def test_can_patch_env(self): self.assertIsNone(captured_env.get('LD_LIBRARY_PATH')) +class MockRequest(object): + def __init__(self, request): + self._request = request + + def makefile(self, *args, **kwargs): + return BytesIO(self._request) + + def sendall(self, data): + pass + + +class TestOAuthCallbackHandler: + """Tests for OAuthCallbackHandler, which handles + individual requests that we receive at the callback uri + """ + def test_expected_query_params(self): + fetcher = mock.Mock(AuthCodeFetcher) + + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /?state=123&code=456'), + mock.MagicMock(), + mock.MagicMock(), + ) + fetcher.set_auth_code_and_state.assert_called_once_with('456', '123') + + def test_error(self): + fetcher = mock.Mock(AuthCodeFetcher) + + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /?error=Error%20message'), + mock.MagicMock(), + mock.MagicMock(), + ) + + fetcher.set_auth_code_and_state.assert_called_once_with(None, None) + + def test_missing_expected_query_params(self): + fetcher = mock.Mock(AuthCodeFetcher) + + # We generally don't expect to be missing the expected query params, + # but if we do we expect the server to keep waiting for a valid callback + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /'), + mock.MagicMock(), + mock.MagicMock(), + ) + + fetcher.set_auth_code_and_state.assert_not_called() + + class TestAuthCodeFetcher: """Tests for the AuthCodeFetcher class, which is the local web server we use to handle the OAuth 2.0 callback