diff --git a/Cargo.lock b/Cargo.lock index 3ff524537..18bbaf5cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2640,6 +2640,7 @@ version = "0.3.0-alpha.3" dependencies = [ "aws-lc-rs", "base64", + "byteorder", "rama-core", "rama-utils", "rcgen", diff --git a/rama-crypto/Cargo.toml b/rama-crypto/Cargo.toml index 13d7f0849..41b242226 100644 --- a/rama-crypto/Cargo.toml +++ b/rama-crypto/Cargo.toml @@ -30,6 +30,7 @@ rustls-pki-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } x509-parser = { workspace = true } +byteorder = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["full"] } diff --git a/rama-crypto/src/jose/jwa.rs b/rama-crypto/src/jose/jwa.rs index b645624d1..f23c15af9 100644 --- a/rama-crypto/src/jose/jwa.rs +++ b/rama-crypto/src/jose/jwa.rs @@ -1,8 +1,18 @@ use std::ops::Deref; use aws_lc_rs::signature::{ - ECDSA_P256_SHA256_FIXED_SIGNING, ECDSA_P384_SHA384_FIXED_SIGNING, EcdsaSigningAlgorithm, - EcdsaVerificationAlgorithm, + RSA_PKCS1_2048_8192_SHA256, RSA_PKCS1_2048_8192_SHA384, RSA_PKCS1_2048_8192_SHA512, + RSA_PSS_2048_8192_SHA256, RSA_PSS_2048_8192_SHA384, RSA_PSS_2048_8192_SHA512, + VerificationAlgorithm, +}; +use aws_lc_rs::{ + hmac, + hmac::{HMAC_SHA256, HMAC_SHA384, HMAC_SHA512}, + signature::{ + ECDSA_P256_SHA256_FIXED_SIGNING, ECDSA_P384_SHA384_FIXED_SIGNING, EcdsaSigningAlgorithm, + EcdsaVerificationAlgorithm, RSA_PKCS1_SHA256, RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, + RSA_PSS_SHA256, RSA_PSS_SHA384, RSA_PSS_SHA512, RsaEncoding, + }, }; use rama_core::error::OpaqueError; use serde::{Deserialize, Serialize}; @@ -99,3 +109,58 @@ impl TryFrom for &'static EcdsaVerificationAlgorithm { Ok(signing_algo.deref()) } } + +impl TryFrom for &'static hmac::Algorithm { + type Error = OpaqueError; + + fn try_from(value: JWA) -> Result { + match value { + JWA::HS256 => Ok(&HMAC_SHA256), + JWA::HS384 => Ok(&HMAC_SHA384), + JWA::HS512 => Ok(&HMAC_SHA512), + _ => Err(OpaqueError::from_display( + "Non-Hmac algorithm cannot be converted to hmac types", + )), + } + } +} + +impl TryFrom for &'static dyn RsaEncoding { + type Error = OpaqueError; + + fn try_from(value: JWA) -> Result { + match value { + JWA::RS256 => Ok(&RSA_PKCS1_SHA256), + JWA::RS384 => Ok(&RSA_PKCS1_SHA384), + JWA::RS512 => Ok(&RSA_PKCS1_SHA512), + JWA::PS256 => Ok(&RSA_PSS_SHA256), + JWA::PS384 => Ok(&RSA_PSS_SHA384), + JWA::PS512 => Ok(&RSA_PSS_SHA512), + _ => Err(OpaqueError::from_display( + "Non-RSA algorithm cannot be converted to rsa types", + )), + } + } +} + +impl TryFrom for &'static dyn VerificationAlgorithm { + type Error = OpaqueError; + + fn try_from(value: JWA) -> Result { + match value { + JWA::RS256 => Ok(&RSA_PKCS1_2048_8192_SHA256), + JWA::RS384 => Ok(&RSA_PKCS1_2048_8192_SHA384), + JWA::RS512 => Ok(&RSA_PKCS1_2048_8192_SHA512), + JWA::PS256 => Ok(&RSA_PSS_2048_8192_SHA256), + JWA::PS384 => Ok(&RSA_PSS_2048_8192_SHA384), + JWA::PS512 => Ok(&RSA_PSS_2048_8192_SHA512), + JWA::ES256 | JWA::ES384 | JWA::ES512 => { + let signing_algo: &'static EcdsaSigningAlgorithm = value.try_into()?; + Ok(signing_algo.deref()) + } + _ => Err(OpaqueError::from_display( + "Verification algorithm is not supported", + )), + } + } +} diff --git a/rama-crypto/src/jose/jwk.rs b/rama-crypto/src/jose/jwk.rs index 0233ebf00..eccc3cc82 100644 --- a/rama-crypto/src/jose/jwk.rs +++ b/rama-crypto/src/jose/jwk.rs @@ -1,3 +1,6 @@ +use aws_lc_rs::encoding::{AsDer, Pkcs8V1Der}; +use aws_lc_rs::rsa::KeySize; +use aws_lc_rs::signature::RsaKeyPair; use aws_lc_rs::{ digest::{Digest, SHA256, digest}, pkcs8::Document, @@ -156,7 +159,40 @@ impl JWK { &self, ) -> Result>, OpaqueError> { match &self.key_type { - JWKType::RSA { .. } => Err(OpaqueError::from_display("currently not supported")), + JWKType::RSA { n, e } => { + let n_bytes = BASE64_URL_SAFE_NO_PAD + .decode(n) + .context("decode RSA modulus (n)")?; + let e_bytes = BASE64_URL_SAFE_NO_PAD + .decode(e) + .context("decode RSA exponent (e)")?; + let n_der_encoded = utils::encode_integer(n_bytes); + let e_der_encoded = utils::encode_integer(e_bytes); + + let rsa_public_key_sequence = + Self::create_public_key_sequence(n_der_encoded, e_der_encoded); + + let rsa_key_len = utils::encode_der_length(rsa_public_key_sequence.len()); + + let der_rsa_key = Self::create_der_rsa_key(rsa_key_len, rsa_public_key_sequence); + + let bit_string_payload = Self::create_bit_string_payload(der_rsa_key); + let bit_string_len = utils::encode_der_length(bit_string_payload.len()); + + let bit_string = Self::create_bit_string(bit_string_len, bit_string_payload); + + let algorithm_identifier = [ + 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, + 0x05, 0x00, + ]; + + let final_sequence = Self::create_final_sequence(algorithm_identifier, bit_string); + + Ok(signature::UnparsedPublicKey::new( + self.alg.try_into()?, + final_sequence, + )) + } JWKType::OCT { .. } => Err(OpaqueError::from_display( "Symmetric key cannot be converted to public key", )), @@ -180,6 +216,70 @@ impl JWK { } } } + + /// Creates a new [`JWK`] from a given [`RSAKeyPair`] + pub fn new_from_rsa_key_pair(rsa_key_pair: &RsaKeyPair, alg: JWA) -> Result { + let n = rsa_key_pair.public_key().modulus(); + let e = rsa_key_pair.public_key().exponent(); + Ok(Self { + alg, + key_type: JWKType::RSA { + n: BASE64_URL_SAFE_NO_PAD.encode(n.big_endian_without_leading_zero()), + e: BASE64_URL_SAFE_NO_PAD.encode(e.big_endian_without_leading_zero()), + }, + + r#use: Some(JWKUse::Signature), + key_ops: None, + x5c: None, + x5t: None, + x5t_sha256: None, + }) + } + + fn create_public_key_sequence(n_der_encoded: Vec, e_der_encoded: Vec) -> Vec { + let mut rsa_public_key_sequence = + Vec::with_capacity(n_der_encoded.len() + e_der_encoded.len()); + rsa_public_key_sequence.extend(n_der_encoded); + rsa_public_key_sequence.extend(e_der_encoded); + rsa_public_key_sequence + } + + fn create_der_rsa_key(rsa_key_len: Vec, rsa_public_key_sequence: Vec) -> Vec { + let mut rsa_key_der = + Vec::with_capacity(1 + rsa_key_len.len() + rsa_public_key_sequence.len()); + rsa_key_der.push(0x30); + rsa_key_der.extend(rsa_key_len); + rsa_key_der.extend(rsa_public_key_sequence); + rsa_key_der + } + + fn create_bit_string_payload(rsa_key_der: Vec) -> Vec { + let mut bit_string_payload = Vec::with_capacity(1 + rsa_key_der.len()); + bit_string_payload.push(0x00); + bit_string_payload.extend(rsa_key_der); + bit_string_payload + } + + fn create_bit_string(bit_string_len: Vec, bit_string_payload: Vec) -> Vec { + let mut bit_string = + Vec::with_capacity(1 + bit_string_len.len() + bit_string_payload.len()); + bit_string.push(0x03); + bit_string.extend(bit_string_len); + bit_string.extend(bit_string_payload); + bit_string + } + + fn create_final_sequence(algorithm_identifier: [u8; 15], bit_string: Vec) -> Vec { + let mut final_sequence = Vec::with_capacity(algorithm_identifier.len() + bit_string.len()); + final_sequence.extend(algorithm_identifier); + final_sequence.extend(bit_string); + let final_sequence_len = utils::encode_der_length(final_sequence.len()); + let mut result = Vec::with_capacity(1 + final_sequence_len.len() + final_sequence.len()); + result.push(0x30); + result.extend(final_sequence_len); + result.extend(final_sequence); + result + } } /// [`EcdsaKey`] which is used to identify and authenticate our requests @@ -252,7 +352,7 @@ impl EcdsaKey { } #[derive(Serialize)] -struct EcdsaKeySigningHeaders<'a> { +struct SigningHeaders<'a> { alg: JWA, jwk: &'a JWK, } @@ -267,7 +367,7 @@ impl Signer for EcdsaKey { _unprotected_headers: &mut super::jws::Headers, ) -> Result<(), Self::Error> { let jwk = self.create_jwk(); - protected_headers.try_set_headers(EcdsaKeySigningHeaders { + protected_headers.try_set_headers(SigningHeaders { alg: jwk.alg, jwk: &jwk, })?; @@ -284,6 +384,123 @@ impl Signer for EcdsaKey { } } +pub struct RsaKey { + rng: SystemRandom, + alg: JWA, + inner: RsaKeyPair, +} + +impl RsaKey { + /// Create a new [`RsaKey`] from the given [`RsaKeyPair`] + pub fn new(key_pair: RsaKeyPair, alg: JWA, rng: SystemRandom) -> Result { + Ok(Self { + rng, + alg, + inner: key_pair, + }) + } + + /// Generate a new [`RsaKey`] from a newly generated [`RsaKeyPair`] + pub fn generate(key_size: KeySize) -> Result { + let key_pair = RsaKeyPair::generate(key_size).context("error generating rsa key pair")?; + + Self::new(key_pair, JWA::RS256, SystemRandom::new()) + } + + /// Generate a new [`RsaKey`] from the given pkcs8 der + pub fn from_pkcs8_der( + pkcs8_der: &[u8], + alg: JWA, + rng: SystemRandom, + ) -> Result { + let key_pair = RsaKeyPair::from_pkcs8(pkcs8_der).context("create RSAKeyPair from pkcs8")?; + + Self::new(key_pair, alg, rng) + } + + /// Create pkcs8 der for the current [`RsaKeyPair`] + pub fn pkcs8_der(&self) -> Result<(JWA, Pkcs8V1Der<'static>), OpaqueError> { + let doc = self + .inner + .as_der() + .context("error creating pkcs8 der from rsa keypair")?; + Ok((self.alg, doc)) + } + + /// Create a [`JWK`] for this [`RsaKey`] + #[must_use] + pub fn create_jwk(&self) -> JWK { + JWK::new_from_rsa_key_pair(&self.inner, self.alg) + .expect("error creating jwa from rsa keypair") + } + + #[must_use] + pub fn rng(&self) -> &SystemRandom { + &self.rng + } +} + +impl Signer for RsaKey { + type Signature = Vec; + type Error = OpaqueError; + + fn set_headers( + &self, + protected_headers: &mut super::jws::Headers, + _unprotected_headers: &mut super::jws::Headers, + ) -> Result<(), Self::Error> { + let jwk = self.create_jwk(); + protected_headers.try_set_headers(SigningHeaders { + alg: jwk.alg, + jwk: &jwk, + })?; + Ok(()) + } + + fn sign(&self, data: &str) -> Result { + let mut sig = vec![0; self.inner.public_modulus_len()]; + self.inner + .sign(self.alg.try_into()?, self.rng(), data.as_bytes(), &mut sig) + .context("sign protected data")?; + Ok(sig) + } +} + +mod utils { + pub(super) fn encode_der_length(len: usize) -> Vec { + if len < 128 { + vec![len as u8] + } else { + let mut len_bytes = len.to_be_bytes().to_vec(); + while len_bytes[0] == 0 { + len_bytes.remove(0); + } + let first_byte = 0x80 | len_bytes.len() as u8; + let mut result = vec![first_byte]; + result.extend_from_slice(&len_bytes); + result + } + } + + /// This function should only be used for parsing JWK encoded RSA values. + /// The function should **not** be used for general ASN1 encoded values. + /// The function assumes the input is in minimal form, not empty and is a + /// positive integer. + pub(super) fn encode_integer(bytes: Vec) -> Vec { + let needs_leading_zero = bytes[0] & 0x80 != 0; + let value_len = bytes.len() + needs_leading_zero as usize; + let len_bytes = encode_der_length(value_len); + let mut result = Vec::with_capacity(1 + len_bytes.len() + value_len); + result.push(0x02); + result.extend_from_slice(&len_bytes); + if needs_leading_zero { + result.push(0); + } + result.extend(bytes); + result + } +} + #[cfg(test)] mod tests { use super::*; @@ -322,4 +539,16 @@ mod tests { assert_eq!(key.create_jwk(), recreated_key.create_jwk()) } + + #[test] + fn test_n_and_e_are_base64_encoded() { + let rsa_key_pair = RsaKey::generate(KeySize::Rsa4096).unwrap(); + let jwk = JWK::new_from_rsa_key_pair(&rsa_key_pair.inner, JWA::PS512).unwrap(); + let (n, e) = match jwk.key_type { + JWKType::RSA { n, e } => (n, e), + _ => panic!("JWK type not RSA"), + }; + assert!(BASE64_URL_SAFE_NO_PAD.decode(n).is_ok()); + assert!(BASE64_URL_SAFE_NO_PAD.decode(e).is_ok()); + } } diff --git a/rama-crypto/src/jose/mod.rs b/rama-crypto/src/jose/mod.rs index b8ce8601b..05d5c467f 100644 --- a/rama-crypto/src/jose/mod.rs +++ b/rama-crypto/src/jose/mod.rs @@ -33,7 +33,7 @@ mod jwa; pub use jwa::JWA; mod jwk; -pub use jwk::{EcdsaKey, JWK, JWKEllipticCurves, JWKType, JWKUse}; +pub use jwk::{EcdsaKey, JWK, JWKEllipticCurves, JWKType, JWKUse, RsaKey}; mod jws; pub use jws::{