diff --git a/PROCESSORS.md b/PROCESSORS.md index c3ebccb10f..1fa6cdfec0 100644 --- a/PROCESSORS.md +++ b/PROCESSORS.md @@ -1136,16 +1136,16 @@ Establishes a TCP Server that defines and retrieves one or more byte messages fr In the list below, the names of required properties appear in bold. Any other properties (not in bold) are considered optional. The table also indicates any default values, and whether a property supports the NiFi Expression Language. -| Name | Default Value | Allowable Values | Description | -|----------------------------|---------------|------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| **endpoint-list** | | | A comma delimited list of the endpoints to connect to. The format should be :. | -| concurrent-handler-count | 1 | | Number of concurrent handlers for this session | -| reconnect-interval | 5 s | | The number of seconds to wait before attempting to reconnect to the endpoint. | -| Stay Connected | true | | Determines if we keep the same socket despite having no data | -| receive-buffer-size | 16 MB | | The size of the buffer to receive data in. Default 16384 (16MB). | -| SSL Context Service | | | SSL Context Service Name | -| connection-attempt-timeout | 3 | | Maximum number of connection attempts before attempting backup hosts, if configured | -| end-of-message-byte | 13 | | Byte value which denotes end of message. Must be specified as integer within the valid byte range (-128 thru 127). For example, '13' = Carriage return and '10' = New line. Default '13'. | +| Name | Default Value | Allowable Values | Description | +|-------------------------------|---------------|------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| **Endpoint List** | | | A comma delimited list of the endpoints to connect to. The format should be :. | +| SSL Context Service | | | SSL Context Service Name | +| Message Delimiter | \n | | Character that denotes the end of the message. | +| **Max Size of Message Queue** | 10000 | | Maximum number of messages allowed to be buffered before processing them when the processor is triggered. If the buffer is full, the message is ignored. If set to zero the buffer is unlimited. | +| Maximum Message Size | | | Optional size of the buffer to receive data in. | +| **Max Batch Size** | 500 | | The maximum number of messages to process at a time. | +| **Timeout** | 1s | | The timeout for connecting to and communicating with the destination.
**Supports Expression Language: true** | +| **Reconnection Interval** | 1 min | | The duration to wait before attempting to reconnect to the endpoints.
**Supports Expression Language: true** | ### Relationships @@ -1154,6 +1154,13 @@ In the list below, the names of required properties appear in bold. Any other pr | success | All files are routed to success | | partial | Indicates an incomplete message as a result of encountering the end of message byte trigger | +### Output Attributes + +| Attribute | Relationship | Description | +|-----------------|------------------|----------------------------------------------------------| +| source.endpoint | success, partial | The address of the source endpoint the message came from | + + ## GetUSBCamera @@ -2779,5 +2786,3 @@ In the list below, the names of required properties appear in bold. Any other pr |---------|-----------------------------------------| | success | All files are routed to success | | failure | Failed files are transferred to failure | - - diff --git a/extensions/standard-processors/processors/GetTCP.cpp b/extensions/standard-processors/processors/GetTCP.cpp index 4eed2d13d4..a1cebafe28 100644 --- a/extensions/standard-processors/processors/GetTCP.cpp +++ b/extensions/standard-processors/processors/GetTCP.cpp @@ -17,275 +17,312 @@ */ #include "GetTCP.h" -#ifndef WIN32 -#include -#endif -#include -#include #include -#include #include -#include -#include #include -#include "io/ClientSocket.h" +#include +#include +#include "utils/net/AsioCoro.h" #include "io/StreamFactory.h" #include "utils/gsl.h" #include "utils/StringUtils.h" -#include "utils/TimeUtil.h" #include "core/ProcessContext.h" #include "core/ProcessSession.h" #include "core/ProcessSessionFactory.h" #include "core/PropertyBuilder.h" #include "core/Resource.h" -namespace org::apache::nifi::minifi::processors { +using namespace std::literals::chrono_literals; -const char *DataHandler::SOURCE_ENDPOINT_ATTRIBUTE = "source.endpoint"; +namespace org::apache::nifi::minifi::processors { const core::Property GetTCP::EndpointList( - core::PropertyBuilder::createProperty("endpoint-list")->withDescription("A comma delimited list of the endpoints to connect to. The format should be :.")->isRequired(true) - ->build()); - -const core::Property GetTCP::ConcurrentHandlers( - core::PropertyBuilder::createProperty("concurrent-handler-count")->withDescription("Number of concurrent handlers for this session")->withDefaultValue(1)->build()); - -const core::Property GetTCP::ReconnectInterval( - core::PropertyBuilder::createProperty("reconnect-interval")->withDescription("The number of seconds to wait before attempting to reconnect to the endpoint.") - ->withDefaultValue("5 s")->build()); + core::PropertyBuilder::createProperty("Endpoint List") + ->withDescription("A comma delimited list of the endpoints to connect to. The format should be :.") + ->isRequired(true)->build()); -const core::Property GetTCP::ReceiveBufferSize( - core::PropertyBuilder::createProperty("receive-buffer-size")->withDescription("The size of the buffer to receive data in. Default 16384 (16MB).")->withDefaultValue("16 MB") +const core::Property GetTCP::SSLContextService( + core::PropertyBuilder::createProperty("SSL Context Service") + ->withDescription("SSL Context Service Name") + ->asType()->build()); + +const core::Property GetTCP::MessageDelimiter( + core::PropertyBuilder::createProperty("Message Delimiter")->withDescription( + "Character that denotes the end of the message.") + ->withDefaultValue("\\n")->build()); + +const core::Property GetTCP::MaxQueueSize( + core::PropertyBuilder::createProperty("Max Size of Message Queue") + ->withDescription("Maximum number of messages allowed to be buffered before processing them when the processor is triggered. " + "If the buffer is full, the message is ignored. If set to zero the buffer is unlimited.") + ->withDefaultValue(10000) + ->isRequired(true) ->build()); -const core::Property GetTCP::SSLContextService( - core::PropertyBuilder::createProperty("SSL Context Service")->withDescription("SSL Context Service Name")->asType()->build()); +const core::Property GetTCP::MaxBatchSize( + core::PropertyBuilder::createProperty("Max Batch Size") + ->withDescription("The maximum number of messages to process at a time.") + ->withDefaultValue(500) + ->isRequired(true) + ->build()); -const core::Property GetTCP::StayConnected( - core::PropertyBuilder::createProperty("Stay Connected")->withDescription("Determines if we keep the same socket despite having no data")->withDefaultValue(true)->build()); +const core::Property GetTCP::MaxMessageSize( + core::PropertyBuilder::createProperty("Maximum Message Size") + ->withDescription("Optional size of the buffer to receive data in.")->build()); -const core::Property GetTCP::ConnectionAttemptLimit( - core::PropertyBuilder::createProperty("connection-attempt-timeout")->withDescription("Maximum number of connection attempts before attempting backup hosts, if configured")->withDefaultValue( - 3)->build()); +const core::Property GetTCP::Timeout = core::PropertyBuilder::createProperty("Timeout") + ->withDescription("The timeout for connecting to and communicating with the destination.") + ->withDefaultValue("1s") + ->isRequired(true) + ->supportsExpressionLanguage(true) + ->build(); -const core::Property GetTCP::EndOfMessageByte( - core::PropertyBuilder::createProperty("end-of-message-byte")->withDescription( - "Byte value which denotes end of message. Must be specified as integer within the valid byte range (-128 thru 127). For example, '13' = Carriage return and '10' = New line. Default '13'.") - ->withDefaultValue("13")->build()); +const core::Property GetTCP::ReconnectInterval = core::PropertyBuilder::createProperty("Reconnection Interval") + ->withDescription("The duration to wait before attempting to reconnect to the endpoints.") + ->withDefaultValue("1 min") + ->isRequired(true) + ->supportsExpressionLanguage(true) + ->build(); const core::Relationship GetTCP::Success("success", "All files are routed to success"); const core::Relationship GetTCP::Partial("partial", "Indicates an incomplete message as a result of encountering the end of message byte trigger"); -int16_t DataHandler::handle(const std::string& source, uint8_t *message, size_t size, bool partial) { - std::shared_ptr my_session = sessionFactory_->createSession(); - std::shared_ptr flowFile = my_session->create(); - - my_session->writeBuffer(flowFile, gsl::make_span(reinterpret_cast(message), size)); +const core::OutputAttribute GetTCP::SourceEndpoint{"source.endpoint", {Success, Partial}, "The address of the source endpoint the message came from"}; - my_session->putAttribute(flowFile, SOURCE_ENDPOINT_ATTRIBUTE, source); - - if (partial) { - my_session->transfer(flowFile, GetTCP::Partial); - } else { - my_session->transfer(flowFile, GetTCP::Success); - } - - my_session->commit(); - - return 0; -} void GetTCP::initialize() { setSupportedProperties(properties()); setSupportedRelationships(relationships()); } -void GetTCP::onSchedule(const std::shared_ptr &context, const std::shared_ptr &sessionFactory) { - std::string value; - if (context->getProperty(EndpointList.getName(), value)) { - endpoints = utils::StringUtils::split(value, ","); - } - int handlers = 0; - if (context->getProperty(ConcurrentHandlers.getName(), handlers)) { - concurrent_handlers_ = handlers; +std::vector GetTCP::parseEndpointList(core::ProcessContext& context) { + std::vector connections_to_make; + if (auto endpoint_list_str = context.getProperty(EndpointList)) { + for (const auto& endpoint_str : utils::StringUtils::splitAndTrim(*endpoint_list_str, ",")) { + auto hostname_service_pair = utils::StringUtils::splitAndTrim(endpoint_str, ":"); + if (hostname_service_pair.size() != 2) { + logger_->log_error("%s endpoint is invalid, expected {hostname}:{service} format", endpoint_str); + continue; + } + connections_to_make.emplace_back(hostname_service_pair[0], hostname_service_pair[1]); + } } + if (connections_to_make.empty()) + throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("No valid endpoint in {} property", EndpointList.getName())); - stay_connected_ = true; - if (context->getProperty(StayConnected.getName(), value)) { - stay_connected_ = utils::StringUtils::toBool(value).value_or(true); + return connections_to_make; +} + +char GetTCP::parseDelimiter(core::ProcessContext& context) { + char delimiter = '\n'; + if (auto delimiter_str = context.getProperty(GetTCP::MessageDelimiter)) { + auto parsed_delimiter = utils::StringUtils::parseCharacter(*delimiter_str); + if (!parsed_delimiter || !parsed_delimiter->has_value()) + throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("Invalid delimiter: {} (it must be a single (escaped or not) character", *delimiter_str)); + delimiter = **parsed_delimiter; } + return delimiter; +} - int connects = 0; - if (context->getProperty(ConnectionAttemptLimit.getName(), connects)) { - connection_attempt_limit_ = connects; +std::optional GetTCP::parseSSLContext(core::ProcessContext& context) { + std::optional ssl_context; + if (auto context_name = context.getProperty(SSLContextService)) { + if (auto controller_service = context.getControllerService(*context_name)) { + if (auto ssl_context_service = std::dynamic_pointer_cast(context.getControllerService(*context_name))) { + ssl_context = utils::net::getSslContext(*ssl_context_service); + } else { + throw Exception(PROCESS_SCHEDULE_EXCEPTION, *context_name + " is not an SSL Context Service"); + } + } else { + throw Exception(PROCESS_SCHEDULE_EXCEPTION, "Invalid controller service: " + *context_name); + } } - context->getProperty(ReceiveBufferSize.getName(), receive_buffer_size_); + return ssl_context; +} - if (context->getProperty(EndOfMessageByte.getName(), value)) { - logger_->log_trace("EOM is passed in as %s", value); - int64_t byteValue = 0; - core::Property::StringToInt(value, byteValue); - endOfMessageByte = static_cast(byteValue & 0xFF); +uint64_t GetTCP::parseMaxBatchSize(core::ProcessContext& context) { + if (auto max_batch_size = context.getProperty(MaxBatchSize)) { + if (*max_batch_size == 0) { + throw Exception(PROCESS_SCHEDULE_EXCEPTION, fmt::format("{} should be non-zero.", MaxBatchSize.getName())); + } + return *max_batch_size; } + return MaxBatchSize.getDefaultValue(); +} + +void GetTCP::onSchedule(const std::shared_ptr& context, const std::shared_ptr&) { + gsl_Expects(context); + + auto connections_to_make = parseEndpointList(*context); + auto delimiter = parseDelimiter(*context); + auto ssl_context = parseSSLContext(*context); + + std::optional max_queue_size = context->getProperty(MaxQueueSize); + std::optional max_message_size = context->getProperty(MaxMessageSize); - logger_->log_trace("EOM is defined as %i", static_cast(endOfMessageByte)); - if (auto reconnect_interval = context->getProperty(ReconnectInterval)) { - reconnect_interval_ = reconnect_interval->getMilliseconds(); - logger_->log_debug("Reconnect interval is %" PRId64 " ms", reconnect_interval_.count()); - } else { - logger_->log_debug("Reconnect interval using default value of %" PRId64 " ms", reconnect_interval_.count()); + asio::steady_timer::duration timeout_duration = 1s; + if (auto timeout_value = context->getProperty(Timeout)) { + timeout_duration = timeout_value->getMilliseconds(); } - handler_ = std::make_unique(sessionFactory); - - f_ex = [&] { - std::unique_ptr socket_ptr; - // reuse the byte buffer. - std::vector buffer; - int reconnects = 0; - do { - if ( socket_ring_buffer_.try_dequeue(socket_ptr) ) { - buffer.resize(receive_buffer_size_); - const auto size_read = socket_ptr->read(buffer, false); - if (!io::isError(size_read)) { - if (size_read != 0) { - // determine cut location - size_t startLoc = 0; - for (size_t i = 0; i < size_read; i++) { - if (buffer.at(i) == endOfMessageByte && i > 0) { - if (i-startLoc > 0) { - handler_->handle(socket_ptr->getHostname(), reinterpret_cast(buffer.data())+startLoc, (i-startLoc), true); - } - startLoc = i; - } - } - if (startLoc > 0) { - logger_->log_trace("Starting at %i, ending at %i", startLoc, size_read); - if (size_read-startLoc > 0) { - handler_->handle(socket_ptr->getHostname(), reinterpret_cast(buffer.data())+startLoc, (size_read-startLoc), true); - } - } else { - logger_->log_trace("Handling at %i, ending at %i", startLoc, size_read); - if (size_read > 0) { - handler_->handle(socket_ptr->getHostname(), reinterpret_cast(buffer.data()), size_read, false); - } - } - reconnects = 0; - } - socket_ring_buffer_.enqueue(std::move(socket_ptr)); - } else if (size_read == static_cast(-2) && stay_connected_) { - if (++reconnects > connection_attempt_limit_) { - logger_->log_info("Too many reconnects, exiting thread"); - socket_ptr->close(); - return -1; - } - logger_->log_info("Sleeping for %" PRId64 " msec before attempting to reconnect", int64_t{reconnect_interval_.count()}); - std::this_thread::sleep_for(reconnect_interval_); - socket_ring_buffer_.enqueue(std::move(socket_ptr)); - } else { - socket_ptr->close(); - std::this_thread::sleep_for(reconnect_interval_); - logger_->log_info("Read response returned a -1 from socket, exiting thread"); - return -1; - } - } else { - std::this_thread::sleep_for(reconnect_interval_); - logger_->log_info("Could not use socket, exiting thread"); - return -1; - } - }while (running_); - logger_->log_debug("Ending private thread"); - return 0; - }; - - if (context->getProperty(SSLContextService.getName(), value)) { - std::shared_ptr service = context->getControllerService(value); - if (nullptr != service) { - ssl_service_ = std::static_pointer_cast(service); - } + asio::steady_timer::duration reconnection_interval = 1min; + if (auto reconnect_interval_value = context->getProperty(ReconnectInterval)) { + reconnection_interval = reconnect_interval_value->getMilliseconds(); } - client_thread_pool_.setMaxConcurrentTasks(concurrent_handlers_); - client_thread_pool_.start(); - running_ = true; + client_.emplace(delimiter, timeout_duration, reconnection_interval, std::move(ssl_context), max_queue_size, max_message_size, std::move(connections_to_make), logger_); + client_thread_ = std::thread([this]() { client_->run(); }); // NOLINT + + max_batch_size_ = parseMaxBatchSize(*context); } void GetTCP::notifyStop() { - running_ = false; - // await threads to shutdown. - client_thread_pool_.shutdown(); - std::unique_ptr socket_ptr; - while (socket_ring_buffer_.size_approx() > 0) { - socket_ring_buffer_.try_dequeue(socket_ptr); + if (client_) + client_->stop(); +} + +void GetTCP::transferAsFlowFile(const utils::net::Message& message, core::ProcessSession& session) { + auto flow_file = session.create(); + session.writeBuffer(flow_file, message.message_data); + flow_file->setAttribute(GetTCP::SourceEndpoint.getName(), fmt::format("{}:{}", message.sender_address.to_string(), std::to_string(message.server_port))); + if (message.is_partial) + session.transfer(flow_file, Partial); + else + session.transfer(flow_file, Success); +} + +void GetTCP::onTrigger(const std::shared_ptr&, const std::shared_ptr& session) { + gsl_Expects(session && max_batch_size_ > 0); + size_t logs_processed = 0; + while (!client_->queueEmpty() && logs_processed < max_batch_size_) { + utils::net::Message received_message; + if (!client_->tryDequeue(received_message)) + break; + transferAsFlowFile(received_message, *session); + ++logs_processed; } } -void GetTCP::onTrigger(const std::shared_ptr &context, const std::shared_ptr& /*session*/) { - // Perform directory list - std::lock_guard lock(mutex_); - // check if the futures are valid. If they've terminated remove it from the map. - - for (auto &initEndpoint : endpoints) { - std::vector hostAndPort = utils::StringUtils::split(initEndpoint, ":"); - auto realizedHost = hostAndPort.at(0); -#ifdef WIN32 - if ("localhost" == realizedHost) { - realizedHost = org::apache::nifi::minifi::io::Socket::getMyHostName(); + +GetTCP::TcpClient::TcpClient(char delimiter, + asio::steady_timer::duration timeout_duration, + asio::steady_timer::duration reconnection_interval, + std::optional ssl_context, + std::optional max_queue_size, + std::optional max_message_size, + std::vector connections, + std::shared_ptr logger) + : delimiter_(delimiter), + timeout_duration_(timeout_duration), + reconnection_interval_(reconnection_interval), + ssl_context_(std::move(ssl_context)), + max_queue_size_(max_queue_size), + max_message_size_(max_message_size), + connections_(std::move(connections)), + logger_(std::move(logger)) { +} + +GetTCP::TcpClient::~TcpClient() { + stop(); +} + + +void GetTCP::TcpClient::run() { + gsl_Expects(!connections_.empty()); + for (const auto& connection_id : connections_) { + asio::co_spawn(io_context_, doReceiveFrom(connection_id), asio::detached); // NOLINT + } + io_context_.run(); +} + +void GetTCP::TcpClient::stop() { + io_context_.stop(); +} + +bool GetTCP::TcpClient::queueEmpty() const { + return concurrent_queue_.empty(); +} + +bool GetTCP::TcpClient::tryDequeue(utils::net::Message& received_message) { + return concurrent_queue_.tryDequeue(received_message); +} + +asio::awaitable GetTCP::TcpClient::readLoop(auto& socket) { + std::string read_message; + bool previous_didnt_end_with_delimiter = false; + bool current_doesnt_end_with_delimiter = false; + while (true) { + { + previous_didnt_end_with_delimiter = current_doesnt_end_with_delimiter; + current_doesnt_end_with_delimiter = false; } -#endif - if (hostAndPort.size() != 2) { + auto dynamic_buffer = max_message_size_ ? asio::dynamic_buffer(read_message, *max_message_size_) : asio::dynamic_buffer(read_message); + auto [read_error, bytes_read] = co_await asio::async_read_until(socket, dynamic_buffer, delimiter_, utils::net::use_nothrow_awaitable); // NOLINT + + if (*max_message_size_ && read_error == asio::error::not_found) { + current_doesnt_end_with_delimiter = true; + bytes_read = *max_message_size_; + } else if (read_error) { + logger_->log_error("Error during read %s", read_error.message()); + co_return read_error; + } + + if (bytes_read == 0) continue; + + if (!max_queue_size_ || max_queue_size_ > concurrent_queue_.size()) { + utils::net::Message message{read_message.substr(0, bytes_read), utils::net::IpProtocol::TCP, socket.lowest_layer().remote_endpoint().address(), socket.lowest_layer().remote_endpoint().port()}; + if (previous_didnt_end_with_delimiter || current_doesnt_end_with_delimiter) + message.is_partial = true; + concurrent_queue_.enqueue(std::move(message)); + } else { + logger_->log_warn("Queue is full. TCP message ignored."); } + read_message.erase(0, bytes_read); + } +} - auto portStr = hostAndPort.at(1); - auto endpoint = utils::StringUtils::join_pack(realizedHost, ":", portStr); - - auto endPointFuture = live_clients_.find(endpoint); - // does not exist - if (endPointFuture == live_clients_.end()) { - logger_->log_info("creating endpoint for %s", endpoint); - if (hostAndPort.size() == 2) { - logger_->log_debug("Opening another socket to %s:%s is secure %d", realizedHost, portStr, (ssl_service_ != nullptr)); - std::unique_ptr socket = - ssl_service_ != nullptr ? stream_factory_->createSecureSocket(realizedHost, std::stoi(portStr), ssl_service_) : stream_factory_->createSocket(realizedHost, std::stoi(portStr)); - if (!socket) { - logger_->log_error("Could not create socket during initialization for %s", endpoint); - continue; - } - socket->setNonBlocking(); - if (socket->initialize() != -1) { - logger_->log_debug("Enqueueing socket into ring buffer %s:%s", realizedHost, portStr); - socket_ring_buffer_.enqueue(std::move(socket)); - } else { - logger_->log_error("Could not create socket during initialization for %s", endpoint); +template +asio::awaitable GetTCP::TcpClient::doReceiveFromEndpoint(const asio::ip::tcp::endpoint& endpoint, SocketType& socket) { + auto [connection_error] = co_await utils::net::asyncOperationWithTimeout(socket.lowest_layer().async_connect(endpoint, utils::net::use_nothrow_awaitable), timeout_duration_); // NOLINT + if (connection_error) + co_return connection_error; + auto [handshake_error] = co_await utils::net::handshake(socket, timeout_duration_); + if (handshake_error) + co_return handshake_error; + co_return co_await readLoop(socket); +} + +asio::awaitable GetTCP::TcpClient::doReceiveFrom(const utils::net::ConnectionId& connection_id) { + while (true) { + asio::ip::tcp::resolver resolver(io_context_); + auto [resolve_error, resolve_result] = co_await utils::net::asyncOperationWithTimeout( // NOLINT + resolver.async_resolve(connection_id.getHostname(), connection_id.getService(), utils::net::use_nothrow_awaitable), timeout_duration_); + if (resolve_error) { + logger_->log_error("Error during resolution: %s", resolve_error.message()); + co_await utils::net::async_wait(reconnection_interval_); + continue; + } + + std::error_code last_error; + for (const auto& endpoint : resolve_result) { + if (ssl_context_) { + utils::net::SslSocket ssl_socket{io_context_, *ssl_context_}; + last_error = co_await doReceiveFromEndpoint(endpoint, ssl_socket); + if (last_error) continue; - } } else { - logger_->log_error("Could not create socket for %s", endpoint); - } - auto* future = new std::future(); - std::unique_ptr> after_execute = std::unique_ptr>(new SocketAfterExecute(running_, endpoint, &live_clients_, &mutex_)); - utils::Worker functor(f_ex, "workers", std::move(after_execute)); - client_thread_pool_.execute(std::move(functor), *future); - live_clients_[endpoint] = future; - } else { - if (!endPointFuture->second->valid()) { - delete endPointFuture->second; - auto* future = new std::future(); - std::unique_ptr> after_execute = std::unique_ptr>(new SocketAfterExecute(running_, endpoint, &live_clients_, &mutex_)); - utils::Worker functor(f_ex, "workers", std::move(after_execute)); - client_thread_pool_.execute(std::move(functor), *future); - live_clients_[endpoint] = future; - } else { - logger_->log_debug("Thread still running for %s", endPointFuture->first); - // we have a thread corresponding to this. + utils::net::TcpSocket tcp_socket(io_context_); + last_error = co_await doReceiveFromEndpoint(endpoint, tcp_socket); + if (last_error) + continue; } } + logger_->log_error("Error connecting to %s:%s due to %s", connection_id.getHostname().data(), connection_id.getService().data(), last_error.message()); + co_await utils::net::async_wait(reconnection_interval_); } - logger_->log_debug("Updating endpoint"); - context->yield(); } REGISTER_RESOURCE(GetTCP, Processor); diff --git a/extensions/standard-processors/processors/GetTCP.h b/extensions/standard-processors/processors/GetTCP.h index 7e0dd03fd3..67207643c6 100644 --- a/extensions/standard-processors/processors/GetTCP.h +++ b/extensions/standard-processors/processors/GetTCP.h @@ -23,6 +23,8 @@ #include #include #include +#include +#include "utils/Literals.h" #include "../core/state/nodes/MetricsBase.h" #include "FlowFileRecord.h" @@ -36,64 +38,11 @@ #include "controllers/SSLContextService.h" #include "utils/gsl.h" #include "utils/Export.h" +#include "utils/net/AsioSocketUtils.h" +#include "utils/net/Message.h" namespace org::apache::nifi::minifi::processors { -class SocketAfterExecute : public utils::AfterExecute { - public: - explicit SocketAfterExecute(std::atomic &running, std::string endpoint, std::map*> *list, std::mutex *mutex) - : running_(running.load()), - endpoint_(std::move(endpoint)), - mutex_(mutex), - list_(list) { - } - - SocketAfterExecute(const SocketAfterExecute&) = delete; - SocketAfterExecute(SocketAfterExecute&&) = delete; - - SocketAfterExecute& operator=(const SocketAfterExecute&) = delete; - SocketAfterExecute& operator=(SocketAfterExecute&&) = delete; - - ~SocketAfterExecute() override = default; - - bool isFinished(const int &result) override { - if (result == -1 || result == 0 || !running_) { - std::lock_guard lock(*mutex_); - list_->erase(endpoint_); - return true; - } else { - return false; - } - } - bool isCancelled(const int& /*result*/) override { - return !running_; - } - - std::chrono::steady_clock::duration wait_time() override { - // wait 500ms - return std::chrono::milliseconds(500); - } - - protected: - std::atomic running_; - std::string endpoint_; - std::mutex *mutex_; - std::map*> *list_; -}; - -class DataHandler { - public: - DataHandler(std::shared_ptr sessionFactory) // NOLINT - : sessionFactory_(std::move(sessionFactory)) { - } - static const char *SOURCE_ENDPOINT_ATTRIBUTE; - - int16_t handle(const std::string& source, uint8_t *message, size_t size, bool partial); - - private: - std::shared_ptr sessionFactory_; -}; - class GetTCP : public core::Processor { public: explicit GetTCP(std::string name, const utils::Identifier& uuid = {}) @@ -101,30 +50,35 @@ class GetTCP : public core::Processor { } ~GetTCP() override { - // thread pool must be shut down first before members it is using are destructed, otherwise segfault is possible - client_thread_pool_.shutdown(); + if (client_) { + client_->stop(); + } + if (client_thread_.joinable()) { + client_thread_.join(); + } + client_.reset(); } EXTENSIONAPI static constexpr const char* Description = "Establishes a TCP Server that defines and retrieves one or more byte messages from clients"; EXTENSIONAPI static const core::Property EndpointList; - EXTENSIONAPI static const core::Property ConcurrentHandlers; - EXTENSIONAPI static const core::Property ReconnectInterval; - EXTENSIONAPI static const core::Property StayConnected; - EXTENSIONAPI static const core::Property ReceiveBufferSize; EXTENSIONAPI static const core::Property SSLContextService; - EXTENSIONAPI static const core::Property ConnectionAttemptLimit; - EXTENSIONAPI static const core::Property EndOfMessageByte; + EXTENSIONAPI static const core::Property MessageDelimiter; + EXTENSIONAPI static const core::Property MaxQueueSize; + EXTENSIONAPI static const core::Property MaxMessageSize; + EXTENSIONAPI static const core::Property MaxBatchSize; + EXTENSIONAPI static const core::Property Timeout; + EXTENSIONAPI static const core::Property ReconnectInterval; static auto properties() { return std::array{ EndpointList, - ConcurrentHandlers, - ReconnectInterval, - StayConnected, - ReceiveBufferSize, SSLContextService, - ConnectionAttemptLimit, - EndOfMessageByte + MessageDelimiter, + MaxQueueSize, + MaxMessageSize, + MaxBatchSize, + Timeout, + ReconnectInterval }; } @@ -137,6 +91,10 @@ class GetTCP : public core::Processor { EXTENSIONAPI static constexpr core::annotation::Input InputRequirement = core::annotation::Input::INPUT_ALLOWED; EXTENSIONAPI static constexpr bool IsSingleThreaded = false; + EXTENSIONAPI static const core::OutputAttribute SourceEndpoint; + + static auto outputAttributes() { return std::array{SourceEndpoint}; } + ADD_COMMON_VIRTUAL_FUNCTIONS_FOR_PROCESSORS void onSchedule(const std::shared_ptr &processContext, const std::shared_ptr &sessionFactory) override; @@ -148,28 +106,60 @@ class GetTCP : public core::Processor { throw std::logic_error{"GetTCP::onTrigger(ProcessContext*, ProcessSession*) is unimplemented"}; } void initialize() override; - - protected: void notifyStop() override; private: - std::function f_ex; - std::atomic running_{false}; - std::unique_ptr handler_; - std::vector endpoints; - std::map*> live_clients_; - moodycamel::ConcurrentQueue> socket_ring_buffer_; - bool stay_connected_{true}; - uint16_t concurrent_handlers_{2}; - std::byte endOfMessageByte{13}; - std::chrono::milliseconds reconnect_interval_{5000}; - uint64_t receive_buffer_size_{16 * 1024 * 1024}; - uint16_t connection_attempt_limit_{3}; - // Mutex for ensuring clients are running - std::mutex mutex_; - std::shared_ptr ssl_service_; + static void transferAsFlowFile(const utils::net::Message& message, core::ProcessSession& session); + + std::vector parseEndpointList(core::ProcessContext& context); + static char parseDelimiter(core::ProcessContext& context); + static std::optional parseSSLContext(core::ProcessContext& context); + static uint64_t parseMaxBatchSize(core::ProcessContext& context); + + class TcpClient { + public: + TcpClient(char delimiter, + asio::steady_timer::duration timeout_duration, + asio::steady_timer::duration reconnection_interval, + std::optional ssl_context, + std::optional max_queue_size, + std::optional max_message_size, + std::vector connections, + std::shared_ptr logger); + + ~TcpClient(); + + void run(); + void stop(); + + bool queueEmpty() const; + bool tryDequeue(utils::net::Message& received_message); + + private: + asio::awaitable doReceiveFrom(const utils::net::ConnectionId& connection_id); + + template + asio::awaitable doReceiveFromEndpoint(const asio::ip::tcp::endpoint& endpoint, SocketType& socket); + + asio::awaitable readLoop(auto& socket); + + utils::ConcurrentQueue concurrent_queue_; + asio::io_context io_context_; + + char delimiter_; + asio::steady_timer::duration timeout_duration_; + asio::steady_timer::duration reconnection_interval_; + std::optional ssl_context_; + std::optional max_queue_size_; + std::optional max_message_size_; + std::vector connections_; + std::shared_ptr logger_; + }; + + std::optional client_; + size_t max_batch_size_{500}; + std::thread client_thread_; std::shared_ptr logger_ = core::logging::LoggerFactory::getLogger(uuid_); - utils::ThreadPool client_thread_pool_; }; } // namespace org::apache::nifi::minifi::processors diff --git a/extensions/standard-processors/processors/PutTCP.cpp b/extensions/standard-processors/processors/PutTCP.cpp index 19144d8fbb..57e96e6c4e 100644 --- a/extensions/standard-processors/processors/PutTCP.cpp +++ b/extensions/standard-processors/processors/PutTCP.cpp @@ -113,20 +113,6 @@ void PutTCP::initialize() { void PutTCP::notifyStop() {} -namespace { -asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service) { - asio::ssl::context ssl_context(asio::ssl::context::tls_client); - ssl_context.set_options(asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); - ssl_context.load_verify_file(ssl_context_service.getCACertificate().string()); - ssl_context.set_verify_mode(asio::ssl::verify_peer); - if (const auto& cert_file = ssl_context_service.getCertificateFile(); !cert_file.empty()) - ssl_context.use_certificate_file(cert_file.string(), asio::ssl::context::pem); - if (const auto& private_key_file = ssl_context_service.getPrivateKeyFile(); !private_key_file.empty()) - ssl_context.use_private_key_file(private_key_file.string(), asio::ssl::context::pem); - ssl_context.set_password_callback([password = ssl_context_service.getPassphrase()](std::size_t&, asio::ssl::context_base::password_purpose&) { return password; }); - return ssl_context; -} -} // namespace void PutTCP::onSchedule(core::ProcessContext* const context, core::ProcessSessionFactory*) { gsl_Expects(context); @@ -158,7 +144,7 @@ void PutTCP::onSchedule(core::ProcessContext* const context, core::ProcessSessio if (context->getProperty(SSLContextService.getName(), context_name) && !IsNullOrEmpty(context_name)) { if (auto controller_service = context->getControllerService(context_name)) { if (auto ssl_context_service = std::dynamic_pointer_cast(context->getControllerService(context_name))) { - ssl_context_ = getSslContext(*ssl_context_service); + ssl_context_ = utils::net::getSslContext(*ssl_context_service); } else { throw Exception(PROCESS_SCHEDULE_EXCEPTION, context_name + " is not an SSL Context Service"); } @@ -177,20 +163,11 @@ void PutTCP::onSchedule(core::ProcessContext* const context, core::ProcessSessio } namespace { -template -asio::awaitable> handshake(SocketType&, asio::steady_timer::duration) { - co_return std::error_code(); -} - -template<> -asio::awaitable> handshake(SslSocket& socket, asio::steady_timer::duration timeout_duration) { - co_return co_await asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, use_nothrow_awaitable), timeout_duration); // NOLINT -} template class ConnectionHandler : public ConnectionHandlerBase { public: - ConnectionHandler(detail::ConnectionId connection_id, + ConnectionHandler(utils::net::ConnectionId connection_id, std::chrono::milliseconds timeout, std::shared_ptr logger, std::optional max_size_of_socket_send_buffer, @@ -227,11 +204,11 @@ class ConnectionHandler : public ConnectionHandlerBase { SocketType createNewSocket(asio::io_context& io_context_); - detail::ConnectionId connection_id_; + utils::net::ConnectionId connection_id_; std::optional socket_; std::optional last_used_; - std::chrono::milliseconds timeout_duration_; + asio::steady_timer::duration timeout_duration_; std::shared_ptr logger_; std::optional max_size_of_socket_send_buffer_; @@ -262,7 +239,7 @@ asio::awaitable ConnectionHandler::establishNewConn last_error = connection_error; continue; } - auto [handshake_error] = co_await handshake(socket, timeout_duration_); + auto [handshake_error] = co_await utils::net::handshake(socket, timeout_duration_); if (handshake_error) { core::logging::LOG_DEBUG(logger_) << "Handshake with " << endpoint.endpoint() << " failed due to " << handshake_error.message(); last_error = handshake_error; @@ -281,7 +258,8 @@ template if (hasUsableSocket()) co_return std::error_code(); tcp::resolver resolver(io_context); - auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout(resolver.async_resolve(connection_id_.getHostname(), connection_id_.getPort(), use_nothrow_awaitable), timeout_duration_); + auto [resolve_error, resolve_result] = co_await asyncOperationWithTimeout( + resolver.async_resolve(connection_id_.getHostname(), connection_id_.getService(), use_nothrow_awaitable), timeout_duration_); if (resolve_error) co_return resolve_error; co_return co_await establishNewConnection(resolve_result, io_context); @@ -343,7 +321,7 @@ void PutTCP::onTrigger(core::ProcessContext* context, core::ProcessSession* cons return; } - auto connection_id = detail::ConnectionId(std::move(hostname), std::move(port)); + auto connection_id = utils::net::ConnectionId(std::move(hostname), std::move(port)); std::shared_ptr handler; if (!connections_ || !connections_->contains(connection_id)) { if (ssl_context_) diff --git a/extensions/standard-processors/processors/PutTCP.h b/extensions/standard-processors/processors/PutTCP.h index 58ae28f94a..113a38d2f2 100644 --- a/extensions/standard-processors/processors/PutTCP.h +++ b/extensions/standard-processors/processors/PutTCP.h @@ -31,39 +31,12 @@ #include "utils/expected.h" #include "utils/StringUtils.h" // for string <=> on libc++ +#include "utils/net/AsioSocketUtils.h" #include #include #include -namespace org::apache::nifi::minifi::processors::detail { - -class ConnectionId { - public: - ConnectionId(std::string hostname, std::string port) : hostname_(std::move(hostname)), port_(std::move(port)) {} - - auto operator<=>(const ConnectionId&) const = default; - - [[nodiscard]] std::string_view getHostname() const { return hostname_; } - [[nodiscard]] std::string_view getPort() const { return port_; } - - private: - std::string hostname_; - std::string port_; -}; -} // namespace org::apache::nifi::minifi::processors::detail - -namespace std { -template<> -struct hash { - size_t operator()(const org::apache::nifi::minifi::processors::detail::ConnectionId& connection_id) const { - return org::apache::nifi::minifi::utils::hash_combine( - std::hash{}(connection_id.getHostname()), - std::hash{}(connection_id.getPort())); - } -}; -} // namespace std - namespace org::apache::nifi::minifi::processors { class ConnectionHandlerBase { public: @@ -128,7 +101,7 @@ class PutTCP final : public core::Processor { std::vector delimiter_; asio::io_context io_context_; - std::optional>> connections_; + std::optional>> connections_; std::optional idle_connection_expiration_; std::optional max_size_of_socket_send_buffer_; std::chrono::milliseconds timeout_duration_ = std::chrono::seconds(15); diff --git a/extensions/standard-processors/processors/TailFile.cpp b/extensions/standard-processors/processors/TailFile.cpp index 97bcfa04ab..12281a21af 100644 --- a/extensions/standard-processors/processors/TailFile.cpp +++ b/extensions/standard-processors/processors/TailFile.cpp @@ -49,6 +49,7 @@ #include "core/PropertyBuilder.h" #include "core/Resource.h" #include "utils/RegexUtils.h" +#include "utils/expected.h" namespace org::apache::nifi::minifi::processors { @@ -173,19 +174,6 @@ uint64_t readOptionalUint64(const Container &container, const Key &key) { } } -// the delimiter is the first character of the input, allowing some escape sequences -std::string parseDelimiter(const std::string &input) { - if (input.empty()) return ""; - if (input[0] != '\\') return std::string{ input[0] }; - if (input.size() == std::size_t{1}) return "\\"; - switch (input[1]) { - case 'r': return "\r"; - case 't': return "\t"; - case 'n': return "\n"; - default: return std::string{ input[1] }; - } -} - std::map update_keys_in_legacy_states(const std::map &legacy_tail_states) { std::map new_tail_states; for (const auto &key_value_pair : legacy_tail_states) { @@ -326,6 +314,19 @@ class WholeFileReaderCallback { std::ifstream input_stream_; std::shared_ptr logger_ = core::logging::LoggerFactory::getLogger(); }; + +// This is for backwards compatibility only, as it will accept any string as Input Delimiter while only use the first character from it, which can be confusing +std::optional getDelimiterOld(const std::string& delimiter_str) { + if (delimiter_str.empty()) return std::nullopt; + if (delimiter_str[0] != '\\') return delimiter_str[0]; + if (delimiter_str.size() == 1) return '\\'; + switch (delimiter_str[1]) { + case 'r': return '\r'; + case 't': return '\t'; + case 'n': return '\n'; + default: return delimiter_str[1]; + } +} } // namespace void TailFile::initialize() { @@ -343,10 +344,14 @@ void TailFile::onSchedule(const std::shared_ptr &context, throw Exception(PROCESSOR_EXCEPTION, "Failed to get StateManager"); } - std::string value; - - if (context->getProperty(Delimiter.getName(), value)) { - delimiter_ = parseDelimiter(value); + if (auto delimiter_str = context->getProperty(Delimiter)) { + if (auto parsed_delimiter = utils::StringUtils::parseCharacter(*delimiter_str)) { + delimiter_ = *parsed_delimiter; + } else { + logger_->log_error("Invalid %s: \"%s\" (it should be a single character, whether escaped or not). Using the first character as the %s", + TailFile::Delimiter.getName(), *delimiter_str, TailFile::Delimiter.getName()); + delimiter_ = getDelimiterOld(*delimiter_str); + } } std::string file_name_str; @@ -788,12 +793,11 @@ void TailFile::processSingleFile(const std::shared_ptr &se if (extension.starts_with('.')) extension.erase(extension.begin()); - if (!delimiter_.empty()) { - char delim = delimiter_[0]; - logger_->log_trace("Looking for delimiter 0x%X", delim); + if (delimiter_) { + logger_->log_trace("Looking for delimiter 0x%X", *delimiter_); std::size_t num_flow_files = 0; - FileReaderCallback file_reader{full_file_name, state.position_, delim, state.checksum_}; + FileReaderCallback file_reader{full_file_name, state.position_, *delimiter_, state.checksum_}; TailState state_copy{state}; while (file_reader.hasMoreToRead() && (!batch_size_ || *batch_size_ > num_flow_files)) { diff --git a/extensions/standard-processors/processors/TailFile.h b/extensions/standard-processors/processors/TailFile.h index 286ec6b8bc..705717dd2b 100644 --- a/extensions/standard-processors/processors/TailFile.h +++ b/extensions/standard-processors/processors/TailFile.h @@ -200,7 +200,7 @@ class TailFile : public core::Processor { static const char *POSITION_STR; static const int BUFFER_SIZE = 512; - std::string delimiter_; // Delimiter for the data incoming from the tailed file. + std::optional delimiter_; // Delimiter for the data incoming from the tailed file. core::StateManager* state_manager_ = nullptr; std::map tail_states_; Mode tail_mode_ = Mode::UNDEFINED; diff --git a/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp b/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp index 4dc638c5bc..68492d5bb1 100644 --- a/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp +++ b/extensions/standard-processors/tests/integration/SecureSocketGetTCPTest.cpp @@ -77,7 +77,7 @@ class SecureSocketTest : public IntegrationBase { void runAssertions() override { using org::apache::nifi::minifi::utils::verifyLogLinePresenceInPollTime; - assert(verifyLogLinePresenceInPollTime(std::chrono::milliseconds(wait_time_), "SSL socket connect success")); + assert(verifyLogLinePresenceInPollTime(std::chrono::milliseconds(wait_time_), "Accepted on")); isRunning_ = false; server_socket_.reset(); assert(verifyLogLinePresenceInPollTime(std::chrono::milliseconds(wait_time_), "send succeed 20")); diff --git a/extensions/standard-processors/tests/unit/GetTCPTests.cpp b/extensions/standard-processors/tests/unit/GetTCPTests.cpp index a267b67a5f..6c537d8df5 100644 --- a/extensions/standard-processors/tests/unit/GetTCPTests.cpp +++ b/extensions/standard-processors/tests/unit/GetTCPTests.cpp @@ -1,5 +1,4 @@ /** - * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -15,391 +14,285 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include #include -#include -#include -#include "unit/ProvenanceTestHelper.h" -#include "TestBase.h" -#include "Catch.h" -#include "RandomServerSocket.h" -#include "Scheduling.h" -#include "LogAttribute.h" -#include "GetTCP.h" -#include "core/Core.h" -#include "core/FlowFile.h" -#include "core/Processor.h" -#include "core/ProcessContext.h" -#include "core/ProcessSession.h" -#include "core/ProcessorNode.h" -#include "core/reporting/SiteToSiteProvenanceReportingTask.h" - -TEST_CASE("GetTCPWithoutEOM", "[GetTCP1]") { - TestController testController; - std::vector buffer; - for (auto c : "Hello World\nHello Warld\nGoodByte Cruel world") { - buffer.push_back(c); - } - std::shared_ptr content_repo = std::make_shared(); - - content_repo->initialize(std::make_shared()); - - std::shared_ptr stream_factory = minifi::io::StreamFactory::getInstance(std::make_shared()); - org::apache::nifi::minifi::io::RandomServerSocket server(org::apache::nifi::minifi::io::Socket::getMyHostName()); - - LogTestController::getInstance().setDebug(); - LogTestController::getInstance().setDebug(); - LogTestController::getInstance().setTrace(); - - std::shared_ptr repo = std::make_shared(); - - auto processor = std::make_unique("gettcpexample"); - - auto logAttribute = std::make_unique("logattribute"); - - processor->setStreamFactory(stream_factory); - processor->initialize(); - - utils::Identifier processoruuid = processor->getUUID(); - REQUIRE(processoruuid); - - utils::Identifier logattribute_uuid = logAttribute->getUUID(); - REQUIRE(logattribute_uuid); - - REQUIRE(processoruuid.to_string() != logattribute_uuid.to_string()); - - auto connection = std::make_unique(repo, content_repo, "gettcpexampleConnection"); - connection->addRelationship(core::Relationship("success", "description")); - - auto connection2 = std::make_unique(repo, content_repo, "logattribute"); - connection2->addRelationship(core::Relationship("success", "description")); - - // link the connections so that we can test results at the end for this - connection->setSource(processor.get()); - - // link the connections so that we can test results at the end for this - connection->setDestination(logAttribute.get()); - connection2->setSource(logAttribute.get()); - - connection2->setSourceUUID(logattribute_uuid); - connection->setSourceUUID(processoruuid); - connection->setDestinationUUID(logattribute_uuid); - - processor->addConnection(connection.get()); - logAttribute->addConnection(connection.get()); - logAttribute->addConnection(connection2.get()); - - auto node = std::make_shared(processor.get()); - auto node2 = std::make_shared(logAttribute.get()); - auto context = std::make_shared(node, nullptr, repo, repo, content_repo); - auto context2 = std::make_shared(node2, nullptr, repo, repo, content_repo); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndpointList, org::apache::nifi::minifi::io::Socket::getMyHostName() + ":" + std::to_string(server.getPort())); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval, "200 msec"); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit, "10"); - auto session = std::make_shared(context); - auto session2 = std::make_shared(context2); - - REQUIRE(processor->getName() == "gettcpexample"); - - std::shared_ptr record; - processor->setScheduledState(core::ScheduledState::RUNNING); - - std::shared_ptr factory = std::make_shared(context); - processor->onSchedule(context, factory); - processor->onTrigger(context, session); - server.write(buffer, buffer.size()); - std::this_thread::sleep_for(std::chrono::seconds(2)); - - logAttribute->initialize(); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - std::shared_ptr factory2 = std::make_shared(context2); - logAttribute->onSchedule(context2, factory2); - logAttribute->onTrigger(context2, session2); +#include "Catch.h" +#include "processors/GetTCP.h" +#include "SingleProcessorTestController.h" +#include "Utils.h" +#include "utils/net/AsioCoro.h" +#include "utils/net/AsioSocketUtils.h" +#include "controllers/SSLContextService.h" +#include "range/v3/algorithm/contains.hpp" +#include "utils/gsl.h" - auto reporter = session->getProvenanceReporter(); - auto records = reporter->getEvents(); - record = session->get(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); +using GetTCP = org::apache::nifi::minifi::processors::GetTCP; - processor->incrementActiveTasks(); - processor->setScheduledState(core::ScheduledState::RUNNING); - processor->onTrigger(context, session); - reporter = session->getProvenanceReporter(); +using namespace std::literals::chrono_literals; - session->commit(); +namespace org::apache::nifi::minifi::test { - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); +void check_for_attributes(core::FlowFile& flow_file, uint16_t port) { + const auto local_addresses = {"127.0.0.1:" + std::to_string(port), "::ffff:127.0.0.1:" + std::to_string(port), "::1:" + std::to_string(port)}; + CHECK(ranges::contains(local_addresses, flow_file.getAttribute(GetTCP::SourceEndpoint.getName()))); +} - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Size:45 Offset:0")); +minifi::utils::net::SslData createSslDataForServer() { + const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); + minifi::utils::net::SslData ssl_data; + ssl_data.ca_loc = (executable_dir / "resources" / "ca_A.crt").string(); + ssl_data.cert_loc = (executable_dir / "resources" / "localhost_by_A.pem").string(); + ssl_data.key_loc = (executable_dir / "resources" / "localhost_by_A.pem").string(); + return ssl_data; +} - LogTestController::getInstance().reset(); +void addSslContextServiceTo(SingleProcessorTestController& controller) { + auto ssl_context_service = controller.plan->addController("SSLContextService", "SSLContextService"); + LogTestController::getInstance().setTrace(); + const auto executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); + REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::CACertificate.getName(), (executable_dir / "resources" / "ca_A.crt").string())); + REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::ClientCertificate.getName(), (executable_dir / "resources" / "alice_by_A.pem").string())); + REQUIRE(controller.plan->setProperty(ssl_context_service, controllers::SSLContextService::PrivateKey.getName(), (executable_dir / "resources" / "alice_by_A.pem").string())); + ssl_context_service->enable(); } -TEST_CASE("GetTCPWithOEM", "[GetTCP2]") { - std::vector buffer; - for (auto c : "Hello World\nHello Warld\nGoodByte Cruel world") { - buffer.push_back(c); +class TcpTestServer { + public: + void run() { + server_thread_ = std::thread([&]() { + asio::co_spawn(io_context_, listenAndSendMessages(), asio::detached); + io_context_.run(); + }); } - std::shared_ptr content_repo = std::make_shared(); - - content_repo->initialize(std::make_shared()); - - std::shared_ptr stream_factory = minifi::io::StreamFactory::getInstance(std::make_shared()); - - TestController testController; - - org::apache::nifi::minifi::io::RandomServerSocket server(org::apache::nifi::minifi::io::Socket::getMyHostName()); - - LogTestController::getInstance().setDebug(); - LogTestController::getInstance().setTrace(); - LogTestController::getInstance().setTrace(); - LogTestController::getInstance().setTrace(); - LogTestController::getInstance().setTrace(); - - std::shared_ptr repo = std::make_shared(); - - std::shared_ptr processor = std::make_shared("gettcpexample"); - std::shared_ptr logAttribute = std::make_shared("logattribute"); - - processor->setStreamFactory(stream_factory); - processor->initialize(); - - utils::Identifier processoruuid = processor->getUUID(); - REQUIRE(processoruuid); - - utils::Identifier logattribute_uuid = logAttribute->getUUID(); - REQUIRE(logattribute_uuid); - - auto connection = std::make_unique(repo, content_repo, "gettcpexampleConnection"); - connection->addRelationship(core::Relationship("partial", "description")); - - auto connection2 = std::make_unique(repo, content_repo, "logattribute"); - connection2->addRelationship(core::Relationship("partial", "description")); - - // link the connections so that we can test results at the end for this - connection->setSource(processor.get()); - - // link the connections so that we can test results at the end for this - connection->setDestination(logAttribute.get()); - - connection2->setSource(logAttribute.get()); - - connection2->setSourceUUID(logattribute_uuid); - connection->setSourceUUID(processoruuid); - connection->setDestinationUUID(logattribute_uuid); - - processor->addConnection(connection.get()); - logAttribute->addConnection(connection.get()); - logAttribute->addConnection(connection2.get()); - - auto node = std::make_shared(processor.get()); - auto node2 = std::make_shared(logAttribute.get()); - auto context = std::make_shared(node, nullptr, repo, repo, content_repo); - auto context2 = std::make_shared(node2, nullptr, repo, repo, content_repo); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndpointList, org::apache::nifi::minifi::io::Socket::getMyHostName() + ":" + std::to_string(server.getPort())); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval, "200 msec"); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit, "10"); - // we're using new lines above - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndOfMessageByte, "10"); - auto session = std::make_shared(context); - auto session2 = std::make_shared(context2); - - - REQUIRE(processor->getName() == "gettcpexample"); - - std::shared_ptr record; - processor->setScheduledState(core::ScheduledState::RUNNING); - - std::shared_ptr factory = std::make_shared(context); - processor->onSchedule(context, factory); - processor->onTrigger(context, session); - server.write(buffer, buffer.size()); - std::this_thread::sleep_for(std::chrono::seconds(2)); + void queueMessage(std::string message) { + messages_to_send_.enqueue(std::move(message)); + } - logAttribute->initialize(); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - std::shared_ptr factory2 = std::make_shared(context2); - logAttribute->onSchedule(context2, factory2); - logAttribute->onTrigger(context2, session2); + void enableSSL() { + const std::filesystem::path executable_dir = minifi::utils::file::FileUtils::get_executable_dir(); - auto reporter = session->getProvenanceReporter(); - auto records = reporter->getEvents(); - record = session->get(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); + asio::ssl::context ssl_context(asio::ssl::context::tls_server); + ssl_context.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::single_dh_use | asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + ssl_context.set_password_callback([key_pw = "Password12"](std::size_t&, asio::ssl::context_base::password_purpose&) { return key_pw; }); + ssl_context.use_certificate_file((executable_dir / "resources" / "localhost_by_A.pem").string(), asio::ssl::context::pem); + ssl_context.use_private_key_file((executable_dir / "resources" / "localhost_by_A.pem").string(), asio::ssl::context::pem); + ssl_context.load_verify_file((executable_dir / "resources" / "ca_A.crt").string()); + ssl_context.set_verify_mode(asio::ssl::verify_peer); - processor->incrementActiveTasks(); - processor->setScheduledState(core::ScheduledState::RUNNING); - processor->onTrigger(context, session); - reporter = session->getProvenanceReporter(); + ssl_context_ = std::move(ssl_context); + } - session->commit(); + uint16_t getPort() const { + return port_; + } - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); + ~TcpTestServer() { + io_context_.stop(); + if (server_thread_.joinable()) + server_thread_.join(); + } - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); + private: + asio::awaitable sendMessages(auto& socket) { + while (true) { + std::string message_to_send; + if (!messages_to_send_.tryDequeue(message_to_send)) { + co_await minifi::utils::net::async_wait(10ms); + continue; + } + co_await asio::async_write(socket, asio::buffer(message_to_send), minifi::utils::net::use_nothrow_awaitable); + } + } - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Size:11 Offset:0")); - REQUIRE(true == LogTestController::getInstance().contains("Size:12 Offset:0")); - REQUIRE(true == LogTestController::getInstance().contains("Size:22 Offset:0")); + asio::awaitable secureSession(asio::ip::tcp::socket socket) { + gsl_Expects(ssl_context_); + minifi::utils::net::SslSocket ssl_socket(std::move(socket), *ssl_context_); + auto [handshake_error] = co_await ssl_socket.async_handshake(minifi::utils::net::HandshakeType::server, minifi::utils::net::use_nothrow_awaitable); + if (handshake_error) { + co_return; + } + co_await sendMessages(ssl_socket); + } - LogTestController::getInstance().reset(); -} + asio::awaitable insecureSession(asio::ip::tcp::socket socket) { + co_await sendMessages(socket); + } -TEST_CASE("GetTCPWithOnlyOEM", "[GetTCP3]") { - std::vector buffer; - for (auto c : "\n") { - buffer.push_back(c); + asio::awaitable listenAndSendMessages() { + asio::ip::tcp::acceptor acceptor(io_context_, asio::ip::tcp::endpoint(asio::ip::tcp::v6(), port_)); + if (port_ == 0) + port_ = acceptor.local_endpoint().port(); + while (true) { + auto [accept_error, socket] = co_await acceptor.async_accept(minifi::utils::net::use_nothrow_awaitable); + if (accept_error) { + co_return; + } + if (ssl_context_) + co_spawn(io_context_, secureSession(std::move(socket)), asio::detached); + else + co_spawn(io_context_, insecureSession(std::move(socket)), asio::detached); + } } - std::shared_ptr content_repo = std::make_shared(); + std::optional ssl_context_; + minifi::utils::ConcurrentQueue messages_to_send_; + std::atomic port_ = 0; + std::thread server_thread_; + asio::io_context io_context_; +}; - content_repo->initialize(std::make_shared()); +TEST_CASE("GetTCP test with delimiter", "[GetTCP]") { + const auto get_tcp = std::make_shared("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2")); - std::shared_ptr stream_factory = minifi::io::StreamFactory::getInstance(std::make_shared()); - TestController testController; + TcpTestServer tcp_test_server; - LogTestController::getInstance().setDebug(); + SECTION("No SSL") {} - org::apache::nifi::minifi::io::RandomServerSocket server(org::apache::nifi::minifi::io::Socket::getMyHostName()); + SECTION("SSL") { + addSslContextServiceTo(controller); + tcp_test_server.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } - LogTestController::getInstance().setDebug(); + tcp_test_server.queueMessage("Hello\n"); + tcp_test_server.run(); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return tcp_test_server.getPort() != 0; }, 20ms)); - LogTestController::getInstance().setDebug(); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", tcp_test_server.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - std::shared_ptr repo = std::make_shared(); + ProcessorTriggerResult result; + REQUIRE(controller.triggerUntil({{GetTCP::Success, 1}}, result, 1s, 50ms)); + CHECK(controller.plan->getContent(result.at(GetTCP::Success)[0]) == "Hello\n"); - std::shared_ptr processor = std::make_shared("gettcpexample"); + check_for_attributes(*result.at(GetTCP::Success)[0], tcp_test_server.getPort()); +} - std::shared_ptr logAttribute = std::make_shared("logattribute"); +TEST_CASE("GetTCP test with too large message", "[GetTCP]") { + const auto get_tcp = std::make_shared("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2")); + REQUIRE(get_tcp->setProperty(GetTCP::MaxMessageSize, "10")); + REQUIRE(get_tcp->setProperty(GetTCP::MessageDelimiter, "\r")); - processor->setStreamFactory(stream_factory); - processor->initialize(); + TcpTestServer tcp_test_server; - utils::Identifier processoruuid = processor->getUUID(); - REQUIRE(processoruuid); + SECTION("No SSL") {} - utils::Identifier logattribute_uuid = logAttribute->getUUID(); - REQUIRE(logattribute_uuid); + SECTION("SSL") { + addSslContextServiceTo(controller); + tcp_test_server.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } - auto connection = std::make_unique(repo, content_repo, "gettcpexampleConnection"); - connection->addRelationship(core::Relationship("success", "description")); + tcp_test_server.queueMessage("abcdefghijklmnopqrstuvwxyz\rBye\r"); + tcp_test_server.run(); - auto connection2 = std::make_unique(repo, content_repo, "logattribute"); - connection2->addRelationship(core::Relationship("success", "description")); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return tcp_test_server.getPort() != 0; }, 20ms)); - // link the connections so that we can test results at the end for this - connection->setSource(processor.get()); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", tcp_test_server.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - // link the connections so that we can test results at the end for this - connection->setDestination(logAttribute.get()); + ProcessorTriggerResult result; + REQUIRE(controller.triggerUntil({{GetTCP::Success, 1}}, result, 1s, 50ms)); + REQUIRE(result.at(GetTCP::Partial).size() == 3); + REQUIRE(result.at(GetTCP::Success).size() == 1); + CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[0]) == "abcdefghij"); + CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[1]) == "klmnopqrst"); + CHECK(controller.plan->getContent(result.at(GetTCP::Partial)[2]) == "uvwxyz\r"); + CHECK(controller.plan->getContent(result.at(GetTCP::Success)[0]) == "Bye\r"); - connection2->setSource(logAttribute.get()); + check_for_attributes(*result.at(GetTCP::Partial)[0], tcp_test_server.getPort()); + check_for_attributes(*result.at(GetTCP::Partial)[1], tcp_test_server.getPort()); + check_for_attributes(*result.at(GetTCP::Partial)[2], tcp_test_server.getPort()); + check_for_attributes(*result.at(GetTCP::Success)[0], tcp_test_server.getPort()); +} - connection2->setSourceUUID(logattribute_uuid); - connection->setSourceUUID(processoruuid); - connection->setDestinationUUID(logattribute_uuid); +TEST_CASE("GetTCP test multiple endpoints", "[GetTCP]") { + const auto get_tcp = std::make_shared("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "2")); - processor->addConnection(connection.get()); - logAttribute->addConnection(connection.get()); - logAttribute->addConnection(connection2.get()); + TcpTestServer server_1; + TcpTestServer server_2; - auto node = std::make_shared(processor.get()); - auto node2 = std::make_shared(logAttribute.get()); - auto context = std::make_shared(node, nullptr, repo, repo, content_repo); - auto context2 = std::make_shared(node2, nullptr, repo, repo, content_repo); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndpointList, org::apache::nifi::minifi::io::Socket::getMyHostName() + ":" + std::to_string(server.getPort())); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval, "200 msec"); - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit, "10"); - // we're using new lines above - context->setProperty(org::apache::nifi::minifi::processors::GetTCP::EndOfMessageByte, "10"); - auto session = std::make_shared(context); - auto session2 = std::make_shared(context2); + SECTION("No SSL") {} + SECTION("SSL") { + addSslContextServiceTo(controller); + server_1.enableSSL(); + server_2.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } - REQUIRE(processor->getName() == "gettcpexample"); + server_1.queueMessage("abcdefghijklmnopqrstuvwxyz\nBye\n"); + server_1.run(); - std::shared_ptr record; - processor->setScheduledState(core::ScheduledState::RUNNING); + server_2.queueMessage("012345678901234567890\nAuf Wiedersehen\n"); + server_2.run(); - std::shared_ptr factory = std::make_shared(context); - processor->onSchedule(context, factory); - processor->onTrigger(context, session); - server.write(buffer, buffer.size()); - std::this_thread::sleep_for(std::chrono::seconds(2)); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return server_1.getPort() != 0 && server_2.getPort() != 0; }, 20ms)); - logAttribute->initialize(); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - std::shared_ptr factory2 = std::make_shared(context2); - logAttribute->onSchedule(context2, factory2); - logAttribute->onTrigger(context2, session2); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{},localhost:{}", server_1.getPort(), server_2.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - auto reporter = session->getProvenanceReporter(); - auto records = reporter->getEvents(); - record = session->get(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); + ProcessorTriggerResult result; + CHECK(controller.triggerUntil({{GetTCP::Success, 4}}, result, 1s, 50ms)); + CHECK(result.at(GetTCP::Success).size() == 4); - processor->incrementActiveTasks(); - processor->setScheduledState(core::ScheduledState::RUNNING); - processor->onTrigger(context, session); - reporter = session->getProvenanceReporter(); + std::vector success_flow_file_contents; + for (const auto& flow_file: result.at(GetTCP::Success)) { + success_flow_file_contents.push_back(controller.plan->getContent(flow_file)); + } - session->commit(); + CHECK(ranges::contains(success_flow_file_contents, "abcdefghijklmnopqrstuvwxyz\n")); + CHECK(ranges::contains(success_flow_file_contents, "Bye\n")); + CHECK(ranges::contains(success_flow_file_contents, "012345678901234567890\n")); + CHECK(ranges::contains(success_flow_file_contents, "Auf Wiedersehen\n")); +} - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); +TEST_CASE("GetTCP max queue and max batch size test", "[GetTCP]") { + const auto get_tcp = std::make_shared("GetTCP"); + SingleProcessorTestController controller{get_tcp}; + LogTestController::getInstance().setTrace(); + REQUIRE(get_tcp->setProperty(GetTCP::MaxBatchSize, "10")); + REQUIRE(get_tcp->setProperty(GetTCP::MaxQueueSize, "50")); - logAttribute->incrementActiveTasks(); - logAttribute->setScheduledState(core::ScheduledState::RUNNING); - logAttribute->onTrigger(context2, session2); + TcpTestServer server; - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Size:2 Offset:0")); - LogTestController::getInstance().reset(); -} + SECTION("No SSL") {} + SECTION("SSL") { + addSslContextServiceTo(controller); + server.enableSSL(); + REQUIRE(get_tcp->setProperty(GetTCP::SSLContextService, "SSLContextService")); + } -TEST_CASE("GetTCPEmptyNoConnect", "[GetTCP3]") { - TestController testController; - LogTestController::getInstance().setDebug(); - LogTestController::getInstance().setDebug(); - LogTestController::getInstance().setTrace(); + LogTestController::getInstance().setWarn(); - std::shared_ptr plan = testController.createPlan(); - std::shared_ptr getfile = plan->addProcessor("GetTCP", "gettcpexample"); + for (auto i = 0; i < 100; ++i) { + server.queueMessage("some_message\n"); + } - plan->addProcessor("LogAttribute", "logattribute", core::Relationship("success", "description"), true); + server.run(); - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::EndpointList.getName(), org::apache::nifi::minifi::io::Socket::getMyHostName() + ":9182"); - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::ReconnectInterval.getName(), "200 msec"); - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::ConnectionAttemptLimit.getName(), "10"); - // we're using new lines above - plan->setProperty(getfile, org::apache::nifi::minifi::processors::GetTCP::EndOfMessageByte.getName(), "10"); + REQUIRE(minifi::utils::verifyEventHappenedInPollTime(250ms, [&] { return server.getPort() != 0; }, 20ms)); - TestController::runSession(plan, false); - auto records = plan->getProvenanceRecords(); - std::shared_ptr record = plan->getCurrentFlowFile(); - REQUIRE(record == nullptr); - REQUIRE(records.empty()); + REQUIRE(get_tcp->setProperty(GetTCP::EndpointList, fmt::format("localhost:{}", server.getPort()))); + controller.plan->scheduleProcessor(get_tcp); - REQUIRE(true == LogTestController::getInstance().contains("Reconnect interval is 200 ms")); - REQUIRE(true == LogTestController::getInstance().contains("Could not create socket during initialization for " + org::apache::nifi::minifi::io::Socket::getMyHostName() + ":9182")); - LogTestController::getInstance().reset(); + CHECK(utils::countLogOccurrencesUntil("Queue is full. TCP message ignored.", 50, 300ms, 50ms)); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).size() == 10); + CHECK(controller.trigger().at(GetTCP::Success).empty()); } +} // namespace org::apache::nifi::minifi::test diff --git a/libminifi/include/utils/StringUtils.h b/libminifi/include/utils/StringUtils.h index 9effff8ea2..0001f88d8b 100644 --- a/libminifi/include/utils/StringUtils.h +++ b/libminifi/include/utils/StringUtils.h @@ -32,6 +32,7 @@ #include #include #endif +#include "utils/expected.h" #include "utils/FailurePolicy.h" #include "utils/gsl.h" #include "utils/span.h" @@ -493,6 +494,10 @@ class StringUtils { static bool matchesSequence(std::string_view str, const std::vector& patterns); static bool splitToValueAndUnit(std::string_view input, int64_t& value, std::string& unit); + + struct ParseError {}; + + static nonstd::expected, ParseError> parseCharacter(const std::string &input); }; } // namespace org::apache::nifi::minifi::utils diff --git a/libminifi/include/utils/net/AsioCoro.h b/libminifi/include/utils/net/AsioCoro.h index 5c2e5268b4..55a3a4cbcc 100644 --- a/libminifi/include/utils/net/AsioCoro.h +++ b/libminifi/include/utils/net/AsioCoro.h @@ -35,10 +35,6 @@ namespace org::apache::nifi::minifi::utils::net { constexpr auto use_nothrow_awaitable = asio::experimental::as_tuple(asio::use_awaitable); -using HandshakeType = asio::ssl::stream_base::handshake_type; -using TcpSocket = asio::ip::tcp::socket; -using SslSocket = asio::ssl::stream; - #if defined(__GNUC__) && __GNUC__ < 11 // [coroutines] unexpected 'warning: statement has no effect [-Wunused-value]' // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=96749 @@ -52,19 +48,17 @@ inline asio::awaitable async_wait(asio::steady_timer& timer) { #pragma GCC diagnostic pop #endif // defined(__GNUC__) && __GNUC__ < 11 -namespace detail { -inline asio::awaitable timeout(std::chrono::steady_clock::duration duration) { +inline asio::awaitable async_wait(std::chrono::steady_clock::duration duration) { asio::steady_timer timer(co_await asio::this_coro::executor); // NOLINT timer.expires_after(duration); co_await async_wait(timer); } -} // namespace detail template asio::awaitable> asyncOperationWithTimeout(asio::awaitable>&& async_operation, std::chrono::steady_clock::duration timeout_duration) { using asio::experimental::awaitable_operators::operator||; - auto operation_result = co_await(std::move(async_operation) || detail::timeout(timeout_duration)); + auto operation_result = co_await(std::move(async_operation) || async_wait(timeout_duration)); // NOLINT if (operation_result.index() == 1) { std::tuple result; std::get<0>(result) = asio::error::timed_out; diff --git a/libminifi/include/utils/net/AsioSocketUtils.h b/libminifi/include/utils/net/AsioSocketUtils.h new file mode 100644 index 0000000000..9ae531232e --- /dev/null +++ b/libminifi/include/utils/net/AsioSocketUtils.h @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "asio/ssl.hpp" +#include "asio/ip/tcp.hpp" + +#include "AsioCoro.h" +#include "utils/Hash.h" +#include "utils/StringUtils.h" // for string <=> on libc++ +#include "controllers/SSLContextService.h" + + +namespace org::apache::nifi::minifi::utils::net { + +using HandshakeType = asio::ssl::stream_base::handshake_type; +using TcpSocket = asio::ip::tcp::socket; +using SslSocket = asio::ssl::stream; + +class ConnectionId { + public: + ConnectionId(std::string hostname, std::string port) : hostname_(std::move(hostname)), service_(std::move(port)) {} + ConnectionId(const ConnectionId& connection_id) = default; + ConnectionId(ConnectionId&& connection_id) = default; + + auto operator<=>(const ConnectionId&) const = default; + + [[nodiscard]] std::string_view getHostname() const { return hostname_; } + [[nodiscard]] std::string_view getService() const { return service_; } + + private: + std::string hostname_; + std::string service_; +}; + +template +asio::awaitable> handshake(SocketType&, asio::steady_timer::duration) = delete; +template<> +asio::awaitable> handshake(TcpSocket&, asio::steady_timer::duration); +template<> +asio::awaitable> handshake(SslSocket& socket, asio::steady_timer::duration); + + +asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service); +} // namespace org::apache::nifi::minifi::utils::net + +namespace std { +template<> +struct hash { + size_t operator()(const org::apache::nifi::minifi::utils::net::ConnectionId& connection_id) const { + return org::apache::nifi::minifi::utils::hash_combine( + std::hash{}(connection_id.getHostname()), + std::hash{}(connection_id.getService())); + } +}; +} // namespace std diff --git a/libminifi/include/utils/net/Message.h b/libminifi/include/utils/net/Message.h new file mode 100644 index 0000000000..2cc05f7170 --- /dev/null +++ b/libminifi/include/utils/net/Message.h @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include "IpProtocol.h" +#include "asio/ts/internet.hpp" + +namespace org::apache::nifi::minifi::utils::net { + +struct Message { + public: + Message() = default; + Message(std::string message_data, IpProtocol protocol, asio::ip::address sender_address, asio::ip::port_type server_port) + : message_data(std::move(message_data)), + protocol(protocol), + server_port(server_port), + sender_address(std::move(sender_address)) { + } + + bool is_partial = false; + std::string message_data; + IpProtocol protocol; + asio::ip::port_type server_port; + asio::ip::address sender_address; +}; + +} // namespace org::apache::nifi::minifi::utils::net diff --git a/libminifi/include/utils/net/Server.h b/libminifi/include/utils/net/Server.h index e84815c3b4..b36936ef61 100644 --- a/libminifi/include/utils/net/Server.h +++ b/libminifi/include/utils/net/Server.h @@ -25,30 +25,13 @@ #include "utils/MinifiConcurrentQueue.h" #include "core/logging/Logger.h" #include "asio/ts/buffer.hpp" -#include "asio/ts/internet.hpp" #include "asio/awaitable.hpp" #include "asio/co_spawn.hpp" #include "asio/detached.hpp" -#include "IpProtocol.h" +#include "Message.h" namespace org::apache::nifi::minifi::utils::net { -struct Message { - public: - Message() = default; - Message(std::string message_data, IpProtocol protocol, asio::ip::address sender_address, asio::ip::port_type server_port) - : message_data(std::move(message_data)), - protocol(protocol), - server_port(server_port), - sender_address(std::move(sender_address)) { - } - - std::string message_data; - IpProtocol protocol; - asio::ip::port_type server_port; - asio::ip::address sender_address; -}; - class Server { public: virtual void run() { diff --git a/libminifi/src/utils/StringUtils.cpp b/libminifi/src/utils/StringUtils.cpp index 0b02175698..9daf145de0 100644 --- a/libminifi/src/utils/StringUtils.cpp +++ b/libminifi/src/utils/StringUtils.cpp @@ -517,4 +517,24 @@ bool StringUtils::splitToValueAndUnit(std::string_view input, int64_t& value, st return true; } +nonstd::expected, StringUtils::ParseError> StringUtils::parseCharacter(const std::string &input) { + if (input.empty()) { return std::nullopt; } + if (input.size() == 1) { return input[0]; } + + if (input.size() == 2 && input.starts_with('\\')) { + switch (input[1]) { + case '0': return '\0'; // Null + case 'a': return '\a'; // Bell + case 'b': return '\b'; // Backspace + case 't': return '\t'; // Horizontal Tab + case 'n': return '\n'; // Line Feed + case 'v': return '\v'; // Vertical Tab + case 'f': return '\f'; // Form Feed + case 'r': return '\r'; // Carriage Return + default: return input[1]; + } + } + return nonstd::make_unexpected(ParseError{}); +} + } // namespace org::apache::nifi::minifi::utils diff --git a/libminifi/src/utils/net/AsioSocketUtils.cpp b/libminifi/src/utils/net/AsioSocketUtils.cpp new file mode 100644 index 0000000000..8eeb61ffe7 --- /dev/null +++ b/libminifi/src/utils/net/AsioSocketUtils.cpp @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/net/AsioSocketUtils.h" +#include "controllers/SSLContextService.h" + +namespace org::apache::nifi::minifi::utils::net { + +template<> +asio::awaitable> handshake(TcpSocket&, asio::steady_timer::duration) { + co_return std::error_code(); +} + +template<> +asio::awaitable> handshake(SslSocket& socket, asio::steady_timer::duration timeout_duration) { + co_return co_await asyncOperationWithTimeout(socket.async_handshake(HandshakeType::client, use_nothrow_awaitable), timeout_duration); // NOLINT +} + +asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service) { + asio::ssl::context ssl_context(asio::ssl::context::tls_client); + ssl_context.set_options(asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + ssl_context.load_verify_file(ssl_context_service.getCACertificate().string()); + ssl_context.set_verify_mode(asio::ssl::verify_peer); + ssl_context.set_password_callback([password = ssl_context_service.getPassphrase()](std::size_t&, asio::ssl::context_base::password_purpose&) { return password; }); + if (const auto& cert_file = ssl_context_service.getCertificateFile(); !cert_file.empty()) + ssl_context.use_certificate_file(cert_file.string(), asio::ssl::context::pem); + if (const auto& private_key_file = ssl_context_service.getPrivateKeyFile(); !private_key_file.empty()) + ssl_context.use_private_key_file(private_key_file.string(), asio::ssl::context::pem); + return ssl_context; +} +} // namespace org::apache::nifi::minifi::utils::net diff --git a/libminifi/src/utils/net/TcpServer.cpp b/libminifi/src/utils/net/TcpServer.cpp index b1fa06b200..c443bf347b 100644 --- a/libminifi/src/utils/net/TcpServer.cpp +++ b/libminifi/src/utils/net/TcpServer.cpp @@ -16,6 +16,9 @@ */ #include "utils/net/TcpServer.h" #include "utils/net/AsioCoro.h" +#include "utils/net/AsioSocketUtils.h" + +using namespace std::literals::chrono_literals; namespace org::apache::nifi::minifi::utils::net { @@ -27,7 +30,8 @@ asio::awaitable TcpServer::doReceive() { auto [accept_error, socket] = co_await acceptor.async_accept(use_nothrow_awaitable); if (accept_error) { logger_->log_error("Error during accepting new connection: %s", accept_error.message()); - break; + co_await utils::net::async_wait(1s); + continue; } if (ssl_data_) co_spawn(io_context_, secureSession(std::move(socket)), asio::detached); diff --git a/libminifi/test/resources/TestC2Metrics.yml b/libminifi/test/resources/TestC2Metrics.yml index ea3b7eb745..6a0af5e4c1 100644 --- a/libminifi/test/resources/TestC2Metrics.yml +++ b/libminifi/test/resources/TestC2Metrics.yml @@ -31,10 +31,10 @@ Processors: run duration nanos: 0 auto-terminated relationships list: Properties: - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute @@ -66,4 +66,3 @@ Connections: Controller Services: [] Remote Processing Groups: - diff --git a/libminifi/test/resources/TestGetTCPSecure.yml b/libminifi/test/resources/TestGetTCPSecure.yml index ecf56fea76..15618931b8 100644 --- a/libminifi/test/resources/TestGetTCPSecure.yml +++ b/libminifi/test/resources/TestGetTCPSecure.yml @@ -32,10 +32,8 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: d - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute diff --git a/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml b/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml index f9dd9a0d62..dfcaa2fdd4 100644 --- a/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml +++ b/libminifi/test/resources/TestGetTCPSecureEmptyPass.yml @@ -32,10 +32,10 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:29776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:29776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute @@ -88,4 +88,3 @@ Controller Services: - value: nifi-cert.pem Remote Processing Groups: - diff --git a/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml b/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml index c48aa25674..0e68be7827 100644 --- a/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml +++ b/libminifi/test/resources/TestGetTCPSecureWithFilePass.yml @@ -32,10 +32,8 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:18776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:18776 + Message Delimiter: \r - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute diff --git a/libminifi/test/resources/TestGetTCPSecureWithPass.yml b/libminifi/test/resources/TestGetTCPSecureWithPass.yml index c11d76fe64..0393eb6a9d 100644 --- a/libminifi/test/resources/TestGetTCPSecureWithPass.yml +++ b/libminifi/test/resources/TestGetTCPSecureWithPass.yml @@ -32,10 +32,10 @@ Processors: auto-terminated relationships list: Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:28776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:28776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute diff --git a/libminifi/test/resources/TestSameProcessorMetrics.yml b/libminifi/test/resources/TestSameProcessorMetrics.yml index 2c842b8b81..4b5a4d69dd 100644 --- a/libminifi/test/resources/TestSameProcessorMetrics.yml +++ b/libminifi/test/resources/TestSameProcessorMetrics.yml @@ -60,10 +60,10 @@ Processors: - partial Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: GetTCP2 id: 2438e3c8-015a-1000-79ca-83af40ec1996 class: org.apache.nifi.processors.standard.GetTCP @@ -78,10 +78,10 @@ Processors: - partial Properties: SSL Context Service: SSLContextService - endpoint-list: localhost:8776 - end-of-message-byte: d - reconnect-interval: 100ms - connection-attempt-timeout: 2000 + Endpoint List: localhost:8776 + Message Delimiter: \r + Reconnection Interval: 100ms + Timeout: 1s - name: LogAttribute id: 2438e3c8-015a-1000-79ca-83af40ec1992 class: org.apache.nifi.processors.standard.LogAttribute @@ -120,4 +120,3 @@ Connections: flowfile expiration: 60 sec Remote Processing Groups: - diff --git a/libminifi/test/resources/encrypted.cn.pass b/libminifi/test/resources/encrypted.cn.pass index 9dd74dac1e..cdbe4f5d3f 100644 --- a/libminifi/test/resources/encrypted.cn.pass +++ b/libminifi/test/resources/encrypted.cn.pass @@ -1 +1 @@ -VsVTmHBzixyA9UfTCttRYXus1oMpIxO6jmDXrNrOp5w +VsVTmHBzixyA9UfTCttRYXus1oMpIxO6jmDXrNrOp5w \ No newline at end of file diff --git a/libminifi/test/unit/StringUtilsTests.cpp b/libminifi/test/unit/StringUtilsTests.cpp index a63c963fb2..6e1f3012e8 100644 --- a/libminifi/test/unit/StringUtilsTests.cpp +++ b/libminifi/test/unit/StringUtilsTests.cpp @@ -584,4 +584,19 @@ TEST_CASE("StringUtils::splitToValueAndUnit tests") { } } +TEST_CASE("StringUtils::parseCharacter tests") { + CHECK(StringUtils::parseCharacter("a") == 'a'); + CHECK(StringUtils::parseCharacter("\\n") == '\n'); + CHECK(StringUtils::parseCharacter("\\t") == '\t'); + CHECK(StringUtils::parseCharacter("\\r") == '\r'); + CHECK(StringUtils::parseCharacter("\\s") == 's'); + CHECK(StringUtils::parseCharacter("\\'") == '\''); + CHECK(StringUtils::parseCharacter("\\") == '\\'); + CHECK(StringUtils::parseCharacter("\\?") == '\?'); + + CHECK_FALSE(StringUtils::parseCharacter("abc").has_value()); + CHECK_FALSE(StringUtils::parseCharacter("\\nd").has_value()); + CHECK(StringUtils::parseCharacter("") == std::nullopt); +} + // NOLINTEND(readability-container-size-empty)