Skip to content

Commit

Permalink
Merge pull request #110 from soxofaan/oidc-specify-redirect-port
Browse files Browse the repository at this point in the history
Allow specifying OAuth redirect port (and host)
  • Loading branch information
soxofaan authored Feb 6, 2020
2 parents 66d643d + 01af400 commit cd204f4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
14 changes: 10 additions & 4 deletions openeo/rest/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ class HttpServerThread(threading.Thread):
Thread that runs a HTTP server (`http.server.HTTPServer`)
"""

def __init__(self, RequestHandlerClass, server_address=('', 0)):
def __init__(self, RequestHandlerClass, server_address: Tuple[str, int] = None):
# Make it a daemon to minimize potential shutdown issues due to `serve_forever`
super().__init__(daemon=True)
self._RequestHandlerClass = RequestHandlerClass
# Server address ('', 0): listen on all ips and let OS pick a free port
self._server_address = server_address
self._server_address = server_address or ('', 0)
self._server = None

def start(self):
Expand Down Expand Up @@ -187,11 +187,13 @@ class OidcAuthCodePkceAuthenticator(OidcAuthenticator):
AuthCodeResult = namedtuple("AuthCodeResult", ["auth_code", "nonce", "code_verifier", "redirect_uri"])
AccessTokenResult = namedtuple("AccessTokenResult", ["access_token", "id_token", "refresh_token"])

def __init__(self, client_id: str, oidc_discovery_url: str, webbrowser_open: Callable = None, timeout=120):
def __init__(self, client_id: str, oidc_discovery_url: str, webbrowser_open: Callable = None, timeout=120,
server_address: Tuple[str, int] = None):
self._client_id = client_id
self._provider_info = requests.get(oidc_discovery_url).json()
self._webbrowser_open = webbrowser_open or webbrowser.open
self._authentication_timeout = timeout
self._server_address = server_address

@staticmethod
def hash_code_verifier(code: str) -> str:
Expand Down Expand Up @@ -222,7 +224,11 @@ def _get_auth_code(self) -> AuthCodeResult:
# Set up HTTP server (in separate thread) to catch OAuth redirect URL
callback_queue = Queue()
RequestHandlerClass = OAuthRedirectRequestHandler.with_queue(callback_queue)
with HttpServerThread(RequestHandlerClass=RequestHandlerClass) as http_server_thread:
http_server_thread = HttpServerThread(
RequestHandlerClass=RequestHandlerClass,
server_address=self._server_address
)
with http_server_thread:
port, host, fqdn = http_server_thread.server_address_info()
# TODO: use fully qualified domain name instead of "localhost"?
# Otherwise things won't work when the client is for example
Expand Down
7 changes: 5 additions & 2 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil
import sys
import warnings
from typing import Dict, List
from typing import Dict, List, Tuple
from urllib.parse import urljoin

import requests
Expand Down Expand Up @@ -197,14 +197,16 @@ def authenticate_basic(self, username: str, password: str) -> 'Connection':
self.auth = BearerAuth(bearer=resp["access_token"])
return self

def authenticate_OIDC(self, client_id: str, webbrowser_open=None, timeout=120) -> 'Connection':
def authenticate_OIDC(self, client_id: str, webbrowser_open=None, timeout=120,
server_address: Tuple[str, int] = None) -> 'Connection':
"""
Authenticates a user to the backend using OpenID Connect.
:param client_id: Client id to use for OpenID Connect authentication
:param webbrowser_open: optional handler for the initial OAuth authentication request
(opens a webbrowser by default)
:param timeout: number of seconds after which to abort the authentication procedure
:param server_address: optional tuple (hostname, port_number) to serve the OAuth redirect callback on
"""
# Local import to avoid importing the whole OpenID Connect dependency chain. TODO: just do global import?
from openeo.rest.auth.oidc import OidcAuthCodePkceAuthenticator
Expand All @@ -216,6 +218,7 @@ def authenticate_OIDC(self, client_id: str, webbrowser_open=None, timeout=120) -
oidc_discovery_url=oidc_discovery_url,
webbrowser_open=webbrowser_open,
timeout=timeout,
server_address=server_address,
)
# Do the Oauth/OpenID Connect flow and use the access token as bearer token.
tokens = authenticator.get_tokens()
Expand Down
15 changes: 15 additions & 0 deletions tests/rest/auth/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def test_http_server_thread():
server_thread.join()


def test_http_server_thread_port():
queue = Queue()
server_thread = HttpServerThread(RequestHandlerClass=QueuingRequestHandler.with_queue(queue),
server_address=('', 12345))
server_thread.start()
port, host, fqdn = server_thread.server_address_info()
assert port == 12345
url = 'http://{f}:{p}/foo/bar'.format(f=fqdn, p=port)
response = requests.get(url)
response.raise_for_status()
assert list(drain_queue(queue)) == ['/foo/bar']
server_thread.shutdown()
server_thread.join()


def test_oidc_flow(oidc_test_setup):
# see test/rest/conftest.py for `oidc_test_setup` fixture
client_id = "myclient"
Expand Down

0 comments on commit cd204f4

Please sign in to comment.