diff --git a/src/engine/Server.cpp b/src/engine/Server.cpp index f22ab3b79..cb9bf96a3 100644 --- a/src/engine/Server.cpp +++ b/src/engine/Server.cpp @@ -344,18 +344,25 @@ Awaitable Server::process( checkParameterNotPresent("default-graph-uri"); checkParameterNotPresent("named-graph-uri"); - auto checkParameter = [¶meters](std::string_view key, - std::optional value, - bool accessAllowed = true) { - return Server::checkParameter(parameters, key, std::move(value), - accessAllowed); - }; + // We always want to call `Server::checkParameter` with the same first + // parameter. + auto checkParameter = + std::bind_front(&Server::checkParameter, std::cref(parameters)); // Check the access token. If an access token is provided and the check fails, // throw an exception and do not process any part of the query (even if the // processing had been allowed without access token). bool accessTokenOk = checkAccessToken(checkParameter("access-token", std::nullopt)); + auto requireValidAccessToken = [&accessTokenOk]( + const std::string& actionName) { + if (!accessTokenOk) { + throw std::runtime_error(absl::StrCat( + actionName, + " requires a valid access token. No valid access token is present.", + "Processing of request aborted.")); + } + }; // Process all URL parameters known to QLever. If there is more than one, // QLever processes all of them, but only returns the result from the last @@ -382,8 +389,8 @@ Awaitable Server::process( logCommand(cmd, "clear the cache (unpinned elements only)"); cache_.clearUnpinnedOnly(); response = createJsonResponse(composeCacheStatsJson(), request); - } else if (auto cmd = - checkParameter("cmd", "clear-cache-complete", accessTokenOk)) { + } else if (auto cmd = checkParameter("cmd", "clear-cache-complete")) { + requireValidAccessToken("clear-cache-complete"); logCommand(cmd, "clear cache completely (including unpinned elements)"); cache_.clearAll(); response = createJsonResponse(composeCacheStatsJson(), request); @@ -394,8 +401,8 @@ Awaitable Server::process( logCommand(cmd, "get index ID"); response = createOkResponse(index_.getIndexId(), request, MediaType::textPlain); - } else if (auto cmd = - checkParameter("cmd", "dump-active-queries", accessTokenOk)) { + } else if (auto cmd = checkParameter("cmd", "dump-active-queries")) { + requireValidAccessToken("dump-active-queries"); logCommand(cmd, "dump active queries"); nlohmann::json json; for (auto& [key, value] : queryRegistry_.getActiveQueries()) { @@ -417,8 +424,8 @@ Awaitable Server::process( } // Set description of KB index. - if (auto description = - checkParameter("index-description", std::nullopt, accessTokenOk)) { + if (auto description = checkParameter("index-description", std::nullopt)) { + requireValidAccessToken("index-description"); LOG(INFO) << "Setting index description to: \"" << description.value() << "\"" << std::endl; index_.setKbName(std::string{description.value()}); @@ -426,8 +433,8 @@ Awaitable Server::process( } // Set description of text index. - if (auto description = - checkParameter("text-description", std::nullopt, accessTokenOk)) { + if (auto description = checkParameter("text-description", std::nullopt)) { + requireValidAccessToken("text-description"); LOG(INFO) << "Setting text description to: \"" << description.value() << "\"" << std::endl; index_.setTextName(std::string{description.value()}); @@ -436,7 +443,8 @@ Awaitable Server::process( // Set one or several of the runtime parameters. for (auto key : RuntimeParameters().getKeys()) { - if (auto value = checkParameter(key, std::nullopt, accessTokenOk)) { + if (auto value = checkParameter(key, std::nullopt)) { + requireValidAccessToken("setting runtime parameters"); LOG(INFO) << "Setting runtime parameter \"" << key << "\"" << " to value \"" << value.value() << "\"" << std::endl; RuntimeParameters().set(key, std::string{value.value()}); @@ -451,19 +459,31 @@ Awaitable Server::process( if (auto timeLimit = co_await verifyUserSubmittedQueryTimeout( checkParameter("timeout", std::nullopt), accessTokenOk, request, send)) { - co_return co_await processQuery(parameters, query.query_, requestTimer, - std::move(request), send, - timeLimit.value()); + co_return co_await processQueryOrUpdate( + parameters, query.query_, requestTimer, std::move(request), send, + timeLimit.value()); } else { // If the optional is empty, this indicates an error response has been // sent to the client already. We can stop here. co_return; } }; - auto visitUpdate = [](const ad_utility::url_parser::sparqlOperation::Update&) + auto visitUpdate = + [&checkParameter, &accessTokenOk, &request, &send, ¶meters, + &requestTimer, + this](const ad_utility::url_parser::sparqlOperation::Update& update) -> Awaitable { - throw std::runtime_error( - "SPARQL 1.1 Update is currently not supported by QLever."); + if (auto timeLimit = co_await verifyUserSubmittedQueryTimeout( + checkParameter("timeout", std::nullopt), accessTokenOk, request, + send)) { + co_return co_await processQueryOrUpdate( + parameters, update.update_, requestTimer, std::move(request), send, + timeLimit.value()); + } else { + // If the optional is empty, this indicates an error response has been + // sent to the client already. We can stop here. + co_return; + } }; auto visitNone = [&response, &send, &request]( @@ -492,6 +512,37 @@ Awaitable Server::process( parsedHttpRequest.operation_); } +// ____________________________________________________________________________ +std::pair Server::determineResultPinning( + const ad_utility::url_parser::ParamValueMap& params) { + const bool pinSubtrees = + checkParameter(params, "pinsubtrees", "true").has_value(); + const bool pinResult = + checkParameter(params, "pinresult", "true").has_value(); + return {pinSubtrees, pinResult}; +} +// ____________________________________________________________________________ +Awaitable Server::setupPlannedQuery( + const ad_utility::url_parser::ParamValueMap& params, + const std::string& operation, QueryExecutionContext& qec, + SharedCancellationHandle handle, TimeLimit timeLimit, + const ad_utility::Timer& requestTimer) { + auto queryDatasets = ad_utility::url_parser::parseDatasetClauses(params); + std::optional plannedQuery = + co_await parseAndPlan(operation, queryDatasets, qec, handle, timeLimit); + AD_CORRECTNESS_CHECK(plannedQuery.has_value()); + auto& qet = plannedQuery.value().queryExecutionTree_; + qet.isRoot() = true; // allow pinning of the final result + auto timeForQueryPlanning = requestTimer.msecs(); + auto& runtimeInfoWholeQuery = + qet.getRootOperation()->getRuntimeInfoWholeQuery(); + runtimeInfoWholeQuery.timeQueryPlanning = timeForQueryPlanning; + LOG(INFO) << "Query planning done in " << timeForQueryPlanning.count() + << " ms" << std::endl; + LOG(TRACE) << qet.getCacheKey() << std::endl; + + co_return std::move(plannedQuery.value()); +} // _____________________________________________________________________________ json Server::composeErrorResponseJson( const string& query, const std::string& errorMsg, @@ -512,11 +563,6 @@ json Server::composeErrorResponseJson( j["metadata"]["stopIndex"] = value.stopIndex_; j["metadata"]["line"] = value.line_; j["metadata"]["positionInLine"] = value.charPositionInLine_; - // The ANTLR parser may not see the whole query. (The reason is value mixing - // of the old and new parser.) To detect/work with this we also transmit - // what ANTLR saw as query. - // TODO remove once the whole query is parsed with ANTLR. - j["metadata"]["query"] = value.query_; } return j; @@ -574,7 +620,6 @@ class QueryAlreadyInUseError : public std::runtime_error { }; // _____________________________________________ - ad_utility::websocket::OwningQueryId Server::getQueryId( const ad_utility::httpUtils::HttpRequest auto& request, std::string_view query) { @@ -592,7 +637,6 @@ ad_utility::websocket::OwningQueryId Server::getQueryId( } // _____________________________________________________________________________ - auto Server::cancelAfterDeadline( std::weak_ptr> cancellationHandle, TimeLimit timeLimit) @@ -665,20 +709,143 @@ Awaitable Server::sendStreamableResponse( } // ____________________________________________________________________________ -boost::asio::awaitable Server::processQuery( +class NoSupportedMediatypeError : public std::runtime_error { + public: + explicit NoSupportedMediatypeError(std::string_view msg) + : std::runtime_error{std::string{msg}} {} +}; + +// ____________________________________________________________________________ +MediaType Server::determineMediaType( + const ad_utility::url_parser::ParamValueMap& params, + const ad_utility::httpUtils::HttpRequest auto& request) { + // The following code block determines the media type to be used for the + // result. The media type is either determined by the "Accept:" header of + // the request or by the URL parameter "action=..." (for TSV and CSV export, + // for QLever-historical reasons). + std::optional mediaType = std::nullopt; + + // The explicit `action=..._export` parameter have precedence over the + // `Accept:...` header field + if (checkParameter(params, "action", "csv_export")) { + mediaType = MediaType::csv; + } else if (checkParameter(params, "action", "tsv_export")) { + mediaType = MediaType::tsv; + } else if (checkParameter(params, "action", "qlever_json_export")) { + mediaType = MediaType::qleverJson; + } else if (checkParameter(params, "action", "sparql_json_export")) { + mediaType = MediaType::sparqlJson; + } else if (checkParameter(params, "action", "turtle_export")) { + mediaType = MediaType::turtle; + } else if (checkParameter(params, "action", "binary_export")) { + mediaType = MediaType::octetStream; + } + + std::string_view acceptHeader = request.base()[http::field::accept]; + + if (!mediaType.has_value()) { + mediaType = ad_utility::getMediaTypeFromAcceptHeader(acceptHeader); + } + AD_CORRECTNESS_CHECK(mediaType.has_value()); + + return mediaType.value(); +} + +// ____________________________________________________________________________ +Awaitable Server::processQuery( const ad_utility::url_parser::ParamValueMap& params, const string& query, ad_utility::Timer& requestTimer, const ad_utility::httpUtils::HttpRequest auto& request, auto&& send, TimeLimit timeLimit) { - using namespace ad_utility::httpUtils; + MediaType mediaType = determineMediaType(params, request); + LOG(INFO) << "Requested media type of result is \"" + << ad_utility::toString(mediaType) << "\"" << std::endl; + + // TODO use std::optional::transform + std::optional maxSend = std::nullopt; + auto parameterValue = + ad_utility::url_parser::getParameterCheckAtMostOnce(params, "send"); + if (parameterValue.has_value()) { + maxSend = std::stoul(parameterValue.value()); + } + // Limit JSON requests by default + if (!maxSend.has_value() && (mediaType == MediaType::sparqlJson || + mediaType == MediaType::qleverJson)) { + maxSend = MAX_NOF_ROWS_IN_RESULT; + } - auto sendJson = - [&request, &send]( - const json& jsonString, - http::status responseStatus) -> boost::asio::awaitable { - auto response = createJsonResponse(jsonString, request, responseStatus); - co_return co_await send(std::move(response)); - }; + auto queryHub = queryHub_.lock(); + AD_CORRECTNESS_CHECK(queryHub); + ad_utility::websocket::MessageSender messageSender{getQueryId(request, query), + *queryHub}; + + auto [cancellationHandle, cancelTimeoutOnDestruction] = + setupCancellationHandle(messageSender.getQueryId(), timeLimit); + + // Do the query planning. This creates a `QueryExecutionTree`, which will + // then be used to process the query. + auto [pinSubtrees, pinResult] = determineResultPinning(params); + LOG(INFO) << "Processing the following SPARQL query:" + << (pinResult ? " [pin result]" : "") + << (pinSubtrees ? " [pin subresults]" : "") << "\n" + << query << std::endl; + QueryExecutionContext qec(index_, &cache_, allocator_, + sortPerformanceEstimator_, std::ref(messageSender), + pinSubtrees, pinResult); + auto plannedQuery = co_await setupPlannedQuery( + params, query, qec, cancellationHandle, timeLimit, requestTimer); + auto qet = plannedQuery.queryExecutionTree_; + + if (plannedQuery.parsedQuery_.hasUpdateClause()) { + throw std::runtime_error("Expected Query but received Update."); + } + + // Apply stricter limit for export if present + if (maxSend.has_value()) { + auto& pq = plannedQuery.parsedQuery_; + pq._limitOffset._limit = + std::min(maxSend.value(), pq._limitOffset.limitOrDefault()); + } + // Make sure we don't underflow here + AD_CORRECTNESS_CHECK(plannedQuery.parsedQuery_._limitOffset._offset >= + qet.getRootOperation()->getLimit()._offset); + // Don't apply offset twice, if the offset was not applied to the operation + // then the exporter can safely apply it during export. + plannedQuery.parsedQuery_._limitOffset._offset -= + qet.getRootOperation()->getLimit()._offset; + + // This actually processes the query and sends the result in the requested + // format. + co_await sendStreamableResponse(request, send, mediaType, plannedQuery, qet, + requestTimer, cancellationHandle); + + // Print the runtime info. This needs to be done after the query + // was computed. + + // Log that we are done with the query and how long it took. + // + // NOTE: We need to explicitly stop the `requestTimer` here because in the + // sending code above, it is done only in some cases and not in others (in + // particular, not for TSV and CSV because for those, the result does not + // contain timing information). + // + // TODO Also log an identifier of the query. + LOG(INFO) << "Done processing query and sending result" + << ", total time was " << requestTimer.msecs().count() << " ms" + << std::endl; + LOG(DEBUG) << "Runtime Info:\n" + << qet.getRootOperation()->runtimeInfo().toString() << std::endl; + co_return; +} + +// ____________________________________________________________________________ +template +Awaitable Server::processQueryOrUpdate( + const ad_utility::url_parser::ParamValueMap& params, + const string& queryOrUpdate, ad_utility::Timer& requestTimer, + const ad_utility::httpUtils::HttpRequest auto& request, auto&& send, + TimeLimit timeLimit) { + using namespace ad_utility::httpUtils; http::status responseStatus = http::status::ok; @@ -692,141 +859,13 @@ boost::asio::awaitable Server::processQuery( // access to the runtimeInformation in the case of an error. std::optional plannedQuery; try { - auto containsParam = [¶ms](const std::string& param, - const std::string& expected) { - auto parameterValue = - ad_utility::url_parser::getParameterCheckAtMostOnce(params, param); - return parameterValue.has_value() && parameterValue.value() == expected; - }; - const bool pinSubtrees = containsParam("pinsubtrees", "true"); - const bool pinResult = containsParam("pinresult", "true"); - LOG(INFO) << "Processing the following SPARQL query:" - << (pinResult ? " [pin result]" : "") - << (pinSubtrees ? " [pin subresults]" : "") << "\n" - << query << std::endl; - - // The following code block determines the media type to be used for the - // result. The media type is either determined by the "Accept:" header of - // the request or by the URL parameter "action=..." (for TSV and CSV export, - // for QLever-historical reasons). - - std::optional mediaType = std::nullopt; - - // The explicit `action=..._export` parameter have precedence over the - // `Accept:...` header field - if (containsParam("action", "csv_export")) { - mediaType = MediaType::csv; - } else if (containsParam("action", "tsv_export")) { - mediaType = MediaType::tsv; - } else if (containsParam("action", "qlever_json_export")) { - mediaType = MediaType::qleverJson; - } else if (containsParam("action", "sparql_json_export")) { - mediaType = MediaType::sparqlJson; - } else if (containsParam("action", "turtle_export")) { - mediaType = MediaType::turtle; - } else if (containsParam("action", "binary_export")) { - mediaType = MediaType::octetStream; - } - - std::string_view acceptHeader = request.base()[http::field::accept]; - - if (!mediaType.has_value()) { - mediaType = ad_utility::getMediaTypeFromAcceptHeader(acceptHeader); - } - - // TODO use std::optional::transform - std::optional maxSend = std::nullopt; - auto parameterValue = - ad_utility::url_parser::getParameterCheckAtMostOnce(params, "send"); - if (parameterValue.has_value()) { - maxSend = std::stoul(parameterValue.value()); - } - // Limit JSON requests by default - if (!maxSend.has_value() && (mediaType == MediaType::sparqlJson || - mediaType == MediaType::qleverJson)) { - maxSend = MAX_NOF_ROWS_IN_RESULT; - } - - if (!mediaType.has_value()) { - co_return co_await send(createBadRequestResponse( - absl::StrCat("Did not find any supported media type " - "in this \'Accept:\' header field: \"", - acceptHeader, "\". ", - ad_utility::getErrorMessageForSupportedMediaTypes()), - request)); - } - AD_CONTRACT_CHECK(mediaType.has_value()); - LOG(INFO) << "Requested media type of result is \"" - << ad_utility::toString(mediaType.value()) << "\"" << std::endl; - - auto queryHub = queryHub_.lock(); - AD_CORRECTNESS_CHECK(queryHub); - ad_utility::websocket::MessageSender messageSender{ - getQueryId(request, query), *queryHub}; - // Do the query planning. This creates a `QueryExecutionTree`, which will - // then be used to process the query. - // - // NOTE: This should come after determining the media type. Otherwise, it - // might happen that the query planner runs for a while (recall that it many - // do index scans) and then we get an error message afterwards that a - // certain media type is not supported. - QueryExecutionContext qec(index_, &cache_, allocator_, - sortPerformanceEstimator_, - std::ref(messageSender), pinSubtrees, pinResult); - auto [cancellationHandle, cancelTimeoutOnDestruction] = - setupCancellationHandle(messageSender.getQueryId(), timeLimit); - - auto queryDatasets = ad_utility::url_parser::parseDatasetClauses(params); - plannedQuery = co_await parseAndPlan(query, queryDatasets, qec, - cancellationHandle, timeLimit); - AD_CORRECTNESS_CHECK(plannedQuery.has_value()); - auto& qet = plannedQuery.value().queryExecutionTree_; - qet.isRoot() = true; // allow pinning of the final result - auto timeForQueryPlanning = requestTimer.msecs(); - auto& runtimeInfoWholeQuery = - qet.getRootOperation()->getRuntimeInfoWholeQuery(); - runtimeInfoWholeQuery.timeQueryPlanning = timeForQueryPlanning; - LOG(INFO) << "Query planning done in " << timeForQueryPlanning.count() - << " ms" << std::endl; - LOG(TRACE) << qet.getCacheKey() << std::endl; - - // Apply stricter limit for export if present - if (maxSend.has_value()) { - auto& pq = plannedQuery.value().parsedQuery_; - pq._limitOffset._limit = - std::min(maxSend.value(), pq._limitOffset.limitOrDefault()); + if constexpr (type == OperationType::Query) { + co_await processQuery(params, queryOrUpdate, requestTimer, request, send, + timeLimit); + } else { + throw std::runtime_error( + "SPARQL 1.1 Update is currently not supported by QLever."); } - // Make sure we don't underflow here - AD_CORRECTNESS_CHECK( - plannedQuery.value().parsedQuery_._limitOffset._offset >= - qet.getRootOperation()->getLimit()._offset); - // Don't apply offset twice, if the offset was not applied to the operation - // then the exporter can safely apply it during export. - plannedQuery.value().parsedQuery_._limitOffset._offset -= - qet.getRootOperation()->getLimit()._offset; - - // This actually processes the query and sends the result in the requested - // format. - co_await sendStreamableResponse(request, send, mediaType.value(), - plannedQuery.value(), qet, requestTimer, - cancellationHandle); - - // Print the runtime info. This needs to be done after the query - // was computed. - - // Log that we are done with the query and how long it took. - // - // NOTE: We need to explicitly stop the `requestTimer` here because in the - // sending code above, it is done only in some cases and not in others (in - // particular, not for TSV and CSV because for those, the result does not - // contain timing information). - // - // TODO Also log an identifier of the query. - LOG(INFO) << "Done processing query and sending result" - << ", total time was " << requestTimer.msecs().count() << " ms" - << std::endl; - LOG(DEBUG) << "Runtime Info:\n" - << qet.getRootOperation()->runtimeInfo().toString() << std::endl; } catch (const ParseException& e) { responseStatus = http::status::bad_request; exceptionErrorMsg = e.errorMessageWithoutPositionalInfo(); @@ -834,6 +873,9 @@ boost::asio::awaitable Server::processQuery( } catch (const QueryAlreadyInUseError& e) { responseStatus = http::status::conflict; exceptionErrorMsg = e.what(); + } catch (const NoSupportedMediatypeError& e) { + responseStatus = http::status::bad_request; + exceptionErrorMsg = e.what(); } catch (const ad_utility::CancellationException& e) { // Send 429 status code to indicate that the time limit was reached // or the query was cancelled because of some other reason. @@ -863,14 +905,16 @@ boost::asio::awaitable Server::processQuery( } } auto errorResponseJson = composeErrorResponseJson( - query, exceptionErrorMsg.value(), requestTimer, metadata); + queryOrUpdate, exceptionErrorMsg.value(), requestTimer, metadata); if (plannedQuery.has_value()) { errorResponseJson["runtimeInformation"] = nlohmann::ordered_json(plannedQuery.value() .queryExecutionTree_.getRootOperation() ->runtimeInfo()); } - co_return co_await sendJson(errorResponseJson, responseStatus); + auto response = + createJsonResponse(errorResponseJson, request, responseStatus); + co_return co_await send(std::move(response)); } } @@ -964,11 +1008,9 @@ bool Server::checkAccessToken( } // _____________________________________________________________________________ - std::optional Server::checkParameter( const ad_utility::url_parser::ParamValueMap& parameters, - std::string_view key, std::optional value, - bool accessAllowed) { + std::string_view key, std::optional value) { auto param = ad_utility::url_parser::getParameterCheckAtMostOnce(parameters, key); if (!param.has_value()) { @@ -983,12 +1025,5 @@ std::optional Server::checkParameter( } else if (value != parameterValue) { return std::nullopt; } - // Now that we have the value, check if there is a problem with the access. - // If yes, we abort the query processing at this point. - if (!accessAllowed) { - throw std::runtime_error(absl::StrCat( - "Access to \"", key, "=", value.value(), "\" denied", - " (requires a valid access token)", ", processing of request aborted")); - } return value; } diff --git a/src/engine/Server.h b/src/engine/Server.h index c863e9377..7abaf3eaf 100644 --- a/src/engine/Server.h +++ b/src/engine/Server.h @@ -5,6 +5,8 @@ #pragma once +#include + #include #include @@ -118,9 +120,14 @@ class Server { Awaitable process( const ad_utility::httpUtils::HttpRequest auto& request, auto&& send); - /// Handle a http request that asks for the processing of a query. + // Indicates which type of operation is being processed. + enum class OperationType { Query, Update }; + + /// Handle a http request that asks for the processing of an query or update. + /// This is only a wrapper for `processQuery` and `processUpdate` which + /// does the error handling. /// \param params The key-value-pairs sent in the HTTP GET request. - /// \param query The query. + /// \param queryOrUpdate The query or update. /// \param requestTimer Timer that measure the total processing /// time of this request. /// \param request The HTTP request. @@ -128,12 +135,37 @@ class Server { /// `HttpServer.h` for documentation). /// \param timeLimit Duration in seconds after which the query will be /// cancelled. + template + Awaitable processQueryOrUpdate( + const ad_utility::url_parser::ParamValueMap& params, + const string& queryOrUpdate, ad_utility::Timer& requestTimer, + const ad_utility::httpUtils::HttpRequest auto& request, auto&& send, + TimeLimit timeLimit); + // Do the actual execution of a query. Awaitable processQuery( const ad_utility::url_parser::ParamValueMap& params, const string& query, ad_utility::Timer& requestTimer, const ad_utility::httpUtils::HttpRequest auto& request, auto&& send, TimeLimit timeLimit); + // Determine the media type to be used for the result. The media type is + // determined (in this order) by the current action (e.g., + // "action=csv_export") and by the "Accept" header of the request. + static ad_utility::MediaType determineMediaType( + const ad_utility::url_parser::ParamValueMap& params, + const ad_utility::httpUtils::HttpRequest auto& request); + FRIEND_TEST(ServerTest, determineMediaType); + // Determine whether the subtrees and the result should be pinned. + static std::pair determineResultPinning( + const ad_utility::url_parser::ParamValueMap& params); + FRIEND_TEST(ServerTest, determineResultPinning); + // Sets up the PlannedQuery s.t. it is ready to be executed. + Awaitable setupPlannedQuery( + const ad_utility::url_parser::ParamValueMap& params, + const std::string& operation, QueryExecutionContext& qec, + SharedCancellationHandle handle, TimeLimit timeLimit, + const ad_utility::Timer& requestTimer); + static json composeErrorResponseJson( const string& query, const std::string& errorMsg, ad_utility::Timer& requestTimer, @@ -204,19 +236,16 @@ class Server { /// HTTP error response. bool checkAccessToken(std::optional accessToken) const; - /// Checks if a URL parameter exists in the request, if we are allowed to - /// access it and it matches the expected `value`. If yes, return the value, - /// otherwise return `std::nullopt`. If `value` is `std::nullopt`, only check - /// if the key exists. We need this because we have parameters like - /// "cmd=stats", where a fixed combination of the key and value determines the - /// kind of action, as well as parameters like "index-decription=...", where - /// the key determines the kind of action. If the key is not found, always - /// return `std::nullopt`. If `accessAllowed` is false and a value is present, - /// throw an exception. + /// Checks if a URL parameter exists in the request, and it matches the + /// expected `value`. If yes, return the value, otherwise return + /// `std::nullopt`. If `value` is `std::nullopt`, only check if the key + /// exists. We need this because we have parameters like "cmd=stats", where a + /// fixed combination of the key and value determines the kind of action, as + /// well as parameters like "index-decription=...", where the key determines + /// the kind of action. If the key is not found, always return `std::nullopt`. static std::optional checkParameter( const ad_utility::url_parser::ParamValueMap& parameters, - std::string_view key, std::optional value, - bool accessAllowed); + std::string_view key, std::optional value); FRIEND_TEST(ServerTest, checkParameter); /// Check if user-provided timeout is authorized with a valid access-token or diff --git a/src/util/http/MediaTypes.h b/src/util/http/MediaTypes.h index 43a809e33..0b2634b8c 100644 --- a/src/util/http/MediaTypes.h +++ b/src/util/http/MediaTypes.h @@ -112,6 +112,9 @@ std::vector parseAcceptHeader( /// media types that appear earlier in the `SUPPORTED_MEDIA_TYPES`. If none of /// the `SUPPORTED_MEDIA_TYPES` is accepted by `acceptHeader`, then /// `std::nullopt` is returned. +// TODO: This function never returns `nullopt`, because an exception is thrown +// if no supported media type is found. Update the docstring and make the return +// type just `MediaType`. std::optional getMediaTypeFromAcceptHeader( std::string_view acceptHeader); diff --git a/test/ServerTest.cpp b/test/ServerTest.cpp index 4f98de679..5ed589fc4 100644 --- a/test/ServerTest.cpp +++ b/test/ServerTest.cpp @@ -146,29 +146,79 @@ TEST(ServerTest, checkParameter) { const ParamValueMap exampleParams = {{"foo", {"bar"}}, {"baz", {"qux", "quux"}}}; - EXPECT_THAT(Server::checkParameter(exampleParams, "doesNotExist", "", false), + EXPECT_THAT(Server::checkParameter(exampleParams, "doesNotExist", ""), testing::Eq(std::nullopt)); - EXPECT_THAT(Server::checkParameter(exampleParams, "foo", "baz", false), + EXPECT_THAT(Server::checkParameter(exampleParams, "foo", "baz"), testing::Eq(std::nullopt)); - AD_EXPECT_THROW_WITH_MESSAGE( - Server::checkParameter(exampleParams, "foo", "bar", false), - testing::StrEq("Access to \"foo=bar\" denied (requires a valid access " - "token), processing of request aborted")); - EXPECT_THAT(Server::checkParameter(exampleParams, "foo", "bar", true), + EXPECT_THAT(Server::checkParameter(exampleParams, "foo", "bar"), testing::Optional(testing::StrEq("bar"))); AD_EXPECT_THROW_WITH_MESSAGE( - Server::checkParameter(exampleParams, "baz", "qux", false), + Server::checkParameter(exampleParams, "baz", "qux"), testing::StrEq("Parameter \"baz\" must be given exactly once. Is: 2")); - EXPECT_THAT(Server::checkParameter(exampleParams, "foo", std::nullopt, true), + EXPECT_THAT(Server::checkParameter(exampleParams, "foo", std::nullopt), testing::Optional(testing::StrEq("bar"))); AD_EXPECT_THROW_WITH_MESSAGE( - Server::checkParameter(exampleParams, "foo", std::nullopt, false), - testing::StrEq("Access to \"foo=bar\" denied (requires a valid access " - "token), processing of request aborted")); - AD_EXPECT_THROW_WITH_MESSAGE( - Server::checkParameter(exampleParams, "baz", std::nullopt, false), + Server::checkParameter(exampleParams, "baz", std::nullopt), testing::StrEq("Parameter \"baz\" must be given exactly once. Is: 2")); AD_EXPECT_THROW_WITH_MESSAGE( - Server::checkParameter(exampleParams, "baz", std::nullopt, true), + Server::checkParameter(exampleParams, "baz", std::nullopt), testing::StrEq("Parameter \"baz\" must be given exactly once. Is: 2")); } + +TEST(ServerTest, determineResultPinning) { + EXPECT_THAT(Server::determineResultPinning( + {{"pinsubtrees", {"true"}}, {"pinresult", {"true"}}}), + testing::Pair(true, true)); + EXPECT_THAT(Server::determineResultPinning({{"pinresult", {"true"}}}), + testing::Pair(false, true)); + EXPECT_THAT(Server::determineResultPinning({{"pinsubtrees", {"otherValue"}}}), + testing::Pair(false, false)); +} + +TEST(ServerTest, determineMediaType) { + auto MakeRequest = [](const std::optional& accept, + const http::verb method = http::verb::get, + const std::string& target = "/", + const std::string& body = "") { + auto req = http::request{method, target, 11}; + if (accept.has_value()) { + req.set(http::field::accept, accept.value()); + } + req.body() = body; + req.prepare_payload(); + return req; + }; + auto checkActionMediatype = [&](const std::string& actionName, + ad_utility::MediaType expectedMediaType) { + EXPECT_THAT(Server::determineMediaType({{"action", {actionName}}}, + MakeRequest(std::nullopt)), + testing::Eq(expectedMediaType)); + }; + // The media type associated with the action overrides the `Accept` header. + EXPECT_THAT(Server::determineMediaType( + {{"action", {"csv_export"}}}, + MakeRequest("application/sparql-results+json")), + testing::Eq(ad_utility::MediaType::csv)); + checkActionMediatype("csv_export", ad_utility::MediaType::csv); + checkActionMediatype("tsv_export", ad_utility::MediaType::tsv); + checkActionMediatype("qlever_json_export", ad_utility::MediaType::qleverJson); + checkActionMediatype("sparql_json_export", ad_utility::MediaType::sparqlJson); + checkActionMediatype("turtle_export", ad_utility::MediaType::turtle); + checkActionMediatype("binary_export", ad_utility::MediaType::octetStream); + EXPECT_THAT(Server::determineMediaType( + {}, MakeRequest("application/sparql-results+json")), + testing::Eq(ad_utility::MediaType::sparqlJson)); + // No supported media type in the `Accept` header. (Contrary to it's docstring + // and interface) `ad_utility::getMediaTypeFromAcceptHeader` throws an + // exception if no supported media type is found. + AD_EXPECT_THROW_WITH_MESSAGE( + Server::determineMediaType({}, MakeRequest("text/css")), + testing::HasSubstr("Not a single media type known to this parser was " + "detected in \"text/css\".")); + // No `Accept` header means that any content type is allowed. + EXPECT_THAT(Server::determineMediaType({}, MakeRequest(std::nullopt)), + testing::Eq(ad_utility::MediaType::sparqlJson)); + // No `Accept` header and an empty `Accept` header are not distinguished. + EXPECT_THAT(Server::determineMediaType({}, MakeRequest("")), + testing::Eq(ad_utility::MediaType::sparqlJson)); +}