Skip to content

Commit

Permalink
Improve SSL upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
aramperes committed Jul 31, 2021
1 parent 137251d commit 60cd6a8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 61 deletions.
47 changes: 0 additions & 47 deletions nut-client/src/blocking/filter.rs

This file was deleted.

23 changes: 9 additions & 14 deletions nut-client/src/blocking/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::io::{BufRead, BufReader, Write};
use std::net::{SocketAddr, TcpStream};

use crate::blocking::filter::ConnectionPipeline;
use crate::blocking::stream::ConnectionStream;
use crate::cmd::{Command, Response};
use crate::{ClientError, Config, Host, NutError, Variable};

mod filter;
mod stream;

/// A blocking NUT client connection.
pub enum Connection {
Expand Down Expand Up @@ -62,7 +62,7 @@ impl Connection {
/// A blocking TCP NUT client connection.
pub struct TcpConnection {
config: Config,
pipeline: ConnectionPipeline,
pipeline: ConnectionStream,
}

impl TcpConnection {
Expand All @@ -71,11 +71,11 @@ impl TcpConnection {
let tcp_stream = TcpStream::connect_timeout(socket_addr, config.timeout)?;
let mut connection = Self {
config,
pipeline: ConnectionPipeline::Tcp(tcp_stream),
pipeline: ConnectionStream::Plain(tcp_stream),
};

// Initialize SSL connection
connection.enable_ssl()?;
connection = connection.enable_ssl()?;

// Attempt login using `config.auth`
connection.login()?;
Expand All @@ -84,7 +84,7 @@ impl TcpConnection {
}

#[cfg(feature = "ssl")]
fn enable_ssl(&mut self) -> crate::Result<()> {
fn enable_ssl(mut self) -> crate::Result<Self> {
if self.config.ssl {
// Send TLS request and check for 'OK'
self.write_cmd(Command::StartTLS)?;
Expand All @@ -110,17 +110,12 @@ impl TcpConnection {
let sess = rustls::ClientSession::new(&std::sync::Arc::new(config), dns_name);

// Wrap and override the TCP stream
let tcp = self
.pipeline
.tcp()
.ok_or_else(|| ClientError::from(NutError::SslNotSupported))?;
let tls = rustls::StreamOwned::new(sess, tcp);
self.pipeline = ConnectionPipeline::Ssl(tls);
self.pipeline = self.pipeline.upgrade_ssl(sess)?;

// Send a test command
self.get_network_version()?;
}
Ok(())
Ok(self)
}

#[cfg(not(feature = "ssl"))]
Expand Down Expand Up @@ -190,7 +185,7 @@ impl TcpConnection {
}

fn parse_line(
reader: &mut BufReader<&mut ConnectionPipeline>,
reader: &mut BufReader<&mut ConnectionStream>,
debug: bool,
) -> crate::Result<Vec<String>> {
let mut raw = String::new();
Expand Down
50 changes: 50 additions & 0 deletions nut-client/src/blocking/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::io::{Read, Write};
use std::net::TcpStream;

/// A wrapper for various synchronous stream types.
pub enum ConnectionStream {
/// A plain TCP stream.
Plain(TcpStream),

/// A stream wrapped with SSL using `rustls`.
#[cfg(feature = "ssl")]
Ssl(Box<rustls::StreamOwned<rustls::ClientSession, ConnectionStream>>),
}

impl ConnectionStream {
/// Wraps the current stream with SSL using `rustls`.
#[cfg(feature = "ssl")]
pub fn upgrade_ssl(self, session: rustls::ClientSession) -> crate::Result<ConnectionStream> {
Ok(ConnectionStream::Ssl(Box::new(rustls::StreamOwned::new(
session, self,
))))
}
}

impl Read for ConnectionStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Plain(stream) => stream.read(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.read(buf),
}
}
}

impl Write for ConnectionStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
Self::Plain(stream) => stream.write(buf),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.write(buf),
}
}

fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(stream) => stream.flush(),
#[cfg(feature = "ssl")]
Self::Ssl(stream) => stream.flush(),
}
}
}

0 comments on commit 60cd6a8

Please sign in to comment.