diff --git a/src/common/decrypted_read_handler.rs b/src/common/decrypted_read_handler.rs index 1def261f..4239ca65 100644 --- a/src/common/decrypted_read_handler.rs +++ b/src/common/decrypted_read_handler.rs @@ -53,7 +53,9 @@ impl DecryptedReadHandler<'_> { } ServerRecord::ChangeCipherSpec(_) => Err(TlsError::InternalError), ServerRecord::Handshake(ServerHandshake::NewSessionTicket(_)) => { - // Ignore + // TODO: we should validate extensions and abort. We can do this automatically + // as long as the connection is unsplit, however, split connections must be aborted + // by the user. Ok(()) } _ => { diff --git a/src/connection.rs b/src/connection.rs index dfa4adef..fb0ddb6f 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -238,18 +238,18 @@ impl<'a> State { let record = record_reader .read(transport, key_schedule.read_state()) .await?; - process_server_hello(handshake, key_schedule, record) + let result = process_server_hello(handshake, key_schedule, record); + + handle_processing_error(result, transport, key_schedule, tx_buf).await } State::ServerVerify => { - /*info!( - "SIZE of server record queue : {}", - core::mem::size_of_val(&records) - );*/ let record = record_reader .read(transport, key_schedule.read_state()) .await?; - process_server_verify(handshake, key_schedule, config, record) + let result = process_server_verify(handshake, key_schedule, config, record); + + handle_processing_error(result, transport, key_schedule, tx_buf).await } State::ClientCert => { let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?; @@ -296,16 +296,17 @@ impl<'a> State { } State::ServerHello => { let record = record_reader.read_blocking(transport, key_schedule.read_state())?; - process_server_hello(handshake, key_schedule, record) + + let result = process_server_hello(handshake, key_schedule, record); + + handle_processing_error_blocking(result, transport, key_schedule, tx_buf) } State::ServerVerify => { - /*info!( - "SIZE of server record queue : {}", - core::mem::size_of_val(&records) - );*/ let record = record_reader.read_blocking(transport, key_schedule.read_state())?; - process_server_verify(handshake, key_schedule, config, record) + let result = process_server_verify(handshake, key_schedule, config, record); + + handle_processing_error_blocking(result, transport, key_schedule, tx_buf) } State::ClientCert => { let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?; @@ -326,6 +327,29 @@ impl<'a> State { } } +fn handle_processing_error_blocking( + result: Result, + transport: &mut impl BlockingWrite, + key_schedule: &mut KeySchedule, + tx_buf: &mut WriteBuffer, +) -> Result +where + CipherSuite: TlsCipherSuite, +{ + if let Err(TlsError::AbortHandshake(level, description)) = result { + let (write_key_schedule, read_key_schedule) = key_schedule.as_split(); + let tx = tx_buf.write_record( + &ClientRecord::Alert(Alert { level, description }, false), + write_key_schedule, + Some(read_key_schedule), + )?; + + respond_blocking(tx, transport, key_schedule)?; + } + + result +} + fn respond_blocking( tx: &[u8], transport: &mut impl BlockingWrite, @@ -345,6 +369,30 @@ where Ok(()) } +#[cfg(feature = "async")] +async fn handle_processing_error<'a, CipherSuite>( + result: Result, + transport: &mut impl AsyncWrite, + key_schedule: &mut KeySchedule, + tx_buf: &mut WriteBuffer<'a>, +) -> Result +where + CipherSuite: TlsCipherSuite, +{ + if let Err(TlsError::AbortHandshake(level, description)) = result { + let (write_key_schedule, read_key_schedule) = key_schedule.as_split(); + let tx = tx_buf.write_record( + &ClientRecord::Alert(Alert { level, description }, false), + write_key_schedule, + Some(read_key_schedule), + )?; + + respond(tx, transport, key_schedule).await?; + } + + result +} + #[cfg(feature = "async")] async fn respond( tx: &[u8], diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 16a9ddf4..8d9e4a98 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -10,7 +10,7 @@ use crate::supported_versions::ProtocolVersions; use crate::TlsError; use heapless::Vec; -#[derive(Debug)] +#[derive(Debug, PartialEq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum ExtensionType { ServerName = 0, @@ -19,7 +19,7 @@ pub enum ExtensionType { SupportedGroups = 10, SignatureAlgorithms = 13, UseSrtp = 14, - Heatbeat = 15, + Heartbeat = 15, ApplicationLayerProtocolNegotiation = 16, SignedCertificateTimestamp = 18, ClientCertificateType = 19, @@ -47,7 +47,7 @@ impl ExtensionType { 10 => Some(Self::SupportedGroups), 13 => Some(Self::SignatureAlgorithms), 14 => Some(Self::UseSrtp), - 15 => Some(Self::Heatbeat), + 15 => Some(Self::Heartbeat), 16 => Some(Self::ApplicationLayerProtocolNegotiation), 18 => Some(Self::SignedCertificateTimestamp), 19 => Some(Self::ClientCertificateType), diff --git a/src/extensions/server.rs b/src/extensions/server.rs index 1db7ecce..ef2b57ed 100644 --- a/src/extensions/server.rs +++ b/src/extensions/server.rs @@ -1,3 +1,4 @@ +use crate::alert::{AlertDescription, AlertLevel}; use crate::extensions::common::KeyShareEntry; use crate::extensions::ExtensionType; use crate::parse_buffer::{ParseBuffer, ParseError}; @@ -11,6 +12,9 @@ pub enum ServerExtension<'a> { SupportedVersion(SupportedVersion), KeyShare(KeyShare<'a>), PreSharedKey(u16), + + SupportedGroups, + ServerName, } #[derive(Debug)] @@ -26,6 +30,29 @@ impl SupportedVersion { } } +pub struct ServerExtensionParserIterator<'a, 'b> { + buffer: &'b mut ParseBuffer<'a>, + allowed: &'b [ExtensionType], +} + +impl<'a, 'b> ServerExtensionParserIterator<'a, 'b> { + pub fn new(buffer: &'b mut ParseBuffer<'a>, allowed: &'b [ExtensionType]) -> Self { + Self { buffer, allowed } + } +} + +impl<'a, 'b> Iterator for ServerExtensionParserIterator<'a, 'b> { + type Item = Result>, TlsError>; + + fn next(&mut self) -> Option { + if self.buffer.is_empty() { + return None; + } + + Some(ServerExtension::parse(&mut self.buffer, &self.allowed)) + } +} + #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct KeyShare<'a>(pub(crate) KeyShareEntry<'a>); @@ -37,75 +64,91 @@ impl<'a> KeyShare<'a> { } impl<'a> ServerExtension<'a> { - pub fn parse_vector( + pub fn parse( buf: &mut ParseBuffer<'a>, - ) -> Result, 16>, TlsError> { - let mut extensions = Vec::new(); + allowed: &[ExtensionType], + ) -> Result>, TlsError> { + let extension_type = + ExtensionType::of(buf.read_u16().map_err(|_| TlsError::UnknownExtensionType)?) + .ok_or(TlsError::UnknownExtensionType)?; - loop { - if buf.is_empty() { - break; - } + trace!("extension type {:?}", extension_type); + + if !allowed.contains(&extension_type) { + warn!( + "{:?} extension is not allowed in this context", + extension_type + ); + + // Section 4.2. Extensions + // If an implementation receives an extension + // which it recognizes and which is not specified for the message in + // which it appears, it MUST abort the handshake with an + // "illegal_parameter" alert. + return Err(TlsError::AbortHandshake( + AlertLevel::Fatal, + AlertDescription::IllegalParameter, + )); + } - let extension_type = - ExtensionType::of(buf.read_u16().map_err(|_| TlsError::UnknownExtensionType)?) - .ok_or(TlsError::UnknownExtensionType)?; - - //info!("extension type {:?}", extension_type); - - let extension_length = buf - .read_u16() - .map_err(|_| TlsError::InvalidExtensionsLength)?; - - //info!("extension length {}", extension_length); - - match extension_type { - ExtensionType::SupportedVersions => { - extensions - .push(ServerExtension::SupportedVersion( - SupportedVersion::parse( - &mut buf - .slice(extension_length as usize) - .map_err(|_| TlsError::InvalidExtensionsLength)?, - ) - .map_err(|_| TlsError::InvalidSupportedVersions)?, - )) - .map_err(|_| TlsError::DecodeError)?; - } - ExtensionType::KeyShare => { - extensions - .push(ServerExtension::KeyShare( - KeyShare::parse( - &mut buf - .slice(extension_length as usize) - .map_err(|_| TlsError::InvalidExtensionsLength)?, - ) - .map_err(|_| TlsError::InvalidKeyShare)?, - )) - .map_err(|_| TlsError::DecodeError)?; - } - ExtensionType::SupportedGroups => { - let _ = buf.slice(extension_length as usize); - } - ExtensionType::ServerName => { - let _ = buf.slice(extension_length as usize); - } - ExtensionType::PreSharedKey => { - let data = buf - .slice(extension_length as usize) - .map_err(|_| TlsError::DecodeError)?; - let data = data.as_slice(); - let value = u16::from_be_bytes([data[0], data[1]]); - extensions - .push(ServerExtension::PreSharedKey(value)) - .map_err(|_| TlsError::DecodeError)?; - } - t => { - info!("Unsupported extension type {:?}", t); - return Err(TlsError::Unimplemented); - } + let extension_length = buf + .read_u16() + .map_err(|_| TlsError::InvalidExtensionsLength)?; + + trace!("extension length {}", extension_length); + + Self::from_type_and_data(extension_type, &mut buf.slice(extension_length as usize)?) + } + + pub fn parse_vector( + buf: &mut ParseBuffer<'a>, + allowed: &[ExtensionType], + ) -> Result, N>, TlsError> { + let extensions_len = buf + .read_u16() + .map_err(|_| TlsError::InvalidExtensionsLength)?; + + let mut ext_buf = buf.slice(extensions_len as usize)?; + + let mut iter = ServerExtensionParserIterator::new(&mut ext_buf, allowed); + + let mut extensions = Vec::new(); + + while let Some(extension) = iter.next() { + if let Some(extension) = extension? { + extensions + .push(extension) + .map_err(|_| TlsError::DecodeError)?; } } + Ok(extensions) } + + fn from_type_and_data<'b>( + extension_type: ExtensionType, + data: &mut ParseBuffer<'b>, + ) -> Result>, TlsError> { + let extension = match extension_type { + ExtensionType::SupportedVersions => ServerExtension::SupportedVersion( + SupportedVersion::parse(data).map_err(|_| TlsError::InvalidSupportedVersions)?, + ), + ExtensionType::KeyShare => ServerExtension::KeyShare( + KeyShare::parse(data).map_err(|_| TlsError::InvalidKeyShare)?, + ), + ExtensionType::PreSharedKey => { + let value = data.read_u16()?; + + ServerExtension::PreSharedKey(value) + } + ExtensionType::SupportedGroups => ServerExtension::SupportedGroups, + ExtensionType::ServerName => ServerExtension::ServerName, + t => { + warn!("Unimplemented extension: {:?}", t); + return Ok(None); + } + }; + + Ok(Some(extension)) + } } diff --git a/src/handshake/certificate.rs b/src/handshake/certificate.rs index e88ca9ed..351f1cce 100644 --- a/src/handshake/certificate.rs +++ b/src/handshake/certificate.rs @@ -1,4 +1,6 @@ use crate::buffer::CryptoBuffer; +use crate::extensions::server::ServerExtension; +use crate::extensions::ExtensionType; use crate::parse_buffer::ParseBuffer; use crate::TlsError; use heapless::Vec; @@ -68,6 +70,12 @@ pub enum CertificateEntryRef<'a> { } impl<'a> CertificateEntryRef<'a> { + // Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with CT + const ALLOWED_EXTENSIONS: &[ExtensionType] = &[ + ExtensionType::StatusRequest, + ExtensionType::SignedCertificateTimestamp, + ]; + pub fn parse_vector( buf: &mut ParseBuffer<'a>, ) -> Result, 16>, TlsError> { @@ -88,9 +96,8 @@ impl<'a> CertificateEntryRef<'a> { .push(CertificateEntryRef::X509(cert.as_slice())) .map_err(|_| TlsError::DecodeError)?; - let _extensions_len = buf - .read_u16() - .map_err(|_| TlsError::InvalidExtensionsLength)?; + // Validate extensions + ServerExtension::parse_vector::<2>(buf, Self::ALLOWED_EXTENSIONS)?; if buf.is_empty() { break; diff --git a/src/handshake/certificate_request.rs b/src/handshake/certificate_request.rs index f20ab415..17c52995 100644 --- a/src/handshake/certificate_request.rs +++ b/src/handshake/certificate_request.rs @@ -1,3 +1,5 @@ +use crate::extensions::server::ServerExtension; +use crate::extensions::ExtensionType; use crate::parse_buffer::ParseBuffer; use crate::TlsError; use heapless::Vec; @@ -9,6 +11,16 @@ pub struct CertificateRequestRef<'a> { } impl<'a> CertificateRequestRef<'a> { + // Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with CR + const ALLOWED_EXTENSIONS: &[ExtensionType] = &[ + ExtensionType::StatusRequest, + ExtensionType::SignatureAlgorithms, + ExtensionType::SignedCertificateTimestamp, + ExtensionType::CertificateAuthorities, + ExtensionType::OidFilters, + ExtensionType::SignatureAlgorithmsCert, + ]; + pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, TlsError> { let request_context_len = buf .read_u8() @@ -22,12 +34,8 @@ impl<'a> CertificateRequestRef<'a> { .map_err(|_| TlsError::InvalidExtensionsLength)?; //info!("sh 5 {}", extensions_length); - buf.slice(_extensions_length as usize) - .map_err(|_| TlsError::DecodeError)?; - // info!("Cert request parsing"); - // TODO - //let extensions = ServerExtension::parse_vector(buf)?; - //info!("Cert request parsing done"); + // Validate extensions + ServerExtension::parse_vector::<6>(buf, Self::ALLOWED_EXTENSIONS)?; Ok(Self { request_context: request_context.as_slice(), diff --git a/src/handshake/client_hello.rs b/src/handshake/client_hello.rs index f2f3fda5..2e3385b9 100644 --- a/src/handshake/client_hello.rs +++ b/src/handshake/client_hello.rs @@ -66,6 +66,10 @@ where // extensions (1+) buf.with_u16_length(|buf| { + // Section 4.2.1. Supported Versions + // Implementations of this specification MUST send this extension in the + // ClientHello containing all versions of TLS which they are prepared to + // negotiate ClientExtension::SupportedVersions { versions: Vec::from_slice(&[TLS13]).unwrap(), } @@ -92,13 +96,16 @@ where } .encode(buf)?; - if let Some(server_name) = self.config.server_name.as_ref() { + if let Some(server_name) = self.config.server_name { // TODO Add SNI extension ClientExtension::ServerName { server_name }.encode(buf)?; } - // IMPORTANT: The pre shared keys must be encoded last, since we encode the binders - // at a later stage + // Section 4.2 + // When multiple extensions of different types are present, the + // extensions MAY appear in any order, with the exception of + // "pre_shared_key" which MUST be the last extension in + // the ClientHello. if let Some((_, identities)) = &self.config.psk { ClientExtension::PreSharedKey { identities: identities.clone(), diff --git a/src/handshake/encrypted_extensions.rs b/src/handshake/encrypted_extensions.rs index 021b4515..d6160271 100644 --- a/src/handshake/encrypted_extensions.rs +++ b/src/handshake/encrypted_extensions.rs @@ -1,4 +1,5 @@ use crate::extensions::server::ServerExtension; +use crate::extensions::ExtensionType; use crate::parse_buffer::ParseBuffer; use crate::TlsError; @@ -11,14 +12,21 @@ pub struct EncryptedExtensions<'a> { } impl<'a> EncryptedExtensions<'a> { + // Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with EE + const ALLOWED_EXTENSIONS: &[ExtensionType] = &[ + ExtensionType::ServerName, + ExtensionType::MaxFragmentLength, + ExtensionType::SupportedGroups, + ExtensionType::UseSrtp, + ExtensionType::Heartbeat, + ExtensionType::ApplicationLayerProtocolNegotiation, + ExtensionType::ClientCertificateType, + ExtensionType::ServerCertificateType, + ExtensionType::EarlyData, + ]; + pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, TlsError> { - //let extensions_len = u16::from_be_bytes([buf[0], buf[1]]) as usize; - let extensions_len = buf - .read_u16() - .map_err(|_| TlsError::InvalidExtensionsLength)?; - // info!("extensions length: {}", extensions_len); - let extensions = - ServerExtension::parse_vector(&mut buf.slice(extensions_len as usize).unwrap())?; - Ok(Self { extensions }) + ServerExtension::parse_vector(buf, Self::ALLOWED_EXTENSIONS) + .map(|extensions| Self { extensions }) } } diff --git a/src/handshake/new_session_ticket.rs b/src/handshake/new_session_ticket.rs index 8e11f835..8433805d 100644 --- a/src/handshake/new_session_ticket.rs +++ b/src/handshake/new_session_ticket.rs @@ -1,6 +1,7 @@ use heapless::Vec; use crate::extensions::server::ServerExtension; +use crate::extensions::ExtensionType; use crate::parse_buffer::ParseBuffer; use crate::TlsError; @@ -15,6 +16,9 @@ pub struct NewSessionTicket<'a> { } impl<'a> NewSessionTicket<'a> { + // Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with NST + const ALLOWED_EXTENSIONS: &[ExtensionType] = &[ExtensionType::EarlyData]; + pub fn parse(buf: &mut ParseBuffer<'a>) -> Result, TlsError> { let lifetime = buf.read_u32()?; let age_add = buf.read_u32()?; @@ -29,10 +33,7 @@ impl<'a> NewSessionTicket<'a> { .slice(ticket_length as usize) .map_err(|_| TlsError::InvalidTicketLength)?; - let _extensions_length = buf - .read_u16() - .map_err(|_| TlsError::InvalidExtensionsLength)?; - let extensions = ServerExtension::parse_vector(buf)?; + let extensions = ServerExtension::parse_vector(buf, Self::ALLOWED_EXTENSIONS)?; Ok(Self { lifetime, diff --git a/src/handshake/server_hello.rs b/src/handshake/server_hello.rs index bfc8eaa3..c07610d1 100644 --- a/src/handshake/server_hello.rs +++ b/src/handshake/server_hello.rs @@ -4,6 +4,7 @@ use crate::cipher_suites::CipherSuite; use crate::crypto_engine::CryptoEngine; use crate::extensions::common::KeyShareEntry; use crate::extensions::server::ServerExtension; +use crate::extensions::ExtensionType; use crate::handshake::Random; use crate::parse_buffer::ParseBuffer; use crate::TlsError; @@ -21,6 +22,14 @@ pub struct ServerHello<'a> { } impl<'a> ServerHello<'a> { + // Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with SH + const ALLOWED_EXTENSIONS: &[ExtensionType] = &[ + ExtensionType::KeyShare, + ExtensionType::PreSharedKey, + ExtensionType::SupportedVersions, + ExtensionType::PostHandshakeAuth, + ]; + pub fn read(buf: &'a [u8], digest: &mut D) -> Result, TlsError> { //trace!("server hello hash [{:x?}]", &buf[..]); digest.update(buf); @@ -54,14 +63,7 @@ impl<'a> ServerHello<'a> { // skip compression method, it's 0. buf.read_u8()?; - //info!("sh 4"); - let _extensions_length = buf - .read_u16() - .map_err(|_| TlsError::InvalidExtensionsLength)?; - //info!("sh 5 {}", extensions_length); - - let extensions = ServerExtension::parse_vector(buf)?; - //info!("sh 6"); + let extensions = ServerExtension::parse_vector(buf, Self::ALLOWED_EXTENSIONS)?; // info!("server random {:x?}", random); // info!("server session-id {:x?}", session_id.as_slice()); diff --git a/src/lib.rs b/src/lib.rs index a3100765..badb8322 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,7 @@ pub enum TlsError { Unimplemented, MissingHandshake, HandshakeAborted(alert::AlertLevel, alert::AlertDescription), + AbortHandshake(alert::AlertLevel, alert::AlertDescription), IoError, InternalError, InvalidRecord, diff --git a/tests/early_data_test.rs b/tests/early_data_test.rs new file mode 100644 index 00000000..2751f739 --- /dev/null +++ b/tests/early_data_test.rs @@ -0,0 +1,92 @@ +#![macro_use] +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] +#![feature(impl_trait_projections)] +use embedded_io::adapters::FromStd; +use embedded_io::blocking::{Read, Write}; +use rand_core::OsRng; +use std::net::SocketAddr; +use std::sync::Once; + +mod tlsserver; + +static INIT: Once = Once::new(); +static mut ADDR: Option = None; + +fn setup() -> SocketAddr { + use mio::net::TcpListener; + INIT.call_once(|| { + env_logger::init(); + + let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap(); + + let listener = TcpListener::bind(addr).expect("cannot listen on port"); + let addr = listener + .local_addr() + .expect("error retrieving socket address"); + + std::thread::spawn(move || { + use tlsserver::*; + + let versions = &[&rustls::version::TLS13]; + + let test_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests"); + + let certs = load_certs(&test_dir.join("data").join("server-cert.pem")); + let privkey = load_private_key(&test_dir.join("data").join("server-key.pem")); + + let mut config = rustls::ServerConfig::builder() + .with_cipher_suites(rustls::ALL_CIPHER_SUITES) + .with_kx_groups(&rustls::ALL_KX_GROUPS) + .with_protocol_versions(versions) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, privkey) + .unwrap(); + + config.max_early_data_size = 512; + + run_with_config(listener, config); + }); + unsafe { ADDR.replace(addr) }; + }); + unsafe { ADDR.unwrap() } +} + +#[test] +fn early_data_ignored() { + use embedded_tls::blocking::*; + use std::net::TcpStream; + + let addr = setup(); + let pem = include_str!("data/ca-cert.pem"); + let der = pem_parser::pem_to_der(pem); + + let stream = TcpStream::connect(addr).expect("error connecting to server"); + + log::info!("Connected"); + let mut read_record_buffer = [0; 16384]; + let mut write_record_buffer = [0; 16384]; + let config = TlsConfig::new() + .with_ca(Certificate::X509(&der[..])) + .with_server_name("localhost"); + + let mut tls: TlsConnection, Aes128GcmSha256> = TlsConnection::new( + FromStd::new(stream), + &mut read_record_buffer, + &mut write_record_buffer, + ); + + tls.open::(TlsContext::new(&config, &mut OsRng)) + .expect("error establishing TLS connection"); + + tls.write_all(b"ping").expect("Failed to write data"); + tls.flush().expect("Failed to flush"); + + let mut buffer = [0; 4]; + tls.read_exact(&mut buffer).expect("Failed to read data"); + + tls.close() + .map_err(|(_, e)| e) + .expect("error closing session"); +} diff --git a/tests/tlsserver.rs b/tests/tlsserver.rs index 867a98b9..ad6f4973 100644 --- a/tests/tlsserver.rs +++ b/tests/tlsserver.rs @@ -9,18 +9,18 @@ use std::io::{BufReader, Read, Write}; use std::net; // Token for our listening socket. -const LISTENER: mio::Token = mio::Token(0); +pub const LISTENER: mio::Token = mio::Token(0); // Which mode the server operates in. #[derive(Clone)] -enum ServerMode { +pub enum ServerMode { /// Write back received bytes Echo, } /// This binds together a TCP listening socket, some outstanding /// connections, and a TLS server configuration. -struct TlsServer { +pub struct TlsServer { server: TcpListener, connections: HashMap, next_id: usize, @@ -29,7 +29,7 @@ struct TlsServer { } impl TlsServer { - fn new(server: TcpListener, mode: ServerMode, cfg: Arc) -> TlsServer { + pub fn new(server: TcpListener, mode: ServerMode, cfg: Arc) -> TlsServer { TlsServer { server, connections: HashMap::new(), @@ -39,7 +39,7 @@ impl TlsServer { } } - fn accept(&mut self, registry: &mio::Registry) -> Result<(), io::Error> { + pub fn accept(&mut self, registry: &mio::Registry) -> Result<(), io::Error> { loop { match self.server.accept() { Ok((socket, addr)) => { @@ -68,7 +68,7 @@ impl TlsServer { } } - fn conn_event(&mut self, registry: &mio::Registry, event: &mio::event::Event) { + pub fn conn_event(&mut self, registry: &mio::Registry, event: &mio::event::Event) { let token = event.token(); if self.connections.contains_key(&token) { @@ -197,7 +197,6 @@ impl Connection { self.do_tls_write_and_handle_error(); self.closing = true; - } } @@ -270,7 +269,6 @@ impl Connection { if rc.is_err() { log::warn!("write failed {:?}", rc); self.closing = true; - } } @@ -326,7 +324,7 @@ impl Connection { } } -fn load_certs(filename: &PathBuf) -> Vec { +pub fn load_certs(filename: &PathBuf) -> Vec { let certfile = fs::File::open(filename).expect("cannot open certificate file"); let mut reader = BufReader::new(certfile); rustls_pemfile::certs(&mut reader) @@ -336,7 +334,7 @@ fn load_certs(filename: &PathBuf) -> Vec { .collect() } -fn load_private_key(filename: &PathBuf) -> rustls::PrivateKey { +pub fn load_private_key(filename: &PathBuf) -> rustls::PrivateKey { let keyfile = fs::File::open(filename).expect("cannot open private key file"); let mut reader = BufReader::new(keyfile); @@ -355,7 +353,8 @@ fn load_private_key(filename: &PathBuf) -> rustls::PrivateKey { ); } -pub fn run(mut listener: TcpListener) { +#[allow(dead_code)] +pub fn run(listener: TcpListener) { let versions = &[&rustls::version::TLS13]; let test_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests"); @@ -372,6 +371,10 @@ pub fn run(mut listener: TcpListener) { .with_single_cert(certs, privkey) .unwrap(); + run_with_config(listener, config) +} + +pub fn run_with_config(mut listener: TcpListener, config: rustls::ServerConfig) { let mut poll = mio::Poll::new().unwrap(); poll.registry() .register(&mut listener, LISTENER, mio::Interest::READABLE)