diff --git a/mig/server/grid_openid.py b/mig/server/grid_openid.py
index ab83fd5bf..e1db1851d 100755
--- a/mig/server/grid_openid.py
+++ b/mig/server/grid_openid.py
@@ -58,17 +58,19 @@
from __future__ import print_function
from __future__ import absolute_import
-from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
-from SocketServer import ThreadingMixIn
+from http.cookies import SimpleCookie, CookieError
+from http.server import HTTPServer, BaseHTTPRequestHandler
+from socketserver import ThreadingMixIn
import base64
-import Cookie
import cgi
import cgitb
+import codecs
import os
import re
import socket
import sys
+import threading
import time
import types
@@ -87,6 +89,7 @@
from mig.shared.accountstate import check_account_accessible
from mig.shared.base import client_id_dir, cert_field_map
from mig.shared.conf import get_configuration_object
+from mig.shared.compat import PY2, escape_html
from mig.shared.griddaemons.openid import default_max_user_hits, \
default_user_abuse_hits, default_proto_abuse_hits, \
default_username_validator, refresh_user_creds, update_login_map, \
@@ -99,14 +102,12 @@
valid_path, valid_ascii, valid_job_id, valid_base_url, valid_url, \
valid_complex_url, InputException
from mig.shared.tlsserver import hardened_ssl_context
-from mig.shared.url import urlparse, urlencode
+from mig.shared.url import urlparse, urlencode, parse_qsl
from mig.shared.useradm import get_openid_user_dn, check_password_scramble, \
check_hash
from mig.shared.userdb import default_db_path
from mig.shared.validstring import possible_user_id
-configuration, logger = None, None
-
# Update with extra fields
cert_field_map.update({'role': 'ROLE', 'timezone': 'TZ', 'nickname': 'NICK',
'fullname': 'CN', 'o': 'O', 'ou': 'OU'})
@@ -120,9 +121,17 @@
pw_regexp = re.compile(pw_pattern)
+if PY2:
+ def _ensure_encoded_string(chunk):
+ return chunk
+else:
+ def _ensure_encoded_string(chunk):
+ return codecs.encode(chunk, 'utf8')
+
+
def quoteattr(val):
"""Escape string for safe printing"""
- esc = cgi.escape(val, 1)
+ esc = escape_html(val, 1)
return '"%s"' % (esc,)
@@ -246,8 +255,14 @@ class OpenIDHTTPServer(HTTPServer):
# any gain and it potentially introduces a race
hash_cache, scramble_cache = None, None
- def __init__(self, *args, **kwargs):
- HTTPServer.__init__(self, *args, **kwargs)
+ def __init__(self, configuration, **kwargs):
+ self.configuration = configuration
+
+ address = configuration.daemon_conf['address']
+ port = configuration.daemon_conf['port']
+
+ addr = (address, port)
+ HTTPServer.__init__(self, addr, OpenIDRequestHandler, **kwargs)
fqdn = self.server_name
port = self.server_port
@@ -270,7 +285,6 @@ def __init__(self, *args, **kwargs):
# We serve from sub dir to ease targeted proxying
self.server_base = 'openid'
self.base_url += "%s/" % self.server_base
-
self.openid = None
self.approved = {}
self.lastCheckIDRequest = {}
@@ -287,6 +301,13 @@ def __init__(self, *args, **kwargs):
cert_field_aliases[name].append(target)
# print "DEBUG: cert field aliases: %s" % cert_field_aliases
+ # Instantiate OpenID consumer store and OpenID consumer. If you
+ # were connecting to a database, you would create the database
+ # connection and instantiate an appropriate store here.
+ data_path = configuration.openid_store
+ store = FileOpenIDStore(data_path)
+ oidserver = server.Server(store, self.base_url + 'openidserver')
+
def expire_volatile(self):
"""Expire old entries in the volatile helper dictionaries"""
if self.last_expire + self.min_expire_delay < time.time():
@@ -299,17 +320,24 @@ def expire_volatile(self):
self.scramble_cache.clear()
logger.debug("Expired old rate limits and scramble cache")
- def setOpenIDServer(self, oidserver):
- """Override openid attribute"""
- self.openid = oidserver
-
class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer):
"""Multi-threaded version of the OpenIDHTTPServer"""
- pass
+ def __init__(self, configuration, **kwargs):
+ self._started = threading.Event()
+ OpenIDHTTPServer.__init__(self, configuration, **kwargs)
+
+ def server_activate(self):
+ OpenIDHTTPServer.server_activate(self)
+ self._started.set()
-class ServerHandler(BaseHTTPRequestHandler):
+ def wait_start(self):
+ if not self._started.is_set():
+ self._started.wait()
+
+
+class OpenIDRequestHandler(BaseHTTPRequestHandler):
"""Override BaseHTTPRequestHandler to handle OpenID protocol"""
# Input validation helper which must hold validators for all valid query
@@ -352,8 +380,10 @@ class ServerHandler(BaseHTTPRequestHandler):
}
def __init__(self, *args, **kwargs):
- if configuration.daemon_conf['session_ttl'] > 0:
- self.session_ttl = configuration.daemon_conf['session_ttl']
+ self.server = args[2]
+
+ if self.daemon_conf['session_ttl'] > 0:
+ self.session_ttl = self.daemon_conf['session_ttl']
else:
self.session_ttl = 48 * 3600
@@ -364,19 +394,32 @@ def __init__(self, *args, **kwargs):
self.retry_url = ''
BaseHTTPRequestHandler.__init__(self, *args, **kwargs)
+ @property
+ def configuration(self):
+ return self.server.configuration
+
+ @property
+ def daemon_conf(self):
+ return self.server.configuration.daemon_conf
+
+ @property
+ def logger(self):
+ return self.server.configuration.daemon_conf['logger']
+
def __retry_url_from_cookie(self):
"""Extract retry_url from cookie and validate"""
+ logger = self.logger
retry_url = None
try:
cookies = self.headers.get('Cookie')
- cookie = Cookie.SimpleCookie(cookies)
+ cookie = SimpleCookie(cookies)
cookie_dict = dict((k, v.value) for k, v in cookie.iteritems())
retry_url = cookie_dict.get('retry_url', '')
if retry_url and retry_url.startswith("http"):
raise InputException("invalid retry_url: %s" % retry_url)
elif retry_url:
valid_url(retry_url)
- except Cookie.CookieError as err:
+ except CookieError as err:
retry_url = None
logger.error("found invalid cookie: %s" % err)
except InputException as exc:
@@ -397,10 +440,11 @@ def do_GET(self):
"""Handle all HTTP GET requests"""
# Make sure key is always available for exception handler
key = 'UNSET'
+ logger = self.logger
try:
self.parsed_uri = urlparse(self.path)
self.query = {}
- for (key, val) in cgi.parse_qsl(self.parsed_uri[4]):
+ for (key, val) in parse_qsl(self.parsed_uri[4]):
# print "DEBUG: checking input arg %s: '%s'" % (key, val)
validate_helper = self.validators.get(key, invalid_argument)
# Let validation errors pass to general exception handler below
@@ -467,9 +511,12 @@ def do_GET(self):
logger.error("do_GET with ref %d crashed: %s" % (error_ref, exc))
# Do not disclose internal details
filtered_exc = cgitb.text(sys.exc_info(), context=10)
- # IMPORTANT: do NOT ever print or log raw password
- pw = self.password
- filtered_exc = filtered_exc.replace(pw, '*' * len(pw))
+
+ if self.password:
+ # IMPORTANT: do NOT ever print or log raw password
+ pw = self.password
+ filtered_exc = filtered_exc.replace(pw, '*' * len(pw))
+
logger.debug("Traceback %s: %s" % (error_ref, filtered_exc))
err_msg = """
Internal error while handling your request - please contact the site support
@@ -477,7 +524,7 @@ def do_GET(self):
Back
-
""" % (configuration.support_email, error_ref)
+""" % (self.daemon_conf['support_email'], error_ref)
self.showErrorPage(err_msg, error_code=500)
def do_POST(self):
@@ -636,7 +683,7 @@ def handleAllow(self, query):
self.password = None
account_accessible = check_account_accessible(
- configuration, self.user, 'openid')
+ self.configuration, self.user, 'openid')
# NOTE: returns None for invalid user, and boolean otherwise
accepted = self.checkLogin(self.user, self.password, client_ip)
if accepted is None:
@@ -647,7 +694,7 @@ def handleAllow(self, query):
# Update rate limits and write to auth log
(authorized, _) = validate_auth_attempt(
- configuration,
+ self.configuration,
'openid',
'password',
self.user,
@@ -836,7 +883,7 @@ def displayResponse(self, response):
Error details:
%s
-''' % cgi.escape(text))
+''' % escape_html(text))
return
self.send_response(webresponse.code)
@@ -855,6 +902,8 @@ def checkLogin(self, username, password, addr):
None if no such user was found or username is invalid.
"""
+ configuration = self.configuration
+
# Only need to update users here
changed_users = []
if possible_user_id(configuration, username):
@@ -921,7 +970,8 @@ def doLogin(self):
invalid_user = False
account_accessible = False
valid_password = False
- daemon_conf = configuration.daemon_conf
+ configuration = self.configuration
+ daemon_conf = self.daemon_conf
max_user_hits = daemon_conf['auth_limits']['max_user_hits']
user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits']
proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits']
@@ -1034,6 +1084,9 @@ def redirect(self, url):
def writeUserHeader(self):
"""Response helper"""
+
+ logger = self.logger
+
# NOTE: we added secure and httponly flags as suggested by OpenVAS
# NOTE: we need to set empty user cookie for logout to work
@@ -1067,7 +1120,7 @@ def showAboutPage(self):
def link(url):
url_attr = quoteattr(url)
- url_text = cgi.escape(url)
+ url_text = escape_html(url)
return '%s
' % (url_attr, url_text)
def term(url, text):
@@ -1099,7 +1152,7 @@ def showPingPage(self):
def link(url):
url_attr = quoteattr(url)
- url_text = cgi.escape(url)
+ url_text = escape_html(url)
return '%s
' % (url_attr, url_text)
# IMPORTANT: This is the format availability checker looks for.
@@ -1308,7 +1361,7 @@ def showIdPage(self, path):
for (aident, trust_root) in self.server.approved:
if aident == ident:
trs = '%s\n' % \
- cgi.escape(trust_root)
+ escape_html(trust_root)
approved_trust_roots.append(trs)
else:
logger.debug("Not disclosing trust roots for %s (active user %s)"
@@ -1440,6 +1493,10 @@ def showLoginPage(self, success_to, fail_to, query):
def showPage(self, response_code, title,
head_extras='', msg=None, err=None, form=None):
"""Show page helper"""
+
+ configuration = self.configuration
+ logger = self.logger
+
if self.user is None:
user_link = 'not logged in.' % \
self.server.server_base
@@ -1519,7 +1576,8 @@ def showPage(self, response_code, title,
self.send_header('Content-type', 'text/html')
self.end_headers()
page_template = openid_page_template(configuration, head_extras)
- self.wfile.write(page_template % fill_helpers)
+ output_string = page_template % fill_helpers
+ self.wfile.write(_ensure_encoded_string(output_string))
def limited_accept(self, *args, **kwargs):
@@ -1552,24 +1610,12 @@ def limited_accept(self, *args, **kwargs):
def start_service(configuration):
"""Service launcher"""
- host = configuration.user_openid_address
- port = configuration.user_openid_port
- data_path = configuration.openid_store
daemon_conf = configuration.daemon_conf
nossl = daemon_conf['nossl']
addr = (host, port)
# TODO: is this threaded version robust enough (thread safety)?
# OpenIDServer = OpenIDHTTPServer
- OpenIDServer = ThreadedOpenIDHTTPServer
- httpserver = OpenIDServer(addr, ServerHandler)
-
- # Instantiate OpenID consumer store and OpenID consumer. If you
- # were connecting to a database, you would create the database
- # connection and instantiate an appropriate store here.
- store = FileOpenIDStore(data_path)
- oidserver = server.Server(store, httpserver.base_url + 'openidserver')
-
- httpserver.setOpenIDServer(oidserver)
+ httpserver = ThreadedOpenIDHTTPServer(configuration, addr)
# Wrap in SSL if enabled
if nossl:
@@ -1601,7 +1647,43 @@ def start_service(configuration):
httpserver.expire_volatile()
-if __name__ == '__main__':
+def _extend_configuration(configuration, address, port, **kwargs):
+ configuration.daemon_conf = {
+ 'address': address,
+ 'port': port,
+ 'root_dir': os.path.abspath(configuration.user_home),
+ 'db_path': os.path.abspath(default_db_path(configuration)),
+ 'session_store': os.path.abspath(configuration.openid_store),
+ 'session_ttl': 24*3600,
+ 'allow_password': 'password' in configuration.user_openid_auth,
+ 'allow_digest': 'digest' in configuration.user_openid_auth,
+ 'allow_publickey': 'publickey' in configuration.user_openid_auth,
+ 'user_alias': configuration.user_openid_alias,
+ 'host_rsa_key': kwargs['host_rsa_key'],
+ 'users': [],
+ 'login_map': {},
+ 'time_stamp': 0,
+ 'logger': kwargs['logger'],
+ 'nossl': kwargs['nossl'],
+ 'expandusername': kwargs['expandusername'],
+ 'show_address': kwargs['show_address'],
+ 'show_port': kwargs['show_port'],
+ 'support_email': configuration.support_email,
+ # TODO: Add the following to configuration:
+ # max_openid_user_hits
+ # max_openid_user_abuse_hits
+ # max_openid_proto_abuse_hits
+ # max_openid_secret_hits
+ 'auth_limits':
+ {'max_user_hits': default_max_user_hits,
+ 'user_abuse_hits': default_user_abuse_hits,
+ 'proto_abuse_hits': default_proto_abuse_hits,
+ 'max_secret_hits': 1,
+ },
+ }
+
+
+def main():
# Force no log init since we use separate logger
configuration = get_configuration_object(skip_log=True)
@@ -1644,9 +1726,7 @@ def start_service(configuration):
unless it is available in mig/server/MiGserver.conf
""")
print(__doc__)
- address = configuration.user_openid_address
- port = configuration.user_openid_port
- session_store = configuration.openid_store
+
default_host_key = """
-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEA404IBMReHOdvhhJ5YtgquY3DNi0v0QwfPUk+EcH/CxFW8UCC
@@ -1676,6 +1756,7 @@ def start_service(configuration):
3stfzMDGtKM9lntAsfFQ8n4yvvEbn/quEWad6srf1yxt9B4t5JA=
-----END RSA PRIVATE KEY-----
"""
+
try:
host_key_fd = open(configuration.user_openid_key, 'r')
host_rsa_key = host_key_fd.read()
@@ -1683,38 +1764,21 @@ def start_service(configuration):
except IOError:
logger.info("No valid host key provided - using default")
host_rsa_key = default_host_key
- configuration.daemon_conf = {
- 'address': address,
- 'port': port,
- 'root_dir': os.path.abspath(configuration.user_home),
- 'db_path': os.path.abspath(default_db_path(configuration)),
- 'session_store': os.path.abspath(configuration.openid_store),
- 'session_ttl': 24*3600,
- 'allow_password': 'password' in configuration.user_openid_auth,
- 'allow_digest': 'digest' in configuration.user_openid_auth,
- 'allow_publickey': 'publickey' in configuration.user_openid_auth,
- 'user_alias': configuration.user_openid_alias,
- 'host_rsa_key': host_rsa_key,
- 'users': [],
- 'login_map': {},
- 'time_stamp': 0,
- 'logger': logger,
- 'nossl': nossl,
- 'expandusername': expandusername,
- 'show_address': show_address,
- 'show_port': show_port,
- # TODO: Add the following to configuration:
- # max_openid_user_hits
- # max_openid_user_abuse_hits
- # max_openid_proto_abuse_hits
- # max_openid_secret_hits
- 'auth_limits':
- {'max_user_hits': default_max_user_hits,
- 'user_abuse_hits': default_user_abuse_hits,
- 'proto_abuse_hits': default_proto_abuse_hits,
- 'max_secret_hits': 1,
- },
- }
+
+ address = configuration.user_openid_address
+ port = configuration.user_openid_port
+ _extend_configuration(
+ configuration,
+ address,
+ port,
+ logger=logger,
+ expandusername=False,
+ host_rsa_key=host_rsa_key,
+ nossl=True,
+ show_address=False,
+ show_port=False,
+ )
+
logger.info("Starting OpenID server")
info_msg = "Listening on address '%s' and port %d" % (address, port)
logger.info(info_msg)
@@ -1728,3 +1792,7 @@ def start_service(configuration):
info_msg = "Leaving with no more workers active"
logger.info(info_msg)
print(info_msg)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/mig/shared/auth.py b/mig/shared/auth.py
index 54c73b57a..b69c9ceed 100644
--- a/mig/shared/auth.py
+++ b/mig/shared/auth.py
@@ -29,11 +29,11 @@
from __future__ import absolute_import
-import Cookie
import base64
import glob
import os
import re
+import sys
import time
# Only needed for 2FA so ignore import error and only fail on use
@@ -42,6 +42,13 @@
except ImportError:
pyotp = None
+PY2 = sys.version_info[0] < 3
+
+if PY2:
+ from Cookie import SimpleCookie
+else:
+ from http.cookies import SimpleCookie
+
from mig.shared.base import client_id_dir, extract_field, force_utf8
from mig.shared.defaults import twofactor_key_name, twofactor_interval_name, \
twofactor_key_bytes, twofactor_cookie_bytes, twofactor_cookie_ttl
@@ -291,7 +298,7 @@ def client_twofactor_session(configuration,
if configuration.site_enable_gdp:
client_id = get_base_client_id(configuration, client_id,
expand_oid_alias=False)
- session_cookie = Cookie.SimpleCookie()
+ session_cookie = SimpleCookie()
session_cookie.load(environ.get('HTTP_COOKIE', ""))
session_cookie = session_cookie.get('2FA_Auth', None)
if session_cookie is None:
diff --git a/mig/shared/compat.py b/mig/shared/compat.py
index 3c5263f78..c96752a51 100644
--- a/mig/shared/compat.py
+++ b/mig/shared/compat.py
@@ -40,6 +40,14 @@
_TYPE_UNICODE = type(u"")
+escape_html = None
+if PY2:
+ from cgi import escape as escape_html
+else:
+ from html import escape as escape_html
+assert escape_html is not None
+
+
def _is_unicode(val):
"""Return boolean indicating if the value is a unicode string.
diff --git a/mig/shared/safeinput.py b/mig/shared/safeinput.py
index 9042f0132..a225df246 100644
--- a/mig/shared/safeinput.py
+++ b/mig/shared/safeinput.py
@@ -49,16 +49,8 @@
except ImportError:
nbformat = None
-PY2 = sys.version_info[0] < 3
-
-escape_html = None
-if PY2:
- from cgi import escape as escape_html
-else:
- from html import escape as escape_html
-assert escape_html is not None
-
from mig.shared.base import force_unicode, force_utf8
+from mig.shared.compat import escape_html
from mig.shared.defaults import src_dst_sep, username_charset, \
username_max_length, session_id_charset, session_id_length, \
subject_id_charset, subject_id_min_length, subject_id_max_length, \
diff --git a/requirements.txt b/requirements.txt
index 210b8c9ec..07786232b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
# migrid core dependencies on a format suitable for pip install as described on
# https://pip.pypa.io/en/stable/reference/requirement-specifiers/
future
+python-openid2
pyotp;python_version >= "3"
pyotp<2.4;python_version < "3"
pyyaml
diff --git a/tests/test_mig_server_grid_openid.py b/tests/test_mig_server_grid_openid.py
new file mode 100644
index 000000000..af8cb976e
--- /dev/null
+++ b/tests/test_mig_server_grid_openid.py
@@ -0,0 +1,96 @@
+from __future__ import print_function
+import os
+import sys
+from threading import Thread
+
+from tests.support import PY2, MIG_BASE, testmain, MigTestCase
+
+from mig.server.grid_openid import ThreadedOpenIDHTTPServer, _extend_configuration, main
+from mig.shared.conf import get_configuration_object
+
+_PYTHON_MAJOR = '2' if PY2 else '3'
+_TEST_CONF_DIR = os.path.join(
+ MIG_BASE, "envhelp/output/testconfs-py%s" % (_PYTHON_MAJOR,))
+_TEST_CONF_FILE = os.path.join(_TEST_CONF_DIR, "MiGserver.conf")
+
+
+if PY2:
+ from urllib2 import urlopen
+else:
+ from urllib.request import urlopen
+
+
+class TestThreadedOpenIDServerExecutor(Thread):
+ def __init__(self, configuration):
+ super(TestThreadedOpenIDServerExecutor, self).__init__()
+ self.server = ThreadedOpenIDHTTPServer(configuration)
+
+ def run(self):
+ try:
+ self.server.serve_forever()
+ except Exception as e:
+ pass
+
+ def start_wait_until_ready(self):
+ self.start()
+ self.wait_until_ready()
+ return self
+
+ def stop(self):
+ self.server.shutdown()
+ super(TestThreadedOpenIDServerExecutor, self).join()
+
+ def wait_until_ready(self):
+ self.server.wait_start()
+
+
+class MigServerGrid_openid(MigTestCase):
+ def before_each(self):
+ self.server_addr = None
+ self.server_thread = None
+
+ def after_each(self):
+ if self.server_thread:
+ self.server_thread.stop()
+
+ def issue_request(self, request_path):
+ assert isinstance(request_path, str) and request_path.startswith('/'), "require http path starting with /"
+ request_url = ''.join(('http://', self.server_addr[0], ':', str(self.server_addr[1]), request_path))
+ response = urlopen(request_url, None, timeout=2000)
+ status = response.getcode()
+ data = response.read()
+ return (status, data)
+
+ def test_top_level_request_succeeds_with_status_ok(self):
+ self.server_addr = ('localhost', 4567)
+ configuration = self._make_configuration(self.logger, self.server_addr)
+ self.server_thread = self._make_server(configuration).start_wait_until_ready()
+
+ status, _ = self.issue_request('/')
+
+ self.assertEqual(status, 200)
+
+ @staticmethod
+ def _make_configuration(test_logger, server_addr):
+ configuration = get_configuration_object(
+ _TEST_CONF_FILE, skip_log=True, disable_auth_log=True)
+ _extend_configuration(
+ configuration,
+ server_addr[0],
+ server_addr[1],
+ logger=test_logger,
+ expandusername=False,
+ host_rsa_key='',
+ nossl=True,
+ show_address=False,
+ show_port=False,
+ )
+ return configuration
+
+ @staticmethod
+ def _make_server(configuration):
+ return TestThreadedOpenIDServerExecutor(configuration)
+
+
+if __name__ == '__main__':
+ testmain()