Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions linkerd/meshtls/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Ensure that at least one TLS implementation feature is enabled.
static TLS_FEATURES: &[&str] = &["rustls"];
if !TLS_FEATURES
.iter()
.any(|f| std::env::var_os(&*format!("CARGO_FEATURE_{}", f.to_ascii_uppercase())).is_some())
{
return Err(format!(
"at least one of the following TLS implementations must be enabled: '{}'",
TLS_FEATURES.join("', '"),
)
.into());
}

Ok(())
Comment on lines +2 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I think this could be achieved without a build script by just sticking

#![cfg(not(any(feature = "rustls"))]
compile_error!("at least one of the following TLS implementations must be enabled: 'rustls}''")

in lib.rs (and adding the other implementation feature flags as needed)

}
106 changes: 41 additions & 65 deletions linkerd/meshtls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ impl NewService<ClientTls> for NewClient {
type Service = Connect;

fn new_service(&self, target: ClientTls) -> Self::Service {
#[cfg(feature = "rustls")]
if let Self::Rustls(new_client) = self {
return Connect::Rustls(new_client.new_service(target));
match self {
#[cfg(feature = "rustls")]
Self::Rustls(new_client) => Connect::Rustls(new_client.new_service(target)),
}

unreachable!()
}
}

Expand All @@ -61,22 +59,18 @@ where
type Future = ConnectFuture<I>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
#[cfg(feature = "rustls")]
if let Self::Rustls(connect) = self {
return <rustls::Connect as Service<I>>::poll_ready(connect, cx);
match self {
#[cfg(feature = "rustls")]
Self::Rustls(connect) => <rustls::Connect as Service<I>>::poll_ready(connect, cx),
}

unreachable!()
}

#[inline]
fn call(&mut self, io: I) -> Self::Future {
#[cfg(feature = "rustls")]
if let Self::Rustls(connect) = self {
return ConnectFuture::Rustls(connect.call(io));
match self {
#[cfg(feature = "rustls")]
Self::Rustls(connect) => ConnectFuture::Rustls(connect.call(io)),
}

unreachable!()
}
}

Expand All @@ -89,15 +83,13 @@ where
type Output = io::Result<ClientIo<I>>;

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

#[cfg(feature = "rustls")]
if let ConnectFutureProj::Rustls(f) = this {
let res = futures::ready!(f.poll(cx));
return Poll::Ready(res.map(ClientIo::Rustls));
match self.project() {
#[cfg(feature = "rustls")]
ConnectFutureProj::Rustls(f) => {
let res = futures::ready!(f.poll(cx));
Poll::Ready(res.map(ClientIo::Rustls))
}
}

unreachable!()
}
}

