diff --git a/mig/server/grid_openid.py b/mig/server/grid_openid.py index ab83fd5bf..e1db1851d 100755 --- a/mig/server/grid_openid.py +++ b/mig/server/grid_openid.py @@ -58,17 +58,19 @@ from __future__ import print_function from __future__ import absolute_import -from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -from SocketServer import ThreadingMixIn +from http.cookies import SimpleCookie, CookieError +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn import base64 -import Cookie import cgi import cgitb +import codecs import os import re import socket import sys +import threading import time import types @@ -87,6 +89,7 @@ from mig.shared.accountstate import check_account_accessible from mig.shared.base import client_id_dir, cert_field_map from mig.shared.conf import get_configuration_object +from mig.shared.compat import PY2, escape_html from mig.shared.griddaemons.openid import default_max_user_hits, \ default_user_abuse_hits, default_proto_abuse_hits, \ default_username_validator, refresh_user_creds, update_login_map, \ @@ -99,14 +102,12 @@ valid_path, valid_ascii, valid_job_id, valid_base_url, valid_url, \ valid_complex_url, InputException from mig.shared.tlsserver import hardened_ssl_context -from mig.shared.url import urlparse, urlencode +from mig.shared.url import urlparse, urlencode, parse_qsl from mig.shared.useradm import get_openid_user_dn, check_password_scramble, \ check_hash from mig.shared.userdb import default_db_path from mig.shared.validstring import possible_user_id -configuration, logger = None, None - # Update with extra fields cert_field_map.update({'role': 'ROLE', 'timezone': 'TZ', 'nickname': 'NICK', 'fullname': 'CN', 'o': 'O', 'ou': 'OU'}) @@ -120,9 +121,17 @@ pw_regexp = re.compile(pw_pattern) +if PY2: + def _ensure_encoded_string(chunk): + return chunk +else: + def _ensure_encoded_string(chunk): + return codecs.encode(chunk, 'utf8') + + def quoteattr(val): """Escape string for safe printing""" - esc = cgi.escape(val, 1) + esc = escape_html(val, 1) return '"%s"' % (esc,) @@ -246,8 +255,14 @@ class OpenIDHTTPServer(HTTPServer): # any gain and it potentially introduces a race hash_cache, scramble_cache = None, None - def __init__(self, *args, **kwargs): - HTTPServer.__init__(self, *args, **kwargs) + def __init__(self, configuration, **kwargs): + self.configuration = configuration + + address = configuration.daemon_conf['address'] + port = configuration.daemon_conf['port'] + + addr = (address, port) + HTTPServer.__init__(self, addr, OpenIDRequestHandler, **kwargs) fqdn = self.server_name port = self.server_port @@ -270,7 +285,6 @@ def __init__(self, *args, **kwargs): # We serve from sub dir to ease targeted proxying self.server_base = 'openid' self.base_url += "%s/" % self.server_base - self.openid = None self.approved = {} self.lastCheckIDRequest = {} @@ -287,6 +301,13 @@ def __init__(self, *args, **kwargs): cert_field_aliases[name].append(target) # print "DEBUG: cert field aliases: %s" % cert_field_aliases + # Instantiate OpenID consumer store and OpenID consumer. If you + # were connecting to a database, you would create the database + # connection and instantiate an appropriate store here. + data_path = configuration.openid_store + store = FileOpenIDStore(data_path) + oidserver = server.Server(store, self.base_url + 'openidserver') + def expire_volatile(self): """Expire old entries in the volatile helper dictionaries""" if self.last_expire + self.min_expire_delay < time.time(): @@ -299,17 +320,24 @@ def expire_volatile(self): self.scramble_cache.clear() logger.debug("Expired old rate limits and scramble cache") - def setOpenIDServer(self, oidserver): - """Override openid attribute""" - self.openid = oidserver - class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer): """Multi-threaded version of the OpenIDHTTPServer""" - pass + def __init__(self, configuration, **kwargs): + self._started = threading.Event() + OpenIDHTTPServer.__init__(self, configuration, **kwargs) + + def server_activate(self): + OpenIDHTTPServer.server_activate(self) + self._started.set() -class ServerHandler(BaseHTTPRequestHandler): + def wait_start(self): + if not self._started.is_set(): + self._started.wait() + + +class OpenIDRequestHandler(BaseHTTPRequestHandler): """Override BaseHTTPRequestHandler to handle OpenID protocol""" # Input validation helper which must hold validators for all valid query @@ -352,8 +380,10 @@ class ServerHandler(BaseHTTPRequestHandler): } def __init__(self, *args, **kwargs): - if configuration.daemon_conf['session_ttl'] > 0: - self.session_ttl = configuration.daemon_conf['session_ttl'] + self.server = args[2] + + if self.daemon_conf['session_ttl'] > 0: + self.session_ttl = self.daemon_conf['session_ttl'] else: self.session_ttl = 48 * 3600 @@ -364,19 +394,32 @@ def __init__(self, *args, **kwargs): self.retry_url = '' BaseHTTPRequestHandler.__init__(self, *args, **kwargs) + @property + def configuration(self): + return self.server.configuration + + @property + def daemon_conf(self): + return self.server.configuration.daemon_conf + + @property + def logger(self): + return self.server.configuration.daemon_conf['logger'] + def __retry_url_from_cookie(self): """Extract retry_url from cookie and validate""" + logger = self.logger retry_url = None try: cookies = self.headers.get('Cookie') - cookie = Cookie.SimpleCookie(cookies) + cookie = SimpleCookie(cookies) cookie_dict = dict((k, v.value) for k, v in cookie.iteritems()) retry_url = cookie_dict.get('retry_url', '') if retry_url and retry_url.startswith("http"): raise InputException("invalid retry_url: %s" % retry_url) elif retry_url: valid_url(retry_url) - except Cookie.CookieError as err: + except CookieError as err: retry_url = None logger.error("found invalid cookie: %s" % err) except InputException as exc: @@ -397,10 +440,11 @@ def do_GET(self): """Handle all HTTP GET requests""" # Make sure key is always available for exception handler key = 'UNSET' + logger = self.logger try: self.parsed_uri = urlparse(self.path) self.query = {} - for (key, val) in cgi.parse_qsl(self.parsed_uri[4]): + for (key, val) in parse_qsl(self.parsed_uri[4]): # print "DEBUG: checking input arg %s: '%s'" % (key, val) validate_helper = self.validators.get(key, invalid_argument) # Let validation errors pass to general exception handler below @@ -467,9 +511,12 @@ def do_GET(self): logger.error("do_GET with ref %d crashed: %s" % (error_ref, exc)) # Do not disclose internal details filtered_exc = cgitb.text(sys.exc_info(), context=10) - # IMPORTANT: do NOT ever print or log raw password - pw = self.password - filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + + if self.password: + # IMPORTANT: do NOT ever print or log raw password + pw = self.password + filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + logger.debug("Traceback %s: %s" % (error_ref, filtered_exc)) err_msg = """

