diff --git a/mig/server/grid_openid.py b/mig/server/grid_openid.py index aad04742e..a77472600 100755 --- a/mig/server/grid_openid.py +++ b/mig/server/grid_openid.py @@ -58,13 +58,14 @@ from __future__ import print_function from __future__ import absolute_import -from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -from SocketServer import ThreadingMixIn +from http.cookies import SimpleCookie, BaseCookie, CookieError +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn import base64 -import Cookie import cgi import cgitb +import codecs import os import re import socket @@ -87,6 +88,7 @@ from mig.shared.accountstate import check_account_accessible from mig.shared.base import client_id_dir, cert_field_map from mig.shared.conf import get_configuration_object +from mig.shared.compat import PY2 from mig.shared.griddaemons.openid import default_max_user_hits, \ default_user_abuse_hits, default_proto_abuse_hits, \ default_username_validator, refresh_user_creds, update_login_map, \ @@ -99,14 +101,12 @@ valid_path, valid_ascii, valid_job_id, valid_base_url, valid_url, \ valid_complex_url, InputException from mig.shared.tlsserver import hardened_ssl_context -from mig.shared.url import urlparse, urlencode, check_local_site_url +from mig.shared.url import urlparse, urlencode, parse_qsl, check_local_site_url from mig.shared.useradm import get_openid_user_dn, check_password_scramble, \ check_hash from mig.shared.userdb import default_db_path from mig.shared.validstring import possible_user_id -configuration, logger = None, None - # Update with extra fields cert_field_map.update({'role': 'ROLE', 'timezone': 'TZ', 'nickname': 'NICK', 'fullname': 'CN', 'o': 'O', 'ou': 'OU'}) @@ -120,9 +120,23 @@ pw_regexp = re.compile(pw_pattern) +if PY2: + def _ensure_encoded_string(chunk): + return chunk +else: + def _ensure_encoded_string(chunk): + return codecs.encode(chunk, 'utf8') + + +if PY2: + from cgi import escape as escape_html +else: + from html import escape as escape_html + + def quoteattr(val): """Escape string for safe printing""" - esc = cgi.escape(val, 1) + esc = escape_html(val, 1) return '"%s"' % (esc,) @@ -201,7 +215,7 @@ def filter_why_pw(configuration, why): return 'Unexpected %s to filter for safe log and output.' % why_type -def lookup_full_user(username): +def lookup_full_user(configuration, username): """Look up the full user identity for username consisting of e.g. just an email address. The method to extract the full identity depends on the back end database. @@ -225,7 +239,7 @@ def lookup_full_user(username): return (username, {}) -def lookup_full_identity(username): +def lookup_full_identity(configuration, username): """Look up the full identity for username consisting of e.g. just an email address. The method to extract the full identity depends on the back end database @@ -236,7 +250,7 @@ def lookup_full_identity(username): """ # print "DEBUG: lookup full ID for %s" % username - return lookup_full_user(username)[0] + return lookup_full_user(configuration, username)[0] class OpenIDHTTPServer(HTTPServer): @@ -253,8 +267,15 @@ class OpenIDHTTPServer(HTTPServer): # any gain and it potentially introduces a race hash_cache, scramble_cache = None, None - def __init__(self, *args, **kwargs): - HTTPServer.__init__(self, *args, **kwargs) + def __init__(self, configuration, **kwargs): + self.configuration = configuration + self._on_start = kwargs.pop('on_start', lambda _: None) + + address = configuration.daemon_conf['address'] + port = configuration.daemon_conf['port'] + + addr = (address, port) + HTTPServer.__init__(self, addr, OpenIDRequestHandler, **kwargs) fqdn = self.server_name port = self.server_port @@ -277,7 +298,6 @@ def __init__(self, *args, **kwargs): # We serve from sub dir to ease targeted proxying self.server_base = 'openid' self.base_url += "%s/" % self.server_base - self.openid = None self.approved = {} self.lastCheckIDRequest = {} @@ -294,8 +314,19 @@ def __init__(self, *args, **kwargs): cert_field_aliases[name].append(target) # print "DEBUG: cert field aliases: %s" % cert_field_aliases + # Instantiate OpenID consumer store and OpenID consumer. If you + # were connecting to a database, you would create the database + # connection and instantiate an appropriate store here. + data_path = configuration.openid_store + store = FileOpenIDStore(data_path) + oidserver = server.Server(store, self.base_url + 'openidserver') + def expire_volatile(self): """Expire old entries in the volatile helper dictionaries""" + + configuration = self.configuration + logger = self.logger + if self.last_expire + self.min_expire_delay < time.time(): self.last_expire = time.time() expire_rate_limit(configuration, "openid", @@ -306,9 +337,9 @@ def expire_volatile(self): self.scramble_cache.clear() logger.debug("Expired old rate limits and scramble cache") - def setOpenIDServer(self, oidserver): - """Override openid attribute""" - self.openid = oidserver + def server_activate(self): + HTTPServer.server_activate(self) + self._on_start(self) class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer): @@ -316,7 +347,7 @@ class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer): pass -class ServerHandler(BaseHTTPRequestHandler): +class OpenIDRequestHandler(BaseHTTPRequestHandler): """Override BaseHTTPRequestHandler to handle OpenID protocol""" # Input validation helper which must hold validators for all valid query @@ -359,8 +390,10 @@ class ServerHandler(BaseHTTPRequestHandler): } def __init__(self, *args, **kwargs): - if configuration.daemon_conf['session_ttl'] > 0: - self.session_ttl = configuration.daemon_conf['session_ttl'] + self.server = args[2] + + if self.daemon_conf['session_ttl'] > 0: + self.session_ttl = self.daemon_conf['session_ttl'] else: self.session_ttl = 48 * 3600 @@ -371,19 +404,32 @@ def __init__(self, *args, **kwargs): self.retry_url = '' BaseHTTPRequestHandler.__init__(self, *args, **kwargs) + @property + def configuration(self): + return self.server.configuration + + @property + def daemon_conf(self): + return self.server.configuration.daemon_conf + + @property + def logger(self): + return self.server.configuration.daemon_conf['logger'] + def __retry_url_from_cookie(self): """Extract retry_url from cookie and validate""" + logger = self.logger retry_url = None try: cookies = self.headers.get('Cookie') - cookie = Cookie.SimpleCookie(cookies) + cookie = SimpleCookie(cookies) cookie_dict = dict((k, v.value) for k, v in cookie.iteritems()) retry_url = cookie_dict.get('retry_url', '') if retry_url and retry_url.startswith("http"): raise InputException("invalid retry_url: %s" % retry_url) elif retry_url: valid_url(retry_url) - except Cookie.CookieError as err: + except CookieError as err: retry_url = None logger.error("found invalid cookie: %s" % err) except InputException as exc: @@ -402,12 +448,17 @@ def clearUser(self): def do_GET(self): """Handle all HTTP GET requests""" + + configuration = self.configuration + logger = self.logger + # Make sure key is always available for exception handler key = 'UNSET' + try: self.parsed_uri = urlparse(self.path) self.query = {} - for (key, val) in cgi.parse_qsl(self.parsed_uri[4]): + for (key, val) in parse_qsl(self.parsed_uri[4]): # print "DEBUG: checking input arg %s: '%s'" % (key, val) validate_helper = self.validators.get(key, invalid_argument) # Let validation errors pass to general exception handler below @@ -474,9 +525,12 @@ def do_GET(self): logger.error("do_GET with ref %d crashed: %s" % (error_ref, exc)) # Do not disclose internal details filtered_exc = cgitb.text(sys.exc_info(), context=10) - # IMPORTANT: do NOT ever print or log raw password - pw = self.password or '' - filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + + if self.password: + # IMPORTANT: do NOT ever print or log raw password + pw = self.password + filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + logger.debug("Traceback %s: %s" % (error_ref, filtered_exc)) err_msg = """

