diff --git a/mig/server/grid_openid.py b/mig/server/grid_openid.py index aad04742e..a77472600 100755 --- a/mig/server/grid_openid.py +++ b/mig/server/grid_openid.py @@ -58,13 +58,14 @@ from __future__ import print_function from __future__ import absolute_import -from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -from SocketServer import ThreadingMixIn +from http.cookies import SimpleCookie, BaseCookie, CookieError +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn import base64 -import Cookie import cgi import cgitb +import codecs import os import re import socket @@ -87,6 +88,7 @@ from mig.shared.accountstate import check_account_accessible from mig.shared.base import client_id_dir, cert_field_map from mig.shared.conf import get_configuration_object +from mig.shared.compat import PY2 from mig.shared.griddaemons.openid import default_max_user_hits, \ default_user_abuse_hits, default_proto_abuse_hits, \ default_username_validator, refresh_user_creds, update_login_map, \ @@ -99,14 +101,12 @@ valid_path, valid_ascii, valid_job_id, valid_base_url, valid_url, \ valid_complex_url, InputException from mig.shared.tlsserver import hardened_ssl_context -from mig.shared.url import urlparse, urlencode, check_local_site_url +from mig.shared.url import urlparse, urlencode, parse_qsl, check_local_site_url from mig.shared.useradm import get_openid_user_dn, check_password_scramble, \ check_hash from mig.shared.userdb import default_db_path from mig.shared.validstring import possible_user_id -configuration, logger = None, None - # Update with extra fields cert_field_map.update({'role': 'ROLE', 'timezone': 'TZ', 'nickname': 'NICK', 'fullname': 'CN', 'o': 'O', 'ou': 'OU'}) @@ -120,9 +120,23 @@ pw_regexp = re.compile(pw_pattern) +if PY2: + def _ensure_encoded_string(chunk): + return chunk +else: + def _ensure_encoded_string(chunk): + return codecs.encode(chunk, 'utf8') + + +if PY2: + from cgi import escape as escape_html +else: + from html import escape as escape_html + + def quoteattr(val): """Escape string for safe printing""" - esc = cgi.escape(val, 1) + esc = escape_html(val, 1) return '"%s"' % (esc,) @@ -201,7 +215,7 @@ def filter_why_pw(configuration, why): return 'Unexpected %s to filter for safe log and output.' % why_type -def lookup_full_user(username): +def lookup_full_user(configuration, username): """Look up the full user identity for username consisting of e.g. just an email address. The method to extract the full identity depends on the back end database. @@ -225,7 +239,7 @@ def lookup_full_user(username): return (username, {}) -def lookup_full_identity(username): +def lookup_full_identity(configuration, username): """Look up the full identity for username consisting of e.g. just an email address. The method to extract the full identity depends on the back end database @@ -236,7 +250,7 @@ def lookup_full_identity(username): """ # print "DEBUG: lookup full ID for %s" % username - return lookup_full_user(username)[0] + return lookup_full_user(configuration, username)[0] class OpenIDHTTPServer(HTTPServer): @@ -253,8 +267,15 @@ class OpenIDHTTPServer(HTTPServer): # any gain and it potentially introduces a race hash_cache, scramble_cache = None, None - def __init__(self, *args, **kwargs): - HTTPServer.__init__(self, *args, **kwargs) + def __init__(self, configuration, **kwargs): + self.configuration = configuration + self._on_start = kwargs.pop('on_start', lambda _: None) + + address = configuration.daemon_conf['address'] + port = configuration.daemon_conf['port'] + + addr = (address, port) + HTTPServer.__init__(self, addr, OpenIDRequestHandler, **kwargs) fqdn = self.server_name port = self.server_port @@ -277,7 +298,6 @@ def __init__(self, *args, **kwargs): # We serve from sub dir to ease targeted proxying self.server_base = 'openid' self.base_url += "%s/" % self.server_base - self.openid = None self.approved = {} self.lastCheckIDRequest = {} @@ -294,8 +314,19 @@ def __init__(self, *args, **kwargs): cert_field_aliases[name].append(target) # print "DEBUG: cert field aliases: %s" % cert_field_aliases + # Instantiate OpenID consumer store and OpenID consumer. If you + # were connecting to a database, you would create the database + # connection and instantiate an appropriate store here. + data_path = configuration.openid_store + store = FileOpenIDStore(data_path) + oidserver = server.Server(store, self.base_url + 'openidserver') + def expire_volatile(self): """Expire old entries in the volatile helper dictionaries""" + + configuration = self.configuration + logger = self.logger + if self.last_expire + self.min_expire_delay < time.time(): self.last_expire = time.time() expire_rate_limit(configuration, "openid", @@ -306,9 +337,9 @@ def expire_volatile(self): self.scramble_cache.clear() logger.debug("Expired old rate limits and scramble cache") - def setOpenIDServer(self, oidserver): - """Override openid attribute""" - self.openid = oidserver + def server_activate(self): + HTTPServer.server_activate(self) + self._on_start(self) class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer): @@ -316,7 +347,7 @@ class ThreadedOpenIDHTTPServer(ThreadingMixIn, OpenIDHTTPServer): pass -class ServerHandler(BaseHTTPRequestHandler): +class OpenIDRequestHandler(BaseHTTPRequestHandler): """Override BaseHTTPRequestHandler to handle OpenID protocol""" # Input validation helper which must hold validators for all valid query @@ -359,8 +390,10 @@ class ServerHandler(BaseHTTPRequestHandler): } def __init__(self, *args, **kwargs): - if configuration.daemon_conf['session_ttl'] > 0: - self.session_ttl = configuration.daemon_conf['session_ttl'] + self.server = args[2] + + if self.daemon_conf['session_ttl'] > 0: + self.session_ttl = self.daemon_conf['session_ttl'] else: self.session_ttl = 48 * 3600 @@ -371,19 +404,32 @@ def __init__(self, *args, **kwargs): self.retry_url = '' BaseHTTPRequestHandler.__init__(self, *args, **kwargs) + @property + def configuration(self): + return self.server.configuration + + @property + def daemon_conf(self): + return self.server.configuration.daemon_conf + + @property + def logger(self): + return self.server.configuration.daemon_conf['logger'] + def __retry_url_from_cookie(self): """Extract retry_url from cookie and validate""" + logger = self.logger retry_url = None try: cookies = self.headers.get('Cookie') - cookie = Cookie.SimpleCookie(cookies) + cookie = SimpleCookie(cookies) cookie_dict = dict((k, v.value) for k, v in cookie.iteritems()) retry_url = cookie_dict.get('retry_url', '') if retry_url and retry_url.startswith("http"): raise InputException("invalid retry_url: %s" % retry_url) elif retry_url: valid_url(retry_url) - except Cookie.CookieError as err: + except CookieError as err: retry_url = None logger.error("found invalid cookie: %s" % err) except InputException as exc: @@ -402,12 +448,17 @@ def clearUser(self): def do_GET(self): """Handle all HTTP GET requests""" + + configuration = self.configuration + logger = self.logger + # Make sure key is always available for exception handler key = 'UNSET' + try: self.parsed_uri = urlparse(self.path) self.query = {} - for (key, val) in cgi.parse_qsl(self.parsed_uri[4]): + for (key, val) in parse_qsl(self.parsed_uri[4]): # print "DEBUG: checking input arg %s: '%s'" % (key, val) validate_helper = self.validators.get(key, invalid_argument) # Let validation errors pass to general exception handler below @@ -474,9 +525,12 @@ def do_GET(self): logger.error("do_GET with ref %d crashed: %s" % (error_ref, exc)) # Do not disclose internal details filtered_exc = cgitb.text(sys.exc_info(), context=10) - # IMPORTANT: do NOT ever print or log raw password - pw = self.password or '' - filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + + if self.password: + # IMPORTANT: do NOT ever print or log raw password + pw = self.password + filtered_exc = filtered_exc.replace(pw, '*' * len(pw)) + logger.debug("Traceback %s: %s" % (error_ref, filtered_exc)) err_msg = """
Internal error while handling your request - please contact the site support @@ -489,6 +543,10 @@ def do_GET(self): def do_POST(self): """Handle all HTTP POST requests""" + + configuration = self.configuration + logger = self.logger + try: self.parsed_uri = urlparse(self.path) @@ -563,6 +621,11 @@ def handleAllow(self, query): Must verify user is already logged in or validate username/password pair against user DB. """ + + configuration = self.configuration + daemon_conf = self.daemon_conf + logger = self.logger + # Use client address directly but with optional local proxy override hashed_secret = None exceeded_rate_limit = False @@ -570,7 +633,6 @@ def handleAllow(self, query): invalid_user = False account_accessible = False valid_password = False - daemon_conf = configuration.daemon_conf max_user_hits = daemon_conf['auth_limits']['max_user_hits'] user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits'] proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits'] @@ -628,7 +690,7 @@ def handleAllow(self, query): if request.idSelect(): # Do any ID expansion to a specified format if daemon_conf['expandusername']: - user_id = lookup_full_identity(query.get('identifier', '')) + user_id = lookup_full_identity(configuration, query.get('identifier', '')) else: user_id = query.get('identifier', '') identity = self.server.base_url + 'id/' + user_id @@ -654,7 +716,7 @@ def handleAllow(self, query): self.password = None account_accessible = check_account_accessible( - configuration, self.user, 'openid') + self.configuration, self.user, 'openid') # NOTE: returns None for invalid user, and boolean otherwise accepted = self.checkLogin(self.user, self.password, client_ip) if accepted is None: @@ -665,7 +727,7 @@ def handleAllow(self, query): # Update rate limits and write to auth log (authorized, _) = validate_auth_attempt( - configuration, + self.configuration, 'openid', 'password', self.user, @@ -721,17 +783,20 @@ def handleAllow(self, query): def setUser(self): """Read any saved user value from cookie""" + + logger = self.logger + cookies = self.headers.get('Cookie') # print "found cookies: %s" % cookies if cookies: - morsel = Cookie.BaseCookie(cookies).get('user') + morsel = BaseCookie(cookies).get('user') # Added morsel value check here since IE sends empty string from # cookie after initial user=;expire is sent. Others leave it out. if morsel and morsel.value != '': self.user = morsel.value expire = int(time.time() + self.session_ttl) - morsel = Cookie.BaseCookie(cookies).get('session_expire') + morsel = BaseCookie(cookies).get('session_expire') if morsel and morsel.value != '': # print "found user session_expire value: %s" % morsel.value if morsel.value.isdigit() and int(morsel.value) <= expire: @@ -754,6 +819,10 @@ def isAuthorized(self, identity_url, trust_root): def serverEndPoint(self, query): """End-point handler""" + + configuration = self.configuration + logger = self.logger + try: request = self.server.openid.decodeRequest(query) # Pass any errors from previous login attempts on for display @@ -794,9 +863,10 @@ def addSRegResponse(self, request, response): return sreg_req = sreg.SRegRequest.fromOpenIDRequest(request) - (username, user) = lookup_full_user(self.user) + (username, user) = lookup_full_user(self.configuration, self.user) if not user: + logger = self.logger logger.warning("addSRegResponse user lookup failed!") return @@ -835,6 +905,9 @@ def rejected(self, request, identifier=None): def handleCheckIDRequest(self, request): """Check ID handler""" + + logger = self.logger + logger.debug("handleCheckIDRequest with req %s" % request) is_authorized = self.isAuthorized(request.identity, request.trust_root) if is_authorized: @@ -851,6 +924,10 @@ def handleCheckIDRequest(self, request): def displayResponse(self, response): """Response helper""" + + configuration = self.configuration + logger = self.logger + try: webresponse = self.server.openid.encodeResponse(response) except server.EncodingError as why: @@ -864,7 +941,7 @@ def displayResponse(self, response): inconsistent session state.
%s-''' % (configuration.short_title, cgi.escape(safe_why))) +''' % (configuration.short_title, escape_html(safe_why))) return self.send_response(webresponse.code) @@ -883,6 +960,9 @@ def checkLogin(self, username, password, addr): None if no such user was found or username is invalid. """ + configuration = self.configuration + logger = self.logger + # Only need to update users here changed_users = [] if possible_user_id(configuration, username): @@ -943,13 +1023,17 @@ def checkLogin(self, username, password, addr): def doLogin(self): """Login handler""" + + configuration = self.configuration + daemon_conf = self.daemon_conf + logger = self.logger + hashed_secret = None exceeded_rate_limit = False invalid_username = False invalid_user = False account_accessible = False valid_password = False - daemon_conf = configuration.daemon_conf max_user_hits = daemon_conf['auth_limits']['max_user_hits'] user_abuse_hits = daemon_conf['auth_limits']['user_abuse_hits'] proto_abuse_hits = daemon_conf['auth_limits']['proto_abuse_hits'] @@ -1049,6 +1133,9 @@ def doLogin(self): def doLogout(self): """Logout handler""" + + logger = self.logger + logger.debug("logout clearing user %s" % self.user) self.clearUser() return_to_url = self.query.get('return_to', None) @@ -1058,7 +1145,9 @@ def doLogout(self): def redirect(self, url): """Redirect helper with built-in check for safe destination URL""" - if url and not check_local_site_url(configuration, url): + logger = self.logger + + if url and not check_local_site_url(self.configuration, url): logger.error("reject redirect to external URL %r" % url) self.send_response(400) else: @@ -1071,6 +1160,9 @@ def redirect(self, url): def writeUserHeader(self): """Response helper""" + + logger = self.logger + # NOTE: we added secure and httponly flags as suggested by OpenVAS # NOTE: we need to set empty user cookie for logout to work @@ -1104,7 +1196,7 @@ def showAboutPage(self): def link(url): url_attr = quoteattr(url) - url_text = cgi.escape(url) + url_text = escape_html(url) return '
%s
' % (url_attr, url_text)
def term(url, text):
@@ -1136,7 +1228,7 @@ def showPingPage(self):
def link(url):
url_attr = quoteattr(url)
- url_text = cgi.escape(url)
+ url_text = escape_html(url)
return '%s
' % (url_attr, url_text)
# IMPORTANT: This is the format availability checker looks for.
@@ -1189,6 +1281,9 @@ def showErrorPage(self, error_message, error_code=400):
def showDecidePage(self, request):
"""Decide page provider"""
+
+ configuration = self.configuration
+
id_url_base = self.server.base_url + 'id/'
# XXX: This may break if there are any synonyms for id_url_base,
# such as referring to it by IP address or a CNAME.
@@ -1331,6 +1426,9 @@ def showDecidePage(self, request):
def showIdPage(self, path):
"""User info page provider"""
+
+ logger = self.logger
+
link_tag = '' % \
self.server.base_url
yadis_loc_tag = '' % \
@@ -1345,7 +1443,7 @@ def showIdPage(self, path):
for (aident, trust_root) in self.server.approved:
if aident == ident:
trs = '