From a45e92a00031d07ff2cfff31c6176dd9adf27f6e Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Thu, 8 Aug 2024 09:35:20 +0200 Subject: [PATCH] Apply Python 3 changes to grid_openid allowing test instantiation. Make use of a need to cover the basic operation of this code for some minimal changes that allow the server it implements to be trivial to instantiate directly as an object. Use this ability to do so under test and add the necessary wiring to issue a single top-level request from the server and validate it responds with no internal issues logged. Note that a top-level request was not previously succeedding due to certain values being errantly dereferenced (though the code was clearly written to have been safe under these circumstances). Fix this here. --- align requirements with extant reality ensure server is closed between tests and validate with a status 404 case pull the test synchronisation up a level thus out of the server logic ensure the server is fully managed inside test case arranged thread improve naming make the class configurable an extend documentation inline escape_hmtl fixup fixup fixup fixup remove unnecessary version swizzle make use of the same Py3 compatible openid library as rocky 9 fixup --- mig/server/grid_openid.py | 305 +++++++++++++++++++-------- mig/shared/auth.py | 5 +- requirements.txt | 2 + tests/test_mig_server_grid_openid.py | 82 +++++++ 4 files changed, 299 insertions(+), 95 deletions(-) create mode 100644 tests/test_mig_server_grid_openid.py diff --git a/mig/server/grid_openid.py b/mig/server/grid_openid.py index aad04742e..a77472600 100755 --- a/mig/server/grid_openid.py +++ b/mig/server/grid_openid.py @@ -58,13 +58,14 @@ from __future__ import print_function from __future__ import absolute_import -from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -from SocketServer import ThreadingMixIn +from http.cookies import SimpleCookie, BaseCookie, CookieError +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn import base64 -import Cookie import cgi import cgitb +import codecs import os import re import socket @@ -87,6 +88,7 @@ from mig.shared.accountstate import check_account_accessible from mig.shared.base import client_id_dir, cert_field_map from mig.shared.conf import get_configuration_object +from mig.shared.compat import PY2 from mig.shared.griddaemons.openid import default_max_user_hits, \ default_user_abuse_hits, default_proto_abuse_hits, \ default_username_validator, refresh_user_creds, update_login_map, \ @@ -99,14 +101,12 @@ valid_path, valid_ascii, valid_job_id, valid_base_url, valid_url, \ valid_complex_url, InputException from mig.shared.tlsserver import hardened_ssl_context -from mig.shared.url import urlparse, urlencode, check_local_site_url +from mig.shared.url import urlparse, urlencode, parse_qsl, check_local_site_url from mig.shared.useradm import get_openid_user_dn, check_password_scramble, \ check_hash from mig.shared.userdb import default_db_path from mig.shared.validstring import possible_user_id -configuration, logger = None, None - # Update with extra fields cert_field_map.update({'role': 'ROLE', 'timezone': 'TZ', 'nickname': 'NICK', 'fullname': 'CN', 'o': 'O', 'ou': 'OU'}) @@ -120,9 +120,23 @@ pw_regexp = re.compile(pw_pattern) +if PY2: + def _ensure_encoded_string(chunk): + return chunk +else: + def _ensure_encoded_string(chunk): + return codecs.encode(chunk, 'utf8') + + +if PY2: + from cgi import escape as escape_html +else: + from html import escape as escape_html + + def quoteattr(val): """Escape string for safe printing""" - esc = cgi.escape(val, 1) + esc = escape_html(val, 1) return '"%s"' % (esc,) @@ -201,7 +215,7 @@ def filter_why_pw(configuration, why): return 'Unexpected %s to filter for safe log and output.' % why_type -def lookup_full_user(username): +def lookup_full_user(configuration, username): """Look up the full user identity for username consisting of e.g. just an email address. The method to extract the full identity depends on the back end database. @@ -225,7 +239,7 @@ def lookup_full_user(username): return (username, {}) -def lookup_full_identity(username): +def lookup_full_identity(configuration, username): """Look up the full identity for username consisting of e.g. just an email address. The method to extract the full identity depends on the back end database @@ -236,7 +250,7 @@ def lookup_full_identity(username): """ # print "DEBUG: lookup full ID for %s" % username - return lookup_full_user(username)[0] + return lookup_full_user(configuration, username)[0] class OpenIDHTTPServer(HTTPServer): @@ -253,8 +267,15 @@ class OpenIDHTTPServer(HTTPServer): # any gain and it potentially introduces a race hash_cache, scramble_cache = None, None - def __init__(self, *args, **kwargs): - HTTPServer.__init__(self, *args, **kwargs) + def __init__(self, configuration, **kwargs): + self.configuration = configuration + self._on_start = kwargs.pop('on_start', lambda _: None) + + address = configuration.daemon_conf['address'] + port = configuration.daemon_conf['port'] + + addr = (address, port) + HTTPServer.__init__(self, addr, OpenIDRequestHandler, **kwargs) fqdn = self.server_name port = self.server_port @@ -277,7 +298,6 @@ def __init__(self, *args, **kwargs): # We serve from sub dir to ease targeted proxying self.server_base = 'openid' self.base_url += "%s/" % self.server_base - self.openid = None self.approved = {} self.lastCheckIDRequest = {} @@ -294,8 +314,19 @@ def __init__(self, *args, **kwargs): cert_field_aliases[name].append(target) # print "DEBUG: cert field aliases: %s" % cert_field_aliases + # Instantiate OpenID consumer store and OpenID consumer. If you + # were connecting to a database, you would create the database + # connection and instantiate an appropriate store here. + data_path = configuration.openid_store + store = FileOpenIDStore(data_path) + oidserver = server.Server(store, self.base_url + 'openidserver') + def expire_volatile(self): """Expire old entries in the volatile helper dictionaries""" + + configuration = self.configuration + logger = self.logger + if self.last_expire + self.min_expire_delay < time.time(): self.last_expire = time.time() expire_rate_limit(configuration, "openid", @@ -306,9 +337,9 @@ def expire_volatile(self): self.scramble_cache.clear() logger.debug("Expired old rate limits and scramble cache") - def setOpenIDServer(self, oidserver): - """Override openid attribute""" - self.openid = oidserver + def server_activate(self): + HTTPServer.server_activate(self) + self._on_start(self) class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer): @@ -316,7 +347,7 @@ class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer): pass -class ServerHandler(BaseHTTPRequestHandler): +class OpenIDRequestHandler(BaseHTTPRequestHandler): """Override BaseHTTPRequestHandler to handle OpenID protocol""" # Input validation helper which must hold validators for all valid query @@ -359,8 +390,10 @@ class ServerHandler(BaseHTTPRequestHandler): } def __init__(self, *args, **kwargs): - if configuration.daemon_conf['session_ttl'] > 0: - self.session_ttl = configuration.daemon_conf['session_ttl'] + self.server = args[2] + + if self.daemon_conf['session_ttl'] > 0: + self.session_ttl = self.daemon_conf['session_ttl'] else: self.session_ttl = 48 * 3600 @@ -371,19 +404,32 @@ def __init__(self, *args, **kwargs): self.retry_url = '' BaseHTTPRequestHandler.__init__(self, *args, **kwargs) + @property + def configuration(self): + return self.server.configuration + + @property + def daemon_conf(self): + return self.server.configuration.daemon_conf + + @property + def logger(self): + return self.server.configuration.daemon_conf['logger'] + def __retry_url_from_cookie(self): """Extract retry_url from cookie and validate""" + logger = self.logger retry_url = None try: cookies = self.headers.get('Cookie') - cookie = Cookie.SimpleCookie(cookies) + cookie = SimpleCookie(cookies) cookie_dict = dict((k, v.value) for k, v in cookie.iteritems()) retry_url = cookie_dict.get('retry_url', '') if retry_url and retry_url.startswith("http"): raise InputException("invalid retry_url: %s" % retry_url) elif retry_url: valid_url(retry_url) - except Cookie.CookieError as err: + except CookieError as err: retry_url = None logger.error("found invalid cookie: %s" % err) except InputException as exc: @@ -402,12 +448,17 @@ def clearUser(self): def do_GET(self): """Handle all HTTP GET requests""" + + configuration = self.configuration + logger = self.logger + # Make sure key is always available for exception handler key = 'UNSET' + try: self.parsed_uri = urlparse(self.path) self.query = {} - for (key, val) in cgi.parse_qsl(self.parsed_uri[4]): + for (key, val) in parse_qsl(self.parsed_uri[4]): # print "DEBUG: checking input arg %s: '%s'" % (key, val) validate_helper = self.validators.get(key, invalid_argument) # Let validation errors pass to general exception handler below @@ -474,9 +525,12 @@ def do_GET(self): logger.error("do_GET with ref %d crashed: %s" % (error_ref, exc)) # Do not disclose internal details filtered_exc = cgitb.text(sys.exc_info(), context=10) - # IMPORTANT: do NOT ever print or log raw password - pw = self.password or '' - filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + + if self.password: + # IMPORTANT: do NOT ever print or log raw password + pw = self.password + filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + logger.debug("Traceback %s: %s" % (error_ref, filtered_exc)) err_msg = """

Internal error while handling your request - please contact the site support @@ -489,6 +543,10 @@ def do_GET(self): def do_POST(self): """Handle all HTTP POST requests""" + + configuration = self.configuration + logger = self.logger + try: self.parsed_uri = urlparse(self.path) @@ -563,6 +621,11 @@ def handleAllow(self, query): Must verify user is already logged in or validate username/password pair against user DB. """ + + configuration = self.configuration + daemon_conf = self.daemon_conf + logger = self.logger + # Use client address directly but with optional local proxy override hashed_secret = None exceeded_rate_limit = False @@ -570,7 +633,6 @@ def handleAllow(self, query): invalid_user = False account_accessible = False valid_password = False - daemon_conf = configuration.daemon_conf max_user_hits = daemon_conf['auth_limits']['max_user_hits'] user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits'] proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits'] @@ -628,7 +690,7 @@ def handleAllow(self, query): if request.idSelect(): # Do any ID expansion to a specified format if daemon_conf['expandusername']: - user_id = lookup_full_identity(query.get('identifier', '')) + user_id = lookup_full_identity(configuration, query.get('identifier', '')) else: user_id = query.get('identifier', '') identity = self.server.base_url + 'id/' + user_id @@ -654,7 +716,7 @@ def handleAllow(self, query): self.password = None account_accessible = check_account_accessible( - configuration, self.user, 'openid') + self.configuration, self.user, 'openid') # NOTE: returns None for invalid user, and boolean otherwise accepted = self.checkLogin(self.user, self.password, client_ip) if accepted is None: @@ -665,7 +727,7 @@ def handleAllow(self, query): # Update rate limits and write to auth log (authorized, _) = validate_auth_attempt( - configuration, + self.configuration, 'openid', 'password', self.user, @@ -721,17 +783,20 @@ def handleAllow(self, query): def setUser(self): """Read any saved user value from cookie""" + + logger = self.logger + cookies = self.headers.get('Cookie') # print "found cookies: %s" % cookies if cookies: - morsel = Cookie.BaseCookie(cookies).get('user') + morsel = BaseCookie(cookies).get('user') # Added morsel value check here since IE sends empty string from # cookie after initial user=;expire is sent. Others leave it out. if morsel and morsel.value != '': self.user = morsel.value expire = int(time.time() + self.session_ttl) - morsel = Cookie.BaseCookie(cookies).get('session_expire') + morsel = BaseCookie(cookies).get('session_expire') if morsel and morsel.value != '': # print "found user session_expire value: %s" % morsel.value if morsel.value.isdigit() and int(morsel.value) <= expire: @@ -754,6 +819,10 @@ def isAuthorized(self, identity_url, trust_root): def serverEndPoint(self, query): """End-point handler""" + + configuration = self.configuration + logger = self.logger + try: request = self.server.openid.decodeRequest(query) # Pass any errors from previous login attempts on for display @@ -794,9 +863,10 @@ def addSRegResponse(self, request, response): return sreg_req = sreg.SRegRequest.fromOpenIDRequest(request) - (username, user) = lookup_full_user(self.user) + (username, user) = lookup_full_user(self.configuration, self.user) if not user: + logger = self.logger logger.warning("addSRegResponse user lookup failed!") return @@ -835,6 +905,9 @@ def rejected(self, request, identifier=None): def handleCheckIDRequest(self, request): """Check ID handler""" + + logger = self.logger + logger.debug("handleCheckIDRequest with req %s" % request) is_authorized = self.isAuthorized(request.identity, request.trust_root) if is_authorized: @@ -851,6 +924,10 @@ def handleCheckIDRequest(self, request): def displayResponse(self, response): """Response helper""" + + configuration = self.configuration + logger = self.logger + try: webresponse = self.server.openid.encodeResponse(response) except server.EncodingError as why: @@ -864,7 +941,7 @@ def displayResponse(self, response): inconsistent session state.

Error details:

%s
-''' % (configuration.short_title, cgi.escape(safe_why))) +''' % (configuration.short_title, escape_html(safe_why))) return self.send_response(webresponse.code) @@ -883,6 +960,9 @@ def checkLogin(self, username, password, addr): None if no such user was found or username is invalid. """ + configuration = self.configuration + logger = self.logger + # Only need to update users here changed_users = [] if possible_user_id(configuration, username): @@ -943,13 +1023,17 @@ def checkLogin(self, username, password, addr): def doLogin(self): """Login handler""" + + configuration = self.configuration + daemon_conf = self.daemon_conf + logger = self.logger + hashed_secret = None exceeded_rate_limit = False invalid_username = False invalid_user = False account_accessible = False valid_password = False - daemon_conf = configuration.daemon_conf max_user_hits = daemon_conf['auth_limits']['max_user_hits'] user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits'] proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits'] @@ -1049,6 +1133,9 @@ def doLogin(self): def doLogout(self): """Logout handler""" + + logger = self.logger + logger.debug("logout clearing user %s" % self.user) self.clearUser() return_to_url = self.query.get('return_to', None) @@ -1058,7 +1145,9 @@ def doLogout(self): def redirect(self, url): """Redirect helper with built-in check for safe destination URL""" - if url and not check_local_site_url(configuration, url): + logger = self.logger + + if url and not check_local_site_url(self.configuration, url): logger.error("reject redirect to external URL %r" % url) self.send_response(400) else: @@ -1071,6 +1160,9 @@ def redirect(self, url): def writeUserHeader(self): """Response helper""" + + logger = self.logger + # NOTE: we added secure and httponly flags as suggested by OpenVAS # NOTE: we need to set empty user cookie for logout to work @@ -1104,7 +1196,7 @@ def showAboutPage(self): def link(url): url_attr = quoteattr(url) - url_text = cgi.escape(url) + url_text = escape_html(url) return '%s' % (url_attr, url_text) def term(url, text): @@ -1136,7 +1228,7 @@ def showPingPage(self): def link(url): url_attr = quoteattr(url) - url_text = cgi.escape(url) + url_text = escape_html(url) return '%s' % (url_attr, url_text) # IMPORTANT: This is the format availability checker looks for. @@ -1189,6 +1281,9 @@ def showErrorPage(self, error_message, error_code=400): def showDecidePage(self, request): """Decide page provider""" + + configuration = self.configuration + id_url_base = self.server.base_url + 'id/' # XXX: This may break if there are any synonyms for id_url_base, # such as referring to it by IP address or a CNAME. @@ -1331,6 +1426,9 @@ def showDecidePage(self, request): def showIdPage(self, path): """User info page provider""" + + logger = self.logger + link_tag = '' % \ self.server.base_url yadis_loc_tag = '' % \ @@ -1345,7 +1443,7 @@ def showIdPage(self, path): for (aident, trust_root) in self.server.approved: if aident == ident: trs = '
  • %s
  • \n' % \ - cgi.escape(trust_root) + escape_html(trust_root) approved_trust_roots.append(trs) else: logger.debug("Not disclosing trust roots for %s (active user %s)" @@ -1471,12 +1569,16 @@ def showLoginPage(self, success_to, fail_to, query):
    %s
    - ''' % (configuration.short_title, self.server.server_base, + ''' % (self.configuration.short_title, self.server.server_base, success_to, fail_to, err_msg)) def showPage(self, response_code, title, head_extras='', msg=None, err=None, form=None): """Show page helper""" + + configuration = self.configuration + logger = self.logger + if self.user is None: user_link = 'not logged in.' % \ self.server.server_base @@ -1556,10 +1658,11 @@ def showPage(self, response_code, title, self.send_header('Content-type', 'text/html') self.end_headers() page_template = openid_page_template(configuration, head_extras) - self.wfile.write(page_template % fill_helpers) + output_string = page_template % fill_helpers + self.wfile.write(_ensure_encoded_string(output_string)) -def limited_accept(self, *args, **kwargs): +def limited_accept(self, logger, *args, **kwargs): """Accepts a new connection from a remote client, and returns a tuple containing that new connection wrapped with a server-side SSL channel, and the address of the remote client. @@ -1573,6 +1676,7 @@ def limited_accept(self, *args, **kwargs): curl https://FQDN:PORT/ which should eventually show the page content. """ + newsock, addr = socket.socket.accept(self) # NOTE: fetch timeout from kwargs but with fall back to 10s # it must be short since server completely blocks here! @@ -1589,26 +1693,16 @@ def limited_accept(self, *args, **kwargs): def start_service(configuration): """Service launcher""" - host = configuration.user_openid_address - port = configuration.user_openid_port - data_path = configuration.openid_store + daemon_conf = configuration.daemon_conf - nossl = daemon_conf['nossl'] - addr = (host, port) + logger = configuration.logger + # TODO: is this threaded version robust enough (thread safety)? # OpenIDServer = OpenIDHTTPServer - OpenIDServer = ThreadedOpenIDHTTPServer - httpserver = OpenIDServer(addr, ServerHandler) - - # Instantiate OpenID consumer store and OpenID consumer. If you - # were connecting to a database, you would create the database - # connection and instantiate an appropriate store here. - store = FileOpenIDStore(data_path) - oidserver = server.Server(store, httpserver.base_url + 'openidserver') - - httpserver.setOpenIDServer(oidserver) + httpserver = ThreadedOpenIDHTTPServer(configuration) # Wrap in SSL if enabled + nossl = daemon_conf['nossl'] if nossl: logger.warning('Not wrapping connections in SSL - only for testing!') else: @@ -1625,8 +1719,12 @@ def start_service(configuration): server_side=True) # Override default SSLSocket accept function to inject timeout support # https://stackoverflow.com/questions/394770/override-a-method-at-instance-level/42154067#42154067 + + def logging_limited_accept(*args, **kwargs): + return limited_accept(logger, *args, **kwargs) + httpserver.socket.accept = types.MethodType( - limited_accept, httpserver.socket) + logging_limited_accept, httpserver.socket) serve_msg = 'Server running at: %s' % httpserver.base_url logger.info(serve_msg) @@ -1638,7 +1736,42 @@ def start_service(configuration): httpserver.expire_volatile() -if __name__ == '__main__': +def _extend_configuration(configuration, address, port, **kwargs): + configuration.daemon_conf = { + 'address': address, + 'port': port, + 'root_dir': os.path.abspath(configuration.user_home), + 'db_path': os.path.abspath(default_db_path(configuration)), + 'session_store': os.path.abspath(configuration.openid_store), + 'session_ttl': 24 * 3600, + 'allow_password': 'password' in configuration.user_openid_auth, + 'allow_digest': 'digest' in configuration.user_openid_auth, + 'allow_publickey': 'publickey' in configuration.user_openid_auth, + 'user_alias': configuration.user_openid_alias, + 'host_rsa_key': kwargs['host_rsa_key'], + 'users': [], + 'login_map': {}, + 'time_stamp': 0, + 'logger': kwargs['logger'], + 'nossl': kwargs['nossl'], + 'expandusername': kwargs['expandusername'], + 'show_address': kwargs['show_address'], + 'show_port': kwargs['show_port'], + # TODO: Add the following to configuration: + # max_openid_user_hits + # max_openid_user_abuse_hits + # max_openid_proto_abuse_hits + # max_openid_secret_hits + 'auth_limits': + {'max_user_hits': default_max_user_hits, + 'user_abuse_hits': default_user_abuse_hits, + 'proto_abuse_hits': default_proto_abuse_hits, + 'max_secret_hits': 1, + }, + } + + +def main(): # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) @@ -1681,9 +1814,7 @@ def start_service(configuration): unless it is available in mig/server/MiGserver.conf """) print(__doc__) - address = configuration.user_openid_address - port = configuration.user_openid_port - session_store = configuration.openid_store + default_host_key = """ -----BEGIN RSA PRIVATE KEY----- MIIEogIBAAKCAQEA404IBMReHOdvhhJ5YtgquY3DNi0v0QwfPUk+EcH/CxFW8UCC @@ -1713,6 +1844,7 @@ def start_service(configuration): 3stfzMDGtKM9lntAsfFQ8n4yvvEbn/quEWad6srf1yxt9B4t5JA= -----END RSA PRIVATE KEY----- """ + try: host_key_fd = open(configuration.user_openid_key, 'r') host_rsa_key = host_key_fd.read() @@ -1720,38 +1852,21 @@ def start_service(configuration): except IOError: logger.info("No valid host key provided - using default") host_rsa_key = default_host_key - configuration.daemon_conf = { - 'address': address, - 'port': port, - 'root_dir': os.path.abspath(configuration.user_home), - 'db_path': os.path.abspath(default_db_path(configuration)), - 'session_store': os.path.abspath(configuration.openid_store), - 'session_ttl': 24 * 3600, - 'allow_password': 'password' in configuration.user_openid_auth, - 'allow_digest': 'digest' in configuration.user_openid_auth, - 'allow_publickey': 'publickey' in configuration.user_openid_auth, - 'user_alias': configuration.user_openid_alias, - 'host_rsa_key': host_rsa_key, - 'users': [], - 'login_map': {}, - 'time_stamp': 0, - 'logger': logger, - 'nossl': nossl, - 'expandusername': expandusername, - 'show_address': show_address, - 'show_port': show_port, - # TODO: Add the following to configuration: - # max_openid_user_hits - # max_openid_user_abuse_hits - # max_openid_proto_abuse_hits - # max_openid_secret_hits - 'auth_limits': - {'max_user_hits': default_max_user_hits, - 'user_abuse_hits': default_user_abuse_hits, - 'proto_abuse_hits': default_proto_abuse_hits, - 'max_secret_hits': 1, - }, - } + + address = configuration.user_openid_address + port = configuration.user_openid_port + _extend_configuration( + configuration, + address, + port, + logger=logger, + expandusername=False, + host_rsa_key=host_rsa_key, + nossl=True, + show_address=False, + show_port=False, + ) + logger.info("Starting OpenID server") info_msg = "Listening on address '%s' and port %d" % (address, port) logger.info(info_msg) @@ -1765,3 +1880,7 @@ def start_service(configuration): info_msg = "Leaving with no more workers active" logger.info(info_msg) print(info_msg) + + +if __name__ == '__main__': + main() diff --git a/mig/shared/auth.py b/mig/shared/auth.py index 54c73b57a..343d27ea9 100644 --- a/mig/shared/auth.py +++ b/mig/shared/auth.py @@ -29,11 +29,12 @@ from __future__ import absolute_import -import Cookie import base64 import glob +from http.cookies import SimpleCookie import os import re +import sys import time # Only needed for 2FA so ignore import error and only fail on use @@ -291,7 +292,7 @@ def client_twofactor_session(configuration, if configuration.site_enable_gdp: client_id = get_base_client_id(configuration, client_id, expand_oid_alias=False) - session_cookie = Cookie.SimpleCookie() + session_cookie = SimpleCookie() session_cookie.load(environ.get('HTTP_COOKIE', "")) session_cookie = session_cookie.get('2FA_Auth', None) if session_cookie is None: diff --git a/requirements.txt b/requirements.txt index 98ff33272..6e78004b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ # migrid core dependencies on a format suitable for pip install as described on # https://pip.pypa.io/en/stable/reference/requirement-specifiers/ future +python-openid;python_version < "3" +python3-openid;python_version >= "3" # NOTE: python-3.6 and earlier versions require older pyotp, whereas 3.7+ # should work with any modern version. We tested 2.9.0 to work. pyotp;python_version >= "3.7" diff --git a/tests/test_mig_server_grid_openid.py b/tests/test_mig_server_grid_openid.py new file mode 100644 index 000000000..415cfad68 --- /dev/null +++ b/tests/test_mig_server_grid_openid.py @@ -0,0 +1,82 @@ +from __future__ import print_function +import os +import sys +import threading + +from tests.support import PY2, MIG_BASE, MigTestCase, testmain, make_wrapped_server + +from mig.server.grid_openid import ThreadedOpenIDHTTPServer, _extend_configuration, main +from mig.shared.conf import get_configuration_object + +if PY2: + from urllib2 import HTTPError, urlopen +else: + from urllib.error import HTTPError + from urllib.request import urlopen + + +class MigServerGrid_openid(MigTestCase): + def before_each(self): + self.server_addr = None + self.server_thread = None + + def after_each(self): + if self.server_thread: + self.server_thread.stop() + + def _provide_configuration(self): + return 'testconfig' + + + def issue_request(self, request_path): + assert isinstance(request_path, str) and request_path.startswith('/'), "require http path starting with /" + request_url = ''.join(('http://', self.server_addr[0], ':', str(self.server_addr[1]), request_path)) + try: + response = urlopen(request_url, None, timeout=2000) + + status = response.getcode() + data = response.read() + return (status, data) + except HTTPError as httpexc: + return (httpexc.code, None) + + def test_top_level_request_responds_status_ok(self): + self.server_addr = ('localhost', 4567) + configuration = self._make_configuration(self.configuration, self.logger, self.server_addr) + self.server_thread = self._make_server(configuration).start_wait_until_ready() + + status, _ = self.issue_request('/') + + self.assertEqual(status, 200) + + def test_unknown_request_responds_status_bad_request(self): + self.server_addr = ('localhost', 4567) + configuration = self._make_configuration(self.configuration, self.logger, self.server_addr) + self.server_thread = self._make_server(configuration).start_wait_until_ready() + + status, _ = self.issue_request('/foobar') + + self.assertEqual(status, 404) + + @staticmethod + def _make_configuration(configuration, test_logger, server_addr): + _extend_configuration( + configuration, + server_addr[0], + server_addr[1], + logger=test_logger, + expandusername=False, + host_rsa_key='', + nossl=True, + show_address=False, + show_port=False, + ) + return configuration + + @staticmethod + def _make_server(configuration): + return make_wrapped_server(ThreadedOpenIDHTTPServer, configuration) + + +if __name__ == '__main__': + testmain()