Internal error while handling your request - please contact the site support @@ -489,6 +543,10 @@ def do_GET(self): def do_POST(self): """Handle all HTTP POST requests""" + + configuration = self.configuration + logger = self.logger + try: self.parsed_uri = urlparse(self.path) @@ -563,6 +621,11 @@ def handleAllow(self, query): Must verify user is already logged in or validate username/password pair against user DB. """ + + configuration = self.configuration + daemon_conf = self.daemon_conf + logger = self.logger + # Use client address directly but with optional local proxy override hashed_secret = None exceeded_rate_limit = False @@ -570,7 +633,6 @@ def handleAllow(self, query): invalid_user = False account_accessible = False valid_password = False - daemon_conf = configuration.daemon_conf max_user_hits = daemon_conf['auth_limits']['max_user_hits'] user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits'] proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits'] @@ -628,7 +690,7 @@ def handleAllow(self, query): if request.idSelect(): # Do any ID expansion to a specified format if daemon_conf['expandusername']: - user_id = lookup_full_identity(query.get('identifier', '')) + user_id = lookup_full_identity(configuration, query.get('identifier', '')) else: user_id = query.get('identifier', '') identity = self.server.base_url + 'id/' + user_id @@ -654,7 +716,7 @@ def handleAllow(self, query): self.password = None account_accessible = check_account_accessible( - configuration, self.user, 'openid') + self.configuration, self.user, 'openid') # NOTE: returns None for invalid user, and boolean otherwise accepted = self.checkLogin(self.user, self.password, client_ip) if accepted is None: @@ -665,7 +727,7 @@ def handleAllow(self, query): # Update rate limits and write to auth log (authorized, _) = validate_auth_attempt( - configuration, + self.configuration, 'openid', 'password', self.user, @@ -721,17 +783,20 @@ def handleAllow(self, query): def setUser(self): """Read any saved user value from cookie""" + + logger = self.logger + cookies = self.headers.get('Cookie') # print "found cookies: %s" % cookies if cookies: - morsel = Cookie.BaseCookie(cookies).get('user') + morsel = BaseCookie(cookies).get('user') # Added morsel value check here since IE sends empty string from # cookie after initial user=;expire is sent. Others leave it out. if morsel and morsel.value != '': self.user = morsel.value expire = int(time.time() + self.session_ttl) - morsel = Cookie.BaseCookie(cookies).get('session_expire') + morsel = BaseCookie(cookies).get('session_expire') if morsel and morsel.value != '': # print "found user session_expire value: %s" % morsel.value if morsel.value.isdigit() and int(morsel.value) <= expire: @@ -754,6 +819,10 @@ def isAuthorized(self, identity_url, trust_root): def serverEndPoint(self, query): """End-point handler""" + + configuration = self.configuration + logger = self.logger + try: request = self.server.openid.decodeRequest(query) # Pass any errors from previous login attempts on for display @@ -794,9 +863,10 @@ def addSRegResponse(self, request, response): return sreg_req = sreg.SRegRequest.fromOpenIDRequest(request) - (username, user) = lookup_full_user(self.user) + (username, user) = lookup_full_user(self.configuration, self.user) if not user: + logger = self.logger logger.warning("addSRegResponse user lookup failed!") return @@ -835,6 +905,9 @@ def rejected(self, request, identifier=None): def handleCheckIDRequest(self, request): """Check ID handler""" + + logger = self.logger + logger.debug("handleCheckIDRequest with req %s" % request) is_authorized = self.isAuthorized(request.identity, request.trust_root) if is_authorized: @@ -851,6 +924,10 @@ def handleCheckIDRequest(self, request): def displayResponse(self, response): """Response helper""" + + configuration = self.configuration + logger = self.logger + try: webresponse = self.server.openid.encodeResponse(response) except server.EncodingError as why: @@ -864,7 +941,7 @@ def displayResponse(self, response): inconsistent session state.

Error details:

%s
-''' % (configuration.short_title, cgi.escape(safe_why))) +''' % (configuration.short_title, escape_html(safe_why))) return self.send_response(webresponse.code) @@ -883,6 +960,9 @@ def checkLogin(self, username, password, addr): None if no such user was found or username is invalid. """ + configuration = self.configuration + logger = self.logger + # Only need to update users here changed_users = [] if possible_user_id(configuration, username): @@ -943,13 +1023,17 @@ def checkLogin(self, username, password, addr): def doLogin(self): """Login handler""" + + configuration = self.configuration + daemon_conf = self.daemon_conf + logger = self.logger + hashed_secret = None exceeded_rate_limit = False invalid_username = False invalid_user = False account_accessible = False valid_password = False - daemon_conf = configuration.daemon_conf max_user_hits = daemon_conf['auth_limits']['max_user_hits'] user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits'] proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits'] @@ -1049,6 +1133,9 @@ def doLogin(self): def doLogout(self): """Logout handler""" + + logger = self.logger + logger.debug("logout clearing user %s" % self.user) self.clearUser() return_to_url = self.query.get('return_to', None) @@ -1058,7 +1145,9 @@ def doLogout(self): def redirect(self, url): """Redirect helper with built-in check for safe destination URL""" - if url and not check_local_site_url(configuration, url): + logger = self.logger + + if url and not check_local_site_url(self.configuration, url): logger.error("reject redirect to external URL %r" % url) self.send_response(400) else: @@ -1071,6 +1160,9 @@ def redirect(self, url): def writeUserHeader(self): """Response helper""" + + logger = self.logger + # NOTE: we added secure and httponly flags as suggested by OpenVAS # NOTE: we need to set empty user cookie for logout to work @@ -1104,7 +1196,7 @@ def showAboutPage(self): def link(url): url_attr = quoteattr(url) - url_text = cgi.escape(url) + url_text = escape_html(url) return '%s' % (url_attr, url_text) def term(url, text): @@ -1136,7 +1228,7 @@ def showPingPage(self): def link(url): url_attr = quoteattr(url) - url_text = cgi.escape(url) + url_text = escape_html(url) return '%s' % (url_attr, url_text) # IMPORTANT: This is the format availability checker looks for. @@ -1189,6 +1281,9 @@ def showErrorPage(self, error_message, error_code=400): def showDecidePage(self, request): """Decide page provider""" + + configuration = self.configuration + id_url_base = self.server.base_url + 'id/' # XXX: This may break if there are any synonyms for id_url_base, # such as referring to it by IP address or a CNAME. @@ -1331,6 +1426,9 @@ def showDecidePage(self, request): def showIdPage(self, path): """User info page provider""" + + logger = self.logger + link_tag = '' % \ self.server.base_url yadis_loc_tag = '' % \ @@ -1345,7 +1443,7 @@ def showIdPage(self, path): for (aident, trust_root) in self.server.approved: if aident == ident: trs = '
  • %s
  • \n' % \ - cgi.escape(trust_root) + escape_html(trust_root) approved_trust_roots.append(trs) else: logger.debug("Not disclosing trust roots for %s (active user %s)" @@ -1471,12 +1569,16 @@ def showLoginPage(self, success_to, fail_to, query):
    %s
    - ''' % (configuration.short_title, self.server.server_base, + ''' % (self.configuration.short_title, self.server.server_base, success_to, fail_to, err_msg)) def showPage(self, response_code, title, head_extras='', msg=None, err=None, form=None): """Show page helper""" + + configuration = self.configuration + logger = self.logger + if self.user is None: user_link = 'not logged in.' % \ self.server.server_base @@ -1556,10 +1658,11 @@ def showPage(self, response_code, title, self.send_header('Content-type', 'text/html') self.end_headers() page_template = openid_page_template(configuration, head_extras) - self.wfile.write(page_template % fill_helpers) + output_string = page_template % fill_helpers + self.wfile.write(_ensure_encoded_string(output_string)) -def limited_accept(self, *args, **kwargs): +def limited_accept(self, logger, *args, **kwargs): """Accepts a new connection from a remote client, and returns a tuple containing that new connection wrapped with a server-side SSL channel, and the address of the remote client. @@ -1573,6 +1676,7 @@ def limited_accept(self, *args, **kwargs): curl https://FQDN:PORT/ which should eventually show the page content. """ + newsock, addr = socket.socket.accept(self) # NOTE: fetch timeout from kwargs but with fall back to 10s # it must be short since server completely blocks here! @@ -1589,26 +1693,16 @@ def limited_accept(self, *args, **kwargs): def start_service(configuration): """Service launcher""" - host = configuration.user_openid_address - port = configuration.user_openid_port - data_path = configuration.openid_store + daemon_conf = configuration.daemon_conf - nossl = daemon_conf['nossl'] - addr = (host, port) + logger = configuration.logger + # TODO: is this threaded version robust enough (thread safety)? # OpenIDServer = OpenIDHTTPServer - OpenIDServer = ThreadedOpenIDHTTPServer - httpserver = OpenIDServer(addr, ServerHandler) - - # Instantiate OpenID consumer store and OpenID consumer. If you - # were connecting to a database, you would create the database - # connection and instantiate an appropriate store here. - store = FileOpenIDStore(data_path) - oidserver = server.Server(store, httpserver.base_url + 'openidserver') - - httpserver.setOpenIDServer(oidserver) + httpserver = ThreadedOpenIDHTTPServer(configuration) # Wrap in SSL if enabled + nossl = daemon_conf['nossl'] if nossl: logger.warning('Not wrapping connections in SSL - only for testing!') else: @@ -1625,8 +1719,12 @@ def start_service(configuration): server_side=True) # Override default SSLSocket accept function to inject timeout support # https://stackoverflow.com/questions/394770/override-a-method-at-instance-level/42154067#42154067 + + def logging_limited_accept(*args, **kwargs): + return limited_accept(logger, *args, **kwargs) + httpserver.socket.accept = types.MethodType( - limited_accept, httpserver.socket) + logging_limited_accept, httpserver.socket) serve_msg = 'Server running at: %s' % httpserver.base_url logger.info(serve_msg) @@ -1638,7 +1736,42 @@ def start_service(configuration): httpserver.expire_volatile() -if __name__ == '__main__': +def _extend_configuration(configuration, address, port, **kwargs): + configuration.daemon_conf = { + 'address': address, + 'port': port, + 'root_dir': os.path.abspath(configuration.user_home), + 'db_path': os.path.abspath(default_db_path(configuration)), + 'session_store': os.path.abspath(configuration.openid_store), + 'session_ttl': 24 * 3600, + 'allow_password': 'password' in configuration.user_openid_auth, + 'allow_digest': 'digest' in configuration.user_openid_auth, + 'allow_publickey': 'publickey' in configuration.user_openid_auth, + 'user_alias': configuration.user_openid_alias, + 'host_rsa_key': kwargs['host_rsa_key'], + 'users': [], + 'login_map': {}, + 'time_stamp': 0, + 'logger': kwargs['logger'], + 'nossl': kwargs['nossl'], + 'expandusername': kwargs['expandusername'], + 'show_address': kwargs['show_address'], + 'show_port': kwargs['show_port'], + # TODO: Add the following to configuration: + # max_openid_user_hits + # max_openid_user_abuse_hits + # max_openid_proto_abuse_hits + # max_openid_secret_hits + 'auth_limits': + {'max_user_hits': default_max_user_hits, + 'user_abuse_hits': default_user_abuse_hits, + 'proto_abuse_hits': default_proto_abuse_hits, + 'max_secret_hits': 1, + }, + } + + +def main(): # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) @@ -1681,9 +1814,7 @@ def start_service(configuration): unless it is available in mig/server/MiGserver.conf """) print(__doc__) - address = configuration.user_openid_address - port = configuration.user_openid_port - session_store = configuration.openid_store + default_host_key = """ -----BEGIN RSA PRIVATE KEY----- MIIEogIBAAKCAQEA404IBMReHOdvhhJ5YtgquY3DNi0v0QwfPUk+EcH/CxFW8UCC @@ -1713,6 +1844,7 @@ def start_service(configuration): 3stfzMDGtKM9lntAsfFQ8n4yvvEbn/quEWad6srf1yxt9B4t5JA= -----END RSA PRIVATE KEY----- """ + try: host_key_fd = open(configuration.user_openid_key, 'r') host_rsa_key = host_key_fd.read() @@ -1720,38 +1852,21 @@ def start_service(configuration): except IOError: logger.info("No valid host key provided - using default") host_rsa_key = default_host_key - configuration.daemon_conf = { - 'address': address, - 'port': port, - 'root_dir': os.path.abspath(configuration.user_home), - 'db_path': os.path.abspath(default_db_path(configuration)), - 'session_store': os.path.abspath(configuration.openid_store), - 'session_ttl': 24 * 3600, - 'allow_password': 'password' in configuration.user_openid_auth, - 'allow_digest': 'digest' in configuration.user_openid_auth, - 'allow_publickey': 'publickey' in configuration.user_openid_auth, - 'user_alias': configuration.user_openid_alias, - 'host_rsa_key': host_rsa_key, - 'users': [], - 'login_map': {}, - 'time_stamp': 0, - 'logger': logger, - 'nossl': nossl, - 'expandusername': expandusername, - 'show_address': show_address, - 'show_port': show_port, - # TODO: Add the following to configuration: - # max_openid_user_hits - # max_openid_user_abuse_hits - # max_openid_proto_abuse_hits - # max_openid_secret_hits - 'auth_limits': - {'max_user_hits': default_max_user_hits, - 'user_abuse_hits': default_user_abuse_hits, - 'proto_abuse_hits': default_proto_abuse_hits, - 'max_secret_hits': 1, - }, - } + + address = configuration.user_openid_address + port = configuration.user_openid_port + _extend_configuration( + configuration, + address, + port, + logger=logger, + expandusername=False, + host_rsa_key=host_rsa_key, + nossl=True, + show_address=False, + show_port=False, + ) + logger.info("Starting OpenID server") info_msg = "Listening on address '%s' and port %d" % (address, port) logger.info(info_msg) @@ -1765,3 +1880,7 @@ def start_service(configuration): info_msg = "Leaving with no more workers active" logger.info(info_msg) print(info_msg) + + +if __name__ == '__main__': + main() diff --git a/mig/shared/auth.py b/mig/shared/auth.py index 54c73b57a..343d27ea9 100644 --- a/mig/shared/auth.py +++ b/mig/shared/auth.py @@ -29,11 +29,12 @@ from __future__ import absolute_import -import Cookie import base64 import glob +from http.cookies import SimpleCookie import os import re +import sys import time # Only needed for 2FA so ignore import error and only fail on use @@ -291,7 +292,7 @@ def client_twofactor_session(configuration, if configuration.site_enable_gdp: client_id = get_base_client_id(configuration, client_id, expand_oid_alias=False) - session_cookie = Cookie.SimpleCookie() + session_cookie = SimpleCookie() session_cookie.load(environ.get('HTTP_COOKIE', "")) session_cookie = session_cookie.get('2FA_Auth', None) if session_cookie is None: diff --git a/requirements.txt b/requirements.txt index 98ff33272..6e78004b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ # migrid core dependencies on a format suitable for pip install as described on # https://pip.pypa.io/en/stable/reference/requirement-specifiers/ future +python-openid;python_version < "3" +python3-openid;python_version >= "3" # NOTE: python-3.6 and earlier versions require older pyotp, whereas 3.7+ # should work with any modern version. We tested 2.9.0 to work. pyotp;python_version >= "3.7" diff --git a/tests/test_mig_server_grid_openid.py b/tests/test_mig_server_grid_openid.py new file mode 100644 index 000000000..415cfad68 --- /dev/null +++ b/tests/test_mig_server_grid_openid.py @@ -0,0 +1,82 @@ +from __future__ import print_function +import os +import sys +import threading + +from tests.support import PY2, MIG_BASE, MigTestCase, testmain, make_wrapped_server + +from mig.server.grid_openid import ThreadedOpenIDHTTPServer, _extend_configuration, main +from mig.shared.conf import get_configuration_object + +if PY2: + from urllib2 import HTTPError, urlopen +else: + from urllib.error import HTTPError + from urllib.request import urlopen + + +class MigServerGrid_openid(MigTestCase): + def before_each(self): + self.server_addr = None + self.server_thread = None + + def after_each(self): + if self.server_thread: + self.server_thread.stop() + + def _provide_configuration(self): + return 'testconfig' + + + def issue_request(self, request_path): + assert isinstance(request_path, str) and request_path.startswith('/'), "require http path starting with /" + request_url = ''.join(('http://', self.server_addr[0], ':', str(self.server_addr[1]), request_path)) + try: + response = urlopen(request_url, None, timeout=2000) + + status = response.getcode() + data = response.read() + return (status, data) + except HTTPError as httpexc: + return (httpexc.code, None) + + def test_top_level_request_responds_status_ok(self): + self.server_addr = ('localhost', 4567) + configuration = self._make_configuration(self.configuration, self.logger, self.server_addr) + self.server_thread = self._make_server(configuration).start_wait_until_ready() + + status, _ = self.issue_request('/') + + self.assertEqual(status, 200) + + def test_unknown_request_responds_status_bad_request(self): + self.server_addr = ('localhost', 4567) + configuration = self._make_configuration(self.configuration, self.logger, self.server_addr) + self.server_thread = self._make_server(configuration).start_wait_until_ready() + + status, _ = self.issue_request('/foobar') + + self.assertEqual(status, 404) + + @staticmethod + def _make_configuration(configuration, test_logger, server_addr): + _extend_configuration( + configuration, + server_addr[0], + server_addr[1], + logger=test_logger, + expandusername=False, + host_rsa_key='', + nossl=True, + show_address=False, + show_port=False, + ) + return configuration + + @staticmethod + def _make_server(configuration): + return make_wrapped_server(ThreadedOpenIDHTTPServer, configuration) + + +if __name__ == '__main__': + testmain()