Expand All @@ -110,52 +102,36 @@ impl<I: io::AsyncRead + io::AsyncWrite + Unpin> io::AsyncRead for ClientIo<I> {
cx: &mut Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> io::Poll<()> {
let this = self.project();

#[cfg(feature = "rustls")]
if let ClientIoProj::Rustls(io) = this {
return io.poll_read(cx, buf);
match self.project() {
#[cfg(feature = "rustls")]
ClientIoProj::Rustls(io) => io.poll_read(cx, buf),
}

unreachable!()
}
}

impl<I: io::AsyncRead + io::AsyncWrite + Unpin> io::AsyncWrite for ClientIo<I> {
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> {
let this = self.project();

#[cfg(feature = "rustls")]
if let ClientIoProj::Rustls(io) = this {
return io.poll_flush(cx);
match self.project() {
#[cfg(feature = "rustls")]
ClientIoProj::Rustls(io) => io.poll_flush(cx),
}

unreachable!()
}

#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> {
let this = self.project();

#[cfg(feature = "rustls")]
if let ClientIoProj::Rustls(io) = this {
return io.poll_shutdown(cx);
match self.project() {
#[cfg(feature = "rustls")]
ClientIoProj::Rustls(io) => io.poll_shutdown(cx),
}

unreachable!()
}

#[inline]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll<usize> {
let this = self.project();

#[cfg(feature = "rustls")]
if let ClientIoProj::Rustls(io) = this {
return io.poll_write(cx, buf);
match self.project() {
#[cfg(feature = "rustls")]
ClientIoProj::Rustls(io) => io.poll_write(cx, buf),
}

unreachable!()
}

#[inline]
Expand All @@ -164,37 +140,37 @@ impl<I: io::AsyncRead + io::AsyncWrite + Unpin> io::AsyncWrite for ClientIo<I> {
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.project();

#[cfg(feature = "rustls")]
if let ClientIoProj::Rustls(io) = this {
return io.poll_write_vectored(cx, bufs);
match self.project() {
#[cfg(feature = "rustls")]
ClientIoProj::Rustls(io) => io.poll_write_vectored(cx, bufs),
}

unreachable!()
}

#[inline]
fn is_write_vectored(&self) -> bool {
unimplemented!()
match self {
#[cfg(feature = "rustls")]
Self::Rustls(io) => io.is_write_vectored(),
}
}
}

impl<I> HasNegotiatedProtocol for ClientIo<I> {
#[inline]
fn negotiated_protocol(&self) -> Option<NegotiatedProtocolRef<'_>> {
unimplemented!()
match self {
#[cfg(feature = "rustls")]
Self::Rustls(io) => io.negotiated_protocol(),
}
}
}

impl<I: io::PeerAddr> io::PeerAddr for ClientIo<I> {
#[inline]
fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
#[cfg(feature = "rustls")]
if let Self::Rustls(io) = self {
return io.peer_addr();
match self {
#[cfg(feature = "rustls")]
Self::Rustls(io) => io.peer_addr(),
}

unreachable!()
}
}
48 changes: 18 additions & 30 deletions linkerd/meshtls/src/creds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@ pub enum Receiver {

impl Credentials for Store {
fn dns_name(&self) -> &Name {
#[cfg(feature = "rustls")]
if let Self::Rustls(store) = self {
return store.dns_name();
match self {
#[cfg(feature = "rustls")]
Self::Rustls(store) => store.dns_name(),
}

unreachable!()
}

fn gen_certificate_signing_request(&mut self) -> DerX509 {
#[cfg(feature = "rustls")]
if let Self::Rustls(store) = self {
return store.gen_certificate_signing_request();
match self {
#[cfg(feature = "rustls")]
Self::Rustls(store) => store.gen_certificate_signing_request(),
}

unreachable!()
}

fn set_certificate(
Expand All @@ -43,12 +39,10 @@ impl Credentials for Store {
chain: Vec<DerX509>,
expiry: std::time::SystemTime,
) -> Result<()> {
#[cfg(feature = "rustls")]
if let Self::Rustls(store) = self {
return store.set_certificate(leaf, chain, expiry);
match self {
#[cfg(feature = "rustls")]
Self::Rustls(store) => store.set_certificate(leaf, chain, expiry),
}

unreachable!()
}
}

Expand All @@ -63,29 +57,23 @@ impl From<rustls::creds::Receiver> for Receiver {

impl Receiver {
pub fn name(&self) -> &Name {
#[cfg(feature = "rustls")]
if let Self::Rustls(receiver) = self {
return receiver.name();
match self {
#[cfg(feature = "rustls")]
Self::Rustls(receiver) => receiver.name(),
}

unreachable!()
}

pub fn new_client(&self) -> NewClient {
#[cfg(feature = "rustls")]
if let Self::Rustls(receiver) = self {
return NewClient::Rustls(receiver.new_client());
match self {
#[cfg(feature = "rustls")]
Self::Rustls(receiver) => NewClient::Rustls(receiver.new_client()),
}

unreachable!()
}

pub fn server(&self) -> Server {
#[cfg(feature = "rustls")]
if let Self::Rustls(receiver) = self {
return Server::Rustls(receiver.server());
match self {
#[cfg(feature = "rustls")]
Self::Rustls(receiver) => Server::Rustls(receiver.server()),
}

unreachable!()
}
}
21 changes: 11 additions & 10 deletions linkerd/meshtls/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(irrefutable_let_patterns)]
#![deny(warnings, rust_2018_idioms)]
#![forbid(unsafe_code)]

mod client;
pub mod creds;
Expand Down Expand Up @@ -37,15 +38,15 @@ impl Mode {
key_pkcs8: &[u8],
csr: &[u8],
) -> Result<(creds::Store, creds::Receiver)> {
#[cfg(feature = "rustls")]
if let Self::Rustls = self {
let (store, receiver) = rustls::creds::watch(identity, roots_pem, key_pkcs8, csr)?;
return Ok((
creds::Store::Rustls(store),
creds::Receiver::Rustls(receiver),
));
match self {
#[cfg(feature = "rustls")]
Self::Rustls => {
let (store, receiver) = rustls::creds::watch(identity, roots_pem, key_pkcs8, csr)?;
Ok((
creds::Store::Rustls(store),
creds::Receiver::Rustls(receiver),
))
}
}

unreachable!()
}
}
Loading