diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 3ecc4b1eb5be..0a8ae79c3b54 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -157,6 +157,9 @@ ABSL_FLAG(string, availability_zone, "", "server availability zone, used by clients to read from local-zone replicas"); ABSL_FLAG(bool, keep_legacy_memory_metrics, true, "legacy metrics format"); +ABSL_FLAG(uint32_t, tls_reload_interval_secs, 0, + "If non-zero, periodically checks whether the TLS cert/key files on disk have changed " + "and reloads them automatically. Interval in seconds (minimum 60). 0 to disable."); // TODO deprecate when flipped in production ABSL_FLAG(bool, replicaof_no_one_start_journal, true, "when set, preserves journal offsets after REPLICAOF NO ONE"); @@ -165,6 +168,8 @@ ABSL_DECLARE_FLAG(int32_t, port); ABSL_DECLARE_FLAG(bool, cache_mode); ABSL_DECLARE_FLAG(int32_t, hz); ABSL_DECLARE_FLAG(bool, tls); +ABSL_DECLARE_FLAG(string, tls_cert_file); +ABSL_DECLARE_FLAG(string, tls_key_file); ABSL_DECLARE_FLAG(string, tls_ca_cert_file); ABSL_DECLARE_FLAG(string, tls_ca_cert_dir); ABSL_DECLARE_FLAG(int, replica_priority); @@ -1208,6 +1213,11 @@ void ServerFamily::Init(util::AcceptServer* acceptor, std::vectorLaunchFiber([this] { TlsReloadScheduling(); }); } void ServerFamily::LoadFromSnapshot() { @@ -1258,6 +1268,9 @@ void ServerFamily::Shutdown() { JoinSnapshotSchedule(); + tls_reload_done_.Notify(); + tls_reload_fb_.JoinIfNeeded(); + bg_save_fb_.JoinIfNeeded(); if (save_on_shutdown_ && !absl::GetFlag(FLAGS_dbfilename).empty()) { @@ -1499,6 +1512,76 @@ void ServerFamily::SnapshotScheduling() { } } +namespace { + +time_t FileMtime(const std::string& path) { + if (path.empty()) + return 0; + struct stat st; + if (stat(path.c_str(), &st) != 0) + return 0; + return st.st_mtim.tv_sec; +} + +} // namespace + +void ServerFamily::TlsReloadScheduling() { + uint32_t interval = absl::GetFlag(FLAGS_tls_reload_interval_secs); + if (interval == 0) + return; + + if (!absl::GetFlag(FLAGS_tls)) { + LOG(WARNING) << "tls_reload_interval_secs is set but TLS is not enabled; " + "ignoring"; + return; + } + + if (interval < 60) { + LOG(WARNING) << "tls_reload_interval_secs must be >= 60, got " << interval << "; clamping"; + interval = 60; + } + + LOG(INFO) << "TLS cert hot reload enabled, checking every " << interval << "s"; + + // Snapshot the initial mtimes so we only reload on change. + time_t last_cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_cert_file)); + time_t last_key_mtime = FileMtime(absl::GetFlag(FLAGS_tls_key_file)); + time_t last_ca_cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_ca_cert_file)); + + while (true) { + if (tls_reload_done_.WaitFor(chrono::seconds(interval))) { + break; + } + + time_t cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_cert_file)); + time_t key_mtime = FileMtime(absl::GetFlag(FLAGS_tls_key_file)); + time_t ca_cert_mtime = FileMtime(absl::GetFlag(FLAGS_tls_ca_cert_file)); + + if (cert_mtime == last_cert_mtime && key_mtime == last_key_mtime && + ca_cert_mtime == last_ca_cert_mtime) { + continue; + } + + LOG(INFO) << "TLS cert/key file change detected, reloading"; + + bool ok = true; + for (facade::Listener* l : listeners_) { + if (!l->socket()->proactor()->Await([l] { return l->ReconfigureTLS(); })) { + LOG(WARNING) << "TLS hot reload failed on a listener"; + ok = false; + break; + } + } + + if (ok) { + LOG(INFO) << "TLS certs reloaded successfully"; + last_cert_mtime = cert_mtime; + last_key_mtime = key_mtime; + last_ca_cert_mtime = ca_cert_mtime; + } + } +} + std::error_code ServerFamily::LoadRdb(const std::string& rdb_file, LoadExistingKeys existing_keys, LoadOptions* load_opts, RdbLoadContext* load_context, detail::SnapshotStorage* storage) { diff --git a/src/server/server_family.h b/src/server/server_family.h index 8862d7122924..b718dce3e506 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -393,6 +393,11 @@ class ServerFamily { void SnapshotScheduling() ABSL_LOCKS_EXCLUDED(loading_stats_mu_); + // Periodically checks whether the TLS cert/key files on disk have changed + // (by mtime) and calls ReconfigureTLS() on every listener when they have. + // Driven by --tls_reload_interval_secs. Exits immediately if the flag is 0. + void TlsReloadScheduling(); + void SendInvalidationMessages() const; std::optional GetSaveCmdOpts(CmdArgList args, CommandContext* cmd_cntx); @@ -420,6 +425,7 @@ class ServerFamily { void ChangeConnectionAccept(bool accept); util::fb2::Fiber snapshot_schedule_fb_; + util::fb2::Fiber tls_reload_fb_; util::fb2::Fiber load_fiber_; Service& service_; @@ -450,6 +456,7 @@ class ServerFamily { bool save_on_shutdown_{true}; util::fb2::Done schedule_done_; + util::fb2::Done tls_reload_done_; std::unique_ptr fq_threadpool_; std::shared_ptr snapshot_storage_; diff --git a/tests/dragonfly/tls_conf_test.py b/tests/dragonfly/tls_conf_test.py index e7eba74b41db..fb14c06ecf39 100644 --- a/tests/dragonfly/tls_conf_test.py +++ b/tests/dragonfly/tls_conf_test.py @@ -1,3 +1,7 @@ +import asyncio +import os +import shutil + import pytest import redis from .utility import * @@ -164,3 +168,79 @@ async def test_config_disable_tls( # Connecting without TLS should succeed. async with server.client() as client_unauth: await client_unauth.ping() + + +async def test_tls_hot_reload(df_factory, with_tls_ca_cert_args, tmp_dir): + """Verify --tls_reload_interval_secs detects cert file changes on disk. + + 1. Start the server with TLS certs signed by CA-A and a short reload interval. + 2. Overwrite the cert/key files on disk with certs signed by CA-B. + 3. Wait for the reload timer to fire. + 4. New connections with CA-A should fail; connections with CA-B should work. + 5. The existing (pre-reload) connection stays alive throughout. + """ + # Paths the server will read — these get overwritten mid-test. + server_key = os.path.join(tmp_dir, "reload-df-key.pem") + server_cert = os.path.join(tmp_dir, "reload-df-cert.pem") + server_req = os.path.join(tmp_dir, "reload-df-req.pem") + + # Generate initial cert/key from the session CA. + gen_certificate( + with_tls_ca_cert_args["ca_key"], + with_tls_ca_cert_args["ca_cert"], + server_req, + server_key, + server_cert, + ) + + server_args = { + "tls": None, + "tls_key_file": server_key, + "tls_cert_file": server_cert, + "requirepass": "XXX", + "tls_reload_interval_secs": 60, + } + + with df_factory.create(**server_args) as server: + # Establish a connection with the original CA — should work. + async with server.client( + ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] + ) as client: + await client.ping() + + # Generate a brand-new CA and server cert, then overwrite the + # files in-place. The flag paths don't change — only the contents. + new_ca_key = os.path.join(tmp_dir, "reload-ca-key-new.pem") + new_ca_cert = os.path.join(tmp_dir, "reload-ca-cert-new.pem") + gen_ca_cert(new_ca_key, new_ca_cert) + + new_key_tmp = os.path.join(tmp_dir, "reload-df-key-new.pem") + new_cert_tmp = os.path.join(tmp_dir, "reload-df-cert-new.pem") + new_req_tmp = os.path.join(tmp_dir, "reload-df-req-new.pem") + gen_certificate( + new_ca_key, new_ca_cert, new_req_tmp, new_key_tmp, new_cert_tmp + ) + + # Atomic-ish overwrite of the files the server is watching. + shutil.copy2(new_key_tmp, server_key) + shutil.copy2(new_cert_tmp, server_cert) + + # Wait for the reload timer to pick up the mtime change. + await asyncio.sleep(65) + + # Existing connection should still be alive (SSL session unchanged). + await client.ping() + + # New connection with the OLD CA should fail — server now presents + # a cert signed by the new CA. + with pytest.raises(redis.exceptions.ConnectionError): + async with server.client( + ssl=True, password="XXX", ssl_ca_certs=with_tls_ca_cert_args["ca_cert"] + ) as bad_client: + await bad_client.ping() + + # New connection with the NEW CA should succeed. + async with server.client( + ssl=True, password="XXX", ssl_ca_certs=new_ca_cert + ) as new_client: + await new_client.ping()