Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/server/server_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ ABSL_FLAG(string, availability_zone, "",
"server availability zone, used by clients to read from local-zone replicas");

ABSL_FLAG(bool, keep_legacy_memory_metrics, true, "legacy metrics format");
ABSL_FLAG(uint32_t, tls_reload_interval_secs, 0,
"If non-zero, periodically checks whether the TLS cert/key files on disk have changed "
"and reloads them automatically. Interval in seconds (minimum 60). 0 to disable.");
// TODO deprecate when flipped in production
ABSL_FLAG(bool, replicaof_no_one_start_journal, true,
"when set, preserves journal offsets after REPLICAOF NO ONE");
Expand All @@ -165,6 +168,8 @@ ABSL_DECLARE_FLAG(int32_t, port);
ABSL_DECLARE_FLAG(bool, cache_mode);
ABSL_DECLARE_FLAG(int32_t, hz);
ABSL_DECLARE_FLAG(bool, tls);
ABSL_DECLARE_FLAG(string, tls_cert_file);
ABSL_DECLARE_FLAG(string, tls_key_file);
ABSL_DECLARE_FLAG(string, tls_ca_cert_file);
ABSL_DECLARE_FLAG(string, tls_ca_cert_dir);
ABSL_DECLARE_FLAG(int, replica_priority);
Expand Down Expand Up @@ -1208,6 +1213,11 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vector<facade::Listen
return true;
});
create_snapshot_schedule_fb();

// Launch the TLS cert hot-reload fiber. It exits immediately when the
// interval is 0, so there is no cost when the feature is disabled.
tls_reload_fb_ =
service_.proactor_pool().GetNextProactor()->LaunchFiber([this] { TlsReloadScheduling(); });
}

void ServerFamily::LoadFromSnapshot() {
Expand Down Expand Up @@ -1258,6 +1268,9 @@ void ServerFamily::Shutdown() {

JoinSnapshotSchedule();

tls_reload_done_.Notify();
tls_reload_fb_.JoinIfNeeded();

bg_save_fb_.JoinIfNeeded();

if (save_on_shutdown_ && !absl::GetFlag(FLAGS_dbfilename).empty()) {
Expand Down Expand Up @@ -1499,6 +1512,76 @@ void ServerFamily::SnapshotScheduling() {
}
}

namespace {

time_t FileMtime(const std::string& path) {
if (path.empty())
return 0;
struct stat st;
if (stat(path.c_str(), &st) != 0)
return 0;
return st.st_mtim.tv_sec;
}

} // namespace

void ServerFamily::TlsReloadScheduling() {
uint32_t interval = absl::GetFlag(FLAGS_tls_reload_interval_secs);
if (interval == 0)
return;

if (!absl::GetFlag(FLAGS_tls)) {
LOG(WARNING) << "tls_reload_interval_secs is set but TLS is not enabled; "
"ignoring";
return;
}

if (interval < 60) {
LOG(WARNING) << "tls_reload_interval_secs must be >= 60, got " << interval << "; clamping";
interval = 60;
}

LOG(INFO) << "TLS cert hot reload enabled, checking every " << interval << "s";

// Snapshot the initial mtimes so we only reload on change.
time_t last_cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_cert_file));
time_t last_key_mtime = FileMtime(absl::GetFlag(FLAGS_tls_key_file));
time_t last_ca_cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_ca_cert_file));

while (true) {
if (tls_reload_done_.WaitFor(chrono::seconds(interval))) {
break;
}

time_t cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_cert_file));
time_t key_mtime = FileMtime(absl::GetFlag(FLAGS_tls_key_file));
time_t ca_cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_ca_cert_file));

if (cert_mtime == last_cert_mtime && key_mtime == last_key_mtime &&
ca_cert_mtime == last_ca_cert_mtime) {
continue;
}

LOG(INFO) << "TLS cert/key file change detected, reloading";

bool ok = true;
for (facade::Listener* l : listeners_) {
if (!l->socket()->proactor()->Await([l] { return l->ReconfigureTLS(); })) {
LOG(WARNING) << "TLS hot reload failed on a listener";
ok = false;
break;
}
}

if (ok) {
LOG(INFO) << "TLS certs reloaded successfully";
last_cert_mtime = cert_mtime;
last_key_mtime = key_mtime;
last_ca_cert_mtime = ca_cert_mtime;
}
}
}

