diff --git a/wsd/Admin.cpp b/wsd/Admin.cpp index 2e4a353f1a513..4ae7d226a0b3b 100644 --- a/wsd/Admin.cpp +++ b/wsd/Admin.cpp @@ -9,10 +9,14 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +#include +#include #include #include +#include #include +#include #include #include @@ -23,6 +27,7 @@ #include "Admin.hpp" #include "AdminModel.hpp" #include "Auth.hpp" +#include "HttpRequest.hpp" #include #include #include @@ -41,7 +46,6 @@ using namespace COOLProtocol; -using Poco::Net::HTTPResponse; using Poco::Util::Application; const int Admin::MinStatsIntervalMs = 50; @@ -416,6 +420,10 @@ void AdminSocketHandler::handleMessage(const std::vector &payload) { _admin->setCloseMonitorFlag(); } + else if (tokens.equals(0, "rollingupdate") && tokens.size() > 1) + { + _admin->setRollingUpdateInfo(tokens[1]); + } } AdminSocketHandler::AdminSocketHandler(Admin* adminManager, @@ -496,8 +504,7 @@ bool AdminSocketHandler::handleInitialRequest( return true; } - HTTPResponse response; - response.setStatusAndReason(HTTPResponse::HTTP_BAD_REQUEST); + http::Response response(http::StatusCode::BadRequest); response.setContentLength(0); LOG_INF_S("Admin::handleInitialRequest bad request"); socket->send(response); @@ -1293,6 +1300,42 @@ void Admin::deleteMonitorSocket(const std::string& uriWithoutParam) } } +void Admin::setRollingUpdateInfo(const std::string& jsonString) +{ + Poco::JSON::Object::Ptr object; + if (JsonUtil::parseJSON(jsonString, object)) + { + bool status = JsonUtil::getJSONValue(object, "inprogress"); + setRollingUpdateStatus(status); + Poco::JSON::Array::Ptr infoArray = object->getArray("serverinfo"); + if (!infoArray.isNull()) + { + for(size_t i=0; i < infoArray->size(); i++) + { + if (!infoArray->isObject(i)) + { + return; + } + const auto serverInfoObject = infoArray->getObject(i); + const std::string gitHash = JsonUtil::getJSONValue(serverInfoObject , "gitHash"); + const std::string serverId = JsonUtil::getJSONValue(serverInfoObject, "serverId"); + const std::string routeToken = JsonUtil::getJSONValue(serverInfoObject, "routeToken"); + _rollingUpdateInfo.try_emplace(gitHash, RollingUpdateServerInfo(gitHash, serverId, routeToken)); + } + } + } +} + +std::string Admin::getBuddyServer(const std::string& gitHash) +{ + auto iterator = _rollingUpdateInfo.find(gitHash); + if (iterator != _rollingUpdateInfo.end()) + { + return iterator->second.getRouteToken(); + } + return std::string(); +} + void Admin::stop() { joinThread(); diff --git a/wsd/Admin.hpp b/wsd/Admin.hpp index 4437f85e89dd5..0176c31367293 100644 --- a/wsd/Admin.hpp +++ b/wsd/Admin.hpp @@ -15,6 +15,8 @@ #include "net/WebSocketHandler.hpp" #include "COOLWSD.hpp" +#include +#include class Admin; @@ -186,6 +188,14 @@ class Admin : public SocketPoll void setCloseMonitorFlag() { _closeMonitor = true; } + void setRollingUpdateInfo(const std::string& jsonString); + + void setRollingUpdateStatus(bool status) { _rollingUpdateStatus = status; } + + bool getRollingUpdateStatus() { return _rollingUpdateStatus; } + + std::string getBuddyServer(const std::string& gitHash); + private: /// Notify Forkit of changed settings. void notifyForkit(); @@ -253,6 +263,30 @@ class Admin : public SocketPoll std::map> _monitorSockets; std::atomic _closeMonitor = false; + + class RollingUpdateServerInfo + { + public: + std::string getGitHash() { return _gitHash; } + std::string getServerId() { return _serverId; } + std::string getRouteToken() { return _routeToken; } + + RollingUpdateServerInfo(const std::string& gitHash, const std::string& serverId, + const std::string& routeToken) + : _gitHash(gitHash) + , _serverId(serverId) + , _routeToken(routeToken) + { + } + + private: + std::string _gitHash; + std::string _serverId; + std::string _routeToken; + }; + + std::map _rollingUpdateInfo; + std::atomic _rollingUpdateStatus; }; /* vim:set shiftwidth=4 softtabstop=4 expandtab: */ diff --git a/wsd/COOLWSD.cpp b/wsd/COOLWSD.cpp index b2bf851ddb282..ae9005618bd8f 100644 --- a/wsd/COOLWSD.cpp +++ b/wsd/COOLWSD.cpp @@ -4321,21 +4321,35 @@ class ClientRequestDispatcher final : public SimpleSocketHandler { // Unit testing, nothing to do here } - else if (requestDetails.equals(RequestDetails::Field::Type, "browser") || requestDetails.equals(RequestDetails::Field::Type, "wopi")) + else if (requestDetails.equals(RequestDetails::Field::Type, "browser") || + requestDetails.equals(RequestDetails::Field::Type, "wopi")) { + + std::string protocol = "http"; + if (socket->sniffSSL()) + protocol = "https"; + + Poco::URI requestUri(protocol + "://" + request.getHost() + request.getURI()); + const std::string& path = requestUri.getPath(); + bool versionMismatch = false; + if (path.find("browser/" COOLWSD_VERSION_HASH "/") == std::string::npos && + path.find("admin/") == std::string::npos) + { + versionMismatch = true; + } + // File server assert(socket && "Must have a valid socket"); constexpr auto ProxyRemote = "/remote/"; constexpr auto ProxyRemoteLen = sizeof(ProxyRemote) - 1; constexpr auto ProxyRemoteStatic = "/remote/static/"; - const auto uri = requestDetails.getURI(); - const auto pos = uri.find(ProxyRemoteStatic); + const auto pos = path.find(ProxyRemoteStatic); if (pos != std::string::npos) { - if (Util::endsWith(uri, "lokit-extra-img.svg")) + if (Util::endsWith(path, "lokit-extra-img.svg")) { ProxyRequestHandler::handleRequest( - uri.substr(pos + ProxyRemoteLen), socket, + path.substr(pos + ProxyRemoteLen), socket, ProxyRequestHandler::getProxyRatingServer()); } #if ENABLE_FEATURE_LOCK @@ -4347,12 +4361,42 @@ class ClientRequestDispatcher final : public SimpleSocketHandler { const std::string& serverUri = unlockImageUri.getScheme() + "://" + unlockImageUri.getAuthority(); - ProxyRequestHandler::handleRequest(uri.substr(pos + sizeof("/remote/static") - 1), - socket, serverUri); + ProxyRequestHandler::handleRequest( + path.substr(pos + sizeof("/remote/static") - 1), socket, serverUri); } } #endif } + else if (COOLWSD::IndirectionServerEnabled && versionMismatch && + Admin::instance().getRollingUpdateStatus()) + { + std::string searchString = "/browser/"; + size_t startHashPos = path.find(searchString); + if (startHashPos != std::string::npos) + { + startHashPos += searchString.length(); + size_t endHashPos = path.find('/', startHashPos); + + std::string gitHash; + if (endHashPos != std::string::npos) + { + gitHash = path.substr(startHashPos, endHashPos - startHashPos); + } + else + { + gitHash = path.substr(startHashPos); + } + std::string routeToken = Admin::instance().getBuddyServer(gitHash); + if (!routeToken.empty()) + requestUri.addQueryParameter("RouteToken", routeToken); + + LOG_DBG("proxyRequestUri: " << requestUri.toString()); + + ProxyRequestHandler::handleRequest(requestUri.getPath(), socket, + requestUri.getScheme() + "://" + + requestUri.getAuthority()); + } + } else { COOLWSD::FileRequestHandler->handleRequest(request, requestDetails, message, socket); diff --git a/wsd/ProxyRequestHandler.cpp b/wsd/ProxyRequestHandler.cpp index 441e744353596..944c17fbad375 100644 --- a/wsd/ProxyRequestHandler.cpp +++ b/wsd/ProxyRequestHandler.cpp @@ -44,9 +44,11 @@ void ProxyRequestHandler::handleRequest(const std::string& relPath, } uriProxy.setPath(relPath); - auto sessionProxy = http::Session::create(uriProxy.getHost(), - http::Session::Protocol::HttpSsl, - uriProxy.getPort()); + LOG_DBG("proxyRequestUri: " << uriProxy.toString()); + auto protocol = uriProxy.getScheme() == "https" ? http::Session::Protocol::HttpSsl + : http::Session::Protocol::HttpUnencrypted; + + auto sessionProxy = http::Session::create(uriProxy.getHost(), protocol, uriProxy.getPort()); sessionProxy->setTimeout(std::chrono::seconds(10)); http::Request requestProxy(uriProxy.getPathAndQuery()); http::Session::FinishedCallback proxyCallback = diff --git a/wsd/ProxyRequestHandler.hpp b/wsd/ProxyRequestHandler.hpp index 100531a1e2e14..c810a61b2fc49 100644 --- a/wsd/ProxyRequestHandler.hpp +++ b/wsd/ProxyRequestHandler.hpp @@ -11,6 +11,7 @@ #pragma once +#include #include #include "Socket.hpp" @@ -20,6 +21,7 @@ class ProxyRequestHandler static void handleRequest(const std::string& relPath, const std::shared_ptr& socket, const std::string& serverUri); + static std::string getProxyRatingServer() { return ProxyRatingServer; } private: