diff --git a/Cargo.lock b/Cargo.lock index 3afe4a81..c18bd858 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 3 [[package]] name = "a2" -version = "0.10.0" +version = "0.10.2" dependencies = [ "argparse", "base64", diff --git a/Cargo.toml b/Cargo.toml index 29bdaeeb..8c639b6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "a2" -version = "0.10.0" +version = "0.10.2" authors = [ "Harry Bairstow ", "Julius de Bruijn ", diff --git a/src/client.rs b/src/client.rs index 9da3c8d4..c060d766 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,8 +16,10 @@ use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; use hyper_util::client::legacy::connect::HttpConnector; use hyper_util::client::legacy::Client as HttpClient; use hyper_util::rt::TokioExecutor; +pub use rustls::client::ResolvesClientCert; use std::convert::Infallible; use std::io::Read; +use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; @@ -69,6 +71,9 @@ pub struct ClientConfig { pub request_timeout_secs: Option, /// The timeout for idle sockets being kept alive pub pool_idle_timeout_secs: Option, + /// We use TLS 1.3 by default, setting this to `true` overrides it to use TLS 1.2. + /// Defaults to `false`. + pub use_tls_12_override: bool, } impl Default for ClientConfig { @@ -77,6 +82,7 @@ impl Default for ClientConfig { endpoint: Endpoint::Production, request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), pool_idle_timeout_secs: Some(600), + use_tls_12_override: false, } } } @@ -88,6 +94,14 @@ impl ClientConfig { ..Default::default() } } + + pub fn get_tls_config_builder(&self) -> rustls::ConfigBuilder { + if self.use_tls_12_override { + rustls::client::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS12]) + } else { + rustls::client::ClientConfig::builder() + } + } } #[derive(Debug, Clone)] @@ -130,6 +144,7 @@ impl ClientBuilder { endpoint, request_timeout_secs, pool_idle_timeout_secs, + use_tls_12_override: _, }, signer, connector, @@ -188,7 +203,31 @@ impl Client { let Some((cert, pkey)) = pkcs.cert.zip(pkcs.pkey) else { return Err(Error::InvalidCertificate); }; - let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; + let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?, &config)?; + + Ok(Self::builder().connector(connector).config(config).build()) + } + + /// Create a connection to APNs using the provider client certificate which + /// you obtain from your [Apple developer + /// account](https://developer.apple.com/account/), chosen dynamically via + /// the rustls `ResolvesClientCert` trait. Prefer certificate() over this; use + /// this if you're using a key management service and don't have the private + /// key available. + pub fn certificate_resolver( + client_auth_cert_resolver: Arc, + config: ClientConfig, + ) -> Result { + let tls_config = config + .get_tls_config_builder() + .with_webpki_roots() + .with_client_cert_resolver(client_auth_cert_resolver); + + let connector = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http2() + .build(); Ok(Self::builder().connector(connector).config(config).build()) } @@ -197,7 +236,7 @@ impl Client { /// key, extracted from the provider client certificate you obtain from your /// [Apple developer account](https://developer.apple.com/account/) pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], config: ClientConfig) -> Result { - let connector = client_cert_connector(cert_pem, key_pem)?; + let connector = client_cert_connector(cert_pem, key_pem, &config)?; Ok(Self::builder().config(config).connector(connector).build()) } @@ -309,7 +348,11 @@ fn default_connector() -> HyperConnector { .build() } -fn client_cert_connector(mut cert_pem: &[u8], mut key_pem: &[u8]) -> Result { +fn client_cert_connector( + mut cert_pem: &[u8], + mut key_pem: &[u8], + client_config: &ClientConfig, +) -> Result { let private_key_error = || io::Error::new(io::ErrorKind::InvalidData, "private key"); let key = rustls_pemfile::pkcs8_private_keys(&mut key_pem) @@ -320,7 +363,8 @@ fn client_cert_connector(mut cert_pem: &[u8], mut key_pem: &[u8]) -> Result, _> = rustls_pemfile::certs(&mut cert_pem).collect(); let cert_chain = cert_chain.map_err(|_| private_key_error())?; - let config = rustls::client::ClientConfig::builder() + let config = client_config + .get_tls_config_builder() .with_webpki_roots() .with_client_auth_cert(cert_chain, key.into())?; diff --git a/src/lib.rs b/src/lib.rs index 513fac04..c87b841b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,6 +131,6 @@ pub use crate::request::notification::{ pub use crate::response::{ErrorBody, ErrorReason, Response}; -pub use crate::client::{Client, ClientConfig, Endpoint}; +pub use crate::client::{Client, ClientConfig, Endpoint, ResolvesClientCert}; pub use crate::error::Error;