std::error_code ServerFamily::LoadRdb(const std::string& rdb_file, LoadExistingKeys existing_keys,
LoadOptions* load_opts, RdbLoadContext* load_context,
detail::SnapshotStorage* storage) {
Expand Down
7 changes: 7 additions & 0 deletions src/server/server_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,11 @@ class ServerFamily {

void SnapshotScheduling() ABSL_LOCKS_EXCLUDED(loading_stats_mu_);

// Periodically checks whether the TLS cert/key files on disk have changed
// (by mtime) and calls ReconfigureTLS() on every listener when they have.
// Driven by --tls_reload_interval_secs. Exits immediately if the flag is 0.
void TlsReloadScheduling();

void SendInvalidationMessages() const;

std::optional<SaveCmdOptions> GetSaveCmdOpts(CmdArgList args, CommandContext* cmd_cntx);
Expand Down Expand Up @@ -420,6 +425,7 @@ class ServerFamily {
void ChangeConnectionAccept(bool accept);

util::fb2::Fiber snapshot_schedule_fb_;
util::fb2::Fiber tls_reload_fb_;
util::fb2::Fiber load_fiber_;

Service& service_;
Expand Down Expand Up @@ -450,6 +456,7 @@ class ServerFamily {
bool save_on_shutdown_{true};

util::fb2::Done schedule_done_;
util::fb2::Done tls_reload_done_;
std::unique_ptr<util::fb2::FiberQueueThreadPool> fq_threadpool_;
std::shared_ptr<detail::SnapshotStorage> snapshot_storage_;

Expand Down
80 changes: 80 additions & 0 deletions tests/dragonfly/tls_conf_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import asyncio
import os
import shutil

import pytest
import redis
from .utility import *
Expand Down Expand Up @@ -164,3 +168,79 @@ async def test_config_disable_tls(
# Connecting without TLS should succeed.
async with server.client() as client_unauth:
await client_unauth.ping()


async def test_tls_hot_reload(df_factory, with_tls_ca_cert_args, tmp_dir):
"""Verify --tls_reload_interval_secs detects cert file changes on disk.

1. Start the server with TLS certs signed by CA-A and a short reload interval.
2. Overwrite the cert/key files on disk with certs signed by CA-B.
3. Wait for the reload timer to fire.
4. New connections with CA-A should fail; connections with CA-B should work.
5. The existing (pre-reload) connection stays alive throughout.
"""
# Paths the server will read — these get overwritten mid-test.
server_key = os.path.join(tmp_dir, "reload-df-key.pem")
server_cert = os.path.join(tmp_dir, "reload-df-cert.pem")
server_req = os.path.join(tmp_dir, "reload-df-req.pem")

# Generate initial cert/key from the session CA.
gen_certificate(
with_tls_ca_cert_args["ca_key"],
with_tls_ca_cert_args["ca_cert"],
server_req,
server_key,
server_cert,
)

server_args = {
"tls": None,
"tls_key_file": server_key,
"tls_cert_file": server_cert,
"requirepass": "XXX",
"tls_reload_interval_secs": 60,
}

with df_factory.create(**server_args) as server:
# Establish a connection with the original CA — should work.
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as client:
await client.ping()

# Generate a brand-new CA and server cert, then overwrite the
# files in-place. The flag paths don't change — only the contents.
new_ca_key = os.path.join(tmp_dir, "reload-ca-key-new.pem")
new_ca_cert = os.path.join(tmp_dir, "reload-ca-cert-new.pem")
gen_ca_cert(new_ca_key, new_ca_cert)

new_key_tmp = os.path.join(tmp_dir, "reload-df-key-new.pem")
new_cert_tmp = os.path.join(tmp_dir, "reload-df-cert-new.pem")
new_req_tmp = os.path.join(tmp_dir, "reload-df-req-new.pem")
gen_certificate(
new_ca_key, new_ca_cert, new_req_tmp, new_key_tmp, new_cert_tmp
)

# Atomic-ish overwrite of the files the server is watching.
shutil.copy2(new_key_tmp, server_key)
shutil.copy2(new_cert_tmp, server_cert)

# Wait for the reload timer to pick up the mtime change.
await asyncio.sleep(65)

# Existing connection should still be alive (SSL session unchanged).
await client.ping()

# New connection with the OLD CA should fail — server now presents
# a cert signed by the new CA.
with pytest.raises(redis.exceptions.ConnectionError):
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"]
) as bad_client:
await bad_client.ping()

# New connection with the NEW CA should succeed.
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=new_ca_cert
) as new_client:
await new_client.ping()
Loading