From ee4d227b562f98158928de17f03babefd96f9aff Mon Sep 17 00:00:00 2001 From: Tristan Ross Date: Fri, 17 Jan 2025 16:49:31 -0800 Subject: [PATCH] Wrap AWS HTTP requests with FileTransfer --- src/libstore/s3-binary-cache-store.cc | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/libstore/s3-binary-cache-store.cc b/src/libstore/s3-binary-cache-store.cc index cfa713b00c2..b17a3646e83 100644 --- a/src/libstore/s3-binary-cache-store.cc +++ b/src/libstore/s3-binary-cache-store.cc @@ -17,8 +17,12 @@ #include #include #include +#include +#include +#include #include #include +#include #include #include #include @@ -70,6 +74,49 @@ class AwsLogger : public Aws::Utils::Logging::FormattedLogSystem #endif }; +class AwsHttpClient : public Aws::Http::HttpClient +{ +public: + AwsHttpClient(const Aws::Client::ClientConfiguration& clientConfig) : HttpClient() {} + + std::shared_ptr MakeRequest(const std::shared_ptr& request, + Aws::Utils::RateLimits::RateLimiterInterface* readLimiter = nullptr, + Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter = nullptr) const override { + Aws::Http::URI uri = request->GetUri(); + Aws::String url = uri.GetURIString(); + + debug("Making request %s", url); + + std::shared_ptr response = std::make_shared(request); + + if (writeLimiter != nullptr) { + writeLimiter->ApplyAndPayForCost(request->GetSize()); + } + + FileTransferRequest ftr = FileTransferRequest(url); + getFileTransfer()->download(ftr); + return response; + } +}; + +class AwsHttpClientFactory : public Aws::Http::HttpClientFactory +{ +public: + std::shared_ptr CreateHttpClient(const Aws::Client::ClientConfiguration& clientConfiguration) const override { + return std::make_shared(clientConfiguration); + } + + std::shared_ptr CreateHttpRequest(const Aws::String& uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory& streamFactory) const override { + return CreateHttpRequest(Aws::Http::URI(uri), method, streamFactory); + } + + std::shared_ptr CreateHttpRequest(const Aws::Http::URI& uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory& streamFactory) const override { + auto request = std::make_shared(uri, method); + request->SetResponseStreamFactory(streamFactory); + return request; + } +}; + static void initAWS() { static std::once_flag flag; @@ -80,6 +127,10 @@ static void initAWS() shared.cc), so don't let aws-sdk-cpp override it. */ options.cryptoOptions.initAndCleanupOpenSSL = false; + options.httpOptions.httpClientFactory_create_fn = []() { + return std::make_shared(); + }; + if (verbosity >= lvlDebug) { options.loggingOptions.logLevel = verbosity == lvlDebug