From 4b58d55878db55372d1b09de49c6caf363fe3c06 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 6 Dec 2022 13:38:21 +0100 Subject: [PATCH 01/30] test: move the implementation of StaticContentsSock to .cpp Move the implementation (method definitions) from `test/util/net.h` to `test/util/net.cpp` to make the header easier to follow. --- src/test/util/net.cpp | 90 +++++++++++++++++++++++++++++++++++++++++++ src/test/util/net.h | 89 ++++++++---------------------------------- 2 files changed, 107 insertions(+), 72 deletions(-) diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index beefc32bee4b7..0861c2cc09b93 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -137,3 +137,93 @@ std::vector GetRandomNodeEvictionCandidates(int n_candida } return candidates; } + +StaticContentsSock::StaticContentsSock(const std::string& contents) + : Sock{INVALID_SOCKET}, m_contents{contents} +{ +} + +StaticContentsSock::~StaticContentsSock() { m_socket = INVALID_SOCKET; } + +StaticContentsSock& StaticContentsSock::operator=(Sock&& other) +{ + assert(false && "Move of Sock into MockSock not allowed."); + return *this; +} + +ssize_t StaticContentsSock::Send(const void*, size_t len, int) const { return len; } + +ssize_t StaticContentsSock::Recv(void* buf, size_t len, int flags) const +{ + const size_t consume_bytes{std::min(len, m_contents.size() - m_consumed)}; + std::memcpy(buf, m_contents.data() + m_consumed, consume_bytes); + if ((flags & MSG_PEEK) == 0) { + m_consumed += consume_bytes; + } + return consume_bytes; +} + +int StaticContentsSock::Connect(const sockaddr*, socklen_t) const { return 0; } + +int StaticContentsSock::Bind(const sockaddr*, socklen_t) const { return 0; } + +int StaticContentsSock::Listen(int) const { return 0; } + +std::unique_ptr StaticContentsSock::Accept(sockaddr* addr, socklen_t* addr_len) const +{ + if (addr != nullptr) { + // Pretend all connections come from 5.5.5.5:6789 + memset(addr, 0x00, *addr_len); + const socklen_t write_len = static_cast(sizeof(sockaddr_in)); + if (*addr_len >= write_len) { + *addr_len = write_len; + sockaddr_in* addr_in = reinterpret_cast(addr); + addr_in->sin_family = AF_INET; + memset(&addr_in->sin_addr, 0x05, sizeof(addr_in->sin_addr)); + addr_in->sin_port = htons(6789); + } + } + return std::make_unique(""); +}; + +int StaticContentsSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const +{ + std::memset(opt_val, 0x0, *opt_len); + return 0; +} + +int StaticContentsSock::SetSockOpt(int, int, const void*, socklen_t) const { return 0; } + +int StaticContentsSock::GetSockName(sockaddr* name, socklen_t* name_len) const +{ + std::memset(name, 0x0, *name_len); + return 0; +} + +bool StaticContentsSock::SetNonBlocking() const { return true; } + +bool StaticContentsSock::IsSelectable() const { return true; } + +bool StaticContentsSock::Wait(std::chrono::milliseconds timeout, + Event requested, + Event* occurred) const +{ + if (occurred != nullptr) { + *occurred = requested; + } + return true; +} + +bool StaticContentsSock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const +{ + for (auto& [sock, events] : events_per_sock) { + (void)sock; + events.occurred = events.requested; + } + return true; +} + +bool StaticContentsSock::IsConnected(std::string&) const +{ + return true; +} diff --git a/src/test/util/net.h b/src/test/util/net.h index 043e317bf080f..9397dbad3d730 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -141,96 +141,41 @@ constexpr auto ALL_NETWORKS = std::array{ class StaticContentsSock : public Sock { public: - explicit StaticContentsSock(const std::string& contents) - : Sock{INVALID_SOCKET}, - m_contents{contents} - { - } + explicit StaticContentsSock(const std::string& contents); - ~StaticContentsSock() override { m_socket = INVALID_SOCKET; } + ~StaticContentsSock() override; - StaticContentsSock& operator=(Sock&& other) override - { - assert(false && "Move of Sock into MockSock not allowed."); - return *this; - } + StaticContentsSock& operator=(Sock&& other) override; - ssize_t Send(const void*, size_t len, int) const override { return len; } + ssize_t Send(const void*, size_t len, int) const override; - ssize_t Recv(void* buf, size_t len, int flags) const override - { - const size_t consume_bytes{std::min(len, m_contents.size() - m_consumed)}; - std::memcpy(buf, m_contents.data() + m_consumed, consume_bytes); - if ((flags & MSG_PEEK) == 0) { - m_consumed += consume_bytes; - } - return consume_bytes; - } + ssize_t Recv(void* buf, size_t len, int flags) const override; - int Connect(const sockaddr*, socklen_t) const override { return 0; } + int Connect(const sockaddr*, socklen_t) const override; - int Bind(const sockaddr*, socklen_t) const override { return 0; } + int Bind(const sockaddr*, socklen_t) const override; - int Listen(int) const override { return 0; } + int Listen(int) const override; - std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const override - { - if (addr != nullptr) { - // Pretend all connections come from 5.5.5.5:6789 - memset(addr, 0x00, *addr_len); - const socklen_t write_len = static_cast(sizeof(sockaddr_in)); - if (*addr_len >= write_len) { - *addr_len = write_len; - sockaddr_in* addr_in = reinterpret_cast(addr); - addr_in->sin_family = AF_INET; - memset(&addr_in->sin_addr, 0x05, sizeof(addr_in->sin_addr)); - addr_in->sin_port = htons(6789); - } - } - return std::make_unique(""); - }; + std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const override; - int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override - { - std::memset(opt_val, 0x0, *opt_len); - return 0; - } + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override; - int SetSockOpt(int, int, const void*, socklen_t) const override { return 0; } + int SetSockOpt(int, int, const void*, socklen_t) const override; - int GetSockName(sockaddr* name, socklen_t* name_len) const override - { - std::memset(name, 0x0, *name_len); - return 0; - } + int GetSockName(sockaddr* name, socklen_t* name_len) const override; - bool SetNonBlocking() const override { return true; } + bool SetNonBlocking() const override; - bool IsSelectable() const override { return true; } + bool IsSelectable() const override; bool Wait(std::chrono::milliseconds timeout, Event requested, - Event* occurred = nullptr) const override - { - if (occurred != nullptr) { - *occurred = requested; - } - return true; - } + Event* occurred = nullptr) const override; - bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override - { - for (auto& [sock, events] : events_per_sock) { - (void)sock; - events.occurred = events.requested; - } - return true; - } + bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override; - bool IsConnected(std::string&) const override - { - return true; - } + bool IsConnected(std::string&) const override; private: const std::string m_contents; From f1864148c4a091afd63be75bc1ff14ae93383523 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Fri, 6 Sep 2024 11:16:50 +0200 Subject: [PATCH 02/30] test: put the generic parts from StaticContentsSock into a separate class This allows reusing them in other mocked implementations. --- src/test/util/net.cpp | 83 +++++++++++++++++++++++++------------------ src/test/util/net.h | 40 ++++++++++++++++----- 2 files changed, 79 insertions(+), 44 deletions(-) diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index 0861c2cc09b93..77ce3b7585de4 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -138,38 +138,31 @@ std::vector GetRandomNodeEvictionCandidates(int n_candida return candidates; } -StaticContentsSock::StaticContentsSock(const std::string& contents) - : Sock{INVALID_SOCKET}, m_contents{contents} -{ -} +// Have different ZeroSock (or others that inherit from it) objects have different +// m_socket because EqualSharedPtrSock compares m_socket and we want to avoid two +// different objects comparing as equal. +static std::atomic g_mocked_sock_fd{0}; -StaticContentsSock::~StaticContentsSock() { m_socket = INVALID_SOCKET; } +ZeroSock::ZeroSock() : Sock{g_mocked_sock_fd++} {} -StaticContentsSock& StaticContentsSock::operator=(Sock&& other) -{ - assert(false && "Move of Sock into MockSock not allowed."); - return *this; -} +// Sock::~Sock() would try to close(2) m_socket if it is not INVALID_SOCKET, avoid that. +ZeroSock::~ZeroSock() { m_socket = INVALID_SOCKET; } -ssize_t StaticContentsSock::Send(const void*, size_t len, int) const { return len; } +ssize_t ZeroSock::Send(const void*, size_t len, int) const { return len; } -ssize_t StaticContentsSock::Recv(void* buf, size_t len, int flags) const +ssize_t ZeroSock::Recv(void* buf, size_t len, int flags) const { - const size_t consume_bytes{std::min(len, m_contents.size() - m_consumed)}; - std::memcpy(buf, m_contents.data() + m_consumed, consume_bytes); - if ((flags & MSG_PEEK) == 0) { - m_consumed += consume_bytes; - } - return consume_bytes; + memset(buf, 0x0, len); + return len; } -int StaticContentsSock::Connect(const sockaddr*, socklen_t) const { return 0; } +int ZeroSock::Connect(const sockaddr*, socklen_t) const { return 0; } -int StaticContentsSock::Bind(const sockaddr*, socklen_t) const { return 0; } +int ZeroSock::Bind(const sockaddr*, socklen_t) const { return 0; } -int StaticContentsSock::Listen(int) const { return 0; } +int ZeroSock::Listen(int) const { return 0; } -std::unique_ptr StaticContentsSock::Accept(sockaddr* addr, socklen_t* addr_len) const +std::unique_ptr ZeroSock::Accept(sockaddr* addr, socklen_t* addr_len) const { if (addr != nullptr) { // Pretend all connections come from 5.5.5.5:6789 @@ -183,30 +176,28 @@ std::unique_ptr StaticContentsSock::Accept(sockaddr* addr, socklen_t* addr addr_in->sin_port = htons(6789); } } - return std::make_unique(""); -}; + return std::make_unique(); +} -int StaticContentsSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const +int ZeroSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const { std::memset(opt_val, 0x0, *opt_len); return 0; } -int StaticContentsSock::SetSockOpt(int, int, const void*, socklen_t) const { return 0; } +int ZeroSock::SetSockOpt(int, int, const void*, socklen_t) const { return 0; } -int StaticContentsSock::GetSockName(sockaddr* name, socklen_t* name_len) const +int ZeroSock::GetSockName(sockaddr* name, socklen_t* name_len) const { std::memset(name, 0x0, *name_len); return 0; } -bool StaticContentsSock::SetNonBlocking() const { return true; } +bool ZeroSock::SetNonBlocking() const { return true; } -bool StaticContentsSock::IsSelectable() const { return true; } +bool ZeroSock::IsSelectable() const { return true; } -bool StaticContentsSock::Wait(std::chrono::milliseconds timeout, - Event requested, - Event* occurred) const +bool ZeroSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { if (occurred != nullptr) { *occurred = requested; @@ -214,7 +205,7 @@ bool StaticContentsSock::Wait(std::chrono::milliseconds timeout, return true; } -bool StaticContentsSock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const +bool ZeroSock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const { for (auto& [sock, events] : events_per_sock) { (void)sock; @@ -223,7 +214,29 @@ bool StaticContentsSock::WaitMany(std::chrono::milliseconds timeout, EventsPerSo return true; } -bool StaticContentsSock::IsConnected(std::string&) const +ZeroSock& ZeroSock::operator=(Sock&& other) { - return true; + assert(false && "Move of Sock into ZeroSock not allowed."); + return *this; +} + +StaticContentsSock::StaticContentsSock(const std::string& contents) + : m_contents{contents} +{ +} + +ssize_t StaticContentsSock::Recv(void* buf, size_t len, int flags) const +{ + const size_t consume_bytes{std::min(len, m_contents.size() - m_consumed)}; + std::memcpy(buf, m_contents.data() + m_consumed, consume_bytes); + if ((flags & MSG_PEEK) == 0) { + m_consumed += consume_bytes; + } + return consume_bytes; +} + +StaticContentsSock& StaticContentsSock::operator=(Sock&& other) +{ + assert(false && "Move of Sock into StaticContentsSock not allowed."); + return *this; } diff --git a/src/test/util/net.h b/src/test/util/net.h index 9397dbad3d730..20b70cc45448a 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -134,18 +134,15 @@ constexpr auto ALL_NETWORKS = std::array{ }; /** - * A mocked Sock alternative that returns a statically contained data upon read and succeeds - * and ignores all writes. The data to be returned is given to the constructor and when it is - * exhausted an EOF is returned by further reads. + * A mocked Sock alternative that succeeds on all operations. + * Returns infinite amount of 0x0 bytes on reads. */ -class StaticContentsSock : public Sock +class ZeroSock : public Sock { public: - explicit StaticContentsSock(const std::string& contents); - - ~StaticContentsSock() override; + ZeroSock(); - StaticContentsSock& operator=(Sock&& other) override; + ~ZeroSock() override; ssize_t Send(const void*, size_t len, int) const override; @@ -175,9 +172,34 @@ class StaticContentsSock : public Sock bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override; - bool IsConnected(std::string&) const override; +private: + ZeroSock& operator=(Sock&& other) override; +}; + +/** + * A mocked Sock alternative that returns a statically contained data upon read and succeeds + * and ignores all writes. The data to be returned is given to the constructor and when it is + * exhausted an EOF is returned by further reads. + */ +class StaticContentsSock : public ZeroSock +{ +public: + explicit StaticContentsSock(const std::string& contents); + + /** + * Return parts of the contents that was provided at construction until it is exhausted + * and then return 0 (EOF). + */ + ssize_t Recv(void* buf, size_t len, int flags) const override; + + bool IsConnected(std::string&) const override + { + return true; + } private: + StaticContentsSock& operator=(Sock&& other) override; + const std::string m_contents; mutable size_t m_consumed{0}; }; From b448b014947093cd217dbde47c8fb9e6c2bc8ba3 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 6 Dec 2022 13:42:03 +0100 Subject: [PATCH 03/30] test: add a mocked Sock that allows inspecting what has been Send() to it And also allows gradually providing the data to be returned by `Recv()` and sending and receiving net messages (`CNetMessage`). --- src/test/util/net.cpp | 168 ++++++++++++++++++++++++++++++++++++++++++ src/test/util/net.h | 154 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 322 insertions(+) diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index 77ce3b7585de4..ddd96a50640bf 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -14,7 +14,10 @@ #include #include #include +#include +#include +#include #include void ConnmanTestMsg::Handshake(CNode& node, @@ -240,3 +243,168 @@ StaticContentsSock& StaticContentsSock::operator=(Sock&& other) assert(false && "Move of Sock into StaticContentsSock not allowed."); return *this; } + +ssize_t DynSock::Pipe::GetBytes(void* buf, size_t len, int flags) +{ + WAIT_LOCK(m_mutex, lock); + + if (m_data.empty()) { + if (m_eof) { + return 0; + } + errno = EAGAIN; // Same as recv(2) on a non-blocking socket. + return -1; + } + + const size_t read_bytes{std::min(len, m_data.size())}; + + std::memcpy(buf, m_data.data(), read_bytes); + if ((flags & MSG_PEEK) == 0) { + m_data.erase(m_data.begin(), m_data.begin() + read_bytes); + } + + return read_bytes; +} + +std::optional DynSock::Pipe::GetNetMsg() +{ + V1Transport transport{NodeId{0}}; + + { + WAIT_LOCK(m_mutex, lock); + + WaitForDataOrEof(lock); + if (m_eof && m_data.empty()) { + return std::nullopt; + } + + for (;;) { + Span s{m_data}; + if (!transport.ReceivedBytes(s)) { // Consumed bytes are removed from the front of s. + return std::nullopt; + } + m_data.erase(m_data.begin(), m_data.begin() + m_data.size() - s.size()); + if (transport.ReceivedMessageComplete()) { + break; + } + if (m_data.empty()) { + WaitForDataOrEof(lock); + if (m_eof && m_data.empty()) { + return std::nullopt; + } + } + } + } + + bool reject{false}; + CNetMessage msg{transport.GetReceivedMessage(/*time=*/{}, reject)}; + if (reject) { + return std::nullopt; + } + return std::make_optional(std::move(msg)); +} + +void DynSock::Pipe::PushBytes(const void* buf, size_t len) +{ + LOCK(m_mutex); + const uint8_t* b = static_cast(buf); + m_data.insert(m_data.end(), b, b + len); + m_cond.notify_all(); +} + +void DynSock::Pipe::Eof() +{ + LOCK(m_mutex); + m_eof = true; + m_cond.notify_all(); +} + +void DynSock::Pipe::WaitForDataOrEof(UniqueLock& lock) +{ + Assert(lock.mutex() == &m_mutex); + + m_cond.wait(lock, [&]() EXCLUSIVE_LOCKS_REQUIRED(m_mutex) { + AssertLockHeld(m_mutex); + return !m_data.empty() || m_eof; + }); +} + +DynSock::DynSock(std::shared_ptr pipes, std::shared_ptr accept_sockets) + : m_pipes{pipes}, m_accept_sockets{accept_sockets} +{ +} + +DynSock::~DynSock() +{ + m_pipes->send.Eof(); +} + +ssize_t DynSock::Recv(void* buf, size_t len, int flags) const +{ + return m_pipes->recv.GetBytes(buf, len, flags); +} + +ssize_t DynSock::Send(const void* buf, size_t len, int) const +{ + m_pipes->send.PushBytes(buf, len); + return len; +} + +std::unique_ptr DynSock::Accept(sockaddr* addr, socklen_t* addr_len) const +{ + ZeroSock::Accept(addr, addr_len); + return m_accept_sockets->Pop().value_or(nullptr); +} + +bool DynSock::Wait(std::chrono::milliseconds timeout, + Event requested, + Event* occurred) const +{ + EventsPerSock ev; + ev.emplace(this, Events{requested}); + const bool ret{WaitMany(timeout, ev)}; + if (occurred != nullptr) { + *occurred = ev.begin()->second.occurred; + } + return ret; +} + +bool DynSock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const +{ + const auto deadline = std::chrono::steady_clock::now() + timeout; + bool at_least_one_event_occurred{false}; + + for (;;) { + // Check all sockets for readiness without waiting. + for (auto& [sock, events] : events_per_sock) { + if ((events.requested & Sock::SEND) != 0) { + // Always ready for Send(). + events.occurred |= Sock::SEND; + at_least_one_event_occurred = true; + } + + if ((events.requested & Sock::RECV) != 0) { + auto dyn_sock = reinterpret_cast(sock.get()); + uint8_t b; + if (dyn_sock->m_pipes->recv.GetBytes(&b, 1, MSG_PEEK) == 1 || !dyn_sock->m_accept_sockets->Empty()) { + events.occurred |= Sock::RECV; + at_least_one_event_occurred = true; + } + } + } + + if (at_least_one_event_occurred || std::chrono::steady_clock::now() > deadline) { + break; + } + + std::this_thread::sleep_for(10ms); + } + + return true; +} + +DynSock& DynSock::operator=(Sock&&) +{ + assert(false && "Move of Sock into DynSock not allowed."); + return *this; +} diff --git a/src/test/util/net.h b/src/test/util/net.h index 20b70cc45448a..acc135b0c135a 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -6,6 +6,7 @@ #define BITCOIN_TEST_UTIL_NET_H #include +#include #include #include #include @@ -19,9 +20,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -204,6 +207,157 @@ class StaticContentsSock : public ZeroSock mutable size_t m_consumed{0}; }; +/** + * A mocked Sock alternative that allows providing the data to be returned by Recv() + * and inspecting the data that has been supplied to Send(). + */ +class DynSock : public ZeroSock +{ +public: + /** + * Unidirectional bytes or CNetMessage queue (FIFO). + */ + class Pipe + { + public: + /** + * Get bytes and remove them from the pipe. + * @param[in] buf Destination to write bytes to. + * @param[in] len Write up to this number of bytes. + * @param[in] flags Same as the flags of `recv(2)`. Just `MSG_PEEK` is honored. + * @return The number of bytes written to `buf`. `0` if `Eof()` has been called. + * If no bytes are available then `-1` is returned and `errno` is set to `EAGAIN`. + */ + ssize_t GetBytes(void* buf, size_t len, int flags = 0) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); + + /** + * Deserialize a `CNetMessage` and remove it from the pipe. + * If not enough bytes are available then the function will wait. If parsing fails + * or EOF is signaled to the pipe, then `std::nullopt` is returned. + */ + std::optional GetNetMsg() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); + + /** + * Push bytes to the pipe. + */ + void PushBytes(const void* buf, size_t len) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); + + /** + * Construct and push CNetMessage to the pipe. + */ + template + void PushNetMsg(const std::string& type, Args&&... payload) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); + + /** + * Signal end-of-file on the receiving end (`GetBytes()` or `GetNetMsg()`). + */ + void Eof() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); + + private: + /** + * Return when there is some data to read or EOF has been signaled. + * @param[in,out] lock Unique lock that must have been derived from `m_mutex` by `WAIT_LOCK(m_mutex, lock)`. + */ + void WaitForDataOrEof(UniqueLock& lock) EXCLUSIVE_LOCKS_REQUIRED(m_mutex); + + Mutex m_mutex; + std::condition_variable m_cond; + std::vector m_data GUARDED_BY(m_mutex); + bool m_eof GUARDED_BY(m_mutex){false}; + }; + + struct Pipes { + Pipe recv; + Pipe send; + }; + + /** + * A basic thread-safe queue, used for queuing sockets to be returned by Accept(). + */ + class Queue + { + public: + using S = std::unique_ptr; + + void Push(S s) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) + { + LOCK(m_mutex); + m_queue.push(std::move(s)); + } + + std::optional Pop() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) + { + LOCK(m_mutex); + if (m_queue.empty()) { + return std::nullopt; + } + S front{std::move(m_queue.front())}; + m_queue.pop(); + return front; + } + + bool Empty() const EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) + { + LOCK(m_mutex); + return m_queue.empty(); + } + + private: + mutable Mutex m_mutex; + std::queue m_queue GUARDED_BY(m_mutex); + }; + + /** + * Create a new mocked sock. + * @param[in] pipes Send/recv pipes used by the Send() and Recv() methods. + * @param[in] accept_sockets Sockets to return by the Accept() method. + */ + explicit DynSock(std::shared_ptr pipes, std::shared_ptr accept_sockets); + + ~DynSock(); + + ssize_t Recv(void* buf, size_t len, int flags) const override; + + ssize_t Send(const void* buf, size_t len, int) const override; + + std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const override; + + bool Wait(std::chrono::milliseconds timeout, + Event requested, + Event* occurred = nullptr) const override; + + bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override; + +private: + DynSock& operator=(Sock&&) override; + + std::shared_ptr m_pipes; + std::shared_ptr m_accept_sockets; +}; + +template +void DynSock::Pipe::PushNetMsg(const std::string& type, Args&&... payload) +{ + auto msg = NetMsg::Make(type, std::forward(payload)...); + V1Transport transport{NodeId{0}}; + + const bool queued{transport.SetMessageToSend(msg)}; + assert(queued); + + LOCK(m_mutex); + + for (;;) { + const auto& [bytes, _more, _msg_type] = transport.GetBytesToSend(/*have_next_message=*/true); + if (bytes.empty()) { + break; + } + m_data.insert(m_data.end(), bytes.begin(), bytes.end()); + transport.MarkBytesSent(bytes.size()); + } + + m_cond.notify_all(); +} + std::vector GetRandomNodeEvictionCandidates(int n_candidates, FastRandomContext& random_context); #endif // BITCOIN_TEST_UTIL_NET_H From b69d199b7fd98677c8ce08f7f7e3ea0155bfe512 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Fri, 23 Aug 2024 13:11:37 +0200 Subject: [PATCH 04/30] net: reduce CAddress usage to CService or CNetAddr * `CConnman::CalculateKeyedNetGroup()` needs `CNetAddr`, not `CAddress`, thus change its argument. * Both callers of `CConnman::CreateNodeFromAcceptedSocket()` create a dummy `CAddress` from `CService`, so use `CService` instead. * `GetBindAddress()` only needs to return `CAddress`. * `CNode::addrBind` does not need to be `CAddress`. --- src/net.cpp | 27 +++++++++++++-------------- src/net.h | 12 ++++++------ src/test/net_tests.cpp | 6 +++--- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 8ea7f6ce44508..77b3945bd1eb5 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -377,9 +377,9 @@ bool CConnman::CheckIncomingNonce(uint64_t nonce) } /** Get the bind address for a socket as CAddress */ -static CAddress GetBindAddress(const Sock& sock) +static CService GetBindAddress(const Sock& sock) { - CAddress addr_bind; + CService addr_bind; struct sockaddr_storage sockaddr_bind; socklen_t sockaddr_bind_len = sizeof(sockaddr_bind); if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { @@ -454,7 +454,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo // Connect std::unique_ptr sock; Proxy proxy; - CAddress addr_bind; + CService addr_bind; assert(!addr_bind.IsValid()); std::unique_ptr i2p_transient_session; @@ -491,7 +491,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo if (connected) { sock = std::move(conn.sock); - addr_bind = CAddress{conn.me, NODE_NONE}; + addr_bind = conn.me; } } else if (use_proxy) { LogPrintLevel(BCLog::PROXY, BCLog::Level::Debug, "Using proxy: %s to connect to %s\n", proxy.ToString(), target_addr.ToStringAddrPort()); @@ -1718,7 +1718,6 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len); - CAddress addr; if (!sock) { const int nErr = WSAGetLastError(); @@ -1728,13 +1727,14 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { return; } + CService addr; if (!addr.SetSockAddr((const struct sockaddr*)&sockaddr)) { LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Unknown socket family\n"); } else { - addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE}; + addr = MaybeFlipIPv6toCJDNS(addr); } - const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock)), NODE_NONE}; + const CService addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock))}; NetPermissionFlags permission_flags = NetPermissionFlags::None; hListenSocket.AddSocketPermissionFlags(permission_flags); @@ -1744,8 +1744,8 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, NetPermissionFlags permission_flags, - const CAddress& addr_bind, - const CAddress& addr) + const CService& addr_bind, + const CService& addr) { int nInbound = 0; @@ -1812,7 +1812,7 @@ void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, CNode* pnode = new CNode(id, std::move(sock), - addr, + CAddress{addr, NODE_NONE}, CalculateKeyedNetGroup(addr), nonce, addr_bind, @@ -3074,8 +3074,7 @@ void CConnman::ThreadI2PAcceptIncoming() continue; } - CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None, - CAddress{conn.me, NODE_NONE}, CAddress{conn.peer, NODE_NONE}); + CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None, conn.me, conn.peer); err_wait = err_wait_begin; } @@ -3765,7 +3764,7 @@ CNode::CNode(NodeId idIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, - const CAddress& addrBindIn, + const CService& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion, @@ -3902,7 +3901,7 @@ CSipHasher CConnman::GetDeterministicRandomizer(uint64_t id) const return CSipHasher(nSeed0, nSeed1).Write(id); } -uint64_t CConnman::CalculateKeyedNetGroup(const CAddress& address) const +uint64_t CConnman::CalculateKeyedNetGroup(const CNetAddr& address) const { std::vector vchNetGroup(m_netgroupman.GetGroup(address)); diff --git a/src/net.h b/src/net.h index 99a9d0da4b45d..e64d9a67f4608 100644 --- a/src/net.h +++ b/src/net.h @@ -211,7 +211,7 @@ class CNodeStats // Address of this peer CAddress addr; // Bind address of our side of the connection - CAddress addrBind; + CService addrBind; // Network the peer connected through Network m_network; uint32_t m_mapped_as; @@ -707,7 +707,7 @@ class CNode // Address of this peer const CAddress addr; // Bind address of our side of the connection - const CAddress addrBind; + const CService addrBind; const std::string m_addr_name; /** The pszDest argument provided to ConnectNode(). Only used for reconnections. */ const std::string m_dest; @@ -883,7 +883,7 @@ class CNode const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, - const CAddress& addrBindIn, + const CService& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion, @@ -1312,8 +1312,8 @@ class CConnman */ void CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, NetPermissionFlags permission_flags, - const CAddress& addr_bind, - const CAddress& addr); + const CService& addr_bind, + const CService& addr); void DisconnectNodes() EXCLUSIVE_LOCKS_REQUIRED(!m_reconnections_mutex, !m_nodes_mutex); void NotifyNumConnectionsChanged(); @@ -1350,7 +1350,7 @@ class CConnman void ThreadSocketHandler() EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc, !m_nodes_mutex, !m_reconnections_mutex); void ThreadDNSAddressSeed() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_nodes_mutex); - uint64_t CalculateKeyedNetGroup(const CAddress& ad) const; + uint64_t CalculateKeyedNetGroup(const CNetAddr& ad) const; CNode* FindNode(const CNetAddr& ip); CNode* FindNode(const std::string& addrName); diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index 384b1d7cc92d1..5f0f05c842ad4 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -671,7 +671,7 @@ BOOST_AUTO_TEST_CASE(get_local_addr_for_peer_port) /*addrIn=*/CAddress{CService{peer_out_in_addr, 8333}, NODE_NETWORK}, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, - /*addrBindIn=*/CAddress{}, + /*addrBindIn=*/CService{}, /*addrNameIn=*/std::string{}, /*conn_type_in=*/ConnectionType::OUTBOUND_FULL_RELAY, /*inbound_onion=*/false}; @@ -692,7 +692,7 @@ BOOST_AUTO_TEST_CASE(get_local_addr_for_peer_port) /*addrIn=*/CAddress{CService{peer_in_in_addr, 8333}, NODE_NETWORK}, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, - /*addrBindIn=*/CAddress{}, + /*addrBindIn=*/CService{}, /*addrNameIn=*/std::string{}, /*conn_type_in=*/ConnectionType::INBOUND, /*inbound_onion=*/false}; @@ -829,7 +829,7 @@ BOOST_AUTO_TEST_CASE(initial_advertise_from_version_message) /*addrIn=*/CAddress{CService{peer_in_addr, 8333}, NODE_NETWORK}, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, - /*addrBindIn=*/CAddress{}, + /*addrBindIn=*/CService{}, /*addrNameIn=*/std::string{}, /*conn_type_in=*/ConnectionType::OUTBOUND_FULL_RELAY, /*inbound_onion=*/false}; From fd81820214e695ba228a954506397c3d781fe3fe Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 14 Jan 2025 17:24:22 +0100 Subject: [PATCH 05/30] net: separate the listening socket from the permissions They were coupled in `struct ListenSocket`, but the socket belongs to the lower level transport protocol, whereas the permissions are specific to the higher Bitcoin P2P protocol. --- src/net.cpp | 14 ++++++++++---- src/net.h | 16 +++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 77b3945bd1eb5..1b98e197f11b1 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1737,7 +1737,10 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { const CService addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock))}; NetPermissionFlags permission_flags = NetPermissionFlags::None; - hListenSocket.AddSocketPermissionFlags(permission_flags); + auto it{m_listen_permissions.find(addr_bind)}; + if (it != m_listen_permissions.end()) { + NetPermissions::AddFlag(permission_flags, it->second); + } CreateNodeFromAcceptedSocket(std::move(sock), permission_flags, addr_bind, addr); } @@ -3080,7 +3083,7 @@ void CConnman::ThreadI2PAcceptIncoming() } } -bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, NetPermissionFlags permissions) +bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError) { int nOne = 1; @@ -3145,7 +3148,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - vhListenSocket.emplace_back(std::move(sock), permissions); + vhListenSocket.emplace_back(std::move(sock)); return true; } @@ -3211,13 +3214,15 @@ bool CConnman::Bind(const CService& addr_, unsigned int flags, NetPermissionFlag const CService addr{MaybeFlipIPv6toCJDNS(addr_)}; bilingual_str strError; - if (!BindListenPort(addr, strError, permissions)) { + if (!BindListenPort(addr, strError)) { if ((flags & BF_REPORT_ERROR) && m_client_interface) { m_client_interface->ThreadSafeMessageBox(strError, "", CClientUIInterface::MSG_ERROR); } return false; } + m_listen_permissions.emplace(addr, permissions); + if (addr.IsRoutable() && fDiscover && !(flags & BF_DONT_ADVERTISE) && !NetPermissions::HasFlag(permissions, NetPermissionFlags::NoBan)) { AddLocal(addr, LOCAL_BIND); } @@ -3451,6 +3456,7 @@ void CConnman::StopNodes() DeleteNode(pnode); } m_nodes_disconnected.clear(); + m_listen_permissions.clear(); vhListenSocket.clear(); semOutbound.reset(); semAddnode.reset(); diff --git a/src/net.h b/src/net.h index e64d9a67f4608..8ece1071278a8 100644 --- a/src/net.h +++ b/src/net.h @@ -1276,21 +1276,17 @@ class CConnman struct ListenSocket { public: std::shared_ptr sock; - inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); } - ListenSocket(std::shared_ptr sock_, NetPermissionFlags permissions_) - : sock{sock_}, m_permissions{permissions_} + ListenSocket(std::shared_ptr sock_) + : sock{sock_} { } - - private: - NetPermissionFlags m_permissions; }; //! returns the time left in the current max outbound cycle //! in case of no limit, it will always return 0 std::chrono::seconds GetMaxOutboundTimeLeftInCycle_() const EXCLUSIVE_LOCKS_REQUIRED(m_total_bytes_sent_mutex); - bool BindListenPort(const CService& bindAddr, bilingual_str& strError, NetPermissionFlags permissions); + bool BindListenPort(const CService& bindAddr, bilingual_str& strError); bool Bind(const CService& addr, unsigned int flags, NetPermissionFlags permissions); bool InitBinds(const Options& options); @@ -1431,6 +1427,12 @@ class CConnman unsigned int nReceiveFloodSize{0}; std::vector vhListenSocket; + + /** + * Permissions that incoming peers get based on our listening address they connected to. + */ + std::unordered_map m_listen_permissions; + std::atomic fNetworkActive{true}; bool fAddressesInitialized{false}; AddrMan& addrman; From e5d36eea015efc31aa38d540af4cf39c9e2e46b0 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Fri, 23 Aug 2024 15:36:40 +0200 Subject: [PATCH 06/30] net: split CConnman::BindListenPort() off CConnman Introduce a new low-level socket managing class `SockMan` and move the `CConnman::BindListenPort()` method to it. Also, separate the listening socket from the permissions - they were coupled in `struct ListenSocket`, but the socket is protocol agnostic, whereas the permissions are specific to the application of the Bitcoin P2P protocol. --- src/CMakeLists.txt | 1 + src/common/sockman.cpp | 85 ++++++++++++++++++++++++++++++++++++++++++ src/common/sockman.h | 44 ++++++++++++++++++++++ src/net.cpp | 85 ++++-------------------------------------- src/net.h | 17 ++------- 5 files changed, 141 insertions(+), 91 deletions(-) create mode 100644 src/common/sockman.cpp create mode 100644 src/common/sockman.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 889c00c78327f..40f4ec87e4f7e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -123,6 +123,7 @@ add_library(bitcoin_common STATIC EXCLUDE_FROM_ALL common/run_command.cpp common/settings.cpp common/signmessage.cpp + common/sockman.cpp common/system.cpp common/url.cpp compressor.cpp diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp new file mode 100644 index 0000000000000..7cc7edb809a8c --- /dev/null +++ b/src/common/sockman.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2024-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or https://opensource.org/license/mit/. + +#include // IWYU pragma: keep + +#include +#include +#include +#include + +bool SockMan::BindListenPort(const CService& addrBind, bilingual_str& strError) +{ + int nOne = 1; + + // Create socket for listening for incoming connections + struct sockaddr_storage sockaddr; + socklen_t len = sizeof(sockaddr); + if (!addrBind.GetSockAddr((struct sockaddr*)&sockaddr, &len)) + { + strError = Untranslated(strprintf("Bind address family for %s not supported", addrBind.ToStringAddrPort())); + LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + return false; + } + + std::unique_ptr sock = CreateSock(addrBind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP); + if (!sock) { + strError = Untranslated(strprintf("Couldn't open socket for incoming connections (socket returned error %s)", NetworkErrorString(WSAGetLastError()))); + LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + return false; + } + + // Allow binding if the port is still in TIME_WAIT state after + // the program was closed and restarted. + if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { + strError = Untranslated(strprintf("Error setting SO_REUSEADDR on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); + LogPrintf("%s\n", strError.original); + } + + // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option + // and enable it by default or not. Try to enable it, if possible. + if (addrBind.IsIPv6()) { +#ifdef IPV6_V6ONLY + if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { + strError = Untranslated(strprintf("Error setting IPV6_V6ONLY on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); + LogPrintf("%s\n", strError.original); + } +#endif +#ifdef WIN32 + int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; + if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)) == SOCKET_ERROR) { + strError = Untranslated(strprintf("Error setting IPV6_PROTECTION_LEVEL on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); + LogPrintf("%s\n", strError.original); + } +#endif + } + + if (sock->Bind(reinterpret_cast(&sockaddr), len) == SOCKET_ERROR) { + int nErr = WSAGetLastError(); + if (nErr == WSAEADDRINUSE) + strError = strprintf(_("Unable to bind to %s on this computer. %s is probably already running."), addrBind.ToStringAddrPort(), CLIENT_NAME); + else + strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToStringAddrPort(), NetworkErrorString(nErr)); + LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + return false; + } + LogPrintf("Bound to %s\n", addrBind.ToStringAddrPort()); + + // Listen for incoming connections + if (sock->Listen(SOMAXCONN) == SOCKET_ERROR) + { + strError = strprintf(_("Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); + LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + return false; + } + + m_listen.emplace_back(std::move(sock)); + + return true; +} + +void SockMan::CloseSockets() +{ + m_listen.clear(); +} diff --git a/src/common/sockman.h b/src/common/sockman.h new file mode 100644 index 0000000000000..d96b59491b879 --- /dev/null +++ b/src/common/sockman.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or https://opensource.org/license/mit/. + +#ifndef BITCOIN_COMMON_SOCKMAN_H +#define BITCOIN_COMMON_SOCKMAN_H + +#include +#include +#include + +#include +#include + +/** + * A socket manager class which handles socket operations. + * To use this class, inherit from it and implement the pure virtual methods. + * Handled operations: + * - binding and listening on sockets + */ +class SockMan +{ +public: + /** + * Bind to a new address:port, start listening and add the listen socket to `m_listen`. + * @param[in] addrBind Where to bind. + * @param[out] strError Error string if an error occurs. + * @retval true Success. + * @retval false Failure, `strError` will be set. + */ + bool BindListenPort(const CService& addrBind, bilingual_str& strError); + + /** + * Close all sockets. + */ + void CloseSockets(); + + /** + * List of listening sockets. + */ + std::vector> m_listen; +}; + +#endif // BITCOIN_COMMON_SOCKMAN_H diff --git a/src/net.cpp b/src/net.cpp index 1b98e197f11b1..52a878eec94a6 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1714,10 +1714,10 @@ bool CConnman::AttemptToEvictConnection() return false; } -void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { +void CConnman::AcceptConnection(const Sock& listen_sock) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len); + auto sock = listen_sock.Accept((struct sockaddr*)&sockaddr, &len); if (!sock) { const int nErr = WSAGetLastError(); @@ -2033,8 +2033,8 @@ Sock::EventsPerSock CConnman::GenerateWaitSockets(Span nodes) { Sock::EventsPerSock events_per_sock; - for (const ListenSocket& hListenSocket : vhListenSocket) { - events_per_sock.emplace(hListenSocket.sock, Sock::Events{Sock::RECV}); + for (const auto& sock : m_listen) { + events_per_sock.emplace(sock, Sock::Events{Sock::RECV}); } for (CNode* pnode : nodes) { @@ -2189,13 +2189,13 @@ void CConnman::SocketHandlerConnected(const std::vector& nodes, void CConnman::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock) { - for (const ListenSocket& listen_socket : vhListenSocket) { + for (const auto& sock : m_listen) { if (interruptNet) { return; } - const auto it = events_per_sock.find(listen_socket.sock); + const auto it = events_per_sock.find(sock); if (it != events_per_sock.end() && it->second.occurred & Sock::RECV) { - AcceptConnection(listen_socket); + AcceptConnection(*sock); } } } @@ -3083,75 +3083,6 @@ void CConnman::ThreadI2PAcceptIncoming() } } -bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError) -{ - int nOne = 1; - - // Create socket for listening for incoming connections - struct sockaddr_storage sockaddr; - socklen_t len = sizeof(sockaddr); - if (!addrBind.GetSockAddr((struct sockaddr*)&sockaddr, &len)) - { - strError = Untranslated(strprintf("Bind address family for %s not supported", addrBind.ToStringAddrPort())); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); - return false; - } - - std::unique_ptr sock = CreateSock(addrBind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP); - if (!sock) { - strError = Untranslated(strprintf("Couldn't open socket for incoming connections (socket returned error %s)", NetworkErrorString(WSAGetLastError()))); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); - return false; - } - - // Allow binding if the port is still in TIME_WAIT state after - // the program was closed and restarted. - if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { - strError = Untranslated(strprintf("Error setting SO_REUSEADDR on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); - LogPrintf("%s\n", strError.original); - } - - // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option - // and enable it by default or not. Try to enable it, if possible. - if (addrBind.IsIPv6()) { -#ifdef IPV6_V6ONLY - if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { - strError = Untranslated(strprintf("Error setting IPV6_V6ONLY on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); - LogPrintf("%s\n", strError.original); - } -#endif -#ifdef WIN32 - int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)) == SOCKET_ERROR) { - strError = Untranslated(strprintf("Error setting IPV6_PROTECTION_LEVEL on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); - LogPrintf("%s\n", strError.original); - } -#endif - } - - if (sock->Bind(reinterpret_cast(&sockaddr), len) == SOCKET_ERROR) { - int nErr = WSAGetLastError(); - if (nErr == WSAEADDRINUSE) - strError = strprintf(_("Unable to bind to %s on this computer. %s is probably already running."), addrBind.ToStringAddrPort(), CLIENT_NAME); - else - strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToStringAddrPort(), NetworkErrorString(nErr)); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); - return false; - } - LogPrintf("Bound to %s\n", addrBind.ToStringAddrPort()); - - // Listen for incoming connections - if (sock->Listen(SOMAXCONN) == SOCKET_ERROR) - { - strError = strprintf(_("Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); - return false; - } - - vhListenSocket.emplace_back(std::move(sock)); - return true; -} - void Discover() { if (!fDiscover) @@ -3457,7 +3388,7 @@ void CConnman::StopNodes() } m_nodes_disconnected.clear(); m_listen_permissions.clear(); - vhListenSocket.clear(); + CloseSockets(); semOutbound.reset(); semAddnode.reset(); } diff --git a/src/net.h b/src/net.h index 8ece1071278a8..cc9b82421ef33 100644 --- a/src/net.h +++ b/src/net.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1048,7 +1049,7 @@ class NetEventsInterface ~NetEventsInterface() = default; }; -class CConnman +class CConnman : private SockMan { public: @@ -1273,20 +1274,10 @@ class CConnman bool MultipleManualOrFullOutboundConns(Network net) const EXCLUSIVE_LOCKS_REQUIRED(m_nodes_mutex); private: - struct ListenSocket { - public: - std::shared_ptr sock; - ListenSocket(std::shared_ptr sock_) - : sock{sock_} - { - } - }; - //! returns the time left in the current max outbound cycle //! in case of no limit, it will always return 0 std::chrono::seconds GetMaxOutboundTimeLeftInCycle_() const EXCLUSIVE_LOCKS_REQUIRED(m_total_bytes_sent_mutex); - bool BindListenPort(const CService& bindAddr, bilingual_str& strError); bool Bind(const CService& addr, unsigned int flags, NetPermissionFlags permissions); bool InitBinds(const Options& options); @@ -1296,7 +1287,7 @@ class CConnman void ThreadOpenConnections(std::vector connect, Span seed_nodes) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_added_nodes_mutex, !m_nodes_mutex, !m_unused_i2p_sessions_mutex, !m_reconnections_mutex); void ThreadMessageHandler() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); void ThreadI2PAcceptIncoming(); - void AcceptConnection(const ListenSocket& hListenSocket); + void AcceptConnection(const Sock& listen_sock); /** * Create a `CNode` object from a socket that has just been accepted and add the node to @@ -1426,8 +1417,6 @@ class CConnman unsigned int nSendBufferMaxSize{0}; unsigned int nReceiveFloodSize{0}; - std::vector vhListenSocket; - /** * Permissions that incoming peers get based on our listening address they connected to. */ From b717550b58f7790682362bc95f60047c8a994779 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 17 Sep 2024 15:05:46 +0200 Subject: [PATCH 07/30] style: modernize the style of SockMan::BindListenPort() It was copied verbatim from `CConnman::BindListenPort()` in the previous commit. Modernize its variables and style and log the error messages from the caller. --- src/common/sockman.cpp | 79 ++++++++++++++++++--------------- src/common/sockman.h | 6 +-- src/net.cpp | 5 ++- test/functional/feature_port.py | 10 ++--- 4 files changed, 56 insertions(+), 44 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index 7cc7edb809a8c..dd793ecd90639 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -9,68 +9,77 @@ #include #include -bool SockMan::BindListenPort(const CService& addrBind, bilingual_str& strError) +bool SockMan::BindAndStartListening(const CService& to, bilingual_str& errmsg) { - int nOne = 1; - // Create socket for listening for incoming connections - struct sockaddr_storage sockaddr; - socklen_t len = sizeof(sockaddr); - if (!addrBind.GetSockAddr((struct sockaddr*)&sockaddr, &len)) - { - strError = Untranslated(strprintf("Bind address family for %s not supported", addrBind.ToStringAddrPort())); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + struct sockaddr_storage storage; + socklen_t len{sizeof(storage)}; + if (!to.GetSockAddr(reinterpret_cast(&storage), &len)) { + errmsg = Untranslated(strprintf("Bind address family for %s not supported", to.ToStringAddrPort())); return false; } - std::unique_ptr sock = CreateSock(addrBind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP); + std::unique_ptr sock{CreateSock(to.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP)}; if (!sock) { - strError = Untranslated(strprintf("Couldn't open socket for incoming connections (socket returned error %s)", NetworkErrorString(WSAGetLastError()))); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + errmsg = Untranslated(strprintf("Cannot create %s listen socket: %s", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError()))); return false; } + int one{1}; + // Allow binding if the port is still in TIME_WAIT state after // the program was closed and restarted. - if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { - strError = Untranslated(strprintf("Error setting SO_REUSEADDR on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); - LogPrintf("%s\n", strError.original); + if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&one), sizeof(one)) == SOCKET_ERROR) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Info, + "Cannot set SO_REUSEADDR on %s listen socket: %s, continuing anyway\n", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError())); } // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option // and enable it by default or not. Try to enable it, if possible. - if (addrBind.IsIPv6()) { + if (to.IsIPv6()) { #ifdef IPV6_V6ONLY - if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { - strError = Untranslated(strprintf("Error setting IPV6_V6ONLY on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); - LogPrintf("%s\n", strError.original); + if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&one), sizeof(one)) == SOCKET_ERROR) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Info, + "Cannot set IPV6_V6ONLY on %s listen socket: %s, continuing anyway\n", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError())); } #endif #ifdef WIN32 - int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)) == SOCKET_ERROR) { - strError = Untranslated(strprintf("Error setting IPV6_PROTECTION_LEVEL on socket: %s, continuing anyway", NetworkErrorString(WSAGetLastError()))); - LogPrintf("%s\n", strError.original); + int prot_level{PROTECTION_LEVEL_UNRESTRICTED}; + if (sock->SetSockOpt(IPPROTO_IPV6, + IPV6_PROTECTION_LEVEL, + reinterpret_cast(&prot_level), + sizeof(prot_level)) == SOCKET_ERROR) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Info, + "Cannot set IPV6_PROTECTION_LEVEL on %s listen socket: %s, continuing anyway\n", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError())); } #endif } - if (sock->Bind(reinterpret_cast(&sockaddr), len) == SOCKET_ERROR) { - int nErr = WSAGetLastError(); - if (nErr == WSAEADDRINUSE) - strError = strprintf(_("Unable to bind to %s on this computer. %s is probably already running."), addrBind.ToStringAddrPort(), CLIENT_NAME); - else - strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToStringAddrPort(), NetworkErrorString(nErr)); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + if (sock->Bind(reinterpret_cast(&storage), len) == SOCKET_ERROR) { + const int err{WSAGetLastError()}; + errmsg = strprintf(_("Cannot bind to %s: %s%s"), + to.ToStringAddrPort(), + NetworkErrorString(err), + err == WSAEADDRINUSE + ? std::string{" ("} + CLIENT_NAME + " already running?)" + : ""); return false; } - LogPrintf("Bound to %s\n", addrBind.ToStringAddrPort()); // Listen for incoming connections - if (sock->Listen(SOMAXCONN) == SOCKET_ERROR) - { - strError = strprintf(_("Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); - LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); + if (sock->Listen(SOMAXCONN) == SOCKET_ERROR) { + errmsg = strprintf(_("Cannot listen to %s: %s"), to.ToStringAddrPort(), NetworkErrorString(WSAGetLastError())); return false; } diff --git a/src/common/sockman.h b/src/common/sockman.h index d96b59491b879..0dd72326135ad 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -23,12 +23,12 @@ class SockMan public: /** * Bind to a new address:port, start listening and add the listen socket to `m_listen`. - * @param[in] addrBind Where to bind. - * @param[out] strError Error string if an error occurs. + * @param[in] to Where to bind. + * @param[out] errmsg Error string if an error occurs. * @retval true Success. * @retval false Failure, `strError` will be set. */ - bool BindListenPort(const CService& addrBind, bilingual_str& strError); + bool BindAndStartListening(const CService& to, bilingual_str& errmsg); /** * Close all sockets. diff --git a/src/net.cpp b/src/net.cpp index 52a878eec94a6..256ebfd3651f4 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -3145,13 +3145,16 @@ bool CConnman::Bind(const CService& addr_, unsigned int flags, NetPermissionFlag const CService addr{MaybeFlipIPv6toCJDNS(addr_)}; bilingual_str strError; - if (!BindListenPort(addr, strError)) { + if (!BindAndStartListening(addr, strError)) { + LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); if ((flags & BF_REPORT_ERROR) && m_client_interface) { m_client_interface->ThreadSafeMessageBox(strError, "", CClientUIInterface::MSG_ERROR); } return false; } + LogPrintLevel(BCLog::NET, BCLog::Level::Info, "Bound to and listening on %s\n", addr.ToStringAddrPort()); + m_listen_permissions.emplace(addr, permissions); if (addr.IsRoutable() && fDiscover && !(flags & BF_DONT_ADVERTISE) && !NetPermissions::HasFlag(permissions, NetPermissionFlags::NoBan)) { diff --git a/test/functional/feature_port.py b/test/functional/feature_port.py index 2746d7d79c1d6..1317583c4caf6 100755 --- a/test/functional/feature_port.py +++ b/test/functional/feature_port.py @@ -29,23 +29,23 @@ def run_test(self): port2 = p2p_port(self.num_nodes + 5) self.log.info("When starting with -port, bitcoind binds to it and uses port + 1 for an onion bind") - with node.assert_debug_log(expected_msgs=[f'Bound to 0.0.0.0:{port1}', f'Bound to 127.0.0.1:{port1 + 1}']): + with node.assert_debug_log(expected_msgs=[f'Bound to and listening on 0.0.0.0:{port1}', f'Bound to and listening on 127.0.0.1:{port1 + 1}']): self.restart_node(0, extra_args=["-listen", f"-port={port1}"]) self.log.info("When specifying -port multiple times, only the last one is taken") - with node.assert_debug_log(expected_msgs=[f'Bound to 0.0.0.0:{port2}', f'Bound to 127.0.0.1:{port2 + 1}'], unexpected_msgs=[f'Bound to 0.0.0.0:{port1}']): + with node.assert_debug_log(expected_msgs=[f'Bound to and listening on 0.0.0.0:{port2}', f'Bound to and listening on 127.0.0.1:{port2 + 1}'], unexpected_msgs=[f'Bound to and listening on 0.0.0.0:{port1}']): self.restart_node(0, extra_args=["-listen", f"-port={port1}", f"-port={port2}"]) self.log.info("When specifying ports with both -port and -bind, the one from -port is ignored") - with node.assert_debug_log(expected_msgs=[f'Bound to 0.0.0.0:{port2}'], unexpected_msgs=[f'Bound to 0.0.0.0:{port1}']): + with node.assert_debug_log(expected_msgs=[f'Bound to and listening on 0.0.0.0:{port2}'], unexpected_msgs=[f'Bound to and listening on 0.0.0.0:{port1}']): self.restart_node(0, extra_args=["-listen", f"-port={port1}", f"-bind=0.0.0.0:{port2}"]) self.log.info("When -bind specifies no port, the values from -port and -bind are combined") - with self.nodes[0].assert_debug_log(expected_msgs=[f'Bound to 0.0.0.0:{port1}']): + with self.nodes[0].assert_debug_log(expected_msgs=[f'Bound to and listening on 0.0.0.0:{port1}']): self.restart_node(0, extra_args=["-listen", f"-port={port1}", "-bind=0.0.0.0"]) self.log.info("When an onion bind specifies no port, the value from -port, incremented by 1, is taken") - with self.nodes[0].assert_debug_log(expected_msgs=[f'Bound to 127.0.0.1:{port1 + 1}']): + with self.nodes[0].assert_debug_log(expected_msgs=[f'Bound to and listening on 127.0.0.1:{port1 + 1}']): self.restart_node(0, extra_args=["-listen", f"-port={port1}", "-bind=127.0.0.1=onion"]) self.log.info("Invalid values for -port raise errors") From 0241b04cf406d482abfac3fddfad9a9c28725f32 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 26 Aug 2024 13:14:53 +0200 Subject: [PATCH 08/30] net: split CConnman::AcceptConnection() off CConnman Move the `CConnman::AcceptConnection()` method to `SockMan` and split parts of it: * the flip-to-CJDNS part: to just after the `AcceptConnection()` call * the permissions part: at the start of `CreateNodeFromAcceptedSocket()` --- src/common/sockman.cpp | 21 ++++++++++++++++++ src/common/sockman.h | 9 ++++++++ src/net.cpp | 49 ++++++++++++++---------------------------- src/net.h | 3 --- 4 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index dd793ecd90639..610771b90c9c2 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -88,6 +88,27 @@ bool SockMan::BindAndStartListening(const CService& to, bilingual_str& errmsg) return true; } +std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CService& addr) +{ + struct sockaddr_storage sockaddr; + socklen_t len = sizeof(sockaddr); + auto sock = listen_sock.Accept((struct sockaddr*)&sockaddr, &len); + + if (!sock) { + const int nErr = WSAGetLastError(); + if (nErr != WSAEWOULDBLOCK) { + LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); + } + return {}; + } + + if (!addr.SetSockAddr((const struct sockaddr*)&sockaddr)) { + LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Unknown socket family\n"); + } + + return sock; +} + void SockMan::CloseSockets() { m_listen.clear(); diff --git a/src/common/sockman.h b/src/common/sockman.h index 0dd72326135ad..3aaed8df12307 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -17,6 +17,7 @@ * To use this class, inherit from it and implement the pure virtual methods. * Handled operations: * - binding and listening on sockets + * - accepting incoming connections */ class SockMan { @@ -30,6 +31,14 @@ class SockMan */ bool BindAndStartListening(const CService& to, bilingual_str& errmsg); + /** + * Accept a connection. + * @param[in] listen_sock Socket on which to accept the connection. + * @param[out] addr Address of the peer that was accepted. + * @return Newly created socket for the accepted connection. + */ + std::unique_ptr AcceptConnection(const Sock& listen_sock, CService& addr); + /** * Close all sockets. */ diff --git a/src/net.cpp b/src/net.cpp index 256ebfd3651f4..51f71f2b27840 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1714,27 +1714,11 @@ bool CConnman::AttemptToEvictConnection() return false; } -void CConnman::AcceptConnection(const Sock& listen_sock) { - struct sockaddr_storage sockaddr; - socklen_t len = sizeof(sockaddr); - auto sock = listen_sock.Accept((struct sockaddr*)&sockaddr, &len); - - if (!sock) { - const int nErr = WSAGetLastError(); - if (nErr != WSAEWOULDBLOCK) { - LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); - } - return; - } - - CService addr; - if (!addr.SetSockAddr((const struct sockaddr*)&sockaddr)) { - LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Unknown socket family\n"); - } else { - addr = MaybeFlipIPv6toCJDNS(addr); - } - - const CService addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock))}; +void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, + const CService& addr_bind, + const CService& addr) +{ + int nInbound = 0; NetPermissionFlags permission_flags = NetPermissionFlags::None; auto it{m_listen_permissions.find(addr_bind)}; @@ -1742,16 +1726,6 @@ void CConnman::AcceptConnection(const Sock& listen_sock) { NetPermissions::AddFlag(permission_flags, it->second); } - CreateNodeFromAcceptedSocket(std::move(sock), permission_flags, addr_bind, addr); -} - -void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, - NetPermissionFlags permission_flags, - const CService& addr_bind, - const CService& addr) -{ - int nInbound = 0; - AddWhitelistPermissionFlags(permission_flags, addr, vWhitelistedRangeIncoming); { @@ -2195,7 +2169,16 @@ void CConnman::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock } const auto it = events_per_sock.find(sock); if (it != events_per_sock.end() && it->second.occurred & Sock::RECV) { - AcceptConnection(*sock); + CService addr_accepted; + + auto sock_accepted{AcceptConnection(*sock, addr_accepted)}; + + if (sock_accepted) { + addr_accepted = MaybeFlipIPv6toCJDNS(addr_accepted); + const CService addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock))}; + + CreateNodeFromAcceptedSocket(std::move(sock_accepted), addr_bind, addr_accepted); + } } } } @@ -3077,7 +3060,7 @@ void CConnman::ThreadI2PAcceptIncoming() continue; } - CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None, conn.me, conn.peer); + CreateNodeFromAcceptedSocket(std::move(conn.sock), conn.me, conn.peer); err_wait = err_wait_begin; } diff --git a/src/net.h b/src/net.h index cc9b82421ef33..01fc644e40c5a 100644 --- a/src/net.h +++ b/src/net.h @@ -1287,18 +1287,15 @@ class CConnman : private SockMan void ThreadOpenConnections(std::vector connect, Span seed_nodes) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_added_nodes_mutex, !m_nodes_mutex, !m_unused_i2p_sessions_mutex, !m_reconnections_mutex); void ThreadMessageHandler() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); void ThreadI2PAcceptIncoming(); - void AcceptConnection(const Sock& listen_sock); /** * Create a `CNode` object from a socket that has just been accepted and add the node to * the `m_nodes` member. * @param[in] sock Connected socket to communicate with the peer. - * @param[in] permission_flags The peer's permissions. * @param[in] addr_bind The address and port at our side of the connection. * @param[in] addr The address and port at the peer's side of the connection. */ void CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, - NetPermissionFlags permission_flags, const CService& addr_bind, const CService& addr); From 8658f832ba5502e3858f4d863583b8270449e063 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 17 Sep 2024 17:29:07 +0200 Subject: [PATCH 09/30] style: modernize the style of SockMan::AcceptConnection() --- src/common/sockman.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index 610771b90c9c2..87b9e9b9f53dc 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -90,19 +90,23 @@ bool SockMan::BindAndStartListening(const CService& to, bilingual_str& errmsg) std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CService& addr) { - struct sockaddr_storage sockaddr; - socklen_t len = sizeof(sockaddr); - auto sock = listen_sock.Accept((struct sockaddr*)&sockaddr, &len); + sockaddr_storage storage; + socklen_t len{sizeof(storage)}; + + auto sock{listen_sock.Accept(reinterpret_cast(&storage), &len)}; if (!sock) { - const int nErr = WSAGetLastError(); - if (nErr != WSAEWOULDBLOCK) { - LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); + const int err{WSAGetLastError()}; + if (err != WSAEWOULDBLOCK) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Error, + "Cannot accept new connection: %s\n", + NetworkErrorString(err)); } return {}; } - if (!addr.SetSockAddr((const struct sockaddr*)&sockaddr)) { + if (!addr.SetSockAddr(reinterpret_cast(&storage))) { LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Unknown socket family\n"); } From 1b05e1d4ba55a42ba74026b68fa4e616b973e06d Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 27 Aug 2024 11:25:24 +0200 Subject: [PATCH 10/30] net: move the generation of ids for new nodes from CConnman to SockMan --- src/common/sockman.cpp | 5 +++++ src/common/sockman.h | 15 +++++++++++++++ src/net.cpp | 5 ----- src/net.h | 5 ----- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index 87b9e9b9f53dc..35605170958f9 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -113,6 +113,11 @@ std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CServic return sock; } +NodeId SockMan::GetNewNodeId() +{ + return m_next_node_id.fetch_add(1, std::memory_order_relaxed); +} + void SockMan::CloseSockets() { m_listen.clear(); diff --git a/src/common/sockman.h b/src/common/sockman.h index 3aaed8df12307..540ab27a6897f 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -9,9 +9,12 @@ #include #include +#include #include #include +typedef int64_t NodeId; + /** * A socket manager class which handles socket operations. * To use this class, inherit from it and implement the pure virtual methods. @@ -39,6 +42,11 @@ class SockMan */ std::unique_ptr AcceptConnection(const Sock& listen_sock, CService& addr); + /** + * Generate an id for a newly created node. + */ + NodeId GetNewNodeId(); + /** * Close all sockets. */ @@ -48,6 +56,13 @@ class SockMan * List of listening sockets. */ std::vector> m_listen; + +private: + + /** + * The id to assign to the next created node. Used to generate ids of nodes. + */ + std::atomic m_next_node_id{0}; }; #endif // BITCOIN_COMMON_SOCKMAN_H diff --git a/src/net.cpp b/src/net.cpp index 51f71f2b27840..9b9fc6ba3aa59 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -3107,11 +3107,6 @@ CConnman::CConnman(uint64_t nSeed0In, uint64_t nSeed1In, AddrMan& addrman_in, SetNetworkActive(network_active); } -NodeId CConnman::GetNewNodeId() -{ - return nLastNodeId.fetch_add(1, std::memory_order_relaxed); -} - uint16_t CConnman::GetDefaultPort(Network net) const { return net == NET_I2P ? I2P_SAM31_PORT : m_params.GetDefaultPort(); diff --git a/src/net.h b/src/net.h index 01fc644e40c5a..531438ced0131 100644 --- a/src/net.h +++ b/src/net.h @@ -95,8 +95,6 @@ static const size_t DEFAULT_MAXSENDBUFFER = 1 * 1000; static constexpr bool DEFAULT_V2_TRANSPORT{true}; -typedef int64_t NodeId; - struct AddedNodeParams { std::string m_added_node; bool m_use_v2transport; @@ -1352,8 +1350,6 @@ class CConnman : private SockMan void DeleteNode(CNode* pnode); - NodeId GetNewNodeId(); - /** (Try to) send data from node's vSendMsg. Returns (bytes_sent, data_left). */ std::pair SocketSendData(CNode& node) const EXCLUSIVE_LOCKS_REQUIRED(node.cs_vSend); @@ -1433,7 +1429,6 @@ class CConnman : private SockMan std::vector m_nodes GUARDED_BY(m_nodes_mutex); std::list m_nodes_disconnected; mutable RecursiveMutex m_nodes_mutex; - std::atomic nLastNodeId{0}; unsigned int nPrevNodeCount{0}; // Stores number of full-tx connections (outbound and manual) per network From 14fcef6b0d1d1fa9395f9af2bafbf3de63d14ac2 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 27 Aug 2024 16:11:35 +0200 Subject: [PATCH 11/30] net: move CConnman-specific parts away from ThreadI2PAcceptIncoming() CConnman-specific or in other words, Bitcoin P2P specific. Now the `ThreadI2PAcceptIncoming()` method is protocol agnostic and can be moved to `SockMan`. --- src/common/sockman.cpp | 2 ++ src/common/sockman.h | 27 +++++++++++++++++++++++++++ src/net.cpp | 27 ++++++++++++++++++--------- src/net.h | 6 ++++++ 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index 35605170958f9..b474986031d00 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -122,3 +122,5 @@ void SockMan::CloseSockets() { m_listen.clear(); } + +void SockMan::EventI2PListen(const CService&, bool) {} diff --git a/src/common/sockman.h b/src/common/sockman.h index 540ab27a6897f..615971463f872 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -25,6 +25,13 @@ typedef int64_t NodeId; class SockMan { public: + + virtual ~SockMan() = default; + + // + // Non-virtual functions, to be reused by children classes. + // + /** * Bind to a new address:port, start listening and add the listen socket to `m_listen`. * @param[in] to Where to bind. @@ -59,6 +66,26 @@ class SockMan private: + // + // Pure virtual functions must be implemented by children classes. + // + + // + // Non-pure virtual functions can be overridden by children classes or left + // alone to use the default implementation from SockMan. + // + + /** + * Be notified of a change in the state of listening for incoming I2P connections. + * The default behavior, implemented by `SockMan`, is to ignore this event. + * @param[in] addr Our listening address. + * @param[in] success If true then the listen succeeded and we are now + * listening for incoming I2P connections at `addr`. If false then the + * call failed and now we are not listening (even if this was invoked + * before with `true`). + */ + virtual void EventI2PListen(const CService& addr, bool success); + /** * The id to assign to the next created node. Used to generate ids of nodes. */ diff --git a/src/net.cpp b/src/net.cpp index 9b9fc6ba3aa59..5b0bc6f10c5a2 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -3023,13 +3023,28 @@ void CConnman::ThreadMessageHandler() } } +void CConnman::EventI2PListen(const CService& addr, bool success) +{ + if (success) { + if (!m_i2p_advertising_listen_addr) { + AddLocal(addr, LOCAL_MANUAL); + m_i2p_advertising_listen_addr = true; + } + return; + } + // a failure to listen + if (m_i2p_advertising_listen_addr && addr.IsValid()) { + RemoveLocal(addr); + m_i2p_advertising_listen_addr = false; + } +} + void CConnman::ThreadI2PAcceptIncoming() { static constexpr auto err_wait_begin = 1s; static constexpr auto err_wait_cap = 5min; auto err_wait = err_wait_begin; - bool advertising_listen_addr = false; i2p::Connection conn; auto SleepOnFailure = [&]() { @@ -3042,18 +3057,12 @@ void CConnman::ThreadI2PAcceptIncoming() while (!interruptNet) { if (!m_i2p_sam_session->Listen(conn)) { - if (advertising_listen_addr && conn.me.IsValid()) { - RemoveLocal(conn.me); - advertising_listen_addr = false; - } + EventI2PListen(conn.me, /*success=*/false); SleepOnFailure(); continue; } - if (!advertising_listen_addr) { - AddLocal(conn.me, LOCAL_MANUAL); - advertising_listen_addr = true; - } + EventI2PListen(conn.me, /*success=*/true); if (!m_i2p_sam_session->Accept(conn)) { SleepOnFailure(); diff --git a/src/net.h b/src/net.h index 531438ced0131..e650d252e2465 100644 --- a/src/net.h +++ b/src/net.h @@ -1284,6 +1284,12 @@ class CConnman : private SockMan void ProcessAddrFetch() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_unused_i2p_sessions_mutex); void ThreadOpenConnections(std::vector connect, Span seed_nodes) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_added_nodes_mutex, !m_nodes_mutex, !m_unused_i2p_sessions_mutex, !m_reconnections_mutex); void ThreadMessageHandler() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); + + /// Whether we are currently advertising our I2P address (via `AddLocal()`). + bool m_i2p_advertising_listen_addr{false}; + + virtual void EventI2PListen(const CService& addr, bool success) override; + void ThreadI2PAcceptIncoming(); /** From 213ec07e7fa261fa1394992c2becc8f7f6d7ab4e Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 27 Aug 2024 16:23:31 +0200 Subject: [PATCH 12/30] net: move I2P-accept-incoming code from CConnman to SockMan --- src/CMakeLists.txt | 2 +- src/common/sockman.cpp | 55 +++++++++++++++++++++++++++++++++ src/common/sockman.h | 69 ++++++++++++++++++++++++++++++++++++++++++ src/net.cpp | 69 +++++++++--------------------------------- src/net.h | 31 ++++--------------- 5 files changed, 146 insertions(+), 80 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 40f4ec87e4f7e..a511215dc7886 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -131,6 +131,7 @@ add_library(bitcoin_common STATIC EXCLUDE_FROM_ALL core_write.cpp deploymentinfo.cpp external_signer.cpp + i2p.cpp init/common.cpp kernel/chainparams.cpp key.cpp @@ -209,7 +210,6 @@ add_library(bitcoin_node STATIC EXCLUDE_FROM_ALL headerssync.cpp httprpc.cpp httpserver.cpp - i2p.cpp index/base.cpp index/blockfilterindex.cpp index/coinstatsindex.cpp diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index b474986031d00..6d1a0ea1ef5a7 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -8,6 +8,7 @@ #include #include #include +#include bool SockMan::BindAndStartListening(const CService& to, bilingual_str& errmsg) { @@ -88,6 +89,24 @@ bool SockMan::BindAndStartListening(const CService& to, bilingual_str& errmsg) return true; } +void SockMan::StartSocketsThreads(const Options& options) +{ + if (options.i2p.has_value()) { + m_i2p_sam_session = std::make_unique( + options.i2p->private_key_file, options.i2p->sam_proxy, &interruptNet); + + m_thread_i2p_accept = + std::thread(&util::TraceThread, "i2paccept", [this] { ThreadI2PAccept(); }); + } +} + +void SockMan::JoinSocketsThreads() +{ + if (m_thread_i2p_accept.joinable()) { + m_thread_i2p_accept.join(); + } +} + std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CService& addr) { sockaddr_storage storage; @@ -124,3 +143,39 @@ void SockMan::CloseSockets() } void SockMan::EventI2PListen(const CService&, bool) {} + +void SockMan::ThreadI2PAccept() +{ + static constexpr auto err_wait_begin = 1s; + static constexpr auto err_wait_cap = 5min; + auto err_wait = err_wait_begin; + + i2p::Connection conn; + + auto SleepOnFailure = [&]() { + interruptNet.sleep_for(err_wait); + if (err_wait < err_wait_cap) { + err_wait += 1s; + } + }; + + while (!interruptNet) { + + if (!m_i2p_sam_session->Listen(conn)) { + EventI2PListen(conn.me, /*success=*/false); + SleepOnFailure(); + continue; + } + + EventI2PListen(conn.me, /*success=*/true); + + if (!m_i2p_sam_session->Accept(conn)) { + SleepOnFailure(); + continue; + } + + EventNewConnectionAccepted(std::move(conn.sock), conn.me, conn.peer); + + err_wait = err_wait_begin; + } +} diff --git a/src/common/sockman.h b/src/common/sockman.h index 615971463f872..b51de9b68e095 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -5,12 +5,16 @@ #ifndef BITCOIN_COMMON_SOCKMAN_H #define BITCOIN_COMMON_SOCKMAN_H +#include #include +#include +#include #include #include #include #include +#include #include typedef int64_t NodeId; @@ -20,6 +24,7 @@ typedef int64_t NodeId; * To use this class, inherit from it and implement the pure virtual methods. * Handled operations: * - binding and listening on sockets + * - starting of necessary threads to process socket operations * - accepting incoming connections */ class SockMan @@ -34,6 +39,7 @@ class SockMan /** * Bind to a new address:port, start listening and add the listen socket to `m_listen`. + * Should be called before `StartSocketsThreads()`. * @param[in] to Where to bind. * @param[out] errmsg Error string if an error occurs. * @retval true Success. @@ -41,6 +47,33 @@ class SockMan */ bool BindAndStartListening(const CService& to, bilingual_str& errmsg); + /** + * Options to influence `StartSocketsThreads()`. + */ + struct Options { + struct I2P { + explicit I2P(const fs::path& file, const Proxy& proxy) : private_key_file{file}, sam_proxy{proxy} {} + + const fs::path private_key_file; + const Proxy sam_proxy; + }; + + /** + * I2P options. If set then a thread will be started that will accept incoming I2P connections. + */ + std::optional i2p; + }; + + /** + * Start the necessary threads for sockets IO. + */ + void StartSocketsThreads(const Options& options); + + /** + * Join (wait for) the threads started by `StartSocketsThreads()` to exit. + */ + void JoinSocketsThreads(); + /** * Accept a connection. * @param[in] listen_sock Socket on which to accept the connection. @@ -59,6 +92,21 @@ class SockMan */ void CloseSockets(); + /** + * This is signaled when network activity should cease. + * A pointer to it is saved in `m_i2p_sam_session`, so make sure that + * the lifetime of `interruptNet` is not shorter than + * the lifetime of `m_i2p_sam_session`. + */ + CThreadInterrupt interruptNet; + + /** + * I2P SAM session. + * Used to accept incoming and make outgoing I2P connections from a persistent + * address. + */ + std::unique_ptr m_i2p_sam_session; + /** * List of listening sockets. */ @@ -70,6 +118,16 @@ class SockMan // Pure virtual functions must be implemented by children classes. // + /** + * Be notified when a new connection has been accepted. + * @param[in] sock Connected socket to communicate with the peer. + * @param[in] me The address and port at our side of the connection. + * @param[in] them The address and port at the peer's side of the connection. + */ + virtual void EventNewConnectionAccepted(std::unique_ptr&& sock, + const CService& me, + const CService& them) = 0; + // // Non-pure virtual functions can be overridden by children classes or left // alone to use the default implementation from SockMan. @@ -86,10 +144,21 @@ class SockMan */ virtual void EventI2PListen(const CService& addr, bool success); + /** + * Accept incoming I2P connections in a loop and call + * `EventNewConnectionAccepted()` for each new connection. + */ + void ThreadI2PAccept(); + /** * The id to assign to the next created node. Used to generate ids of nodes. */ std::atomic m_next_node_id{0}; + + /** + * Thread that accepts incoming I2P connections in a loop, can be stopped via `interruptNet`. + */ + std::thread m_thread_i2p_accept; }; #endif // BITCOIN_COMMON_SOCKMAN_H diff --git a/src/net.cpp b/src/net.cpp index 5b0bc6f10c5a2..353b226eefaa0 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1714,9 +1714,9 @@ bool CConnman::AttemptToEvictConnection() return false; } -void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, - const CService& addr_bind, - const CService& addr) +void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, + const CService& addr_bind, + const CService& addr) { int nInbound = 0; @@ -2177,7 +2177,7 @@ void CConnman::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock addr_accepted = MaybeFlipIPv6toCJDNS(addr_accepted); const CService addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock))}; - CreateNodeFromAcceptedSocket(std::move(sock_accepted), addr_bind, addr_accepted); + EventNewConnectionAccepted(std::move(sock_accepted), addr_bind, addr_accepted); } } } @@ -3039,42 +3039,6 @@ void CConnman::EventI2PListen(const CService& addr, bool success) } } -void CConnman::ThreadI2PAcceptIncoming() -{ - static constexpr auto err_wait_begin = 1s; - static constexpr auto err_wait_cap = 5min; - auto err_wait = err_wait_begin; - - i2p::Connection conn; - - auto SleepOnFailure = [&]() { - interruptNet.sleep_for(err_wait); - if (err_wait < err_wait_cap) { - err_wait += 1s; - } - }; - - while (!interruptNet) { - - if (!m_i2p_sam_session->Listen(conn)) { - EventI2PListen(conn.me, /*success=*/false); - SleepOnFailure(); - continue; - } - - EventI2PListen(conn.me, /*success=*/true); - - if (!m_i2p_sam_session->Accept(conn)) { - SleepOnFailure(); - continue; - } - - CreateNodeFromAcceptedSocket(std::move(conn.sock), conn.me, conn.peer); - - err_wait = err_wait_begin; - } -} - void Discover() { if (!fDiscover) @@ -3199,12 +3163,6 @@ bool CConnman::Start(CScheduler& scheduler, const Options& connOptions) return false; } - Proxy i2p_sam; - if (GetProxy(NET_I2P, i2p_sam) && connOptions.m_i2p_accept_incoming) { - m_i2p_sam_session = std::make_unique(gArgs.GetDataDirNet() / "i2p_private_key", - i2p_sam, &interruptNet); - } - // Randomize the order in which we may query seednode to potentially prevent connecting to the same one every restart (and signal that we have restarted) std::vector seed_nodes = connOptions.vSeedNodes; if (!seed_nodes.empty()) { @@ -3250,6 +3208,15 @@ bool CConnman::Start(CScheduler& scheduler, const Options& connOptions) // Send and receive from sockets, accept connections threadSocketHandler = std::thread(&util::TraceThread, "net", [this] { ThreadSocketHandler(); }); + SockMan::Options sockman_options; + + Proxy i2p_sam; + if (GetProxy(NET_I2P, i2p_sam) && connOptions.m_i2p_accept_incoming) { + sockman_options.i2p.emplace(gArgs.GetDataDirNet() / "i2p_private_key", i2p_sam); + } + + StartSocketsThreads(sockman_options); + if (!gArgs.GetBoolArg("-dnsseed", DEFAULT_DNSSEED)) LogPrintf("DNS seeding disabled\n"); else @@ -3275,11 +3242,6 @@ bool CConnman::Start(CScheduler& scheduler, const Options& connOptions) // Process messages threadMessageHandler = std::thread(&util::TraceThread, "msghand", [this] { ThreadMessageHandler(); }); - if (m_i2p_sam_session) { - threadI2PAcceptIncoming = - std::thread(&util::TraceThread, "i2paccept", [this] { ThreadI2PAcceptIncoming(); }); - } - // Dump network addresses scheduler.scheduleEvery([this] { DumpAddresses(); }, DUMP_PEERS_INTERVAL); @@ -3333,9 +3295,8 @@ void CConnman::Interrupt() void CConnman::StopThreads() { - if (threadI2PAcceptIncoming.joinable()) { - threadI2PAcceptIncoming.join(); - } + JoinSocketsThreads(); + if (threadMessageHandler.joinable()) threadMessageHandler.join(); if (threadOpenConnections.joinable()) diff --git a/src/net.h b/src/net.h index e650d252e2465..b33dc5c418c1d 100644 --- a/src/net.h +++ b/src/net.h @@ -1290,18 +1290,15 @@ class CConnman : private SockMan virtual void EventI2PListen(const CService& addr, bool success) override; - void ThreadI2PAcceptIncoming(); - /** - * Create a `CNode` object from a socket that has just been accepted and add the node to - * the `m_nodes` member. + * Create a `CNode` object and add it to the `m_nodes` member. * @param[in] sock Connected socket to communicate with the peer. - * @param[in] addr_bind The address and port at our side of the connection. - * @param[in] addr The address and port at the peer's side of the connection. + * @param[in] me The address and port at our side of the connection. + * @param[in] them The address and port at the peer's side of the connection. */ - void CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, - const CService& addr_bind, - const CService& addr); + virtual void EventNewConnectionAccepted(std::unique_ptr&& sock, + const CService& me, + const CService& them) override; void DisconnectNodes() EXCLUSIVE_LOCKS_REQUIRED(!m_reconnections_mutex, !m_nodes_mutex); void NotifyNumConnectionsChanged(); @@ -1529,27 +1526,11 @@ class CConnman : private SockMan Mutex mutexMsgProc; std::atomic flagInterruptMsgProc{false}; - /** - * This is signaled when network activity should cease. - * A pointer to it is saved in `m_i2p_sam_session`, so make sure that - * the lifetime of `interruptNet` is not shorter than - * the lifetime of `m_i2p_sam_session`. - */ - CThreadInterrupt interruptNet; - - /** - * I2P SAM session. - * Used to accept incoming and make outgoing I2P connections from a persistent - * address. - */ - std::unique_ptr m_i2p_sam_session; - std::thread threadDNSAddressSeed; std::thread threadSocketHandler; std::thread threadOpenAddedConnections; std::thread threadOpenConnections; std::thread threadMessageHandler; - std::thread threadI2PAcceptIncoming; /** flag for deciding to connect to an extra outbound peer, * in excess of m_max_outbound_full_relay From 21c5e05619c8a6eb736bb1c61725f4b5f669ffb4 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Fri, 20 Sep 2024 13:27:28 +0200 Subject: [PATCH 13/30] net: index nodes in CConnman by id Change `CConnman::m_nodes` from `std::vector` to `std::unordered_map` because interaction between `CConnman` and `SockMan` is going to be based on `NodeId` and finding a node by its id would better be fast. As a nice side effect the existent search-by-id operations in `CConnman::AttemptToEvictConnection()`, `CConnman::DisconnectNode()` and `CConnman::ForNode()` now become `O(1)` (were `O(number of nodes)`), as well as the erase in `CConnman::DisconnectNodes()`. --- src/net.cpp | 146 +++++++++++++------------ src/net.h | 12 +- src/net_processing.cpp | 13 ++- src/rpc/net.cpp | 4 + src/test/fuzz/connman.cpp | 5 +- src/test/net_peer_connection_tests.cpp | 10 +- src/test/util/net.h | 6 +- 7 files changed, 106 insertions(+), 90 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 353b226eefaa0..feb80adfdd8ed 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -331,7 +331,7 @@ bool IsLocal(const CService& addr) CNode* CConnman::FindNode(const CNetAddr& ip) { LOCK(m_nodes_mutex); - for (CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (static_cast(pnode->addr) == ip) { return pnode; } @@ -342,7 +342,7 @@ CNode* CConnman::FindNode(const CNetAddr& ip) CNode* CConnman::FindNode(const std::string& addrName) { LOCK(m_nodes_mutex); - for (CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->m_addr_name == addrName) { return pnode; } @@ -353,7 +353,7 @@ CNode* CConnman::FindNode(const std::string& addrName) CNode* CConnman::FindNode(const CService& addr) { LOCK(m_nodes_mutex); - for (CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (static_cast(pnode->addr) == addr) { return pnode; } @@ -369,7 +369,7 @@ bool CConnman::AlreadyConnectedToAddress(const CAddress& addr) bool CConnman::CheckIncomingNonce(uint64_t nonce) { LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (!pnode->fSuccessfullyConnected && !pnode->IsInboundConn() && pnode->GetLocalNonce() == nonce) return false; } @@ -1677,11 +1677,11 @@ bool CConnman::AttemptToEvictConnection() { LOCK(m_nodes_mutex); - for (const CNode* node : m_nodes) { + for (const auto& [id, node] : m_nodes) { if (node->fDisconnect) continue; NodeEvictionCandidate candidate{ - .id = node->GetId(), + .id = id, .m_connected = node->m_connected, .m_min_ping_time = node->m_min_ping_time, .m_last_block_time = node->m_last_block_time, @@ -1704,12 +1704,13 @@ bool CConnman::AttemptToEvictConnection() return false; } LOCK(m_nodes_mutex); - for (CNode* pnode : m_nodes) { - if (pnode->GetId() == *node_id_to_evict) { - LogDebug(BCLog::NET, "selected %s connection for eviction peer=%d; disconnecting\n", pnode->ConnectionTypeAsString(), pnode->GetId()); - pnode->fDisconnect = true; - return true; - } + auto it{m_nodes.find(*node_id_to_evict)}; + if (it != m_nodes.end()) { + auto id{it->first}; + auto node{it->second}; + LogDebug(BCLog::NET, "selected %s connection for eviction peer=%d; disconnecting\n", node->ConnectionTypeAsString(), id); + node->fDisconnect = true; + return true; } return false; } @@ -1730,7 +1731,7 @@ void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, { LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->IsInboundConn()) nInbound++; } } @@ -1806,7 +1807,7 @@ void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, m_msgproc->InitializeNode(*pnode, local_services); { LOCK(m_nodes_mutex); - m_nodes.push_back(pnode); + m_nodes.emplace(id, pnode); } LogDebug(BCLog::NET, "connection from %s accepted\n", addr.ToStringAddrPort()); @@ -1837,8 +1838,11 @@ bool CConnman::AddConnection(const std::string& address, ConnectionType conn_typ } // no default case, so the compiler can warn about missing cases // Count existing connections - int existing_connections = WITH_LOCK(m_nodes_mutex, - return std::count_if(m_nodes.begin(), m_nodes.end(), [conn_type](CNode* node) { return node->m_conn_type == conn_type; });); + int existing_connections = WITH_LOCK( + m_nodes_mutex, return std::count_if(m_nodes.begin(), m_nodes.end(), [conn_type](const auto& pair) { + const auto node{pair.second}; + return node->m_conn_type == conn_type; + });); // Max connections of specified type already exist if (max_connections != std::nullopt && existing_connections >= max_connections) return false; @@ -1865,7 +1869,7 @@ void CConnman::DisconnectNodes() if (!fNetworkActive) { // Disconnect any connected nodes - for (CNode* pnode : m_nodes) { + for (auto& [id, pnode] : m_nodes) { if (!pnode->fDisconnect) { LogDebug(BCLog::NET, "Network not active, %s\n", pnode->DisconnectMsg(fLogIPs)); pnode->fDisconnect = true; @@ -1874,40 +1878,42 @@ void CConnman::DisconnectNodes() } // Disconnect unused nodes - std::vector nodes_copy = m_nodes; - for (CNode* pnode : nodes_copy) - { - if (pnode->fDisconnect) - { - // remove from m_nodes - m_nodes.erase(remove(m_nodes.begin(), m_nodes.end(), pnode), m_nodes.end()); - - // Add to reconnection list if appropriate. We don't reconnect right here, because - // the creation of a connection is a blocking operation (up to several seconds), - // and we don't want to hold up the socket handler thread for that long. - if (pnode->m_transport->ShouldReconnectV1()) { - reconnections_to_add.push_back({ - .addr_connect = pnode->addr, - .grant = std::move(pnode->grantOutbound), - .destination = pnode->m_dest, - .conn_type = pnode->m_conn_type, - .use_v2transport = false}); - LogDebug(BCLog::NET, "retrying with v1 transport protocol for peer=%d\n", pnode->GetId()); - } + for (auto it{m_nodes.begin()}; it != m_nodes.end();) { + auto id{it->first}; + auto pnode{it->second}; - // release outbound grant (if any) - pnode->grantOutbound.Release(); + if (!pnode->fDisconnect) { + ++it; + continue; + } - // close socket and cleanup - pnode->CloseSocketDisconnect(); + it = m_nodes.erase(it); + + // Add to reconnection list if appropriate. We don't reconnect right here, because + // the creation of a connection is a blocking operation (up to several seconds), + // and we don't want to hold up the socket handler thread for that long. + if (pnode->m_transport->ShouldReconnectV1()) { + reconnections_to_add.push_back({ + .addr_connect = pnode->addr, + .grant = std::move(pnode->grantOutbound), + .destination = pnode->m_dest, + .conn_type = pnode->m_conn_type, + .use_v2transport = false}); + LogDebug(BCLog::NET, "retrying with v1 transport protocol for peer=%d\n", id); + } - // update connection count by network - if (pnode->IsManualOrFullOutboundConn()) --m_network_conn_counts[pnode->addr.GetNetwork()]; + // release outbound grant (if any) + pnode->grantOutbound.Release(); - // hold in disconnected pool until all refs are released - pnode->Release(); - m_nodes_disconnected.push_back(pnode); - } + // close socket and cleanup + pnode->CloseSocketDisconnect(); + + // update connection count by network + if (pnode->IsManualOrFullOutboundConn()) --m_network_conn_counts[pnode->addr.GetNetwork()]; + + // hold in disconnected pool until all refs are released + pnode->Release(); + m_nodes_disconnected.push_back(pnode); } } { @@ -2397,7 +2403,7 @@ int CConnman::GetFullOutboundConnCount() const int nRelevant = 0; { LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->fSuccessfullyConnected && pnode->IsFullOutboundConn()) ++nRelevant; } } @@ -2415,7 +2421,7 @@ int CConnman::GetExtraFullOutboundCount() const int full_outbound_peers = 0; { LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->fSuccessfullyConnected && !pnode->fDisconnect && pnode->IsFullOutboundConn()) { ++full_outbound_peers; } @@ -2429,7 +2435,7 @@ int CConnman::GetExtraBlockRelayCount() const int block_relay_peers = 0; { LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->fSuccessfullyConnected && !pnode->fDisconnect && pnode->IsBlockOnlyConn()) { ++block_relay_peers; } @@ -2600,7 +2606,7 @@ void CConnman::ThreadOpenConnections(const std::vector connect, Spa { LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->IsFullOutboundConn()) nOutboundFullRelay++; if (pnode->IsBlockOnlyConn()) nOutboundBlockRelay++; @@ -2845,7 +2851,7 @@ std::vector CConnman::GetCurrentBlockRelayOnlyConns() const { std::vector ret; LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->IsBlockOnlyConn()) { ret.push_back(pnode->addr); } @@ -2871,7 +2877,7 @@ std::vector CConnman::GetAddedNodeInfo(bool include_connected) co std::map> mapConnectedByName; { LOCK(m_nodes_mutex); - for (const CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (pnode->addr.IsValid()) { mapConnected[pnode->addr] = pnode->IsInboundConn(); } @@ -2975,7 +2981,7 @@ void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFai m_msgproc->InitializeNode(*pnode, m_local_services); { LOCK(m_nodes_mutex); - m_nodes.push_back(pnode); + m_nodes.emplace(pnode->GetId(), pnode); // update connection count by network if (pnode->IsManualOrFullOutboundConn()) ++m_network_conn_counts[pnode->addr.GetNetwork()]; @@ -3326,9 +3332,9 @@ void CConnman::StopNodes() } // Delete peer connections. - std::vector nodes; + decltype(m_nodes) nodes; WITH_LOCK(m_nodes_mutex, nodes.swap(m_nodes)); - for (CNode* pnode : nodes) { + for (auto& [id, pnode] : nodes) { LogDebug(BCLog::NET, "%s\n", pnode->DisconnectMsg(fLogIPs)); pnode->CloseSocketDisconnect(); DeleteNode(pnode); @@ -3457,7 +3463,7 @@ size_t CConnman::GetNodeCount(ConnectionDirection flags) const return m_nodes.size(); int nNum = 0; - for (const auto& pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { if (flags & (pnode->IsInboundConn() ? ConnectionDirection::In : ConnectionDirection::Out)) { nNum++; } @@ -3483,7 +3489,7 @@ void CConnman::GetNodeStats(std::vector& vstats) const vstats.clear(); LOCK(m_nodes_mutex); vstats.reserve(m_nodes.size()); - for (CNode* pnode : m_nodes) { + for (const auto& [id, pnode] : m_nodes) { vstats.emplace_back(); pnode->CopyStats(vstats.back()); vstats.back().m_mapped_as = GetMappedAS(pnode->addr); @@ -3505,7 +3511,7 @@ bool CConnman::DisconnectNode(const CSubNet& subnet) { bool disconnected = false; LOCK(m_nodes_mutex); - for (CNode* pnode : m_nodes) { + for (auto& [id, pnode] : m_nodes) { if (subnet.Match(pnode->addr)) { LogDebug(BCLog::NET, "disconnect by subnet%s matched peer=%d; disconnecting\n", (fLogIPs ? strprintf("=%s", subnet.ToString()) : ""), pnode->GetId()); pnode->fDisconnect = true; @@ -3523,14 +3529,14 @@ bool CConnman::DisconnectNode(const CNetAddr& addr) bool CConnman::DisconnectNode(NodeId id) { LOCK(m_nodes_mutex); - for(CNode* pnode : m_nodes) { - if (id == pnode->GetId()) { - LogDebug(BCLog::NET, "disconnect by id peer=%d; disconnecting\n", pnode->GetId()); - pnode->fDisconnect = true; - return true; - } + auto it{m_nodes.find(id)}; + if (it == m_nodes.end()) { + return false; } - return false; + auto node{it->second}; + LogDebug(BCLog::NET, "disconnect by id peer=%d; disconnecting\n", id); + node->fDisconnect = true; + return true; } void CConnman::RecordBytesRecv(uint64_t bytes) @@ -3775,11 +3781,9 @@ bool CConnman::ForNode(NodeId id, std::function func) { CNode* found = nullptr; LOCK(m_nodes_mutex); - for (auto&& pnode : m_nodes) { - if(pnode->GetId() == id) { - found = pnode; - break; - } + auto it{m_nodes.find(id)}; + if (it != m_nodes.end()) { + found = it->second; } return found != nullptr && NodeFullyConnected(found) && func(found); } diff --git a/src/net.h b/src/net.h index b33dc5c418c1d..81e376f91712c 100644 --- a/src/net.h +++ b/src/net.h @@ -44,6 +44,7 @@ #include #include #include +#include #include #include @@ -1150,7 +1151,7 @@ class CConnman : private SockMan void ForEachNode(const NodeFn& func) { LOCK(m_nodes_mutex); - for (auto&& node : m_nodes) { + for (auto& [id, node] : m_nodes) { if (NodeFullyConnected(node)) func(node); } @@ -1159,7 +1160,7 @@ class CConnman : private SockMan void ForEachNode(const NodeFn& func) const { LOCK(m_nodes_mutex); - for (auto&& node : m_nodes) { + for (auto& [id, node] : m_nodes) { if (NodeFullyConnected(node)) func(node); } @@ -1429,7 +1430,7 @@ class CConnman : private SockMan std::vector m_added_node_params GUARDED_BY(m_added_nodes_mutex); mutable Mutex m_added_nodes_mutex; - std::vector m_nodes GUARDED_BY(m_nodes_mutex); + std::unordered_map m_nodes GUARDED_BY(m_nodes_mutex); std::list m_nodes_disconnected; mutable RecursiveMutex m_nodes_mutex; unsigned int nPrevNodeCount{0}; @@ -1615,8 +1616,9 @@ class CConnman : private SockMan { { LOCK(connman.m_nodes_mutex); - m_nodes_copy = connman.m_nodes; - for (auto& node : m_nodes_copy) { + m_nodes_copy.reserve(connman.m_nodes.size()); + for (auto& [in, node] : connman.m_nodes) { + m_nodes_copy.push_back(node); node->AddRef(); } } diff --git a/src/net_processing.cpp b/src/net_processing.cpp index a19443c0f5615..1e9dba756e02d 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -5077,10 +5077,15 @@ void PeerManagerImpl::EvictExtraOutboundPeers(std::chrono::seconds now) m_connman.ForEachNode([&](CNode* pnode) { if (!pnode->IsBlockOnlyConn() || pnode->fDisconnect) return; - if (pnode->GetId() > youngest_peer.first) { - next_youngest_peer = youngest_peer; - youngest_peer.first = pnode->GetId(); - youngest_peer.second = pnode->m_last_block_time; + if (pnode->GetId() > next_youngest_peer.first) { + if (pnode->GetId() > youngest_peer.first) { + next_youngest_peer = youngest_peer; + youngest_peer.first = pnode->GetId(); + youngest_peer.second = pnode->m_last_block_time; + } else { + next_youngest_peer.first = pnode->GetId(); + next_youngest_peer.second = pnode->m_last_block_time; + } } }); NodeId to_disconnect = youngest_peer.first; diff --git a/src/rpc/net.cpp b/src/rpc/net.cpp index bda07365e0e76..26ab94fa9374e 100644 --- a/src/rpc/net.cpp +++ b/src/rpc/net.cpp @@ -203,6 +203,10 @@ static RPCHelpMan getpeerinfo() std::vector vstats; connman.GetNodeStats(vstats); + std::sort(vstats.begin(), vstats.end(), [](const CNodeStats& a, const CNodeStats& b) { + return a.nodeid < b.nodeid; + }); + UniValue ret(UniValue::VARR); for (const CNodeStats& stats : vstats) { diff --git a/src/test/fuzz/connman.cpp b/src/test/fuzz/connman.cpp index a62d227da8efb..5d2bdaf98b591 100644 --- a/src/test/fuzz/connman.cpp +++ b/src/test/fuzz/connman.cpp @@ -64,12 +64,13 @@ FUZZ_TARGET(connman, .init = initialize_connman) connman.Init(options); CNetAddr random_netaddr; - CNode random_node = ConsumeNode(fuzzed_data_provider); + NodeId node_id{0}; + CNode random_node = ConsumeNode(fuzzed_data_provider, node_id++); CSubNet random_subnet; std::string random_string; LIMITED_WHILE(fuzzed_data_provider.ConsumeBool(), 100) { - CNode& p2p_node{*ConsumeNodeAsUniquePtr(fuzzed_data_provider).release()}; + CNode& p2p_node{*ConsumeNodeAsUniquePtr(fuzzed_data_provider, node_id++).release()}; connman.AddTestNode(p2p_node); } diff --git a/src/test/net_peer_connection_tests.cpp b/src/test/net_peer_connection_tests.cpp index e60ce8b99d360..33e8a7cc07ad9 100644 --- a/src/test/net_peer_connection_tests.cpp +++ b/src/test/net_peer_connection_tests.cpp @@ -117,9 +117,9 @@ BOOST_FIXTURE_TEST_CASE(test_addnode_getaddednodeinfo_and_connection_detection, BOOST_CHECK_EQUAL(nodes.back()->ConnectedThroughNetwork(), Network::NET_CJDNS); BOOST_TEST_MESSAGE("Call AddNode() for all the peers"); - for (auto node : connman->TestNodes()) { + for (const auto& [id, node] : connman->TestNodes()) { BOOST_CHECK(connman->AddNode({/*m_added_node=*/node->addr.ToStringAddrPort(), /*m_use_v2transport=*/true})); - BOOST_TEST_MESSAGE(strprintf("peer id=%s addr=%s", node->GetId(), node->addr.ToStringAddrPort())); + BOOST_TEST_MESSAGE(strprintf("peer id=%s addr=%s", id, node->addr.ToStringAddrPort())); } BOOST_TEST_MESSAGE("\nCall AddNode() with 2 addrs resolving to existing localhost addnode entry; neither should be added"); @@ -134,7 +134,7 @@ BOOST_FIXTURE_TEST_CASE(test_addnode_getaddednodeinfo_and_connection_detection, BOOST_CHECK(connman->GetAddedNodeInfo(/*include_connected=*/false).empty()); // Test AddedNodesContain() - for (auto node : connman->TestNodes()) { + for (const auto& [id, node] : connman->TestNodes()) { BOOST_CHECK(connman->AddedNodesContain(node->addr)); } AddPeer(id, nodes, *peerman, *connman, ConnectionType::OUTBOUND_FULL_RELAY); @@ -151,12 +151,12 @@ BOOST_FIXTURE_TEST_CASE(test_addnode_getaddednodeinfo_and_connection_detection, } BOOST_TEST_MESSAGE("\nCheck that all connected peers are correctly detected as connected"); - for (auto node : connman->TestNodes()) { + for (const auto& [id, node] : connman->TestNodes()) { BOOST_CHECK(connman->AlreadyConnectedPublic(node->addr)); } // Clean up - for (auto node : connman->TestNodes()) { + for (const auto& [id, node] : connman->TestNodes()) { peerman->FinalizeNode(*node); } connman->ClearTestNodes(); diff --git a/src/test/util/net.h b/src/test/util/net.h index d3aefda4f0e70..99872508363af 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -42,7 +42,7 @@ struct ConnmanTestMsg : public CConnman { m_peer_connect_timeout = timeout; } - std::vector TestNodes() + auto TestNodes() { LOCK(m_nodes_mutex); return m_nodes; @@ -51,7 +51,7 @@ struct ConnmanTestMsg : public CConnman { void AddTestNode(CNode& node) { LOCK(m_nodes_mutex); - m_nodes.push_back(&node); + m_nodes.emplace(node.GetId(), &node); if (node.IsManualOrFullOutboundConn()) ++m_network_conn_counts[node.addr.GetNetwork()]; } @@ -59,7 +59,7 @@ struct ConnmanTestMsg : public CConnman { void ClearTestNodes() { LOCK(m_nodes_mutex); - for (CNode* node : m_nodes) { + for (auto& [id, node] : m_nodes) { delete node; } m_nodes.clear(); From dc6393cb93c4851a363b69fd474656cac1ae3b3b Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Sat, 21 Sep 2024 10:05:34 +0200 Subject: [PATCH 14/30] net: isolate P2P specifics from GenerateWaitSockets() Move the parts of `CConnman::GenerateWaitSockets()` that are specific to the Bitcoin-P2P protocol to dedicated methods: `ShouldTryToSend()` and `ShouldTryToRecv()`. This brings us one step closer to moving `GenerateWaitSockets()` to the protocol agnostic `SockMan` (which would call `ShouldTry...()` from `CConnman`). --- src/common/sockman.cpp | 4 ++++ src/common/sockman.h | 16 +++++++++++++ src/net.cpp | 52 ++++++++++++++++++++++++++++++++++-------- src/net.h | 14 ++++++++++-- 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index 6d1a0ea1ef5a7..4d9db32bfd5aa 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -142,6 +142,10 @@ void SockMan::CloseSockets() m_listen.clear(); } +bool SockMan::ShouldTryToSend(NodeId node_id) const { return true; } + +bool SockMan::ShouldTryToRecv(NodeId node_id) const { return true; } + void SockMan::EventI2PListen(const CService&, bool) {} void SockMan::ThreadI2PAccept() diff --git a/src/common/sockman.h b/src/common/sockman.h index b51de9b68e095..e030b91dd4905 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -133,6 +133,22 @@ class SockMan // alone to use the default implementation from SockMan. // + /** + * SockMan would only call EventReadyToSend() if this returns true. + * Can be used to temporarily pause sends for a node. + * The implementation in SockMan always returns true. + * @param[in] node_id Node for which to confirm or cancel a call to EventReadyToSend(). + */ + virtual bool ShouldTryToSend(NodeId node_id) const; + + /** + * SockMan would only call Recv() on a node's socket if this returns true. + * Can be used to temporarily pause receives for a node. + * The implementation in SockMan always returns true. + * @param[in] node_id Node for which to confirm or cancel a receive. + */ + virtual bool ShouldTryToRecv(NodeId node_id) const; + /** * Be notified of a change in the state of listening for incoming I2P connections. * The default behavior, implemented by `SockMan`, is to ignore this event. diff --git a/src/net.cpp b/src/net.cpp index feb80adfdd8ed..0b3f577e99e59 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1663,6 +1663,16 @@ std::pair CConnman::SocketSendData(CNode& node) const return {nSentSize, data_left}; } +CNode* CConnman::GetNodeById(NodeId node_id) const +{ + LOCK(m_nodes_mutex); + auto it{m_nodes.find(node_id)}; + if (it != m_nodes.end()) { + return it->second; + } + return nullptr; +} + /** Try to find a connection to evict when the node is full. * Extreme care must be taken to avoid opening the node to attacker * triggered network partitioning. @@ -2009,8 +2019,37 @@ bool CConnman::InactivityCheck(const CNode& node) const return false; } +bool CConnman::ShouldTryToSend(NodeId node_id) const +{ + AssertLockNotHeld(m_nodes_mutex); + + CNode* node{GetNodeById(node_id)}; + if (node == nullptr) { + return false; + } + LOCK(node->cs_vSend); + // Sending is possible if either there are bytes to send right now, or if there will be + // once a potential message from vSendMsg is handed to the transport. GetBytesToSend + // determines both of these in a single call. + const auto& [to_send, more, _msg_type] = node->m_transport->GetBytesToSend(!node->vSendMsg.empty()); + return !to_send.empty() || more; +} + +bool CConnman::ShouldTryToRecv(NodeId node_id) const +{ + AssertLockNotHeld(m_nodes_mutex); + + CNode* node{GetNodeById(node_id)}; + if (node == nullptr) { + return false; + } + return !node->fPauseRecv; +} + Sock::EventsPerSock CConnman::GenerateWaitSockets(Span nodes) { + AssertLockNotHeld(m_nodes_mutex); + Sock::EventsPerSock events_per_sock; for (const auto& sock : m_listen) { @@ -2018,16 +2057,8 @@ Sock::EventsPerSock CConnman::GenerateWaitSockets(Span nodes) } for (CNode* pnode : nodes) { - bool select_recv = !pnode->fPauseRecv; - bool select_send; - { - LOCK(pnode->cs_vSend); - // Sending is possible if either there are bytes to send right now, or if there will be - // once a potential message from vSendMsg is handed to the transport. GetBytesToSend - // determines both of these in a single call. - const auto& [to_send, more, _msg_type] = pnode->m_transport->GetBytesToSend(!pnode->vSendMsg.empty()); - select_send = !to_send.empty() || more; - } + const bool select_recv{ShouldTryToRecv(pnode->GetId())}; + const bool select_send{ShouldTryToSend(pnode->GetId())}; if (!select_recv && !select_send) continue; LOCK(pnode->m_sock_mutex); @@ -2191,6 +2222,7 @@ void CConnman::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock void CConnman::ThreadSocketHandler() { + AssertLockNotHeld(m_nodes_mutex); AssertLockNotHeld(m_total_bytes_sent_mutex); while (!interruptNet) diff --git a/src/net.h b/src/net.h index 81e376f91712c..e9526a55d54f5 100644 --- a/src/net.h +++ b/src/net.h @@ -1306,17 +1306,25 @@ class CConnman : private SockMan /** Return true if the peer is inactive and should be disconnected. */ bool InactivityCheck(const CNode& node) const; + virtual bool ShouldTryToSend(NodeId node_id) const override + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + + virtual bool ShouldTryToRecv(NodeId node_id) const override + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + /** * Generate a collection of sockets to check for IO readiness. * @param[in] nodes Select from these nodes' sockets. * @return sockets to check for readiness */ - Sock::EventsPerSock GenerateWaitSockets(Span nodes); + Sock::EventsPerSock GenerateWaitSockets(Span nodes) + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); /** * Check connected and listening sockets for IO readiness and process them accordingly. */ - void SocketHandler() EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc); + void SocketHandler() + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex, !m_total_bytes_sent_mutex, !mutexMsgProc); /** * Do the read/write for connected sockets that are ready for IO. @@ -1429,6 +1437,8 @@ class CConnman : private SockMan // connection string and whether to use v2 p2p std::vector m_added_node_params GUARDED_BY(m_added_nodes_mutex); + CNode* GetNodeById(NodeId node_id) const EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + mutable Mutex m_added_nodes_mutex; std::unordered_map m_nodes GUARDED_BY(m_nodes_mutex); std::list m_nodes_disconnected; From 7a4a9307632ad95a32a5b236ba575c68a44cfa57 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Sat, 21 Sep 2024 10:31:53 +0200 Subject: [PATCH 15/30] net: isolate P2P specifics from SocketHandlerConnected() and ThreadSocketHandler() Move some parts of `CConnman::SocketHandlerConnected()` and `CConnman::ThreadSocketHandler()` that are specific to the Bitcoin-P2P protocol to dedicated methods: `EventIOLoopCompletedForNode()` and `EventIOLoopCompletedForAllPeers()`. This brings us one step closer to moving `SocketHandlerConnected()` and `ThreadSocketHandler()` to the protocol agnostic `SockMan` (which would call `EventIOLoopCompleted...()` from `CConnman`). --- src/common/sockman.cpp | 4 ++++ src/common/sockman.h | 17 +++++++++++++++++ src/net.cpp | 29 ++++++++++++++++++++++++++--- src/net.h | 8 +++++++- 4 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index 4d9db32bfd5aa..0a9aa1a72904f 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -146,6 +146,10 @@ bool SockMan::ShouldTryToSend(NodeId node_id) const { return true; } bool SockMan::ShouldTryToRecv(NodeId node_id) const { return true; } +void SockMan::EventIOLoopCompletedForNode(NodeId node_id) {} + +void SockMan::EventIOLoopCompletedForAllPeers() {} + void SockMan::EventI2PListen(const CService&, bool) {} void SockMan::ThreadI2PAccept() diff --git a/src/common/sockman.h b/src/common/sockman.h index e030b91dd4905..3a7eb91819fc2 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -149,6 +149,23 @@ class SockMan */ virtual bool ShouldTryToRecv(NodeId node_id) const; + /** + * SockMan has completed the current send+recv iteration for a node. + * It will do another send+recv for this node after processing all other nodes. + * Can be used to execute periodic tasks for a given node. + * The implementation in SockMan does nothing. + * @param[in] node_id Node for which send+recv has been done. + */ + virtual void EventIOLoopCompletedForNode(NodeId node_id); + + /** + * SockMan has completed send+recv for all nodes. + * Can be used to execute periodic tasks for all nodes, like disconnecting + * nodes due to higher level logic. + * The implementation in SockMan does nothing. + */ + virtual void EventIOLoopCompletedForAllPeers(); + /** * Be notified of a change in the state of listening for incoming I2P connections. * The default behavior, implemented by `SockMan`, is to ignore this event. diff --git a/src/net.cpp b/src/net.cpp index 0b3f577e99e59..821ab343f159f 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -2046,6 +2046,29 @@ bool CConnman::ShouldTryToRecv(NodeId node_id) const return !node->fPauseRecv; } +void CConnman::EventIOLoopCompletedForNode(NodeId node_id) +{ + AssertLockNotHeld(m_nodes_mutex); + + CNode* node{GetNodeById(node_id)}; + if (node == nullptr) { + return; + } + + if (InactivityCheck(*node)) { + node->fDisconnect = true; + } +} + +void CConnman::EventIOLoopCompletedForAllPeers() +{ + AssertLockNotHeld(m_nodes_mutex); + AssertLockNotHeld(m_reconnections_mutex); + + DisconnectNodes(); + NotifyNumConnectionsChanged(); +} + Sock::EventsPerSock CConnman::GenerateWaitSockets(Span nodes) { AssertLockNotHeld(m_nodes_mutex); @@ -2102,6 +2125,7 @@ void CConnman::SocketHandler() void CConnman::SocketHandlerConnected(const std::vector& nodes, const Sock::EventsPerSock& events_per_sock) { + AssertLockNotHeld(m_nodes_mutex); AssertLockNotHeld(m_total_bytes_sent_mutex); for (CNode* pnode : nodes) { @@ -2194,7 +2218,7 @@ void CConnman::SocketHandlerConnected(const std::vector& nodes, } } - if (InactivityCheck(*pnode)) pnode->fDisconnect = true; + EventIOLoopCompletedForNode(pnode->GetId()); } } @@ -2227,8 +2251,7 @@ void CConnman::ThreadSocketHandler() while (!interruptNet) { - DisconnectNodes(); - NotifyNumConnectionsChanged(); + EventIOLoopCompletedForAllPeers(); SocketHandler(); } } diff --git a/src/net.h b/src/net.h index e9526a55d54f5..d9b2f4d2f0de3 100644 --- a/src/net.h +++ b/src/net.h @@ -1312,6 +1312,12 @@ class CConnman : private SockMan virtual bool ShouldTryToRecv(NodeId node_id) const override EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + virtual void EventIOLoopCompletedForNode(NodeId node_id) override + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + + virtual void EventIOLoopCompletedForAllPeers() override + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex, !m_reconnections_mutex); + /** * Generate a collection of sockets to check for IO readiness. * @param[in] nodes Select from these nodes' sockets. @@ -1333,7 +1339,7 @@ class CConnman : private SockMan */ void SocketHandlerConnected(const std::vector& nodes, const Sock::EventsPerSock& events_per_sock) - EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc); + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex, !m_total_bytes_sent_mutex, !mutexMsgProc); /** * Accept incoming connections, one from each read-ready listening socket. From 5eae99d1224f8972b8961cbea397ce9a147cfe1e Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Sun, 22 Sep 2024 12:11:42 +0200 Subject: [PATCH 16/30] net: isolate all remaining P2P specifics from SocketHandlerConnected() Introduce 4 new methods for the interaction between `CConnman` and `SockMan`: * `EventReadyToSend()`: called when there is readiness to send and do the actual sending of data. * `EventGotData()`, `EventGotEOF()`, `EventGotPermanentReadError()`: called when the corresponing recv events occur. These methods contain logic that is specific to the Bitcoin-P2P protocol and move it away from `CConnman::SocketHandlerConnected()` which will become a protocol agnostic method of `SockMan`. Also, move the counting of sent bytes to `CConnman::SocketSendData()` - both callers of that method called `RecordBytesSent()` just after the call, so move it from the callers to inside `CConnman::SocketSendData()`. --- src/common/sockman.h | 33 +++++++++++ src/net.cpp | 134 ++++++++++++++++++++++++++++++------------- src/net.h | 15 ++++- 3 files changed, 142 insertions(+), 40 deletions(-) diff --git a/src/common/sockman.h b/src/common/sockman.h index 3a7eb91819fc2..af089243e9d99 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -128,6 +128,39 @@ class SockMan const CService& me, const CService& them) = 0; + /** + * Called when the socket is ready to send data and `ShouldTryToSend()` has + * returned true. This is where the higher level code serializes its messages + * and calls `SockMan::SendBytes()`. + * @param[in] node_id Id of the node whose socket is ready to send. + * @param[out] cancel_recv Should always be set upon return and if it is true, + * then the next attempt to receive data from that node will be omitted. + */ + virtual void EventReadyToSend(NodeId node_id, bool& cancel_recv) = 0; + + /** + * Called when new data has been received. + * @param[in] node_id Node for which the data arrived. + * @param[in] data Data buffer. + * @param[in] n Number of bytes in `data`. + */ + virtual void EventGotData(NodeId node_id, const uint8_t* data, size_t n) = 0; + + /** + * Called when the remote peer has sent an EOF on the socket. This is a graceful + * close of their writing side, we can still send and they will receive, if it + * makes sense at the application level. + * @param[in] node_id Node whose socket got EOF. + */ + virtual void EventGotEOF(NodeId node_id) = 0; + + /** + * Called when we get an irrecoverable error trying to read from a socket. + * @param[in] node_id Node whose socket got an error. + * @param[in] errmsg Message describing the error. + */ + virtual void EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) = 0; + // // Non-pure virtual functions can be overridden by children classes or left // alone to use the default implementation from SockMan. diff --git a/src/net.cpp b/src/net.cpp index 821ab343f159f..15e2c233d1ddb 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1584,8 +1584,10 @@ Transport::Info V2Transport::GetInfo() const noexcept return info; } -std::pair CConnman::SocketSendData(CNode& node) const +std::pair CConnman::SocketSendData(CNode& node) { + AssertLockNotHeld(m_total_bytes_sent_mutex); + auto it = node.vSendMsg.begin(); size_t nSentSize = 0; bool data_left{false}; //!< second return value (whether unsent data remains) @@ -1660,6 +1662,11 @@ std::pair CConnman::SocketSendData(CNode& node) const assert(node.m_send_memusage == 0); } node.vSendMsg.erase(node.vSendMsg.begin(), it); + + if (nSentSize > 0) { + RecordBytesSent(nSentSize); + } + return {nSentSize, data_left}; } @@ -2019,6 +2026,83 @@ bool CConnman::InactivityCheck(const CNode& node) const return false; } +void CConnman::EventReadyToSend(NodeId node_id, bool& cancel_recv) +{ + AssertLockNotHeld(m_nodes_mutex); + + CNode* node{GetNodeById(node_id)}; + if (node == nullptr) { + cancel_recv = true; + return; + } + + const auto [bytes_sent, data_left] = WITH_LOCK(node->cs_vSend, return SocketSendData(*node);); + + // If both receiving and (non-optimistic) sending were possible, we first attempt + // sending. If that succeeds, but does not fully drain the send queue, do not + // attempt to receive. This avoids needlessly queueing data if the remote peer + // is slow at receiving data, by means of TCP flow control. We only do this when + // sending actually succeeded to make sure progress is always made; otherwise a + // deadlock would be possible when both sides have data to send, but neither is + // receiving. + cancel_recv = bytes_sent > 0 && data_left; +} + +void CConnman::EventGotData(NodeId node_id, const uint8_t* data, size_t n) +{ + AssertLockNotHeld(mutexMsgProc); + AssertLockNotHeld(m_nodes_mutex); + + CNode* node{GetNodeById(node_id)}; + if (node == nullptr) { + return; + } + + bool notify = false; + if (!node->ReceiveMsgBytes({data, n}, notify)) { + LogDebug(BCLog::NET, + "receiving message bytes failed, %s\n", + node->DisconnectMsg(fLogIPs) + ); + node->CloseSocketDisconnect(); + } + RecordBytesRecv(n); + if (notify) { + node->MarkReceivedMsgsForProcessing(); + WakeMessageHandler(); + } +} + +void CConnman::EventGotEOF(NodeId node_id) +{ + AssertLockNotHeld(m_nodes_mutex); + + CNode* node{GetNodeById(node_id)}; + if (node == nullptr) { + return; + } + + if (!node->fDisconnect) { + LogDebug(BCLog::NET, "socket closed for peer=%d\n", node_id); + } + node->CloseSocketDisconnect(); +} + +void CConnman::EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) +{ + AssertLockNotHeld(m_nodes_mutex); + + CNode* node{GetNodeById(node_id)}; + if (node == nullptr) { + return; + } + + if (!node->fDisconnect) { + LogDebug(BCLog::NET, "socket recv error for peer=%d: %s\n", node_id, errmsg); + } + node->CloseSocketDisconnect(); +} + bool CConnman::ShouldTryToSend(NodeId node_id) const { AssertLockNotHeld(m_nodes_mutex); @@ -2152,19 +2236,12 @@ void CConnman::SocketHandlerConnected(const std::vector& nodes, } if (sendSet) { - // Send data - auto [bytes_sent, data_left] = WITH_LOCK(pnode->cs_vSend, return SocketSendData(*pnode)); - if (bytes_sent) { - RecordBytesSent(bytes_sent); - - // If both receiving and (non-optimistic) sending were possible, we first attempt - // sending. If that succeeds, but does not fully drain the send queue, do not - // attempt to receive. This avoids needlessly queueing data if the remote peer - // is slow at receiving data, by means of TCP flow control. We only do this when - // sending actually succeeded to make sure progress is always made; otherwise a - // deadlock would be possible when both sides have data to send, but neither is - // receiving. - if (data_left) recvSet = false; + bool cancel_recv; + + EventReadyToSend(pnode->GetId(), cancel_recv); + + if (cancel_recv) { + recvSet = false; } } @@ -2182,27 +2259,11 @@ void CConnman::SocketHandlerConnected(const std::vector& nodes, } if (nBytes > 0) { - bool notify = false; - if (!pnode->ReceiveMsgBytes({pchBuf, (size_t)nBytes}, notify)) { - LogDebug(BCLog::NET, - "receiving message bytes failed, %s\n", - pnode->DisconnectMsg(fLogIPs) - ); - pnode->CloseSocketDisconnect(); - } - RecordBytesRecv(nBytes); - if (notify) { - pnode->MarkReceivedMsgsForProcessing(); - WakeMessageHandler(); - } + EventGotData(pnode->GetId(), pchBuf, nBytes); } else if (nBytes == 0) { - // socket closed gracefully - if (!pnode->fDisconnect) { - LogDebug(BCLog::NET, "socket closed, %s\n", pnode->DisconnectMsg(fLogIPs)); - } - pnode->CloseSocketDisconnect(); + EventGotEOF(pnode->GetId()); } else if (nBytes < 0) { @@ -2210,10 +2271,7 @@ void CConnman::SocketHandlerConnected(const std::vector& nodes, int nErr = WSAGetLastError(); if (nErr != WSAEWOULDBLOCK && nErr != WSAEMSGSIZE && nErr != WSAEINTR && nErr != WSAEINPROGRESS) { - if (!pnode->fDisconnect) { - LogDebug(BCLog::NET, "socket recv error, %s: %s\n", pnode->DisconnectMsg(fLogIPs), NetworkErrorString(nErr)); - } - pnode->CloseSocketDisconnect(); + EventGotPermanentReadError(pnode->GetId(), NetworkErrorString(nErr)); } } } @@ -3803,7 +3861,6 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) msg.data.data() ); - size_t nBytesSent = 0; { LOCK(pnode->cs_vSend); // Check if the transport still has unsent bytes, and indicate to it that we're about to @@ -3826,10 +3883,9 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) // results in sendable bytes there, but with V2Transport this is not the case (it may // still be in the handshake). if (queue_was_empty && more) { - std::tie(nBytesSent, std::ignore) = SocketSendData(*pnode); + SocketSendData(*pnode); } } - if (nBytesSent) RecordBytesSent(nBytesSent); } bool CConnman::ForNode(NodeId id, std::function func) diff --git a/src/net.h b/src/net.h index d9b2f4d2f0de3..477b1f14cdcc1 100644 --- a/src/net.h +++ b/src/net.h @@ -1306,6 +1306,18 @@ class CConnman : private SockMan /** Return true if the peer is inactive and should be disconnected. */ bool InactivityCheck(const CNode& node) const; + void EventReadyToSend(NodeId node_id, bool& cancel_recv) override + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + + virtual void EventGotData(NodeId node_id, const uint8_t* data, size_t n) override + EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc, !m_nodes_mutex); + + virtual void EventGotEOF(NodeId node_id) override + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + + virtual void EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) override + EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); + virtual bool ShouldTryToSend(NodeId node_id) const override EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); @@ -1369,7 +1381,8 @@ class CConnman : private SockMan void DeleteNode(CNode* pnode); /** (Try to) send data from node's vSendMsg. Returns (bytes_sent, data_left). */ - std::pair SocketSendData(CNode& node) const EXCLUSIVE_LOCKS_REQUIRED(node.cs_vSend); + std::pair SocketSendData(CNode& node) + EXCLUSIVE_LOCKS_REQUIRED(node.cs_vSend, !m_total_bytes_sent_mutex); void DumpAddresses(); From d4de6c946c38f459cca8b1d8f650b9579ae0ae1e Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 23 Sep 2024 12:50:25 +0200 Subject: [PATCH 17/30] net: split CConnman::ConnectNode() Move the protocol agnostic parts of `CConnman::ConnectNode()` into `SockMan::ConnectAndMakeNodeId()` and leave the Bitcoin-P2P specific stuff in `CConnman::ConnectNode()`. Move the protocol agnostic `CConnman::m_unused_i2p_sessions`, its mutex and `MAX_UNUSED_I2P_SESSIONS_SIZE` to `SockMan`. Move `GetBindAddress()` from `net.cpp` to `sockman.cpp`. --- src/common/sockman.cpp | 90 +++++++++++++++++++++++++++++++++++++ src/common/sockman.h | 56 +++++++++++++++++++++++ src/net.cpp | 100 ++++++++++++----------------------------- src/net.h | 35 +++------------ src/test/util/net.h | 3 +- 5 files changed, 183 insertions(+), 101 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index 0a9aa1a72904f..a696787e1b441 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -10,6 +10,19 @@ #include #include +CService GetBindAddress(const Sock& sock) +{ + CService addr_bind; + struct sockaddr_storage sockaddr_bind; + socklen_t sockaddr_bind_len = sizeof(sockaddr_bind); + if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { + addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind); + } else { + LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "getsockname failed\n"); + } + return addr_bind; +} + bool SockMan::BindAndStartListening(const CService& to, bilingual_str& errmsg) { // Create socket for listening for incoming connections @@ -107,6 +120,83 @@ void SockMan::JoinSocketsThreads() } } +std::optional +SockMan::ConnectAndMakeNodeId(const std::variant& to, + bool is_important, + const Proxy& proxy, + bool& proxy_failed, + CService& me, + std::unique_ptr& sock, + std::unique_ptr& i2p_transient_session) +{ + AssertLockNotHeld(m_unused_i2p_sessions_mutex); + + Assume(!me.IsValid()); + + if (std::holds_alternative(to)) { + const CService& addr_to{std::get(to)}; + if (addr_to.IsI2P()) { + if (!Assume(proxy.IsValid())) { + return std::nullopt; + } + + i2p::Connection conn; + bool connected{false}; + + if (m_i2p_sam_session) { + connected = m_i2p_sam_session->Connect(addr_to, conn, proxy_failed); + } else { + { + LOCK(m_unused_i2p_sessions_mutex); + if (m_unused_i2p_sessions.empty()) { + i2p_transient_session = std::make_unique(proxy, &interruptNet); + } else { + i2p_transient_session.swap(m_unused_i2p_sessions.front()); + m_unused_i2p_sessions.pop(); + } + } + connected = i2p_transient_session->Connect(addr_to, conn, proxy_failed); + if (!connected) { + LOCK(m_unused_i2p_sessions_mutex); + if (m_unused_i2p_sessions.size() < MAX_UNUSED_I2P_SESSIONS_SIZE) { + m_unused_i2p_sessions.emplace(i2p_transient_session.release()); + } + } + } + + if (connected) { + sock = std::move(conn.sock); + me = conn.me; + } + } else if (proxy.IsValid()) { + sock = ConnectThroughProxy(proxy, addr_to.ToStringAddr(), addr_to.GetPort(), proxy_failed); + } else { + sock = ConnectDirectly(addr_to, is_important); + } + } else { + if (!Assume(proxy.IsValid())) { + return std::nullopt; + } + + const auto& hostport{std::get(to)}; + + bool dummy_proxy_failed; + sock = ConnectThroughProxy(proxy, hostport.host, hostport.port, dummy_proxy_failed); + } + + if (!sock) { + return std::nullopt; + } + + if (!me.IsValid()) { + me = GetBindAddress(*sock); + } + + const NodeId node_id{GetNewNodeId()}; + + return node_id; +} + std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CService& addr) { sockaddr_storage storage; diff --git a/src/common/sockman.h b/src/common/sockman.h index af089243e9d99..c81d30a12aefb 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -14,11 +14,15 @@ #include #include +#include #include +#include #include typedef int64_t NodeId; +CService GetBindAddress(const Sock& sock); + /** * A socket manager class which handles socket operations. * To use this class, inherit from it and implement the pure virtual methods. @@ -26,6 +30,7 @@ typedef int64_t NodeId; * - binding and listening on sockets * - starting of necessary threads to process socket operations * - accepting incoming connections + * - making outbound connections */ class SockMan { @@ -74,6 +79,37 @@ class SockMan */ void JoinSocketsThreads(); + /** + * A more readable std::tuple for host and port. + */ + struct StringHostIntPort { + const std::string& host; + uint16_t port; + }; + + /** + * Make an outbound connection, save the socket internally and return a newly generated node id. + * @param[in] to The address to connect to, either as CService or a host as string and port as + * an integer, if the later is used, then `proxy` must be valid. + * @param[in] is_important If true, then log failures with higher severity. + * @param[in] proxy Proxy to connect through if `proxy.IsValid()` is true. + * @param[out] proxy_failed If `proxy` is valid and the connection failed because of the + * proxy, then it will be set to true. + * @param[out] me If the connection was successful then this is set to the address on the + * local side of the socket. + * @param[out] sock Connected socket, if the operation is successful. + * @param[out] i2p_transient_session I2P session, if the operation is successful. + * @return Newly generated node id, or std::nullopt if the operation fails. + */ + std::optional ConnectAndMakeNodeId(const std::variant& to, + bool is_important, + const Proxy& proxy, + bool& proxy_failed, + CService& me, + std::unique_ptr& sock, + std::unique_ptr& i2p_transient_session) + EXCLUSIVE_LOCKS_REQUIRED(!m_unused_i2p_sessions_mutex); + /** * Accept a connection. * @param[in] listen_sock Socket on which to accept the connection. @@ -114,6 +150,12 @@ class SockMan private: + /** + * Cap on the size of `m_unused_i2p_sessions`, to ensure it does not + * unexpectedly use too much memory. + */ + static constexpr size_t MAX_UNUSED_I2P_SESSIONS_SIZE{10}; + // // Pure virtual functions must be implemented by children classes. // @@ -225,6 +267,20 @@ class SockMan * Thread that accepts incoming I2P connections in a loop, can be stopped via `interruptNet`. */ std::thread m_thread_i2p_accept; + + /** + * Mutex protecting m_i2p_sam_sessions. + */ + Mutex m_unused_i2p_sessions_mutex; + + /** + * A pool of created I2P SAM transient sessions that should be used instead + * of creating new ones in order to reduce the load on the I2P network. + * Creating a session in I2P is not cheap, thus if this is not empty, then + * pick an entry from it instead of creating a new session. If connecting to + * a host fails, then the created session is put to this pool for reuse. + */ + std::queue> m_unused_i2p_sessions GUARDED_BY(m_unused_i2p_sessions_mutex); }; #endif // BITCOIN_COMMON_SOCKMAN_H diff --git a/src/net.cpp b/src/net.cpp index 15e2c233d1ddb..f94b4207807b4 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -376,23 +376,8 @@ bool CConnman::CheckIncomingNonce(uint64_t nonce) return true; } -/** Get the bind address for a socket as CAddress */ -static CService GetBindAddress(const Sock& sock) -{ - CService addr_bind; - struct sockaddr_storage sockaddr_bind; - socklen_t sockaddr_bind_len = sizeof(sockaddr_bind); - if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { - addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind); - } else { - LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "getsockname failed\n"); - } - return addr_bind; -} - CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, ConnectionType conn_type, bool use_v2transport) { - AssertLockNotHeld(m_unused_i2p_sessions_mutex); assert(conn_type != ConnectionType::INBOUND); if (pszDest == nullptr) { @@ -454,52 +439,29 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo // Connect std::unique_ptr sock; Proxy proxy; - CService addr_bind; - assert(!addr_bind.IsValid()); + assert(!proxy.IsValid()); std::unique_ptr i2p_transient_session; + std::optional node_id; + CService me; + for (auto& target_addr: connect_to) { if (target_addr.IsValid()) { const bool use_proxy{GetProxy(target_addr.GetNetwork(), proxy)}; bool proxyConnectionFailed = false; - if (target_addr.IsI2P() && use_proxy) { - i2p::Connection conn; - bool connected{false}; - - if (m_i2p_sam_session) { - connected = m_i2p_sam_session->Connect(target_addr, conn, proxyConnectionFailed); - } else { - { - LOCK(m_unused_i2p_sessions_mutex); - if (m_unused_i2p_sessions.empty()) { - i2p_transient_session = - std::make_unique(proxy, &interruptNet); - } else { - i2p_transient_session.swap(m_unused_i2p_sessions.front()); - m_unused_i2p_sessions.pop(); - } - } - connected = i2p_transient_session->Connect(target_addr, conn, proxyConnectionFailed); - if (!connected) { - LOCK(m_unused_i2p_sessions_mutex); - if (m_unused_i2p_sessions.size() < MAX_UNUSED_I2P_SESSIONS_SIZE) { - m_unused_i2p_sessions.emplace(i2p_transient_session.release()); - } - } - } - - if (connected) { - sock = std::move(conn.sock); - addr_bind = conn.me; - } - } else if (use_proxy) { + if (use_proxy && !target_addr.IsI2P()) { LogPrintLevel(BCLog::PROXY, BCLog::Level::Debug, "Using proxy: %s to connect to %s\n", proxy.ToString(), target_addr.ToStringAddrPort()); - sock = ConnectThroughProxy(proxy, target_addr.ToStringAddr(), target_addr.GetPort(), proxyConnectionFailed); - } else { - // no proxy needed (none set for target network) - sock = ConnectDirectly(target_addr, conn_type == ConnectionType::MANUAL); } + + node_id = ConnectAndMakeNodeId(target_addr, + /*is_important=*/conn_type == ConnectionType::MANUAL, + proxy, + proxyConnectionFailed, + me, + sock, + i2p_transient_session); + if (!proxyConnectionFailed) { // If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to // the proxy, mark this as an attempt. @@ -508,12 +470,19 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo } else if (pszDest && GetNameProxy(proxy)) { std::string host; uint16_t port{default_port}; - SplitHostPort(std::string(pszDest), port, host); - bool proxyConnectionFailed; - sock = ConnectThroughProxy(proxy, host, port, proxyConnectionFailed); + SplitHostPort(pszDest, port, host); + + bool dummy; + node_id = ConnectAndMakeNodeId(StringHostIntPort{host, port}, + /*is_important=*/conn_type == ConnectionType::MANUAL, + proxy, + dummy, + me, + sock, + i2p_transient_session); } // Check any other resolved address (if any) if we fail to connect - if (!sock) { + if (!node_id.has_value()) { continue; } @@ -521,18 +490,13 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo std::vector whitelist_permissions = conn_type == ConnectionType::MANUAL ? vWhitelistedRangeOutgoing : std::vector{}; AddWhitelistPermissionFlags(permission_flags, target_addr, whitelist_permissions); - // Add node - NodeId id = GetNewNodeId(); - uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); - if (!addr_bind.IsValid()) { - addr_bind = GetBindAddress(*sock); - } - CNode* pnode = new CNode(id, + const uint64_t nonce{GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(node_id.value()).Finalize()}; + CNode* pnode = new CNode(node_id.value(), std::move(sock), target_addr, CalculateKeyedNetGroup(target_addr), nonce, - addr_bind, + me, pszDest ? pszDest : "", conn_type, /*inbound_onion=*/false, @@ -545,7 +509,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo pnode->AddRef(); // We're making a new connection, harvest entropy from the time (and our peer count) - RandAddEvent((uint32_t)id); + RandAddEvent(static_cast(node_id.value())); return pnode; } @@ -1834,7 +1798,6 @@ void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, bool CConnman::AddConnection(const std::string& address, ConnectionType conn_type, bool use_v2transport = false) { - AssertLockNotHeld(m_unused_i2p_sessions_mutex); std::optional max_connections; switch (conn_type) { case ConnectionType::INBOUND: @@ -2474,7 +2437,6 @@ void CConnman::DumpAddresses() void CConnman::ProcessAddrFetch() { - AssertLockNotHeld(m_unused_i2p_sessions_mutex); std::string strDest; { LOCK(m_addr_fetches_mutex); @@ -2594,7 +2556,6 @@ bool CConnman::MaybePickPreferredNetwork(std::optional& network) void CConnman::ThreadOpenConnections(const std::vector connect, Span seed_nodes) { - AssertLockNotHeld(m_unused_i2p_sessions_mutex); AssertLockNotHeld(m_reconnections_mutex); FastRandomContext rng; // Connect to specific addresses @@ -3035,7 +2996,6 @@ std::vector CConnman::GetAddedNodeInfo(bool include_connected) co void CConnman::ThreadOpenAddedConnections() { - AssertLockNotHeld(m_unused_i2p_sessions_mutex); AssertLockNotHeld(m_reconnections_mutex); while (true) { @@ -3065,7 +3025,6 @@ void CConnman::ThreadOpenAddedConnections() // if successful, this moves the passed grant to the constructed node void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant&& grant_outbound, const char *pszDest, ConnectionType conn_type, bool use_v2transport) { - AssertLockNotHeld(m_unused_i2p_sessions_mutex); assert(conn_type != ConnectionType::INBOUND); // @@ -3914,7 +3873,6 @@ uint64_t CConnman::CalculateKeyedNetGroup(const CNetAddr& address) const void CConnman::PerformReconnections() { AssertLockNotHeld(m_reconnections_mutex); - AssertLockNotHeld(m_unused_i2p_sessions_mutex); while (true) { // Move first element of m_reconnections to todo (avoiding an allocation inside the lock). decltype(m_reconnections) todo; diff --git a/src/net.h b/src/net.h index 477b1f14cdcc1..2b35de79af05f 100644 --- a/src/net.h +++ b/src/net.h @@ -42,7 +42,6 @@ #include #include #include -#include #include #include #include @@ -1136,7 +1135,7 @@ class CConnman : private SockMan bool GetNetworkActive() const { return fNetworkActive; }; bool GetUseAddrmanOutgoing() const { return m_use_addrman_outgoing; }; void SetNetworkActive(bool active); - void OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant&& grant_outbound, const char* strDest, ConnectionType conn_type, bool use_v2transport) EXCLUSIVE_LOCKS_REQUIRED(!m_unused_i2p_sessions_mutex); + void OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant&& grant_outbound, const char* strDest, ConnectionType conn_type, bool use_v2transport); bool CheckIncomingNonce(uint64_t nonce); void ASMapHealthCheck(); @@ -1221,7 +1220,7 @@ class CConnman : private SockMan * - Max total outbound connection capacity filled * - Max connection capacity for type is filled */ - bool AddConnection(const std::string& address, ConnectionType conn_type, bool use_v2transport) EXCLUSIVE_LOCKS_REQUIRED(!m_unused_i2p_sessions_mutex); + bool AddConnection(const std::string& address, ConnectionType conn_type, bool use_v2transport); size_t GetNodeCount(ConnectionDirection) const; std::map getNetLocalAddresses() const; @@ -1280,10 +1279,10 @@ class CConnman : private SockMan bool Bind(const CService& addr, unsigned int flags, NetPermissionFlags permissions); bool InitBinds(const Options& options); - void ThreadOpenAddedConnections() EXCLUSIVE_LOCKS_REQUIRED(!m_added_nodes_mutex, !m_unused_i2p_sessions_mutex, !m_reconnections_mutex); + void ThreadOpenAddedConnections() EXCLUSIVE_LOCKS_REQUIRED(!m_added_nodes_mutex, !m_reconnections_mutex); void AddAddrFetch(const std::string& strDest) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex); - void ProcessAddrFetch() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_unused_i2p_sessions_mutex); - void ThreadOpenConnections(std::vector connect, Span seed_nodes) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_added_nodes_mutex, !m_nodes_mutex, !m_unused_i2p_sessions_mutex, !m_reconnections_mutex); + void ProcessAddrFetch() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex); + void ThreadOpenConnections(std::vector connect, Span seed_nodes) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_added_nodes_mutex, !m_nodes_mutex, !m_reconnections_mutex); void ThreadMessageHandler() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); /// Whether we are currently advertising our I2P address (via `AddLocal()`). @@ -1375,7 +1374,7 @@ class CConnman : private SockMan bool AlreadyConnectedToAddress(const CAddress& addr); bool AttemptToEvictConnection(); - CNode* ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, ConnectionType conn_type, bool use_v2transport) EXCLUSIVE_LOCKS_REQUIRED(!m_unused_i2p_sessions_mutex); + CNode* ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, ConnectionType conn_type, bool use_v2transport); void AddWhitelistPermissionFlags(NetPermissionFlags& flags, const CNetAddr &addr, const std::vector& ranges) const; void DeleteNode(CNode* pnode); @@ -1591,20 +1590,6 @@ class CConnman : private SockMan */ bool whitelist_relay; - /** - * Mutex protecting m_i2p_sam_sessions. - */ - Mutex m_unused_i2p_sessions_mutex; - - /** - * A pool of created I2P SAM transient sessions that should be used instead - * of creating new ones in order to reduce the load on the I2P network. - * Creating a session in I2P is not cheap, thus if this is not empty, then - * pick an entry from it instead of creating a new session. If connecting to - * a host fails, then the created session is put to this pool for reuse. - */ - std::queue> m_unused_i2p_sessions GUARDED_BY(m_unused_i2p_sessions_mutex); - /** * Mutex protecting m_reconnections. */ @@ -1626,13 +1611,7 @@ class CConnman : private SockMan std::list m_reconnections GUARDED_BY(m_reconnections_mutex); /** Attempt reconnections, if m_reconnections non-empty. */ - void PerformReconnections() EXCLUSIVE_LOCKS_REQUIRED(!m_reconnections_mutex, !m_unused_i2p_sessions_mutex); - - /** - * Cap on the size of `m_unused_i2p_sessions`, to ensure it does not - * unexpectedly use too much memory. - */ - static constexpr size_t MAX_UNUSED_I2P_SESSIONS_SIZE{10}; + void PerformReconnections() EXCLUSIVE_LOCKS_REQUIRED(!m_reconnections_mutex); /** * RAII helper to atomically create a copy of `m_nodes` and add a reference diff --git a/src/test/util/net.h b/src/test/util/net.h index 99872508363af..6c07e98fd8174 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -85,8 +85,7 @@ struct ConnmanTestMsg : public CConnman { bool AlreadyConnectedPublic(const CAddress& addr) { return AlreadyConnectedToAddress(addr); }; - CNode* ConnectNodePublic(PeerManager& peerman, const char* pszDest, ConnectionType conn_type) - EXCLUSIVE_LOCKS_REQUIRED(!m_unused_i2p_sessions_mutex); + CNode* ConnectNodePublic(PeerManager& peerman, const char* pszDest, ConnectionType conn_type); }; constexpr ServiceFlags ALL_SERVICE_FLAGS[]{ From 0070acd93755a170bb47dc32524f699235988c30 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 24 Sep 2024 09:41:47 +0200 Subject: [PATCH 18/30] net: tweak EventNewConnectionAccepted() Move `MaybeFlipIPv6toCJDNS()`, which is Bitcoin P2P specific from the callers of `CConnman::EventNewConnectionAccepted()` to inside that method. Move the IsSelectable check, the `TCP_NODELAY` option set and the generation of new node id out of `CConnman::EventNewConnectionAccepted()` because those are protocol agnostic. Move those to a new method `SockMan::NewSockAccepted()` which is called instead of `CConnman::EventNewConnectionAccepted()`. --- src/common/sockman.cpp | 22 +++++++++++++++++++++- src/common/sockman.h | 13 ++++++++++++- src/net.cpp | 37 ++++++++++++------------------------- src/net.h | 4 +++- 4 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index a696787e1b441..ca5292ef858c5 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -222,6 +222,26 @@ std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CServic return sock; } +void SockMan::NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) +{ + if (!sock->IsSelectable()) { + LogPrintf("connection from %s dropped: non-selectable socket\n", them.ToStringAddrPort()); + return; + } + + // According to the internet TCP_NODELAY is not carried into accepted sockets + // on all platforms. Set it again here just to be sure. + const int on{1}; + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + LogDebug(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n", + them.ToStringAddrPort()); + } + + const NodeId node_id{GetNewNodeId()}; + + EventNewConnectionAccepted(node_id, std::move(sock), me, them); +} + NodeId SockMan::GetNewNodeId() { return m_next_node_id.fetch_add(1, std::memory_order_relaxed); @@ -272,7 +292,7 @@ void SockMan::ThreadI2PAccept() continue; } - EventNewConnectionAccepted(std::move(conn.sock), conn.me, conn.peer); + NewSockAccepted(std::move(conn.sock), conn.me, conn.peer); err_wait = err_wait_begin; } diff --git a/src/common/sockman.h b/src/common/sockman.h index c81d30a12aefb..19bc4929d6133 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -118,6 +118,15 @@ class SockMan */ std::unique_ptr AcceptConnection(const Sock& listen_sock, CService& addr); + /** + * After a new socket with a peer has been created, configure its flags, + * make a new node id and call `EventNewConnectionAccepted()`. + * @param[in] sock The newly created socket. + * @param[in] me Address at our end of the connection. + * @param[in] them Address of the new peer. + */ + void NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them); + /** * Generate an id for a newly created node. */ @@ -162,11 +171,13 @@ class SockMan /** * Be notified when a new connection has been accepted. + * @param[in] node_id Id of the newly accepted connection. * @param[in] sock Connected socket to communicate with the peer. * @param[in] me The address and port at our side of the connection. * @param[in] them The address and port at the peer's side of the connection. */ - virtual void EventNewConnectionAccepted(std::unique_ptr&& sock, + virtual void EventNewConnectionAccepted(NodeId node_id, + std::unique_ptr&& sock, const CService& me, const CService& them) = 0; diff --git a/src/net.cpp b/src/net.cpp index f94b4207807b4..c64a344f5d68c 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1696,10 +1696,14 @@ bool CConnman::AttemptToEvictConnection() return false; } -void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, - const CService& addr_bind, - const CService& addr) +void CConnman::EventNewConnectionAccepted(NodeId node_id, + std::unique_ptr&& sock, + const CService& addr_bind_, + const CService& addr_) { + const CService addr_bind{MaybeFlipIPv6toCJDNS(addr_bind_)}; + const CService addr{MaybeFlipIPv6toCJDNS(addr_)}; + int nInbound = 0; NetPermissionFlags permission_flags = NetPermissionFlags::None; @@ -1722,19 +1726,6 @@ void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, return; } - if (!sock->IsSelectable()) { - LogPrintf("connection from %s dropped: non-selectable socket\n", addr.ToStringAddrPort()); - return; - } - - // According to the internet TCP_NODELAY is not carried into accepted sockets - // on all platforms. Set it again here just to be sure. - const int on{1}; - if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { - LogDebug(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n", - addr.ToStringAddrPort()); - } - // Don't accept connections from banned peers. bool banned = m_banman && m_banman->IsBanned(addr); if (!NetPermissions::HasFlag(permission_flags, NetPermissionFlags::NoBan) && banned) @@ -1760,8 +1751,7 @@ void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, } } - NodeId id = GetNewNodeId(); - uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); + const uint64_t nonce{GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(node_id).Finalize()}; const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end(); // The V2Transport transparently falls back to V1 behavior when an incoming V1 connection is @@ -1769,7 +1759,7 @@ void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, ServiceFlags local_services = GetLocalServices(); const bool use_v2transport(local_services & NODE_P2P_V2); - CNode* pnode = new CNode(id, + CNode* pnode = new CNode(node_id, std::move(sock), CAddress{addr, NODE_NONE}, CalculateKeyedNetGroup(addr), @@ -1788,12 +1778,12 @@ void CConnman::EventNewConnectionAccepted(std::unique_ptr&& sock, m_msgproc->InitializeNode(*pnode, local_services); { LOCK(m_nodes_mutex); - m_nodes.emplace(id, pnode); + m_nodes.emplace(node_id, pnode); } LogDebug(BCLog::NET, "connection from %s accepted\n", addr.ToStringAddrPort()); // We received a new connection, harvest entropy from the time (and our peer count) - RandAddEvent((uint32_t)id); + RandAddEvent(static_cast(node_id)); } bool CConnman::AddConnection(const std::string& address, ConnectionType conn_type, bool use_v2transport = false) @@ -2256,10 +2246,7 @@ void CConnman::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock auto sock_accepted{AcceptConnection(*sock, addr_accepted)}; if (sock_accepted) { - addr_accepted = MaybeFlipIPv6toCJDNS(addr_accepted); - const CService addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock))}; - - EventNewConnectionAccepted(std::move(sock_accepted), addr_bind, addr_accepted); + NewSockAccepted(std::move(sock_accepted), GetBindAddress(*sock), addr_accepted); } } } diff --git a/src/net.h b/src/net.h index 2b35de79af05f..2f435b2241056 100644 --- a/src/net.h +++ b/src/net.h @@ -1292,11 +1292,13 @@ class CConnman : private SockMan /** * Create a `CNode` object and add it to the `m_nodes` member. + * @param[in] node_id Id of the newly accepted connection. * @param[in] sock Connected socket to communicate with the peer. * @param[in] me The address and port at our side of the connection. * @param[in] them The address and port at the peer's side of the connection. */ - virtual void EventNewConnectionAccepted(std::unique_ptr&& sock, + virtual void EventNewConnectionAccepted(NodeId node_id, + std::unique_ptr&& sock, const CService& me, const CService& them) override; From 754fe4a333d489be50b4dcf680d22fb8cd95b777 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 23 Sep 2024 11:03:32 +0200 Subject: [PATCH 19/30] net: move sockets from CNode to SockMan Move `CNode::m_sock` and `CNode::m_i2p_sam_session` to `SockMan::m_connected`. Also move all the code that handles sockets to `SockMan`. `CNode::CloseSocketDisconnect()` becomes `CConnman::MarkAsDisconnectAndCloseConnection()`. `CConnman::SocketSendData()` is renamed to `CConnman::SendMessagesAsBytes()` and its sockets-touching bits are moved to `SockMan::SendBytes()`. `CConnman::GenerateWaitSockets()` goes to `SockMan::GenerateWaitSockets()`. `CConnman::ThreadSocketHandler()` and `CConnman::SocketHandler()` are combined into `SockMan::ThreadSocketHandler()`. `CConnman::SocketHandlerConnected()` goes to `SockMan::SocketHandlerConnected()`. `CConnman::SocketHandlerListening()` goes to `SockMan::SocketHandlerListening()`. --- src/common/sockman.cpp | 243 ++++++++++++++++++++- src/common/sockman.h | 157 ++++++++++++-- src/net.cpp | 280 ++++--------------------- src/net.h | 73 +------ src/test/denialofservice_tests.cpp | 6 - src/test/fuzz/connman.cpp | 5 +- src/test/fuzz/net.cpp | 3 - src/test/fuzz/p2p_handshake.cpp | 2 +- src/test/fuzz/p2p_headers_presync.cpp | 2 +- src/test/fuzz/process_message.cpp | 2 +- src/test/fuzz/process_messages.cpp | 2 +- src/test/fuzz/util/net.h | 3 - src/test/net_peer_connection_tests.cpp | 1 - src/test/net_tests.cpp | 9 - src/test/util/net.h | 6 + 15 files changed, 447 insertions(+), 347 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index ca5292ef858c5..fcbb29da87f35 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -10,7 +10,13 @@ #include #include -CService GetBindAddress(const Sock& sock) +#include + +// The set of sockets cannot be modified while waiting +// The sleep time needs to be small to avoid new sockets stalling +static constexpr auto SELECT_TIMEOUT{50ms}; + +static CService GetBindAddress(const Sock& sock) { CService addr_bind; struct sockaddr_storage sockaddr_bind; @@ -104,6 +110,8 @@ bool SockMan::BindAndStartListening(const CService& to, bilingual_str& errmsg) void SockMan::StartSocketsThreads(const Options& options) { + m_thread_socket_handler = std::thread(&util::TraceThread, "net", [this] { ThreadSocketHandler(); }); + if (options.i2p.has_value()) { m_i2p_sam_session = std::make_unique( options.i2p->private_key_file, options.i2p->sam_proxy, &interruptNet); @@ -118,6 +126,10 @@ void SockMan::JoinSocketsThreads() if (m_thread_i2p_accept.joinable()) { m_thread_i2p_accept.join(); } + + if (m_thread_socket_handler.joinable()) { + m_thread_socket_handler.join(); + } } std::optional @@ -125,12 +137,14 @@ SockMan::ConnectAndMakeNodeId(const std::variant& t bool is_important, const Proxy& proxy, bool& proxy_failed, - CService& me, - std::unique_ptr& sock, - std::unique_ptr& i2p_transient_session) + CService& me) { + AssertLockNotHeld(m_connected_mutex); AssertLockNotHeld(m_unused_i2p_sessions_mutex); + std::unique_ptr sock; + std::unique_ptr i2p_transient_session; + Assume(!me.IsValid()); if (std::holds_alternative(to)) { @@ -194,6 +208,12 @@ SockMan::ConnectAndMakeNodeId(const std::variant& t const NodeId node_id{GetNewNodeId()}; + { + LOCK(m_connected_mutex); + m_connected.emplace(node_id, std::make_shared(std::move(sock), + std::move(i2p_transient_session))); + } + return node_id; } @@ -224,6 +244,8 @@ std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CServic void SockMan::NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) { + AssertLockNotHeld(m_connected_mutex); + if (!sock->IsSelectable()) { LogPrintf("connection from %s dropped: non-selectable socket\n", them.ToStringAddrPort()); return; @@ -239,7 +261,14 @@ void SockMan::NewSockAccepted(std::unique_ptr&& sock, const CService& me, const NodeId node_id{GetNewNodeId()}; - EventNewConnectionAccepted(node_id, std::move(sock), me, them); + { + LOCK(m_connected_mutex); + m_connected.emplace(node_id, std::make_shared(std::move(sock))); + } + + if (!EventNewConnectionAccepted(node_id, me, them)) { + CloseConnection(node_id); + } } NodeId SockMan::GetNewNodeId() @@ -247,6 +276,52 @@ NodeId SockMan::GetNewNodeId() return m_next_node_id.fetch_add(1, std::memory_order_relaxed); } +bool SockMan::CloseConnection(NodeId node_id) +{ + LOCK(m_connected_mutex); + return m_connected.erase(node_id) > 0; +} + +ssize_t SockMan::SendBytes(NodeId node_id, + std::span data, + bool will_send_more, + std::string& errmsg) const +{ + AssertLockNotHeld(m_connected_mutex); + + if (data.empty()) { + return 0; + } + + auto node_sockets{GetNodeSockets(node_id)}; + if (!node_sockets) { + // Bail out immediately and just leave things in the caller's send queue. + return 0; + } + + int flags{MSG_NOSIGNAL | MSG_DONTWAIT}; +#ifdef MSG_MORE + if (will_send_more) { + flags |= MSG_MORE; + } +#endif + + const ssize_t sent{WITH_LOCK( + node_sockets->mutex, + return node_sockets->sock->Send(reinterpret_cast(data.data()), data.size(), flags);)}; + + if (sent >= 0) { + return sent; + } + + const int err{WSAGetLastError()}; + if (err == WSAEWOULDBLOCK || err == WSAEMSGSIZE || err == WSAEINTR || err == WSAEINPROGRESS) { + return 0; + } + errmsg = NetworkErrorString(err); + return -1; +} + void SockMan::CloseSockets() { m_listen.clear(); @@ -262,8 +337,17 @@ void SockMan::EventIOLoopCompletedForAllPeers() {} void SockMan::EventI2PListen(const CService&, bool) {} +void SockMan::TestOnlyAddExistentNode(NodeId node_id, std::unique_ptr&& sock) +{ + LOCK(m_connected_mutex); + const auto result{m_connected.emplace(node_id, std::make_shared(std::move(sock)))}; + assert(result.second); +} + void SockMan::ThreadI2PAccept() { + AssertLockNotHeld(m_connected_mutex); + static constexpr auto err_wait_begin = 1s; static constexpr auto err_wait_cap = 5min; auto err_wait = err_wait_begin; @@ -297,3 +381,152 @@ void SockMan::ThreadI2PAccept() err_wait = err_wait_begin; } } + +void SockMan::ThreadSocketHandler() +{ + AssertLockNotHeld(m_connected_mutex); + + while (!interruptNet) { + EventIOLoopCompletedForAllPeers(); + + // Check for the readiness of the already connected sockets and the + // listening sockets in one call ("readiness" as in poll(2) or + // select(2)). If none are ready, wait for a short while and return + // empty sets. + auto io_readiness{GenerateWaitSockets()}; + if (io_readiness.events_per_sock.empty() || + // WaitMany() may as well be a static method, the context of the first Sock in the vector is not relevant. + !io_readiness.events_per_sock.begin()->first->WaitMany(SELECT_TIMEOUT, + io_readiness.events_per_sock)) { + interruptNet.sleep_for(SELECT_TIMEOUT); + } + + // Service (send/receive) each of the already connected sockets. + SocketHandlerConnected(io_readiness); + + // Accept new connections from listening sockets. + SocketHandlerListening(io_readiness.events_per_sock); + } +} + +SockMan::IOReadiness SockMan::GenerateWaitSockets() +{ + AssertLockNotHeld(m_connected_mutex); + + IOReadiness io_readiness; + + for (const auto& sock : m_listen) { + io_readiness.events_per_sock.emplace(sock, Sock::Events{Sock::RECV}); + } + + auto connected_snapshot{WITH_LOCK(m_connected_mutex, return m_connected;)}; + + for (const auto& [node_id, node_sockets] : connected_snapshot) { + const bool select_recv{ShouldTryToRecv(node_id)}; + const bool select_send{ShouldTryToSend(node_id)}; + if (!select_recv && !select_send) continue; + + Sock::Event event = (select_send ? Sock::SEND : 0) | (select_recv ? Sock::RECV : 0); + io_readiness.events_per_sock.emplace(node_sockets->sock, Sock::Events{event}); + io_readiness.node_ids_per_sock.emplace(node_sockets->sock, node_id); + } + + return io_readiness; +} + +void SockMan::SocketHandlerConnected(const IOReadiness& io_readiness) +{ + AssertLockNotHeld(m_connected_mutex); + + for (const auto& [sock, events] : io_readiness.events_per_sock) { + if (interruptNet) { + return; + } + + auto it{io_readiness.node_ids_per_sock.find(sock)}; + if (it == io_readiness.node_ids_per_sock.end()) { + continue; + } + const NodeId node_id{it->second}; + + bool send_ready = events.occurred & Sock::SEND; + bool recv_ready = events.occurred & Sock::RECV; + bool err_ready = events.occurred & Sock::ERR; + + if (send_ready) { + bool cancel_recv; + + EventReadyToSend(node_id, cancel_recv); + + if (cancel_recv) { + recv_ready = false; + } + } + + if (recv_ready || err_ready) { + uint8_t buf[0x10000]; // typical socket buffer is 8K-64K + + auto node_sockets{GetNodeSockets(node_id)}; + if (!node_sockets) { + continue; + } + + const ssize_t nrecv{WITH_LOCK( + node_sockets->mutex, + return node_sockets->sock->Recv(buf, sizeof(buf), MSG_DONTWAIT);)}; + + switch (nrecv) { + case -1: { + const int err = WSAGetLastError(); + if (err != WSAEWOULDBLOCK && err != WSAEMSGSIZE && err != WSAEINTR && err != WSAEINPROGRESS) { + EventGotPermanentReadError(node_id, NetworkErrorString(err)); + } + break; + } + case 0: + EventGotEOF(node_id); + break; + default: + EventGotData(node_id, buf, nrecv); + break; + } + } + + EventIOLoopCompletedForNode(node_id); + } +} + +void SockMan::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock) +{ + AssertLockNotHeld(m_connected_mutex); + + for (const auto& sock : m_listen) { + if (interruptNet) { + return; + } + const auto it = events_per_sock.find(sock); + if (it != events_per_sock.end() && it->second.occurred & Sock::RECV) { + CService addr_accepted; + + auto sock_accepted{AcceptConnection(*sock, addr_accepted)}; + + if (sock_accepted) { + NewSockAccepted(std::move(sock_accepted), GetBindAddress(*sock), addr_accepted); + } + } + } +} + +std::shared_ptr SockMan::GetNodeSockets(NodeId node_id) const +{ + LOCK(m_connected_mutex); + + auto it{m_connected.find(node_id)}; + if (it == m_connected.end()) { + // There is no socket in case we've already disconnected, or in test cases without + // real connections. + return {}; + } + + return it->second; +} diff --git a/src/common/sockman.h b/src/common/sockman.h index 19bc4929d6133..570467d965077 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -15,14 +15,13 @@ #include #include #include +#include #include #include #include typedef int64_t NodeId; -CService GetBindAddress(const Sock& sock); - /** * A socket manager class which handles socket operations. * To use this class, inherit from it and implement the pure virtual methods. @@ -31,6 +30,8 @@ CService GetBindAddress(const Sock& sock); * - starting of necessary threads to process socket operations * - accepting incoming connections * - making outbound connections + * - closing connections + * - waiting for IO readiness on sockets and doing send/recv accordingly */ class SockMan { @@ -97,18 +98,14 @@ class SockMan * proxy, then it will be set to true. * @param[out] me If the connection was successful then this is set to the address on the * local side of the socket. - * @param[out] sock Connected socket, if the operation is successful. - * @param[out] i2p_transient_session I2P session, if the operation is successful. * @return Newly generated node id, or std::nullopt if the operation fails. */ std::optional ConnectAndMakeNodeId(const std::variant& to, bool is_important, const Proxy& proxy, bool& proxy_failed, - CService& me, - std::unique_ptr& sock, - std::unique_ptr& i2p_transient_session) - EXCLUSIVE_LOCKS_REQUIRED(!m_unused_i2p_sessions_mutex); + CService& me) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex, !m_unused_i2p_sessions_mutex); /** * Accept a connection. @@ -125,13 +122,38 @@ class SockMan * @param[in] me Address at our end of the connection. * @param[in] them Address of the new peer. */ - void NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them); + void NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); /** * Generate an id for a newly created node. */ NodeId GetNewNodeId(); + /** + * Disconnect a given peer by closing its socket and release resources occupied by it. + * @return Whether the peer existed and its socket was closed by this call. + */ + bool CloseConnection(NodeId node_id) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Try to send some data to the given node. + * @param[in] node_id Identifier of the node to send to. + * @param[in] data The data to send, it might happen that only a prefix of this is sent. + * @param[in] will_send_more Used as an optimization if the caller knows that they will + * be sending more data soon after this call. + * @param[out] errmsg If <0 is returned then this will contain a human readable message + * explaining the error. + * @retval >=0 The number of bytes actually sent. + * @retval <0 A permanent error has occurred. + */ + ssize_t SendBytes(NodeId node_id, + std::span data, + bool will_send_more, + std::string& errmsg) const + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + /** * Close all sockets. */ @@ -157,6 +179,15 @@ class SockMan */ std::vector> m_listen; +protected: + + /** + * During some tests mocked sockets are created outside of `SockMan`, make it + * possible to add those so that send/recv can be exercised. + */ + void TestOnlyAddExistentNode(NodeId node_id, std::unique_ptr&& sock) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + private: /** @@ -172,12 +203,13 @@ class SockMan /** * Be notified when a new connection has been accepted. * @param[in] node_id Id of the newly accepted connection. - * @param[in] sock Connected socket to communicate with the peer. * @param[in] me The address and port at our side of the connection. * @param[in] them The address and port at the peer's side of the connection. + * @retval true The new connection was accepted at the higher level. + * @retval false The connection was refused at the higher level, so the + * associated socket and node_id should be discarded by `SockMan`. */ - virtual void EventNewConnectionAccepted(NodeId node_id, - std::unique_ptr&& sock, + virtual bool EventNewConnectionAccepted(NodeId node_id, const CService& me, const CService& them) = 0; @@ -263,17 +295,107 @@ class SockMan */ virtual void EventI2PListen(const CService& addr, bool success); + /** + * The sockets used by a connected node - a data socket and an optional I2P session. + */ + struct NodeSockets { + explicit NodeSockets(std::unique_ptr&& s) + : sock{std::move(s)} + { + } + + explicit NodeSockets(std::shared_ptr&& s, std::unique_ptr&& sess) + : sock{std::move(s)}, + i2p_transient_session{std::move(sess)} + { + } + + /** + * Mutex that serializes the Send() and Recv() calls on `sock`. + */ + Mutex mutex; + + /** + * Underlying socket. + * `shared_ptr` (instead of `unique_ptr`) is used to avoid premature close of the + * underlying file descriptor by one thread while another thread is poll(2)-ing + * it for activity. + * @see https://github.com/bitcoin/bitcoin/issues/21744 for details. + */ + std::shared_ptr sock; + + /** + * When transient I2P sessions are used, then each node has its own session, otherwise + * all nodes use the session from `m_i2p_sam_session` and share the same I2P address. + * I2P sessions involve a data/transport socket (in `sock`) and a control socket + * (in `i2p_transient_session`). For transient sessions, once the data socket `sock` is + * closed, the control socket is not going to be used anymore and would be just taking + * resources. Storing it here makes its deletion together with `sock` automatic. + */ + std::unique_ptr i2p_transient_session; + }; + + /** + * Info about which socket has which event ready and its node id. + */ + struct IOReadiness { + Sock::EventsPerSock events_per_sock; + std::unordered_map node_ids_per_sock; + }; + /** * Accept incoming I2P connections in a loop and call * `EventNewConnectionAccepted()` for each new connection. */ - void ThreadI2PAccept(); + void ThreadI2PAccept() + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Check connected and listening sockets for IO readiness and process them accordingly. + */ + void ThreadSocketHandler() + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Generate a collection of sockets to check for IO readiness. + * @return Sockets to check for readiness plus an aux map to find the + * corresponding node id given a socket. + */ + IOReadiness GenerateWaitSockets() + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Do the read/write for connected sockets that are ready for IO. + * @param[in] io_readiness Which sockets are ready and their node ids. + */ + void SocketHandlerConnected(const IOReadiness& io_readiness) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Accept incoming connections, one from each read-ready listening socket. + * @param[in] events_per_sock Sockets that are ready for IO. + */ + void SocketHandlerListening(const Sock::EventsPerSock& events_per_sock) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Retrieve an entry from m_connected. + * @param[in] node_id Node id to search for. + * @return NodeSockets for the given node id or empty shared_ptr if not found. + */ + std::shared_ptr GetNodeSockets(NodeId node_id) const + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); /** * The id to assign to the next created node. Used to generate ids of nodes. */ std::atomic m_next_node_id{0}; + /** + * Thread that sends to and receives from sockets and accepts connections. + */ + std::thread m_thread_socket_handler; + /** * Thread that accepts incoming I2P connections in a loop, can be stopped via `interruptNet`. */ @@ -292,6 +414,15 @@ class SockMan * a host fails, then the created session is put to this pool for reuse. */ std::queue> m_unused_i2p_sessions GUARDED_BY(m_unused_i2p_sessions_mutex); + + mutable Mutex m_connected_mutex; + + /** + * Sockets for connected peers. + * The `shared_ptr` makes it possible to create a snapshot of this by simply copying + * it (under `m_connected_mutex`). + */ + std::unordered_map> m_connected GUARDED_BY(m_connected_mutex); }; #endif // BITCOIN_COMMON_SOCKMAN_H diff --git a/src/net.cpp b/src/net.cpp index c64a344f5d68c..63eb175febd0a 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -103,10 +103,6 @@ enum BindFlags { BF_DONT_ADVERTISE = (1U << 1), }; -// The set of sockets cannot be modified while waiting -// The sleep time needs to be small to avoid new sockets stalling -static const uint64_t SELECT_TIMEOUT_MILLISECONDS = 50; - const std::string NET_MESSAGE_TYPE_OTHER = "*other*"; static const uint64_t RANDOMIZER_ID_NETGROUP = 0x6c0edd8036ef4036ULL; // SHA256("netgroup")[0:8] @@ -437,10 +433,8 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo } // Connect - std::unique_ptr sock; Proxy proxy; assert(!proxy.IsValid()); - std::unique_ptr i2p_transient_session; std::optional node_id; CService me; @@ -458,9 +452,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo /*is_important=*/conn_type == ConnectionType::MANUAL, proxy, proxyConnectionFailed, - me, - sock, - i2p_transient_session); + me); if (!proxyConnectionFailed) { // If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to @@ -477,9 +469,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo /*is_important=*/conn_type == ConnectionType::MANUAL, proxy, dummy, - me, - sock, - i2p_transient_session); + me); } // Check any other resolved address (if any) if we fail to connect if (!node_id.has_value()) { @@ -492,7 +482,6 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo const uint64_t nonce{GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(node_id.value()).Finalize()}; CNode* pnode = new CNode(node_id.value(), - std::move(sock), target_addr, CalculateKeyedNetGroup(target_addr), nonce, @@ -502,7 +491,6 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo /*inbound_onion=*/false, CNodeOptions{ .permission_flags = permission_flags, - .i2p_sam_session = std::move(i2p_transient_session), .recv_flood_size = nReceiveFloodSize, .use_v2transport = use_v2transport, }); @@ -517,16 +505,6 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo return nullptr; } -void CNode::CloseSocketDisconnect() -{ - fDisconnect = true; - LOCK(m_sock_mutex); - if (m_sock) { - m_sock.reset(); - } - m_i2p_sam_session.reset(); -} - void CConnman::AddWhitelistPermissionFlags(NetPermissionFlags& flags, const CNetAddr &addr, const std::vector& ranges) const { for (const auto& subnet : ranges) { if (subnet.m_subnet.Match(addr)) { @@ -1548,7 +1526,7 @@ Transport::Info V2Transport::GetInfo() const noexcept return info; } -std::pair CConnman::SocketSendData(CNode& node) +std::pair CConnman::SendMessagesAsBytes(CNode& node) { AssertLockNotHeld(m_total_bytes_sent_mutex); @@ -1576,45 +1554,29 @@ std::pair CConnman::SocketSendData(CNode& node) if (expected_more.has_value()) Assume(!data.empty() == *expected_more); expected_more = more; data_left = !data.empty(); // will be overwritten on next loop if all of data gets sent - int nBytes = 0; - if (!data.empty()) { - LOCK(node.m_sock_mutex); - // There is no socket in case we've already disconnected, or in test cases without - // real connections. In these cases, we bail out immediately and just leave things - // in the send queue and transport. - if (!node.m_sock) { - break; - } - int flags = MSG_NOSIGNAL | MSG_DONTWAIT; -#ifdef MSG_MORE - if (more) { - flags |= MSG_MORE; - } -#endif - nBytes = node.m_sock->Send(reinterpret_cast(data.data()), data.size(), flags); - } - if (nBytes > 0) { + + std::string errmsg; + + const ssize_t sent{SendBytes(node.GetId(), data, more, errmsg)}; + + if (sent > 0) { node.m_last_send = GetTime(); - node.nSendBytes += nBytes; + node.nSendBytes += sent; // Notify transport that bytes have been processed. - node.m_transport->MarkBytesSent(nBytes); + node.m_transport->MarkBytesSent(sent); // Update statistics per message type. if (!msg_type.empty()) { // don't report v2 handshake bytes for now - node.AccountForSentBytes(msg_type, nBytes); + node.AccountForSentBytes(msg_type, sent); } - nSentSize += nBytes; - if ((size_t)nBytes != data.size()) { + nSentSize += sent; + if (static_cast(sent) != data.size()) { // could not send full message; stop sending more break; } } else { - if (nBytes < 0) { - // error - int nErr = WSAGetLastError(); - if (nErr != WSAEWOULDBLOCK && nErr != WSAEMSGSIZE && nErr != WSAEINTR && nErr != WSAEINPROGRESS) { - LogDebug(BCLog::NET, "socket send error, %s: %s\n", node.DisconnectMsg(fLogIPs), NetworkErrorString(nErr)); - node.CloseSocketDisconnect(); - } + if (sent < 0) { + LogDebug(BCLog::NET, "socket send error, %s: %s\n", node.DisconnectMsg(fLogIPs), errmsg); + MarkAsDisconnectAndCloseConnection(node); } break; } @@ -1696,8 +1658,7 @@ bool CConnman::AttemptToEvictConnection() return false; } -void CConnman::EventNewConnectionAccepted(NodeId node_id, - std::unique_ptr&& sock, +bool CConnman::EventNewConnectionAccepted(NodeId node_id, const CService& addr_bind_, const CService& addr_) { @@ -1723,7 +1684,7 @@ void CConnman::EventNewConnectionAccepted(NodeId node_id, if (!fNetworkActive) { LogDebug(BCLog::NET, "connection from %s dropped: not accepting new connections\n", addr.ToStringAddrPort()); - return; + return false; } // Don't accept connections from banned peers. @@ -1731,7 +1692,7 @@ void CConnman::EventNewConnectionAccepted(NodeId node_id, if (!NetPermissions::HasFlag(permission_flags, NetPermissionFlags::NoBan) && banned) { LogDebug(BCLog::NET, "connection from %s dropped (banned)\n", addr.ToStringAddrPort()); - return; + return false; } // Only accept connections from discouraged peers if our inbound slots aren't (almost) full. @@ -1739,7 +1700,7 @@ void CConnman::EventNewConnectionAccepted(NodeId node_id, if (!NetPermissions::HasFlag(permission_flags, NetPermissionFlags::NoBan) && nInbound + 1 >= m_max_inbound && discouraged) { LogDebug(BCLog::NET, "connection from %s dropped (discouraged)\n", addr.ToStringAddrPort()); - return; + return false; } if (nInbound >= m_max_inbound) @@ -1747,7 +1708,7 @@ void CConnman::EventNewConnectionAccepted(NodeId node_id, if (!AttemptToEvictConnection()) { // No connection to evict, disconnect the new connection LogDebug(BCLog::NET, "failed to find an eviction candidate - connection dropped (full)\n"); - return; + return false; } } @@ -1760,7 +1721,6 @@ void CConnman::EventNewConnectionAccepted(NodeId node_id, const bool use_v2transport(local_services & NODE_P2P_V2); CNode* pnode = new CNode(node_id, - std::move(sock), CAddress{addr, NODE_NONE}, CalculateKeyedNetGroup(addr), nonce, @@ -1784,6 +1744,8 @@ void CConnman::EventNewConnectionAccepted(NodeId node_id, // We received a new connection, harvest entropy from the time (and our peer count) RandAddEvent(static_cast(node_id)); + + return true; } bool CConnman::AddConnection(const std::string& address, ConnectionType conn_type, bool use_v2transport = false) @@ -1825,6 +1787,14 @@ bool CConnman::AddConnection(const std::string& address, ConnectionType conn_typ return true; } +void CConnman::MarkAsDisconnectAndCloseConnection(CNode& node) +{ + node.fDisconnect = true; + if (CloseConnection(node.GetId())) { + LogDebug(BCLog::NET, "%s\n", node.DisconnectMsg(fLogIPs)); + } +} + void CConnman::DisconnectNodes() { AssertLockNotHeld(m_nodes_mutex); @@ -1875,8 +1845,7 @@ void CConnman::DisconnectNodes() // release outbound grant (if any) pnode->grantOutbound.Release(); - // close socket and cleanup - pnode->CloseSocketDisconnect(); + MarkAsDisconnectAndCloseConnection(*pnode); // update connection count by network if (pnode->IsManualOrFullOutboundConn()) --m_network_conn_counts[pnode->addr.GetNetwork()]; @@ -1989,7 +1958,7 @@ void CConnman::EventReadyToSend(NodeId node_id, bool& cancel_recv) return; } - const auto [bytes_sent, data_left] = WITH_LOCK(node->cs_vSend, return SocketSendData(*node);); + const auto [bytes_sent, data_left] = WITH_LOCK(node->cs_vSend, return SendMessagesAsBytes(*node);); // If both receiving and (non-optimistic) sending were possible, we first attempt // sending. If that succeeds, but does not fully drain the send queue, do not @@ -2017,7 +1986,7 @@ void CConnman::EventGotData(NodeId node_id, const uint8_t* data, size_t n) "receiving message bytes failed, %s\n", node->DisconnectMsg(fLogIPs) ); - node->CloseSocketDisconnect(); + MarkAsDisconnectAndCloseConnection(*node); } RecordBytesRecv(n); if (notify) { @@ -2038,7 +2007,7 @@ void CConnman::EventGotEOF(NodeId node_id) if (!node->fDisconnect) { LogDebug(BCLog::NET, "socket closed for peer=%d\n", node_id); } - node->CloseSocketDisconnect(); + MarkAsDisconnectAndCloseConnection(*node); } void CConnman::EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) @@ -2053,7 +2022,7 @@ void CConnman::EventGotPermanentReadError(NodeId node_id, const std::string& err if (!node->fDisconnect) { LogDebug(BCLog::NET, "socket recv error for peer=%d: %s\n", node_id, errmsg); } - node->CloseSocketDisconnect(); + MarkAsDisconnectAndCloseConnection(*node); } bool CConnman::ShouldTryToSend(NodeId node_id) const @@ -2105,164 +2074,6 @@ void CConnman::EventIOLoopCompletedForAllPeers() DisconnectNodes(); NotifyNumConnectionsChanged(); } - -Sock::EventsPerSock CConnman::GenerateWaitSockets(Span nodes) -{ - AssertLockNotHeld(m_nodes_mutex); - - Sock::EventsPerSock events_per_sock; - - for (const auto& sock : m_listen) { - events_per_sock.emplace(sock, Sock::Events{Sock::RECV}); - } - - for (CNode* pnode : nodes) { - const bool select_recv{ShouldTryToRecv(pnode->GetId())}; - const bool select_send{ShouldTryToSend(pnode->GetId())}; - if (!select_recv && !select_send) continue; - - LOCK(pnode->m_sock_mutex); - if (pnode->m_sock) { - Sock::Event event = (select_send ? Sock::SEND : 0) | (select_recv ? Sock::RECV : 0); - events_per_sock.emplace(pnode->m_sock, Sock::Events{event}); - } - } - - return events_per_sock; -} - -void CConnman::SocketHandler() -{ - AssertLockNotHeld(m_total_bytes_sent_mutex); - - Sock::EventsPerSock events_per_sock; - - { - const NodesSnapshot snap{*this, /*shuffle=*/false}; - - const auto timeout = std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS); - - // Check for the readiness of the already connected sockets and the - // listening sockets in one call ("readiness" as in poll(2) or - // select(2)). If none are ready, wait for a short while and return - // empty sets. - events_per_sock = GenerateWaitSockets(snap.Nodes()); - if (events_per_sock.empty() || !events_per_sock.begin()->first->WaitMany(timeout, events_per_sock)) { - interruptNet.sleep_for(timeout); - } - - // Service (send/receive) each of the already connected nodes. - SocketHandlerConnected(snap.Nodes(), events_per_sock); - } - - // Accept new connections from listening sockets. - SocketHandlerListening(events_per_sock); -} - -void CConnman::SocketHandlerConnected(const std::vector& nodes, - const Sock::EventsPerSock& events_per_sock) -{ - AssertLockNotHeld(m_nodes_mutex); - AssertLockNotHeld(m_total_bytes_sent_mutex); - - for (CNode* pnode : nodes) { - if (interruptNet) - return; - - // - // Receive - // - bool recvSet = false; - bool sendSet = false; - bool errorSet = false; - { - LOCK(pnode->m_sock_mutex); - if (!pnode->m_sock) { - continue; - } - const auto it = events_per_sock.find(pnode->m_sock); - if (it != events_per_sock.end()) { - recvSet = it->second.occurred & Sock::RECV; - sendSet = it->second.occurred & Sock::SEND; - errorSet = it->second.occurred & Sock::ERR; - } - } - - if (sendSet) { - bool cancel_recv; - - EventReadyToSend(pnode->GetId(), cancel_recv); - - if (cancel_recv) { - recvSet = false; - } - } - - if (recvSet || errorSet) - { - // typical socket buffer is 8K-64K - uint8_t pchBuf[0x10000]; - int nBytes = 0; - { - LOCK(pnode->m_sock_mutex); - if (!pnode->m_sock) { - continue; - } - nBytes = pnode->m_sock->Recv(pchBuf, sizeof(pchBuf), MSG_DONTWAIT); - } - if (nBytes > 0) - { - EventGotData(pnode->GetId(), pchBuf, nBytes); - } - else if (nBytes == 0) - { - EventGotEOF(pnode->GetId()); - } - else if (nBytes < 0) - { - // error - int nErr = WSAGetLastError(); - if (nErr != WSAEWOULDBLOCK && nErr != WSAEMSGSIZE && nErr != WSAEINTR && nErr != WSAEINPROGRESS) - { - EventGotPermanentReadError(pnode->GetId(), NetworkErrorString(nErr)); - } - } - } - - EventIOLoopCompletedForNode(pnode->GetId()); - } -} - -void CConnman::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock) -{ - for (const auto& sock : m_listen) { - if (interruptNet) { - return; - } - const auto it = events_per_sock.find(sock); - if (it != events_per_sock.end() && it->second.occurred & Sock::RECV) { - CService addr_accepted; - - auto sock_accepted{AcceptConnection(*sock, addr_accepted)}; - - if (sock_accepted) { - NewSockAccepted(std::move(sock_accepted), GetBindAddress(*sock), addr_accepted); - } - } - } -} - -void CConnman::ThreadSocketHandler() -{ - AssertLockNotHeld(m_nodes_mutex); - AssertLockNotHeld(m_total_bytes_sent_mutex); - - while (!interruptNet) - { - EventIOLoopCompletedForAllPeers(); - SocketHandler(); - } -} void CConnman::WakeMessageHandler() { @@ -3270,9 +3081,6 @@ bool CConnman::Start(CScheduler& scheduler, const Options& connOptions) fMsgProcWake = false; } - // Send and receive from sockets, accept connections - threadSocketHandler = std::thread(&util::TraceThread, "net", [this] { ThreadSocketHandler(); }); - SockMan::Options sockman_options; Proxy i2p_sam; @@ -3370,8 +3178,6 @@ void CConnman::StopThreads() threadOpenAddedConnections.join(); if (threadDNSAddressSeed.joinable()) threadDNSAddressSeed.join(); - if (threadSocketHandler.joinable()) - threadSocketHandler.join(); } void CConnman::StopNodes() @@ -3394,8 +3200,7 @@ void CConnman::StopNodes() decltype(m_nodes) nodes; WITH_LOCK(m_nodes_mutex, nodes.swap(m_nodes)); for (auto& [id, pnode] : nodes) { - LogDebug(BCLog::NET, "%s\n", pnode->DisconnectMsg(fLogIPs)); - pnode->CloseSocketDisconnect(); + MarkAsDisconnectAndCloseConnection(*pnode); DeleteNode(pnode); } @@ -3713,7 +3518,6 @@ static std::unique_ptr MakeTransport(NodeId id, bool use_v2transport, } CNode::CNode(NodeId idIn, - std::shared_ptr sock, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, @@ -3724,7 +3528,6 @@ CNode::CNode(NodeId idIn, CNodeOptions&& node_opts) : m_transport{MakeTransport(idIn, node_opts.use_v2transport, conn_type_in == ConnectionType::INBOUND)}, m_permission_flags{node_opts.permission_flags}, - m_sock{sock}, m_connected{GetTime()}, addr{addrIn}, addrBind{addrBindIn}, @@ -3736,8 +3539,7 @@ CNode::CNode(NodeId idIn, m_conn_type{conn_type_in}, id{idIn}, nLocalHostNonce{nLocalHostNonceIn}, - m_recv_flood_size{node_opts.recv_flood_size}, - m_i2p_sam_session{std::move(node_opts.i2p_sam_session)} + m_recv_flood_size{node_opts.recv_flood_size} { if (inbound_onion) assert(conn_type_in == ConnectionType::INBOUND); @@ -3823,13 +3625,13 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) // If there was nothing to send before, and there is now (predicted by the "more" value // returned by the GetBytesToSend call above), attempt "optimistic write": - // because the poll/select loop may pause for SELECT_TIMEOUT_MILLISECONDS before actually + // because the poll/select loop may pause for a while before actually // doing a send, try sending from the calling thread if the queue was empty before. // With a V1Transport, more will always be true here, because adding a message always // results in sendable bytes there, but with V2Transport this is not the case (it may // still be in the handshake). if (queue_was_empty && more) { - SocketSendData(*pnode); + SendMessagesAsBytes(*pnode); } } } diff --git a/src/net.h b/src/net.h index 2f435b2241056..39200b9915821 100644 --- a/src/net.h +++ b/src/net.h @@ -661,7 +661,6 @@ class V2Transport final : public Transport struct CNodeOptions { NetPermissionFlags permission_flags = NetPermissionFlags::None; - std::unique_ptr i2p_sam_session = nullptr; bool prefer_evict = false; size_t recv_flood_size{DEFAULT_MAXRECEIVEBUFFER * 1000}; bool use_v2transport = false; @@ -677,16 +676,6 @@ class CNode const NetPermissionFlags m_permission_flags; - /** - * Socket used for communication with the node. - * May not own a Sock object (after `CloseSocketDisconnect()` or during tests). - * `shared_ptr` (instead of `unique_ptr`) is used to avoid premature close of - * the underlying file descriptor by one thread while another thread is - * poll(2)-ing it for activity. - * @see https://github.com/bitcoin/bitcoin/issues/21744 for details. - */ - std::shared_ptr m_sock GUARDED_BY(m_sock_mutex); - /** Sum of GetMemoryUsage of all vSendMsg entries. */ size_t m_send_memusage GUARDED_BY(cs_vSend){0}; /** Total number of bytes sent on the wire to this peer. */ @@ -878,7 +867,6 @@ class CNode std::atomic m_min_ping_time{std::chrono::microseconds::max()}; CNode(NodeId id, - std::shared_ptr sock, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, @@ -940,8 +928,6 @@ class CNode nRefCount--; } - void CloseSocketDisconnect() EXCLUSIVE_LOCKS_REQUIRED(!m_sock_mutex); - void CopyStats(CNodeStats& stats) EXCLUSIVE_LOCKS_REQUIRED(!m_subver_mutex, !m_addr_local_mutex, !cs_vSend, !cs_vRecv); std::string ConnectionTypeAsString() const { return ::ConnectionTypeAsString(m_conn_type); } @@ -986,18 +972,6 @@ class CNode mapMsgTypeSize mapSendBytesPerMsgType GUARDED_BY(cs_vSend); mapMsgTypeSize mapRecvBytesPerMsgType GUARDED_BY(cs_vRecv); - - /** - * If an I2P session is created per connection (for outbound transient I2P - * connections) then it is stored here so that it can be destroyed when the - * socket is closed. I2P sessions involve a data/transport socket (in `m_sock`) - * and a control socket (in `m_i2p_sam_session`). For transient sessions, once - * the data socket is closed, the control socket is not going to be used anymore - * and is just taking up resources. So better close it as soon as `m_sock` is - * closed. - * Otherwise this unique_ptr is empty. - */ - std::unique_ptr m_i2p_sam_session GUARDED_BY(m_sock_mutex); }; /** @@ -1293,15 +1267,21 @@ class CConnman : private SockMan /** * Create a `CNode` object and add it to the `m_nodes` member. * @param[in] node_id Id of the newly accepted connection. - * @param[in] sock Connected socket to communicate with the peer. * @param[in] me The address and port at our side of the connection. * @param[in] them The address and port at the peer's side of the connection. + * @retval true on success + * @retval false on failure, meaning that the associated socket and node_id should be discarded */ - virtual void EventNewConnectionAccepted(NodeId node_id, - std::unique_ptr&& sock, + virtual bool EventNewConnectionAccepted(NodeId node_id, const CService& me, const CService& them) override; + /** + * Mark a node as disconnected and close its connection with the peer. + * @param[in] node Node to disconnect. + */ + void MarkAsDisconnectAndCloseConnection(CNode& node); + void DisconnectNodes() EXCLUSIVE_LOCKS_REQUIRED(!m_reconnections_mutex, !m_nodes_mutex); void NotifyNumConnectionsChanged(); /** Return true if the peer is inactive and should be disconnected. */ @@ -1330,37 +1310,7 @@ class CConnman : private SockMan virtual void EventIOLoopCompletedForAllPeers() override EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex, !m_reconnections_mutex); - - /** - * Generate a collection of sockets to check for IO readiness. - * @param[in] nodes Select from these nodes' sockets. - * @return sockets to check for readiness - */ - Sock::EventsPerSock GenerateWaitSockets(Span nodes) - EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex); - - /** - * Check connected and listening sockets for IO readiness and process them accordingly. - */ - void SocketHandler() - EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex, !m_total_bytes_sent_mutex, !mutexMsgProc); - - /** - * Do the read/write for connected sockets that are ready for IO. - * @param[in] nodes Nodes to process. The socket of each node is checked against `what`. - * @param[in] events_per_sock Sockets that are ready for IO. - */ - void SocketHandlerConnected(const std::vector& nodes, - const Sock::EventsPerSock& events_per_sock) - EXCLUSIVE_LOCKS_REQUIRED(!m_nodes_mutex, !m_total_bytes_sent_mutex, !mutexMsgProc); - - /** - * Accept incoming connections, one from each read-ready listening socket. - * @param[in] events_per_sock Sockets that are ready for IO. - */ - void SocketHandlerListening(const Sock::EventsPerSock& events_per_sock); - void ThreadSocketHandler() EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc, !m_nodes_mutex, !m_reconnections_mutex); void ThreadDNSAddressSeed() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_nodes_mutex); uint64_t CalculateKeyedNetGroup(const CNetAddr& ad) const; @@ -1382,8 +1332,8 @@ class CConnman : private SockMan void DeleteNode(CNode* pnode); /** (Try to) send data from node's vSendMsg. Returns (bytes_sent, data_left). */ - std::pair SocketSendData(CNode& node) - EXCLUSIVE_LOCKS_REQUIRED(node.cs_vSend, !m_total_bytes_sent_mutex); + std::pair SendMessagesAsBytes(CNode& node) EXCLUSIVE_LOCKS_REQUIRED(node.cs_vSend) + EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex); void DumpAddresses(); @@ -1558,7 +1508,6 @@ class CConnman : private SockMan std::atomic flagInterruptMsgProc{false}; std::thread threadDNSAddressSeed; - std::thread threadSocketHandler; std::thread threadOpenAddedConnections; std::thread threadOpenConnections; std::thread threadMessageHandler; diff --git a/src/test/denialofservice_tests.cpp b/src/test/denialofservice_tests.cpp index 9ee7e9c9fe24b..d2e62c7395cd8 100644 --- a/src/test/denialofservice_tests.cpp +++ b/src/test/denialofservice_tests.cpp @@ -55,7 +55,6 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) CAddress addr1(ip(0xa0b0c001), NODE_NONE); NodeId id{0}; CNode dummyNode1{id++, - /*sock=*/nullptr, addr1, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -121,7 +120,6 @@ void AddRandomOutboundPeer(NodeId& id, std::vector& vNodes, PeerManager& } vNodes.emplace_back(new CNode{id++, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -320,7 +318,6 @@ BOOST_AUTO_TEST_CASE(peer_discouragement) banman->ClearBanned(); NodeId id{0}; nodes[0] = new CNode{id++, - /*sock=*/nullptr, addr[0], /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -340,7 +337,6 @@ BOOST_AUTO_TEST_CASE(peer_discouragement) BOOST_CHECK(!banman->IsDiscouraged(other_addr)); // Different address, not discouraged nodes[1] = new CNode{id++, - /*sock=*/nullptr, addr[1], /*nKeyedNetGroupIn=*/1, /*nLocalHostNonceIn=*/1, @@ -370,7 +366,6 @@ BOOST_AUTO_TEST_CASE(peer_discouragement) // Make sure non-IP peers are discouraged and disconnected properly. nodes[2] = new CNode{id++, - /*sock=*/nullptr, addr[2], /*nKeyedNetGroupIn=*/1, /*nLocalHostNonceIn=*/1, @@ -412,7 +407,6 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) CAddress addr(ip(0xa0b0c001), NODE_NONE); NodeId id{0}; CNode dummyNode{id++, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/4, /*nLocalHostNonceIn=*/4, diff --git a/src/test/fuzz/connman.cpp b/src/test/fuzz/connman.cpp index 5d2bdaf98b591..ce5ed6eacbc0e 100644 --- a/src/test/fuzz/connman.cpp +++ b/src/test/fuzz/connman.cpp @@ -65,13 +65,14 @@ FUZZ_TARGET(connman, .init = initialize_connman) CNetAddr random_netaddr; NodeId node_id{0}; - CNode random_node = ConsumeNode(fuzzed_data_provider, node_id++); + CNode& random_node{*ConsumeNodeAsUniquePtr(fuzzed_data_provider, node_id++).release()}; + connman.AddTestNode(random_node, std::make_unique(fuzzed_data_provider)); CSubNet random_subnet; std::string random_string; LIMITED_WHILE(fuzzed_data_provider.ConsumeBool(), 100) { CNode& p2p_node{*ConsumeNodeAsUniquePtr(fuzzed_data_provider, node_id++).release()}; - connman.AddTestNode(p2p_node); + connman.AddTestNode(p2p_node, std::make_unique(fuzzed_data_provider)); } LIMITED_WHILE(fuzzed_data_provider.ConsumeBool(), 10000) { diff --git a/src/test/fuzz/net.cpp b/src/test/fuzz/net.cpp index 1a0de7aa3631e..8d02e5c4efda6 100644 --- a/src/test/fuzz/net.cpp +++ b/src/test/fuzz/net.cpp @@ -42,9 +42,6 @@ FUZZ_TARGET(net, .init = initialize_net) LIMITED_WHILE(fuzzed_data_provider.ConsumeBool(), 10000) { CallOneOf( fuzzed_data_provider, - [&] { - node.CloseSocketDisconnect(); - }, [&] { CNodeStats stats; node.CopyStats(stats); diff --git a/src/test/fuzz/p2p_handshake.cpp b/src/test/fuzz/p2p_handshake.cpp index d608efd87aceb..8d5d65655b415 100644 --- a/src/test/fuzz/p2p_handshake.cpp +++ b/src/test/fuzz/p2p_handshake.cpp @@ -65,7 +65,7 @@ FUZZ_TARGET(p2p_handshake, .init = ::initialize) const auto num_peers_to_add = fuzzed_data_provider.ConsumeIntegralInRange(1, 3); for (int i = 0; i < num_peers_to_add; ++i) { peers.push_back(ConsumeNodeAsUniquePtr(fuzzed_data_provider, i).release()); - connman.AddTestNode(*peers.back()); + connman.AddTestNode(*peers.back(), std::make_unique(fuzzed_data_provider)); peerman->InitializeNode( *peers.back(), static_cast(fuzzed_data_provider.ConsumeIntegral())); diff --git a/src/test/fuzz/p2p_headers_presync.cpp b/src/test/fuzz/p2p_headers_presync.cpp index ed7041ad1f1ad..94d3f290d4d4d 100644 --- a/src/test/fuzz/p2p_headers_presync.cpp +++ b/src/test/fuzz/p2p_headers_presync.cpp @@ -60,7 +60,7 @@ void HeadersSyncSetup::ResetAndInitialize() for (auto conn_type : conn_types) { CAddress addr{}; - m_connections.push_back(new CNode(id++, nullptr, addr, 0, 0, addr, "", conn_type, false)); + m_connections.push_back(new CNode(id++, addr, 0, 0, addr, "", conn_type, false)); CNode& p2p_node = *m_connections.back(); connman.Handshake( diff --git a/src/test/fuzz/process_message.cpp b/src/test/fuzz/process_message.cpp index 4bd38a1ac684b..e94f5b2b3d79b 100644 --- a/src/test/fuzz/process_message.cpp +++ b/src/test/fuzz/process_message.cpp @@ -68,7 +68,7 @@ FUZZ_TARGET(process_message, .init = initialize_process_message) } CNode& p2p_node = *ConsumeNodeAsUniquePtr(fuzzed_data_provider).release(); - connman.AddTestNode(p2p_node); + connman.AddTestNode(p2p_node, std::make_unique(fuzzed_data_provider)); FillNode(fuzzed_data_provider, connman, p2p_node); const auto mock_time = ConsumeTime(fuzzed_data_provider); diff --git a/src/test/fuzz/process_messages.cpp b/src/test/fuzz/process_messages.cpp index 0688868c02b21..dbb221e6056e5 100644 --- a/src/test/fuzz/process_messages.cpp +++ b/src/test/fuzz/process_messages.cpp @@ -60,7 +60,7 @@ FUZZ_TARGET(process_messages, .init = initialize_process_messages) FillNode(fuzzed_data_provider, connman, p2p_node); - connman.AddTestNode(p2p_node); + connman.AddTestNode(p2p_node, std::make_unique(fuzzed_data_provider)); } LIMITED_WHILE(fuzzed_data_provider.ConsumeBool(), 30) diff --git a/src/test/fuzz/util/net.h b/src/test/fuzz/util/net.h index cc73cdff4b795..e2ea4a340770a 100644 --- a/src/test/fuzz/util/net.h +++ b/src/test/fuzz/util/net.h @@ -221,7 +221,6 @@ template auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional& node_id_in = std::nullopt) noexcept { const NodeId node_id = node_id_in.value_or(fuzzed_data_provider.ConsumeIntegralInRange(0, std::numeric_limits::max())); - const auto sock = std::make_shared(fuzzed_data_provider); const CAddress address = ConsumeAddress(fuzzed_data_provider); const uint64_t keyed_net_group = fuzzed_data_provider.ConsumeIntegral(); const uint64_t local_host_nonce = fuzzed_data_provider.ConsumeIntegral(); @@ -232,7 +231,6 @@ auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional(node_id, - sock, address, keyed_net_group, local_host_nonce, @@ -243,7 +241,6 @@ auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional& nodes, PeerManager& peerman, Connm const bool inbound_onion{onion_peer && conn_type == ConnectionType::INBOUND}; nodes.emplace_back(new CNode{++id, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index 5f0f05c842ad4..b4d898b3eabb5 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -60,7 +60,6 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) std::string pszDest; std::unique_ptr pnode1 = std::make_unique(id++, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -78,7 +77,6 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK_EQUAL(pnode1->ConnectedThroughNetwork(), Network::NET_IPV4); std::unique_ptr pnode2 = std::make_unique(id++, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/1, /*nLocalHostNonceIn=*/1, @@ -96,7 +94,6 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK_EQUAL(pnode2->ConnectedThroughNetwork(), Network::NET_IPV4); std::unique_ptr pnode3 = std::make_unique(id++, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -114,7 +111,6 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK_EQUAL(pnode3->ConnectedThroughNetwork(), Network::NET_IPV4); std::unique_ptr pnode4 = std::make_unique(id++, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/1, /*nLocalHostNonceIn=*/1, @@ -613,7 +609,6 @@ BOOST_AUTO_TEST_CASE(ipv4_peer_with_ipv6_addrMe_test) ipv4AddrPeer.s_addr = 0xa0b0c001; CAddress addr = CAddress(CService(ipv4AddrPeer, 7777), NODE_NETWORK); std::unique_ptr pnode = std::make_unique(/*id=*/0, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -667,7 +662,6 @@ BOOST_AUTO_TEST_CASE(get_local_addr_for_peer_port) in_addr peer_out_in_addr; peer_out_in_addr.s_addr = htonl(0x01020304); CNode peer_out{/*id=*/0, - /*sock=*/nullptr, /*addrIn=*/CAddress{CService{peer_out_in_addr, 8333}, NODE_NETWORK}, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -688,7 +682,6 @@ BOOST_AUTO_TEST_CASE(get_local_addr_for_peer_port) in_addr peer_in_in_addr; peer_in_in_addr.s_addr = htonl(0x05060708); CNode peer_in{/*id=*/0, - /*sock=*/nullptr, /*addrIn=*/CAddress{CService{peer_in_in_addr, 8333}, NODE_NETWORK}, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -825,7 +818,6 @@ BOOST_AUTO_TEST_CASE(initial_advertise_from_version_message) in_addr peer_in_addr; peer_in_addr.s_addr = htonl(0x01020304); CNode peer{/*id=*/0, - /*sock=*/nullptr, /*addrIn=*/CAddress{CService{peer_in_addr, 8333}, NODE_NETWORK}, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, @@ -900,7 +892,6 @@ BOOST_AUTO_TEST_CASE(advertise_local_address) { auto CreatePeer = [](const CAddress& addr) { return std::make_unique(/*id=*/0, - /*sock=*/nullptr, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, diff --git a/src/test/util/net.h b/src/test/util/net.h index 6c07e98fd8174..a211f2e097eef 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -48,6 +48,12 @@ struct ConnmanTestMsg : public CConnman { return m_nodes; } + void AddTestNode(CNode& node, std::unique_ptr&& sock) + { + TestOnlyAddExistentNode(node.GetId(), std::move(sock)); + AddTestNode(node); + } + void AddTestNode(CNode& node) { LOCK(m_nodes_mutex); From bcf1254e91782a33b5a1db542f3758831e0bfe2e Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 23 Sep 2024 11:05:59 +0200 Subject: [PATCH 20/30] net: move-only: improve encapsulation of SockMan `SockMan` members `AcceptConnection()` `NewSockAccepted()` `GetNewNodeId()` `m_i2p_sam_session` `m_listen private` are now used only by `SockMan`, thus make them private. --- src/common/sockman.cpp | 118 ++++++++++++++++++++--------------------- src/common/sockman.h | 70 ++++++++++++------------ 2 files changed, 94 insertions(+), 94 deletions(-) diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp index fcbb29da87f35..e9f9907d6d65d 100644 --- a/src/common/sockman.cpp +++ b/src/common/sockman.cpp @@ -217,65 +217,6 @@ SockMan::ConnectAndMakeNodeId(const std::variant& t return node_id; } -std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CService& addr) -{ - sockaddr_storage storage; - socklen_t len{sizeof(storage)}; - - auto sock{listen_sock.Accept(reinterpret_cast(&storage), &len)}; - - if (!sock) { - const int err{WSAGetLastError()}; - if (err != WSAEWOULDBLOCK) { - LogPrintLevel(BCLog::NET, - BCLog::Level::Error, - "Cannot accept new connection: %s\n", - NetworkErrorString(err)); - } - return {}; - } - - if (!addr.SetSockAddr(reinterpret_cast(&storage))) { - LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Unknown socket family\n"); - } - - return sock; -} - -void SockMan::NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) -{ - AssertLockNotHeld(m_connected_mutex); - - if (!sock->IsSelectable()) { - LogPrintf("connection from %s dropped: non-selectable socket\n", them.ToStringAddrPort()); - return; - } - - // According to the internet TCP_NODELAY is not carried into accepted sockets - // on all platforms. Set it again here just to be sure. - const int on{1}; - if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { - LogDebug(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n", - them.ToStringAddrPort()); - } - - const NodeId node_id{GetNewNodeId()}; - - { - LOCK(m_connected_mutex); - m_connected.emplace(node_id, std::make_shared(std::move(sock))); - } - - if (!EventNewConnectionAccepted(node_id, me, them)) { - CloseConnection(node_id); - } -} - -NodeId SockMan::GetNewNodeId() -{ - return m_next_node_id.fetch_add(1, std::memory_order_relaxed); -} - bool SockMan::CloseConnection(NodeId node_id) { LOCK(m_connected_mutex); @@ -409,6 +350,65 @@ void SockMan::ThreadSocketHandler() } } +std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CService& addr) +{ + sockaddr_storage storage; + socklen_t len{sizeof(storage)}; + + auto sock{listen_sock.Accept(reinterpret_cast(&storage), &len)}; + + if (!sock) { + const int err{WSAGetLastError()}; + if (err != WSAEWOULDBLOCK) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Error, + "Cannot accept new connection: %s\n", + NetworkErrorString(err)); + } + return {}; + } + + if (!addr.SetSockAddr(reinterpret_cast(&storage))) { + LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Unknown socket family\n"); + } + + return sock; +} + +void SockMan::NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) +{ + AssertLockNotHeld(m_connected_mutex); + + if (!sock->IsSelectable()) { + LogPrintf("connection from %s dropped: non-selectable socket\n", them.ToStringAddrPort()); + return; + } + + // According to the internet TCP_NODELAY is not carried into accepted sockets + // on all platforms. Set it again here just to be sure. + const int on{1}; + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + LogDebug(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n", + them.ToStringAddrPort()); + } + + const NodeId node_id{GetNewNodeId()}; + + { + LOCK(m_connected_mutex); + m_connected.emplace(node_id, std::make_shared(std::move(sock))); + } + + if (!EventNewConnectionAccepted(node_id, me, them)) { + CloseConnection(node_id); + } +} + +NodeId SockMan::GetNewNodeId() +{ + return m_next_node_id.fetch_add(1, std::memory_order_relaxed); +} + SockMan::IOReadiness SockMan::GenerateWaitSockets() { AssertLockNotHeld(m_connected_mutex); diff --git a/src/common/sockman.h b/src/common/sockman.h index 570467d965077..ec9b251464360 100644 --- a/src/common/sockman.h +++ b/src/common/sockman.h @@ -107,29 +107,6 @@ class SockMan CService& me) EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex, !m_unused_i2p_sessions_mutex); - /** - * Accept a connection. - * @param[in] listen_sock Socket on which to accept the connection. - * @param[out] addr Address of the peer that was accepted. - * @return Newly created socket for the accepted connection. - */ - std::unique_ptr AcceptConnection(const Sock& listen_sock, CService& addr); - - /** - * After a new socket with a peer has been created, configure its flags, - * make a new node id and call `EventNewConnectionAccepted()`. - * @param[in] sock The newly created socket. - * @param[in] me Address at our end of the connection. - * @param[in] them Address of the new peer. - */ - void NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) - EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); - - /** - * Generate an id for a newly created node. - */ - NodeId GetNewNodeId(); - /** * Disconnect a given peer by closing its socket and release resources occupied by it. * @return Whether the peer existed and its socket was closed by this call. @@ -167,18 +144,6 @@ class SockMan */ CThreadInterrupt interruptNet; - /** - * I2P SAM session. - * Used to accept incoming and make outgoing I2P connections from a persistent - * address. - */ - std::unique_ptr m_i2p_sam_session; - - /** - * List of listening sockets. - */ - std::vector> m_listen; - protected: /** @@ -356,6 +321,29 @@ class SockMan void ThreadSocketHandler() EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + /** + * Accept a connection. + * @param[in] listen_sock Socket on which to accept the connection. + * @param[out] addr Address of the peer that was accepted. + * @return Newly created socket for the accepted connection. + */ + std::unique_ptr AcceptConnection(const Sock& listen_sock, CService& addr); + + /** + * After a new socket with a peer has been created, configure its flags, + * make a new node id and call `EventNewConnectionAccepted()`. + * @param[in] sock The newly created socket. + * @param[in] me Address at our end of the connection. + * @param[in] them Address of the new peer. + */ + void NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Generate an id for a newly created node. + */ + NodeId GetNewNodeId(); + /** * Generate a collection of sockets to check for IO readiness. * @return Sockets to check for readiness plus an aux map to find the @@ -415,6 +403,18 @@ class SockMan */ std::queue> m_unused_i2p_sessions GUARDED_BY(m_unused_i2p_sessions_mutex); + /** + * I2P SAM session. + * Used to accept incoming and make outgoing I2P connections from a persistent + * address. + */ + std::unique_ptr m_i2p_sam_session; + + /** + * List of listening sockets. + */ + std::vector> m_listen; + mutable Mutex m_connected_mutex; /** From 731e063eeb0703a17dc5a889da1ca85f59b9f3ba Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Thu, 30 Nov 2023 13:58:48 +0100 Subject: [PATCH 21/30] Add sv2 log category for Stratum v2 --- src/logging.cpp | 1 + src/logging.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/logging.cpp b/src/logging.cpp index 5f055566ef5d8..3aff2a439a001 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -198,6 +198,7 @@ static const std::map> LOG_CATEGORIES_ {"blockstorage", BCLog::BLOCKSTORAGE}, {"txreconciliation", BCLog::TXRECONCILIATION}, {"scan", BCLog::SCAN}, + {"sv2", BCLog::SV2}, {"txpackages", BCLog::TXPACKAGES}, }; diff --git a/src/logging.h b/src/logging.h index fdc12c79b3281..f33f44f6ea0b5 100644 --- a/src/logging.h +++ b/src/logging.h @@ -71,6 +71,7 @@ namespace BCLog { TXRECONCILIATION = (CategoryMask{1} << 26), SCAN = (CategoryMask{1} << 27), TXPACKAGES = (CategoryMask{1} << 28), + SV2 = (CategoryMask{1} << 29), ALL = ~NONE, }; enum class Level { From d53193d33d738475ddf14d761cadec621a592bc2 Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Fri, 13 Sep 2024 11:05:43 +0200 Subject: [PATCH 22/30] build: libbitcoin_sv2 scaffold --- CMakeLists.txt | 3 +++ doc/design/libraries.md | 1 + src/CMakeLists.txt | 3 +++ src/sv2/CMakeLists.txt | 14 ++++++++++++++ src/test/CMakeLists.txt | 7 +++++++ src/test/fuzz/CMakeLists.txt | 9 +++++++++ 6 files changed, 37 insertions(+) create mode 100644 src/sv2/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index e542e217c5cae..035df3d6b4f9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -149,6 +149,8 @@ if(WITH_MULTIPROCESS) ) endif() +option(WITH_SV2 "Enable Stratum v2 functionality." ON) + cmake_dependent_option(BUILD_GUI_TESTS "Build test_bitcoin-qt executable." ON "BUILD_GUI;BUILD_TESTS" OFF) if(BUILD_GUI) set(qt_components Core Gui Widgets LinguistTools) @@ -633,6 +635,7 @@ message(" ZeroMQ .............................. ${WITH_ZMQ}") message(" USDT tracing ........................ ${WITH_USDT}") message(" QR code (GUI) ....................... ${WITH_QRENCODE}") message(" DBus (GUI, Linux only) .............. ${WITH_DBUS}") +message(" Stratum v2 .......................... ${WITH_SV2}") message("Tests:") message(" test_bitcoin ........................ ${BUILD_TESTS}") message(" test_bitcoin-qt ..................... ${BUILD_GUI_TESTS}") diff --git a/doc/design/libraries.md b/doc/design/libraries.md index 24185bf4776df..8448fb7011d69 100644 --- a/doc/design/libraries.md +++ b/doc/design/libraries.md @@ -14,6 +14,7 @@ | *libbitcoin_wallet* | Wallet functionality used by *bitcoind* and *bitcoin-wallet* executables. | | *libbitcoin_wallet_tool* | Lower-level wallet functionality used by *bitcoin-wallet* executable. | | *libbitcoin_zmq* | [ZeroMQ](../zmq.md) functionality used by *bitcoind* and *bitcoin-qt* executables. | +| *libbitcoin_sv2* | Stratum v2 functionality (usage TBD) | ## Conventions diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 89fdd855a4598..91ba1efdf3273 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -29,6 +29,9 @@ add_subdirectory(util) if(WITH_MULTIPROCESS) add_subdirectory(ipc) endif() +if(WITH_SV2) + add_subdirectory(sv2) +endif() #============================= # secp256k1 subtree diff --git a/src/sv2/CMakeLists.txt b/src/sv2/CMakeLists.txt new file mode 100644 index 0000000000000..e02c4c01fa877 --- /dev/null +++ b/src/sv2/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) 2024-present The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or https://opensource.org/license/mit/. + +add_library(bitcoin_sv2 STATIC EXCLUDE_FROM_ALL +) + +target_link_libraries(bitcoin_sv2 + PRIVATE + core_interface + bitcoin_clientversion + bitcoin_crypto + $<$:ws2_32> +) diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 859b913206782..83b9c403b1151 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -175,6 +175,13 @@ if(WITH_MULTIPROCESS) target_link_libraries(test_bitcoin bitcoin_ipc_test bitcoin_ipc) endif() +if(WITH_SV2) + target_sources(test_bitcoin + PRIVATE + ) + target_link_libraries(test_bitcoin bitcoin_sv2) +endif() + function(add_boost_test source_file) if(NOT EXISTS ${source_file}) return() diff --git a/src/test/fuzz/CMakeLists.txt b/src/test/fuzz/CMakeLists.txt index a261d3ecea238..c47a958f82b1e 100644 --- a/src/test/fuzz/CMakeLists.txt +++ b/src/test/fuzz/CMakeLists.txt @@ -151,3 +151,12 @@ target_link_libraries(fuzz if(ENABLE_WALLET) add_subdirectory(${PROJECT_SOURCE_DIR}/src/wallet/test/fuzz wallet) endif() + +if(WITH_SV2) + target_sources(fuzz + PRIVATE + ) + target_link_libraries(fuzz + bitcoin_sv2 + ) +endif() From 5f40f59abc402d1aeb3287eddf34b30d34fc9d13 Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Thu, 19 Dec 2024 12:36:07 +0700 Subject: [PATCH 23/30] Add sv2 noise protocol Co-Authored-By: Christopher Coverdale --- src/pubkey.h | 2 +- src/sv2/CMakeLists.txt | 1 + src/sv2/noise.cpp | 508 +++++++++++++++++++++++++++++++++++ src/sv2/noise.h | 299 +++++++++++++++++++++ src/test/CMakeLists.txt | 1 + src/test/fuzz/CMakeLists.txt | 1 + src/test/fuzz/sv2_noise.cpp | 169 ++++++++++++ src/test/sv2_noise_tests.cpp | 159 +++++++++++ 8 files changed, 1139 insertions(+), 1 deletion(-) create mode 100644 src/sv2/noise.cpp create mode 100644 src/sv2/noise.h create mode 100644 src/test/fuzz/sv2_noise.cpp create mode 100644 src/test/sv2_noise_tests.cpp diff --git a/src/pubkey.h b/src/pubkey.h index b4666aad228e1..798687de1fe5a 100644 --- a/src/pubkey.h +++ b/src/pubkey.h @@ -319,7 +319,7 @@ struct EllSwiftPubKey /** Construct a new ellswift public key from a given serialization. */ EllSwiftPubKey(Span ellswift) noexcept; - /** Decode to normal compressed CPubKey (for debugging purposes). */ + /** Decode to normal compressed CPubKey. */ CPubKey Decode() const; // Read-only access for serialization. diff --git a/src/sv2/CMakeLists.txt b/src/sv2/CMakeLists.txt index e02c4c01fa877..d6e44842e8c87 100644 --- a/src/sv2/CMakeLists.txt +++ b/src/sv2/CMakeLists.txt @@ -3,6 +3,7 @@ # file COPYING or https://opensource.org/license/mit/. add_library(bitcoin_sv2 STATIC EXCLUDE_FROM_ALL + noise.cpp ) target_link_libraries(bitcoin_sv2 diff --git a/src/sv2/noise.cpp b/src/sv2/noise.cpp new file mode 100644 index 0000000000000..798450a67246a --- /dev/null +++ b/src/sv2/noise.cpp @@ -0,0 +1,508 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include + +#include +#include +#include +#include +#include +#include + +Sv2SignatureNoiseMessage::Sv2SignatureNoiseMessage(uint16_t version, uint32_t valid_from, uint32_t valid_to, const XOnlyPubKey& static_key, const CKey& authority_key) : m_version{version}, m_valid_from{valid_from}, m_valid_to{valid_to}, m_static_key{static_key} +{ + SignSchnorr(authority_key, m_sig); +} + +uint256 Sv2SignatureNoiseMessage::GetHash() +{ + DataStream ss{}; + ss << m_version + << m_valid_from + << m_valid_to + << m_static_key; + + LogTrace(BCLog::SV2, "Certificate hashed data: %s\n", HexStr(ss)); + + CSHA256 hasher; + hasher.Write(reinterpret_cast(&(*ss.begin())), ss.end() - ss.begin()); + + uint256 hash_output; + hasher.Finalize(hash_output.begin()); + return hash_output; +} + +bool Sv2SignatureNoiseMessage::Validate(XOnlyPubKey authority_key) +{ + if (m_version > 0) { + LogTrace(BCLog::SV2, "Invalid certificate version: %d\n", m_version); + return false; + } + auto now{GetTime()}; + if (std::chrono::seconds{m_valid_from} > now) { + LogTrace(BCLog::SV2, "Certificate valid from is in the future: %d\n", m_valid_from); + return false; + } + if (std::chrono::seconds{m_valid_to} < now) { + LogTrace(BCLog::SV2, "Certificate expired: %d\n", m_valid_to); + return false; + } + + if (!authority_key.VerifySchnorr(this->GetHash(), m_sig)) { + LogTrace(BCLog::SV2, "Certificate signature is invalid\n"); + return false; + } + return true; +} + +void Sv2SignatureNoiseMessage::SignSchnorr(const CKey& authority_key, Span sig) +{ + authority_key.SignSchnorr(this->GetHash(), sig, nullptr, {}); +} + +Sv2CipherState::Sv2CipherState(NoiseHash&& key) : m_key(std::move(key)) {}; + +bool Sv2CipherState::DecryptWithAd(Span associated_data, Span ciphertext, Span plain) +{ + Assume(Sv2Cipher::EncryptedMessageSize(plain.size()) == ciphertext.size()); + + if (m_nonce == UINT64_MAX) { + // This nonce value is reserved, see chapter 5.1 of the Noise paper. + LogTrace(BCLog::SV2, "Nonce exceeds maximum value\n"); + return false; + } + AEADChaCha20Poly1305::Nonce96 nonce = {0, m_nonce}; + auto key = MakeByteSpan(m_key); + AEADChaCha20Poly1305 aead{key}; + if (!aead.Decrypt(ciphertext, associated_data, nonce, plain)) { + LogTrace(BCLog::SV2, "Message decryption failed\n"); + return false; + } + // Only increase nonce if decryption succeeded + m_nonce++; + return true; +} + +bool Sv2CipherState::EncryptWithAd(Span associated_data, Span plain, Span ciphertext) +{ + Assume(Sv2Cipher::EncryptedMessageSize(plain.size()) == ciphertext.size()); + + if (m_nonce == UINT64_MAX) { + // This nonce value is reserved, see chapter 5.1 of the Noise paper. + LogTrace(BCLog::SV2, "Nonce exceeds maximum value\n"); + return false; + } + AEADChaCha20Poly1305::Nonce96 nonce = {0, m_nonce++}; + auto key = MakeByteSpan(m_key); + AEADChaCha20Poly1305 aead{key}; + aead.Encrypt(plain, associated_data, nonce, ciphertext); + return true; +} + +bool Sv2CipherState::EncryptMessage(Span plain, Span ciphertext) +{ + Assume(ciphertext.size() == Sv2Cipher::EncryptedMessageSize(plain.size())); + + std::vector ad; // No associated data + + constexpr size_t max_chunk_size = NOISE_MAX_CHUNK_SIZE - Poly1305::TAGLEN; + size_t num_chunks = (plain.size() + max_chunk_size - 1) / max_chunk_size; + if (num_chunks > 1) { + LogTrace(BCLog::SV2, + "Split into %d chunks (max %d bytes)\n", + num_chunks, max_chunk_size); + } + + // Copy input bytes into output buffer + const std::vector padding(Poly1305::TAGLEN, std::byte(0)); + for (size_t i = 0; i < num_chunks; ++i) { + size_t chunk_start = i * max_chunk_size; + size_t chunk_end = std::min(chunk_start + max_chunk_size, plain.size()); + size_t chunk_size = chunk_end - chunk_start; + const auto encrypted_chunk_start = ciphertext.begin() + i * NOISE_MAX_CHUNK_SIZE; + std::copy(plain.begin() + chunk_start, plain.begin() + chunk_start + chunk_size, encrypted_chunk_start); + std::copy(padding.begin(), padding.end(), encrypted_chunk_start + chunk_size); + } + + // Encrypt each chunk + size_t bytes_written = 0; + for (size_t i = 0; i < num_chunks; ++i) { + size_t chunk_size = std::min(ciphertext.size() - bytes_written, NOISE_MAX_CHUNK_SIZE); + Span chunk = ciphertext.subspan(bytes_written, chunk_size); + Span chunk_plain = ciphertext.subspan(bytes_written, chunk_size - Poly1305::TAGLEN); + if (!EncryptWithAd(ad, chunk_plain, chunk)) { + return false; + } + bytes_written += chunk.size(); + } + + Assume(bytes_written == ciphertext.size()); + return true; +} + +bool Sv2CipherState::DecryptMessage(Span ciphertext, Span plain) +{ + Assume(Sv2Cipher::EncryptedMessageSize(plain.size()) == ciphertext.size()); + + size_t processed = 0; + size_t plain_position = 0; + std::vector ad; // No associated data + + while (processed < ciphertext.size()) { + size_t chunk_size = std::min(ciphertext.size() - processed, NOISE_MAX_CHUNK_SIZE); + Span chunk_cipher = ciphertext.subspan(processed, chunk_size); + Span chunk_plain = plain.subspan(plain_position, chunk_size - Poly1305::TAGLEN); + if (!DecryptWithAd(ad, chunk_cipher, chunk_plain)) return false; + processed += chunk_size; + plain_position += chunk_size - Poly1305::TAGLEN; + } + + return true; +} + +void Sv2SymmetricState::MixHash(const Span input) +{ + m_hash_output = (HashWriter{} << m_hash_output << input).GetSHA256(); +} + +void Sv2SymmetricState::MixKey(const Span input_key_material) +{ + NoiseHash out0; + NoiseHash out1; + HKDF2(input_key_material, out0, out1); + m_chaining_key = std::move(out0); + m_cipher_state = Sv2CipherState{std::move(out1)}; +} + +std::string Sv2SymmetricState::GetChainingKey() +{ + return HexStr(m_chaining_key); +} + +void Sv2SymmetricState::LogChainingKey() +{ + LogTrace(BCLog::SV2, "Chaining key: %s\n", GetChainingKey()); +} + +void Sv2SymmetricState::HKDF2(const Span input_key_material, NoiseHash& out0, NoiseHash& out1) +{ + NoiseHash tmp_key; + CHMAC_SHA256 tmp_mac(m_chaining_key.data(), m_chaining_key.size()); + tmp_mac.Write(UCharCast(input_key_material.data()), input_key_material.size()); + tmp_mac.Finalize(tmp_key.data()); + + CHMAC_SHA256 out0_mac(tmp_key.data(), tmp_key.size()); + uint8_t one[1]{0x1}; + out0_mac.Write(one, 1); + out0_mac.Finalize(out0.data()); + + std::vector in1; + in1.reserve(HASHLEN + 1); + std::copy(out0.begin(), out0.end(), std::back_inserter(in1)); + in1.push_back(0x02); + + CHMAC_SHA256 out1_mac(tmp_key.data(), tmp_key.size()); + out1_mac.Write(&in1[0], in1.size()); + out1_mac.Finalize(out1.data()); +} + +bool Sv2SymmetricState::EncryptAndHash(Span plain, Span ciphertext) +{ + Assume(Sv2Cipher::EncryptedMessageSize(plain.size()) == ciphertext.size()); + + if (!m_cipher_state.EncryptWithAd(MakeByteSpan(m_hash_output), plain, ciphertext)) { + return false; + } + MixHash(ciphertext); + return true; +} + +bool Sv2SymmetricState::DecryptAndHash(Span ciphertext, Span plain) +{ + Assume(Sv2Cipher::EncryptedMessageSize(plain.size()) == ciphertext.size()); + + // The handshake requires mix hashing the cipher text NOT the decrypted + // plaintext. + std::vector ciphertext_copy; + ciphertext_copy.assign(ciphertext.begin(), ciphertext.end()); + + bool res = m_cipher_state.DecryptWithAd(MakeByteSpan(m_hash_output), ciphertext, plain); + if (!res) return false; + MixHash(ciphertext_copy); + return true; +} + +std::array Sv2SymmetricState::Split() +{ + NoiseHash send_key; + NoiseHash recv_key; + HKDF2({}, send_key, recv_key); + return {Sv2CipherState{std::move(send_key)}, Sv2CipherState{std::move(recv_key)}}; +} + +uint256 Sv2SymmetricState::GetHashOutput() +{ + return m_hash_output; +} + +void Sv2HandshakeState::SetEphemeralKey(CKey&& key) +{ + m_ephemeral_key = key; + m_ephemeral_ellswift_pk = m_ephemeral_key.EllSwiftCreate(MakeByteSpan(GetRandHash())); +}; + +void Sv2HandshakeState::GenerateEphemeralKey() noexcept +{ + Assume(!m_ephemeral_key.size()); + LogTrace(BCLog::SV2, "Generate ephemeral key\n"); + SetEphemeralKey(GenerateRandomKey()); +}; + +void Sv2HandshakeState::WriteMsgEphemeralPK(Span msg) +{ + if (msg.size() < ELLSWIFT_PUB_KEY_SIZE) { + throw std::runtime_error(strprintf("Invalid message size: %d bytes < %d", msg.size(), ELLSWIFT_PUB_KEY_SIZE)); + } + + if (!m_ephemeral_key.IsValid()) { + GenerateEphemeralKey(); + } + + LogTrace(BCLog::SV2, "Write our ephemeral key\n"); + std::copy(m_ephemeral_ellswift_pk.begin(), m_ephemeral_ellswift_pk.end(), msg.begin()); + + m_symmetric_state.MixHash(msg.subspan(0, ELLSWIFT_PUB_KEY_SIZE)); + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + std::vector empty; + m_symmetric_state.MixHash(empty); +} + +void Sv2HandshakeState::ReadMsgEphemeralPK(Span msg) +{ + LogTrace(BCLog::SV2, "Read their ephemeral key\n"); + Assume(msg.size() == ELLSWIFT_PUB_KEY_SIZE); + m_remote_ephemeral_ellswift_pk = EllSwiftPubKey(msg); + + m_symmetric_state.MixHash(msg.subspan(0, ELLSWIFT_PUB_KEY_SIZE)); + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + std::vector empty; + m_symmetric_state.MixHash(empty); +} + +void Sv2HandshakeState::WriteMsgES(Span msg) +{ + if (msg.size() < HANDSHAKE_STEP2_SIZE) { + throw std::runtime_error(strprintf("Invalid message size: %d bytes < %d", msg.size(), HANDSHAKE_STEP2_SIZE)); + } + + ssize_t bytes_written = 0; + + if (!m_ephemeral_key.IsValid()) { + GenerateEphemeralKey(); + } + + // Send our ephemeral pk. + LogTrace(BCLog::SV2, "Write our ephemeral key\n"); + std::copy(m_ephemeral_ellswift_pk.begin(), m_ephemeral_ellswift_pk.end(), msg.begin()); + + m_symmetric_state.MixHash(msg.subspan(0, ELLSWIFT_PUB_KEY_SIZE)); + bytes_written += ELLSWIFT_PUB_KEY_SIZE; + + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + LogTrace(BCLog::SV2, "Perform ECDH with the remote ephemeral key\n"); + ECDHSecret ecdh_secret{m_ephemeral_key.ComputeBIP324ECDHSecret(m_remote_ephemeral_ellswift_pk, + m_ephemeral_ellswift_pk, + /*initiating=*/false)}; + + LogTrace(BCLog::SV2, "Mix key with ECDH result: ephemeral ours -- remote ephemeral\n"); + m_symmetric_state.MixKey(ecdh_secret); + m_symmetric_state.LogChainingKey(); + + // Send our static pk. + LogTrace(BCLog::SV2, "Encrypt and write our static key\n"); + + if (!m_symmetric_state.EncryptAndHash(m_static_ellswift_pk, msg.subspan(ELLSWIFT_PUB_KEY_SIZE, ELLSWIFT_PUB_KEY_SIZE + Poly1305::TAGLEN))) { + // This should never happen + Assume(false); + throw std::runtime_error("Failed to encrypt our ephemeral key\n"); + } + + bytes_written += ELLSWIFT_PUB_KEY_SIZE + Poly1305::TAGLEN; + + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + LogTrace(BCLog::SV2, "Perform ECDH between our static and remote ephemeral key\n"); + ECDHSecret ecdh_static_secret{m_static_key.ComputeBIP324ECDHSecret(m_remote_ephemeral_ellswift_pk, + m_static_ellswift_pk, + /*initiating=*/false)}; + LogTrace(BCLog::SV2, "ECDH result: %s\n", HexStr(ecdh_static_secret)); + + LogTrace(BCLog::SV2, "Mix key with ECDH result: static ours -- remote ephemeral\n"); + m_symmetric_state.MixKey(ecdh_static_secret); + m_symmetric_state.LogChainingKey(); + + // Serialize our digital signature noise message and encrypt. + DataStream ss{}; + Assume(m_certificate); + ss << m_certificate.value(); + Assume(ss.size() == Sv2SignatureNoiseMessage::SIZE); + + LogTrace(BCLog::SV2, "Encrypt certificate: %s\n", HexStr(ss)); + if (!m_symmetric_state.EncryptAndHash(ss, msg.subspan(bytes_written, Sv2SignatureNoiseMessage::SIZE + Poly1305::TAGLEN))) { + // This should never happen + Assume(false); + throw std::runtime_error("Failed to encrypt our certificate\n"); + } + + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + bytes_written += Sv2SignatureNoiseMessage::SIZE + Poly1305::TAGLEN; + Assume(bytes_written == HANDSHAKE_STEP2_SIZE); +} + +bool Sv2HandshakeState::ReadMsgES(Span msg) +{ + Assume(msg.size() == HANDSHAKE_STEP2_SIZE); + ssize_t bytes_read = 0; + + // Read the remote ephemeral key from the msg and decrypt. + LogTrace(BCLog::SV2, "Read remote ephemeral key\n"); + m_remote_ephemeral_ellswift_pk = EllSwiftPubKey(msg.subspan(0, ELLSWIFT_PUB_KEY_SIZE)); + bytes_read += ELLSWIFT_PUB_KEY_SIZE; + + m_symmetric_state.MixHash(m_remote_ephemeral_ellswift_pk); + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + LogTrace(BCLog::SV2, "Perform ECDH with the remote ephemeral key\n"); + ECDHSecret ecdh_secret{m_ephemeral_key.ComputeBIP324ECDHSecret(m_remote_ephemeral_ellswift_pk, + m_ephemeral_ellswift_pk, + /*initiating=*/true)}; + + LogTrace(BCLog::SV2, "Mix key with ECDH result: ephemeral ours -- remote ephemeral\n"); + m_symmetric_state.MixKey(ecdh_secret); + m_symmetric_state.LogChainingKey(); + + LogTrace(BCLog::SV2, "Decrypt remote static key\n"); + std::array remote_static_key_bytes; + bool res = m_symmetric_state.DecryptAndHash(msg.subspan(ELLSWIFT_PUB_KEY_SIZE, ELLSWIFT_PUB_KEY_SIZE + Poly1305::TAGLEN), remote_static_key_bytes); + if (!res) return false; + bytes_read += ELLSWIFT_PUB_KEY_SIZE + Poly1305::TAGLEN; + + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + // Load remote static key from the decryted msg + m_remote_static_ellswift_pk = EllSwiftPubKey(remote_static_key_bytes); + + LogTrace(BCLog::SV2, "Perform ECDH on the remote static key\n"); + ECDHSecret ecdh_static_secret{m_ephemeral_key.ComputeBIP324ECDHSecret(m_remote_static_ellswift_pk, + m_ephemeral_ellswift_pk, + /*initiating=*/true)}; + LogTrace(BCLog::SV2, "ECDH result: %s\n", HexStr(ecdh_static_secret)); + + LogTrace(BCLog::SV2, "Mix key with ECDH result: ephemeral ours -- remote static\n"); + m_symmetric_state.MixKey(ecdh_static_secret); + m_symmetric_state.LogChainingKey(); + + LogTrace(BCLog::SV2, "Decrypt remote certificate\n"); + std::array remote_cert_bytes; + res = m_symmetric_state.DecryptAndHash(msg.subspan(bytes_read, Sv2SignatureNoiseMessage::SIZE + Poly1305::TAGLEN), remote_cert_bytes); + if (!res) return false; + bytes_read += (Sv2SignatureNoiseMessage::SIZE + Poly1305::TAGLEN); + LogTrace(BCLog::SV2, "Mix hash: %s\n", HexStr(m_symmetric_state.GetHashOutput())); + + LogTrace(BCLog::SV2, "Validate remote certificate\n"); + DataStream ss_cert(remote_cert_bytes); + Sv2SignatureNoiseMessage cert; + ss_cert >> cert; + cert.m_static_key = XOnlyPubKey(m_remote_static_ellswift_pk.Decode()); + Assume(m_authority_pubkey); + if (!cert.Validate(m_authority_pubkey.value())) { + // We initiated the connection, so it's safe to unconditionally log this: + LogWarning("Invalid certificate: %s\n", HexStr(remote_cert_bytes)); + return false; + } + + Assume(bytes_read == HANDSHAKE_STEP2_SIZE); + return true; +} + +std::array Sv2HandshakeState::SplitSymmetricState() +{ + return m_symmetric_state.Split(); +} + +uint256 Sv2HandshakeState::GetHashOutput() +{ + return m_symmetric_state.GetHashOutput(); +} + +Sv2Cipher::Sv2Cipher(CKey&& static_key, XOnlyPubKey authority_pubkey) +{ + m_handshake_state = std::make_unique(std::move(static_key), authority_pubkey); + m_initiator = true; +} + +Sv2Cipher::Sv2Cipher(CKey&& static_key, Sv2SignatureNoiseMessage&& certificate) +{ + m_handshake_state = std::make_unique(std::move(static_key), std::move(certificate)); + m_initiator = false; +} + +Sv2HandshakeState& Sv2Cipher::GetHandshakeState() +{ + Assume(m_handshake_state); + return *m_handshake_state; +} + +void Sv2Cipher::FinishHandshake() +{ + Assume(m_handshake_state); + + auto cipher_state{m_handshake_state->SplitSymmetricState()}; + + m_hash = m_handshake_state->GetHashOutput(); + + m_cs1 = std::move(cipher_state[0]); + m_cs2 = std::move(cipher_state[1]); + + m_handshake_state.reset(); +} + +size_t Sv2Cipher::EncryptedMessageSize(const size_t msg_len) +{ + constexpr size_t chunk_size = NOISE_MAX_CHUNK_SIZE - Poly1305::TAGLEN; + const size_t num_chunks = (msg_len + chunk_size - 1) / chunk_size; + return msg_len + (num_chunks * Poly1305::TAGLEN); +} + +bool Sv2Cipher::DecryptMessage(Span ciphertext, Span plain) +{ + Assume(EncryptedMessageSize(plain.size()) == ciphertext.size()); + + if (m_initiator) { + return m_cs2.DecryptMessage(ciphertext, plain); + } else { + return m_cs1.DecryptMessage(ciphertext, plain); + } +} + +bool Sv2Cipher::EncryptMessage(Span input, Span output) +{ + Assume(output.size() == Sv2Cipher::EncryptedMessageSize(input.size())); + + if (m_initiator) { + return m_cs1.EncryptMessage(input, output); + } else { + return m_cs2.EncryptMessage(input, output); + } +} + +uint256 Sv2Cipher::GetHash() const +{ + return m_hash; +} diff --git a/src/sv2/noise.h b/src/sv2/noise.h new file mode 100644 index 0000000000000..13036424438f5 --- /dev/null +++ b/src/sv2/noise.h @@ -0,0 +1,299 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_SV2_NOISE_H +#define BITCOIN_SV2_NOISE_H + +#include +#include +#include +#include +#include +#include +#include + +/** The Noise Protocol Framework + * https://noiseprotocol.org/noise.html + * Revision 38, 2018-07-11 + * + * Stratum v2 handshake and cipher specification: + * https://github.com/stratum-mining/sv2-spec/blob/main/04-Protocol-Security.md + */ + +/** Section 3: All Noise messages are less than or equal to 65535 bytes in length. */ +static constexpr size_t NOISE_MAX_CHUNK_SIZE = 65535; + +static constexpr size_t HASHLEN{32}; +using NoiseHash = std::array; + +/** Simple certificate for the static key signed by the authority key. + * See 4.5.2 and 4.5.3 of the Stratum v2 spec. + */ +class Sv2SignatureNoiseMessage +{ +public: + /** Size of a Schnorr signature. */ + static constexpr size_t SCHNORR_SIGNATURE_SIZE = 64; + /** Size of serialized message, which does not include the static key. */ + static constexpr size_t SIZE = 2 + 4 + 4 + SCHNORR_SIGNATURE_SIZE; + +private: + uint16_t m_version = 0; + uint32_t m_valid_from = 0; + uint32_t m_valid_to = 0; + std::array m_sig; + + /** Hash of version, valid from/to and the static key. */ + uint256 GetHash(); + void SignSchnorr(const CKey& authority_key, Span sig); + +public: + Sv2SignatureNoiseMessage() = default; + Sv2SignatureNoiseMessage(uint16_t version, uint32_t valid_from, uint32_t valid_to, const XOnlyPubKey& static_key, const CKey& authority_key); + + /* The certificate serializes pubkeys in x-only format, not EllSwift. */ + XOnlyPubKey m_static_key = {}; + + [[nodiscard]] bool Validate(XOnlyPubKey authority_key); + + template + // The static_key is signed for, but not serialized. + void Serialize(Stream& s) const + { + s << m_version + << m_valid_from + << m_valid_to + << m_sig; + } + template + void Unserialize(Stream& s) + { + s >> m_version + >> m_valid_from + >> m_valid_to + >> m_sig; + } +}; + +/* + * The CipherState uses m_key (k) and m_nonce (n) to encrypt and decrypt ciphertexts. + * During the handshake phase each party has a single CipherState, but during + * the transport phase each party has two CipherState objects: one for sending, + * and one for receiving. + * + * See chapter "5. Processing rules" of the Noise paper. + */ +class Sv2CipherState +{ +public: + Sv2CipherState() = default; + explicit Sv2CipherState(NoiseHash&& key); + + /** Decrypt message + * @param[in] associated_data associated data + * @param[in] ciphertext message with encrypted and authenticated chunks. + * @param[out] plain message (defragmented) + * @returns whether decryption succeeded + */ + [[nodiscard]] bool DecryptWithAd(Span associated_data, Span ciphertext, Span plain); + + /** Encrypt message + * @param[in] associated_data associated data + * @param[in] plain message + * @param[out] ciphertext message with encrypted and authenticated chunks. + * @returns whether encryption succeeded + */ + [[nodiscard]] bool EncryptWithAd(Span associated_data, Span plain, Span ciphertext); + + /** The message will be chunked in NOISE_MAX_CHUNK_SIZE parts and expanded + * by 16 bytes per chunk for its MAC. + * + * @param[in] plain message. Can't point to the same memory location as ciphertext, + * because each encrypted message chunk would override the + * start of the next plain text chunk. + * @param[out] ciphertext message with encrypted and authenticated chunks + * @return whether encryption succeeded. Only fails if nonce is uint64_max. + */ + [[nodiscard]] bool EncryptMessage(Span plain, Span ciphertext); + + /** Decrypt message. + * + * @param[in] ciphertext encrypted message + * @param[out] plain decrypted message. May point to the same memory location + * as ciphertext. The result is defragmented. + */ + [[nodiscard]] bool DecryptMessage(Span ciphertext, Span plain); + +private: + NoiseHash m_key{0}; + uint64_t m_nonce = 0; +}; + +/* + * A SymmetricState object contains a CipherState plus m_chaining_key (ck) and + * m_hash_output (h) variables. It is so-named because it encapsulates all the + * "symmetric crypto" used by Noise. During the handshake phase each party has + * a single SymmetricState, which can be deleted once the handshake is finished. + * + * See chapter "5. Processing rules" of the Noise paper. + */ +class Sv2SymmetricState +{ +public: + // Sha256 hash of the ascii encoding - "Noise_NX_Secp256k1+EllSwift_ChaChaPoly_SHA256". + // This is the first step required when setting up the chaining key. + static constexpr NoiseHash PROTOCOL_NAME_HASH = { + 46, 180, 120, 129, 32, 142, 158, 238, 31, 102, 159, 103, 198, 110, 231, 14, + 169, 234, 136, 9, 13, 80, 63, 232, 48, 220, 75, 200, 62, 41, 191, 16}; + + // The double hash of protocol name "Noise_NX_Secp256k1+EllSwift_ChaChaPoly_SHA256". + static constexpr NoiseHash PROTOCOL_NAME_DOUBLE_HASH = { + 146, 47, 163, 46, 79, 72, 124, 13, 89, 202, 163, 190, 215, 137, 156, 227, + 217, 141, 183, 225, 61, 189, 59, 124, 242, 210, 61, 212, 51, 220, 97, 4}; + + Sv2SymmetricState() : m_chaining_key{PROTOCOL_NAME_HASH} {} + + void MixHash(const Span input); + void MixKey(const Span input_key_material); + [[nodiscard]] bool EncryptAndHash(Span plain, Span ciphertext); + [[nodiscard]] bool DecryptAndHash(Span ciphertext, Span plain); + std::array Split(); + + uint256 GetHashOutput(); + + /* For testing */ + void LogChainingKey(); + std::string GetChainingKey(); + +private: + NoiseHash m_chaining_key; + uint256 m_hash_output{uint256(PROTOCOL_NAME_DOUBLE_HASH)}; + Sv2CipherState m_cipher_state; + + void HKDF2(const Span input_key_material, + NoiseHash& out0, + NoiseHash& out1); +}; + +/* + * A HandshakeState object contains a SymmetricState plus DH variables (s, e, rs, re) + * and a variable representing the handshake pattern. During the handshake phase + * each party has a single HandshakeState, which can be deleted once the handshake + * is finished. + * + * See chapter "5. Processing rules" of the Noise paper. + */ + +class Sv2HandshakeState +{ +public: + static constexpr size_t ELLSWIFT_PUB_KEY_SIZE{64}; + static constexpr size_t ECDH_OUTPUT_SIZE{32}; + + static constexpr size_t HANDSHAKE_STEP2_SIZE = ELLSWIFT_PUB_KEY_SIZE + ELLSWIFT_PUB_KEY_SIZE + + Poly1305::TAGLEN + Sv2SignatureNoiseMessage::SIZE + Poly1305::TAGLEN; + + /* + * If we are the initiator m_authority_pubkey must be set in order to verify + * the received certificate. + */ + Sv2HandshakeState(CKey&& static_key, + XOnlyPubKey authority_pubkey) : m_static_key{static_key}, + m_authority_pubkey{authority_pubkey} + { + m_static_ellswift_pk = static_key.EllSwiftCreate(MakeByteSpan(GetRandHash())); + }; + + /* + * If we are the responder, the certificate must be set + */ + Sv2HandshakeState(CKey&& static_key, + Sv2SignatureNoiseMessage&& certificate) : m_static_key{static_key}, + m_certificate{certificate} + { + m_static_ellswift_pk = static_key.EllSwiftCreate(MakeByteSpan(GetRandHash())); + }; + + /** Handshake step 1 for initiator: -> e */ + void WriteMsgEphemeralPK(Span msg); + /** Handshake step 1 for responder: -> e */ + void ReadMsgEphemeralPK(Span msg); + /** During handshake step 2, put our ephmeral key, static key + * and certificate in the buffer: <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE + */ + void WriteMsgES(Span msg); + /** During handshake step 2, read the remote ephmeral key, static key + * and certificate. Verify their certificate. + * <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE + */ + [[nodiscard]] bool ReadMsgES(Span msg); + + std::array SplitSymmetricState(); + uint256 GetHashOutput(); + + void SetEphemeralKey(CKey&& key); + +private: + /** Our static key (s) */ + CKey m_static_key; + /** EllSwift encoded static key, for optimized ECDH */ + EllSwiftPubKey m_static_ellswift_pk; + /** Our ephemeral key (e) */ + CKey m_ephemeral_key; + /** EllSwift encoded ephemeral key, for optimized ECDH */ + EllSwiftPubKey m_ephemeral_ellswift_pk; + /** Remote static key (rs) */ + EllSwiftPubKey m_remote_static_ellswift_pk; + /** Remote ephemeral key (re) */ + EllSwiftPubKey m_remote_ephemeral_ellswift_pk; + Sv2SymmetricState m_symmetric_state; + /** Certificate signed by m_authority_pubkey. */ + std::optional m_certificate; + /** Authority public key. */ + std::optional m_authority_pubkey; + + /** Generate ephemeral key, sets set m_ephemeral_key and m_ephemeral_ellswift_pk */ + void GenerateEphemeralKey() noexcept; +}; + +/** + * Interface somewhat similar to BIP324Cipher for use by a Transport class. + * The initiator and responder roles have their own constructor. + * FinishHandshake() must be called after all handshake bytes have been processed. + */ +class Sv2Cipher +{ +public: + Sv2Cipher(CKey&& static_key, XOnlyPubKey authority_pubkey); + Sv2Cipher(CKey&& static_key, Sv2SignatureNoiseMessage&& certificate); + + Sv2Cipher(bool initiator, std::unique_ptr handshake_state) : m_initiator{initiator}, m_handshake_state{std::move(handshake_state)} {}; + + Sv2HandshakeState& GetHandshakeState(); + /** + * Populates m_hash, m_cs1 and m_cs2 from m_handshake_state and deletes the latter. + */ + void FinishHandshake(); + + /** Decrypts a message. May only be called after FinishHandshake() */ + bool DecryptMessage(Span ciphertext, Span plain); + /** Encrypts a message. May only be called after FinishHandshake() */ + [[nodiscard]] bool EncryptMessage(Span input, Span output); + + /* Expected size after chunking and with MAC */ + static size_t EncryptedMessageSize(const size_t msg_len); + + /* Test only */ + uint256 GetHash() const; + +private: + bool m_initiator; + std::unique_ptr m_handshake_state; + + uint256 m_hash; + Sv2CipherState m_cs1; + Sv2CipherState m_cs2; +}; + +#endif // BITCOIN_SV2_NOISE_H diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 83b9c403b1151..83ad6b5cbf3aa 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -178,6 +178,7 @@ endif() if(WITH_SV2) target_sources(test_bitcoin PRIVATE + sv2_noise_tests.cpp ) target_link_libraries(test_bitcoin bitcoin_sv2) endif() diff --git a/src/test/fuzz/CMakeLists.txt b/src/test/fuzz/CMakeLists.txt index c47a958f82b1e..a843f422a0383 100644 --- a/src/test/fuzz/CMakeLists.txt +++ b/src/test/fuzz/CMakeLists.txt @@ -155,6 +155,7 @@ endif() if(WITH_SV2) target_sources(fuzz PRIVATE + sv2_noise.cpp ) target_link_libraries(fuzz bitcoin_sv2 diff --git a/src/test/fuzz/sv2_noise.cpp b/src/test/fuzz/sv2_noise.cpp new file mode 100644 index 0000000000000..afd8a7ce99c15 --- /dev/null +++ b/src/test/fuzz/sv2_noise.cpp @@ -0,0 +1,169 @@ +// Copyright (c) 2024 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + + +void Initialize() +{ + // Add test context for debugging. Usage: + // --debug=sv2 --loglevel=sv2:trace --printtoconsole=1 + static const auto testing_setup = std::make_unique(); +} +} // namespace + +bool MaybeDamage(FuzzedDataProvider& provider, std::vector& transport) +{ + if (transport.size() == 0) return false; + + // Optionally damage 1 bit in the ciphertext. + const bool damage = provider.ConsumeBool(); + if (damage) { + unsigned damage_bit = provider.ConsumeIntegralInRange(0, + transport.size() * 8U - 1U); + unsigned damage_pos = damage_bit >> 3; + LogTrace(BCLog::SV2, "Damage byte %d of %d\n", damage_pos, transport.size()); + std::byte damage_val{(uint8_t)(1U << (damage_bit & 7))}; + transport.at(damage_pos) ^= damage_val; + } + return damage; +} + +FUZZ_TARGET(sv2_noise_cipher_roundtrip, .init = Initialize) +{ + SeedRandomStateForTest(SeedRand::ZEROS); + // Test that Sv2Noise's encryption and decryption agree. + + // To conserve fuzzer entropy, deterministically generate Alice and Bob keys. + FuzzedDataProvider provider(buffer.data(), buffer.size()); + auto seed_ent = provider.ConsumeBytes(32); + seed_ent.resize(32); + CExtKey seed; + seed.SetSeed(seed_ent); + + CExtKey tmp; + if (!seed.Derive(tmp, 0)) return; + CKey alice_authority_key{tmp.key}; + + if (!seed.Derive(tmp, 1)) return; + CKey alice_static_key{tmp.key}; + + if (!seed.Derive(tmp, 2)) return; + CKey alice_ephemeral_key{tmp.key}; + + if (!seed.Derive(tmp, 10)) return; + CKey bob_authority_key{tmp.key}; + + if (!seed.Derive(tmp, 11)) return; + CKey bob_static_key{tmp.key}; + + if (!seed.Derive(tmp, 12)) return; + CKey bob_ephemeral_key{tmp.key}; + + // Create certificate + // Pick random times in the past or future + uint32_t now = provider.ConsumeIntegralInRange(10000U, UINT32_MAX); + SetMockTime(now); + uint16_t version = provider.ConsumeBool() ? 0 : provider.ConsumeIntegral(); + uint32_t past = provider.ConsumeIntegralInRange(0, now); + uint32_t future = provider.ConsumeIntegralInRange(now, UINT32_MAX); + uint32_t valid_from = int32_t(provider.ConsumeBool() ? past : future); + uint32_t valid_to = int32_t(provider.ConsumeBool() ? future : past); + + auto bob_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(bob_static_key.GetPubKey()), bob_authority_key); + + bool valid_certificate = version == 0 && + (valid_from <= now) && + (valid_to >= now); + + LogTrace(BCLog::SV2, "valid_certificate: %d - version %u, past: %u, now %u, future: %u\n", valid_certificate, version, past, now, future); + + // Alice's static is not used in the test + // Alice needs to verify Bob's certificate, so we pass his authority key + auto alice_handshake = std::make_unique(std::move(alice_static_key), XOnlyPubKey(bob_authority_key.GetPubKey())); + alice_handshake->SetEphemeralKey(std::move(alice_ephemeral_key)); + // Bob is the responder and does not receive (or verify) Alice's certificate, + // so we don't pass her authority key. + auto bob_handshake = std::make_unique(std::move(bob_static_key), std::move(bob_certificate)); + bob_handshake->SetEphemeralKey(std::move(bob_ephemeral_key)); + + // Handshake Act 1: e -> + + std::vector transport; + transport.resize(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + // Alice generates her ephemeral public key and write it into the buffer: + alice_handshake->WriteMsgEphemeralPK(transport); + + bool damage_e = MaybeDamage(provider, transport); + + // Bob reads the ephemeral key () + // With EllSwift encoding this step can't fail + bob_handshake->ReadMsgEphemeralPK(transport); + ClearShrink(transport); + + // Handshake Act 2: <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE + transport.resize(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + bob_handshake->WriteMsgES(transport); + + bool damage_es = MaybeDamage(provider, transport); + + // This ignores the remote possibility that the fuzzer finds two equivalent + // EllSwift encodings by flipping a single ephemeral key bit. + assert(alice_handshake->ReadMsgES(transport) == (valid_certificate && !damage_e && !damage_es)); + + if (!valid_certificate || damage_e || damage_es) return; + + // Construct Sv2Cipher from the Sv2HandshakeState and test transport + auto alice{Sv2Cipher(/*initiator=*/true, std::move(alice_handshake))}; + auto bob{Sv2Cipher(/*initiator=*/false, std::move(bob_handshake))}; + alice.FinishHandshake(); + bob.FinishHandshake(); + + // Use deterministic RNG to generate content rather than creating it from + // the fuzzer input. + InsecureRandomContext rng(provider.ConsumeIntegral()); + + LIMITED_WHILE(provider.remaining_bytes(), 1000) + { + ClearShrink(transport); + + // Alice or Bob sends a message + bool from_alice = provider.ConsumeBool(); + + // Set content length (slightly above NOISE_MAX_CHUNK_SIZE) + unsigned length = provider.ConsumeIntegralInRange(0, NOISE_MAX_CHUNK_SIZE + 100); + std::vector plain(length); + for (auto& val : plain) + val = std::byte{(uint8_t)rng()}; + + const size_t encrypted_size = Sv2Cipher::EncryptedMessageSize(plain.size()); + transport.resize(encrypted_size); + + assert((from_alice ? alice : bob).EncryptMessage(plain, transport)); + + const bool damage = MaybeDamage(provider, transport); + + std::vector plain_read; + plain_read.resize(plain.size()); + + bool ok = (from_alice ? bob : alice).DecryptMessage(transport, plain_read); + assert(!ok == damage); + if (!ok) break; + + assert(plain == plain_read); + } +} diff --git a/src/test/sv2_noise_tests.cpp b/src/test/sv2_noise_tests.cpp new file mode 100644 index 0000000000000..07da06e4e361d --- /dev/null +++ b/src/test/sv2_noise_tests.cpp @@ -0,0 +1,159 @@ +#include +#include +#include +#include +#include + +#include + +BOOST_FIXTURE_TEST_SUITE(sv2_noise_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(MixKey_test) +{ + Sv2SymmetricState i_ss; + Sv2SymmetricState r_ss; + BOOST_CHECK_EQUAL(r_ss.GetChainingKey(), i_ss.GetChainingKey()); + + CKey initiator_key{GenerateRandomKey()}; + auto initiator_pk = initiator_key.EllSwiftCreate(MakeByteSpan(GetRandHash())); + + CKey responder_key{GenerateRandomKey()}; + auto responder_pk = responder_key.EllSwiftCreate(MakeByteSpan(GetRandHash())); + + auto ecdh_output_1 = initiator_key.ComputeBIP324ECDHSecret(responder_pk, initiator_pk, true); + auto ecdh_output_2 = responder_key.ComputeBIP324ECDHSecret(initiator_pk, responder_pk, false); + + BOOST_CHECK(std::memcmp(&ecdh_output_1[0], &ecdh_output_2[0], 32) == 0); + + i_ss.MixKey(ecdh_output_1); + r_ss.MixKey(ecdh_output_2); + + BOOST_CHECK_EQUAL(r_ss.GetChainingKey(), i_ss.GetChainingKey()); +} + +BOOST_AUTO_TEST_CASE(certificate_test) +{ + auto alice_static_key{GenerateRandomKey()}; + auto alice_authority_key{GenerateRandomKey()}; + + // Create certificate + auto epoch_now = std::chrono::system_clock::now().time_since_epoch(); + uint32_t now = static_cast(std::chrono::duration_cast(epoch_now).count()); + uint16_t version = 0; + uint32_t valid_from = now; + uint32_t valid_to = std::numeric_limits::max(); + + auto alice_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(alice_static_key.GetPubKey()), alice_authority_key); + + BOOST_REQUIRE(alice_certificate.Validate(XOnlyPubKey(alice_authority_key.GetPubKey()))); + + auto malory_authority_key{GenerateRandomKey()}; + BOOST_REQUIRE(!alice_certificate.Validate(XOnlyPubKey(malory_authority_key.GetPubKey()))); + + // Check that certificate is not from the future + valid_from = now + 10000; + alice_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(alice_static_key.GetPubKey()), alice_authority_key); + BOOST_REQUIRE(!alice_certificate.Validate(XOnlyPubKey(alice_authority_key.GetPubKey()))); + + valid_from = now; + + // Check certificate expiration + valid_to = now - 10000; + alice_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(alice_static_key.GetPubKey()), alice_authority_key); + BOOST_REQUIRE(!alice_certificate.Validate(XOnlyPubKey(alice_authority_key.GetPubKey()))); + + valid_to = now; + + // Only version 0 is supported + version = 1; + alice_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(alice_static_key.GetPubKey()), alice_authority_key); + BOOST_REQUIRE(!alice_certificate.Validate(XOnlyPubKey(alice_authority_key.GetPubKey()))); +} + +BOOST_AUTO_TEST_CASE(handshake_and_transport_test) +{ + // Alice initiates a handshake with Bob + + auto alice_static_key{GenerateRandomKey()}; + auto bob_static_key{GenerateRandomKey()}; + auto bob_authority_key{GenerateRandomKey()}; + + // Create certificates + auto epoch_now = std::chrono::system_clock::now().time_since_epoch(); + uint16_t version = 0; + uint32_t valid_from = static_cast(std::chrono::duration_cast(epoch_now).count()); + uint32_t valid_to = std::numeric_limits::max(); + + auto bob_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(bob_static_key.GetPubKey()), + bob_authority_key); + + // Alice's static is not used in the test + // Alice needs to verify Bob's certificate, so we pass his authority key + auto alice_handshake = std::make_unique(std::move(alice_static_key), + XOnlyPubKey(bob_authority_key.GetPubKey())); + // Bob is the responder and does not receive (or verify) Alice's certificate, + // so we don't pass her authority key. + auto bob_handshake = std::make_unique(std::move(bob_static_key), + std::move(bob_certificate)); + + // Handshake Act 1: e -> + + std::vector transport; + transport.resize(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + // Alice generates her ephemeral public key and write it into the buffer: + alice_handshake->WriteMsgEphemeralPK(transport); + EllSwiftPubKey alice_pubkey(transport); + + // Bob reads the ephemeral key + bob_handshake->ReadMsgEphemeralPK(transport); + + ClearShrink(transport); + + // Handshake Act 2: <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE + transport.resize(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + bob_handshake->WriteMsgES(transport); + BOOST_REQUIRE(alice_handshake->ReadMsgES(transport)); + + // Construct Sv2Cipher from the Sv2HandshakeState and test transport + auto alice{Sv2Cipher(/*initiator=*/true, std::move(alice_handshake))}; + auto bob{Sv2Cipher(/*initiator=*/false, std::move(bob_handshake))}; + alice.FinishHandshake(); + bob.FinishHandshake(); + + ClearShrink(transport); + + constexpr std::array TEST{ + // hello world + 0x68, + 0x65, + 0x6C, + 0x6C, + 0x6F, + 0x20, + 0x77, + 0x6F, + 0x72, + 0x6C, + 0x64, + }; + + const size_t encrypted_size = Sv2Cipher::EncryptedMessageSize(TEST.size()); + BOOST_CHECK_EQUAL(encrypted_size, TEST.size() + Poly1305::TAGLEN); + + transport.resize(encrypted_size); + + auto plain_send{MakeByteSpan(TEST)}; + BOOST_TEST_CHECKPOINT("Alice encrypts message"); + BOOST_REQUIRE(alice.EncryptMessage(plain_send, transport)); + + std::vector plain_receive(TEST.size(), std::byte(0)); + BOOST_TEST_CHECKPOINT("Bob decrypts message"); + BOOST_REQUIRE(bob.DecryptMessage(transport, plain_receive)); + BOOST_CHECK_EQUAL(HexStr(plain_receive), HexStr(TEST)); +} +BOOST_AUTO_TEST_SUITE_END() From 863448bc7a1913736eb091dde2dc71a532be3d1b Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Thu, 19 Sep 2024 16:25:19 +0200 Subject: [PATCH 24/30] Add sv2 message CoinbaseOutputConstraints This commit adds the simplest stratum v2 message. The remaining messages are introduced in later commits. --- src/sv2/messages.h | 185 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 src/sv2/messages.h diff --git a/src/sv2/messages.h b/src/sv2/messages.h new file mode 100644 index 0000000000000..3280b53ecaca8 --- /dev/null +++ b/src/sv2/messages.h @@ -0,0 +1,185 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_SV2_MESSAGES_H +#define BITCOIN_SV2_MESSAGES_H + +#include +#include +#include +namespace node { +/** + * A type used as the message length field in stratum v2 messages. + */ +using u24_t = uint8_t[3]; + +/** + * All the stratum v2 message types handled by the template provider. + */ +enum class Sv2MsgType : uint8_t { + COINBASE_OUTPUT_CONSTRAINTS = 0x70, +}; + +/** + * Set the coinbase outputs data len for the outputs that the client wants to add to the coinbase. + * The template provider MUST NOT provide NewWork messages which would represent consensus-invalid blocks once this + * additional size — along with a maximally-sized (100 byte) coinbase field — is added. + */ +struct Sv2CoinbaseOutputConstraintsMsg +{ + /** + * The default message type value for this Stratum V2 message. + */ + static constexpr auto m_msg_type = Sv2MsgType::COINBASE_OUTPUT_CONSTRAINTS; + + /** + * The maximum additional serialized bytes which the pool will add in coinbase transaction outputs. + */ + uint32_t m_coinbase_output_max_additional_size; + + /** + * The maximum additional sigops which the pool will add in coinbase transaction outputs. + */ + uint16_t m_coinbase_output_max_additional_sigops; + + template + void Serialize(Stream& s) const + { + s << m_coinbase_output_max_additional_size; + s << m_coinbase_output_max_additional_sigops; + }; + + + template + void Unserialize(Stream& s) + { + s >> m_coinbase_output_max_additional_size; + try { + // This field was added to the spec on ..., + // SRI roles before ... do not provide it. + s >> m_coinbase_output_max_additional_sigops; + } catch (...) { + // Just use the default if it's missing + m_coinbase_output_max_additional_sigops = 400; + } + } +}; + +/** + * Header for all stratum v2 messages. Each header must contain the message type, + * the length of the serialized message and a 2 byte extension field currently + * not utilised by the template provider. + */ +class Sv2NetHeader +{ +public: + /** + * Unique identifier of the message. + */ + Sv2MsgType m_msg_type; + + /** + * Serialized length of the message. + */ + uint32_t m_msg_len; + + Sv2NetHeader() = default; + explicit Sv2NetHeader(Sv2MsgType msg_type, uint32_t msg_len) : m_msg_type{msg_type}, m_msg_len{msg_len} {}; + + template + void Serialize(Stream& s) const + { + // The template provider currently does not use the extension_type field, + // but the field is still required for all headers. + uint16_t extension_type = 0; + + u24_t msg_len; + msg_len[2] = (m_msg_len >> 16) & 0xff; + msg_len[1] = (m_msg_len >> 8) & 0xff; + msg_len[0] = m_msg_len & 0xff; + + s << extension_type + << static_cast(m_msg_type) + << msg_len; + }; + + template + void Unserialize(Stream& s) + { + // Ignore the first 2 bytes (extension type) as the template provider currently doesn't + // interpret this field. + s.ignore(2); + + uint8_t msg_type; + s >> msg_type; + m_msg_type = static_cast(msg_type); + + u24_t msg_len_bytes; + for (unsigned int i = 0; i < sizeof(u24_t); ++i) { + s >> msg_len_bytes[i]; + } + + m_msg_len = msg_len_bytes[2]; + m_msg_len = m_msg_len << 8 | msg_len_bytes[1]; + m_msg_len = m_msg_len << 8 | msg_len_bytes[0]; + } +}; + +/** + * The networked form for all stratum v2 messages, contains a header and a serialized + * payload from a referenced stratum v2 message. + */ +class Sv2NetMsg +{ +public: + Sv2MsgType m_msg_type; + std::vector m_msg; + + explicit Sv2NetMsg(const Sv2MsgType msg_type, const std::vector&& msg) : m_msg_type{msg_type}, m_msg{msg} {}; + + /** + * Serializes the message M and sets an Sv2 network header. + * @throws std::ios_base or std::out_of_range errors. + */ + template + explicit Sv2NetMsg(const M& msg) + { + m_msg_type = msg.m_msg_type; + + // Serialize the sv2 message. + VectorWriter{m_msg, 0, msg}; + } + + unsigned char* data() { return m_msg.data(); } + size_t size() { return m_msg.size(); } + + operator Sv2NetHeader() + { + Sv2NetHeader hdr; + hdr.m_msg_type = m_msg_type; + hdr.m_msg_len = static_cast(m_msg.size()); + return hdr; + } + + template + void Unserialize(Stream& s) + { + uint8_t msg_type; + s >> msg_type; + m_msg_type = static_cast(msg_type); + s.read(MakeWritableByteSpan(m_msg)); + } + + template + void Serialize(Stream& s) const + { + s << static_cast(m_msg_type); + s.write(MakeByteSpan(m_msg)); + } + +}; + +} + +#endif // BITCOIN_SV2_MESSAGES_H From ed341206138ed27306835a81b4915969641e408c Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Fri, 20 Sep 2024 11:01:23 +0200 Subject: [PATCH 25/30] Move CNetMessage and Transport headers to common This avoids a circular dependency between bitcoin-sv2 and bitcoin-node. --- src/common/transport.h | 189 ++++++++++++++++++++++++++++++++++ src/net.h | 166 +---------------------------- src/node/connection_types.cpp | 1 + src/node/connection_types.h | 11 -- 4 files changed, 191 insertions(+), 176 deletions(-) create mode 100644 src/common/transport.h diff --git a/src/common/transport.h b/src/common/transport.h new file mode 100644 index 0000000000000..3a7987b102416 --- /dev/null +++ b/src/common/transport.h @@ -0,0 +1,189 @@ +// Copyright (c) 2024 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_COMMON_TRANSPORT_H +#define BITCOIN_COMMON_TRANSPORT_H + +#include +#include +#include +#include +#include +#include + +/** Transport layer version */ +enum class TransportProtocolType : uint8_t { + DETECTING, //!< Peer could be v1 or v2 + V1, //!< Unencrypted, plaintext protocol + V2, //!< BIP324 protocol +}; + +/** Convert TransportProtocolType enum to a string value */ +std::string TransportTypeAsString(TransportProtocolType transport_type); + +/** Transport protocol agnostic message container. + * Ideally it should only contain receive time, payload, + * type and size. + */ +class CNetMessage +{ +public: + DataStream m_recv; //!< received message data + std::chrono::microseconds m_time{0}; //!< time of message receipt + uint32_t m_message_size{0}; //!< size of the payload + uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) + std::string m_type; + + explicit CNetMessage(DataStream&& recv_in) : m_recv(std::move(recv_in)) {} + // Only one CNetMessage object will exist for the same message on either + // the receive or processing queue. For performance reasons we therefore + // delete the copy constructor and assignment operator to avoid the + // possibility of copying CNetMessage objects. + CNetMessage(CNetMessage&&) = default; + CNetMessage(const CNetMessage&) = delete; + CNetMessage& operator=(CNetMessage&&) = default; + CNetMessage& operator=(const CNetMessage&) = delete; + + /** Compute total memory usage of this object (own memory + any dynamic memory). */ + size_t GetMemoryUsage() const noexcept; +}; + +struct CSerializedNetMsg { + CSerializedNetMsg() = default; + CSerializedNetMsg(CSerializedNetMsg&&) = default; + CSerializedNetMsg& operator=(CSerializedNetMsg&&) = default; + // No implicit copying, only moves. + CSerializedNetMsg(const CSerializedNetMsg& msg) = delete; + CSerializedNetMsg& operator=(const CSerializedNetMsg&) = delete; + + CSerializedNetMsg Copy() const + { + CSerializedNetMsg copy; + copy.data = data; + copy.m_type = m_type; + return copy; + } + + std::vector data; + std::string m_type; + + /** Compute total memory usage of this object (own memory + any dynamic memory). */ + size_t GetMemoryUsage() const noexcept; +}; + +/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */ +class Transport { +public: + virtual ~Transport() = default; + + struct Info + { + TransportProtocolType transport_type; + std::optional session_id; + }; + + /** Retrieve information about this transport. */ + virtual Info GetInfo() const noexcept = 0; + + // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol + // agnostic CNetMessage (message type & payload) objects. + + /** Returns true if the current message is complete (so GetReceivedMessage can be called). */ + virtual bool ReceivedMessageComplete() const = 0; + + /** Feed wire bytes to the transport. + * + * @return false if some bytes were invalid, in which case the transport can't be used anymore. + * + * Consumed bytes are chopped off the front of msg_bytes. + */ + virtual bool ReceivedBytes(Span& msg_bytes) = 0; + + /** Retrieve a completed message from transport. + * + * This can only be called when ReceivedMessageComplete() is true. + * + * If reject_message=true is returned the message itself is invalid, but (other than false + * returned by ReceivedBytes) the transport is not in an inconsistent state. + */ + virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0; + + // 2. Sending side functions, for converting messages into bytes to be sent over the wire. + + /** Set the next message to send. + * + * If no message can currently be set (perhaps because the previous one is not yet done being + * sent), returns false, and msg will be unmodified. Otherwise msg is enqueued (and + * possibly moved-from) and true is returned. + */ + virtual bool SetMessageToSend(CSerializedNetMsg& msg) noexcept = 0; + + /** Return type for GetBytesToSend, consisting of: + * - Span to_send: span of bytes to be sent over the wire (possibly empty). + * - bool more: whether there will be more bytes to be sent after the ones in to_send are + * all sent (as signaled by MarkBytesSent()). + * - const std::string& m_type: message type on behalf of which this is being sent + * ("" for bytes that are not on behalf of any message). + */ + using BytesToSend = std::tuple< + Span /*to_send*/, + bool /*more*/, + const std::string& /*m_type*/ + >; + + /** Get bytes to send on the wire, if any, along with other information about it. + * + * As a const function, it does not modify the transport's observable state, and is thus safe + * to be called multiple times. + * + * @param[in] have_next_message If true, the "more" return value reports whether more will + * be sendable after a SetMessageToSend call. It is set by the caller when they know + * they have another message ready to send, and only care about what happens + * after that. The have_next_message argument only affects this "more" return value + * and nothing else. + * + * Effectively, there are three possible outcomes about whether there are more bytes + * to send: + * - Yes: the transport itself has more bytes to send later. For example, for + * V1Transport this happens during the sending of the header of a + * message, when there is a non-empty payload that follows. + * - No: the transport itself has no more bytes to send, but will have bytes to + * send if handed a message through SetMessageToSend. In V1Transport this + * happens when sending the payload of a message. + * - Blocked: the transport itself has no more bytes to send, and is also incapable + * of sending anything more at all now, if it were handed another + * message to send. This occurs in V2Transport before the handshake is + * complete, as the encryption ciphers are not set up for sending + * messages before that point. + * + * The boolean 'more' is true for Yes, false for Blocked, and have_next_message + * controls what is returned for No. + * + * @return a BytesToSend object. The to_send member returned acts as a stream which is only + * ever appended to. This means that with the exception of MarkBytesSent (which pops + * bytes off the front of later to_sends), operations on the transport can only append + * to what is being returned. Also note that m_type and to_send refer to data that is + * internal to the transport, and calling any non-const function on this object may + * invalidate them. + */ + virtual BytesToSend GetBytesToSend(bool have_next_message) const noexcept = 0; + + /** Report how many bytes returned by the last GetBytesToSend() have been sent. + * + * bytes_sent cannot exceed to_send.size() of the last GetBytesToSend() result. + * + * If bytes_sent=0, this call has no effect. + */ + virtual void MarkBytesSent(size_t bytes_sent) noexcept = 0; + + /** Return the memory usage of this transport attributable to buffered data to send. */ + virtual size_t GetSendMemoryUsage() const noexcept = 0; + + // 3. Miscellaneous functions. + + /** Whether upon disconnections, a reconnect with V1 is warranted. */ + virtual bool ShouldReconnectV1() const noexcept = 0; +}; + +#endif // BITCOIN_COMMON_TRANSPORT_H diff --git a/src/net.h b/src/net.h index 99a9d0da4b45d..fbac6b36a7586 100644 --- a/src/net.h +++ b/src/net.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -111,29 +112,6 @@ struct AddedNodeInfo { class CNodeStats; class CClientUIInterface; -struct CSerializedNetMsg { - CSerializedNetMsg() = default; - CSerializedNetMsg(CSerializedNetMsg&&) = default; - CSerializedNetMsg& operator=(CSerializedNetMsg&&) = default; - // No implicit copying, only moves. - CSerializedNetMsg(const CSerializedNetMsg& msg) = delete; - CSerializedNetMsg& operator=(const CSerializedNetMsg&) = delete; - - CSerializedNetMsg Copy() const - { - CSerializedNetMsg copy; - copy.data = data; - copy.m_type = m_type; - return copy; - } - - std::vector data; - std::string m_type; - - /** Compute total memory usage of this object (own memory + any dynamic memory). */ - size_t GetMemoryUsage() const noexcept; -}; - /** * Look up IP addresses from all interfaces on the machine and add them to the * list of local addresses to self-advertise. @@ -222,148 +200,6 @@ class CNodeStats std::string m_session_id; }; - -/** Transport protocol agnostic message container. - * Ideally it should only contain receive time, payload, - * type and size. - */ -class CNetMessage -{ -public: - DataStream m_recv; //!< received message data - std::chrono::microseconds m_time{0}; //!< time of message receipt - uint32_t m_message_size{0}; //!< size of the payload - uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) - std::string m_type; - - explicit CNetMessage(DataStream&& recv_in) : m_recv(std::move(recv_in)) {} - // Only one CNetMessage object will exist for the same message on either - // the receive or processing queue. For performance reasons we therefore - // delete the copy constructor and assignment operator to avoid the - // possibility of copying CNetMessage objects. - CNetMessage(CNetMessage&&) = default; - CNetMessage(const CNetMessage&) = delete; - CNetMessage& operator=(CNetMessage&&) = default; - CNetMessage& operator=(const CNetMessage&) = delete; - - /** Compute total memory usage of this object (own memory + any dynamic memory). */ - size_t GetMemoryUsage() const noexcept; -}; - -/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */ -class Transport { -public: - virtual ~Transport() = default; - - struct Info - { - TransportProtocolType transport_type; - std::optional session_id; - }; - - /** Retrieve information about this transport. */ - virtual Info GetInfo() const noexcept = 0; - - // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol - // agnostic CNetMessage (message type & payload) objects. - - /** Returns true if the current message is complete (so GetReceivedMessage can be called). */ - virtual bool ReceivedMessageComplete() const = 0; - - /** Feed wire bytes to the transport. - * - * @return false if some bytes were invalid, in which case the transport can't be used anymore. - * - * Consumed bytes are chopped off the front of msg_bytes. - */ - virtual bool ReceivedBytes(Span& msg_bytes) = 0; - - /** Retrieve a completed message from transport. - * - * This can only be called when ReceivedMessageComplete() is true. - * - * If reject_message=true is returned the message itself is invalid, but (other than false - * returned by ReceivedBytes) the transport is not in an inconsistent state. - */ - virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0; - - // 2. Sending side functions, for converting messages into bytes to be sent over the wire. - - /** Set the next message to send. - * - * If no message can currently be set (perhaps because the previous one is not yet done being - * sent), returns false, and msg will be unmodified. Otherwise msg is enqueued (and - * possibly moved-from) and true is returned. - */ - virtual bool SetMessageToSend(CSerializedNetMsg& msg) noexcept = 0; - - /** Return type for GetBytesToSend, consisting of: - * - Span to_send: span of bytes to be sent over the wire (possibly empty). - * - bool more: whether there will be more bytes to be sent after the ones in to_send are - * all sent (as signaled by MarkBytesSent()). - * - const std::string& m_type: message type on behalf of which this is being sent - * ("" for bytes that are not on behalf of any message). - */ - using BytesToSend = std::tuple< - Span /*to_send*/, - bool /*more*/, - const std::string& /*m_type*/ - >; - - /** Get bytes to send on the wire, if any, along with other information about it. - * - * As a const function, it does not modify the transport's observable state, and is thus safe - * to be called multiple times. - * - * @param[in] have_next_message If true, the "more" return value reports whether more will - * be sendable after a SetMessageToSend call. It is set by the caller when they know - * they have another message ready to send, and only care about what happens - * after that. The have_next_message argument only affects this "more" return value - * and nothing else. - * - * Effectively, there are three possible outcomes about whether there are more bytes - * to send: - * - Yes: the transport itself has more bytes to send later. For example, for - * V1Transport this happens during the sending of the header of a - * message, when there is a non-empty payload that follows. - * - No: the transport itself has no more bytes to send, but will have bytes to - * send if handed a message through SetMessageToSend. In V1Transport this - * happens when sending the payload of a message. - * - Blocked: the transport itself has no more bytes to send, and is also incapable - * of sending anything more at all now, if it were handed another - * message to send. This occurs in V2Transport before the handshake is - * complete, as the encryption ciphers are not set up for sending - * messages before that point. - * - * The boolean 'more' is true for Yes, false for Blocked, and have_next_message - * controls what is returned for No. - * - * @return a BytesToSend object. The to_send member returned acts as a stream which is only - * ever appended to. This means that with the exception of MarkBytesSent (which pops - * bytes off the front of later to_sends), operations on the transport can only append - * to what is being returned. Also note that m_type and to_send refer to data that is - * internal to the transport, and calling any non-const function on this object may - * invalidate them. - */ - virtual BytesToSend GetBytesToSend(bool have_next_message) const noexcept = 0; - - /** Report how many bytes returned by the last GetBytesToSend() have been sent. - * - * bytes_sent cannot exceed to_send.size() of the last GetBytesToSend() result. - * - * If bytes_sent=0, this call has no effect. - */ - virtual void MarkBytesSent(size_t bytes_sent) noexcept = 0; - - /** Return the memory usage of this transport attributable to buffered data to send. */ - virtual size_t GetSendMemoryUsage() const noexcept = 0; - - // 3. Miscellaneous functions. - - /** Whether upon disconnections, a reconnect with V1 is warranted. */ - virtual bool ShouldReconnectV1() const noexcept = 0; -}; - class V1Transport final : public Transport { private: diff --git a/src/node/connection_types.cpp b/src/node/connection_types.cpp index 5e4dc5bf2ef94..2d8dbec2f131c 100644 --- a/src/node/connection_types.cpp +++ b/src/node/connection_types.cpp @@ -2,6 +2,7 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. +#include #include #include diff --git a/src/node/connection_types.h b/src/node/connection_types.h index a911b95f7e917..5e1abcace67d1 100644 --- a/src/node/connection_types.h +++ b/src/node/connection_types.h @@ -6,7 +6,6 @@ #define BITCOIN_NODE_CONNECTION_TYPES_H #include -#include /** Different types of connections to a peer. This enum encapsulates the * information we have available at the time of opening or accepting the @@ -80,14 +79,4 @@ enum class ConnectionType { /** Convert ConnectionType enum to a string value */ std::string ConnectionTypeAsString(ConnectionType conn_type); -/** Transport layer version */ -enum class TransportProtocolType : uint8_t { - DETECTING, //!< Peer could be v1 or v2 - V1, //!< Unencrypted, plaintext protocol - V2, //!< BIP324 protocol -}; - -/** Convert TransportProtocolType enum to a string value */ -std::string TransportTypeAsString(TransportProtocolType transport_type); - #endif // BITCOIN_NODE_CONNECTION_TYPES_H From b6ea21d748605476155a69276ccf5d68ed1cb0ce Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Fri, 21 Jun 2024 11:56:22 +0200 Subject: [PATCH 26/30] Convert between Sv2NetMsg and CSerializedNetMsg This allows us to subclass Transport. --- src/sv2/messages.h | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/sv2/messages.h b/src/sv2/messages.h index 3280b53ecaca8..277326687a1cf 100644 --- a/src/sv2/messages.h +++ b/src/sv2/messages.h @@ -5,9 +5,12 @@ #ifndef BITCOIN_SV2_MESSAGES_H #define BITCOIN_SV2_MESSAGES_H +#include // for CSerializedNetMsg and CNetMessage #include #include #include +#include + namespace node { /** * A type used as the message length field in stratum v2 messages. @@ -138,6 +141,40 @@ class Sv2NetMsg explicit Sv2NetMsg(const Sv2MsgType msg_type, const std::vector&& msg) : m_msg_type{msg_type}, m_msg{msg} {}; + // Unwrap CSerializedNetMsg + Sv2NetMsg(CSerializedNetMsg&& net_msg) + { + Assume(net_msg.m_type == ""); + DataStream ss(MakeByteSpan(net_msg.data)); + Unserialize(ss); + }; + + // Unwrap CNetMsg + Sv2NetMsg(CNetMessage net_msg) + { + Unserialize(net_msg.m_recv); + }; + + operator CSerializedNetMsg() + { + CSerializedNetMsg net_msg; + net_msg.m_type = ""; + DataStream ser; + Serialize(ser); + net_msg.data.resize(ser.size()); + std::transform(ser.begin(), ser.end(), net_msg.data.begin(), + [](std::byte b) { return static_cast(b); }); + return net_msg; + } + + operator CNetMessage() + { + DataStream msg; + Serialize(msg); + CNetMessage ret{std::move(msg)}; + return ret; + } + /** * Serializes the message M and sets an Sv2 network header. * @throws std::ios_base or std::out_of_range errors. @@ -168,6 +205,7 @@ class Sv2NetMsg uint8_t msg_type; s >> msg_type; m_msg_type = static_cast(msg_type); + m_msg.resize(s.size()); s.read(MakeWritableByteSpan(m_msg)); } From e8737e0f7727eeeb5ed56e2e1e1be092018298d8 Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Fri, 20 Sep 2024 11:12:08 +0200 Subject: [PATCH 27/30] Introduce Sv2Transport Implemented starting from a copy of V2Transport and the V2TransportTester, modifying it to fit Stratum v2 and Noise Protocol requirements. Co-Authored-By: Christopher Coverdale --- src/sv2/CMakeLists.txt | 1 + src/sv2/transport.cpp | 494 +++++++++++++++++++++++++++++++ src/sv2/transport.h | 194 ++++++++++++ src/test/CMakeLists.txt | 1 + src/test/sv2_transport_tests.cpp | 389 ++++++++++++++++++++++++ 5 files changed, 1079 insertions(+) create mode 100644 src/sv2/transport.cpp create mode 100644 src/sv2/transport.h create mode 100644 src/test/sv2_transport_tests.cpp diff --git a/src/sv2/CMakeLists.txt b/src/sv2/CMakeLists.txt index d6e44842e8c87..e61f2f3560834 100644 --- a/src/sv2/CMakeLists.txt +++ b/src/sv2/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(bitcoin_sv2 STATIC EXCLUDE_FROM_ALL noise.cpp + transport.cpp ) target_link_libraries(bitcoin_sv2 diff --git a/src/sv2/transport.cpp b/src/sv2/transport.cpp new file mode 100644 index 0000000000000..37a6e36ba19e0 --- /dev/null +++ b/src/sv2/transport.cpp @@ -0,0 +1,494 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +Sv2Transport::Sv2Transport(CKey static_key, Sv2SignatureNoiseMessage certificate) noexcept + : m_cipher{Sv2Cipher(std::move(static_key), std::move(certificate))}, m_initiating{false}, + m_recv_state{RecvState::HANDSHAKE_STEP_1}, + m_send_state{SendState::HANDSHAKE_STEP_2}, + m_message{Sv2NetMsg(Sv2NetHeader{})} +{ + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session receive state -> %s\n", + RecvStateAsString(m_recv_state)); +} + +Sv2Transport::Sv2Transport(CKey static_key, XOnlyPubKey responder_authority_key) noexcept + : m_cipher{Sv2Cipher(std::move(static_key), responder_authority_key)}, m_initiating{true}, + m_recv_state{RecvState::HANDSHAKE_STEP_2}, + m_send_state{SendState::HANDSHAKE_STEP_1}, + m_message{Sv2NetMsg(Sv2NetHeader{})} +{ + /** Start sending immediately since we're the initiator of the connection. + This only happens in test code. + */ + LOCK(m_send_mutex); + StartSendingHandshake(); + +} + +void Sv2Transport::SetReceiveState(RecvState recv_state) noexcept +{ + AssertLockHeld(m_recv_mutex); + // Enforce allowed state transitions. + switch (m_recv_state) { + case RecvState::HANDSHAKE_STEP_1: + Assume(recv_state == RecvState::HANDSHAKE_STEP_2); + break; + case RecvState::HANDSHAKE_STEP_2: + Assume(recv_state == RecvState::APP); + break; + case RecvState::APP: + Assume(recv_state == RecvState::APP_READY); + break; + case RecvState::APP_READY: + Assume(recv_state == RecvState::APP); + break; + } + // Change state. + m_recv_state = recv_state; + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session receive state -> %s\n", + RecvStateAsString(m_recv_state)); + +} + +void Sv2Transport::SetSendState(SendState send_state) noexcept +{ + AssertLockHeld(m_send_mutex); + // Enforce allowed state transitions. + switch (m_send_state) { + case SendState::HANDSHAKE_STEP_1: + Assume(send_state == SendState::HANDSHAKE_STEP_2); + break; + case SendState::HANDSHAKE_STEP_2: + Assume(send_state == SendState::READY); + break; + case SendState::READY: + Assume(false); // Final state + break; + } + // Change state. + m_send_state = send_state; + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session send state -> %s\n", + SendStateAsString(m_send_state)); +} + +void Sv2Transport::StartSendingHandshake() noexcept +{ + AssertLockHeld(m_send_mutex); + AssertLockNotHeld(m_recv_mutex); + Assume(m_send_state == SendState::HANDSHAKE_STEP_1); + Assume(m_send_buffer.empty()); + + m_send_buffer.resize(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + m_cipher.GetHandshakeState().WriteMsgEphemeralPK(MakeWritableByteSpan(m_send_buffer)); + + m_send_state = SendState::HANDSHAKE_STEP_2; +} + +void Sv2Transport::SendHandshakeReply() noexcept +{ + AssertLockHeld(m_send_mutex); + AssertLockHeld(m_recv_mutex); + Assume(m_send_state == SendState::HANDSHAKE_STEP_2); + + Assume(m_send_buffer.empty()); + m_send_buffer.resize(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + m_cipher.GetHandshakeState().WriteMsgES(MakeWritableByteSpan(m_send_buffer)); + + m_cipher.FinishHandshake(); + + // We can send and receive stuff now, unless the other side hangs up + SetSendState(SendState::READY); + Assume(m_recv_state == RecvState::HANDSHAKE_STEP_2); + SetReceiveState(RecvState::APP); +} + +Transport::BytesToSend Sv2Transport::GetBytesToSend(bool have_next_message) const noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + const std::string dummy_m_type; // m_type is set to "" when wrapping Sv2NetMsg + + Assume(m_send_pos <= m_send_buffer.size()); + return { + Span{m_send_buffer}.subspan(m_send_pos), + // We only have more to send after the current m_send_buffer if there is a (next) + // message to be sent, and we're capable of sending packets. */ + have_next_message && m_send_state == SendState::READY, + dummy_m_type + }; +} + +void Sv2Transport::MarkBytesSent(size_t bytes_sent) noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + // if (m_send_state == SendState::AWAITING_KEY && m_send_pos == 0 && bytes_sent > 0) { + // LogPrint(BCLog::NET, "start sending v2 handshake to peer=%d\n", m_nodeid); + // } + + m_send_pos += bytes_sent; + Assume(m_send_pos <= m_send_buffer.size()); + // Wipe the buffer when everything is sent. + if (m_send_pos == m_send_buffer.size()) { + m_send_pos = 0; + ClearShrink(m_send_buffer); + } +} + +bool Sv2Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + // We only allow adding a new message to be sent when in the READY state (so the packet cipher + // is available) and the send buffer is empty. This limits the number of messages in the send + // buffer to just one, and leaves the responsibility for queueing them up to the caller. + if (m_send_state != SendState::READY) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "SendState is not READY\n"); + return false; + } + + if (!m_send_buffer.empty()) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send buffer is not empty\n"); + return false; + } + + // The Sv2NetMsg is wrapped inside a dummy CSerializedNetMsg, extract it: + Sv2NetMsg sv2_msg(std::move(msg)); + // Reconstruct the header: + Sv2NetHeader hdr(sv2_msg.m_msg_type, sv2_msg.size()); + + // Construct ciphertext in send buffer. + const size_t encrypted_msg_size = Sv2Cipher::EncryptedMessageSize(sv2_msg.size()); + m_send_buffer.resize(SV2_HEADER_ENCRYPTED_SIZE + encrypted_msg_size); + Span buffer_span{MakeWritableByteSpan(m_send_buffer)}; + + // Header + DataStream ss_header_plain{}; + ss_header_plain << hdr; + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(ss_header_plain)); + Span header_encrypted{buffer_span.subspan(0, SV2_HEADER_ENCRYPTED_SIZE)}; + if (!m_cipher.EncryptMessage(ss_header_plain, header_encrypted)) { + return false; + } + + // Payload + Span payload_plain = MakeByteSpan(sv2_msg); + // TODO: truncate very long messages, about 100 bytes at the start and end + // is probably enough for most debugging. + // LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload_plain)); + Span payload_encrypted{buffer_span.subspan(SV2_HEADER_ENCRYPTED_SIZE, encrypted_msg_size)}; + if (!m_cipher.EncryptMessage(payload_plain, payload_encrypted)) { + return false; + } + + // Release memory (not needed with std::move above) + // ClearShrink(msg.data); + + return true; +} + +size_t Sv2Transport::GetSendMemoryUsage() const noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + return sizeof(m_send_buffer) + memusage::DynamicUsage(m_send_buffer); +} + +bool Sv2Transport::ReceivedBytes(Span& msg_bytes) noexcept +{ + AssertLockNotHeld(m_send_mutex); + AssertLockNotHeld(m_recv_mutex); + /** How many bytes to allocate in the receive buffer at most above what is received so far. */ + static constexpr size_t MAX_RESERVE_AHEAD = 256 * 1024; // TODO: reduce to NOISE_MAX_CHUNK_SIZE? + + LOCK(m_recv_mutex); + // Process the provided bytes in msg_bytes in a loop. In each iteration a nonzero number of + // bytes (decided by GetMaxBytesToProcess) are taken from the beginning om msg_bytes, and + // appended to m_recv_buffer. Then, depending on the receiver state, one of the + // ProcessReceived*Bytes functions is called to process the bytes in that buffer. + while (!msg_bytes.empty()) { + // Decide how many bytes to copy from msg_bytes to m_recv_buffer. + size_t max_read = GetMaxBytesToProcess(); + + // Reserve space in the buffer if there is not enough. + if (m_recv_buffer.size() + std::min(msg_bytes.size(), max_read) > m_recv_buffer.capacity()) { + switch (m_recv_state) { + case RecvState::HANDSHAKE_STEP_1: + m_recv_buffer.reserve(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + break; + case RecvState::HANDSHAKE_STEP_2: + m_recv_buffer.reserve(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + break; + case RecvState::APP: { + // During states where a packet is being received, as much as is expected but never + // more than MAX_RESERVE_AHEAD bytes in addition to what is received so far. + // This means attackers that want to cause us to waste allocated memory are limited + // to MAX_RESERVE_AHEAD above the largest allowed message contents size, and to + // MAX_RESERVE_AHEAD more than they've actually sent us. + size_t alloc_add = std::min(max_read, msg_bytes.size() + MAX_RESERVE_AHEAD); + m_recv_buffer.reserve(m_recv_buffer.size() + alloc_add); + break; + } + case RecvState::APP_READY: + // The buffer is empty in this state. + Assume(m_recv_buffer.empty()); + break; + } + } + + // Can't read more than provided input. + max_read = std::min(msg_bytes.size(), max_read); + // Copy data to buffer. + m_recv_buffer.insert(m_recv_buffer.end(), UCharCast(msg_bytes.data()), UCharCast(msg_bytes.data() + max_read)); + msg_bytes = msg_bytes.subspan(max_read); + + // Process data in the buffer. + switch (m_recv_state) { + + case RecvState::HANDSHAKE_STEP_1: + if (!ProcessReceivedEphemeralKeyBytes()) return false; + break; + + case RecvState::HANDSHAKE_STEP_2: + if (!ProcessReceivedHandshakeReplyBytes()) return false; + break; + + case RecvState::APP: + if (!ProcessReceivedPacketBytes()) return false; + break; + + case RecvState::APP_READY: + return true; + + } + // Make sure we have made progress before continuing. + Assume(max_read > 0); + } + + return true; +} + +bool Sv2Transport::ProcessReceivedEphemeralKeyBytes() noexcept +{ + AssertLockHeld(m_recv_mutex); + AssertLockNotHeld(m_send_mutex); + Assume(m_recv_state == RecvState::HANDSHAKE_STEP_1); + Assume(m_recv_buffer.size() <= Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + + if (m_recv_buffer.size() == Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE) { + // Other side's key has been fully received, and can now be Diffie-Hellman + // combined with our key. This is act 1 of the Noise Protocol handshake. + // TODO handle failure + // TODO: MakeByteSpan instead of MakeWritableByteSpan + m_cipher.GetHandshakeState().ReadMsgEphemeralPK(MakeWritableByteSpan(m_recv_buffer)); + m_recv_buffer.clear(); + SetReceiveState(RecvState::HANDSHAKE_STEP_2); + + LOCK(m_send_mutex); + Assume(m_send_buffer.size() == 0); + + // Send our act 2 handshake + SendHandshakeReply(); + } else { + // We still have to receive more key bytes. + } + return true; +} + +bool Sv2Transport::ProcessReceivedHandshakeReplyBytes() noexcept +{ + AssertLockHeld(m_recv_mutex); + AssertLockNotHeld(m_send_mutex); + Assume(m_recv_state == RecvState::HANDSHAKE_STEP_2); + Assume(m_recv_buffer.size() <= Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + + if (m_recv_buffer.size() == Sv2HandshakeState::HANDSHAKE_STEP2_SIZE) { + // TODO handle failure + // TODO: MakeByteSpan instead of MakeWritableByteSpan + bool res = m_cipher.GetHandshakeState().ReadMsgES(MakeWritableByteSpan(m_recv_buffer)); + if (!res) return false; + m_recv_buffer.clear(); + m_cipher.FinishHandshake(); + SetReceiveState(RecvState::APP); + + LOCK(m_send_mutex); + Assume(m_send_buffer.size() == 0); + + SetSendState(SendState::READY); + } else { + // We still have to receive more key bytes. + } + return true; +} + +size_t Sv2Transport::GetMaxBytesToProcess() noexcept +{ + AssertLockHeld(m_recv_mutex); + switch (m_recv_state) { + case RecvState::HANDSHAKE_STEP_1: + // In this state, we only allow the 64-byte key into the receive buffer. + Assume(m_recv_buffer.size() <= Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + return Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE - m_recv_buffer.size(); + case RecvState::HANDSHAKE_STEP_2: + // In this state, we only allow the handshake reply into the receive buffer. + Assume(m_recv_buffer.size() <= Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + return Sv2HandshakeState::HANDSHAKE_STEP2_SIZE - m_recv_buffer.size(); + case RecvState::APP: + // Decode a packet. Process the header first, + // so that we know where the current packet ends (and we don't process bytes from the next + // packet yet). Then, process the ciphertext bytes of the current packet. + if (m_recv_buffer.size() < SV2_HEADER_ENCRYPTED_SIZE) { + return SV2_HEADER_ENCRYPTED_SIZE - m_recv_buffer.size(); + } else { + // When transitioning from receiving the packet length to receiving its ciphertext, + // the encrypted header is left in the receive buffer. + size_t expanded_size_with_header = SV2_HEADER_ENCRYPTED_SIZE + Sv2Cipher::EncryptedMessageSize(m_header.m_msg_len); + return expanded_size_with_header - m_recv_buffer.size(); + } + case RecvState::APP_READY: + // No bytes can be processed until GetMessage() is called. + return 0; + } + Assume(false); // unreachable + return 0; +} + +bool Sv2Transport::ProcessReceivedPacketBytes() noexcept +{ + AssertLockHeld(m_recv_mutex); + Assume(m_recv_state == RecvState::APP); + + // The maximum permitted decrypted payload size for a packet + static constexpr size_t MAX_CONTENTS_LEN = 16777215; // 24 bit unsigned; + + Assume(m_recv_buffer.size() <= SV2_HEADER_ENCRYPTED_SIZE || m_header.m_msg_len > 0); + + if (m_recv_buffer.size() == SV2_HEADER_ENCRYPTED_SIZE) { + // Header received, decrypt it. + std::array header_plain; + if (!m_cipher.DecryptMessage(MakeWritableByteSpan(m_recv_buffer), header_plain)) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Failed to decrypt header\n"); + return false; + } + + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(header_plain)); + + // Decode header + DataStream ss_header{header_plain}; + node::Sv2NetHeader header; + ss_header >> header; + m_header = std::move(header); + + // TODO: 16 MB is pretty large, maybe set lower limits for most or all message types? + if (m_header.m_msg_len > MAX_CONTENTS_LEN) { + LogTrace(BCLog::SV2, "Packet too large (%u bytes)\n", m_header.m_msg_len); + return false; + } + + // Disconnect for empty messages (TODO: check the spec) + if (m_header.m_msg_len == 0) { + LogTrace(BCLog::SV2, "Empty message\n"); + return false; + } + LogTrace(BCLog::SV2, "Expecting %d bytes payload (plain)\n", m_header.m_msg_len); + } else if (m_recv_buffer.size() > SV2_HEADER_ENCRYPTED_SIZE && + m_recv_buffer.size() == SV2_HEADER_ENCRYPTED_SIZE + Sv2Cipher::EncryptedMessageSize(m_header.m_msg_len)) { + /** Ciphertext received: decrypt into decode_buffer and deserialize into m_message. + * + * Note that it is impossible to reach this branch without hitting the + * branch above first, as GetMaxBytesToProcess only allows up to + * SV2_HEADER_ENCRYPTED_SIZE into the buffer before that point. */ + std::vector payload; + payload.resize(m_header.m_msg_len); + + Span recv_span{MakeWritableByteSpan(m_recv_buffer).subspan(SV2_HEADER_ENCRYPTED_SIZE)}; + if (!m_cipher.DecryptMessage(recv_span, MakeWritableByteSpan(payload))) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Failed to decrypt message payload\n"); + return false; + } + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload)); + + // Wipe the receive buffer where the next packet will be received into. + ClearShrink(m_recv_buffer); + + Sv2NetMsg message{m_header.m_msg_type, std::move(payload)}; + m_message = std::move(message); + + // At this point we have a valid message decrypted into m_message. + SetReceiveState(RecvState::APP_READY); + } else { + // We either have less than 22 bytes, so we don't know the packet's length yet, or more + // than 22 bytes but less than the packet's full ciphertext. Wait until those arrive. + LogTrace(BCLog::SV2, "Waiting for more bytes\n"); + } + return true; +} + +bool Sv2Transport::ReceivedMessageComplete() const noexcept +{ + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); + + return m_recv_state == RecvState::APP_READY; +} + +CNetMessage Sv2Transport::GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) noexcept +{ + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); + Assume(m_recv_state == RecvState::APP_READY); + + SetReceiveState(RecvState::APP); + return m_message; // Sv2NetMsg is wrapped in a CNetMessage +} + +Transport::Info Sv2Transport::GetInfo() const noexcept +{ + return {.transport_type = TransportProtocolType::V1, .session_id = {}}; +} + +std::string RecvStateAsString(Sv2Transport::RecvState state) +{ + switch (state) { + case Sv2Transport::RecvState::HANDSHAKE_STEP_1: + return "HANDSHAKE_STEP_1"; + case Sv2Transport::RecvState::HANDSHAKE_STEP_2: + return "HANDSHAKE_STEP_2"; + case Sv2Transport::RecvState::APP: + return "APP"; + case Sv2Transport::RecvState::APP_READY: + return "APP_READY"; + } // no default case, so the compiler can warn about missing cases + + assert(false); +} + +std::string SendStateAsString(Sv2Transport::SendState state) +{ + switch (state) { + case Sv2Transport::SendState::HANDSHAKE_STEP_1: + return "HANDSHAKE_STEP_1"; + case Sv2Transport::SendState::HANDSHAKE_STEP_2: + return "HANDSHAKE_STEP_2"; + case Sv2Transport::SendState::READY: + return "READY"; + } // no default case, so the compiler can warn about missing cases + + assert(false); +} diff --git a/src/sv2/transport.h b/src/sv2/transport.h new file mode 100644 index 0000000000000..b8948e71a9f42 --- /dev/null +++ b/src/sv2/transport.h @@ -0,0 +1,194 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_SV2_TRANSPORT_H +#define BITCOIN_SV2_TRANSPORT_H + +#include +#include +#include +#include + +static constexpr size_t SV2_HEADER_PLAIN_SIZE{6}; +static constexpr size_t SV2_HEADER_ENCRYPTED_SIZE{SV2_HEADER_PLAIN_SIZE + Poly1305::TAGLEN}; + +using node::Sv2NetHeader; +using node::Sv2NetMsg; + +class Sv2Transport final : public Transport +{ +public: + + // The sender side and receiver side of Sv2Transport are state machines that are transitioned + // through, based on what has been received. The receive state corresponds to the contents of, + // and bytes received to, the receive buffer. The send state controls what can be appended to + // the send buffer and what can be sent from it. + + /** State type that defines the current contents of the receive buffer and/or how the next + * received bytes added to it will be interpreted. + * + * Diagram: + * + * start(responder) + * | start(initiator) + * | | /---------\ + * | | | | + * v v v | + * HANDSHAKE_STEP_1 -> HANDSHAKE_STEP_2 -> APP -> APP_READY + */ + enum class RecvState : uint8_t { + /** Handshake Act 1: -> E */ + HANDSHAKE_STEP_1, + + /** Handshake Act 2: <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE */ + HANDSHAKE_STEP_2, + + /** Application packet. + * + * A packet is received, and decrypted/verified. If that succeeds, the + * state becomes APP_READY and the decrypted message is kept in m_message + * until it is retrieved by GetMessage(). */ + APP, + + /** Nothing (an application packet is available for GetMessage()). + * + * Nothing can be received in this state. When the message is retrieved + * by GetMessage(), the state becomes APP again. */ + APP_READY, + }; + + /** State type that controls the sender side. + * + * Diagram: + * + * start(initiator) + * | start(responder) + * | | + * | | + * v v + * HANDSHAKE_STEP_1 -> HANDSHAKE_STEP_2 -> READY + */ + enum class SendState : uint8_t { + /** Handshake Act 1: -> E */ + HANDSHAKE_STEP_1, + + /** Handshake Act 2: <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE */ + HANDSHAKE_STEP_2, + + /** Normal sending state. + * + * In this state, the ciphers are initialized, so packets can be sent. + * In this state a message can be provided if the send buffer is empty. */ + READY, + }; + +private: + + /** Cipher state. */ + Sv2Cipher m_cipher; + + /** Whether we are the initiator side. */ + const bool m_initiating; + + /** Lock for receiver-side fields. */ + mutable Mutex m_recv_mutex ACQUIRED_BEFORE(m_send_mutex); + /** Receive buffer; meaning is determined by m_recv_state. */ + std::vector m_recv_buffer GUARDED_BY(m_recv_mutex); + /** AAD expected in next received packet (currently used only for garbage). */ + std::vector m_recv_aad GUARDED_BY(m_recv_mutex); + /** Current receiver state. */ + RecvState m_recv_state GUARDED_BY(m_recv_mutex); + + /** Lock for sending-side fields. If both sending and receiving fields are accessed, + * m_recv_mutex must be acquired before m_send_mutex. */ + mutable Mutex m_send_mutex ACQUIRED_AFTER(m_recv_mutex); + /** The send buffer; meaning is determined by m_send_state. */ + std::vector m_send_buffer GUARDED_BY(m_send_mutex); + /** How many bytes from the send buffer have been sent so far. */ + uint32_t m_send_pos GUARDED_BY(m_send_mutex) {0}; + /** The garbage sent, or to be sent (MAYBE_V1 and AWAITING_KEY state only). */ + std::vector m_send_garbage GUARDED_BY(m_send_mutex); + /** Type of the message being sent. */ + std::string m_send_type GUARDED_BY(m_send_mutex); + /** Current sender state. */ + SendState m_send_state GUARDED_BY(m_send_mutex); + + /** Change the receive state. */ + void SetReceiveState(RecvState recv_state) noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + /** Change the send state. */ + void SetSendState(SendState send_state) noexcept EXCLUSIVE_LOCKS_REQUIRED(m_send_mutex); + /** Given a packet's contents, find the message type (if valid), and strip it from contents. */ + static std::optional GetMessageType(Span& contents) noexcept; + /** Determine how many received bytes can be processed in one go (not allowed in V1 state). */ + size_t GetMaxBytesToProcess() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + /** Put our ephemeral public key in the send buffer. */ + void StartSendingHandshake() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_send_mutex, !m_recv_mutex); + /** Put second part of the handshake in the send buffer. */ + void SendHandshakeReply() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_send_mutex, m_recv_mutex); + /** Process bytes in m_recv_buffer, while in HANDSHAKE_STEP_1 state. */ + bool ProcessReceivedEphemeralKeyBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex, !m_send_mutex); + /** Process bytes in m_recv_buffer, while in HANDSHAKE_STEP_2 state. */ + bool ProcessReceivedHandshakeReplyBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex, !m_send_mutex); + + /** Process bytes in m_recv_buffer, while in VERSION/APP state. */ + bool ProcessReceivedPacketBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + + /** In APP, the decrypted header, if m_recv_buffer.size() >= + * SV2_HEADER_ENCRYPTED_SIZE. Unspecified otherwise. */ + Sv2NetHeader m_header GUARDED_BY(m_recv_mutex); + /* In APP_READY the last retrieved message. Unspecified otherwise */ + Sv2NetMsg m_message GUARDED_BY(m_recv_mutex); + +public: + /** Construct a Stratum v2 transport as the initiator + * + * @param[in] static_key a securely generated key + + */ + Sv2Transport(CKey static_key, XOnlyPubKey responder_authority_key) noexcept; + + /** Construct a Stratum v2 transport as the responder + * + * @param[in] static_key a securely generated key + + */ + Sv2Transport(CKey static_key, Sv2SignatureNoiseMessage certificate) noexcept; + + // Receive side functions. + bool ReceivedMessageComplete() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); + bool ReceivedBytes(Span& msg_bytes) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex, !m_send_mutex); + + CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); + + // Send side functions. + bool SetMessageToSend(CSerializedNetMsg& msg) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + + BytesToSend GetBytesToSend(bool have_next_message) const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + + void MarkBytesSent(size_t bytes_sent) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + size_t GetSendMemoryUsage() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + + // Miscellaneous functions. + bool ShouldReconnectV1() const noexcept override { return false; }; + Info GetInfo() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); + + // Test only + uint256 NoiseHash() const { return m_cipher.GetHash(); }; + RecvState GetRecvState() EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); + return m_recv_state; + }; + SendState GetSendState() EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex) { + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + return m_send_state; + }; +}; + +/** Convert TransportProtocolType enum to a string value */ +std::string RecvStateAsString(Sv2Transport::RecvState state); +std::string SendStateAsString(Sv2Transport::SendState state); + +#endif // BITCOIN_SV2_TRANSPORT_H diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 83ad6b5cbf3aa..4196a417005b6 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -179,6 +179,7 @@ if(WITH_SV2) target_sources(test_bitcoin PRIVATE sv2_noise_tests.cpp + sv2_transport_tests.cpp ) target_link_libraries(test_bitcoin bitcoin_sv2) endif() diff --git a/src/test/sv2_transport_tests.cpp b/src/test/sv2_transport_tests.cpp new file mode 100644 index 0000000000000..c4c7da2cfd358 --- /dev/null +++ b/src/test/sv2_transport_tests.cpp @@ -0,0 +1,389 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +using namespace std::literals; +using node::Sv2NetMsg; +using node::Sv2CoinbaseOutputConstraintsMsg; +using node::Sv2MsgType; + +BOOST_FIXTURE_TEST_SUITE(sv2_transport_tests, RegTestingSetup) + +namespace { + +/** A class for scenario-based tests of Sv2Transport + * + * Each Sv2TransportTester encapsulates a Sv2Transport (the one being tested), + * and can be told to interact with it. To do so, it also encapsulates a Sv2Cipher + * to act as the other side. A second Sv2Transport is not used, as doing so would + * not permit scenarios that involve sending invalid data. + */ +class Sv2TransportTester +{ + FastRandomContext& m_rng; + std::unique_ptr m_transport; //!< Sv2Transport being tested + std::unique_ptr m_peer_cipher; //!< Cipher to help with the other side + bool m_test_initiator; //!< Whether m_transport is the initiator (true) or responder (false) + + std::vector m_to_send; //!< Bytes we have queued up to send to m_transport-> + std::vector m_received; //!< Bytes we have received from m_transport-> + std::deque m_msg_to_send; //!< Messages to be sent *by* m_transport to us. + +public: + /** Construct a tester object. test_initiator: whether the tested transport is initiator. */ + + explicit Sv2TransportTester(FastRandomContext& rng, bool test_initiator) : m_rng{rng}, m_test_initiator(test_initiator) + { + auto initiator_static_key{GenerateRandomKey()}; + auto responder_static_key{GenerateRandomKey()}; + auto responder_authority_key{GenerateRandomKey()}; + + // Create certificates + auto epoch_now = std::chrono::system_clock::now().time_since_epoch(); + uint16_t version = 0; + uint32_t valid_from = static_cast(std::chrono::duration_cast(epoch_now).count()); + uint32_t valid_to = std::numeric_limits::max(); + + auto responder_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(responder_static_key.GetPubKey()), responder_authority_key); + + if (test_initiator) { + m_transport = std::make_unique(initiator_static_key, XOnlyPubKey(responder_authority_key.GetPubKey())); + m_peer_cipher = std::make_unique(std::move(responder_static_key), std::move(responder_certificate)); + } else { + m_transport = std::make_unique(responder_static_key, responder_certificate); + m_peer_cipher = std::make_unique(std::move(initiator_static_key), XOnlyPubKey(responder_authority_key.GetPubKey())); + } + } + + /** Data type returned by Interact: + * + * - std::nullopt: transport error occurred + * - otherwise: a vector of + * - std::nullopt: invalid message received + * - otherwise: a Sv2NetMsg retrieved + */ + using InteractResult = std::optional>>; + + void LogProgress(bool should_progress, bool progress, bool pretend_no_progress) { + if (!should_progress) { + BOOST_TEST_MESSAGE("[Interact] !should_progress"); + } else if (!progress) { + BOOST_TEST_MESSAGE("[Interact] should_progress && !progress"); + } else if (pretend_no_progress) { + BOOST_TEST_MESSAGE("[Interact] pretend !progress"); + } + } + + /** Send/receive scheduled/available bytes and messages. + * + * This is the only function that interacts with the transport being tested; everything else is + * scheduling things done by Interact(), or processing things learned by it. + */ + InteractResult Interact() + { + std::vector> ret; + while (true) { + bool progress{false}; + // Send bytes from m_to_send to the transport. + if (!m_to_send.empty()) { + size_t n_bytes_to_send = 1 + m_rng.randrange(m_to_send.size()); + BOOST_TEST_MESSAGE(strprintf("[Interact] send %d of %d bytes", n_bytes_to_send, m_to_send.size())); + Span to_send = Span{m_to_send}.first(n_bytes_to_send); + size_t old_len = to_send.size(); + if (!m_transport->ReceivedBytes(to_send)) { + BOOST_TEST_MESSAGE("[Interact] transport error"); + return std::nullopt; + } + if (old_len != to_send.size()) { + progress = true; + m_to_send.erase(m_to_send.begin(), m_to_send.begin() + (old_len - to_send.size())); + } + } + // Retrieve messages received by the transport. + bool should_progress = m_transport->ReceivedMessageComplete(); + bool pretend_no_progress = m_rng.randbool(); + LogProgress(should_progress, progress, pretend_no_progress); + if (should_progress && (!progress || pretend_no_progress)) { + bool dummy_reject_message = false; + CNetMessage net_msg = m_transport->GetReceivedMessage(std::chrono::microseconds(0), dummy_reject_message); + Sv2NetMsg msg(std::move(net_msg)); + ret.emplace_back(std::move(msg)); + progress = true; + } + // Enqueue a message to be sent by the transport to us. + should_progress = !m_msg_to_send.empty(); + pretend_no_progress = m_rng.randbool(); + LogProgress(should_progress, progress, pretend_no_progress); + if (should_progress && (!progress || pretend_no_progress)) { + BOOST_TEST_MESSAGE("Shoehorn into CSerializedNetMsg"); + CSerializedNetMsg msg{m_msg_to_send.front()}; + BOOST_TEST_MESSAGE("Call SetMessageToSend"); + if (m_transport->SetMessageToSend(msg)) { + BOOST_TEST_MESSAGE("Finished SetMessageToSend"); + m_msg_to_send.pop_front(); + progress = true; + } + } + // Receive bytes from the transport. + const auto& [recv_bytes, _more, _m_type] = m_transport->GetBytesToSend(!m_msg_to_send.empty()); + should_progress = !recv_bytes.empty(); + pretend_no_progress = m_rng.randbool(); + LogProgress(should_progress, progress, pretend_no_progress); + if (should_progress && (!progress || pretend_no_progress)) { + size_t to_receive = 1 + m_rng.randrange(recv_bytes.size()); + BOOST_TEST_MESSAGE(strprintf("[Interact] receive %d of %d bytes", to_receive, recv_bytes.size())); + m_received.insert(m_received.end(), recv_bytes.begin(), recv_bytes.begin() + to_receive); + progress = true; + m_transport->MarkBytesSent(to_receive); + } + if (!progress) break; + } + return ret; + } + + /** Schedule bytes to be sent to the transport. */ + void Send(Span data) + { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send: %s\n", HexStr(data)); + m_to_send.insert(m_to_send.end(), data.begin(), data.end()); + } + + /** Schedule bytes to be sent to the transport. */ + void Send(Span data) { Send(MakeUCharSpan(data)); } + + /** Schedule a message to be sent to us by the transport. */ + void AddMessage(Sv2NetMsg msg) + { + m_msg_to_send.push_back(std::move(msg)); + } + + /** + * If we are the initiator, the send buffer should contain our ephemeral public + * key. Pass this to the peer cipher and clear the buffer. + * + * If we are the responder, put the peer ephemeral public key on our receive buffer. + */ + void ProcessHandshake1() { + if (m_test_initiator) { + BOOST_REQUIRE(m_received.size() == Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + m_peer_cipher->GetHandshakeState().ReadMsgEphemeralPK(MakeWritableByteSpan(m_received)); + m_received.clear(); + } else { + BOOST_REQUIRE(m_to_send.empty()); + m_to_send.resize(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + m_peer_cipher->GetHandshakeState().WriteMsgEphemeralPK(MakeWritableByteSpan(m_to_send)); + } + + } + + /** Expect key to have been received from transport and process it. + * + * Many other Sv2TransportTester functions cannot be called until after + * ProcessHandshake2() has been called, as no encryption keys are set up before that point. + */ + void ProcessHandshake2() + { + if (m_test_initiator) { + BOOST_REQUIRE(m_to_send.empty()); + + // Have the peer cypher write the second part of the handshake into our receive buffer + m_to_send.resize(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + m_peer_cipher->GetHandshakeState().WriteMsgES(MakeWritableByteSpan(m_to_send)); + + // At this point the peer is done with the handshake: + m_peer_cipher->FinishHandshake(); + } else { + BOOST_REQUIRE(m_received.size() == Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + BOOST_REQUIRE(m_peer_cipher->GetHandshakeState().ReadMsgES(MakeWritableByteSpan(m_received))); + m_received.clear(); + + m_peer_cipher->FinishHandshake(); + } + } + + /** Schedule an encrypted packet with specified content to be sent to transport + * (only after ReceiveKey). */ + void SendPacket(Sv2NetMsg msg) + { + // TODO: randomly break stuff + + std::vector ciphertext; + const size_t encrypted_payload_size = Sv2Cipher::EncryptedMessageSize(msg.size()); + ciphertext.resize(SV2_HEADER_ENCRYPTED_SIZE + encrypted_payload_size); + Span buffer_span{MakeWritableByteSpan(ciphertext)}; + + // Header + DataStream ss_header_plain{}; + ss_header_plain << Sv2NetHeader(msg); + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(ss_header_plain)); + Span header_encrypted{buffer_span.subspan(0, SV2_HEADER_ENCRYPTED_SIZE)}; + BOOST_REQUIRE(m_peer_cipher->EncryptMessage(ss_header_plain, header_encrypted)); + + // Payload + Span payload_plain = MakeByteSpan(msg); + // TODO: truncate very long messages, about 100 bytes at the start and end + // is probably enough for most debugging. + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload_plain)); + Span payload_encrypted{buffer_span.subspan(SV2_HEADER_ENCRYPTED_SIZE, encrypted_payload_size)}; + BOOST_REQUIRE(m_peer_cipher->EncryptMessage(payload_plain, payload_encrypted)); + + // Schedule it for sending. + Send(ciphertext); + } + + /** Expect application packet to have been received, with specified message type and payload. + * (only after ReceiveKey). */ + void ReceiveMessage(Sv2NetMsg expected_msg) + { + // When processing a packet, at least enough bytes for its length descriptor must be received. + BOOST_REQUIRE(m_received.size() >= SV2_HEADER_ENCRYPTED_SIZE); + + auto header_encrypted{MakeWritableByteSpan(m_received).subspan(0, SV2_HEADER_ENCRYPTED_SIZE)}; + std::array header_plain; + BOOST_REQUIRE(m_peer_cipher->DecryptMessage(header_encrypted, header_plain)); + + // Decode header + DataStream ss_header{header_plain}; + node::Sv2NetHeader header; + ss_header >> header; + + BOOST_CHECK(header.m_msg_type == expected_msg.m_msg_type); + + size_t expanded_size = Sv2Cipher::EncryptedMessageSize(header.m_msg_len); + BOOST_REQUIRE(m_received.size() >= SV2_HEADER_ENCRYPTED_SIZE + expanded_size); + + Span encrypted_payload{MakeWritableByteSpan(m_received).subspan(SV2_HEADER_ENCRYPTED_SIZE, expanded_size)}; + Span payload = encrypted_payload.subspan(0, header.m_msg_len); + + BOOST_REQUIRE(m_peer_cipher->DecryptMessage(encrypted_payload, payload)); + + std::vector decode_buffer; + decode_buffer.resize(header.m_msg_len); + + std::transform(payload.begin(), payload.end(), decode_buffer.begin(), + [](std::byte b) { return static_cast(b); }); + + // TODO: clear the m_received we used + + Sv2NetMsg message{header.m_msg_type, std::move(decode_buffer)}; + + // TODO: compare payload + } + + /** Test whether the transport's m_hash matches the other side. */ + void CompareHash() const + { + BOOST_REQUIRE(m_transport); + BOOST_CHECK(m_transport->NoiseHash() == m_peer_cipher->GetHash()); + } + + void CheckRecvState(Sv2Transport::RecvState state) { + BOOST_REQUIRE(m_transport); + BOOST_CHECK_EQUAL(RecvStateAsString(m_transport->GetRecvState()), RecvStateAsString(state)); + } + + void CheckSendState(Sv2Transport::SendState state) { + BOOST_REQUIRE(m_transport); + BOOST_CHECK_EQUAL(SendStateAsString(m_transport->GetSendState()), SendStateAsString(state)); + } + + /** Introduce a bit error in the data scheduled to be sent. */ + // void Damage() + // { + // BOOST_TEST_MESSAGE("[Interact] introduce a bit error"); + // m_to_send[m_rng.randrange(m_to_send.size())] ^= (uint8_t{1} << m_rng.randrange(8)); + // } +}; + +} // namespace + +BOOST_AUTO_TEST_CASE(sv2_transport_initiator_test) +{ + // A mostly normal scenario, testing a transport in initiator mode. + // Interact() introduces randomness, so run multiple times + for (int i = 0; i < 10; ++i) { + BOOST_TEST_MESSAGE(strprintf("\nIteration %d (initiator)", i)); + Sv2TransportTester tester(m_rng, true); + // As the initiator, our ephemeral public key is immedidately put + // onto the buffer. + tester.CheckSendState(Sv2Transport::SendState::HANDSHAKE_STEP_2); + tester.CheckRecvState(Sv2Transport::RecvState::HANDSHAKE_STEP_2); + auto ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.ProcessHandshake1(); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.ProcessHandshake2(); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.CheckSendState(Sv2Transport::SendState::READY); + tester.CheckRecvState(Sv2Transport::RecvState::APP); + tester.CompareHash(); + } +} + +BOOST_AUTO_TEST_CASE(sv2_transport_responder_test) +{ + // Normal scenario, with a transport in responder node. + for (int i = 0; i < 10; ++i) { + BOOST_TEST_MESSAGE(strprintf("\nIteration %d (responder)", i)); + Sv2TransportTester tester(m_rng, false); + tester.CheckSendState(Sv2Transport::SendState::HANDSHAKE_STEP_2); + tester.CheckRecvState(Sv2Transport::RecvState::HANDSHAKE_STEP_1); + tester.ProcessHandshake1(); + auto ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.CheckSendState(Sv2Transport::SendState::READY); + tester.CheckRecvState(Sv2Transport::RecvState::APP); + + // Have the test cypher process our handshake reply + tester.ProcessHandshake2(); + tester.CompareHash(); + + // Handshake complete, have the initiator send us a message: + Sv2CoinbaseOutputConstraintsMsg body{4000, 400}; + Sv2NetMsg msg{body}; + BOOST_REQUIRE(msg.m_msg_type == Sv2MsgType::COINBASE_OUTPUT_CONSTRAINTS); + + tester.SendPacket(msg); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->size() == 1); + BOOST_CHECK((*ret)[0] && + (*ret)[0]->m_msg_type == Sv2MsgType::COINBASE_OUTPUT_CONSTRAINTS); + + tester.CompareHash(); + + // Send a message back to the initiator + tester.AddMessage(msg); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->size() == 0); + tester.ReceiveMessage(msg); + + // TODO: send / receive message larger than the chunk size + } +} + + +BOOST_AUTO_TEST_SUITE_END() From a0deeb192f36dcea0b0c356569fa1234ea184191 Mon Sep 17 00:00:00 2001 From: Sjors Provoost Date: Mon, 15 Jul 2024 13:30:39 +0200 Subject: [PATCH 28/30] Add sv2 SETUP_CONNECTION messages Co-Authored-By: Christopher Coverdale --- src/sv2/messages.h | 153 +++++++++++++++++++++++++++++++- src/test/CMakeLists.txt | 1 + src/test/sv2_messages_tests.cpp | 100 +++++++++++++++++++++ 3 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 src/test/sv2_messages_tests.cpp diff --git a/src/sv2/messages.h b/src/sv2/messages.h index 277326687a1cf..e74fda77f1688 100644 --- a/src/sv2/messages.h +++ b/src/sv2/messages.h @@ -6,11 +6,22 @@ #define BITCOIN_SV2_MESSAGES_H #include // for CSerializedNetMsg and CNetMessage +#include +#include +#include +#include