diff --git a/include/seastar/net/tls.hh b/include/seastar/net/tls.hh index 704a50d5d0..491f8c5357 100644 --- a/include/seastar/net/tls.hh +++ b/include/seastar/net/tls.hh @@ -283,8 +283,10 @@ namespace tls { }; class reloadable_credentials_base; + class credentials_builder; using reload_callback = std::function&, std::exception_ptr)>; + using reload_callback_ex = std::function(const credentials_builder&, const std::unordered_set&, std::exception_ptr)>; /** * Intentionally "primitive", and more importantly, copyable @@ -320,10 +322,16 @@ namespace tls { shared_ptr build_certificate_credentials() const; shared_ptr build_server_credentials() const; + void rebuild(certificate_credentials&) const; + void rebuild(server_credentials&) const; + // same as above, but any files used for certs/keys etc will be watched // for modification and reloaded if changed - future> build_reloadable_certificate_credentials(reload_callback = {}, std::optional tolerance = {}) const; - future> build_reloadable_server_credentials(reload_callback = {}, std::optional tolerance = {}) const; + future> build_reloadable_certificate_credentials(reload_callback_ex = {}, std::optional tolerance = {}) const; + future> build_reloadable_server_credentials(reload_callback_ex = {}, std::optional tolerance = {}) const; + + future> build_reloadable_certificate_credentials(reload_callback, std::optional tolerance = {}) const; + future> build_reloadable_server_credentials(reload_callback, std::optional tolerance = {}) const; private: friend class reloadable_credentials_base; diff --git a/src/net/tls.cc b/src/net/tls.cc index f1e3410d53..0d99ab7234 100644 --- a/src/net/tls.cc +++ b/src/net/tls.cc @@ -796,7 +796,7 @@ class tls::reloadable_credentials_base { public: using time_point = std::chrono::system_clock::time_point; - reloading_builder(credentials_builder b, reload_callback cb, reloadable_credentials_base* creds, delay_type delay) + reloading_builder(credentials_builder b, reload_callback_ex cb, reloadable_credentials_base* creds, delay_type delay) : credentials_builder(std::move(b)) , _cb(std::move(cb)) , _creds(creds) @@ -955,7 +955,7 @@ class tls::reloadable_credentials_base { } void do_callback(std::exception_ptr ep = {}) { if (_cb && !_files.empty()) { - _cb(boost::copy_range>(_files | boost::adaptors::map_keys), std::move(ep)); + _cb(*this, boost::copy_range>(_files | boost::adaptors::map_keys), std::move(ep)).get(); } } // called from seastar::thread @@ -988,7 +988,7 @@ class tls::reloadable_credentials_base { }); } - reload_callback _cb; + reload_callback_ex _cb; reloadable_credentials_base* _creds; fsnotifier _fsn; std::unordered_map> _watches; @@ -997,7 +997,7 @@ class tls::reloadable_credentials_base { timer<> _timer; delay_type _delay; }; - reloadable_credentials_base(credentials_builder builder, reload_callback cb, delay_type delay = default_tolerance) + reloadable_credentials_base(credentials_builder builder, reload_callback_ex cb, delay_type delay = default_tolerance) : _builder(seastar::make_shared(std::move(builder), std::move(cb), this, delay)) { _builder->start(); @@ -1016,7 +1016,7 @@ class tls::reloadable_credentials_base { template class tls::reloadable_credentials : public Base, public tls::reloadable_credentials_base { public: - reloadable_credentials(credentials_builder builder, reload_callback cb, Base b, delay_type delay = default_tolerance) + reloadable_credentials(credentials_builder builder, reload_callback_ex cb, Base b, delay_type delay = default_tolerance) : Base(std::move(b)) , tls::reloadable_credentials_base(std::move(builder), std::move(cb), delay) {} @@ -1025,30 +1025,52 @@ class tls::reloadable_credentials : public Base, public tls::reloadable_credenti template<> void tls::reloadable_credentials::rebuild(const credentials_builder& builder) { - auto tmp = builder.build_certificate_credentials(); - this->_impl = std::move(tmp->_impl); + builder.rebuild(*this); } template<> void tls::reloadable_credentials::rebuild(const credentials_builder& builder) { - auto tmp = builder.build_server_credentials(); - this->_impl = std::move(tmp->_impl); + builder.rebuild(*this); } -future> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional tolerance) const { +void tls::credentials_builder::rebuild(certificate_credentials& creds) const { + auto tmp = build_certificate_credentials(); + creds._impl = std::move(tmp->_impl); +} + +void tls::credentials_builder::rebuild(server_credentials& creds) const { + auto tmp = build_server_credentials(); + creds._impl = std::move(tmp->_impl); +} + +future> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback_ex cb, std::optional tolerance) const { auto creds = seastar::make_shared>(*this, std::move(cb), std::move(*build_certificate_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance)); return creds->init().then([creds] { return make_ready_future>(creds); }); } -future> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional tolerance) const { +future> tls::credentials_builder::build_reloadable_server_credentials(reload_callback_ex cb, std::optional tolerance) const { auto creds = seastar::make_shared>(*this, std::move(cb), std::move(*build_server_credentials()), tolerance.value_or(reloadable_credentials_base::default_tolerance)); return creds->init().then([creds] { return make_ready_future>(creds); }); } +future> tls::credentials_builder::build_reloadable_certificate_credentials(reload_callback cb, std::optional tolerance) const { + return build_reloadable_certificate_credentials([cb = std::move(cb)](const credentials_builder&, const std::unordered_set& files, std::exception_ptr p) { + cb(files, p); + return make_ready_future<>(); + }, tolerance); +} + +future> tls::credentials_builder::build_reloadable_server_credentials(reload_callback cb, std::optional tolerance) const { + return build_reloadable_server_credentials([cb = std::move(cb)](const credentials_builder&, const std::unordered_set& files, std::exception_ptr p) { + cb(files, p); + return make_ready_future<>(); + }, tolerance); +} + namespace tls { /** diff --git a/tests/unit/tls_test.cc b/tests/unit/tls_test.cc index 431de2c73c..b2dab764b9 100644 --- a/tests/unit/tls_test.cc +++ b/tests/unit/tls_test.cc @@ -40,6 +40,7 @@ #include #include #include +#include #include @@ -1595,3 +1596,126 @@ SEASTAR_THREAD_TEST_CASE(test_tls13_session_tickets) { } } + +SEASTAR_THREAD_TEST_CASE(test_reload_certificates_with_only_shard0_notify) { + tmpdir tmp; + + namespace fs = std::filesystem; + + // copy the wrong certs. We don't trust these + // blocking calls, but this is a test and seastar does not have a copy + // util and I am lazy... + fs::copy_file(certfile("other.crt"), tmp.path() / "test.crt"); + fs::copy_file(certfile("other.key"), tmp.path() / "test.key"); + + auto cert = (tmp.path() / "test.crt").native(); + auto key = (tmp.path() / "test.key").native(); + promise<> p; + + tls::credentials_builder b; + b.set_x509_key_file(cert, key, tls::x509_crt_format::PEM).get(); + b.set_dh_level(); + + auto certs = b.build_server_credentials(); + + auto shard_1_certs = smp::submit_to(1, [&]() -> future> { + co_return co_await b.build_reloadable_server_credentials([&, changed = std::unordered_set{}](const tls::credentials_builder& builder, const std::unordered_set& files, std::exception_ptr ep) mutable -> future<> { + if (ep) { + co_return; + } + changed.insert(files.begin(), files.end()); + if (changed.count(cert) && changed.count(key)) { + // shard one certs are not reloadable. We issue a reload of them from shard 0 + // - to save inotify instances. + co_await smp::submit_to(0, [&] { + builder.rebuild(*certs); + p.set_value(); + }); + } + }); + }).get(); + + auto def = defer([&]() noexcept { + try { + smp::submit_to(0, [&] { + shard_1_certs = nullptr; + }).get(); + } catch (...) {} + }); + + ::listen_options opts; + opts.reuse_address = true; + auto addr = ::make_ipv4_address( {0x7f000001, 4712}); + auto server = tls::listen(certs, addr, opts); + + tls::credentials_builder b2; + b2.set_x509_trust_file(certfile("catest.pem"), tls::x509_crt_format::PEM).get(); + + { + auto sa = server.accept(); + auto c = tls::connect(b2.build_certificate_credentials(), addr).get(); + auto s = sa.get(); + auto in = s.connection.input(); + + output_stream out(c.output().detach(), 4096); + + try { + out.write("apa").get(); + auto f = out.flush(); + auto f2 = in.read(); + + try { + f.get(); + BOOST_FAIL("should not reach"); + } catch (tls::verification_error&) { + // ok + } + try { + out.close().get(); + } catch (...) { + } + + try { + f2.get(); + BOOST_FAIL("should not reach"); + } catch (...) { + // ok + } + try { + in.close().get(); + } catch (...) { + } + } catch (tls::verification_error&) { + // ok + } + } + + // copy the right (trusted) certs over the old ones. + fs::copy_file(certfile("test.crt"), tmp.path() / "test0.crt"); + fs::copy_file(certfile("test.key"), tmp.path() / "test0.key"); + + rename_file((tmp.path() / "test0.crt").native(), (tmp.path() / "test.crt").native()).get(); + rename_file((tmp.path() / "test0.key").native(), (tmp.path() / "test.key").native()).get(); + + p.get_future().get(); + + // now it should work + { + auto sa = server.accept(); + auto c = tls::connect(b2.build_certificate_credentials(), addr).get(); + auto s = sa.get(); + auto in = s.connection.input(); + + output_stream out(c.output().detach(), 4096); + + out.write("apa").get(); + auto f = out.flush(); + auto buf = in.read().get(); + f.get(); + out.close().get(); + in.read().get(); // ignore - just want eof + in.close().get(); + + BOOST_CHECK_EQUAL(sstring(buf.begin(), buf.end()), "apa"); + } +} \ No newline at end of file