From 960f7273a7319262cc9c8b619996d88d731489a5 Mon Sep 17 00:00:00 2001 From: hhvrc Date: Thu, 16 Jan 2025 16:50:03 +0100 Subject: [PATCH] Improve HTTP head parsing --- .../AsyncWebServer/ESPAsyncWebServer.h | 34 ++-- .../AsyncWebServer/HttpRequestMethod.h | 14 ++ include/external/AsyncWebServer/HttpVersion.h | 6 + .../external/AsyncWebServer/WebHandlerImpl.h | 4 +- src/external/AsyncWebServer/WebRequest.cpp | 183 ++++++++++++++---- src/external/AsyncWebServer/WebResponses.cpp | 4 +- src/external/AsyncWebServer/WebServer.cpp | 6 +- 7 files changed, 184 insertions(+), 67 deletions(-) create mode 100644 include/external/AsyncWebServer/HttpRequestMethod.h create mode 100644 include/external/AsyncWebServer/HttpVersion.h diff --git a/include/external/AsyncWebServer/ESPAsyncWebServer.h b/include/external/AsyncWebServer/ESPAsyncWebServer.h index 1eace397..685aa886 100644 --- a/include/external/AsyncWebServer/ESPAsyncWebServer.h +++ b/include/external/AsyncWebServer/ESPAsyncWebServer.h @@ -26,6 +26,8 @@ #include "FS.h" #include +#include "HttpRequestMethod.h" +#include "HttpVersion.h" #include "LinkedList.h" #include "util/StringUtils.h" @@ -45,23 +47,9 @@ class AsyncStaticWebHandler; class AsyncCallbackWebHandler; class AsyncResponseStream; -#ifndef WEBSERVER_H -typedef enum { - HTTP_GET = 0b00000001, - HTTP_POST = 0b00000010, - HTTP_DELETE = 0b00000100, - HTTP_PUT = 0b00001000, - HTTP_PATCH = 0b00010000, - HTTP_HEAD = 0b00100000, - HTTP_OPTIONS = 0b01000000, - HTTP_ANY = 0b01111111, -} WebRequestMethod; -#endif - // if this value is returned when asked for data, packet will not be sent and you will be asked for data again #define RESPONSE_TRY_AGAIN 0xFFFFFFFF -typedef uint8_t WebRequestMethodComposite; typedef std::function ArDisconnectHandler; /* @@ -156,8 +144,8 @@ class AsyncWebServerRequest { std::string _temp; uint8_t _parseState; - uint8_t _version; - WebRequestMethodComposite _method; + HttpVersion _version; + HttpRequestMethod _method; std::string _url; std::string _host; std::string _contentType; @@ -202,7 +190,7 @@ class AsyncWebServerRequest { void _parseLine(); void _parsePlainPostChar(uint8_t data); void _parseMultipartPostByte(uint8_t data, bool last); - void _addGetParams(std::string_view params); + void _parseQueryParams(std::string_view params); void _handleUploadStart(); void _handleUploadByte(uint8_t data, bool last); @@ -216,8 +204,8 @@ class AsyncWebServerRequest { ~AsyncWebServerRequest(); AsyncClient* client() { return _client; } - uint8_t version() const { return _version; } - WebRequestMethodComposite method() const { return _method; } + HttpVersion version() const { return _version; } + HttpRequestMethod method() const { return _method; } const std::string& url() const { return _url; } const std::string& host() const { return _host; } const std::string& contentType() const { return _contentType; } @@ -356,7 +344,7 @@ class AsyncWebServerResponse { virtual void setContentLength(size_t len); virtual void setContentType(std::string_view type); virtual void addHeader(std::string_view name, std::string_view value); - virtual std::string _assembleHead(uint8_t version); + virtual std::string _assembleHead(HttpVersion version); virtual bool _started() const; virtual bool _finished() const; virtual bool _failed() const; @@ -395,9 +383,9 @@ class AsyncWebServer { bool removeHandler(AsyncWebHandler* handler); AsyncCallbackWebHandler& on(const char* uri, ArRequestHandlerFunction onRequest); - AsyncCallbackWebHandler& on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest); - AsyncCallbackWebHandler& on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload); - AsyncCallbackWebHandler& on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload, ArBodyHandlerFunction onBody); + AsyncCallbackWebHandler& on(const char* uri, HttpRequestMethod method, ArRequestHandlerFunction onRequest); + AsyncCallbackWebHandler& on(const char* uri, HttpRequestMethod method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload); + AsyncCallbackWebHandler& on(const char* uri, HttpRequestMethod method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload, ArBodyHandlerFunction onBody); AsyncStaticWebHandler& serveStatic(const char* uri, fs::FS& fs, const char* path, const char* cache_control = NULL); diff --git a/include/external/AsyncWebServer/HttpRequestMethod.h b/include/external/AsyncWebServer/HttpRequestMethod.h new file mode 100644 index 00000000..5e02a949 --- /dev/null +++ b/include/external/AsyncWebServer/HttpRequestMethod.h @@ -0,0 +1,14 @@ +#pragma once + +enum HttpRequestMethod : uint16_t { + HTTP_GET = 0b0000000001, + HTTP_POST = 0b0000000010, + HTTP_DELETE = 0b0000000100, + HTTP_PUT = 0b0000001000, + HTTP_PATCH = 0b0000010000, + HTTP_HEAD = 0b0000100000, + HTTP_OPTIONS = 0b0001000000, + HTTP_CONNECT = 0b0010000000, + HTTP_TRACE = 0b0100000000, + HTTP_ANY = 0b0111111111, +}; \ No newline at end of file diff --git a/include/external/AsyncWebServer/HttpVersion.h b/include/external/AsyncWebServer/HttpVersion.h new file mode 100644 index 00000000..54b62f08 --- /dev/null +++ b/include/external/AsyncWebServer/HttpVersion.h @@ -0,0 +1,6 @@ +#pragma once + +struct HttpVersion { + uint8_t major; + uint8_t minor; +}; diff --git a/include/external/AsyncWebServer/WebHandlerImpl.h b/include/external/AsyncWebServer/WebHandlerImpl.h index 3b7cab1f..248a855d 100644 --- a/include/external/AsyncWebServer/WebHandlerImpl.h +++ b/include/external/AsyncWebServer/WebHandlerImpl.h @@ -57,7 +57,7 @@ class AsyncCallbackWebHandler : public AsyncWebHandler { private: protected: std::string _uri; - WebRequestMethodComposite _method; + HttpRequestMethod _method; ArRequestHandlerFunction _onRequest; ArUploadHandlerFunction _onUpload; ArBodyHandlerFunction _onBody; @@ -78,7 +78,7 @@ class AsyncCallbackWebHandler : public AsyncWebHandler { _uri = uri; _isRegex = uri.length() > 1 && uri.front() == '^' && uri.back() == '$'; } - void setMethod(WebRequestMethodComposite method) { _method = method; } + void setMethod(HttpRequestMethod method) { _method = method; } void onRequest(ArRequestHandlerFunction fn) { _onRequest = fn; } void onUpload(ArUploadHandlerFunction fn) { _onUpload = fn; } void onBody(ArBodyHandlerFunction fn) { _onBody = fn; } diff --git a/src/external/AsyncWebServer/WebRequest.cpp b/src/external/AsyncWebServer/WebRequest.cpp index 86401a9a..5f563a9f 100644 --- a/src/external/AsyncWebServer/WebRequest.cpp +++ b/src/external/AsyncWebServer/WebRequest.cpp @@ -21,6 +21,7 @@ #include "external/AsyncWebServer/ESPAsyncWebServer.h" #include "external/AsyncWebServer/WebResponseImpl.h" +#include "Convert.h" #include "util/HexUtils.h" static const std::string SharedEmptyString = std::string(); @@ -35,6 +36,105 @@ enum { PARSE_REQ_FAIL }; +static bool httpTryBasicUriDecode(std::string_view uri, std::string& uri_out) +{ + uri_out.clear(); + + if (uri.empty()) return false; + + uri_out.reserve(uri.length()); + + for (std::size_t i = 0; i < uri.length(); ++i) { + char c = uri[i]; + + // Escaped character handling + if (c == '%') [[unlikely]] { + // Check if theres enough space for the two hex chars + if (i + 2 >= uri.length()) { + uri_out.clear(); + return false; + } + + // Decode the hex characters + uint8_t decoded = 0; + if (!OpenShock::HexUtils::TryParseHexPair(uri[i + 1], uri[i + 2], decoded)) { + uri_out.clear(); + return false; + } + + // Push back the decoded character + uri_out.push_back(static_cast(decoded)); + + // Skip the hex characters + i += 2; + + continue; + } + + // Fail on whitespace + if (c == ' ' || c == '\t' || c == '\r' || c == '\n') [[unlikely]] { + uri_out.clear(); + return false; + } + + // Push back the hopefully valid character + uri_out.push_back(c); + } + + return true; +} + +static bool httpParseMethod(std::string_view str, HttpRequestMethod& method_out) +{ + using namespace std::string_view_literals; + + if (str == "GET"sv) { + method_out = HTTP_GET; + } else if (str == "POST"sv) { + method_out = HTTP_POST; + } else if (str == "DELETE"sv) { + method_out = HTTP_DELETE; + } else if (str == "PUT"sv) { + method_out = HTTP_PUT; + } else if (str == "PATCH"sv) { + method_out = HTTP_PATCH; + } else if (str == "HEAD"sv) { + method_out = HTTP_HEAD; + } else if (str == "OPTIONS"sv) { + method_out = HTTP_OPTIONS; + } else if (str == "CONNECT"sv) { + method_out = HTTP_CONNECT; + } else if (str == "TRACE"sv) { + method_out = HTTP_TRACE; + } else { + return false; + } + + return true; +} +static bool httpParseHttpVersion(std::string_view str, HttpVersion& http_version_out) +{ + using namespace std::string_view_literals; + if (!OpenShock::StringStartsWith(str, "HTTP/"sv)) { + return false; + } + + std::size_t dot_pos = str.find('.', 5); + if (dot_pos == std::string_view::npos) { + return false; + } + + if (!OpenShock::Convert::ToUint8(str.substr(5, dot_pos), http_version_out.major)) { + return false; + } + + if (!OpenShock::Convert::ToUint8(str.substr(dot_pos + 1), http_version_out.minor)) { + return false; + } + + return true; +} + AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c) : _client(c) , _server(s) @@ -42,7 +142,7 @@ AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c) , _response(NULL) , _temp() , _parseState(0) - , _version(0) + , _version {} , _method(HTTP_ANY) , _url() , _host() @@ -273,7 +373,7 @@ void AsyncWebServerRequest::_addPathParam(const char* p) _pathParams.add(new std::string(p)); } -void AsyncWebServerRequest::_addGetParams(std::string_view params) +void AsyncWebServerRequest::_parseQueryParams(std::string_view params) { size_t start = 0; while (start < params.length()) { @@ -290,41 +390,50 @@ void AsyncWebServerRequest::_addGetParams(std::string_view params) bool AsyncWebServerRequest::_parseReqHead() { - // Split the head into method, url and version - int index = _temp.indexOf(' '); - String m = _temp.substring(0, index); - index = _temp.indexOf(' ', index + 1); - String u = _temp.substring(m.length() + 1, index); - _temp = _temp.substring(index + 1); - - if (m == "GET") { - _method = HTTP_GET; - } else if (m == "POST") { - _method = HTTP_POST; - } else if (m == "DELETE") { - _method = HTTP_DELETE; - } else if (m == "PUT") { - _method = HTTP_PUT; - } else if (m == "PATCH") { - _method = HTTP_PATCH; - } else if (m == "HEAD") { - _method = HTTP_HEAD; - } else if (m == "OPTIONS") { - _method = HTTP_OPTIONS; - } - - String g = String(); - index = u.indexOf('?'); - if (index > 0) { - g = u.substring(index + 1); - u = u.substring(0, index); - } - _url = urlDecode(u); - _addGetParams(g); - - if (!_temp.startsWith("HTTP/1.0")) _version = 1; + using namespace std::string_view_literals; + + std::string_view body = _temp; + + std::size_t start_pos = 0, end_pos = 0; + + // Get request method + end_pos = body.find(' ', start_pos); + if (end_pos == std::string_view::npos) { + return false; // Should respond: 400 Bad Request + } + std::string_view method_str = body.substr(start_pos, end_pos); + + // Get request URI + start_pos = end_pos + 1; + end_pos = body.find(' ', start_pos); + if (end_pos == std::string_view::npos) { + return false; // Should respond: 400 Bad Request + } + std::string_view request_uri = body.substr(start_pos, end_pos); + + // Get request HTTP version + start_pos = end_pos + 1; + end_pos = body.find("\r\n"sv, start_pos); + if (end_pos == std::string_view::npos) { + return false; // Should respond: 400 Bad Request + } + std::string_view http_version_str = body.substr(start_pos, end_pos); + + // Parse request method + if (!httpParseMethod(method_str, _method)) { + return false; // Should respond: 405 Method Not Allowed + } + + // Parse request URI + if (!httpTryBasicUriDecode(request_uri, _url)) { + return false; // Should respond: 400 Bad Request + } + + // Parse request HTTP version + if (!httpParseHttpVersion(http_version_str, _version)) { + return false; // Should respond: 400 Bad Request + } - _temp = String(); return true; } @@ -758,7 +867,7 @@ AsyncWebServerResponse* AsyncWebServerRequest::beginResponse(std::string_view co AsyncWebServerResponse* AsyncWebServerRequest::beginChunkedResponse(std::string_view contentType, AwsResponseFiller callback) { - if (_version) return new AsyncChunkedResponse(contentType, callback); + if (_version.minor > 0) return new AsyncChunkedResponse(contentType, callback); return new AsyncCallbackResponse(contentType, 0, callback); } diff --git a/src/external/AsyncWebServer/WebResponses.cpp b/src/external/AsyncWebServer/WebResponses.cpp index e806b190..030bd1f9 100644 --- a/src/external/AsyncWebServer/WebResponses.cpp +++ b/src/external/AsyncWebServer/WebResponses.cpp @@ -176,7 +176,7 @@ void AsyncWebServerResponse::addHeader(std::string_view name, std::string_view v _headers.add(new AsyncWebHeader(name, value)); } -std::string AsyncWebServerResponse::_assembleHead(uint8_t version) +std::string AsyncWebServerResponse::_assembleHead(HttpVersion version) { if (version) { addHeader("Accept-Ranges", "none"); @@ -186,7 +186,7 @@ std::string AsyncWebServerResponse::_assembleHead(uint8_t version) int bufSize = 300; char buf[bufSize]; - snprintf(buf, bufSize, "HTTP/1.%d %d %s\r\n", version, _code, _responseCodeToString(_code)); + snprintf(buf, bufSize, "HTTP/%d.%d %d %s\r\n", version.major, version.minor, _code, _responseCodeToString(_code)); out.concat(buf); if (_sendContentLength) { diff --git a/src/external/AsyncWebServer/WebServer.cpp b/src/external/AsyncWebServer/WebServer.cpp index 544dccf4..d380d1ee 100644 --- a/src/external/AsyncWebServer/WebServer.cpp +++ b/src/external/AsyncWebServer/WebServer.cpp @@ -107,7 +107,7 @@ void AsyncWebServer::_attachHandler(AsyncWebServerRequest* request) request->setHandler(_catchAllHandler); } -AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload, ArBodyHandlerFunction onBody) +AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, HttpRequestMethod method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload, ArBodyHandlerFunction onBody) { AsyncCallbackWebHandler* handler = new AsyncCallbackWebHandler(); handler->setUri(uri); @@ -119,7 +119,7 @@ AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodCom return *handler; } -AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload) +AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, HttpRequestMethod method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload) { AsyncCallbackWebHandler* handler = new AsyncCallbackWebHandler(); handler->setUri(uri); @@ -130,7 +130,7 @@ AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodCom return *handler; } -AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest) +AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, HttpRequestMethod method, ArRequestHandlerFunction onRequest) { AsyncCallbackWebHandler* handler = new AsyncCallbackWebHandler(); handler->setUri(uri);