From 9bd423ced1d82c3f0f4a417144cbc7e8e3d2c67a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Collonval?= Date: Wed, 22 Dec 2021 16:46:00 +0100 Subject: [PATCH 1/7] PoC filtering using OpenAPI spec --- jupyter_server/extension/application.py | 46 +++- jupyter_server/firewall.py | 75 ++++++ jupyter_server/serverapp.py | 36 +++ jupyter_server/tests/test_firewall.py | 318 ++++++++++++++++++++++++ setup.cfg | 5 +- 5 files changed, 470 insertions(+), 10 deletions(-) create mode 100644 jupyter_server/firewall.py create mode 100644 jupyter_server/tests/test_firewall.py diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index a0cd5c3551..ee98ea0ff0 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -1,6 +1,8 @@ import logging import re import sys +import typing +from typing import Iterable, Optional from jinja2 import Environment from jinja2 import FileSystemLoader @@ -96,7 +98,9 @@ def _prepare_templates(self): self.initialize_templates() # Add templates to web app settings if extension has templates. if len(self.template_paths) > 0: - self.settings.update({"{}_template_paths".format(self.name): self.template_paths}) + self.settings.update( + {"{}_template_paths".format(self.name): self.template_paths} + ) # Create a jinja environment for logging html templates. self.jinja2_env = Environment( @@ -136,6 +140,12 @@ class method. This method can be set as a entry_point in the extensions setup.py """ + # Filtering rules to apply on handlers registration + # Subclasses can override this list to filter handlers + # They will be applied on the ServerApp + __allowed_spec: Optional[dict] = None + __blocked_spec: Optional[dict] = None + # Subclasses should override this trait. Tells the server if # this extension allows other other extensions to be loaded # side-by-side when launched directly. @@ -180,6 +190,10 @@ def get_extension_package(cls): def get_extension_point(cls): return cls.__module__ + @classmethod + def get_firewall_rules(cls): + return {"allowed": cls.__allowed_spec, "blocked": cls.__blocked_spec} + # Extension URL sets the default landing page for this extension. extension_url = "/" @@ -240,7 +254,9 @@ def _default_static_url_prefix(self): ), ).tag(config=True) - settings = Dict(help=_i18n("""Settings that will passed to the server.""")).tag(config=True) + settings = Dict(help=_i18n("""Settings that will passed to the server.""")).tag( + config=True + ) handlers = List(help=_i18n("""Handlers appended to the server.""")).tag(config=True) @@ -333,7 +349,9 @@ def _prepare_handlers(self): def _prepare_templates(self): # Add templates to web app settings if extension has templates. if len(self.template_paths) > 0: - self.settings.update({"{}_template_paths".format(self.name): self.template_paths}) + self.settings.update( + {"{}_template_paths".format(self.name): self.template_paths} + ) self.initialize_templates() def _jupyter_server_config(self): @@ -452,7 +470,11 @@ def load_classic_server_extension(cls, serverapp): ( r"/static/favicons/favicon.ico", RedirectHandler, - {"url": url_path_join(serverapp.base_url, "static/base/images/favicon.ico")}, + { + "url": url_path_join( + serverapp.base_url, "static/base/images/favicon.ico" + ) + }, ), ( r"/static/favicons/favicon-busy-1.ico", @@ -495,7 +517,8 @@ def load_classic_server_extension(cls, serverapp): RedirectHandler, { "url": url_path_join( - serverapp.base_url, "static/base/images/favicon-notebook.ico" + serverapp.base_url, + "static/base/images/favicon-notebook.ico", ) }, ), @@ -504,14 +527,19 @@ def load_classic_server_extension(cls, serverapp): RedirectHandler, { "url": url_path_join( - serverapp.base_url, "static/base/images/favicon-terminal.ico" + serverapp.base_url, + "static/base/images/favicon-terminal.ico", ) }, ), ( r"/static/logo/logo.png", RedirectHandler, - {"url": url_path_join(serverapp.base_url, "static/base/images/logo.png")}, + { + "url": url_path_join( + serverapp.base_url, "static/base/images/logo.png" + ) + }, ), ] ) @@ -532,7 +560,9 @@ def initialize_server(cls, argv=[], load_other_extensions=True, **kwargs): jpserver_extensions.update(cls.serverapp_config["jpserver_extensions"]) cls.serverapp_config["jpserver_extensions"] = jpserver_extensions find_extensions = False - serverapp = ServerApp.instance(jpserver_extensions=jpserver_extensions, **kwargs) + serverapp = ServerApp.instance( + jpserver_extensions=jpserver_extensions, **kwargs + ) serverapp.aliases.update(cls.aliases) serverapp.initialize( argv=argv, diff --git a/jupyter_server/firewall.py b/jupyter_server/firewall.py new file mode 100644 index 0000000000..e4c8925971 --- /dev/null +++ b/jupyter_server/firewall.py @@ -0,0 +1,75 @@ +from typing import Optional, Union + +from openapi_core import create_spec +from tornado import httpclient, httputil +from tornado.log import access_log +from tornado_openapi3 import RequestValidator + + +class FireWall: + """Validate server request against a list of allowed and blocked OpenAPI v3 specifications. + + If allowed and blocked specifications are defined, the request must be allowed and not blocked; + i.e. blocked specification takes precedence. + + Args: + allowed_spec: [optional] Allowed endpoints + blocked_spec: [optional] Blocked endpoints + """ + + def __init__(self, base_url: str, allowed_spec: Optional[dict], blocked_spec: Optional[dict]): + self.__allowed_validator: Optional[RequestValidator] = None + self.__blocked_validator: Optional[RequestValidator] = None + + def add_base_url_server(spec: dict): + servers = spec.get("servers", []) + if not any(map(lambda s: s.get("url") == base_url, servers)): + servers.append({ + "url": base_url + }) + spec["servers"] = servers + + if allowed_spec is not None: + add_base_url_server(allowed_spec) + self.__allowed_validator = RequestValidator(create_spec(allowed_spec)) + if blocked_spec is not None: + add_base_url_server(blocked_spec) + self.__blocked_validator = RequestValidator(create_spec(blocked_spec)) + + def validate( + self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] + ) -> bool: + """Validate a request against allowed and blocked specifications. + + Args: + request: Request to validate + Returns: + Whether the request is valid or not. + """ + allowed_result = ( + None + if self.__allowed_validator is None + else self.__allowed_validator.validate(request) + ) + + blocked_result = ( + None + if self.__blocked_validator is None + else self.__blocked_validator.validate(request) + ) + + allowed = (allowed_result is None or len(allowed_result.errors) == 0) + not_blocked = ( + blocked_result is None or len(blocked_result.errors) > 0 + ) + + # The error raised if this is not valid will be logged + # So we only give the reason in debug level + if (not (allowed and not_blocked)): + if(not allowed): + # Provides only the first error + access_log.debug(f"Request not allowed: {allowed_result.errors[0]!s}") + elif (not not_blocked): + access_log.debug(f"Request blocked.") + + return allowed and not_blocked diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index ce4ec2b5d3..ceb343b989 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -25,9 +25,15 @@ import sys import threading import time +import typing import urllib import webbrowser from base64 import encodebytes +from typing import Iterable, Optional + +from tornado import httputil + +from jupyter_server.firewall import FireWall try: import resource @@ -222,7 +228,16 @@ def __init__( default_url, settings_overrides, jinja_env_options, + endpoints_filters = None ): + if endpoints_filters is None: + self.__firewall = FireWall(base_url, None, None) + else: + self.__firewall = FireWall( + base_url, + endpoints_filters.get('allowed'), + endpoints_filters.get('blocked') + ) settings = self.init_settings( jupyter_app, @@ -433,6 +448,14 @@ def init_handlers(self, default_services, settings): new_handlers.append((r"(.*)", Template404)) return new_handlers + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> "web._HandlerDelegate": + if self.__firewall.validate(request): + return super().find_handler(request, **kwargs) + else: + return self.get_handler_delegate(request, web.ErrorHandler, {"status_code": 403}) + def last_activity(self): """Get a UTC timestamp for when the server last did something. @@ -756,6 +779,11 @@ class ServerApp(JupyterApp): "view", ) + # Filtering rules to apply on handlers registration + # Subclasses can override this list to filter handlers + __allowed_spec: Optional[dict] = None + __blocked_spec: Optional[dict] = None + _log_formatter_cls = LogFormatter @default("log_level") @@ -1843,6 +1871,10 @@ def init_webapp(self): self.default_url, self.tornado_settings, self.jinja_environment_options, + endpoints_filters={ + "allowed": self.__allowed_spec, + "blocked": self.__blocked_spec + } ) if self.certfile: self.ssl_options["certfile"] = self.certfile @@ -2316,6 +2348,10 @@ def initialize( # Set starter_app property. if point.app: self._starter_app = point.app + # Apply endpoint filters from the extension app + firewall_rules = point.app.get_firewall_rules() + self.__allowed_spec = firewall_rules["allowed"] + self.__blocked_spec = firewall_rules["blocked"] # Load any configuration that comes from the Extension point. self.update_config(Config(point.config)) diff --git a/jupyter_server/tests/test_firewall.py b/jupyter_server/tests/test_firewall.py new file mode 100644 index 0000000000..389ba25424 --- /dev/null +++ b/jupyter_server/tests/test_firewall.py @@ -0,0 +1,318 @@ +import logging + +import pytest +from tornado.httputil import HTTPServerRequest, HTTPHeaders +from tornado.log import access_log +from jupyter_server.firewall import FireWall + +access_log.setLevel(logging.DEBUG) + +spec1 = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/pet": { + "put": { + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Pet"} + }, + }, + "required": True, + }, + "responses": {"400": {"description": ""}}, + } + }, + "/pet/findByTags": { + "get": { + "parameters": [ + { + "name": "tags", + "in": "query", + "description": "Tags to filter by", + "required": True, + "style": "form", + "schema": {"type": "array", "items": {"type": "string"}}, + } + ], + "responses": { + "200": { + "description": "successful operation", + } + }, + } + }, + "/pet/{petId}": { + "get": { + "parameters": [ + { + "name": "petId", + "in": "path", + "description": "ID of pet to return", + "required": True, + "schema": {"type": "integer", "format": "int64"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + "components": { + "schemas": { + "Pet": { + "required": ["name"], + "type": "object", + "properties": { + "id": {"type": "integer", "format": "int64"}, + "name": {"type": "string", "example": "doggie"}, + }, + }, + } + }, +} + +spec2 = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/pet": { + "put": { + "responses": {"400": {"description": ""}}, + } + }, + "/user/{anyId}": { + "get": { + "parameters": [ + { + "name": "anyId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + } + }, + }, +} + + +def format_id(val): + if isinstance(val, HTTPServerRequest): + return f"{val.method}-{val.uri}-{val.body.decode('utf-8') or None}" + + +@pytest.mark.parametrize( + "base_url, server_request, expected", + ( + ( + "/", + HTTPServerRequest( + "put", + "/pet", + body=b'{"name":"puppy"}', + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + ), + ( + "/dummy/base_url/", + HTTPServerRequest( + "put", + "/dummy/base_url/pet", + body=b'{"name":"puppy"}', + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + ), + # No body + ( + "/", + HTTPServerRequest( + "put", + "/pet", + body=b"", + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + ), + # Wrong body + ( + "/", + HTTPServerRequest( + "put", + "/pet", + body=b'{"id":"puppy"}', + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + ), + # Method not allowed + ( + "/", + HTTPServerRequest( + "get", + "/pet", + body=b'{"name":"puppy"}', + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + ), + ( + "/", + HTTPServerRequest( + "get", + "/pet/22", + host="localhost:8888", + ), + True, + ), + # Wrong path parameter + ( + "/", + HTTPServerRequest( + "get", + "/pet/hello", + host="localhost:8888", + ), + False, + ), + ( + "/", + HTTPServerRequest( + "get", + "/pet/findByTags?tags=cat", + host="localhost:8888", + ), + True, + ), + # Missing query parameter + ( + "/", + HTTPServerRequest( + "get", + "/pet/findByTags", + host="localhost:8888", + ), + False, + ), + ), + ids=format_id, +) +def test_Firewall_allowed_spec(base_url, server_request, expected): + firewall = FireWall(base_url, spec1, None) + + assert firewall.validate(server_request) == expected + + +@pytest.mark.parametrize( + "base_url, server_request, expected", + ( + ( + "/", + HTTPServerRequest( + "put", + "/pet", + body=b'{"name":"puppy"}', + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + ), + ( + "/dummy/base_url/", + HTTPServerRequest( + "put", + "/dummy/base_url/pet", + body=b'{"name":"puppy"}', + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + ), + # No body + ( + "/", + HTTPServerRequest( + "put", + "/pet", + body=b"", + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + ), + ( + "/", + HTTPServerRequest( + "put", + "/pet?tags=22", + host="localhost:8888", + ), + False, + ), + # Other method + ( + "/", + HTTPServerRequest( + "get", + "/pet", + body=b'{"name":"puppy"}', + headers=HTTPHeaders({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + ), + ( + "/", + HTTPServerRequest( + "put", + "/pet/22", + host="localhost:8888", + ), + True, + ), + ( + "/", + HTTPServerRequest( + "get", + "/user?id=42", + host="localhost:8888", + ), + True, + ), + ( + "/", + HTTPServerRequest( + "get", + "/user/john/smith", + host="localhost:8888", + ), + False, + ), + ( + "/", + HTTPServerRequest( + "get", + "/user/william?page=42", + host="localhost:8888", + ), + False, + ), + ), + ids=format_id, +) +def test_Firewall_blocked_spec(base_url, server_request, expected): + firewall = FireWall(base_url, None, spec2) + + assert firewall.validate(server_request) == expected + + +def test_Firewall_allowed_and_blocked_spec(): + pass diff --git a/setup.cfg b/setup.cfg index 123c82e102..fb5a9e855e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,19 +16,20 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: BSD License Programming Language :: Python - Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 [options] zip_safe = False include_package_data = True packages = find: -python_requires = >=3.6 +python_requires = >=3.7 install_requires = jinja2 tornado>=6.1.0 + tornado_openapi3 pyzmq>=17 argon2-cffi ipython_genutils From ed715ce62450d5bf3c6a6b73932851545a3efc4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Collonval?= Date: Tue, 28 Dec 2021 12:29:21 +0100 Subject: [PATCH 2/7] Apply review comments and add slash encoder --- jupyter_server/extension/application.py | 4 +- jupyter_server/firewall.py | 75 --- jupyter_server/serverapp.py | 17 +- jupyter_server/specvalidator.py | 229 +++++++++ jupyter_server/tests/test_firewall.py | 318 ------------- jupyter_server/tests/test_specvalidator.py | 527 +++++++++++++++++++++ setup.cfg | 10 +- 7 files changed, 775 insertions(+), 405 deletions(-) delete mode 100644 jupyter_server/firewall.py create mode 100644 jupyter_server/specvalidator.py delete mode 100644 jupyter_server/tests/test_firewall.py create mode 100644 jupyter_server/tests/test_specvalidator.py diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index ee98ea0ff0..92687f5773 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -143,8 +143,8 @@ class method. This method can be set as a entry_point in # Filtering rules to apply on handlers registration # Subclasses can override this list to filter handlers # They will be applied on the ServerApp - __allowed_spec: Optional[dict] = None - __blocked_spec: Optional[dict] = None + _allowed_spec: Optional[dict] = None + _blocked_spec: Optional[dict] = None # Subclasses should override this trait. Tells the server if # this extension allows other other extensions to be loaded diff --git a/jupyter_server/firewall.py b/jupyter_server/firewall.py deleted file mode 100644 index e4c8925971..0000000000 --- a/jupyter_server/firewall.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional, Union - -from openapi_core import create_spec -from tornado import httpclient, httputil -from tornado.log import access_log -from tornado_openapi3 import RequestValidator - - -class FireWall: - """Validate server request against a list of allowed and blocked OpenAPI v3 specifications. - - If allowed and blocked specifications are defined, the request must be allowed and not blocked; - i.e. blocked specification takes precedence. - - Args: - allowed_spec: [optional] Allowed endpoints - blocked_spec: [optional] Blocked endpoints - """ - - def __init__(self, base_url: str, allowed_spec: Optional[dict], blocked_spec: Optional[dict]): - self.__allowed_validator: Optional[RequestValidator] = None - self.__blocked_validator: Optional[RequestValidator] = None - - def add_base_url_server(spec: dict): - servers = spec.get("servers", []) - if not any(map(lambda s: s.get("url") == base_url, servers)): - servers.append({ - "url": base_url - }) - spec["servers"] = servers - - if allowed_spec is not None: - add_base_url_server(allowed_spec) - self.__allowed_validator = RequestValidator(create_spec(allowed_spec)) - if blocked_spec is not None: - add_base_url_server(blocked_spec) - self.__blocked_validator = RequestValidator(create_spec(blocked_spec)) - - def validate( - self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] - ) -> bool: - """Validate a request against allowed and blocked specifications. - - Args: - request: Request to validate - Returns: - Whether the request is valid or not. - """ - allowed_result = ( - None - if self.__allowed_validator is None - else self.__allowed_validator.validate(request) - ) - - blocked_result = ( - None - if self.__blocked_validator is None - else self.__blocked_validator.validate(request) - ) - - allowed = (allowed_result is None or len(allowed_result.errors) == 0) - not_blocked = ( - blocked_result is None or len(blocked_result.errors) > 0 - ) - - # The error raised if this is not valid will be logged - # So we only give the reason in debug level - if (not (allowed and not_blocked)): - if(not allowed): - # Provides only the first error - access_log.debug(f"Request not allowed: {allowed_result.errors[0]!s}") - elif (not not_blocked): - access_log.debug(f"Request blocked.") - - return allowed and not_blocked diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index ceb343b989..4b0403d668 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -33,7 +33,10 @@ from tornado import httputil -from jupyter_server.firewall import FireWall +try: + from jupyter_server.specvalidator import SpecValidator +except ImportError: + SpecValidator = None try: import resource @@ -230,10 +233,10 @@ def __init__( jinja_env_options, endpoints_filters = None ): - if endpoints_filters is None: - self.__firewall = FireWall(base_url, None, None) + if SpecValidator is None or endpoints_filters is None: + self.__requestValidator: Optional[SpecValidator] = None else: - self.__firewall = FireWall( + self.__requestValidator: Optional[SpecValidator] = SpecValidator( base_url, endpoints_filters.get('allowed'), endpoints_filters.get('blocked') @@ -451,7 +454,7 @@ def init_handlers(self, default_services, settings): def find_handler( self, request: httputil.HTTPServerRequest, **kwargs: Any ) -> "web._HandlerDelegate": - if self.__firewall.validate(request): + if self.__requestValidator.validate(request): return super().find_handler(request, **kwargs) else: return self.get_handler_delegate(request, web.ErrorHandler, {"status_code": 403}) @@ -781,8 +784,8 @@ class ServerApp(JupyterApp): # Filtering rules to apply on handlers registration # Subclasses can override this list to filter handlers - __allowed_spec: Optional[dict] = None - __blocked_spec: Optional[dict] = None + _allowed_spec: Optional[dict] = None + _blocked_spec: Optional[dict] = None _log_formatter_cls = LogFormatter diff --git a/jupyter_server/specvalidator.py b/jupyter_server/specvalidator.py new file mode 100644 index 0000000000..9dc6b5c3e9 --- /dev/null +++ b/jupyter_server/specvalidator.py @@ -0,0 +1,229 @@ +import itertools +import re +from typing import Iterable, Optional, Union +from urllib.parse import parse_qsl + +from openapi_core import create_spec +from openapi_core.validation.request import validators +from openapi_core.validation.request.datatypes import ( + OpenAPIRequest, + RequestParameters, + RequestValidationResult, +) +from openapi_spec_validator.handlers import base +from tornado import httpclient, httputil +from tornado.log import access_log +from tornado_openapi3.util import parse_mimetype +from werkzeug.datastructures import Headers, ImmutableMultiDict + + +def encode_slash(regex: Iterable[Union[str, re.Pattern]], path: str) -> str: + """Encode slash (``%2F``) for regex groups found in URL path (it never contains the query arguments). + + Note: + The regex order matters as only the first regex matching the path will be applied. + + Args: + regex: List of regex to test the path against + path: URL path to escape + + Returns: + Escaped path + """ + regex = [re.compile(r) if isinstance(r, str) else r for r in regex] + for r in regex: + m = re.search(r, path) + if m is not None and m.lastindex is not None: + # Start from latest group to the first one so that position is preserved + # Skip the first group as it matches the full path not only the group to be encoded + for index in range(m.lastindex, 0, -1): + path = ( + path[: m.start(index)] + + path[m.start(index) : m.end(index)].replace("/", r"%2F") + + path[m.end(index) :] + ) + # Break at first match + break + + return path + + +# Ref: https://github.com/correl/tornado-openapi3/blob/master/tornado_openapi3/requests.py +class TornadoRequestFactory: + """Factory for converting Tornado requests to OpenAPI request objects.""" + + @classmethod + def create( + cls, + request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest], + encoded_slash_regex: Iterable[re.Pattern], + ) -> OpenAPIRequest: + """Creates an OpenAPI request from Tornado request objects. + + Supports both :class:`tornado.httpclient.HTTPRequest` and + :class:`tornado.httputil.HTTPServerRequest` objects. + """ + if isinstance(request, httpclient.HTTPRequest): + if request.url: + path, _, querystring = request.url.partition("?") + query_arguments: ImmutableMultiDict[str, str] = ImmutableMultiDict( + parse_qsl(querystring) + ) + else: + path = "" + query_arguments = ImmutableMultiDict() + else: + path, _, _ = request.full_url().partition("?") + if path == "://": + path = "" + query_arguments = ImmutableMultiDict( + itertools.chain( + *[ + [(k, v.decode("utf-8")) for v in vs] + for k, vs in request.query_arguments.items() + ] + ) + ) + + # Encode slashes in path to be compliant with Open API specification + # e.g. /api/contents/path/to/file.txt -> /api/contents/path%2Fto%2Ffile.txt + path = encode_slash(encoded_slash_regex, path) + + return OpenAPIRequest( + full_url_pattern=path, + method=request.method.lower() if request.method else "get", + parameters=RequestParameters( + query=query_arguments, + header=Headers(request.headers.get_all()), + cookie=httputil.parse_cookie(request.headers.get("Cookie", "")), + ), + body=request.body if request.body else b"", + mimetype=parse_mimetype( + request.headers.get("Content-Type", "application/x-www-form-urlencoded") + ), + ) + + +class RequestValidator(validators.RequestValidator): + """Validator for Tornado HTTP Requests. + + Args: + base_url: [optional] Server base URL + custom_formatters: [optional] + custom_media_type_deserializers: [optional] + encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped + """ + + def __init__( + self, + spec, + base_url: Optional[str] = None, + custom_formatters=None, + custom_media_type_deserializers=None, + encoded_slash_regex: Optional[Iterable[re.Pattern]] = None, + ): + super().__init__( + spec, + base_url=base_url, + custom_formatters=custom_formatters, + custom_media_type_deserializers=custom_media_type_deserializers, + ) + self.__encoded_slash_regex = encoded_slash_regex + + def validate( + self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] + ) -> RequestValidationResult: + """Validate a Tornado HTTP request object.""" + return super().validate( + TornadoRequestFactory.create(request, self.__encoded_slash_regex) + ) + + +class SpecValidator: + """Validate server request against a list of allowed and blocked OpenAPI v3 specifications. + + If allowed and blocked specifications are defined, the request must be allowed and not blocked; + i.e. blocked specification takes precedence. + + Note: + OpenAPI does not accept path argument containing ``/``. Therefore you should provide regex to encode + them; e.g. ``"/api/contents/([^/]+(?:/[^/]+)*?)$"`` to match ``/api/contents/{path}``. The order in + which you provide the regex are important as only the first regex matching the path will be applied. + You can test your expression using :ref:`jupyter_server.specvalidator.encode_slash`. + + Args: + base_url: [optional] Server base URL + allowed_spec: [optional] Allowed endpoints + blocked_spec: [optional] Blocked endpoints + encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped + """ + + def __init__( + self, + base_url: Optional[str] = None, + allowed_spec: Optional[dict] = None, + blocked_spec: Optional[dict] = None, + encoded_slash_regex: Optional[Iterable[str]] = None, + ): + self.__allowed_validator: Optional[RequestValidator] = None + self.__blocked_validator: Optional[RequestValidator] = None + slash_regex: Iterable[re.Pattern] = list( + map(lambda s: re.compile(s), encoded_slash_regex or []) + ) + + def add_base_url_server(spec: dict): + servers = spec.get("servers", []) + if not any(map(lambda s: s.get("url") == base_url, servers)): + servers.append({"url": base_url}) + spec["servers"] = servers + + if allowed_spec is not None: + if base_url is not None: + add_base_url_server(allowed_spec) + self.__allowed_validator = RequestValidator( + create_spec(allowed_spec), + encoded_slash_regex=slash_regex, + ) + if blocked_spec is not None: + if base_url is not None: + add_base_url_server(blocked_spec) + self.__blocked_validator = RequestValidator( + create_spec(blocked_spec), + encoded_slash_regex=slash_regex, + ) + + def validate( + self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] + ) -> bool: + """Validate a request against allowed and blocked specifications. + + Args: + request: Request to validate + Returns: + Whether the request is valid or not. + """ + allowed_result = ( + None + if self.__allowed_validator is None + else self.__allowed_validator.validate(request) + ) + + blocked_result = ( + None + if self.__blocked_validator is None + else self.__blocked_validator.validate(request) + ) + + allowed = allowed_result is None or len(allowed_result.errors) == 0 + not_blocked = blocked_result is None or len(blocked_result.errors) > 0 + + # The error raised if this is not valid will be logged + # So we only give the reason in debug level + if not (allowed and not_blocked): + if not allowed: + # Provides only the first error + access_log.debug(f"Request not allowed: {allowed_result.errors[0]!s}") + elif not not_blocked: + access_log.debug(f"Request blocked.") + + return allowed and not_blocked diff --git a/jupyter_server/tests/test_firewall.py b/jupyter_server/tests/test_firewall.py deleted file mode 100644 index 389ba25424..0000000000 --- a/jupyter_server/tests/test_firewall.py +++ /dev/null @@ -1,318 +0,0 @@ -import logging - -import pytest -from tornado.httputil import HTTPServerRequest, HTTPHeaders -from tornado.log import access_log -from jupyter_server.firewall import FireWall - -access_log.setLevel(logging.DEBUG) - -spec1 = { - "openapi": "3.0.1", - "info": {"title": "Test specs", "version": "0.0.1"}, - "paths": { - "/pet": { - "put": { - "requestBody": { - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/Pet"} - }, - }, - "required": True, - }, - "responses": {"400": {"description": ""}}, - } - }, - "/pet/findByTags": { - "get": { - "parameters": [ - { - "name": "tags", - "in": "query", - "description": "Tags to filter by", - "required": True, - "style": "form", - "schema": {"type": "array", "items": {"type": "string"}}, - } - ], - "responses": { - "200": { - "description": "successful operation", - } - }, - } - }, - "/pet/{petId}": { - "get": { - "parameters": [ - { - "name": "petId", - "in": "path", - "description": "ID of pet to return", - "required": True, - "schema": {"type": "integer", "format": "int64"}, - } - ], - "responses": {"200": {"description": ""}}, - }, - }, - }, - "components": { - "schemas": { - "Pet": { - "required": ["name"], - "type": "object", - "properties": { - "id": {"type": "integer", "format": "int64"}, - "name": {"type": "string", "example": "doggie"}, - }, - }, - } - }, -} - -spec2 = { - "openapi": "3.0.1", - "info": {"title": "Test specs", "version": "0.0.1"}, - "paths": { - "/pet": { - "put": { - "responses": {"400": {"description": ""}}, - } - }, - "/user/{anyId}": { - "get": { - "parameters": [ - { - "name": "anyId", - "in": "path", - "required": True, - "schema": {"type": "string"}, - } - ], - "responses": {"200": {"description": ""}}, - } - }, - }, -} - - -def format_id(val): - if isinstance(val, HTTPServerRequest): - return f"{val.method}-{val.uri}-{val.body.decode('utf-8') or None}" - - -@pytest.mark.parametrize( - "base_url, server_request, expected", - ( - ( - "/", - HTTPServerRequest( - "put", - "/pet", - body=b'{"name":"puppy"}', - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - True, - ), - ( - "/dummy/base_url/", - HTTPServerRequest( - "put", - "/dummy/base_url/pet", - body=b'{"name":"puppy"}', - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - True, - ), - # No body - ( - "/", - HTTPServerRequest( - "put", - "/pet", - body=b"", - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - False, - ), - # Wrong body - ( - "/", - HTTPServerRequest( - "put", - "/pet", - body=b'{"id":"puppy"}', - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - False, - ), - # Method not allowed - ( - "/", - HTTPServerRequest( - "get", - "/pet", - body=b'{"name":"puppy"}', - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - False, - ), - ( - "/", - HTTPServerRequest( - "get", - "/pet/22", - host="localhost:8888", - ), - True, - ), - # Wrong path parameter - ( - "/", - HTTPServerRequest( - "get", - "/pet/hello", - host="localhost:8888", - ), - False, - ), - ( - "/", - HTTPServerRequest( - "get", - "/pet/findByTags?tags=cat", - host="localhost:8888", - ), - True, - ), - # Missing query parameter - ( - "/", - HTTPServerRequest( - "get", - "/pet/findByTags", - host="localhost:8888", - ), - False, - ), - ), - ids=format_id, -) -def test_Firewall_allowed_spec(base_url, server_request, expected): - firewall = FireWall(base_url, spec1, None) - - assert firewall.validate(server_request) == expected - - -@pytest.mark.parametrize( - "base_url, server_request, expected", - ( - ( - "/", - HTTPServerRequest( - "put", - "/pet", - body=b'{"name":"puppy"}', - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - False, - ), - ( - "/dummy/base_url/", - HTTPServerRequest( - "put", - "/dummy/base_url/pet", - body=b'{"name":"puppy"}', - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - False, - ), - # No body - ( - "/", - HTTPServerRequest( - "put", - "/pet", - body=b"", - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - False, - ), - ( - "/", - HTTPServerRequest( - "put", - "/pet?tags=22", - host="localhost:8888", - ), - False, - ), - # Other method - ( - "/", - HTTPServerRequest( - "get", - "/pet", - body=b'{"name":"puppy"}', - headers=HTTPHeaders({"Content-Type": "application/json"}), - host="localhost:8888", - ), - True, - ), - ( - "/", - HTTPServerRequest( - "put", - "/pet/22", - host="localhost:8888", - ), - True, - ), - ( - "/", - HTTPServerRequest( - "get", - "/user?id=42", - host="localhost:8888", - ), - True, - ), - ( - "/", - HTTPServerRequest( - "get", - "/user/john/smith", - host="localhost:8888", - ), - False, - ), - ( - "/", - HTTPServerRequest( - "get", - "/user/william?page=42", - host="localhost:8888", - ), - False, - ), - ), - ids=format_id, -) -def test_Firewall_blocked_spec(base_url, server_request, expected): - firewall = FireWall(base_url, None, spec2) - - assert firewall.validate(server_request) == expected - - -def test_Firewall_allowed_and_blocked_spec(): - pass diff --git a/jupyter_server/tests/test_specvalidator.py b/jupyter_server/tests/test_specvalidator.py new file mode 100644 index 0000000000..67f85a4a3d --- /dev/null +++ b/jupyter_server/tests/test_specvalidator.py @@ -0,0 +1,527 @@ +import logging + +import pytest +from tornado.httputil import HTTPServerRequest, HTTPHeaders +from tornado.log import access_log +from jupyter_server.specvalidator import SpecValidator, encode_slash + +pytest.importorskip("tornado_openapi3") + +access_log.setLevel(logging.DEBUG) + +allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/pet": { + "put": { + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Pet"} + }, + }, + "required": True, + }, + "responses": {"400": {"description": ""}}, + } + }, + "/pet/findByTags": { + "get": { + "parameters": [ + { + "name": "tags", + "in": "query", + "description": "Tags to filter by", + "required": True, + "style": "form", + "schema": {"type": "array", "items": {"type": "string"}}, + } + ], + "responses": { + "200": { + "description": "successful operation", + } + }, + } + }, + "/pet/{petId}": { + "get": { + "parameters": [ + { + "name": "petId", + "in": "path", + "description": "ID of pet to return", + "required": True, + "schema": {"type": "integer", "format": "int64"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + "components": { + "schemas": { + "Pet": { + "required": ["name"], + "type": "object", + "properties": { + "id": {"type": "integer", "format": "int64"}, + "name": {"type": "string", "example": "doggie"}, + }, + }, + } + }, +} + +blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/pet": { + "put": { + "responses": {"400": {"description": ""}}, + } + }, + "/user/{anyId}": { + "get": { + "parameters": [ + { + "name": "anyId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + } + }, + }, +} + + +@pytest.mark.parametrize( + "base_url, server_request, expected", + ( + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + id="Default case", + ), + pytest.param( + "/dummy/base_url/", + dict( + method="put", + uri="/dummy/base_url/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + id="Non default base URL", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b"", + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="No body", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b'{"id":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Wrong body", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Method not allowed", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/22", + host="localhost:8888", + ), + True, + id="Path argument", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/hello", + host="localhost:8888", + ), + False, + id="Wrong path parameter", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/findByTags?tags=cat", + host="localhost:8888", + ), + True, + id="Query parameters", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/findByTags", + host="localhost:8888", + ), + False, + id="Missing query parameter", + ), + ), +) +def test_SpecValidator_allowed_spec(base_url, server_request, expected): + validator = SpecValidator(base_url, allowed_spec, None) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate( + HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) + ) + == expected + ) + + +@pytest.mark.parametrize( + "base_url, server_request, expected", + ( + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Default case", + ), + pytest.param( + "/dummy/base_url/", + dict( + method="put", + uri="/dummy/base_url/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Non default base URL", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b"", + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Missing body", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet?tags=22", + host="localhost:8888", + ), + False, + id="Query argument", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + id="Non-blocked method", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet/22", + host="localhost:8888", + ), + True, + ), + pytest.param( + "/", + dict( + method="get", + uri="/user?id=42", + host="localhost:8888", + ), + True, + ), + pytest.param( + "/", + dict( + method="get", + uri="/user/john/smith", + host="localhost:8888", + ), + True, + id="Sub path is allowed", + ), + pytest.param( + "/", + dict( + method="get", + uri="/user/william?page=42", + host="localhost:8888", + ), + False, + id="Blocked path with query argument", + ), + ), +) +def test_SpecValidator_blocked_spec(base_url, server_request, expected): + validator = SpecValidator(base_url, None, blocked_spec) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate( + HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) + ) + == expected + ) + + +@pytest.mark.parametrize( + "server_request, expected", + ( + pytest.param( + dict( + method="put", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Blocked although allowed", + ), + pytest.param( + dict( + method="get", + uri="/pet/22", + host="localhost:8888", + ), + True, + id="Allowed and not part of blocked", + ), + ), +) +def test_SpecValidator_allowed_and_blocked_spec(server_request, expected): + validator = SpecValidator(allowed_spec=allowed_spec, blocked_spec=blocked_spec) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate( + HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) + ) + == expected + ) + + +@pytest.mark.parametrize( + "server_request, expected", + ( + pytest.param( + dict( + method="get", + uri="/api/contents/path/to/file.txt/checkpoints/.ipynb_checkpoints/file.txt", + host="localhost:8888", + ), + True, + id="Checkpoint path", + ), + pytest.param( + dict( + method="get", + uri="/api/contents/path/to/file.txt", + host="localhost:8888", + ), + True, + id="Content path", + ), + pytest.param( + dict( + method="get", + uri="/api/contents/path/to/file.txt?tags=dummy", + host="localhost:8888", + ), + True, + id="Encode with query arguments", + ), + pytest.param( + dict( + method="get", + uri="/api/sessions/path/to/file.txt", + host="localhost:8888", + ), + False, + id="Session Id does not support slashes", + ), + ), +) +def test_SpecValidator_encoded_slash(server_request, expected): + validator = SpecValidator( + allowed_spec={ + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/contents/{path}/checkpoints/{checkpointId}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + }, + { + "name": "checkpointId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + }, + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/sessions/{sessionId}": { + "get": { + "parameters": [ + { + "name": "sessionId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + }, + encoded_slash_regex=( + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))/checkpoints/([^/]+(?:(?:/[^/]+)*?))$", + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))$", + # Test regex without group match + r"/api/sessions", + ), + ) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate( + HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) + ) + == expected + ) + + +@pytest.mark.parametrize( + "regex, test, expected", + ( + pytest.param( + (r"/api/contents/([^/]+(?:(?:/[^/]+)*?))$",), + "/api/contents/path/to/file.txt", + "/api/contents/path%2Fto%2Ffile.txt", + id="Path to encode", + ), + pytest.param( + ( + r"/api/contents/([^/]+(?:(?:/[^/]+)*))$", + r"/api/contents/([^/]+(?:(?:/[^/]+)*?))/checkpoints/([^/]+(?:(?:/[^/]+)*?))$", + ), + "/api/contents/path/to/file.txt/checkpoints/.ipynb_checkpoints/file.txt", + r"/api/contents/path%2Fto%2Ffile.txt%2Fcheckpoints%2F.ipynb_checkpoints%2Ffile.txt", + id="Bad regex order", + ), + pytest.param( + ( + r"/api/contents/([^/]+(?:(?:/[^/]+)*?))/checkpoints/([^/]+(?:(?:/[^/]+)*?))$", + r"/api/contents/([^/]+(?:(?:/[^/]+)*))$", + ), + "/api/contents/path/to/file.txt/checkpoints/.ipynb_checkpoints/file.txt", + r"/api/contents/path%2Fto%2Ffile.txt/checkpoints/.ipynb_checkpoints%2Ffile.txt", + id="Multiple regex", + ), + pytest.param( + (r"/api/contents/([^/]+(?:(?:/[^/]+)*?))$",), + "/api/sessions/path/to/file.txt", + "/api/sessions/path/to/file.txt", + id="No match", + ), + pytest.param( + (r"/api/sessions",), + "/api/sessions/path/to/file.txt", + "/api/sessions/path/to/file.txt", + id="No group", + ), + ), +) +def test_encode_slash(regex, test, expected): + assert encode_slash(regex, test) == expected diff --git a/setup.cfg b/setup.cfg index fb5a9e855e..65f999ea91 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,20 +16,19 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: BSD License Programming Language :: Python + Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 [options] zip_safe = False include_package_data = True packages = find: -python_requires = >=3.7 +python_requires = >=3.6 install_requires = jinja2 tornado>=6.1.0 - tornado_openapi3 pyzmq>=17 argon2-cffi ipython_genutils @@ -45,7 +44,12 @@ install_requires = websocket-client [options.extras_require] +validation = + openapi-core>=0.14.0,<0.15.0 + tornado_openapi3>=1.1.0,<2.0.0 test = + openapi-core>=0.14.0,<0.15.0 + tornado_openapi3>=1.1.0,<2.0.0 coverage pytest>=6.0 pytest-cov From 6b739c9e09bdd2bbaa883112f7a23df58a523ea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Collonval?= Date: Tue, 28 Dec 2021 16:38:33 +0100 Subject: [PATCH 3/7] Add application unit tests --- .github/workflows/python-linux.yml | 22 ++ jupyter_server/extension/application.py | 6 +- jupyter_server/serverapp.py | 32 ++- .../tests/extension/mockextensions/app.py | 63 ++++- jupyter_server/tests/extension/test_launch.py | 34 +++ jupyter_server/tests/test_spec_validation.py | 222 ++++++++++++++++++ jupyter_server/tests/test_specvalidator.py | 4 +- 7 files changed, 366 insertions(+), 17 deletions(-) create mode 100644 jupyter_server/tests/test_spec_validation.py diff --git a/.github/workflows/python-linux.yml b/.github/workflows/python-linux.yml index 7402aba1bd..d3356c3d89 100644 --- a/.github/workflows/python-linux.yml +++ b/.github/workflows/python-linux.yml @@ -74,3 +74,25 @@ jobs: pushd test_install ./bin/pytest --pyargs jupyter_server popd + + no-tornado-openapi3: + runs-on: ${{ matrix.os }}-latest + strategy: + fail-fast: false + matrix: + os: [ubuntu] + python-version: ["3.10"] + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Base Setup + uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 + - name: Install the Python dependencies + run: | + pip install -e ".[test]" + # Remove optional dependency + pip uninstall -y tornado-openapi3 + - name: Run some tests + run: | + # Use that files as it contains some tests that should pass and other that needs to be skipped + pytest -vv jupyter_server/tests/extension/test_launch.py diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index 92687f5773..4b86cc74dd 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -1,7 +1,6 @@ import logging import re import sys -import typing from typing import Iterable, Optional from jinja2 import Environment @@ -145,6 +144,7 @@ class method. This method can be set as a entry_point in # They will be applied on the ServerApp _allowed_spec: Optional[dict] = None _blocked_spec: Optional[dict] = None + _slash_encoder: Optional[Iterable[str]] = None # Subclasses should override this trait. Tells the server if # this extension allows other other extensions to be loaded @@ -191,8 +191,8 @@ def get_extension_point(cls): return cls.__module__ @classmethod - def get_firewall_rules(cls): - return {"allowed": cls.__allowed_spec, "blocked": cls.__blocked_spec} + def get_openapi3_spec_rules(cls): + return {"allowed": cls._allowed_spec, "blocked": cls._blocked_spec, "slash_encoder": cls._slash_encoder} # Extension URL sets the default landing page for this extension. extension_url = "/" diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 4b0403d668..8262984b16 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -25,7 +25,6 @@ import sys import threading import time -import typing import urllib import webbrowser from base64 import encodebytes @@ -231,15 +230,21 @@ def __init__( default_url, settings_overrides, jinja_env_options, - endpoints_filters = None + spec_validators = None ): - if SpecValidator is None or endpoints_filters is None: - self.__requestValidator: Optional[SpecValidator] = None + if SpecValidator is None or spec_validators is None: + class DummyValidator: + """Dummy request validator that is always valid.""" + def validate(self, request): + return True + + self.__requestValidator: Optional[SpecValidator] = DummyValidator() else: self.__requestValidator: Optional[SpecValidator] = SpecValidator( base_url, - endpoints_filters.get('allowed'), - endpoints_filters.get('blocked') + spec_validators.get("allowed"), + spec_validators.get("blocked"), + spec_validators.get("slash_encoder") ) settings = self.init_settings( @@ -786,6 +791,7 @@ class ServerApp(JupyterApp): # Subclasses can override this list to filter handlers _allowed_spec: Optional[dict] = None _blocked_spec: Optional[dict] = None + _slash_encoder: Optional[Iterable[str]] = None _log_formatter_cls = LogFormatter @@ -1874,9 +1880,10 @@ def init_webapp(self): self.default_url, self.tornado_settings, self.jinja_environment_options, - endpoints_filters={ - "allowed": self.__allowed_spec, - "blocked": self.__blocked_spec + spec_validators={ + "allowed": self._allowed_spec, + "blocked": self._blocked_spec, + "slash_encoder": self._slash_encoder } ) if self.certfile: @@ -2352,9 +2359,10 @@ def initialize( if point.app: self._starter_app = point.app # Apply endpoint filters from the extension app - firewall_rules = point.app.get_firewall_rules() - self.__allowed_spec = firewall_rules["allowed"] - self.__blocked_spec = firewall_rules["blocked"] + spec_rules = point.app.get_openapi3_spec_rules() + self._allowed_spec = spec_rules["allowed"] + self._blocked_spec = spec_rules["blocked"] + self._slash_encoder = spec_rules["slash_encoder"] # Load any configuration that comes from the Extension point. self.update_config(Config(point.config)) diff --git a/jupyter_server/tests/extension/mockextensions/app.py b/jupyter_server/tests/extension/mockextensions/app.py index 7045417b23..d067901dcc 100644 --- a/jupyter_server/tests/extension/mockextensions/app.py +++ b/jupyter_server/tests/extension/mockextensions/app.py @@ -38,9 +38,70 @@ class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp): loaded = False serverapp_config = { - "jpserver_extensions": {"jupyter_server.tests.extension.mockextensions.mock1": True} + "jpserver_extensions": { + "jupyter_server.tests.extension.mockextensions.mock1": True + } } + _allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/": { + "get": { + "responses": {"200": {"description": ""}}, + }, + }, + "/mock": { + "get": { + "responses": {"200": {"description": ""}}, + }, + }, + "/mock_blocked/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/mock_template": { + "get": { + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/mock_blocked/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _slash_encoder = ( + r"/mocked_blocked/([^/?]+(?:(?:/[^/]+)*?))$", + ) + @staticmethod def get_extension_package(): return "jupyter_server.tests.extension.mockextensions" diff --git a/jupyter_server/tests/extension/test_launch.py b/jupyter_server/tests/extension/test_launch.py index e5cc12e7a7..72013dca35 100644 --- a/jupyter_server/tests/extension/test_launch.py +++ b/jupyter_server/tests/extension/test_launch.py @@ -11,6 +11,11 @@ import pytest import requests +try: + import tornado_openapi3 +except ImportError: + tornado_openapi3 = None + HERE = os.path.dirname(os.path.abspath(__file__)) @@ -109,3 +114,32 @@ def test_token_file(launch_instance, fetch, token): del os.environ["JUPYTER_TOKEN_FILE"] token_file.unlink() assert r.status_code == 200 + + +@pytest.mark.skipif(tornado_openapi3 is None, reason="tornado_openapi3 not available") +def test_request_not_allowed(launch_instance, fetch): + launch_instance() + + r = fetch("/mocka") + + assert r.status_code == 403 + + +@pytest.mark.skipif(tornado_openapi3 is None, reason="tornado_openapi3 not available") +def test_request_allowed(launch_instance, fetch): + launch_instance() + + r = fetch("/mock") + + assert r.status_code == 200 + + +@pytest.mark.skipif(tornado_openapi3 is None, reason="tornado_openapi3 not available") +async def test_ServerApp_with_spec_validation_blocked_with_encoded_path( + launch_instance, fetch +): + launch_instance() + + r = fetch("/mock_blocked/path/dummy.txt") + + assert r.status_code == 403 diff --git a/jupyter_server/tests/test_spec_validation.py b/jupyter_server/tests/test_spec_validation.py new file mode 100644 index 0000000000..c0a8471472 --- /dev/null +++ b/jupyter_server/tests/test_spec_validation.py @@ -0,0 +1,222 @@ +import json +import logging +import os +from binascii import hexlify + +import pytest +from tornado.httpclient import HTTPClientError +from traitlets.config import Config +from jupyter_server.serverapp import ServerApp + + +pytest.importorskip("tornado_openapi3") + + +@pytest.fixture(scope="function") +def jp_serverapp( + jp_ensure_app_fixture, + jp_server_config, + jp_argv, + jp_nbconvert_templates, # this fixture must preceed jp_environ + jp_environ, + jp_http_port, + jp_base_url, + tmp_path, + jp_root_dir, + io_loop, + jp_logging_stream, +): + """Starts a Jupyter Server instance with endpoint specifications based on the established configuration values. + + It overrides the default fixture to define a server with spec validation. + """ + + class ServerAppWithSpec(ServerApp): + _allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + # Will be blocked by _blocked_spec + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/contents/{path}/checkpoints": { + "post": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"201": {"description": ""}}, + }, + }, + "/api/sessions/{sessionId}": { + "get": { + "parameters": [ + { + "name": "sessionId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _slash_encoder = ( + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))/checkpoints$", + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))$", + # Test regex without group match + r"/api/sessions", + ) + + ServerAppWithSpec.clear_instance() + + def _configurable_serverapp( + config=jp_server_config, + base_url=jp_base_url, + argv=jp_argv, + environ=jp_environ, + http_port=jp_http_port, + tmp_path=tmp_path, + root_dir=jp_root_dir, + **kwargs + ): + c = Config(config) + c.NotebookNotary.db_file = ":memory:" + token = hexlify(os.urandom(4)).decode("ascii") + app = ServerAppWithSpec.instance( + # Set the log level to debug for testing purposes + log_level="DEBUG", + port=http_port, + port_retries=0, + open_browser=False, + root_dir=str(root_dir), + base_url=base_url, + config=c, + allow_root=True, + token=token, + **kwargs + ) + + app.init_signal = lambda: None + app.log.propagate = True + app.log.handlers = [] + # Initialize app without httpserver + app.initialize(argv=argv, new_httpserver=False) + # Reroute all logging StreamHandlers away from stdin/stdout since pytest hijacks + # these streams and closes them at unfortunate times. + stream_handlers = [ + h for h in app.log.handlers if isinstance(h, logging.StreamHandler) + ] + for handler in stream_handlers: + handler.setStream(jp_logging_stream) + app.log.propagate = True + app.log.handlers = [] + # Start app without ioloop + app.start_app() + return app + + app = _configurable_serverapp(config=jp_server_config, argv=jp_argv) + yield app + app.remove_server_info_file() + app.remove_browser_open_files() + + +async def test_ServerApp_with_spec_validation_blocked( + tmp_path, jp_serverapp, jp_fetch +): + content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" + content.parent.mkdir(parents=True) + content.write_text("dummy content") + + with pytest.raises(HTTPClientError) as error: + await jp_fetch( + "api", + "contents", + content.parent.name, + content.name, + method="GET", + ) + + assert error.value.code == 403 + + +async def test_ServerApp_with_spec_validation_not_allowed( + tmp_path, jp_serverapp, jp_fetch +): + content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" + content.parent.mkdir(parents=True) + content.write_text("dummy content") + + with pytest.raises(HTTPClientError) as error: + await jp_fetch( + "api", + "contents", + content.parent.name, + content.name, + method="PUT", + body=b"{}" + ) + + assert error.value.code == 403 + + +async def test_ServerApp_with_spec_validation_allowed_with_encoded_path( + tmp_path, jp_serverapp, jp_fetch +): + content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" + content.parent.mkdir(parents=True) + content.write_text("dummy content") + + # Create a checkpoint + r = await jp_fetch( + "api", + "contents", + content.parent.name, + content.name, + "checkpoints", + method="POST", + allow_nonstandard_methods=True, + ) + + assert r.code == 201 + cp1 = json.loads(r.body.decode()) + assert set(cp1) == {"id", "last_modified"} + assert r.headers["Location"].split("/")[-1] == cp1["id"] diff --git a/jupyter_server/tests/test_specvalidator.py b/jupyter_server/tests/test_specvalidator.py index 67f85a4a3d..f39ebb8539 100644 --- a/jupyter_server/tests/test_specvalidator.py +++ b/jupyter_server/tests/test_specvalidator.py @@ -3,10 +3,12 @@ import pytest from tornado.httputil import HTTPServerRequest, HTTPHeaders from tornado.log import access_log -from jupyter_server.specvalidator import SpecValidator, encode_slash pytest.importorskip("tornado_openapi3") +from jupyter_server.specvalidator import SpecValidator, encode_slash + + access_log.setLevel(logging.DEBUG) allowed_spec = { From 010494667d4a7879b20e9b347b80671ca15f3f31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Collonval?= Date: Tue, 28 Dec 2021 16:41:39 +0100 Subject: [PATCH 4/7] Lint code --- jupyter_server/extension/application.py | 37 +++++++------------ jupyter_server/serverapp.py | 13 ++++--- jupyter_server/specvalidator.py | 37 ++++++++----------- .../tests/extension/mockextensions/app.py | 8 +--- jupyter_server/tests/extension/test_launch.py | 4 +- jupyter_server/tests/test_spec_validation.py | 22 +++-------- jupyter_server/tests/test_specvalidator.py | 23 ++++-------- 7 files changed, 52 insertions(+), 92 deletions(-) diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index 4b86cc74dd..5774636043 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -1,7 +1,8 @@ import logging import re import sys -from typing import Iterable, Optional +from typing import Iterable +from typing import Optional from jinja2 import Environment from jinja2 import FileSystemLoader @@ -97,9 +98,7 @@ def _prepare_templates(self): self.initialize_templates() # Add templates to web app settings if extension has templates. if len(self.template_paths) > 0: - self.settings.update( - {"{}_template_paths".format(self.name): self.template_paths} - ) + self.settings.update({"{}_template_paths".format(self.name): self.template_paths}) # Create a jinja environment for logging html templates. self.jinja2_env = Environment( @@ -192,7 +191,11 @@ def get_extension_point(cls): @classmethod def get_openapi3_spec_rules(cls): - return {"allowed": cls._allowed_spec, "blocked": cls._blocked_spec, "slash_encoder": cls._slash_encoder} + return { + "allowed": cls._allowed_spec, + "blocked": cls._blocked_spec, + "slash_encoder": cls._slash_encoder, + } # Extension URL sets the default landing page for this extension. extension_url = "/" @@ -254,9 +257,7 @@ def _default_static_url_prefix(self): ), ).tag(config=True) - settings = Dict(help=_i18n("""Settings that will passed to the server.""")).tag( - config=True - ) + settings = Dict(help=_i18n("""Settings that will passed to the server.""")).tag(config=True) handlers = List(help=_i18n("""Handlers appended to the server.""")).tag(config=True) @@ -349,9 +350,7 @@ def _prepare_handlers(self): def _prepare_templates(self): # Add templates to web app settings if extension has templates. if len(self.template_paths) > 0: - self.settings.update( - {"{}_template_paths".format(self.name): self.template_paths} - ) + self.settings.update({"{}_template_paths".format(self.name): self.template_paths}) self.initialize_templates() def _jupyter_server_config(self): @@ -470,11 +469,7 @@ def load_classic_server_extension(cls, serverapp): ( r"/static/favicons/favicon.ico", RedirectHandler, - { - "url": url_path_join( - serverapp.base_url, "static/base/images/favicon.ico" - ) - }, + {"url": url_path_join(serverapp.base_url, "static/base/images/favicon.ico")}, ), ( r"/static/favicons/favicon-busy-1.ico", @@ -535,11 +530,7 @@ def load_classic_server_extension(cls, serverapp): ( r"/static/logo/logo.png", RedirectHandler, - { - "url": url_path_join( - serverapp.base_url, "static/base/images/logo.png" - ) - }, + {"url": url_path_join(serverapp.base_url, "static/base/images/logo.png")}, ), ] ) @@ -560,9 +551,7 @@ def initialize_server(cls, argv=[], load_other_extensions=True, **kwargs): jpserver_extensions.update(cls.serverapp_config["jpserver_extensions"]) cls.serverapp_config["jpserver_extensions"] = jpserver_extensions find_extensions = False - serverapp = ServerApp.instance( - jpserver_extensions=jpserver_extensions, **kwargs - ) + serverapp = ServerApp.instance(jpserver_extensions=jpserver_extensions, **kwargs) serverapp.aliases.update(cls.aliases) serverapp.initialize( argv=argv, diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 8262984b16..9a333d400b 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -28,7 +28,8 @@ import urllib import webbrowser from base64 import encodebytes -from typing import Iterable, Optional +from typing import Iterable +from typing import Optional from tornado import httputil @@ -230,11 +231,13 @@ def __init__( default_url, settings_overrides, jinja_env_options, - spec_validators = None + spec_validators=None, ): if SpecValidator is None or spec_validators is None: + class DummyValidator: """Dummy request validator that is always valid.""" + def validate(self, request): return True @@ -244,7 +247,7 @@ def validate(self, request): base_url, spec_validators.get("allowed"), spec_validators.get("blocked"), - spec_validators.get("slash_encoder") + spec_validators.get("slash_encoder"), ) settings = self.init_settings( @@ -1883,8 +1886,8 @@ def init_webapp(self): spec_validators={ "allowed": self._allowed_spec, "blocked": self._blocked_spec, - "slash_encoder": self._slash_encoder - } + "slash_encoder": self._slash_encoder, + }, ) if self.certfile: self.ssl_options["certfile"] = self.certfile diff --git a/jupyter_server/specvalidator.py b/jupyter_server/specvalidator.py index 9dc6b5c3e9..194a2c5345 100644 --- a/jupyter_server/specvalidator.py +++ b/jupyter_server/specvalidator.py @@ -1,20 +1,21 @@ import itertools import re -from typing import Iterable, Optional, Union +from typing import Iterable +from typing import Optional +from typing import Union from urllib.parse import parse_qsl from openapi_core import create_spec from openapi_core.validation.request import validators -from openapi_core.validation.request.datatypes import ( - OpenAPIRequest, - RequestParameters, - RequestValidationResult, -) -from openapi_spec_validator.handlers import base -from tornado import httpclient, httputil +from openapi_core.validation.request.datatypes import OpenAPIRequest +from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.validation.request.datatypes import RequestValidationResult +from tornado import httpclient +from tornado import httputil from tornado.log import access_log from tornado_openapi3.util import parse_mimetype -from werkzeug.datastructures import Headers, ImmutableMultiDict +from werkzeug.datastructures import Headers +from werkzeug.datastructures import ImmutableMultiDict def encode_slash(regex: Iterable[Union[str, re.Pattern]], path: str) -> str: @@ -134,9 +135,7 @@ def validate( self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] ) -> RequestValidationResult: """Validate a Tornado HTTP request object.""" - return super().validate( - TornadoRequestFactory.create(request, self.__encoded_slash_regex) - ) + return super().validate(TornadoRequestFactory.create(request, self.__encoded_slash_regex)) class SpecValidator: @@ -192,9 +191,7 @@ def add_base_url_server(spec: dict): encoded_slash_regex=slash_regex, ) - def validate( - self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] - ) -> bool: + def validate(self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest]) -> bool: """Validate a request against allowed and blocked specifications. Args: @@ -203,15 +200,11 @@ def validate( Whether the request is valid or not. """ allowed_result = ( - None - if self.__allowed_validator is None - else self.__allowed_validator.validate(request) + None if self.__allowed_validator is None else self.__allowed_validator.validate(request) ) blocked_result = ( - None - if self.__blocked_validator is None - else self.__blocked_validator.validate(request) + None if self.__blocked_validator is None else self.__blocked_validator.validate(request) ) allowed = allowed_result is None or len(allowed_result.errors) == 0 @@ -224,6 +217,6 @@ def validate( # Provides only the first error access_log.debug(f"Request not allowed: {allowed_result.errors[0]!s}") elif not not_blocked: - access_log.debug(f"Request blocked.") + access_log.debug("Request blocked.") return allowed and not_blocked diff --git a/jupyter_server/tests/extension/mockextensions/app.py b/jupyter_server/tests/extension/mockextensions/app.py index d067901dcc..4c5e315f20 100644 --- a/jupyter_server/tests/extension/mockextensions/app.py +++ b/jupyter_server/tests/extension/mockextensions/app.py @@ -38,9 +38,7 @@ class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp): loaded = False serverapp_config = { - "jpserver_extensions": { - "jupyter_server.tests.extension.mockextensions.mock1": True - } + "jpserver_extensions": {"jupyter_server.tests.extension.mockextensions.mock1": True} } _allowed_spec = { @@ -98,9 +96,7 @@ class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp): }, } - _slash_encoder = ( - r"/mocked_blocked/([^/?]+(?:(?:/[^/]+)*?))$", - ) + _slash_encoder = (r"/mocked_blocked/([^/?]+(?:(?:/[^/]+)*?))$",) @staticmethod def get_extension_package(): diff --git a/jupyter_server/tests/extension/test_launch.py b/jupyter_server/tests/extension/test_launch.py index 72013dca35..402a57d554 100644 --- a/jupyter_server/tests/extension/test_launch.py +++ b/jupyter_server/tests/extension/test_launch.py @@ -135,9 +135,7 @@ def test_request_allowed(launch_instance, fetch): @pytest.mark.skipif(tornado_openapi3 is None, reason="tornado_openapi3 not available") -async def test_ServerApp_with_spec_validation_blocked_with_encoded_path( - launch_instance, fetch -): +async def test_ServerApp_with_spec_validation_blocked_with_encoded_path(launch_instance, fetch): launch_instance() r = fetch("/mock_blocked/path/dummy.txt") diff --git a/jupyter_server/tests/test_spec_validation.py b/jupyter_server/tests/test_spec_validation.py index c0a8471472..261d40e245 100644 --- a/jupyter_server/tests/test_spec_validation.py +++ b/jupyter_server/tests/test_spec_validation.py @@ -6,6 +6,7 @@ import pytest from tornado.httpclient import HTTPClientError from traitlets.config import Config + from jupyter_server.serverapp import ServerApp @@ -78,7 +79,7 @@ class ServerAppWithSpec(ServerApp): }, }, } - + _blocked_spec = { "openapi": "3.0.1", "info": {"title": "Test specs", "version": "0.0.1"}, @@ -142,9 +143,7 @@ def _configurable_serverapp( app.initialize(argv=argv, new_httpserver=False) # Reroute all logging StreamHandlers away from stdin/stdout since pytest hijacks # these streams and closes them at unfortunate times. - stream_handlers = [ - h for h in app.log.handlers if isinstance(h, logging.StreamHandler) - ] + stream_handlers = [h for h in app.log.handlers if isinstance(h, logging.StreamHandler)] for handler in stream_handlers: handler.setStream(jp_logging_stream) app.log.propagate = True @@ -159,9 +158,7 @@ def _configurable_serverapp( app.remove_browser_open_files() -async def test_ServerApp_with_spec_validation_blocked( - tmp_path, jp_serverapp, jp_fetch -): +async def test_ServerApp_with_spec_validation_blocked(tmp_path, jp_serverapp, jp_fetch): content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" content.parent.mkdir(parents=True) content.write_text("dummy content") @@ -178,21 +175,14 @@ async def test_ServerApp_with_spec_validation_blocked( assert error.value.code == 403 -async def test_ServerApp_with_spec_validation_not_allowed( - tmp_path, jp_serverapp, jp_fetch -): +async def test_ServerApp_with_spec_validation_not_allowed(tmp_path, jp_serverapp, jp_fetch): content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" content.parent.mkdir(parents=True) content.write_text("dummy content") with pytest.raises(HTTPClientError) as error: await jp_fetch( - "api", - "contents", - content.parent.name, - content.name, - method="PUT", - body=b"{}" + "api", "contents", content.parent.name, content.name, method="PUT", body=b"{}" ) assert error.value.code == 403 diff --git a/jupyter_server/tests/test_specvalidator.py b/jupyter_server/tests/test_specvalidator.py index f39ebb8539..15374f2508 100644 --- a/jupyter_server/tests/test_specvalidator.py +++ b/jupyter_server/tests/test_specvalidator.py @@ -1,7 +1,8 @@ import logging import pytest -from tornado.httputil import HTTPServerRequest, HTTPHeaders +from tornado.httputil import HTTPHeaders +from tornado.httputil import HTTPServerRequest from tornado.log import access_log pytest.importorskip("tornado_openapi3") @@ -19,9 +20,7 @@ "put": { "requestBody": { "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/Pet"} - }, + "application/json": {"schema": {"$ref": "#/components/schemas/Pet"}}, }, "required": True, }, @@ -213,9 +212,7 @@ def test_SpecValidator_allowed_spec(base_url, server_request, expected): headers = server_request.pop("headers") if "headers" in server_request else {} assert ( - validator.validate( - HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) - ) + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) == expected ) @@ -327,9 +324,7 @@ def test_SpecValidator_blocked_spec(base_url, server_request, expected): headers = server_request.pop("headers") if "headers" in server_request else {} assert ( - validator.validate( - HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) - ) + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) == expected ) @@ -365,9 +360,7 @@ def test_SpecValidator_allowed_and_blocked_spec(server_request, expected): headers = server_request.pop("headers") if "headers" in server_request else {} assert ( - validator.validate( - HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) - ) + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) == expected ) @@ -477,9 +470,7 @@ def test_SpecValidator_encoded_slash(server_request, expected): headers = server_request.pop("headers") if "headers" in server_request else {} assert ( - validator.validate( - HTTPServerRequest(**server_request, headers=HTTPHeaders(headers)) - ) + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) == expected ) From 0e031b8cce8c7448e413a8e4f1e26b987ac19e12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Collonval?= Date: Tue, 28 Dec 2021 16:48:13 +0100 Subject: [PATCH 5/7] Condition optional dep on Python version --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 65f999ea91..f315eaf44d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,11 +45,11 @@ install_requires = [options.extras_require] validation = - openapi-core>=0.14.0,<0.15.0 - tornado_openapi3>=1.1.0,<2.0.0 + openapi-core>=0.14.0,<0.15.0; python_version >= '3.7' + tornado_openapi3>=1.1.0,<2.0.0; python_version >= '3.7' test = - openapi-core>=0.14.0,<0.15.0 - tornado_openapi3>=1.1.0,<2.0.0 + openapi-core>=0.14.0,<0.15.0; python_version >= '3.7' + tornado_openapi3>=1.1.0,<2.0.0; python_version >= '3.7' coverage pytest>=6.0 pytest-cov From f1ecfd25b33f6bb116caa42a0d1e83482eb16dde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Collonval?= Date: Tue, 28 Dec 2021 16:52:44 +0100 Subject: [PATCH 6/7] Make specvalidator optional --- jupyter_server/specvalidator.py | 410 ++++++++++++++++---------------- 1 file changed, 210 insertions(+), 200 deletions(-) diff --git a/jupyter_server/specvalidator.py b/jupyter_server/specvalidator.py index 194a2c5345..e96ac93a77 100644 --- a/jupyter_server/specvalidator.py +++ b/jupyter_server/specvalidator.py @@ -5,218 +5,228 @@ from typing import Union from urllib.parse import parse_qsl -from openapi_core import create_spec -from openapi_core.validation.request import validators -from openapi_core.validation.request.datatypes import OpenAPIRequest -from openapi_core.validation.request.datatypes import RequestParameters -from openapi_core.validation.request.datatypes import RequestValidationResult -from tornado import httpclient -from tornado import httputil -from tornado.log import access_log -from tornado_openapi3.util import parse_mimetype -from werkzeug.datastructures import Headers -from werkzeug.datastructures import ImmutableMultiDict - - -def encode_slash(regex: Iterable[Union[str, re.Pattern]], path: str) -> str: - """Encode slash (``%2F``) for regex groups found in URL path (it never contains the query arguments). - - Note: - The regex order matters as only the first regex matching the path will be applied. - - Args: - regex: List of regex to test the path against - path: URL path to escape - - Returns: - Escaped path - """ - regex = [re.compile(r) if isinstance(r, str) else r for r in regex] - for r in regex: - m = re.search(r, path) - if m is not None and m.lastindex is not None: - # Start from latest group to the first one so that position is preserved - # Skip the first group as it matches the full path not only the group to be encoded - for index in range(m.lastindex, 0, -1): - path = ( - path[: m.start(index)] - + path[m.start(index) : m.end(index)].replace("/", r"%2F") - + path[m.end(index) :] - ) - # Break at first match - break - - return path +try: + from openapi_core import create_spec + from openapi_core.validation.request import validators + from openapi_core.validation.request.datatypes import OpenAPIRequest + from openapi_core.validation.request.datatypes import RequestParameters + from openapi_core.validation.request.datatypes import RequestValidationResult + from tornado import httpclient + from tornado import httputil + from tornado.log import access_log + from tornado_openapi3.util import parse_mimetype + from werkzeug.datastructures import Headers + from werkzeug.datastructures import ImmutableMultiDict + + def encode_slash(regex: Iterable[Union[str, re.Pattern]], path: str) -> str: + """Encode slash (``%2F``) for regex groups found in URL path (it never contains the query arguments). + + Note: + The regex order matters as only the first regex matching the path will be applied. -# Ref: https://github.com/correl/tornado-openapi3/blob/master/tornado_openapi3/requests.py -class TornadoRequestFactory: - """Factory for converting Tornado requests to OpenAPI request objects.""" - - @classmethod - def create( - cls, - request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest], - encoded_slash_regex: Iterable[re.Pattern], - ) -> OpenAPIRequest: - """Creates an OpenAPI request from Tornado request objects. + Args: + regex: List of regex to test the path against + path: URL path to escape - Supports both :class:`tornado.httpclient.HTTPRequest` and - :class:`tornado.httputil.HTTPServerRequest` objects. + Returns: + Escaped path """ - if isinstance(request, httpclient.HTTPRequest): - if request.url: - path, _, querystring = request.url.partition("?") - query_arguments: ImmutableMultiDict[str, str] = ImmutableMultiDict( - parse_qsl(querystring) - ) + regex = [re.compile(r) if isinstance(r, str) else r for r in regex] + for r in regex: + m = re.search(r, path) + if m is not None and m.lastindex is not None: + # Start from latest group to the first one so that position is preserved + # Skip the first group as it matches the full path not only the group to be encoded + for index in range(m.lastindex, 0, -1): + path = ( + path[: m.start(index)] + + path[m.start(index) : m.end(index)].replace("/", r"%2F") + + path[m.end(index) :] + ) + # Break at first match + break + + return path + + # Ref: https://github.com/correl/tornado-openapi3/blob/master/tornado_openapi3/requests.py + class TornadoRequestFactory: + """Factory for converting Tornado requests to OpenAPI request objects.""" + + @classmethod + def create( + cls, + request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest], + encoded_slash_regex: Iterable[re.Pattern], + ) -> OpenAPIRequest: + """Creates an OpenAPI request from Tornado request objects. + + Supports both :class:`tornado.httpclient.HTTPRequest` and + :class:`tornado.httputil.HTTPServerRequest` objects. + """ + if isinstance(request, httpclient.HTTPRequest): + if request.url: + path, _, querystring = request.url.partition("?") + query_arguments: ImmutableMultiDict[str, str] = ImmutableMultiDict( + parse_qsl(querystring) + ) + else: + path = "" + query_arguments = ImmutableMultiDict() else: - path = "" - query_arguments = ImmutableMultiDict() - else: - path, _, _ = request.full_url().partition("?") - if path == "://": - path = "" - query_arguments = ImmutableMultiDict( - itertools.chain( - *[ - [(k, v.decode("utf-8")) for v in vs] - for k, vs in request.query_arguments.items() - ] + path, _, _ = request.full_url().partition("?") + if path == "://": + path = "" + query_arguments = ImmutableMultiDict( + itertools.chain( + *[ + [(k, v.decode("utf-8")) for v in vs] + for k, vs in request.query_arguments.items() + ] + ) ) + + # Encode slashes in path to be compliant with Open API specification + # e.g. /api/contents/path/to/file.txt -> /api/contents/path%2Fto%2Ffile.txt + path = encode_slash(encoded_slash_regex, path) + + return OpenAPIRequest( + full_url_pattern=path, + method=request.method.lower() if request.method else "get", + parameters=RequestParameters( + query=query_arguments, + header=Headers(request.headers.get_all()), + cookie=httputil.parse_cookie(request.headers.get("Cookie", "")), + ), + body=request.body if request.body else b"", + mimetype=parse_mimetype( + request.headers.get("Content-Type", "application/x-www-form-urlencoded") + ), ) - # Encode slashes in path to be compliant with Open API specification - # e.g. /api/contents/path/to/file.txt -> /api/contents/path%2Fto%2Ffile.txt - path = encode_slash(encoded_slash_regex, path) - - return OpenAPIRequest( - full_url_pattern=path, - method=request.method.lower() if request.method else "get", - parameters=RequestParameters( - query=query_arguments, - header=Headers(request.headers.get_all()), - cookie=httputil.parse_cookie(request.headers.get("Cookie", "")), - ), - body=request.body if request.body else b"", - mimetype=parse_mimetype( - request.headers.get("Content-Type", "application/x-www-form-urlencoded") - ), - ) - - -class RequestValidator(validators.RequestValidator): - """Validator for Tornado HTTP Requests. - - Args: - base_url: [optional] Server base URL - custom_formatters: [optional] - custom_media_type_deserializers: [optional] - encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped - """ - - def __init__( - self, - spec, - base_url: Optional[str] = None, - custom_formatters=None, - custom_media_type_deserializers=None, - encoded_slash_regex: Optional[Iterable[re.Pattern]] = None, - ): - super().__init__( + class RequestValidator(validators.RequestValidator): + """Validator for Tornado HTTP Requests. + + Args: + base_url: [optional] Server base URL + custom_formatters: [optional] + custom_media_type_deserializers: [optional] + encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped + """ + + def __init__( + self, spec, - base_url=base_url, - custom_formatters=custom_formatters, - custom_media_type_deserializers=custom_media_type_deserializers, - ) - self.__encoded_slash_regex = encoded_slash_regex - - def validate( - self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] - ) -> RequestValidationResult: - """Validate a Tornado HTTP request object.""" - return super().validate(TornadoRequestFactory.create(request, self.__encoded_slash_regex)) - - -class SpecValidator: - """Validate server request against a list of allowed and blocked OpenAPI v3 specifications. - - If allowed and blocked specifications are defined, the request must be allowed and not blocked; - i.e. blocked specification takes precedence. - - Note: - OpenAPI does not accept path argument containing ``/``. Therefore you should provide regex to encode - them; e.g. ``"/api/contents/([^/]+(?:/[^/]+)*?)$"`` to match ``/api/contents/{path}``. The order in - which you provide the regex are important as only the first regex matching the path will be applied. - You can test your expression using :ref:`jupyter_server.specvalidator.encode_slash`. - - Args: - base_url: [optional] Server base URL - allowed_spec: [optional] Allowed endpoints - blocked_spec: [optional] Blocked endpoints - encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped - """ - - def __init__( - self, - base_url: Optional[str] = None, - allowed_spec: Optional[dict] = None, - blocked_spec: Optional[dict] = None, - encoded_slash_regex: Optional[Iterable[str]] = None, - ): - self.__allowed_validator: Optional[RequestValidator] = None - self.__blocked_validator: Optional[RequestValidator] = None - slash_regex: Iterable[re.Pattern] = list( - map(lambda s: re.compile(s), encoded_slash_regex or []) - ) - - def add_base_url_server(spec: dict): - servers = spec.get("servers", []) - if not any(map(lambda s: s.get("url") == base_url, servers)): - servers.append({"url": base_url}) - spec["servers"] = servers - - if allowed_spec is not None: - if base_url is not None: - add_base_url_server(allowed_spec) - self.__allowed_validator = RequestValidator( - create_spec(allowed_spec), - encoded_slash_regex=slash_regex, + base_url: Optional[str] = None, + custom_formatters=None, + custom_media_type_deserializers=None, + encoded_slash_regex: Optional[Iterable[re.Pattern]] = None, + ): + super().__init__( + spec, + base_url=base_url, + custom_formatters=custom_formatters, + custom_media_type_deserializers=custom_media_type_deserializers, ) - if blocked_spec is not None: - if base_url is not None: - add_base_url_server(blocked_spec) - self.__blocked_validator = RequestValidator( - create_spec(blocked_spec), - encoded_slash_regex=slash_regex, + self.__encoded_slash_regex = encoded_slash_regex + + def validate( + self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] + ) -> RequestValidationResult: + """Validate a Tornado HTTP request object.""" + return super().validate( + TornadoRequestFactory.create(request, self.__encoded_slash_regex) ) - def validate(self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest]) -> bool: - """Validate a request against allowed and blocked specifications. + class SpecValidator: + """Validate server request against a list of allowed and blocked OpenAPI v3 specifications. + + If allowed and blocked specifications are defined, the request must be allowed and not blocked; + i.e. blocked specification takes precedence. + + Note: + OpenAPI does not accept path argument containing ``/``. Therefore you should provide regex to encode + them; e.g. ``"/api/contents/([^/]+(?:/[^/]+)*?)$"`` to match ``/api/contents/{path}``. The order in + which you provide the regex are important as only the first regex matching the path will be applied. + You can test your expression using :ref:`jupyter_server.specvalidator.encode_slash`. Args: - request: Request to validate - Returns: - Whether the request is valid or not. + base_url: [optional] Server base URL + allowed_spec: [optional] Allowed endpoints + blocked_spec: [optional] Blocked endpoints + encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped """ - allowed_result = ( - None if self.__allowed_validator is None else self.__allowed_validator.validate(request) - ) - - blocked_result = ( - None if self.__blocked_validator is None else self.__blocked_validator.validate(request) - ) - - allowed = allowed_result is None or len(allowed_result.errors) == 0 - not_blocked = blocked_result is None or len(blocked_result.errors) > 0 - - # The error raised if this is not valid will be logged - # So we only give the reason in debug level - if not (allowed and not_blocked): - if not allowed: - # Provides only the first error - access_log.debug(f"Request not allowed: {allowed_result.errors[0]!s}") - elif not not_blocked: - access_log.debug("Request blocked.") - - return allowed and not_blocked + + def __init__( + self, + base_url: Optional[str] = None, + allowed_spec: Optional[dict] = None, + blocked_spec: Optional[dict] = None, + encoded_slash_regex: Optional[Iterable[str]] = None, + ): + self.__allowed_validator: Optional[RequestValidator] = None + self.__blocked_validator: Optional[RequestValidator] = None + slash_regex: Iterable[re.Pattern] = list( + map(lambda s: re.compile(s), encoded_slash_regex or []) + ) + + def add_base_url_server(spec: dict): + servers = spec.get("servers", []) + if not any(map(lambda s: s.get("url") == base_url, servers)): + servers.append({"url": base_url}) + spec["servers"] = servers + + if allowed_spec is not None: + if base_url is not None: + add_base_url_server(allowed_spec) + self.__allowed_validator = RequestValidator( + create_spec(allowed_spec), + encoded_slash_regex=slash_regex, + ) + if blocked_spec is not None: + if base_url is not None: + add_base_url_server(blocked_spec) + self.__blocked_validator = RequestValidator( + create_spec(blocked_spec), + encoded_slash_regex=slash_regex, + ) + + def validate( + self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] + ) -> bool: + """Validate a request against allowed and blocked specifications. + + Args: + request: Request to validate + Returns: + Whether the request is valid or not. + """ + allowed_result = ( + None + if self.__allowed_validator is None + else self.__allowed_validator.validate(request) + ) + + blocked_result = ( + None + if self.__blocked_validator is None + else self.__blocked_validator.validate(request) + ) + + allowed = allowed_result is None or len(allowed_result.errors) == 0 + not_blocked = blocked_result is None or len(blocked_result.errors) > 0 + + # The error raised if this is not valid will be logged + # So we only give the reason in debug level + if not (allowed and not_blocked): + if not allowed: + # Provides only the first error + access_log.debug(f"Request not allowed: {allowed_result.errors[0]!s}") + elif not not_blocked: + access_log.debug("Request blocked.") + + return allowed and not_blocked + + +except ImportError: + pass From 6fc0130d4a226e926e7d14f22db4e706e7c70cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Collonval?= Date: Wed, 29 Dec 2021 11:53:28 +0100 Subject: [PATCH 7/7] Add documentation --- docs/source/developers/extensions.rst | 95 +++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/docs/source/developers/extensions.rst b/docs/source/developers/extensions.rst index b056746443..8e122ae95a 100644 --- a/docs/source/developers/extensions.rst +++ b/docs/source/developers/extensions.rst @@ -306,6 +306,101 @@ To make your extension executable from anywhere on your system, point an entry-p } ) +Filtering requests +------------------ + +When launching the ``ExtensionApp`` in standalone mode, you may want to restrict the available endpoints. You can do this by defining +either ``_allowed_spec`` or ``_blocked_spec`` class attributes as OpenAPI v3 specifications. The allowed paths will blocked any requests +not matching the specification and the blocked paths will allow any requests except the ones specified. + +.. note:: + + If both are defined, the blocked paths take precedence on the allowed one; i.e. if a path is allowed and blocked, it will be blocked. + +OpenAPI v3 does not support URL path arguments that contains slashes ``/``. For the specification, +``/api/contents/{path}`` can only be ``/api/contents/[^/]+``. To circumvent this limitation, you can provide a list of regex's through +the class attributes ``_slash_encoder``. The groups matching one of the encoder will have their slashes encoded (``/`` replaced by ``%2F``). +This will allow the validation of path through the OpenAPI v3 specification. + +Here is an example for the unit tests. + +.. code-block:: python + + class MyExtensionApp(ExtensionApp): + + _allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + # Will be blocked by _blocked_spec + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/contents/{path}/checkpoints": { + "post": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"201": {"description": ""}}, + }, + }, + "/api/sessions/{sessionId}": { + "get": { + "parameters": [ + { + "name": "sessionId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _slash_encoder = ( + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))/checkpoints$", + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))$", + ) + + ``ExtensionApp`` as a classic Notebook server extension -------------------------------------------------------