Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ Unreleased
a deprecated alias. If an app context is already pushed, it is not reused
when dispatching a request. This greatly simplifies the internal code for tracking
the active context. :issue:`5639`
- Many ``Flask`` methods involved in request dispatch now take the current
``AppContext`` as the first parameter, instead of using the proxy objects.
If subclasses were overriding these methods, the old signature is detected,
shows a deprecation warning, and will continue to work during the
deprecation period. :issue:`5815`
- ``template_filter``, ``template_test``, and ``template_global`` decorators
can be used without parentheses. :issue:`5729`

Expand Down
167 changes: 129 additions & 38 deletions src/flask/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import collections.abc as cabc
import inspect
import os
import sys
import typing as t
import weakref
from datetime import timedelta
from functools import update_wrapper
from inspect import iscoroutinefunction
from itertools import chain
from types import TracebackType
Expand All @@ -30,6 +32,7 @@
from . import typing as ft
from .ctx import AppContext
from .globals import _cv_app
from .globals import app_ctx
from .globals import g
from .globals import request
from .globals import session
Expand Down Expand Up @@ -73,6 +76,35 @@ def _make_timedelta(value: timedelta | int | None) -> timedelta | None:
return timedelta(seconds=value)


F = t.TypeVar("F", bound=t.Callable[..., t.Any])


# Other methods may call the overridden method with the new ctx arg. Remove it
# and call the method with the remaining args.
def remove_ctx(f: F) -> F:
def wrapper(self: Flask, *args: t.Any, **kwargs: t.Any) -> t.Any:
if args and isinstance(args[0], AppContext):
args = args[1:]

return f(self, *args, **kwargs)

return update_wrapper(wrapper, f) # type: ignore[return-value]


