diff --git a/common/Util.cpp b/common/Util.cpp index c99c65ebb702c..b104fb093560a 100644 --- a/common/Util.cpp +++ b/common/Util.cpp @@ -774,6 +774,13 @@ namespace Util hash.resize(std::min(8, (int)hash.length())); } + std::string getCoolVersionHash() + { + std::string hash(COOLWSD_VERSION_HASH); + hash.resize(std::min(8, (int)hash.length())); + return hash; + } + const std::string& getProcessIdentifier() { static std::string id = Util::rng::getHexString(8); diff --git a/common/Util.hpp b/common/Util.hpp index d8398d9ee1f9a..2f883df9003de 100644 --- a/common/Util.hpp +++ b/common/Util.hpp @@ -303,6 +303,9 @@ namespace Util /// Get version information void getVersionInfo(std::string& version, std::string& hash); + /// Returns the COOL Version Hash string. + std::string getCoolVersionHash(); + ///< A random hex string that identifies the current process. const std::string& getProcessIdentifier(); diff --git a/wsd/Admin.cpp b/wsd/Admin.cpp index 2e4a353f1a513..40eaf87928fe8 100644 --- a/wsd/Admin.cpp +++ b/wsd/Admin.cpp @@ -9,10 +9,14 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +#include +#include #include #include +#include #include +#include #include #include @@ -23,6 +27,7 @@ #include "Admin.hpp" #include "AdminModel.hpp" #include "Auth.hpp" +#include "HttpRequest.hpp" #include #include #include @@ -41,7 +46,6 @@ using namespace COOLProtocol; -using Poco::Net::HTTPResponse; using Poco::Util::Application; const int Admin::MinStatsIntervalMs = 50; @@ -416,6 +420,10 @@ void AdminSocketHandler::handleMessage(const std::vector &payload) { _admin->setCloseMonitorFlag(); } + else if (tokens.equals(0, "rollingupdate") && tokens.size() > 1) + { + _admin->setRollingUpdateInfo(tokens[1]); + } } AdminSocketHandler::AdminSocketHandler(Admin* adminManager, @@ -496,8 +504,7 @@ bool AdminSocketHandler::handleInitialRequest( return true; } - HTTPResponse response; - response.setStatusAndReason(HTTPResponse::HTTP_BAD_REQUEST); + http::Response response(http::StatusCode::BadRequest); response.setContentLength(0); LOG_INF_S("Admin::handleInitialRequest bad request"); socket->send(response); @@ -1293,6 +1300,48 @@ void Admin::deleteMonitorSocket(const std::string& uriWithoutParam) } } +void Admin::setRollingUpdateInfo(const std::string& jsonString) +{ + Poco::JSON::Object::Ptr object; + if (JsonUtil::parseJSON(jsonString, object)) + { + bool status = JsonUtil::getJSONValue(object, "inprogress"); + setRollingUpdateStatus(status); + Poco::JSON::Array::Ptr infoArray = object->getArray("serverinfo"); + if (!infoArray.isNull()) + { + for(size_t i=0; i < infoArray->size(); i++) + { + if (!infoArray->isObject(i)) + { + return; + } + const auto serverInfoObject = infoArray->getObject(i); + const std::string gitHash = JsonUtil::getJSONValue(serverInfoObject , "gitHash"); + const std::string serverId = JsonUtil::getJSONValue(serverInfoObject, "serverId"); + const std::string routeToken = JsonUtil::getJSONValue(serverInfoObject, "routeToken"); + _rollingUpdateInfo.try_emplace(gitHash, RollingUpdateServerInfo(gitHash, serverId, routeToken)); + } + } + } +} + +std::string Admin::getBuddyServer(const std::string& gitHash) +{ + LOG_DBG("Getting routeToken for gitHash[" << gitHash << ']'); + for (auto iterator : _rollingUpdateInfo) + { + LOG_DBG("gitHash[" << iterator.first << "] routeToken[" << iterator.second.getRouteToken() + << "] serverId[" << iterator.second.getRouteToken() << ']'); + } + auto iterator = _rollingUpdateInfo.find(gitHash); + if (iterator != _rollingUpdateInfo.end()) + { + return iterator->second.getRouteToken(); + } + return std::string(); +} + void Admin::stop() { joinThread(); diff --git a/wsd/Admin.hpp b/wsd/Admin.hpp index 4437f85e89dd5..0176c31367293 100644 --- a/wsd/Admin.hpp +++ b/wsd/Admin.hpp @@ -15,6 +15,8 @@ #include "net/WebSocketHandler.hpp" #include "COOLWSD.hpp" +#include +#include class Admin; @@ -186,6 +188,14 @@ class Admin : public SocketPoll void setCloseMonitorFlag() { _closeMonitor = true; } + void setRollingUpdateInfo(const std::string& jsonString); + + void setRollingUpdateStatus(bool status) { _rollingUpdateStatus = status; } + + bool getRollingUpdateStatus() { return _rollingUpdateStatus; } + + std::string getBuddyServer(const std::string& gitHash); + private: /// Notify Forkit of changed settings. void notifyForkit(); @@ -253,6 +263,30 @@ class Admin : public SocketPoll std::map> _monitorSockets; std::atomic _closeMonitor = false; + + class RollingUpdateServerInfo + { + public: + std::string getGitHash() { return _gitHash; } + std::string getServerId() { return _serverId; } + std::string getRouteToken() { return _routeToken; } + + RollingUpdateServerInfo(const std::string& gitHash, const std::string& serverId, + const std::string& routeToken) + : _gitHash(gitHash) + , _serverId(serverId) + , _routeToken(routeToken) + { + } + + private: + std::string _gitHash; + std::string _serverId; + std::string _routeToken; + }; + + std::map _rollingUpdateInfo; + std::atomic _rollingUpdateStatus; }; /* vim:set shiftwidth=4 softtabstop=4 expandtab: */ diff --git a/wsd/COOLWSD.cpp b/wsd/COOLWSD.cpp index b2bf851ddb282..8ad3b3ddc3b8b 100644 --- a/wsd/COOLWSD.cpp +++ b/wsd/COOLWSD.cpp @@ -4321,22 +4321,41 @@ class ClientRequestDispatcher final : public SimpleSocketHandler { // Unit testing, nothing to do here } - else if (requestDetails.equals(RequestDetails::Field::Type, "browser") || requestDetails.equals(RequestDetails::Field::Type, "wopi")) + else if (requestDetails.equals(RequestDetails::Field::Type, "browser") || + requestDetails.equals(RequestDetails::Field::Type, "wopi")) { + + std::string protocol = "http"; + if (socket->sniffSSL()) + protocol = "https"; + + Poco::URI requestUri(protocol + "://" + request.getHost() + request.getURI()); + const std::string& path = requestUri.getPath(); + bool versionMismatch = false; + if ((path.find("browser/" COOLWSD_VERSION_HASH "/") == std::string::npos || + path.find("browser/" + Util::getCoolVersionHash() + "/") == std::string::npos) && + path.find("admin/") == std::string::npos) + { + LOG_DBG("Client - server version mismatch, proxy request to different server " + "Expected: " COOLWSD_VERSION_HASH + "; Actual URI path with version hash: " + << path); + versionMismatch = true; + } + // File server assert(socket && "Must have a valid socket"); constexpr auto ProxyRemote = "/remote/"; constexpr auto ProxyRemoteLen = sizeof(ProxyRemote) - 1; constexpr auto ProxyRemoteStatic = "/remote/static/"; - const auto uri = requestDetails.getURI(); - const auto pos = uri.find(ProxyRemoteStatic); + const auto pos = path.find(ProxyRemoteStatic); if (pos != std::string::npos) { - if (Util::endsWith(uri, "lokit-extra-img.svg")) + if (Util::endsWith(path, "lokit-extra-img.svg")) { ProxyRequestHandler::handleRequest( - uri.substr(pos + ProxyRemoteLen), socket, - ProxyRequestHandler::getProxyRatingServer()); + path.substr(pos + ProxyRemoteLen), socket, + ProxyRequestHandler::getProxyRatingServer(), "GET"); } #if ENABLE_FEATURE_LOCK else @@ -4347,12 +4366,65 @@ class ClientRequestDispatcher final : public SimpleSocketHandler { const std::string& serverUri = unlockImageUri.getScheme() + "://" + unlockImageUri.getAuthority(); - ProxyRequestHandler::handleRequest(uri.substr(pos + sizeof("/remote/static") - 1), - socket, serverUri); + ProxyRequestHandler::handleRequest( + path.substr(pos + sizeof("/remote/static") - 1), socket, serverUri, "GET"); } } #endif } + else if (COOLWSD::IndirectionServerEnabled && versionMismatch && + Admin::instance().getRollingUpdateStatus()) + { + std::string searchString = "/browser/"; + size_t startHashPos = path.find(searchString); + if (startHashPos != std::string::npos) + { + startHashPos += searchString.length(); + size_t endHashPos = path.find('/', startHashPos); + + std::string gitHash; + if (endHashPos != std::string::npos) + { + gitHash = path.substr(startHashPos, endHashPos - startHashPos); + } + else + { + gitHash = path.substr(startHashPos); + } + + const std::string& hash = Util::getCoolVersionHash(); + std::string routeToken = Admin::instance().getBuddyServer(hash); + if (!routeToken.empty()) + { + Poco::URI::QueryParameters params = requestUri.getQueryParameters(); + const auto routeTokenIt = + std::find_if(params.begin(), params.end(), + [](const std::pair& element) + { return element.first == "RouteToken"; }); + if (routeTokenIt == params.end()) + { + LOG_DBG("Adding routeToken[" << routeToken + << "] as a parameter to requestUri[" + << requestUri.toString() << ']'); + + requestUri.addQueryParameter("RouteToken", routeToken); + } + else + { + LOG_DBG("Updating routeToken[" << routeToken + << "] parameter in requestUri[" + << requestUri.toString() << ']'); + + routeTokenIt->second = routeToken; + requestUri.setQueryParameters(params); + } + } + + ProxyRequestHandler::handleRequest(requestUri.getPathAndQuery(), socket, + requestUri.getScheme() + "://" + + requestUri.getAuthority(), request.getMethod()); + } + } else { COOLWSD::FileRequestHandler->handleRequest(request, requestDetails, message, socket); diff --git a/wsd/ProxyRequestHandler.cpp b/wsd/ProxyRequestHandler.cpp index 441e744353596..fe3b063a542be 100644 --- a/wsd/ProxyRequestHandler.cpp +++ b/wsd/ProxyRequestHandler.cpp @@ -23,7 +23,8 @@ std::chrono::system_clock::time_point ProxyRequestHandler::MaxAge; void ProxyRequestHandler::handleRequest(const std::string& relPath, const std::shared_ptr& socket, - const std::string& serverUri) + const std::string& serverUri, + const std::string& verb) { Poco::URI uriProxy(serverUri); @@ -36,19 +37,23 @@ void ProxyRequestHandler::handleRequest(const std::string& relPath, MaxAge = zero; } - const auto cacheEntry = CacheFileHash.find(relPath); - if (cacheEntry != CacheFileHash.end()) - { - socket->sendAndShutdown(*cacheEntry->second); - return; - } + // const auto cacheEntry = CacheFileHash.find(relPath); + // if (cacheEntry != CacheFileHash.end()) + // { + // socket->sendAndShutdown(*cacheEntry->second); + // return; + // } + + uriProxy.setPathEtc(relPath); + LOG_DBG("uriProxy[" << uriProxy.getPathAndQuery() << ']'); + + auto protocol = uriProxy.getScheme() == "https" ? http::Session::Protocol::HttpSsl + : http::Session::Protocol::HttpUnencrypted; - uriProxy.setPath(relPath); - auto sessionProxy = http::Session::create(uriProxy.getHost(), - http::Session::Protocol::HttpSsl, - uriProxy.getPort()); + auto sessionProxy = http::Session::create(uriProxy.getHost(), protocol, uriProxy.getPort()); sessionProxy->setTimeout(std::chrono::seconds(10)); http::Request requestProxy(uriProxy.getPathAndQuery()); + requestProxy.setVerb(verb); http::Session::FinishedCallback proxyCallback = [socket, zero](const std::shared_ptr& httpSession) { diff --git a/wsd/ProxyRequestHandler.hpp b/wsd/ProxyRequestHandler.hpp index 100531a1e2e14..8b88685b96ebe 100644 --- a/wsd/ProxyRequestHandler.hpp +++ b/wsd/ProxyRequestHandler.hpp @@ -11,6 +11,7 @@ #pragma once +#include #include #include "Socket.hpp" @@ -19,7 +20,9 @@ class ProxyRequestHandler public: static void handleRequest(const std::string& relPath, const std::shared_ptr& socket, - const std::string& serverUri); + const std::string& serverUri, + const std::string& verb); + static std::string getProxyRatingServer() { return ProxyRatingServer; } private: