Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dnsdist: add support for a callback when a new tickets key is added #14327

Merged
merged 9 commits into from
Jul 4, 2024
4 changes: 2 additions & 2 deletions pdns/dnsdistdist/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,14 @@ endif

if HAVE_DNS_OVER_TLS
if HAVE_GNUTLS
dnsdist_LDADD += -lgnutls
dnsdist_LDADD += $(GNUTLS_LIBS)
endif
endif

if HAVE_DNS_OVER_HTTPS

if HAVE_GNUTLS
dnsdist_LDADD += -lgnutls
dnsdist_LDADD += $(GNUTLS_LIBS)
endif

if HAVE_LIBH2OEVLOOP
Expand Down
23 changes: 22 additions & 1 deletion pdns/dnsdistdist/dnsdist-lua-hooks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
#include "dnsdist-lua-hooks.hh"
#include "dnsdist-lua.hh"
#include "lock.hh"
#include "tcpiohandler.hh"

namespace dnsdist::lua::hooks
{
using MaintenanceCallback = std::function<void()>;
using TicketsKeyAddedHook = std::function<void(const char*, size_t)>;

static LockGuarded<std::vector<MaintenanceCallback>> s_maintenanceHooks;

void runMaintenanceHooks(const LuaContext& context)
Expand All @@ -15,7 +19,7 @@ void runMaintenanceHooks(const LuaContext& context)
}
}

void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback)
static void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback)
{
(void)context;
s_maintenanceHooks.lock()->push_back(std::move(callback));
Expand All @@ -26,12 +30,29 @@ void clearMaintenanceHooks()
s_maintenanceHooks.lock()->clear();
}

static void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAddedHook& hook)
{
TLSCtx::setTicketsKeyAddedHook([hook](const std::string& key) {
try {
auto lua = g_lua.lock();
hook(key.c_str(), key.size());
}
catch (const std::exception& exp) {
warnlog("Error calling the Lua hook after new tickets key has been added: %s", exp.what());
}
});
}

void setupLuaHooks(LuaContext& luaCtx)
{
luaCtx.writeFunction("addMaintenanceCallback", [&luaCtx](const MaintenanceCallback& callback) {
setLuaSideEffect();
addMaintenanceCallback(luaCtx, callback);
});
luaCtx.writeFunction("setTicketsKeyAddedHook", [&luaCtx](const TicketsKeyAddedHook& hook) {
setLuaSideEffect();
setTicketsKeyAddedHook(luaCtx, hook);
});
}

}
2 changes: 0 additions & 2 deletions pdns/dnsdistdist/dnsdist-lua-hooks.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ class LuaContext;

namespace dnsdist::lua::hooks
{
using MaintenanceCallback = std::function<void()>;
void runMaintenanceHooks(const LuaContext& context);
void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback);
void clearMaintenanceHooks();
void setupLuaHooks(LuaContext& luaCtx);
}
2 changes: 1 addition & 1 deletion pdns/dnsdistdist/dnsdist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
}
}

LockGuarded<LuaContext> g_lua{LuaContext()};
RecursiveLockGuarded<LuaContext> g_lua{LuaContext()};
ComboAddress g_serverControl{"127.0.0.1:5199"};

static void spoofResponseFromString(DNSQuestion& dnsQuestion, const string& spoofContent, bool raw)
Expand Down
2 changes: 1 addition & 1 deletion pdns/dnsdistdist/dnsdist.hh
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ public:
using servers_t = vector<std::shared_ptr<DownstreamState>>;

void responderThread(std::shared_ptr<DownstreamState> dss);
extern LockGuarded<LuaContext> g_lua;
extern RecursiveLockGuarded<LuaContext> g_lua;
extern std::string g_outputBuffer; // locking for this is ok, as locked by g_luamutex

class DNSRule
Expand Down
11 changes: 11 additions & 0 deletions pdns/dnsdistdist/docs/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,17 @@ Other functions
Code is supplied as a string, not as a function object.
Note that this function does nothing in 'client' or 'config-check' modes.

.. function:: setTicketsKeyAddedHook(callback)

.. versionadded:: 1.9.6

Set a Lua function that will be called everytime a new tickets key is added. The function receives:

* the key content as a string
* the keylen as an integer

See :doc:`../advanced/tls-sessions-management` for more information.

.. function:: submitToMainThread(cmd, dict)

.. versionadded:: 1.8.0
Expand Down
2 changes: 1 addition & 1 deletion pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()};

#include "ext/luawrapper/include/LuaContext.hpp"
LockGuarded<LuaContext> g_lua{LuaContext()};
RecursiveLockGuarded<LuaContext> g_lua{LuaContext()};

bool g_snmpEnabled{false};
bool g_snmpTrapsEnabled{false};
Expand Down
21 changes: 21 additions & 0 deletions pdns/libssl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

#undef CERT
#include "misc.hh"
#include "tcpiohandler.hh"

