-
-
Notifications
You must be signed in to change notification settings - Fork 993
Limit max request size #2155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
If we're going to make it a configurable middleware it might also make sense to have some sort of timeout for connections and each chunk, maybe infinite by default but definitely tunable. Another thing to keep in mind is that this is likely something users want to control on a per-endpoint basis. That is, if I have an app that has an upload feature where I'm expecting 1GB files it's likely a single endpoint that expects 1GB files so I'd want to bump up the limits just for that endpoint. That makes me think that the best strategy may be a per-endpoint middleware w/ a companion middleware that just tweaks the config by changing it in |
This is a good one! I also agree that we need a global setting and per-route (Route + Mount). We can add |
Why should the ASGI application be the one to set this instead of the server? |
Example: global POST limit is 1mb, for selected endpoints that upload files - 100mb. |
Adding a LimitRequestSizeMiddleware is the simplest and forward-compatible way. |
Yeah. Shall we follow this path? |
Yes I think someone should make a PR and we can discuss the details (override vs. min/max, should there be a default, etc.) there. |
I am someone, I made a PR 😆 : #2328 |
The PR was closed, but the idea is still on the table, so here are my 2ct:
|
Hi guys: are you still interested in this? I just scan through your discussion and came up with something like this: class Request(HTTPConnection):
_form: FormData | None
def __init__(
self,
scope: Scope,
receive: Receive = empty_receive,
send: Send = empty_send,
max_content_length: int | None = None,
):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
self._send = send
self._stream_consumed = False
self._is_disconnected = False
self._form = None
if max_content_length is not None:
assert max_content_length > 0
if self.headers.get("content-length") > max_content_length:
raise ValueError("Body too large")
self._max_content_length = max_content_length
@property
def method(self) -> str:
return typing.cast(str, self.scope["method"])
@property
def receive(self) -> Receive:
return self._receive
async def stream(self, chunk_size: int | None = None) -> typing.AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
yield b""
return
if self._stream_consumed:
raise RuntimeError("Stream consumed")
buffer = bytearray()
while not self._stream_consumed:
message = await self._receive()
if message["type"] == "http.request":
body = message.get("body", b"")
buffer.extend(body) # Append new data to buffer
if chunk_size:
while len(buffer) >= chunk_size:
yield buffer[:chunk_size] # Yield chunk
del buffer[:chunk_size] # Remove yielded data
if not message.get("more_body", False):
self._stream_consumed = True
if buffer:
yield bytes(buffer) # Yield remaining buffer data
elif message["type"] == "http.disconnect": # pragma: no branch
self._is_disconnected = True
raise ClientDisconnect()
yield b""
async def body(self, chunk_size: int | None = None) -> bytes:
if not hasattr(self, "_body"):
chunks: list[bytes] = []
async for chunk in self.stream(chunk_size):
chunks.append(chunk)
self._body = b"".join(chunks)
return self._body
async def json(self) -> typing.Any:
if not hasattr(self, "_json"): # pragma: no branch
body = await self.body()
self._json = json.loads(body)
return self._json
async def _get_form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
chunk_size: int | None = None
) -> FormData:
if self._form is None: # pragma: no branch
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
content_type_header = self.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
try:
multipart_parser = MultiPartParser(
self.headers,
self.stream(chunk_size),
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
if "app" in self.scope:
raise HTTPException(status_code=400, detail=exc.message)
raise exc
elif content_type == b"application/x-www-form-urlencoded":
form_parser = FormParser(self.headers, self.stream(chunk_size))
self._form = await form_parser.parse()
else:
self._form = FormData()
return self._form
def form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
chunk_size: int | None = None
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(
self._get_form(
max_files=max_files, max_fields=max_fields, max_part_size=max_part_size, chunk_size=chunk_size
)
)
async def close(self) -> None:
if self._form is not None: # pragma: no branch
await self._form.close()
async def is_disconnected(self) -> bool:
if not self._is_disconnected:
message: Message = {}
# If message isn't immediately available, move on
with anyio.CancelScope() as cs:
cs.cancel()
message = await self._receive()
if message.get("type") == "http.disconnect":
self._is_disconnected = True
return self._is_disconnected
async def send_push_promise(self, path: str) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
raw_headers.append(
(name.encode("latin-1"), value.encode("latin-1"))
)
await self._send(
{"type": "http.response.push", "path": path, "headers": raw_headers}
)
def request_response(
func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
f: typing.Callable[[Request], typing.Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
try:
request = Request(scope, receive, send)
except ValueError: # we might want something more specific like RequestBodyOverSizedError
return app_that_send_error_message
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = await f(request)
await response(scope, receive, send)
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
return app
|
User can put any value into the header and render server unreliable, so the proposed variant is not optimal. |
yeah, then we might:
the thing is through, if we assume server is not realiable, and |
WSGI defines that servers "should not" pass more bytes to the application than specified in the But does that really matter for this discussion? A safeguard to prevent OOMs or other resource exhaustion attacks should never depend on a client specified header or undefined server behavior. The only reliable way to enforce such a limit is to count received bytes. |
Discussed in #1516
Originally posted by aviramha April 5, 2020
As discussed in the Gitter, my opinion is that starlette should provide a default limit for request size.
The main reason is that without it, any Starlette application is vulnerable to very easy DoS.
For example, newbie me can write a program as follows:
As a malicious user, I could send a 30GB sized JSON and cause the memory to go OOM.
Other frameworks support this also - Django, Quart.
My proposal is to add a default limit which can be overrided in the app configuration.
Important
The text was updated successfully, but these errors were encountered: