Skip to content

Commit

Permalink
tls: Add optional builder + future-wait to cert reload callback + exp…
Browse files Browse the repository at this point in the history
…ose rebuild

Refs scylladb#2513

Adds a more advanced callback type, which takes both actual reloading builder as
argument (into which new files are loaded), and allows proper future-wait in
callback.

Exposes certificates rebuilding (via builder) to allow "manual", quick, reload of certs.

The point of these seemingly small changes is to allow client code to, for example,
limit actual reloadable_certs (and by extension inotify watches) to shard 0 (or whatever),
and simply use this as a trigger for manual reload of other shards.

Note: we cannot do any magical "shard-0-only" file monitor in the objects themselves,
not without making the certs/builders sharded or similarly stored (which contradict the
general design of light objects, copyable between shards etc). But with this, a calling
app in which certs _are_ held in sharded manners, we can fairly easily delegate non-shard-0
ops in a way that fits that topology.

Note: a builder can be _called_ from any shard (as long as it is safe in its originating
shard), but the objects returned are only valid on the current shard.

Similarly, it is safe to share the reloading builder across shards _in the callback_,
since rebuilding is blocked for the duration of the call.
  • Loading branch information
Calle Wilund committed Dec 11, 2024
1 parent 7068d03 commit 9a0b88f
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 13 deletions.
12 changes: 10 additions & 2 deletions include/seastar/net/tls.hh
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,10 @@ namespace tls {
};

class reloadable_credentials_base;
class credentials_builder;

using reload_callback = std::function<void(const std::unordered_set<sstring>&, std::exception_ptr)>;
using reload_callback_ex = std::function<future<>(const credentials_builder&, const std::unordered_set<sstring>&, std::exception_ptr)>;

/**
* Intentionally "primitive", and more importantly, copyable
Expand Down Expand Up @@ -320,10 +322,16 @@ namespace tls {
shared_ptr<certificate_credentials> build_certificate_credentials() const;
shared_ptr<server_credentials> build_server_credentials() const;

void rebuild(certificate_credentials&) const;
void rebuild(server_credentials&) const;

// same as above, but any files used for certs/keys etc will be watched
// for modification and reloaded if changed
future<shared_ptr<certificate_credentials>> build_reloadable_certificate_credentials(reload_callback = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<server_credentials>> build_reloadable_server_credentials(reload_callback = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<certificate_credentials>> build_reloadable_certificate_credentials(reload_callback_ex = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<server_credentials>> build_reloadable_server_credentials(reload_callback_ex = {}, std::optional<std::chrono::milliseconds> tolerance = {}) const;

future<shared_ptr<certificate_credentials>> build_reloadable_certificate_credentials(reload_callback, std::optional<std::chrono::milliseconds> tolerance = {}) const;
future<shared_ptr<server_credentials>> build_reloadable_server_credentials(reload_callback, std::optional<std::chrono::milliseconds> tolerance = {}) const;
private:
friend class reloadable_credentials_base;

Expand Down
44 changes: 33 additions & 11 deletions src/net/tls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ class tls::reloadable_credentials_base {
public:
using time_point = std::chrono::system_clock::time_point;

reloading_builder(credentials_builder b, reload_callback cb, reloadable_credentials_base* creds, delay_type delay)
reloading_builder(credentials_builder b, reload_callback_ex cb, reloadable_credentials_base* creds, delay_type delay)
: credentials_builder(std::move(b))
, _cb(std::move(cb))
, _creds(creds)
Expand Down Expand Up @@ -955,7 +955,7 @@ class tls::reloadable_credentials_base {
}
void do_callback(std::exception_ptr ep = {}) {
if (_cb && !_files.empty()) {
_cb(boost::copy_range<std::unordered_set<sstring>>(_files | boost::adaptors::map_keys), std::move(ep));
_cb(*this, boost::copy_range<std::unordered_set<sstring>>(_files | boost::adaptors::map_keys), std::move(ep)).get();
}
}
// called from seastar::thread
Expand Down Expand Up @@ -988,7 +988,7 @@ class tls::reloadable_credentials_base {
});
}

reload_callback _cb;
reload_callback_ex _cb;
reloadable_credentials_base* _creds;
fsnotifier _fsn;
std::unordered_map<fsnotifier::watch_token, std::pair<fsnotifier::watch, sstring>> _watches;
Expand All @@ -997,7 +997,7 @@ class tls::reloadable_credentials_base {
timer<> _timer;
delay_type _delay;
};
reloadable_credentials_base(credentials_builder builder, reload_callback cb, delay_type delay = default_tolerance)
reloadable_credentials_base(credentials_builder builder, reload_callback_ex cb, delay_type delay = default_tolerance)
: _builder(seastar::make_shared<reloading_builder>(std::move(builder), std::move(cb), this, delay))
{
_builder->start();
Expand All @@ -1016,7 +1016,7 @@ class tls::reloadable_credentials_base {
template<typename Base>
class tls::reloadable_credentials : public Base, public tls::reloadable_credentials_base {
public:
reloadable_credentials(credentials_builder builder, reload_callback cb, Base b, delay_type delay = default_tolerance)
reloadable_credentials(credentials_builder builder, reload_callback_ex cb, Base b, delay_type delay = default_tolerance)
: Base(std::move(b))
, tls::reloadable_credentials_base(std::move(builder), std::move(cb), delay)
{}
Expand All @@ -1025,30 +1025,52 @@ class tls::reloadable_credentials : public Base, public tls::reloadable_credenti

template<>
void tls::reloadable_credentials<tls::certificate_credentials>::rebuild(const credentials_builder& builder) {
auto tmp = builder.build_certificate_credentials();
this->_impl = std::move(tmp->_impl);
builder.rebuild(*this);
}

template<>
void tls::reloadable_credentials<tls::server_credentials>::rebuild(const credentials_builder& builder) {
auto tmp = builder.build_server_credentials();
this->_impl = std::move(tmp->_impl);
builder.rebuild(*this);
}

future<shared_ptr<tls::certificate_credentials>> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
void tls::credentials_builder::rebuild(certificate_credentials& creds) const {
auto tmp = build_certificate_credentials();
creds._impl = std::move(tmp->_impl);
}

void tls::credentials_builder::rebuild(server_credentials& creds) const {
auto tmp = build_server_credentials();
creds._impl = std::move(tmp->_impl);
}

future<shared_ptr<tls::certificate_credentials>> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback_ex cb, std::optional<std::chrono::milliseconds> tolerance) const {
auto creds = seastar::make_shared<reloadable_credentials<tls::certificate_credentials>>(*this, std::move(cb), std::move(*build_certificate_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance));
return creds->init().then([creds] {
return make_ready_future<shared_ptr<tls::certificate_credentials>>(creds);
});
}

future<shared_ptr<tls::server_credentials>> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
future<shared_ptr<tls::server_credentials>> tls::credentials_builder::build_reloadable_server_credentials(reload_callback_ex cb, std::optional<std::chrono::milliseconds> tolerance) const {
auto creds = seastar::make_shared<reloadable_credentials<tls::server_credentials>>(*this, std::move(cb), std::move(*build_server_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance));
return creds->init().then([creds] {
return make_ready_future<shared_ptr<tls::server_credentials>>(creds);
});
}

future<shared_ptr<tls::certificate_credentials>> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
return build_reloadable_certificate_credentials([cb = std::move(cb)](const credentials_builder&, const std::unordered_set<sstring>& files, std::exception_ptr p) {
cb(files, p);
return make_ready_future<>();
}, tolerance);
}

future<shared_ptr<tls::server_credentials>> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional<std::chrono::milliseconds> tolerance) const {
return build_reloadable_server_credentials([cb = std::move(cb)](const credentials_builder&, const std::unordered_set<sstring>& files, std::exception_ptr p) {
cb(files, p);
return make_ready_future<>();
}, tolerance);
}

namespace tls {

/**
Expand Down
124 changes: 124 additions & 0 deletions tests/unit/tls_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <seastar/net/inet_address.hh>
#include <seastar/testing/test_case.hh>
#include <seastar/testing/thread_test_case.hh>
#include <seastar/util/defer.hh>

#include <boost/dll.hpp>

Expand Down Expand Up @@ -1595,3 +1596,126 @@ SEASTAR_THREAD_TEST_CASE(test_tls13_session_tickets) {
}

}

SEASTAR_THREAD_TEST_CASE(test_reload_certificates_with_only_shard0_notify) {
tmpdir tmp;

namespace fs = std::filesystem;

// copy the wrong certs. We don't trust these
// blocking calls, but this is a test and seastar does not have a copy
// util and I am lazy...
fs::copy_file(certfile("other.crt"), tmp.path() / "test.crt");
fs::copy_file(certfile("other.key"), tmp.path() / "test.key");

auto cert = (tmp.path() / "test.crt").native();
auto key = (tmp.path() / "test.key").native();
promise<> p;

tls::credentials_builder b;
b.set_x509_key_file(cert, key, tls::x509_crt_format::PEM).get();
b.set_dh_level();

auto certs = b.build_server_credentials();

auto shard_1_certs = smp::submit_to(1, [&]() -> future<shared_ptr<tls::server_credentials>> {
co_return co_await b.build_reloadable_server_credentials([&, changed = std::unordered_set<sstring>{}](const tls::credentials_builder& builder, const std::unordered_set<sstring>& files, std::exception_ptr ep) mutable -> future<> {
if (ep) {
co_return;
}
changed.insert(files.begin(), files.end());
if (changed.count(cert) && changed.count(key)) {
// shard one certs are not reloadable. We issue a reload of them from shard 0
// - to save inotify instances.
co_await smp::submit_to(0, [&] {
builder.rebuild(*certs);
p.set_value();
});
}
});
}).get();

auto def = defer([&]() noexcept {
try {
smp::submit_to(0, [&] {
shard_1_certs = nullptr;
}).get();
} catch (...) {}
});

::listen_options opts;
opts.reuse_address = true;
auto addr = ::make_ipv4_address( {0x7f000001, 4712});
auto server = tls::listen(certs, addr, opts);

tls::credentials_builder b2;
b2.set_x509_trust_file(certfile("catest.pem"), tls::x509_crt_format::PEM).get();

{
auto sa = server.accept();
auto c = tls::connect(b2.build_certificate_credentials(), addr).get();
auto s = sa.get();
auto in = s.connection.input();

output_stream<char> out(c.output().detach(), 4096);

try {
out.write("apa").get();
auto f = out.flush();
auto f2 = in.read();

try {
f.get();
BOOST_FAIL("should not reach");
} catch (tls::verification_error&) {
// ok
}
try {
out.close().get();
} catch (...) {
}

try {
f2.get();
BOOST_FAIL("should not reach");
} catch (...) {
// ok
}
try {
in.close().get();
} catch (...) {
}
} catch (tls::verification_error&) {
// ok
}
}

// copy the right (trusted) certs over the old ones.
fs::copy_file(certfile("test.crt"), tmp.path() / "test0.crt");
fs::copy_file(certfile("test.key"), tmp.path() / "test0.key");

rename_file((tmp.path() / "test0.crt").native(), (tmp.path() / "test.crt").native()).get();
rename_file((tmp.path() / "test0.key").native(), (tmp.path() / "test.key").native()).get();

p.get_future().get();

// now it should work
{
auto sa = server.accept();
auto c = tls::connect(b2.build_certificate_credentials(), addr).get();
auto s = sa.get();
auto in = s.connection.input();

output_stream<char> out(c.output().detach(), 4096);

out.write("apa").get();
auto f = out.flush();
auto buf = in.read().get();
f.get();
out.close().get();
in.read().get(); // ignore - just want eof
in.close().get();

BOOST_CHECK_EQUAL(sstring(buf.begin(), buf.end()), "apa");
}
}

0 comments on commit 9a0b88f

Please sign in to comment.