#if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL)
/* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
Expand Down Expand Up @@ -631,6 +632,13 @@ OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() = default;
void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey)
{
d_ticketKeys.write_lock()->push_front(std::move(newKey));
if (TLSCtx::hasTicketsKeyAddedHook()) {
auto key = d_ticketKeys.read_lock()->front();
auto keyContent = key->content();
TLSCtx::getTicketsKeyAddedHook()(keyContent);
chbruyand marked this conversation as resolved.
Show resolved Hide resolved
// fills mem with 0's
OPENSSL_cleanse(keyContent.data(), keyContent.size());
}
}

std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getEncryptionKey()
Expand Down Expand Up @@ -737,6 +745,19 @@ bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_N
return (memcmp(d_name, name, sizeof(d_name)) == 0);
}

std::string OpenSSLTLSTicketKey::content() const
{
std::string result{};
result.reserve(TLS_TICKETS_KEY_NAME_SIZE + TLS_TICKETS_CIPHER_KEY_SIZE + TLS_TICKETS_MAC_KEY_SIZE);
// NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
result.append(reinterpret_cast<const char*>(d_name), TLS_TICKETS_KEY_NAME_SIZE);
result.append(reinterpret_cast<const char*>(d_cipherKey), TLS_TICKETS_CIPHER_KEY_SIZE);
result.append(reinterpret_cast<const char*>(d_hmacKey), TLS_TICKETS_MAC_KEY_SIZE);
// NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)

return result;
}

#if OPENSSL_VERSION_MAJOR >= 3
static const std::string sha256KeyName{"sha256"};
#endif
Expand Down
2 changes: 1 addition & 1 deletion pdns/libssl.hh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public:
#if OPENSSL_VERSION_MAJOR >= 3
int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const;
bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const;
[[nodiscard]] std::string content() const;
#else
int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const;
bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const;
Expand All @@ -124,7 +125,6 @@ public:

private:
void addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey);

SharedLockGuarded<boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > > d_ticketKeys;
};

Expand Down
105 changes: 105 additions & 0 deletions pdns/lock.hh
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,111 @@ private:
T d_value;
};

template <typename T>
class RecursiveLockGuardedHolder
{
public:
explicit RecursiveLockGuardedHolder(T& value, std::recursive_mutex& mutex) :
d_lock(mutex), d_value(value)
{
}

T& operator*() const noexcept
{
return d_value;
}

T* operator->() const noexcept
{
return &d_value;
}

private:
std::lock_guard<std::recursive_mutex> d_lock;
T& d_value;
};

template <typename T>
class RecursiveLockGuardedTryHolder
{
public:
explicit RecursiveLockGuardedTryHolder(T& value, std::recursive_mutex& mutex) :
d_lock(mutex, std::try_to_lock), d_value(value)
{
}

T& operator*() const
{
if (!owns_lock()) {
throw std::runtime_error("Trying to access data protected by a mutex while the lock has not been acquired");
}
return d_value;
}

T* operator->() const
{
if (!owns_lock()) {
throw std::runtime_error("Trying to access data protected by a mutex while the lock has not been acquired");
}
return &d_value;
}

operator bool() const noexcept
{
return d_lock.owns_lock();
}

[[nodiscard]] bool owns_lock() const noexcept
{
return d_lock.owns_lock();
}

void lock()
{
d_lock.lock();
}

private:
std::unique_lock<std::recursive_mutex> d_lock;
T& d_value;
};

template <typename T>
class RecursiveLockGuarded
{
public:
explicit RecursiveLockGuarded(const T& value) :
d_value(value)
{
}

explicit RecursiveLockGuarded(T&& value) :
d_value(std::move(value))
{
}

explicit RecursiveLockGuarded() = default;

RecursiveLockGuardedTryHolder<T> try_lock()
{
return RecursiveLockGuardedTryHolder<T>(d_value, d_mutex);
}

RecursiveLockGuardedHolder<T> lock()
{
return RecursiveLockGuardedHolder<T>(d_value, d_mutex);
}

RecursiveLockGuardedHolder<const T> read_only_lock()
{
return RecursiveLockGuardedHolder<const T>(d_value, d_mutex);
}

private:
std::recursive_mutex d_mutex;
T d_value;
};

template <typename T>
class SharedLockGuardedHolder
{
Expand Down
39 changes: 29 additions & 10 deletions pdns/tcpiohandler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ const bool TCPIOHandler::s_disableConnectForUnitTests = false;
#include <sodium.h>
#endif /* HAVE_LIBSODIUM */

TLSCtx::tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr};

#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
#ifdef HAVE_LIBSSL

Expand Down Expand Up @@ -987,6 +989,16 @@ class GnuTLSTicketsKey
throw;
}
}
[[nodiscard]] std::string content() const
{
std::string result{};
if (d_key.data != nullptr && d_key.size > 0) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
result.append(reinterpret_cast<const char*>(d_key.data), d_key.size);
chbruyand marked this conversation as resolved.
Show resolved Hide resolved
safe_memory_lock(result.data(), result.size());
}
return result;
}

~GnuTLSTicketsKey()
{
Expand Down Expand Up @@ -1730,37 +1742,44 @@ class GnuTLSIOCtx: public TLSCtx
return connection;
}

void rotateTicketsKey(time_t now) override
void addTicketsKey(time_t now, std::shared_ptr<GnuTLSTicketsKey>&& newKey)
{
if (!d_enableTickets) {
return;
}

auto newKey = std::make_shared<GnuTLSTicketsKey>();

{
*(d_ticketsKey.write_lock()) = std::move(newKey);
}

if (d_ticketsKeyRotationDelay > 0) {
d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
}

if (TLSCtx::hasTicketsKeyAddedHook()) {
auto ticketsKey = *(d_ticketsKey.read_lock());
auto content = ticketsKey->content();
TLSCtx::getTicketsKeyAddedHook()(content);
chbruyand marked this conversation as resolved.
Show resolved Hide resolved
safe_memory_release(content.data(), content.size());
}
}
void rotateTicketsKey(time_t now) override
{
if (!d_enableTickets) {
return;
}

auto newKey = std::make_shared<GnuTLSTicketsKey>();
addTicketsKey(now, std::move(newKey));
}
void loadTicketsKeys(const std::string& file) final
{
if (!d_enableTickets) {
return;
}

auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
{
*(d_ticketsKey.write_lock()) = std::move(newKey);
}

if (d_ticketsKeyRotationDelay > 0) {
d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
}
addTicketsKey(time(nullptr), std::move(newKey));
}

size_t getTicketsKeysCount() override
Expand Down
Loading
Loading