From 6a64969009bf1898bd302e6a9633f71942523f7f Mon Sep 17 00:00:00 2001 From: David Lord Date: Fri, 19 Sep 2025 16:58:48 -0700 Subject: [PATCH] pass context through dispatch methods --- CHANGES.rst | 5 ++ src/flask/app.py | 167 +++++++++++++++++++++++++++++--------- src/flask/ctx.py | 4 +- src/flask/templating.py | 52 ++++++------ tests/test_reqctx.py | 9 +- tests/test_subclassing.py | 2 +- 6 files changed, 167 insertions(+), 72 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index fed622578c..efd076e927 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,6 +9,11 @@ Unreleased a deprecated alias. If an app context is already pushed, it is not reused when dispatching a request. This greatly simplifies the internal code for tracking the active context. :issue:`5639` +- Many ``Flask`` methods involved in request dispatch now take the current + ``AppContext`` as the first parameter, instead of using the proxy objects. + If subclasses were overriding these methods, the old signature is detected, + shows a deprecation warning, and will continue to work during the + deprecation period. :issue:`5815` - ``template_filter``, ``template_test``, and ``template_global`` decorators can be used without parentheses. :issue:`5729` diff --git a/src/flask/app.py b/src/flask/app.py index 1149e2482f..e0c193dcb7 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -1,11 +1,13 @@ from __future__ import annotations import collections.abc as cabc +import inspect import os import sys import typing as t import weakref from datetime import timedelta +from functools import update_wrapper from inspect import iscoroutinefunction from itertools import chain from types import TracebackType @@ -30,6 +32,7 @@ from . import typing as ft from .ctx import AppContext from .globals import _cv_app +from .globals import app_ctx from .globals import g from .globals import request from .globals import session @@ -73,6 +76,35 @@ def _make_timedelta(value: timedelta | int | None) -> timedelta | None: return timedelta(seconds=value) +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) + + +# Other methods may call the overridden method with the new ctx arg. Remove it +# and call the method with the remaining args. +def remove_ctx(f: F) -> F: + def wrapper(self: Flask, *args: t.Any, **kwargs: t.Any) -> t.Any: + if args and isinstance(args[0], AppContext): + args = args[1:] + + return f(self, *args, **kwargs) + + return update_wrapper(wrapper, f) # type: ignore[return-value] + + +# The overridden method may call super().base_method without the new ctx arg. +# Add it to the args for the call. +def add_ctx(f: F) -> F: + def wrapper(self: Flask, *args: t.Any, **kwargs: t.Any) -> t.Any: + if not args: + args = (app_ctx._get_current_object(),) + elif not isinstance(args[0], AppContext): + args = (app_ctx._get_current_object(), *args) + + return f(self, *args, **kwargs) + + return update_wrapper(wrapper, f) # type: ignore[return-value] + + class Flask(App): """The flask object implements a WSGI application and acts as the central object. It is passed the name of the module or package of the @@ -218,6 +250,62 @@ class Flask(App): #: .. versionadded:: 0.8 session_interface: SessionInterface = SecureCookieSessionInterface() + def __init_subclass__(cls, **kwargs: t.Any) -> None: + import warnings + + # These method signatures were updated to take a ctx param. Detect + # overridden methods in subclasses that still have the old signature. + # Show a deprecation warning and wrap to call with correct args. + for method in ( + cls.handle_http_exception, + cls.handle_user_exception, + cls.handle_exception, + cls.log_exception, + cls.dispatch_request, + cls.full_dispatch_request, + cls.finalize_request, + cls.make_default_options_response, + cls.preprocess_request, + cls.process_response, + cls.do_teardown_request, + cls.do_teardown_appcontext, + ): + base_method = getattr(Flask, method.__name__) + + if method is base_method: + # not overridden + continue + + # get the second parameter (first is self) + iter_params = iter(inspect.signature(method).parameters.values()) + next(iter_params) + param = next(iter_params, None) + + # must have second parameter named ctx or annotated AppContext + if param is None or not ( + # no annotation, match name + (param.annotation is inspect.Parameter.empty and param.name == "ctx") + or ( + # string annotation, access path ends with AppContext + isinstance(param.annotation, str) + and param.annotation.rpartition(".")[2] == "AppContext" + ) + or ( + # class annotation + inspect.isclass(param.annotation) + and issubclass(param.annotation, AppContext) + ) + ): + warnings.warn( + f"The '{method.__name__}' method now takes 'ctx: AppContext'" + " as the first parameter. The old signature is deprecated" + " and will not be supported in Flask 4.0.", + DeprecationWarning, + stacklevel=2, + ) + setattr(cls, method.__name__, remove_ctx(method)) + setattr(Flask, method.__name__, add_ctx(base_method)) + def __init__( self, import_name: str, @@ -498,7 +586,9 @@ def raise_routing_exception(self, request: Request) -> t.NoReturn: raise FormDataRoutingRedirect(request) - def update_template_context(self, context: dict[str, t.Any]) -> None: + def update_template_context( + self, ctx: AppContext, context: dict[str, t.Any] + ) -> None: """Update the template context with some commonly used variables. This injects request, session, config and g into the template context as well as everything template context processors want @@ -512,7 +602,7 @@ def update_template_context(self, context: dict[str, t.Any]) -> None: names: t.Iterable[str | None] = (None,) # A template may be rendered outside a request context. - if (ctx := _cv_app.get(None)) is not None and ctx.has_request: + if ctx.has_request: names = chain(names, reversed(ctx.request.blueprints)) # The values passed to render_template take precedence. Keep a @@ -737,7 +827,7 @@ def test_cli_runner(self, **kwargs: t.Any) -> FlaskCliRunner: return cls(self, **kwargs) # type: ignore def handle_http_exception( - self, e: HTTPException + self, ctx: AppContext, e: HTTPException ) -> HTTPException | ft.ResponseReturnValue: """Handles an HTTP exception. By default this will invoke the registered error handlers and fall back to returning the @@ -766,13 +856,13 @@ def handle_http_exception( if isinstance(e, RoutingException): return e - handler = self._find_error_handler(e, request.blueprints) + handler = self._find_error_handler(e, ctx.request.blueprints) if handler is None: return e return self.ensure_sync(handler)(e) # type: ignore[no-any-return] def handle_user_exception( - self, e: Exception + self, ctx: AppContext, e: Exception ) -> HTTPException | ft.ResponseReturnValue: """This method is called whenever an exception occurs that should be handled. A special case is :class:`~werkzeug @@ -794,16 +884,16 @@ def handle_user_exception( e.show_exception = True if isinstance(e, HTTPException) and not self.trap_http_exception(e): - return self.handle_http_exception(e) + return self.handle_http_exception(ctx, e) - handler = self._find_error_handler(e, request.blueprints) + handler = self._find_error_handler(e, ctx.request.blueprints) if handler is None: raise return self.ensure_sync(handler)(e) # type: ignore[no-any-return] - def handle_exception(self, e: Exception) -> Response: + def handle_exception(self, ctx: AppContext, e: Exception) -> Response: """Handle an exception that did not have an error handler associated with it, or that was raised from an error handler. This always causes a 500 ``InternalServerError``. @@ -846,19 +936,20 @@ def handle_exception(self, e: Exception) -> Response: raise e - self.log_exception(exc_info) + self.log_exception(ctx, exc_info) server_error: InternalServerError | ft.ResponseReturnValue server_error = InternalServerError(original_exception=e) - handler = self._find_error_handler(server_error, request.blueprints) + handler = self._find_error_handler(server_error, ctx.request.blueprints) if handler is not None: server_error = self.ensure_sync(handler)(server_error) - return self.finalize_request(server_error, from_error_handler=True) + return self.finalize_request(ctx, server_error, from_error_handler=True) def log_exception( self, - exc_info: (tuple[type, BaseException, TracebackType] | tuple[None, None, None]), + ctx: AppContext, + exc_info: tuple[type, BaseException, TracebackType] | tuple[None, None, None], ) -> None: """Logs an exception. This is called by :meth:`handle_exception` if debugging is disabled and right before the handler is called. @@ -868,10 +959,10 @@ def log_exception( .. versionadded:: 0.8 """ self.logger.error( - f"Exception on {request.path} [{request.method}]", exc_info=exc_info + f"Exception on {ctx.request.path} [{ctx.request.method}]", exc_info=exc_info ) - def dispatch_request(self) -> ft.ResponseReturnValue: + def dispatch_request(self, ctx: AppContext) -> ft.ResponseReturnValue: """Does the request dispatching. Matches the URL and returns the return value of the view or error handler. This does not have to be a response object. In order to convert the return value to a @@ -881,7 +972,7 @@ def dispatch_request(self) -> ft.ResponseReturnValue: This no longer does the exception handling, this code was moved to the new :meth:`full_dispatch_request`. """ - req = _cv_app.get().request + req = ctx.request if req.routing_exception is not None: self.raise_routing_exception(req) @@ -892,12 +983,12 @@ def dispatch_request(self) -> ft.ResponseReturnValue: getattr(rule, "provide_automatic_options", False) and req.method == "OPTIONS" ): - return self.make_default_options_response() + return self.make_default_options_response(ctx) # otherwise dispatch to the handler for that endpoint view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment] return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return] - def full_dispatch_request(self) -> Response: + def full_dispatch_request(self, ctx: AppContext) -> Response: """Dispatches the request and on top of that performs request pre and postprocessing as well as HTTP exception catching and error handling. @@ -908,15 +999,16 @@ def full_dispatch_request(self) -> Response: try: request_started.send(self, _async_wrapper=self.ensure_sync) - rv = self.preprocess_request() + rv = self.preprocess_request(ctx) if rv is None: - rv = self.dispatch_request() + rv = self.dispatch_request(ctx) except Exception as e: - rv = self.handle_user_exception(e) - return self.finalize_request(rv) + rv = self.handle_user_exception(ctx, e) + return self.finalize_request(ctx, rv) def finalize_request( self, + ctx: AppContext, rv: ft.ResponseReturnValue | HTTPException, from_error_handler: bool = False, ) -> Response: @@ -934,7 +1026,7 @@ def finalize_request( """ response = self.make_response(rv) try: - response = self.process_response(response) + response = self.process_response(ctx, response) request_finished.send( self, _async_wrapper=self.ensure_sync, response=response ) @@ -946,15 +1038,14 @@ def finalize_request( ) return response - def make_default_options_response(self) -> Response: + def make_default_options_response(self, ctx: AppContext) -> Response: """This method is called to create the default ``OPTIONS`` response. This can be changed through subclassing to change the default behavior of ``OPTIONS`` responses. .. versionadded:: 0.7 """ - adapter = _cv_app.get().url_adapter - methods = adapter.allowed_methods() # type: ignore[union-attr] + methods = ctx.url_adapter.allowed_methods() # type: ignore[union-attr] rv = self.response_class() rv.allow.update(methods) return rv @@ -1260,7 +1351,7 @@ def make_response(self, rv: ft.ResponseReturnValue) -> Response: return rv - def preprocess_request(self) -> ft.ResponseReturnValue | None: + def preprocess_request(self, ctx: AppContext) -> ft.ResponseReturnValue | None: """Called before the request is dispatched. Calls :attr:`url_value_preprocessors` registered with the app and the current blueprint (if any). Then calls :attr:`before_request_funcs` @@ -1270,7 +1361,7 @@ def preprocess_request(self) -> ft.ResponseReturnValue | None: value is handled as if it was the return value from the view, and further request handling is stopped. """ - req = _cv_app.get().request + req = ctx.request names = (None, *reversed(req.blueprints)) for name in names: @@ -1288,7 +1379,7 @@ def preprocess_request(self) -> ft.ResponseReturnValue | None: return None - def process_response(self, response: Response) -> Response: + def process_response(self, ctx: AppContext, response: Response) -> Response: """Can be overridden in order to modify the response object before it's sent to the WSGI server. By default this will call all the :meth:`after_request` decorated functions. @@ -1301,8 +1392,6 @@ def process_response(self, response: Response) -> Response: :return: a new response object or the same, has to be an instance of :attr:`response_class`. """ - ctx = _cv_app.get() - for func in ctx._after_request_functions: response = self.ensure_sync(func)(response) @@ -1316,7 +1405,9 @@ def process_response(self, response: Response) -> Response: return response - def do_teardown_request(self, exc: BaseException | None = None) -> None: + def do_teardown_request( + self, ctx: AppContext, exc: BaseException | None = None + ) -> None: """Called after the request is dispatched and the response is finalized, right before the request context is popped. Called by :meth:`.AppContext.pop`. @@ -1331,16 +1422,16 @@ def do_teardown_request(self, exc: BaseException | None = None) -> None: .. versionchanged:: 0.9 Added the ``exc`` argument. """ - req = _cv_app.get().request - - for name in chain(req.blueprints, (None,)): + for name in chain(ctx.request.blueprints, (None,)): if name in self.teardown_request_funcs: for func in reversed(self.teardown_request_funcs[name]): self.ensure_sync(func)(exc) request_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc) - def do_teardown_appcontext(self, exc: BaseException | None = None) -> None: + def do_teardown_appcontext( + self, ctx: AppContext, exc: BaseException | None = None + ) -> None: """Called right before the application context is popped. Called by :meth:`.AppContext.pop`. @@ -1473,17 +1564,17 @@ def wsgi_app( try: try: ctx.push() - response = self.full_dispatch_request() + response = self.full_dispatch_request(ctx) except Exception as e: error = e - response = self.handle_exception(e) + response = self.handle_exception(ctx, e) except: # noqa: B001 error = sys.exc_info()[1] raise return response(environ, start_response) finally: if "werkzeug.debug.preserve_context" in environ: - environ["werkzeug.debug.preserve_context"](_cv_app.get()) + environ["werkzeug.debug.preserve_context"](ctx) if error is not None and self.should_ignore_error(error): error = None diff --git a/src/flask/ctx.py b/src/flask/ctx.py index 1ac86eafe3..ba72b17596 100644 --- a/src/flask/ctx.py +++ b/src/flask/ctx.py @@ -471,10 +471,10 @@ def pop(self, exc: BaseException | None = None) -> None: try: if self._request is not None: - self.app.do_teardown_request(exc) + self.app.do_teardown_request(self, exc) self._request.close() finally: - self.app.do_teardown_appcontext(exc) + self.app.do_teardown_appcontext(self, exc) _cv_app.reset(self._cv_token) self._cv_token = None appcontext_popped.send(self.app, _async_wrapper=self.app.ensure_sync) diff --git a/src/flask/templating.py b/src/flask/templating.py index 9a0ace846d..4bb86d59ad 100644 --- a/src/flask/templating.py +++ b/src/flask/templating.py @@ -7,14 +7,13 @@ from jinja2 import Template from jinja2 import TemplateNotFound -from .globals import _cv_app -from .globals import current_app +from .ctx import AppContext +from .globals import app_ctx from .helpers import stream_with_context from .signals import before_render_template from .signals import template_rendered if t.TYPE_CHECKING: # pragma: no cover - from .app import Flask from .sansio.app import App from .sansio.scaffold import Scaffold @@ -23,15 +22,12 @@ def _default_template_ctx_processor() -> dict[str, t.Any]: """Default template context processor. Injects `request`, `session` and `g`. """ - ctx = _cv_app.get(None) - rv: dict[str, t.Any] = {} + ctx = app_ctx._get_current_object() + rv: dict[str, t.Any] = {"g": ctx.g} - if ctx is not None: - rv["g"] = ctx.g - - if ctx.has_request: - rv["request"] = ctx.request - rv["session"] = ctx.session + if ctx.has_request: + rv["request"] = ctx.request + rv["session"] = ctx.session return rv @@ -123,8 +119,9 @@ def list_templates(self) -> list[str]: return list(result) -def _render(app: Flask, template: Template, context: dict[str, t.Any]) -> str: - app.update_template_context(context) +def _render(ctx: AppContext, template: Template, context: dict[str, t.Any]) -> str: + app = ctx.app + app.update_template_context(ctx, context) before_render_template.send( app, _async_wrapper=app.ensure_sync, template=template, context=context ) @@ -145,9 +142,9 @@ def render_template( a list is given, the first name to exist will be rendered. :param context: The variables to make available in the template. """ - app = current_app._get_current_object() - template = app.jinja_env.get_or_select_template(template_name_or_list) - return _render(app, template, context) + ctx = app_ctx._get_current_object() + template = ctx.app.jinja_env.get_or_select_template(template_name_or_list) + return _render(ctx, template, context) def render_template_string(source: str, **context: t.Any) -> str: @@ -157,15 +154,16 @@ def render_template_string(source: str, **context: t.Any) -> str: :param source: The source code of the template to render. :param context: The variables to make available in the template. """ - app = current_app._get_current_object() - template = app.jinja_env.from_string(source) - return _render(app, template, context) + ctx = app_ctx._get_current_object() + template = ctx.app.jinja_env.from_string(source) + return _render(ctx, template, context) def _stream( - app: Flask, template: Template, context: dict[str, t.Any] + ctx: AppContext, template: Template, context: dict[str, t.Any] ) -> t.Iterator[str]: - app.update_template_context(context) + app = ctx.app + app.update_template_context(ctx, context) before_render_template.send( app, _async_wrapper=app.ensure_sync, template=template, context=context ) @@ -193,9 +191,9 @@ def stream_template( .. versionadded:: 2.2 """ - app = current_app._get_current_object() - template = app.jinja_env.get_or_select_template(template_name_or_list) - return _stream(app, template, context) + ctx = app_ctx._get_current_object() + template = ctx.app.jinja_env.get_or_select_template(template_name_or_list) + return _stream(ctx, template, context) def stream_template_string(source: str, **context: t.Any) -> t.Iterator[str]: @@ -208,6 +206,6 @@ def stream_template_string(source: str, **context: t.Any) -> t.Iterator[str]: .. versionadded:: 2.2 """ - app = current_app._get_current_object() - template = app.jinja_env.from_string(source) - return _stream(app, template, context) + ctx = app_ctx._get_current_object() + template = ctx.app.jinja_env.from_string(source) + return _stream(ctx, template, context) diff --git a/tests/test_reqctx.py b/tests/test_reqctx.py index a7b77eb901..78561f520d 100644 --- a/tests/test_reqctx.py +++ b/tests/test_reqctx.py @@ -288,8 +288,9 @@ def test_bad_environ_raises_bad_request(): # use a non-printable character in the Host - this is key to this test environ["HTTP_HOST"] = "\x8a" - with app.request_context(environ): - response = app.full_dispatch_request() + with app.request_context(environ) as ctx: + response = app.full_dispatch_request(ctx) + assert response.status_code == 400 @@ -308,8 +309,8 @@ def index(): # these characters are all IDNA-compatible environ["HTTP_HOST"] = "ąśźäüжŠßя.com" - with app.request_context(environ): - response = app.full_dispatch_request() + with app.request_context(environ) as ctx: + response = app.full_dispatch_request(ctx) assert response.status_code == 200 diff --git a/tests/test_subclassing.py b/tests/test_subclassing.py index 087c50dc72..3b9fe31656 100644 --- a/tests/test_subclassing.py +++ b/tests/test_subclassing.py @@ -5,7 +5,7 @@ def test_suppressed_exception_logging(): class SuppressedFlask(flask.Flask): - def log_exception(self, exc_info): + def log_exception(self, ctx, exc_info): pass out = StringIO()