From d6be0ceb8dcdefe69e1efcdfb7eb9f28f6f10f93 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Thu, 13 Apr 2023 14:07:05 +0100 Subject: [PATCH] Requests chunk encoded request handling --- openapi_core/contrib/requests/requests.py | 13 +++++++++++++ .../requests/test_requests_validation.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/openapi_core/contrib/requests/requests.py b/openapi_core/contrib/requests/requests.py index 00a462f5..02a4b4d6 100644 --- a/openapi_core/contrib/requests/requests.py +++ b/openapi_core/contrib/requests/requests.py @@ -1,4 +1,5 @@ """OpenAPI core contrib requests requests module""" +from typing import Mapping from typing import Optional from typing import Union from urllib.parse import parse_qs @@ -7,6 +8,7 @@ from requests import PreparedRequest from requests import Request from requests.cookies import RequestsCookieJar +from requests.utils import rewind_body from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict @@ -67,6 +69,17 @@ def method(self) -> str: def body(self) -> Optional[str]: if self.request.body is None: return None + is_stream = all( + [ + hasattr(self.request.body, "__iter__"), + not isinstance(self.request.body, (str, list, tuple, Mapping)), + ] + ) + if is_stream: + chunks = list(self.request.body) + body = "".join(chunks) + self.request.body = (x for x in chunks) + return body if isinstance(self.request.body, bytes): return self.request.body.decode("utf-8") assert isinstance(self.request.body, str) diff --git a/tests/integration/contrib/requests/test_requests_validation.py b/tests/integration/contrib/requests/test_requests_validation.py index 2e8aee8c..d54011db 100644 --- a/tests/integration/contrib/requests/test_requests_validation.py +++ b/tests/integration/contrib/requests/test_requests_validation.py @@ -1,3 +1,5 @@ +from types import GeneratorType + import pytest import requests import responses @@ -72,6 +74,22 @@ def test_request_validator_path_pattern(self, request_unmarshaller): result = request_unmarshaller.unmarshal(openapi_request) assert not result.errors + def test_request_validator_encoded_chunks(self, request_unmarshaller): + def gen(): + yield '{"param1": 1}' + + request = requests.Request( + "POST", + "http://localhost/browse/12/", + params={"q": "string"}, + headers={"content-type": "application/json"}, + data=gen(), + ) + openapi_request = RequestsOpenAPIRequest(request) + result = request_unmarshaller.unmarshal(openapi_request) + assert not result.errors + assert request.data is GeneratorType + def test_request_validator_prepared_request(self, request_unmarshaller): request = requests.Request( "POST",