Internal error while handling your request - please contact the site support @@ -477,7 +524,7 @@ def do_GET(self):

Back -

""" % (configuration.support_email, error_ref) +

""" % (self.daemon_conf['support_email'], error_ref) self.showErrorPage(err_msg, error_code=500) def do_POST(self): @@ -636,7 +683,7 @@ def handleAllow(self, query): self.password = None account_accessible = check_account_accessible( - configuration, self.user, 'openid') + self.configuration, self.user, 'openid') # NOTE: returns None for invalid user, and boolean otherwise accepted = self.checkLogin(self.user, self.password, client_ip) if accepted is None: @@ -647,7 +694,7 @@ def handleAllow(self, query): # Update rate limits and write to auth log (authorized, _) = validate_auth_attempt( - configuration, + self.configuration, 'openid', 'password', self.user, @@ -836,7 +883,7 @@ def displayResponse(self, response):

Error details:

%s
-''' % cgi.escape(text)) +''' % escape_html(text)) return self.send_response(webresponse.code) @@ -855,6 +902,8 @@ def checkLogin(self, username, password, addr): None if no such user was found or username is invalid. """ + configuration = self.configuration + # Only need to update users here changed_users = [] if possible_user_id(configuration, username): @@ -921,7 +970,8 @@ def doLogin(self): invalid_user = False account_accessible = False valid_password = False - daemon_conf = configuration.daemon_conf + configuration = self.configuration + daemon_conf = self.daemon_conf max_user_hits = daemon_conf['auth_limits']['max_user_hits'] user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits'] proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits'] @@ -1034,6 +1084,9 @@ def redirect(self, url): def writeUserHeader(self): """Response helper""" + + logger = self.logger + # NOTE: we added secure and httponly flags as suggested by OpenVAS # NOTE: we need to set empty user cookie for logout to work @@ -1067,7 +1120,7 @@ def showAboutPage(self): def link(url): url_attr = quoteattr(url) - url_text = cgi.escape(url) + url_text = escape_html(url) return '%s' % (url_attr, url_text) def term(url, text): @@ -1099,7 +1152,7 @@ def showPingPage(self): def link(url): url_attr = quoteattr(url) - url_text = cgi.escape(url) + url_text = escape_html(url) return '%s' % (url_attr, url_text) # IMPORTANT: This is the format availability checker looks for. @@ -1308,7 +1361,7 @@ def showIdPage(self, path): for (aident, trust_root) in self.server.approved: if aident == ident: trs = '
  • %s
  • \n' % \ - cgi.escape(trust_root) + escape_html(trust_root) approved_trust_roots.append(trs) else: logger.debug("Not disclosing trust roots for %s (active user %s)" @@ -1440,6 +1493,10 @@ def showLoginPage(self, success_to, fail_to, query): def showPage(self, response_code, title, head_extras='', msg=None, err=None, form=None): """Show page helper""" + + configuration = self.configuration + logger = self.logger + if self.user is None: user_link = 'not logged in.' % \ self.server.server_base @@ -1519,7 +1576,8 @@ def showPage(self, response_code, title, self.send_header('Content-type', 'text/html') self.end_headers() page_template = openid_page_template(configuration, head_extras) - self.wfile.write(page_template % fill_helpers) + output_string = page_template % fill_helpers + self.wfile.write(_ensure_encoded_string(output_string)) def limited_accept(self, *args, **kwargs): @@ -1552,24 +1610,12 @@ def limited_accept(self, *args, **kwargs): def start_service(configuration): """Service launcher""" - host = configuration.user_openid_address - port = configuration.user_openid_port - data_path = configuration.openid_store daemon_conf = configuration.daemon_conf nossl = daemon_conf['nossl'] addr = (host, port) # TODO: is this threaded version robust enough (thread safety)? # OpenIDServer = OpenIDHTTPServer - OpenIDServer = ThreadedOpenIDHTTPServer - httpserver = OpenIDServer(addr, ServerHandler) - - # Instantiate OpenID consumer store and OpenID consumer. If you - # were connecting to a database, you would create the database - # connection and instantiate an appropriate store here. - store = FileOpenIDStore(data_path) - oidserver = server.Server(store, httpserver.base_url + 'openidserver') - - httpserver.setOpenIDServer(oidserver) + httpserver = ThreadedOpenIDHTTPServer(configuration, addr) # Wrap in SSL if enabled if nossl: @@ -1601,7 +1647,43 @@ def start_service(configuration): httpserver.expire_volatile() -if __name__ == '__main__': +def _extend_configuration(configuration, address, port, **kwargs): + configuration.daemon_conf = { + 'address': address, + 'port': port, + 'root_dir': os.path.abspath(configuration.user_home), + 'db_path': os.path.abspath(default_db_path(configuration)), + 'session_store': os.path.abspath(configuration.openid_store), + 'session_ttl': 24*3600, + 'allow_password': 'password' in configuration.user_openid_auth, + 'allow_digest': 'digest' in configuration.user_openid_auth, + 'allow_publickey': 'publickey' in configuration.user_openid_auth, + 'user_alias': configuration.user_openid_alias, + 'host_rsa_key': kwargs['host_rsa_key'], + 'users': [], + 'login_map': {}, + 'time_stamp': 0, + 'logger': kwargs['logger'], + 'nossl': kwargs['nossl'], + 'expandusername': kwargs['expandusername'], + 'show_address': kwargs['show_address'], + 'show_port': kwargs['show_port'], + 'support_email': configuration.support_email, + # TODO: Add the following to configuration: + # max_openid_user_hits + # max_openid_user_abuse_hits + # max_openid_proto_abuse_hits + # max_openid_secret_hits + 'auth_limits': + {'max_user_hits': default_max_user_hits, + 'user_abuse_hits': default_user_abuse_hits, + 'proto_abuse_hits': default_proto_abuse_hits, + 'max_secret_hits': 1, + }, + } + + +def main(): # Force no log init since we use separate logger configuration = get_configuration_object(skip_log=True) @@ -1644,9 +1726,7 @@ def start_service(configuration): unless it is available in mig/server/MiGserver.conf """) print(__doc__) - address = configuration.user_openid_address - port = configuration.user_openid_port - session_store = configuration.openid_store + default_host_key = """ -----BEGIN RSA PRIVATE KEY----- MIIEogIBAAKCAQEA404IBMReHOdvhhJ5YtgquY3DNi0v0QwfPUk+EcH/CxFW8UCC @@ -1676,6 +1756,7 @@ def start_service(configuration): 3stfzMDGtKM9lntAsfFQ8n4yvvEbn/quEWad6srf1yxt9B4t5JA= -----END RSA PRIVATE KEY----- """ + try: host_key_fd = open(configuration.user_openid_key, 'r') host_rsa_key = host_key_fd.read() @@ -1683,38 +1764,21 @@ def start_service(configuration): except IOError: logger.info("No valid host key provided - using default") host_rsa_key = default_host_key - configuration.daemon_conf = { - 'address': address, - 'port': port, - 'root_dir': os.path.abspath(configuration.user_home), - 'db_path': os.path.abspath(default_db_path(configuration)), - 'session_store': os.path.abspath(configuration.openid_store), - 'session_ttl': 24*3600, - 'allow_password': 'password' in configuration.user_openid_auth, - 'allow_digest': 'digest' in configuration.user_openid_auth, - 'allow_publickey': 'publickey' in configuration.user_openid_auth, - 'user_alias': configuration.user_openid_alias, - 'host_rsa_key': host_rsa_key, - 'users': [], - 'login_map': {}, - 'time_stamp': 0, - 'logger': logger, - 'nossl': nossl, - 'expandusername': expandusername, - 'show_address': show_address, - 'show_port': show_port, - # TODO: Add the following to configuration: - # max_openid_user_hits - # max_openid_user_abuse_hits - # max_openid_proto_abuse_hits - # max_openid_secret_hits - 'auth_limits': - {'max_user_hits': default_max_user_hits, - 'user_abuse_hits': default_user_abuse_hits, - 'proto_abuse_hits': default_proto_abuse_hits, - 'max_secret_hits': 1, - }, - } + + address = configuration.user_openid_address + port = configuration.user_openid_port + _extend_configuration( + configuration, + address, + port, + logger=logger, + expandusername=False, + host_rsa_key=host_rsa_key, + nossl=True, + show_address=False, + show_port=False, + ) + logger.info("Starting OpenID server") info_msg = "Listening on address '%s' and port %d" % (address, port) logger.info(info_msg) @@ -1728,3 +1792,7 @@ def start_service(configuration): info_msg = "Leaving with no more workers active" logger.info(info_msg) print(info_msg) + + +if __name__ == '__main__': + main() diff --git a/mig/shared/auth.py b/mig/shared/auth.py index 54c73b57a..b69c9ceed 100644 --- a/mig/shared/auth.py +++ b/mig/shared/auth.py @@ -29,11 +29,11 @@ from __future__ import absolute_import -import Cookie import base64 import glob import os import re +import sys import time # Only needed for 2FA so ignore import error and only fail on use @@ -42,6 +42,13 @@ except ImportError: pyotp = None +PY2 = sys.version_info[0] < 3 + +if PY2: + from Cookie import SimpleCookie +else: + from http.cookies import SimpleCookie + from mig.shared.base import client_id_dir, extract_field, force_utf8 from mig.shared.defaults import twofactor_key_name, twofactor_interval_name, \ twofactor_key_bytes, twofactor_cookie_bytes, twofactor_cookie_ttl @@ -291,7 +298,7 @@ def client_twofactor_session(configuration, if configuration.site_enable_gdp: client_id = get_base_client_id(configuration, client_id, expand_oid_alias=False) - session_cookie = Cookie.SimpleCookie() + session_cookie = SimpleCookie() session_cookie.load(environ.get('HTTP_COOKIE', "")) session_cookie = session_cookie.get('2FA_Auth', None) if session_cookie is None: diff --git a/mig/shared/compat.py b/mig/shared/compat.py index 3c5263f78..c96752a51 100644 --- a/mig/shared/compat.py +++ b/mig/shared/compat.py @@ -40,6 +40,14 @@ _TYPE_UNICODE = type(u"") +escape_html = None +if PY2: + from cgi import escape as escape_html +else: + from html import escape as escape_html +assert escape_html is not None + + def _is_unicode(val): """Return boolean indicating if the value is a unicode string. diff --git a/mig/shared/safeinput.py b/mig/shared/safeinput.py index 9042f0132..a225df246 100644 --- a/mig/shared/safeinput.py +++ b/mig/shared/safeinput.py @@ -49,16 +49,8 @@ except ImportError: nbformat = None -PY2 = sys.version_info[0] < 3 - -escape_html = None -if PY2: - from cgi import escape as escape_html -else: - from html import escape as escape_html -assert escape_html is not None - from mig.shared.base import force_unicode, force_utf8 +from mig.shared.compat import escape_html from mig.shared.defaults import src_dst_sep, username_charset, \ username_max_length, session_id_charset, session_id_length, \ subject_id_charset, subject_id_min_length, subject_id_max_length, \ diff --git a/requirements.txt b/requirements.txt index 210b8c9ec..07786232b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # migrid core dependencies on a format suitable for pip install as described on # https://pip.pypa.io/en/stable/reference/requirement-specifiers/ future +python-openid2 pyotp;python_version >= "3" pyotp<2.4;python_version < "3" pyyaml diff --git a/tests/test_mig_server_grid_openid.py b/tests/test_mig_server_grid_openid.py new file mode 100644 index 000000000..af8cb976e --- /dev/null +++ b/tests/test_mig_server_grid_openid.py @@ -0,0 +1,96 @@ +from __future__ import print_function +import os +import sys +from threading import Thread + +from tests.support import PY2, MIG_BASE, testmain, MigTestCase + +from mig.server.grid_openid import ThreadedOpenIDHTTPServer, _extend_configuration, main +from mig.shared.conf import get_configuration_object + +_PYTHON_MAJOR = '2' if PY2 else '3' +_TEST_CONF_DIR = os.path.join( + MIG_BASE, "envhelp/output/testconfs-py%s" % (_PYTHON_MAJOR,)) +_TEST_CONF_FILE = os.path.join(_TEST_CONF_DIR, "MiGserver.conf") + + +if PY2: + from urllib2 import urlopen +else: + from urllib.request import urlopen + + +class TestThreadedOpenIDServerExecutor(Thread): + def __init__(self, configuration): + super(TestThreadedOpenIDServerExecutor, self).__init__() + self.server = ThreadedOpenIDHTTPServer(configuration) + + def run(self): + try: + self.server.serve_forever() + except Exception as e: + pass + + def start_wait_until_ready(self): + self.start() + self.wait_until_ready() + return self + + def stop(self): + self.server.shutdown() + super(TestThreadedOpenIDServerExecutor, self).join() + + def wait_until_ready(self): + self.server.wait_start() + + +class MigServerGrid_openid(MigTestCase): + def before_each(self): + self.server_addr = None + self.server_thread = None + + def after_each(self): + if self.server_thread: + self.server_thread.stop() + + def issue_request(self, request_path): + assert isinstance(request_path, str) and request_path.startswith('/'), "require http path starting with /" + request_url = ''.join(('http://', self.server_addr[0], ':', str(self.server_addr[1]), request_path)) + response = urlopen(request_url, None, timeout=2000) + status = response.getcode() + data = response.read() + return (status, data) + + def test_top_level_request_succeeds_with_status_ok(self): + self.server_addr = ('localhost', 4567) + configuration = self._make_configuration(self.logger, self.server_addr) + self.server_thread = self._make_server(configuration).start_wait_until_ready() + + status, _ = self.issue_request('/') + + self.assertEqual(status, 200) + + @staticmethod + def _make_configuration(test_logger, server_addr): + configuration = get_configuration_object( + _TEST_CONF_FILE, skip_log=True, disable_auth_log=True) + _extend_configuration( + configuration, + server_addr[0], + server_addr[1], + logger=test_logger, + expandusername=False, + host_rsa_key='', + nossl=True, + show_address=False, + show_port=False, + ) + return configuration + + @staticmethod + def _make_server(configuration): + return TestThreadedOpenIDServerExecutor(configuration) + + +if __name__ == '__main__': + testmain()