diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index d86aa5e154052..2aa2fe8b99c1f 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -63,6 +63,7 @@ target_link_libraries( presto_function_metadata presto_connectors presto_http + presto_thrift_server presto_operators presto_session_properties presto_velox_plan_conversion diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 15cb145aa9596..2772177af6a50 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -43,6 +43,7 @@ #include "presto_cpp/main/operators/ShuffleRead.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include "presto_cpp/main/types/VeloxPlanConversion.h" +#include "thrift/server/ThriftServer.h" #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/caching/CacheTTLController.h" @@ -260,9 +261,9 @@ void PrestoServer::run() { } sslContext_ = util::createSSLContext( - optionalClientCertPath.value(), - ciphers, - systemConfig->httpClientHttp2Enabled()); + optionalClientCertPath.value(), ciphers, util::SSLProtocol::HTTP_1_1); + thriftSslContext_ = util::createSSLContext( + optionalClientCertPath.value(), ciphers, util::SSLProtocol::THRIFT); } if (systemConfig->internalCommunicationJwtEnabled()) { @@ -619,6 +620,8 @@ void PrestoServer::run() { } }; + startThriftServer(bindToNodeInternalAddressOnly, certPath, keyPath, ciphers); + // Start everything. After the return from the following call we are shutting // down. httpServer_->start(getHttpServerFilters(), [&](proxygen::HTTPServer* server) { @@ -679,6 +682,7 @@ void PrestoServer::run() { taskManager_.reset(); PRESTO_SHUTDOWN_LOG(INFO) << "Destroying HTTP Server"; httpServer_.reset(); + thriftServer_.reset(); unregisterFileReadersAndWriters(); unregisterFileSystems(); @@ -1070,6 +1074,7 @@ void PrestoServer::stop() { httpServer_->stop(); PRESTO_SHUTDOWN_LOG(INFO) << "HTTP Server stopped."; } + shutdownThriftServer(); } size_t PrestoServer::numDriverThreads() const { @@ -1492,7 +1497,7 @@ void PrestoServer::enableWorkerStatsReporting() { void PrestoServer::initVeloxPlanValidator() { VELOX_CHECK_NULL(planValidator_); - planValidator_ = std::make_unique(); + planValidator_ = std::make_shared(); } VeloxPlanValidator* PrestoServer::getVeloxPlanValidator() { @@ -1793,6 +1798,70 @@ void PrestoServer::createTaskManager() { driverExecutor_.get(), httpSrvCpuExecutor_.get(), spillerExecutor_.get()); } +void PrestoServer::startThriftServer( + bool bindToNodeInternalAddressOnly, + const std::string& certPath, + const std::string& keyPath, + const std::string& ciphers) { + auto* systemConfig = SystemConfig::instance(); + bool thriftServerEnabled = systemConfig->thriftServerEnabled(); + + if (thriftServerEnabled) { + std::unique_ptr thriftConfig; + folly::SocketAddress thriftAddress; + int thriftPort = systemConfig->thriftServerPort(); + if (bindToNodeInternalAddressOnly) { + thriftAddress.setFromHostPort(address_, thriftPort); + } else { + thriftAddress.setFromLocalPort(thriftPort); + } + thriftConfig = std::make_unique( + thriftAddress, certPath, keyPath, ciphers); + thriftServer_ = std::make_unique( + std::move(thriftConfig), + httpSrvIoExecutor_, + pool_, + planValidator_, + taskManager_); + + thriftServerFuture_ = + folly::via(folly::getGlobalCPUExecutor().get()) + .thenTry([this](folly::Try) { + try { + PRESTO_STARTUP_LOG(INFO) + << "Starting Thrift server asynchronously..."; + thriftServer_->start(); + PRESTO_STARTUP_LOG(INFO) + << "Thrift server started successfully"; + } catch (const std::exception& e) { + PRESTO_STARTUP_LOG(ERROR) + << "Thrift server failed to start: " << e.what(); + throw; + } + }); + } +} + +void PrestoServer::shutdownThriftServer() { + if (thriftServer_) { + PRESTO_SHUTDOWN_LOG(INFO) << "Stopping Thrift server"; + thriftServer_->stop(); + + // Wait for Thrift server thread to complete with timeout + try { + std::move(thriftServerFuture_) + .within(std::chrono::seconds(5)) // 5-second timeout + .get(); + PRESTO_SHUTDOWN_LOG(INFO) << "Thrift server stopped gracefully"; + } catch (const std::exception& e) { + PRESTO_SHUTDOWN_LOG(WARNING) + << "Thrift server shutdown timeout or error: " << e.what(); + } + + thriftServer_.reset(); + } +} + void PrestoServer::reportNodeStats(proxygen::ResponseHandler* downstream) { protocol::NodeStats nodeStats; diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index 391c5306f9987..159972992139b 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -48,6 +48,10 @@ namespace facebook::presto::http { class HttpServer; } +namespace facebook::presto::thrift { +class ThriftServer; +} + namespace proxygen { class ResponseHandler; } // namespace proxygen @@ -231,6 +235,16 @@ class PrestoServer { virtual void createTaskManager(); + /// Utility method to start the Thrift server if enabled + void startThriftServer( + bool bindToNodeInternalAddressOnly, + const std::string& certPath, + const std::string& keyPath, + const std::string& ciphers); + + /// Utility method to safely shutdown the Thrift server if running + void shutdownThriftServer(); + const std::string configDirectoryPath_; std::shared_ptr coordinatorDiscoverer_; @@ -269,7 +283,7 @@ class PrestoServer { // Executor for spilling. std::unique_ptr spillerExecutor_; - std::unique_ptr planValidator_; + std::shared_ptr planValidator_; std::unique_ptr exchangeSourceConnectionPool_; @@ -277,12 +291,14 @@ class PrestoServer { std::shared_ptr cache_; std::unique_ptr httpServer_; + std::unique_ptr thriftServer_; + folly::Future thriftServerFuture_{folly::makeFuture()}; std::unique_ptr signalHandler_; std::unique_ptr announcer_; std::unique_ptr heartbeatManager_; std::shared_ptr pool_; std::shared_ptr nativeWorkerPool_; - std::unique_ptr taskManager_; + std::shared_ptr taskManager_; std::unique_ptr taskResource_; std::atomic nodeState_{NodeState::kActive}; folly::Synchronized shuttingDown_{false}; @@ -313,6 +329,7 @@ class PrestoServer { std::string nodeLocation_; std::string nodePoolType_; folly::SSLContextPtr sslContext_; + folly::SSLContextPtr thriftSslContext_; std::string prestoBuiltinFunctionPrefix_; }; diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 450d4b34253bd..f5e78028f2918 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -134,6 +134,13 @@ SystemConfig::SystemConfig() { std::unordered_map>{ BOOL_PROP(kMutableConfig, false), NONE_PROP(kPrestoVersion), + BOOL_PROP(kThriftServerEnabled, true), + NUM_PROP(kThriftServerPort, 9090), + NUM_PROP(kThriftServerMaxConnections, 50000), + NUM_PROP(kThriftServerMaxRequests, 200), + NUM_PROP(kThriftServerIdleTimeout, 120000), + NUM_PROP(kThriftServerTaskExpireTimeMs, 60000), + NUM_PROP(kThriftServerStreamExpireTime, 60000), NONE_PROP(kHttpServerHttpPort), BOOL_PROP(kHttpServerReusePort, false), BOOL_PROP(kHttpServerBindToNodeInternalAddressOnlyEnabled, false), @@ -279,6 +286,34 @@ SystemConfig* SystemConfig::instance() { return instance.get(); } +bool SystemConfig::thriftServerEnabled() const { + return optionalProperty(kThriftServerEnabled).value(); +} + +int32_t SystemConfig::thriftServerPort() const { + return optionalProperty(kThriftServerPort).value(); +} + +int32_t SystemConfig::thriftServerMaxConnections() const { + return optionalProperty(kThriftServerMaxConnections).value(); +} + +int32_t SystemConfig::thriftServerMaxRequests() const { + return optionalProperty(kThriftServerMaxRequests).value(); +} + +int32_t SystemConfig::thriftServerIdleTimeout() const { + return optionalProperty(kThriftServerIdleTimeout).value(); +} + +int32_t SystemConfig::thriftServerTaskExpireTimeMs() const { + return optionalProperty(kThriftServerTaskExpireTimeMs).value(); +} + +int32_t SystemConfig::thriftServerStreamExpireTime() const { + return optionalProperty(kThriftServerStreamExpireTime).value(); +} + int SystemConfig::httpServerHttpPort() const { return requiredProperty(kHttpServerHttpPort); } diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index a70b091962e14..175d72b7e9eb4 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -159,6 +159,30 @@ class ConfigBase { class SystemConfig : public ConfigBase { public: static constexpr std::string_view kPrestoVersion{"presto.version"}; + + /// Thrift server configuration + static constexpr std::string_view kThriftServerEnabled{ + "presto.thrift-server.enabled"}; + + static constexpr std::string_view kThriftServerPort{ + "presto.thrift-server.port"}; + + static constexpr std::string_view kThriftServerMaxConnections{ + "presto.thrift-server.max-connections"}; + + static constexpr std::string_view kThriftServerMaxRequests{ + "presto.thrift-server.max-requests"}; + + static constexpr std::string_view kThriftServerIdleTimeout{ + "presto.thrift-server.idle-timeout"}; + + static constexpr std::string_view kThriftServerTaskExpireTimeMs{ + "presto.thrift-server.task-expire-time-ms"}; + + static constexpr std::string_view kThriftServerStreamExpireTime{ + "presto.thrift-server.stream-expire-time"}; + + /// HTTP server configuration static constexpr std::string_view kHttpServerHttpPort{ "http-server.http.port"}; @@ -808,6 +832,22 @@ class SystemConfig : public ConfigBase { static SystemConfig* instance(); + // Thrift server configuration + bool thriftServerEnabled() const; + + int32_t thriftServerPort() const; + + int32_t thriftServerMaxConnections() const; + + int32_t thriftServerMaxRequests() const; + + int32_t thriftServerIdleTimeout() const; + + int32_t thriftServerTaskExpireTimeMs() const; + + int32_t thriftServerStreamExpireTime() const; + + // HTTP server configuration int httpServerHttpPort() const; bool httpServerReusePort() const; diff --git a/presto-native-execution/presto_cpp/main/common/Utils.cpp b/presto-native-execution/presto_cpp/main/common/Utils.cpp index 96fb762e38b89..948d46a525300 100644 --- a/presto-native-execution/presto_cpp/main/common/Utils.cpp +++ b/presto-native-execution/presto_cpp/main/common/Utils.cpp @@ -32,16 +32,25 @@ DateTime toISOTimestamp(uint64_t timeMilli) { std::shared_ptr createSSLContext( const std::string& clientCertAndKeyPath, const std::string& ciphers, - bool http2Enabled) { + SSLProtocol protocol) { try { auto sslContext = std::make_shared(); sslContext->loadCertKeyPairFromFiles( clientCertAndKeyPath.c_str(), clientCertAndKeyPath.c_str()); sslContext->setCiphersOrThrow(ciphers); - if (http2Enabled) { - sslContext->setAdvertisedNextProtocols({"h2", "http/1.1"}); - } else { - sslContext->setAdvertisedNextProtocols({"http/1.1"}); + switch (protocol) { + case SSLProtocol::THRIFT: + sslContext->setVerificationOption( + folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY); + // Set ALPN for Rocket protocol + sslContext->setAdvertisedNextProtocols({"rs"}); + break; + case SSLProtocol::HTTP_1_1: + sslContext->setAdvertisedNextProtocols({"http/1.1"}); + break; + case SSLProtocol::HTTP_2: + sslContext->setAdvertisedNextProtocols({"h2", "http/1.1"}); + break; } return sslContext; } catch (const std::exception& ex) { diff --git a/presto-native-execution/presto_cpp/main/common/Utils.h b/presto-native-execution/presto_cpp/main/common/Utils.h index 1cc6f8cf9c0ae..9d64091bbe65f 100644 --- a/presto-native-execution/presto_cpp/main/common/Utils.h +++ b/presto-native-execution/presto_cpp/main/common/Utils.h @@ -28,10 +28,16 @@ namespace facebook::presto::util { using DateTime = std::string; DateTime toISOTimestamp(uint64_t timeMilli); +enum class SSLProtocol { + THRIFT, // Rocket protocol (rs) + HTTP_1_1, // HTTP/1.1 + HTTP_2 // HTTP/2 (h2) +}; + std::shared_ptr createSSLContext( const std::string& clientCertAndKeyPath, const std::string& ciphers, - bool http2Enabled); + SSLProtocol protocol); /// Returns current process-wide CPU time in nanoseconds. long getProcessCpuTimeNs(); diff --git a/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp b/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp index ce07dc3cfa545..18c97466445cd 100644 --- a/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp +++ b/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp @@ -191,6 +191,88 @@ TEST_F(ConfigTest, optionalSystemConfigs) { ASSERT_EQ(config.discoveryUri(), "my uri"); } +TEST_F(ConfigTest, thriftServerConfigs) { + SystemConfig config; + + // Test default values (when no thrift server configs are provided) + init(config, {}); + + // Test with thrift server enabled + init(config, {{"presto.thrift-server.enabled", "true"}}); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.enabled")).value_or(false), + true); + + // Test thrift server port configuration + init(config, {{"presto.thrift-server.port", "9090"}}); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.port")).value_or(9090), + 9090); + + // Test thrift server max connections + init(config, {{"presto.thrift-server.max-connections", "5000"}}); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.max-connections")).value_or(10000), + 5000); + + // Test thrift server max requests + init(config, {{"presto.thrift-server.max-requests", "5000"}}); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.max-requests")).value_or(10000), + 5000); + + // Test thrift server idle timeout + init(config, {{"presto.thrift-server.idle-timeout", "600000"}}); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.idle-timeout")).value_or(300000), + 600000); + + // Test thrift server task expire time + init(config, {{"presto.thrift-server.task-expire-time-ms", "600000"}}); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.task-expire-time-ms")).value_or(300000), + 600000); + + // Test thrift server stream expire time + init(config, {{"presto.thrift-server.stream-expire-time", "600000"}}); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.stream-expire-time")).value_or(300000), + 600000); + + // Test multiple thrift server configs together + init(config, { + {"presto.thrift-server.enabled", "true"}, + {"presto.thrift-server.port", "9091"}, + {"presto.thrift-server.max-connections", "8000"}, + {"presto.thrift-server.max-requests", "8000"}, + {"presto.thrift-server.idle-timeout", "900000"}, + {"presto.thrift-server.task-expire-time-ms", "900000"}, + {"presto.thrift-server.stream-expire-time", "900000"} + }); + + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.enabled")).value_or(false), + true); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.port")).value_or(9090), + 9091); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.max-connections")).value_or(10000), + 8000); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.max-requests")).value_or(10000), + 8000); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.idle-timeout")).value_or(300000), + 900000); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.task-expire-time-ms")).value_or(300000), + 900000); + ASSERT_EQ( + config.optionalProperty(std::string_view("presto.thrift-server.stream-expire-time")).value_or(300000), + 900000); +} + TEST_F(ConfigTest, optionalNodeConfigs) { NodeConfig config; init(config, {}); diff --git a/presto-native-execution/presto_cpp/main/thrift/CMakeLists.txt b/presto-native-execution/presto_cpp/main/thrift/CMakeLists.txt index 9deb725d94d06..cbd348a8ca983 100644 --- a/presto-native-execution/presto_cpp/main/thrift/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/thrift/CMakeLists.txt @@ -70,3 +70,32 @@ add_dependencies(presto_thrift_extra presto_thrift-cpp2) if(PRESTO_ENABLE_TESTING) add_subdirectory(tests) endif() + +add_library( + presto_thrift_server + server/ThriftServer.cpp + server/PrestoThriftServiceHandler.cpp +) + +target_include_directories( + presto_thrift_server PUBLIC + ${presto_thrift_INCLUDES} + ${THRIFT_INCLUDES} + ${GLOG_INCLUDE_DIR} + # Include main directory for PrestoThriftServiceHandler.h + ${CMAKE_SOURCE_DIR}/presto_cpp/main +) + +target_link_libraries( + presto_thrift_server + presto_thrift + ${presto_thrift_LIBRARIES} + ${THRIFT_LIBRARIES} + ${GLOG_LIBRARY} + presto_cpp_main_common + presto_cpp_main_types + presto_task_lib + ${THRIFT_TRANSPORT} + xsimd +) + diff --git a/presto-native-execution/presto_cpp/main/thrift/presto_thrift.thrift b/presto-native-execution/presto_cpp/main/thrift/presto_thrift.thrift index 035e40d5b2823..5c02b913435ce 100644 --- a/presto-native-execution/presto_cpp/main/thrift/presto_thrift.thrift +++ b/presto-native-execution/presto_cpp/main/thrift/presto_thrift.thrift @@ -705,6 +705,45 @@ struct TaskUpdateRequest { 6: optional TableWriteInfo tableWriteInfo; } +struct TaskResult { + 1: i64 sequence; + 2: i64 nextSequence; + 3: optional IOBufPtr data; + 4: bool complete; + 5: optional list remainingBytes; +} + service PrestoThrift { - void fake(); + /** + * Get task results - corresponds to /v1/task/{taskId}/results/{bufferId}/{token} + * @param taskId The ID of the task to get results for + * @param bufferId The buffer ID to get results from + * @param token Continuation token for paging + * @param maxSizeBytes Maximum number of bytes to return + * @param maxWaitMicros Maximum time to wait in microseconds + * @param getDataSize Two phase protocol: if true, return the size of the data in the first phrase + * @return TaskResult containing the data and metadata + */ + TaskResult getTaskResults( + 1: string taskId, + 2: i64 bufferId, + 3: i64 token, + 4: i64 maxSizeBytes, + 5: i64 maxWaitMicros, + 6: bool getDataSize, + ); + + /** + * Acknowledge task results - corresponds to /v1/task/{taskId}/results/{bufferId}/{token}/acknowledge + * @param taskId The ID of the task to acknowledge results for + * @param bufferId The buffer ID to acknowledge results for + * @param token The token to acknowledge up to + */ + void acknowledgeTaskResults(1: string taskId, 2: i64 bufferId, 3: i64 token); + + /** + * Abort task results - corresponds to /v1/task/{taskId}/results + * @param taskId The ID of the task to abort results for + */ + void abortTaskResults(1: string taskId, 2: i64 destination); } diff --git a/presto-native-execution/presto_cpp/main/thrift/server/PrestoThriftServiceHandler.cpp b/presto-native-execution/presto_cpp/main/thrift/server/PrestoThriftServiceHandler.cpp new file mode 100644 index 0000000000000..812abd0f5b023 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/thrift/server/PrestoThriftServiceHandler.cpp @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/thrift/server/PrestoThriftServiceHandler.h" +#include "presto_cpp/main/TaskManager.h" +#include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +namespace facebook::presto::thrift { + +folly::Future> +PrestoThriftServiceHandler::future_getTaskResults( + std::unique_ptr taskId, + int64_t bufferId, + int64_t token, + int64_t maxSizeBytes, + int64_t maxWaitMicros, + bool getDataSize) { + protocol::Duration maxWait(maxWaitMicros, protocol::TimeUnit::MICROSECONDS); + protocol::DataSize protoMaxSize; + if (getDataSize) { + protoMaxSize = protocol::DataSize(0, protocol::DataUnit::BYTE); + } else { + protoMaxSize = protocol::DataSize(maxSizeBytes, protocol::DataUnit::BYTE); + } + + // Create a callback state for HTTP compatibility + auto callbackState = std::make_shared(); + return taskManager_ + ->getResults( + *taskId, bufferId, token, protoMaxSize, maxWait, callbackState) + .thenValue( + [callbackState](std::unique_ptr result) + -> std::unique_ptr<::facebook::presto::thrift::TaskResult> { + auto thriftResult = + std::make_unique<::facebook::presto::thrift::TaskResult>(); + + *thriftResult->sequence_ref() = result->sequence; + *thriftResult->nextSequence_ref() = result->nextSequence; + *thriftResult->complete_ref() = result->complete; + if (!result->remainingBytes.empty()) { + thriftResult->remainingBytes_ref() = + std::move(result->remainingBytes); + } + if (result->data && result->data->length() > 0) { + thriftResult->data_ref() = std::move(result->data); + } + + return thriftResult; + }); +} + +folly::Future +PrestoThriftServiceHandler::future_acknowledgeTaskResults( + std::unique_ptr taskId, + int64_t bufferId, + int64_t token) { + return folly::makeFutureWith( + [this, taskId = std::move(taskId), bufferId, token]() mutable { + taskManager_->acknowledgeResults(*taskId, bufferId, token); + return folly::unit; + }); +} + +folly::Future PrestoThriftServiceHandler::future_abortTaskResults( + std::unique_ptr taskId, + int64_t destination) { + return folly::makeFutureWith( + [this, taskId = std::move(taskId), destination]() mutable { + taskManager_->abortResults(*taskId, destination); + return folly::unit; + }); +} + +} // namespace facebook::presto::thrift diff --git a/presto-native-execution/presto_cpp/main/thrift/server/PrestoThriftServiceHandler.h b/presto-native-execution/presto_cpp/main/thrift/server/PrestoThriftServiceHandler.h new file mode 100644 index 0000000000000..bf47d1959f90c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/thrift/server/PrestoThriftServiceHandler.h @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/thrift/gen-cpp2/PrestoThrift.h" +#include "presto_cpp/main/TaskManager.h" +#include "presto_cpp/main/types/VeloxPlanValidator.h" +#include "velox/common/memory/Memory.h" + +namespace facebook::presto::thrift { + +class PrestoThriftServiceHandler + : public facebook::presto::thrift::PrestoThriftSvIf { + public: + explicit PrestoThriftServiceHandler( + std::shared_ptr pool, + std::shared_ptr planValidator, + std::shared_ptr taskManager) + : pool_(std::move(pool)), + planValidator_(std::move(planValidator)), + taskManager_(std::move(taskManager)) {} + + folly::Future> + future_getTaskResults( + std::unique_ptr taskId, + int64_t bufferId, + int64_t token, + int64_t maxSizeBytes, + int64_t maxWaitMicros, + bool getDataSize) override; + + folly::Future future_acknowledgeTaskResults( + std::unique_ptr taskId, + int64_t bufferId, + int64_t token) override; + + folly::Future future_abortTaskResults( + std::unique_ptr taskId, + int64_t destination) override; + + private: + std::shared_ptr const pool_; + std::shared_ptr const planValidator_; + std::shared_ptr const taskManager_; +}; + +} // namespace facebook::presto::thrift diff --git a/presto-native-execution/presto_cpp/main/thrift/server/ThriftServer.cpp b/presto-native-execution/presto_cpp/main/thrift/server/ThriftServer.cpp new file mode 100644 index 0000000000000..fb9cd6e199a45 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/thrift/server/ThriftServer.cpp @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/thrift/server/ThriftServer.h" + +#include +#include + +#include "PrestoThriftServiceHandler.h" +#include "common/services/cpp/ServiceFramework.h" +#include "common/services/cpp/TLSConstants.h" +#include "presto_cpp/main/common/Configs.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::thrift { + +ThriftConfig::ThriftConfig( + const folly::SocketAddress& address, + const std::string& certPath, + const std::string& keyPath, + const std::string& supportedCiphers) + : address_(address), + certPath_(certPath), + keyPath_(keyPath), + taskExpireTimeMs_( + SystemConfig::instance()->thriftServerTaskExpireTimeMs()), + streamExpireTimeMs_( + SystemConfig::instance()->thriftServerStreamExpireTime()), + maxRequest_(SystemConfig::instance()->thriftServerMaxRequests()), + maxConnections_(SystemConfig::instance()->thriftServerMaxConnections()), + idleTimeout_(SystemConfig::instance()->thriftServerIdleTimeout()) {} + +ThriftServer::ThriftServer( + std::unique_ptr config, + std::shared_ptr ioExecutor, + std::shared_ptr pool, + std::shared_ptr planValidator, + std::shared_ptr taskManager) + : config_(std::move(config)), ioExecutor_(std::move(ioExecutor)) { + VELOX_CHECK_NOT_NULL(config_); + VELOX_CHECK_NOT_NULL(pool); + VELOX_CHECK_NOT_NULL(planValidator); + VELOX_CHECK_NOT_NULL(taskManager); + + server_ = std::make_unique(); + handler_ = std::make_shared( + pool, planValidator, taskManager); + + server_->setIOThreadPool(ioExecutor_); + server_->setInterface(handler_); + server_->setAddress(config_->getAddress()); + + // Set connection limits and timeouts + server_->setMaxConnections(config_->getMaxConnections()); + server_->setMaxRequests(config_->getMaxRequest()); + server_->setIdleTimeout(std::chrono::milliseconds(config_->getIdleTimeout())); + server_->setTaskExpireTime( + std::chrono::milliseconds(config_->getTaskExpireTimeMs())); + server_->setStreamExpireTime( + std::chrono::milliseconds(config_->getStreamExpireTimeMs())); + + // Configure SSL if cert path is provided + if (!config_->getCertPath().empty() && !config_->getKeyPath().empty()) { + wangle::SSLContextConfig sslCfg; + sslCfg.isDefault = true; + sslCfg.clientVerification = + folly::SSLContext::VerifyClientCertificate::DO_NOT_REQUEST; + sslCfg.setCertificate(config_->getCertPath(), config_->getKeyPath(), ""); + sslCfg.sslCiphers = config_->getSupportedCiphers(); + sslCfg.setNextProtocols({"rs"}); + server_->setSSLConfig(std::make_shared(sslCfg)); + } +} + +void ThriftServer::start() { + PRESTO_STARTUP_LOG(INFO) << "=== THRIFT SERVER CONFIGURATION SUMMARY ===" + << "\n Address: " + << config_->getAddress().getAddressStr() << ":" + << config_->getAddress().getPort() + << "\n Max Conns: " << config_->getMaxConnections() + << "\n Max Requests: " << config_->getMaxRequest() + << "\n Idle Timeout: " << config_->getIdleTimeout() + << "ms" << "\n Task/Stream Timeout: " + << config_->getTaskExpireTimeMs() << "ms"; + server_->serve(); +} + +} // namespace facebook::presto::thrift diff --git a/presto-native-execution/presto_cpp/main/thrift/server/ThriftServer.h b/presto-native-execution/presto_cpp/main/thrift/server/ThriftServer.h new file mode 100644 index 0000000000000..944e045b54446 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/thrift/server/ThriftServer.h @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include +#include + +#include "PrestoThriftServiceHandler.h" +#include "presto_cpp/main/common/Utils.h" + +namespace apache { +namespace thrift { +class ThriftServer; +} +} // namespace apache + +namespace facebook { +namespace velox { +namespace memory { +class MemoryPool; +} +} // namespace velox +} // namespace facebook + +namespace facebook::presto { + +class TaskManager; +class VeloxPlanValidator; + +namespace thrift { +class ThriftConfig { + public: + ThriftConfig( + const folly::SocketAddress& address, + const std::string& certPath, + const std::string& keyPath, + const std::string& supportedCiphers); + + const folly::SocketAddress& getAddress() const { + return address_; + } + const std::string& getCertPath() const { + return certPath_; + } + const std::string& getKeyPath() const { + return keyPath_; + } + const std::string& getSupportedCiphers() const { + return supportedCiphers_; + } + int getTaskExpireTimeMs() const { + return taskExpireTimeMs_; + } + int getStreamExpireTimeMs() const { + return streamExpireTimeMs_; + } + int getMaxRequest() const { + return maxRequest_; + } + int getMaxConnections() const { + return maxConnections_; + } + int getIdleTimeout() const { + return idleTimeout_; + } + + private: + const folly::SocketAddress address_; + const std::string certPath_; + const std::string keyPath_; + const std::string ciphers_; + const std::string supportedCiphers_; + + int taskExpireTimeMs_; + int streamExpireTimeMs_; + int maxRequest_; + int maxConnections_; + int idleTimeout_; +}; + +class ThriftServer { + public: + explicit ThriftServer( + std::unique_ptr config, + std::shared_ptr ioExecutor, + std::shared_ptr pool, + std::shared_ptr planValidator, + std::shared_ptr taskManager); + + void start(); + + void stop() { + server_->stop(); + PRESTO_SHUTDOWN_LOG(INFO) << "Stopping Thrift server..."; + } + + folly::SocketAddress address() const; + + private: + std::unique_ptr config_; + std::shared_ptr ioExecutor_; + std::unique_ptr server_; + std::shared_ptr handler_; +}; + +} // namespace thrift +} // namespace facebook::presto