diff --git a/.github/workflows/python-linux.yml b/.github/workflows/python-linux.yml index 2f69cdaccd..f2c6b80703 100644 --- a/.github/workflows/python-linux.yml +++ b/.github/workflows/python-linux.yml @@ -79,3 +79,25 @@ jobs: cd docs pip install -r doc-requirements.txt make html SPHINXOPTS="-W" + + no-tornado-openapi3: + runs-on: ${{ matrix.os }}-latest + strategy: + fail-fast: false + matrix: + os: [ubuntu] + python-version: ["3.10"] + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Base Setup + uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 + - name: Install the Python dependencies + run: | + pip install -e ".[test]" + # Remove optional dependency + pip uninstall -y tornado-openapi3 + - name: Run some tests + run: | + # Use that files as it contains some tests that should pass and other that needs to be skipped + pytest -vv jupyter_server/tests/extension/test_launch.py diff --git a/docs/source/developers/extensions.rst b/docs/source/developers/extensions.rst index 378f1358d7..ca50152a8b 100644 --- a/docs/source/developers/extensions.rst +++ b/docs/source/developers/extensions.rst @@ -306,6 +306,101 @@ To make your extension executable from anywhere on your system, point an entry-p } ) +Filtering requests +------------------ + +When launching the ``ExtensionApp`` in standalone mode, you may want to restrict the available endpoints. You can do this by defining +either ``_allowed_spec`` or ``_blocked_spec`` class attributes as OpenAPI v3 specifications. The allowed paths will blocked any requests +not matching the specification and the blocked paths will allow any requests except the ones specified. + +.. note:: + + If both are defined, the blocked paths take precedence on the allowed one; i.e. if a path is allowed and blocked, it will be blocked. + +OpenAPI v3 does not support URL path arguments that contains slashes ``/``. For the specification, +``/api/contents/{path}`` can only be ``/api/contents/[^/]+``. To circumvent this limitation, you can provide a list of regex's through +the class attributes ``_slash_encoder``. The groups matching one of the encoder will have their slashes encoded (``/`` replaced by ``%2F``). +This will allow the validation of path through the OpenAPI v3 specification. + +Here is an example for the unit tests. + +.. code-block:: python + + class MyExtensionApp(ExtensionApp): + + _allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + # Will be blocked by _blocked_spec + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/contents/{path}/checkpoints": { + "post": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"201": {"description": ""}}, + }, + }, + "/api/sessions/{sessionId}": { + "get": { + "parameters": [ + { + "name": "sessionId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _slash_encoder = ( + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))/checkpoints$", + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))$", + ) + + ``ExtensionApp`` as a classic Notebook server extension ------------------------------------------------------- diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index a0cd5c3551..5774636043 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -1,6 +1,8 @@ import logging import re import sys +from typing import Iterable +from typing import Optional from jinja2 import Environment from jinja2 import FileSystemLoader @@ -136,6 +138,13 @@ class method. This method can be set as a entry_point in the extensions setup.py """ + # Filtering rules to apply on handlers registration + # Subclasses can override this list to filter handlers + # They will be applied on the ServerApp + _allowed_spec: Optional[dict] = None + _blocked_spec: Optional[dict] = None + _slash_encoder: Optional[Iterable[str]] = None + # Subclasses should override this trait. Tells the server if # this extension allows other other extensions to be loaded # side-by-side when launched directly. @@ -180,6 +189,14 @@ def get_extension_package(cls): def get_extension_point(cls): return cls.__module__ + @classmethod + def get_openapi3_spec_rules(cls): + return { + "allowed": cls._allowed_spec, + "blocked": cls._blocked_spec, + "slash_encoder": cls._slash_encoder, + } + # Extension URL sets the default landing page for this extension. extension_url = "/" @@ -495,7 +512,8 @@ def load_classic_server_extension(cls, serverapp): RedirectHandler, { "url": url_path_join( - serverapp.base_url, "static/base/images/favicon-notebook.ico" + serverapp.base_url, + "static/base/images/favicon-notebook.ico", ) }, ), @@ -504,7 +522,8 @@ def load_classic_server_extension(cls, serverapp): RedirectHandler, { "url": url_path_join( - serverapp.base_url, "static/base/images/favicon-terminal.ico" + serverapp.base_url, + "static/base/images/favicon-terminal.ico", ) }, ), diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 51ebd23f3e..e03058be45 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -28,6 +28,15 @@ import urllib import webbrowser from base64 import encodebytes +from typing import Iterable +from typing import Optional + +from tornado import httputil + +try: + from jupyter_server.specvalidator import SpecValidator +except ImportError: + SpecValidator = None try: import resource @@ -222,7 +231,24 @@ def __init__( default_url, settings_overrides, jinja_env_options, + spec_validators=None, ): + if SpecValidator is None or spec_validators is None: + + class DummyValidator: + """Dummy request validator that is always valid.""" + + def validate(self, request): + return True + + self.__requestValidator: Optional[SpecValidator] = DummyValidator() + else: + self.__requestValidator: Optional[SpecValidator] = SpecValidator( + base_url, + spec_validators.get("allowed"), + spec_validators.get("blocked"), + spec_validators.get("slash_encoder"), + ) settings = self.init_settings( jupyter_app, @@ -433,6 +459,14 @@ def init_handlers(self, default_services, settings): new_handlers.append((r"(.*)", Template404)) return new_handlers + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> "web._HandlerDelegate": + if self.__requestValidator.validate(request): + return super().find_handler(request, **kwargs) + else: + return self.get_handler_delegate(request, web.ErrorHandler, {"status_code": 403}) + def last_activity(self): """Get a UTC timestamp for when the server last did something. @@ -756,6 +790,12 @@ class ServerApp(JupyterApp): "view", ) + # Filtering rules to apply on handlers registration + # Subclasses can override this list to filter handlers + _allowed_spec: Optional[dict] = None + _blocked_spec: Optional[dict] = None + _slash_encoder: Optional[Iterable[str]] = None + _log_formatter_cls = LogFormatter @default("log_level") @@ -1843,6 +1883,11 @@ def init_webapp(self): self.default_url, self.tornado_settings, self.jinja_environment_options, + spec_validators={ + "allowed": self._allowed_spec, + "blocked": self._blocked_spec, + "slash_encoder": self._slash_encoder, + }, ) if self.certfile: self.ssl_options["certfile"] = self.certfile @@ -2316,6 +2361,11 @@ def initialize( # Set starter_app property. if point.app: self._starter_app = point.app + # Apply endpoint filters from the extension app + spec_rules = point.app.get_openapi3_spec_rules() + self._allowed_spec = spec_rules["allowed"] + self._blocked_spec = spec_rules["blocked"] + self._slash_encoder = spec_rules["slash_encoder"] # Load any configuration that comes from the Extension point. self.update_config(Config(point.config)) diff --git a/jupyter_server/specvalidator.py b/jupyter_server/specvalidator.py new file mode 100644 index 0000000000..e96ac93a77 --- /dev/null +++ b/jupyter_server/specvalidator.py @@ -0,0 +1,232 @@ +import itertools +import re +from typing import Iterable +from typing import Optional +from typing import Union +from urllib.parse import parse_qsl + + +try: + from openapi_core import create_spec + from openapi_core.validation.request import validators + from openapi_core.validation.request.datatypes import OpenAPIRequest + from openapi_core.validation.request.datatypes import RequestParameters + from openapi_core.validation.request.datatypes import RequestValidationResult + from tornado import httpclient + from tornado import httputil + from tornado.log import access_log + from tornado_openapi3.util import parse_mimetype + from werkzeug.datastructures import Headers + from werkzeug.datastructures import ImmutableMultiDict + + def encode_slash(regex: Iterable[Union[str, re.Pattern]], path: str) -> str: + """Encode slash (``%2F``) for regex groups found in URL path (it never contains the query arguments). + + Note: + The regex order matters as only the first regex matching the path will be applied. + + Args: + regex: List of regex to test the path against + path: URL path to escape + + Returns: + Escaped path + """ + regex = [re.compile(r) if isinstance(r, str) else r for r in regex] + for r in regex: + m = re.search(r, path) + if m is not None and m.lastindex is not None: + # Start from latest group to the first one so that position is preserved + # Skip the first group as it matches the full path not only the group to be encoded + for index in range(m.lastindex, 0, -1): + path = ( + path[: m.start(index)] + + path[m.start(index) : m.end(index)].replace("/", r"%2F") + + path[m.end(index) :] + ) + # Break at first match + break + + return path + + # Ref: https://github.com/correl/tornado-openapi3/blob/master/tornado_openapi3/requests.py + class TornadoRequestFactory: + """Factory for converting Tornado requests to OpenAPI request objects.""" + + @classmethod + def create( + cls, + request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest], + encoded_slash_regex: Iterable[re.Pattern], + ) -> OpenAPIRequest: + """Creates an OpenAPI request from Tornado request objects. + + Supports both :class:`tornado.httpclient.HTTPRequest` and + :class:`tornado.httputil.HTTPServerRequest` objects. + """ + if isinstance(request, httpclient.HTTPRequest): + if request.url: + path, _, querystring = request.url.partition("?") + query_arguments: ImmutableMultiDict[str, str] = ImmutableMultiDict( + parse_qsl(querystring) + ) + else: + path = "" + query_arguments = ImmutableMultiDict() + else: + path, _, _ = request.full_url().partition("?") + if path == "://": + path = "" + query_arguments = ImmutableMultiDict( + itertools.chain( + *[ + [(k, v.decode("utf-8")) for v in vs] + for k, vs in request.query_arguments.items() + ] + ) + ) + + # Encode slashes in path to be compliant with Open API specification + # e.g. /api/contents/path/to/file.txt -> /api/contents/path%2Fto%2Ffile.txt + path = encode_slash(encoded_slash_regex, path) + + return OpenAPIRequest( + full_url_pattern=path, + method=request.method.lower() if request.method else "get", + parameters=RequestParameters( + query=query_arguments, + header=Headers(request.headers.get_all()), + cookie=httputil.parse_cookie(request.headers.get("Cookie", "")), + ), + body=request.body if request.body else b"", + mimetype=parse_mimetype( + request.headers.get("Content-Type", "application/x-www-form-urlencoded") + ), + ) + + class RequestValidator(validators.RequestValidator): + """Validator for Tornado HTTP Requests. + + Args: + base_url: [optional] Server base URL + custom_formatters: [optional] + custom_media_type_deserializers: [optional] + encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped + """ + + def __init__( + self, + spec, + base_url: Optional[str] = None, + custom_formatters=None, + custom_media_type_deserializers=None, + encoded_slash_regex: Optional[Iterable[re.Pattern]] = None, + ): + super().__init__( + spec, + base_url=base_url, + custom_formatters=custom_formatters, + custom_media_type_deserializers=custom_media_type_deserializers, + ) + self.__encoded_slash_regex = encoded_slash_regex + + def validate( + self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] + ) -> RequestValidationResult: + """Validate a Tornado HTTP request object.""" + return super().validate( + TornadoRequestFactory.create(request, self.__encoded_slash_regex) + ) + + class SpecValidator: + """Validate server request against a list of allowed and blocked OpenAPI v3 specifications. + + If allowed and blocked specifications are defined, the request must be allowed and not blocked; + i.e. blocked specification takes precedence. + + Note: + OpenAPI does not accept path argument containing ``/``. Therefore you should provide regex to encode + them; e.g. ``"/api/contents/([^/]+(?:/[^/]+)*?)$"`` to match ``/api/contents/{path}``. The order in + which you provide the regex are important as only the first regex matching the path will be applied. + You can test your expression using :ref:`jupyter_server.specvalidator.encode_slash`. + + Args: + base_url: [optional] Server base URL + allowed_spec: [optional] Allowed endpoints + blocked_spec: [optional] Blocked endpoints + encoded_slash_regex: [optional] Regex expression to find part of URL path in which ``/`` must be escaped + """ + + def __init__( + self, + base_url: Optional[str] = None, + allowed_spec: Optional[dict] = None, + blocked_spec: Optional[dict] = None, + encoded_slash_regex: Optional[Iterable[str]] = None, + ): + self.__allowed_validator: Optional[RequestValidator] = None + self.__blocked_validator: Optional[RequestValidator] = None + slash_regex: Iterable[re.Pattern] = list( + map(lambda s: re.compile(s), encoded_slash_regex or []) + ) + + def add_base_url_server(spec: dict): + servers = spec.get("servers", []) + if not any(map(lambda s: s.get("url") == base_url, servers)): + servers.append({"url": base_url}) + spec["servers"] = servers + + if allowed_spec is not None: + if base_url is not None: + add_base_url_server(allowed_spec) + self.__allowed_validator = RequestValidator( + create_spec(allowed_spec), + encoded_slash_regex=slash_regex, + ) + if blocked_spec is not None: + if base_url is not None: + add_base_url_server(blocked_spec) + self.__blocked_validator = RequestValidator( + create_spec(blocked_spec), + encoded_slash_regex=slash_regex, + ) + + def validate( + self, request: Union[httpclient.HTTPRequest, httputil.HTTPServerRequest] + ) -> bool: + """Validate a request against allowed and blocked specifications. + + Args: + request: Request to validate + Returns: + Whether the request is valid or not. + """ + allowed_result = ( + None + if self.__allowed_validator is None + else self.__allowed_validator.validate(request) + ) + + blocked_result = ( + None + if self.__blocked_validator is None + else self.__blocked_validator.validate(request) + ) + + allowed = allowed_result is None or len(allowed_result.errors) == 0 + not_blocked = blocked_result is None or len(blocked_result.errors) > 0 + + # The error raised if this is not valid will be logged + # So we only give the reason in debug level + if not (allowed and not_blocked): + if not allowed: + # Provides only the first error + access_log.debug(f"Request not allowed: {allowed_result.errors[0]!s}") + elif not not_blocked: + access_log.debug("Request blocked.") + + return allowed and not_blocked + + +except ImportError: + pass diff --git a/jupyter_server/tests/extension/mockextensions/app.py b/jupyter_server/tests/extension/mockextensions/app.py index 7045417b23..4c5e315f20 100644 --- a/jupyter_server/tests/extension/mockextensions/app.py +++ b/jupyter_server/tests/extension/mockextensions/app.py @@ -41,6 +41,63 @@ class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp): "jpserver_extensions": {"jupyter_server.tests.extension.mockextensions.mock1": True} } + _allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/": { + "get": { + "responses": {"200": {"description": ""}}, + }, + }, + "/mock": { + "get": { + "responses": {"200": {"description": ""}}, + }, + }, + "/mock_blocked/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/mock_template": { + "get": { + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/mock_blocked/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _slash_encoder = (r"/mocked_blocked/([^/?]+(?:(?:/[^/]+)*?))$",) + @staticmethod def get_extension_package(): return "jupyter_server.tests.extension.mockextensions" diff --git a/jupyter_server/tests/extension/test_launch.py b/jupyter_server/tests/extension/test_launch.py index e5cc12e7a7..402a57d554 100644 --- a/jupyter_server/tests/extension/test_launch.py +++ b/jupyter_server/tests/extension/test_launch.py @@ -11,6 +11,11 @@ import pytest import requests +try: + import tornado_openapi3 +except ImportError: + tornado_openapi3 = None + HERE = os.path.dirname(os.path.abspath(__file__)) @@ -109,3 +114,30 @@ def test_token_file(launch_instance, fetch, token): del os.environ["JUPYTER_TOKEN_FILE"] token_file.unlink() assert r.status_code == 200 + + +@pytest.mark.skipif(tornado_openapi3 is None, reason="tornado_openapi3 not available") +def test_request_not_allowed(launch_instance, fetch): + launch_instance() + + r = fetch("/mocka") + + assert r.status_code == 403 + + +@pytest.mark.skipif(tornado_openapi3 is None, reason="tornado_openapi3 not available") +def test_request_allowed(launch_instance, fetch): + launch_instance() + + r = fetch("/mock") + + assert r.status_code == 200 + + +@pytest.mark.skipif(tornado_openapi3 is None, reason="tornado_openapi3 not available") +async def test_ServerApp_with_spec_validation_blocked_with_encoded_path(launch_instance, fetch): + launch_instance() + + r = fetch("/mock_blocked/path/dummy.txt") + + assert r.status_code == 403 diff --git a/jupyter_server/tests/test_spec_validation.py b/jupyter_server/tests/test_spec_validation.py new file mode 100644 index 0000000000..261d40e245 --- /dev/null +++ b/jupyter_server/tests/test_spec_validation.py @@ -0,0 +1,212 @@ +import json +import logging +import os +from binascii import hexlify + +import pytest +from tornado.httpclient import HTTPClientError +from traitlets.config import Config + +from jupyter_server.serverapp import ServerApp + + +pytest.importorskip("tornado_openapi3") + + +@pytest.fixture(scope="function") +def jp_serverapp( + jp_ensure_app_fixture, + jp_server_config, + jp_argv, + jp_nbconvert_templates, # this fixture must preceed jp_environ + jp_environ, + jp_http_port, + jp_base_url, + tmp_path, + jp_root_dir, + io_loop, + jp_logging_stream, +): + """Starts a Jupyter Server instance with endpoint specifications based on the established configuration values. + + It overrides the default fixture to define a server with spec validation. + """ + + class ServerAppWithSpec(ServerApp): + _allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + # Will be blocked by _blocked_spec + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/contents/{path}/checkpoints": { + "post": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"201": {"description": ""}}, + }, + }, + "/api/sessions/{sessionId}": { + "get": { + "parameters": [ + { + "name": "sessionId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + } + + _slash_encoder = ( + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))/checkpoints$", + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))$", + # Test regex without group match + r"/api/sessions", + ) + + ServerAppWithSpec.clear_instance() + + def _configurable_serverapp( + config=jp_server_config, + base_url=jp_base_url, + argv=jp_argv, + environ=jp_environ, + http_port=jp_http_port, + tmp_path=tmp_path, + root_dir=jp_root_dir, + **kwargs + ): + c = Config(config) + c.NotebookNotary.db_file = ":memory:" + token = hexlify(os.urandom(4)).decode("ascii") + app = ServerAppWithSpec.instance( + # Set the log level to debug for testing purposes + log_level="DEBUG", + port=http_port, + port_retries=0, + open_browser=False, + root_dir=str(root_dir), + base_url=base_url, + config=c, + allow_root=True, + token=token, + **kwargs + ) + + app.init_signal = lambda: None + app.log.propagate = True + app.log.handlers = [] + # Initialize app without httpserver + app.initialize(argv=argv, new_httpserver=False) + # Reroute all logging StreamHandlers away from stdin/stdout since pytest hijacks + # these streams and closes them at unfortunate times. + stream_handlers = [h for h in app.log.handlers if isinstance(h, logging.StreamHandler)] + for handler in stream_handlers: + handler.setStream(jp_logging_stream) + app.log.propagate = True + app.log.handlers = [] + # Start app without ioloop + app.start_app() + return app + + app = _configurable_serverapp(config=jp_server_config, argv=jp_argv) + yield app + app.remove_server_info_file() + app.remove_browser_open_files() + + +async def test_ServerApp_with_spec_validation_blocked(tmp_path, jp_serverapp, jp_fetch): + content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" + content.parent.mkdir(parents=True) + content.write_text("dummy content") + + with pytest.raises(HTTPClientError) as error: + await jp_fetch( + "api", + "contents", + content.parent.name, + content.name, + method="GET", + ) + + assert error.value.code == 403 + + +async def test_ServerApp_with_spec_validation_not_allowed(tmp_path, jp_serverapp, jp_fetch): + content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" + content.parent.mkdir(parents=True) + content.write_text("dummy content") + + with pytest.raises(HTTPClientError) as error: + await jp_fetch( + "api", "contents", content.parent.name, content.name, method="PUT", body=b"{}" + ) + + assert error.value.code == 403 + + +async def test_ServerApp_with_spec_validation_allowed_with_encoded_path( + tmp_path, jp_serverapp, jp_fetch +): + content = tmp_path / jp_serverapp.root_dir / "content" / "test.txt" + content.parent.mkdir(parents=True) + content.write_text("dummy content") + + # Create a checkpoint + r = await jp_fetch( + "api", + "contents", + content.parent.name, + content.name, + "checkpoints", + method="POST", + allow_nonstandard_methods=True, + ) + + assert r.code == 201 + cp1 = json.loads(r.body.decode()) + assert set(cp1) == {"id", "last_modified"} + assert r.headers["Location"].split("/")[-1] == cp1["id"] diff --git a/jupyter_server/tests/test_specvalidator.py b/jupyter_server/tests/test_specvalidator.py new file mode 100644 index 0000000000..15374f2508 --- /dev/null +++ b/jupyter_server/tests/test_specvalidator.py @@ -0,0 +1,520 @@ +import logging + +import pytest +from tornado.httputil import HTTPHeaders +from tornado.httputil import HTTPServerRequest +from tornado.log import access_log + +pytest.importorskip("tornado_openapi3") + +from jupyter_server.specvalidator import SpecValidator, encode_slash + + +access_log.setLevel(logging.DEBUG) + +allowed_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/pet": { + "put": { + "requestBody": { + "content": { + "application/json": {"schema": {"$ref": "#/components/schemas/Pet"}}, + }, + "required": True, + }, + "responses": {"400": {"description": ""}}, + } + }, + "/pet/findByTags": { + "get": { + "parameters": [ + { + "name": "tags", + "in": "query", + "description": "Tags to filter by", + "required": True, + "style": "form", + "schema": {"type": "array", "items": {"type": "string"}}, + } + ], + "responses": { + "200": { + "description": "successful operation", + } + }, + } + }, + "/pet/{petId}": { + "get": { + "parameters": [ + { + "name": "petId", + "in": "path", + "description": "ID of pet to return", + "required": True, + "schema": {"type": "integer", "format": "int64"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + "components": { + "schemas": { + "Pet": { + "required": ["name"], + "type": "object", + "properties": { + "id": {"type": "integer", "format": "int64"}, + "name": {"type": "string", "example": "doggie"}, + }, + }, + } + }, +} + +blocked_spec = { + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/pet": { + "put": { + "responses": {"400": {"description": ""}}, + } + }, + "/user/{anyId}": { + "get": { + "parameters": [ + { + "name": "anyId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + } + }, + }, +} + + +@pytest.mark.parametrize( + "base_url, server_request, expected", + ( + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + id="Default case", + ), + pytest.param( + "/dummy/base_url/", + dict( + method="put", + uri="/dummy/base_url/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + id="Non default base URL", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b"", + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="No body", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b'{"id":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Wrong body", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Method not allowed", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/22", + host="localhost:8888", + ), + True, + id="Path argument", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/hello", + host="localhost:8888", + ), + False, + id="Wrong path parameter", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/findByTags?tags=cat", + host="localhost:8888", + ), + True, + id="Query parameters", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet/findByTags", + host="localhost:8888", + ), + False, + id="Missing query parameter", + ), + ), +) +def test_SpecValidator_allowed_spec(base_url, server_request, expected): + validator = SpecValidator(base_url, allowed_spec, None) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) + == expected + ) + + +@pytest.mark.parametrize( + "base_url, server_request, expected", + ( + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Default case", + ), + pytest.param( + "/dummy/base_url/", + dict( + method="put", + uri="/dummy/base_url/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Non default base URL", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet", + body=b"", + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Missing body", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet?tags=22", + host="localhost:8888", + ), + False, + id="Query argument", + ), + pytest.param( + "/", + dict( + method="get", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + True, + id="Non-blocked method", + ), + pytest.param( + "/", + dict( + method="put", + uri="/pet/22", + host="localhost:8888", + ), + True, + ), + pytest.param( + "/", + dict( + method="get", + uri="/user?id=42", + host="localhost:8888", + ), + True, + ), + pytest.param( + "/", + dict( + method="get", + uri="/user/john/smith", + host="localhost:8888", + ), + True, + id="Sub path is allowed", + ), + pytest.param( + "/", + dict( + method="get", + uri="/user/william?page=42", + host="localhost:8888", + ), + False, + id="Blocked path with query argument", + ), + ), +) +def test_SpecValidator_blocked_spec(base_url, server_request, expected): + validator = SpecValidator(base_url, None, blocked_spec) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) + == expected + ) + + +@pytest.mark.parametrize( + "server_request, expected", + ( + pytest.param( + dict( + method="put", + uri="/pet", + body=b'{"name":"puppy"}', + headers=dict({"Content-Type": "application/json"}), + host="localhost:8888", + ), + False, + id="Blocked although allowed", + ), + pytest.param( + dict( + method="get", + uri="/pet/22", + host="localhost:8888", + ), + True, + id="Allowed and not part of blocked", + ), + ), +) +def test_SpecValidator_allowed_and_blocked_spec(server_request, expected): + validator = SpecValidator(allowed_spec=allowed_spec, blocked_spec=blocked_spec) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) + == expected + ) + + +@pytest.mark.parametrize( + "server_request, expected", + ( + pytest.param( + dict( + method="get", + uri="/api/contents/path/to/file.txt/checkpoints/.ipynb_checkpoints/file.txt", + host="localhost:8888", + ), + True, + id="Checkpoint path", + ), + pytest.param( + dict( + method="get", + uri="/api/contents/path/to/file.txt", + host="localhost:8888", + ), + True, + id="Content path", + ), + pytest.param( + dict( + method="get", + uri="/api/contents/path/to/file.txt?tags=dummy", + host="localhost:8888", + ), + True, + id="Encode with query arguments", + ), + pytest.param( + dict( + method="get", + uri="/api/sessions/path/to/file.txt", + host="localhost:8888", + ), + False, + id="Session Id does not support slashes", + ), + ), +) +def test_SpecValidator_encoded_slash(server_request, expected): + validator = SpecValidator( + allowed_spec={ + "openapi": "3.0.1", + "info": {"title": "Test specs", "version": "0.0.1"}, + "paths": { + "/api/contents/{path}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/contents/{path}/checkpoints/{checkpointId}": { + "get": { + "parameters": [ + { + "name": "path", + "in": "path", + "required": True, + "schema": {"type": "string"}, + }, + { + "name": "checkpointId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + }, + ], + "responses": {"200": {"description": ""}}, + }, + }, + "/api/sessions/{sessionId}": { + "get": { + "parameters": [ + { + "name": "sessionId", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "responses": {"200": {"description": ""}}, + }, + }, + }, + }, + encoded_slash_regex=( + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))/checkpoints/([^/]+(?:(?:/[^/]+)*?))$", + r"/api/contents/([^/?]+(?:(?:/[^/]+)*?))$", + # Test regex without group match + r"/api/sessions", + ), + ) + + headers = server_request.pop("headers") if "headers" in server_request else {} + + assert ( + validator.validate(HTTPServerRequest(**server_request, headers=HTTPHeaders(headers))) + == expected + ) + + +@pytest.mark.parametrize( + "regex, test, expected", + ( + pytest.param( + (r"/api/contents/([^/]+(?:(?:/[^/]+)*?))$",), + "/api/contents/path/to/file.txt", + "/api/contents/path%2Fto%2Ffile.txt", + id="Path to encode", + ), + pytest.param( + ( + r"/api/contents/([^/]+(?:(?:/[^/]+)*))$", + r"/api/contents/([^/]+(?:(?:/[^/]+)*?))/checkpoints/([^/]+(?:(?:/[^/]+)*?))$", + ), + "/api/contents/path/to/file.txt/checkpoints/.ipynb_checkpoints/file.txt", + r"/api/contents/path%2Fto%2Ffile.txt%2Fcheckpoints%2F.ipynb_checkpoints%2Ffile.txt", + id="Bad regex order", + ), + pytest.param( + ( + r"/api/contents/([^/]+(?:(?:/[^/]+)*?))/checkpoints/([^/]+(?:(?:/[^/]+)*?))$", + r"/api/contents/([^/]+(?:(?:/[^/]+)*))$", + ), + "/api/contents/path/to/file.txt/checkpoints/.ipynb_checkpoints/file.txt", + r"/api/contents/path%2Fto%2Ffile.txt/checkpoints/.ipynb_checkpoints%2Ffile.txt", + id="Multiple regex", + ), + pytest.param( + (r"/api/contents/([^/]+(?:(?:/[^/]+)*?))$",), + "/api/sessions/path/to/file.txt", + "/api/sessions/path/to/file.txt", + id="No match", + ), + pytest.param( + (r"/api/sessions",), + "/api/sessions/path/to/file.txt", + "/api/sessions/path/to/file.txt", + id="No group", + ), + ), +) +def test_encode_slash(regex, test, expected): + assert encode_slash(regex, test) == expected diff --git a/setup.cfg b/setup.cfg index 5f6cc64116..f45e8525a4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,12 @@ install_requires = packaging [options.extras_require] +validation = + openapi-core>=0.14.0,<0.15.0; python_version >= '3.7' + tornado_openapi3>=1.1.0,<2.0.0; python_version >= '3.7' test = + openapi-core>=0.14.0,<0.15.0; python_version >= '3.7' + tornado_openapi3>=1.1.0,<2.0.0; python_version >= '3.7' coverage pytest>=6.0 pytest-cov