# The overridden method may call super().base_method without the new ctx arg.
# Add it to the args for the call.
def add_ctx(f: F) -> F:
def wrapper(self: Flask, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not args:
args = (app_ctx._get_current_object(),)
elif not isinstance(args[0], AppContext):
args = (app_ctx._get_current_object(), *args)

return f(self, *args, **kwargs)

return update_wrapper(wrapper, f) # type: ignore[return-value]


class Flask(App):
"""The flask object implements a WSGI application and acts as the central
object. It is passed the name of the module or package of the
Expand Down Expand Up @@ -218,6 +250,62 @@ class Flask(App):
#: .. versionadded:: 0.8
session_interface: SessionInterface = SecureCookieSessionInterface()

def __init_subclass__(cls, **kwargs: t.Any) -> None:
import warnings

# These method signatures were updated to take a ctx param. Detect
# overridden methods in subclasses that still have the old signature.
# Show a deprecation warning and wrap to call with correct args.
for method in (
cls.handle_http_exception,
cls.handle_user_exception,
cls.handle_exception,
cls.log_exception,
cls.dispatch_request,
cls.full_dispatch_request,
cls.finalize_request,
cls.make_default_options_response,
cls.preprocess_request,
cls.process_response,
cls.do_teardown_request,
cls.do_teardown_appcontext,
):
base_method = getattr(Flask, method.__name__)

if method is base_method:
# not overridden
continue

# get the second parameter (first is self)
iter_params = iter(inspect.signature(method).parameters.values())
next(iter_params)
param = next(iter_params, None)

# must have second parameter named ctx or annotated AppContext
if param is None or not (
# no annotation, match name
(param.annotation is inspect.Parameter.empty and param.name == "ctx")
or (
# string annotation, access path ends with AppContext
isinstance(param.annotation, str)
and param.annotation.rpartition(".")[2] == "AppContext"
)
or (
# class annotation
inspect.isclass(param.annotation)
and issubclass(param.annotation, AppContext)
)
):
warnings.warn(
f"The '{method.__name__}' method now takes 'ctx: AppContext'"
" as the first parameter. The old signature is deprecated"
" and will not be supported in Flask 4.0.",
DeprecationWarning,
stacklevel=2,
)
setattr(cls, method.__name__, remove_ctx(method))
setattr(Flask, method.__name__, add_ctx(base_method))

def __init__(
self,
import_name: str,
Expand Down Expand Up @@ -498,7 +586,9 @@ def raise_routing_exception(self, request: Request) -> t.NoReturn:

raise FormDataRoutingRedirect(request)

def update_template_context(self, context: dict[str, t.Any]) -> None:
def update_template_context(
self, ctx: AppContext, context: dict[str, t.Any]
) -> None:
"""Update the template context with some commonly used variables.
This injects request, session, config and g into the template
context as well as everything template context processors want
Expand All @@ -512,7 +602,7 @@ def update_template_context(self, context: dict[str, t.Any]) -> None:
names: t.Iterable[str | None] = (None,)

# A template may be rendered outside a request context.
if (ctx := _cv_app.get(None)) is not None and ctx.has_request:
if ctx.has_request:
names = chain(names, reversed(ctx.request.blueprints))

# The values passed to render_template take precedence. Keep a
Expand Down Expand Up @@ -737,7 +827,7 @@ def test_cli_runner(self, **kwargs: t.Any) -> FlaskCliRunner:
return cls(self, **kwargs) # type: ignore

def handle_http_exception(
self, e: HTTPException
self, ctx: AppContext, e: HTTPException
) -> HTTPException | ft.ResponseReturnValue:
"""Handles an HTTP exception. By default this will invoke the
registered error handlers and fall back to returning the
Expand Down Expand Up @@ -766,13 +856,13 @@ def handle_http_exception(
if isinstance(e, RoutingException):
return e

handler = self._find_error_handler(e, request.blueprints)
handler = self._find_error_handler(e, ctx.request.blueprints)
if handler is None:
return e
return self.ensure_sync(handler)(e) # type: ignore[no-any-return]

def handle_user_exception(
self, e: Exception
self, ctx: AppContext, e: Exception
) -> HTTPException | ft.ResponseReturnValue:
"""This method is called whenever an exception occurs that
should be handled. A special case is :class:`~werkzeug
Expand All @@ -794,16 +884,16 @@ def handle_user_exception(
e.show_exception = True

if isinstance(e, HTTPException) and not self.trap_http_exception(e):
return self.handle_http_exception(e)
return self.handle_http_exception(ctx, e)

handler = self._find_error_handler(e, request.blueprints)
handler = self._find_error_handler(e, ctx.request.blueprints)

if handler is None:
raise

return self.ensure_sync(handler)(e) # type: ignore[no-any-return]

def handle_exception(self, e: Exception) -> Response:
def handle_exception(self, ctx: AppContext, e: Exception) -> Response:
"""Handle an exception that did not have an error handler
associated with it, or that was raised from an error handler.
This always causes a 500 ``InternalServerError``.
Expand Down Expand Up @@ -846,19 +936,20 @@ def handle_exception(self, e: Exception) -> Response:

raise e

self.log_exception(exc_info)
self.log_exception(ctx, exc_info)
server_error: InternalServerError | ft.ResponseReturnValue
server_error = InternalServerError(original_exception=e)
handler = self._find_error_handler(server_error, request.blueprints)
handler = self._find_error_handler(server_error, ctx.request.blueprints)

if handler is not None:
server_error = self.ensure_sync(handler)(server_error)

return self.finalize_request(server_error, from_error_handler=True)
return self.finalize_request(ctx, server_error, from_error_handler=True)

def log_exception(
self,
exc_info: (tuple[type, BaseException, TracebackType] | tuple[None, None, None]),
ctx: AppContext,
exc_info: tuple[type, BaseException, TracebackType] | tuple[None, None, None],
) -> None:
"""Logs an exception. This is called by :meth:`handle_exception`
if debugging is disabled and right before the handler is called.
Expand All @@ -868,10 +959,10 @@ def log_exception(
.. versionadded:: 0.8
"""
self.logger.error(
f"Exception on {request.path} [{request.method}]", exc_info=exc_info
f"Exception on {ctx.request.path} [{ctx.request.method}]", exc_info=exc_info
)

def dispatch_request(self) -> ft.ResponseReturnValue:
def dispatch_request(self, ctx: AppContext) -> ft.ResponseReturnValue:
"""Does the request dispatching. Matches the URL and returns the
return value of the view or error handler. This does not have to
be a response object. In order to convert the return value to a
Expand All @@ -881,7 +972,7 @@ def dispatch_request(self) -> ft.ResponseReturnValue:
This no longer does the exception handling, this code was
moved to the new :meth:`full_dispatch_request`.
"""
req = _cv_app.get().request
req = ctx.request

if req.routing_exception is not None:
self.raise_routing_exception(req)
Expand All @@ -892,12 +983,12 @@ def dispatch_request(self) -> ft.ResponseReturnValue:
getattr(rule, "provide_automatic_options", False)
and req.method == "OPTIONS"
):
return self.make_default_options_response()
return self.make_default_options_response(ctx)
# otherwise dispatch to the handler for that endpoint
view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment]
return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return]

def full_dispatch_request(self) -> Response:
def full_dispatch_request(self, ctx: AppContext) -> Response:
"""Dispatches the request and on top of that performs request
pre and postprocessing as well as HTTP exception catching and
error handling.
Expand All @@ -908,15 +999,16 @@ def full_dispatch_request(self) -> Response:

try:
request_started.send(self, _async_wrapper=self.ensure_sync)
rv = self.preprocess_request()
rv = self.preprocess_request(ctx)
if rv is None:
rv = self.dispatch_request()
rv = self.dispatch_request(ctx)
except Exception as e:
rv = self.handle_user_exception(e)
return self.finalize_request(rv)
rv = self.handle_user_exception(ctx, e)
return self.finalize_request(ctx, rv)

def finalize_request(
self,
ctx: AppContext,
rv: ft.ResponseReturnValue | HTTPException,
from_error_handler: bool = False,
) -> Response:
Expand All @@ -934,7 +1026,7 @@ def finalize_request(
"""
response = self.make_response(rv)
try:
response = self.process_response(response)
response = self.process_response(ctx, response)
request_finished.send(
self, _async_wrapper=self.ensure_sync, response=response
)
Expand All @@ -946,15 +1038,14 @@ def finalize_request(
)
return response

def make_default_options_response(self) -> Response:
def make_default_options_response(self, ctx: AppContext) -> Response:
"""This method is called to create the default ``OPTIONS`` response.
This can be changed through subclassing to change the default
behavior of ``OPTIONS`` responses.

.. versionadded:: 0.7
"""
adapter = _cv_app.get().url_adapter
methods = adapter.allowed_methods() # type: ignore[union-attr]
methods = ctx.url_adapter.allowed_methods() # type: ignore[union-attr]
rv = self.response_class()
rv.allow.update(methods)
return rv
Expand Down Expand Up @@ -1260,7 +1351,7 @@ def make_response(self, rv: ft.ResponseReturnValue) -> Response:

return rv

def preprocess_request(self) -> ft.ResponseReturnValue | None:
def preprocess_request(self, ctx: AppContext) -> ft.ResponseReturnValue | None:
"""Called before the request is dispatched. Calls
:attr:`url_value_preprocessors` registered with the app and the
current blueprint (if any). Then calls :attr:`before_request_funcs`
Expand All @@ -1270,7 +1361,7 @@ def preprocess_request(self) -> ft.ResponseReturnValue | None:
value is handled as if it was the return value from the view, and
further request handling is stopped.
"""
req = _cv_app.get().request
req = ctx.request
names = (None, *reversed(req.blueprints))

for name in names:
Expand All @@ -1288,7 +1379,7 @@ def preprocess_request(self) -> ft.ResponseReturnValue | None:

return None

def process_response(self, response: Response) -> Response:
def process_response(self, ctx: AppContext, response: Response) -> Response:
"""Can be overridden in order to modify the response object
before it's sent to the WSGI server. By default this will
call all the :meth:`after_request` decorated functions.
Expand All @@ -1301,8 +1392,6 @@ def process_response(self, response: Response) -> Response:
:return: a new response object or the same, has to be an
instance of :attr:`response_class`.
"""
ctx = _cv_app.get()

for func in ctx._after_request_functions:
response = self.ensure_sync(func)(response)

Expand All @@ -1316,7 +1405,9 @@ def process_response(self, response: Response) -> Response:

return response

def do_teardown_request(self, exc: BaseException | None = None) -> None:
def do_teardown_request(
self, ctx: AppContext, exc: BaseException | None = None
) -> None:
"""Called after the request is dispatched and the response is finalized,
right before the request context is popped. Called by
:meth:`.AppContext.pop`.
Expand All @@ -1331,16 +1422,16 @@ def do_teardown_request(self, exc: BaseException | None = None) -> None:
.. versionchanged:: 0.9
Added the ``exc`` argument.
"""
req = _cv_app.get().request

for name in chain(req.blueprints, (None,)):
for name in chain(ctx.request.blueprints, (None,)):
if name in self.teardown_request_funcs:
for func in reversed(self.teardown_request_funcs[name]):
self.ensure_sync(func)(exc)

request_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc)

def do_teardown_appcontext(self, exc: BaseException | None = None) -> None:
def do_teardown_appcontext(
self, ctx: AppContext, exc: BaseException | None = None
) -> None:
"""Called right before the application context is popped. Called by
:meth:`.AppContext.pop`.

Expand Down Expand Up @@ -1473,17 +1564,17 @@ def wsgi_app(
try:
try:
ctx.push()
response = self.full_dispatch_request()
response = self.full_dispatch_request(ctx)
except Exception as e:
error = e
response = self.handle_exception(e)
response = self.handle_exception(ctx, e)
except: # noqa: B001
error = sys.exc_info()[1]
raise
return response(environ, start_response)
finally:
if "werkzeug.debug.preserve_context" in environ:
environ["werkzeug.debug.preserve_context"](_cv_app.get())
environ["werkzeug.debug.preserve_context"](ctx)

if error is not None and self.should_ignore_error(error):
error = None
Expand Down
Loading
Loading