-
-
Notifications
You must be signed in to change notification settings - Fork 76
implement jwa algorithms #650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
use aws_lc_rs::encoding::{AsDer, Pkcs8V1Der}; | ||
use aws_lc_rs::rsa::KeySize; | ||
use aws_lc_rs::signature::RsaKeyPair; | ||
use aws_lc_rs::{ | ||
digest::{Digest, SHA256, digest}, | ||
pkcs8::Document, | ||
|
@@ -156,7 +159,40 @@ impl JWK { | |
&self, | ||
) -> Result<signature::UnparsedPublicKey<Vec<u8>>, OpaqueError> { | ||
match &self.key_type { | ||
JWKType::RSA { .. } => Err(OpaqueError::from_display("currently not supported")), | ||
JWKType::RSA { n, e } => { | ||
let n_bytes = BASE64_URL_SAFE_NO_PAD | ||
.decode(n) | ||
.context("decode RSA modulus (n)")?; | ||
let e_bytes = BASE64_URL_SAFE_NO_PAD | ||
.decode(e) | ||
.context("decode RSA exponent (e)")?; | ||
let n_der_encoded = utils::encode_integer(n_bytes); | ||
let e_der_encoded = utils::encode_integer(e_bytes); | ||
|
||
let rsa_public_key_sequence = | ||
Self::create_public_key_sequence(n_der_encoded, e_der_encoded); | ||
|
||
let rsa_key_len = utils::encode_der_length(rsa_public_key_sequence.len()); | ||
|
||
let der_rsa_key = Self::create_der_rsa_key(rsa_key_len, rsa_public_key_sequence); | ||
|
||
let bit_string_payload = Self::create_bit_string_payload(der_rsa_key); | ||
let bit_string_len = utils::encode_der_length(bit_string_payload.len()); | ||
|
||
let bit_string = Self::create_bit_string(bit_string_len, bit_string_payload); | ||
|
||
let algorithm_identifier = [ | ||
0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, | ||
0x05, 0x00, | ||
]; | ||
|
||
let final_sequence = Self::create_final_sequence(algorithm_identifier, bit_string); | ||
|
||
Ok(signature::UnparsedPublicKey::new( | ||
self.alg.try_into()?, | ||
final_sequence, | ||
)) | ||
} | ||
JWKType::OCT { .. } => Err(OpaqueError::from_display( | ||
"Symmetric key cannot be converted to public key", | ||
)), | ||
|
@@ -180,6 +216,70 @@ impl JWK { | |
} | ||
} | ||
} | ||
|
||
/// Creates a new [`JWK`] from a given [`RSAKeyPair`] | ||
pub fn new_from_rsa_key_pair(rsa_key_pair: &RsaKeyPair, alg: JWA) -> Result<Self, OpaqueError> { | ||
let n = rsa_key_pair.public_key().modulus(); | ||
let e = rsa_key_pair.public_key().exponent(); | ||
Ok(Self { | ||
alg, | ||
key_type: JWKType::RSA { | ||
n: BASE64_URL_SAFE_NO_PAD.encode(n.big_endian_without_leading_zero()), | ||
e: BASE64_URL_SAFE_NO_PAD.encode(e.big_endian_without_leading_zero()), | ||
}, | ||
|
||
r#use: Some(JWKUse::Signature), | ||
key_ops: None, | ||
x5c: None, | ||
x5t: None, | ||
x5t_sha256: None, | ||
}) | ||
} | ||
|
||
fn create_public_key_sequence(n_der_encoded: Vec<u8>, e_der_encoded: Vec<u8>) -> Vec<u8> { | ||
let mut rsa_public_key_sequence = | ||
Vec::with_capacity(n_der_encoded.len() + e_der_encoded.len()); | ||
rsa_public_key_sequence.extend(n_der_encoded); | ||
rsa_public_key_sequence.extend(e_der_encoded); | ||
rsa_public_key_sequence | ||
} | ||
|
||
fn create_der_rsa_key(rsa_key_len: Vec<u8>, rsa_public_key_sequence: Vec<u8>) -> Vec<u8> { | ||
let mut rsa_key_der = | ||
Vec::with_capacity(1 + rsa_key_len.len() + rsa_public_key_sequence.len()); | ||
rsa_key_der.push(0x30); | ||
rsa_key_der.extend(rsa_key_len); | ||
rsa_key_der.extend(rsa_public_key_sequence); | ||
rsa_key_der | ||
} | ||
|
||
fn create_bit_string_payload(rsa_key_der: Vec<u8>) -> Vec<u8> { | ||
let mut bit_string_payload = Vec::with_capacity(1 + rsa_key_der.len()); | ||
bit_string_payload.push(0x00); | ||
bit_string_payload.extend(rsa_key_der); | ||
bit_string_payload | ||
} | ||
|
||
fn create_bit_string(bit_string_len: Vec<u8>, bit_string_payload: Vec<u8>) -> Vec<u8> { | ||
let mut bit_string = | ||
Vec::with_capacity(1 + bit_string_len.len() + bit_string_payload.len()); | ||
bit_string.push(0x03); | ||
bit_string.extend(bit_string_len); | ||
bit_string.extend(bit_string_payload); | ||
bit_string | ||
} | ||
|
||
fn create_final_sequence(algorithm_identifier: [u8; 15], bit_string: Vec<u8>) -> Vec<u8> { | ||
let mut final_sequence = Vec::with_capacity(algorithm_identifier.len() + bit_string.len()); | ||
final_sequence.extend(algorithm_identifier); | ||
final_sequence.extend(bit_string); | ||
let final_sequence_len = utils::encode_der_length(final_sequence.len()); | ||
let mut result = Vec::with_capacity(1 + final_sequence_len.len() + final_sequence.len()); | ||
result.push(0x30); | ||
result.extend(final_sequence_len); | ||
result.extend(final_sequence); | ||
result | ||
} | ||
} | ||
|
||
/// [`EcdsaKey`] which is used to identify and authenticate our requests | ||
|
@@ -252,7 +352,7 @@ impl EcdsaKey { | |
} | ||
|
||
#[derive(Serialize)] | ||
struct EcdsaKeySigningHeaders<'a> { | ||
struct SigningHeaders<'a> { | ||
alg: JWA, | ||
jwk: &'a JWK, | ||
} | ||
|
@@ -267,7 +367,7 @@ impl Signer for EcdsaKey { | |
_unprotected_headers: &mut super::jws::Headers, | ||
) -> Result<(), Self::Error> { | ||
let jwk = self.create_jwk(); | ||
protected_headers.try_set_headers(EcdsaKeySigningHeaders { | ||
protected_headers.try_set_headers(SigningHeaders { | ||
alg: jwk.alg, | ||
jwk: &jwk, | ||
})?; | ||
|
@@ -284,6 +384,123 @@ impl Signer for EcdsaKey { | |
} | ||
} | ||
|
||
pub struct RsaKey { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one will also have to import/exported in When you run linting/clippy/tests
or only linting/clippy without tests
It should also detect these things, same for unused imports |
||
rng: SystemRandom, | ||
alg: JWA, | ||
inner: RsaKeyPair, | ||
} | ||
|
||
impl RsaKey { | ||
/// Create a new [`RsaKey`] from the given [`RsaKeyPair`] | ||
pub fn new(key_pair: RsaKeyPair, alg: JWA, rng: SystemRandom) -> Result<Self, OpaqueError> { | ||
Ok(Self { | ||
rng, | ||
alg, | ||
inner: key_pair, | ||
}) | ||
} | ||
|
||
/// Generate a new [`RsaKey`] from a newly generated [`RsaKeyPair`] | ||
pub fn generate(key_size: KeySize) -> Result<Self, OpaqueError> { | ||
let key_pair = RsaKeyPair::generate(key_size).context("error generating rsa key pair")?; | ||
|
||
Self::new(key_pair, JWA::RS256, SystemRandom::new()) | ||
} | ||
|
||
/// Generate a new [`RsaKey`] from the given pkcs8 der | ||
pub fn from_pkcs8_der( | ||
pkcs8_der: &[u8], | ||
alg: JWA, | ||
rng: SystemRandom, | ||
) -> Result<Self, OpaqueError> { | ||
let key_pair = RsaKeyPair::from_pkcs8(pkcs8_der).context("create RSAKeyPair from pkcs8")?; | ||
|
||
Self::new(key_pair, alg, rng) | ||
} | ||
|
||
/// Create pkcs8 der for the current [`RsaKeyPair`] | ||
pub fn pkcs8_der(&self) -> Result<(JWA, Pkcs8V1Der<'static>), OpaqueError> { | ||
let doc = self | ||
.inner | ||
.as_der() | ||
.context("error creating pkcs8 der from rsa keypair")?; | ||
Ok((self.alg, doc)) | ||
} | ||
|
||
/// Create a [`JWK`] for this [`RsaKey`] | ||
#[must_use] | ||
pub fn create_jwk(&self) -> JWK { | ||
JWK::new_from_rsa_key_pair(&self.inner, self.alg) | ||
.expect("error creating jwa from rsa keypair") | ||
} | ||
|
||
#[must_use] | ||
pub fn rng(&self) -> &SystemRandom { | ||
&self.rng | ||
} | ||
} | ||
|
||
impl Signer for RsaKey { | ||
type Signature = Vec<u8>; | ||
type Error = OpaqueError; | ||
|
||
fn set_headers( | ||
&self, | ||
protected_headers: &mut super::jws::Headers, | ||
_unprotected_headers: &mut super::jws::Headers, | ||
) -> Result<(), Self::Error> { | ||
let jwk = self.create_jwk(); | ||
protected_headers.try_set_headers(SigningHeaders { | ||
alg: jwk.alg, | ||
jwk: &jwk, | ||
})?; | ||
Ok(()) | ||
} | ||
|
||
fn sign(&self, data: &str) -> Result<Self::Signature, Self::Error> { | ||
let mut sig = vec![0; self.inner.public_modulus_len()]; | ||
self.inner | ||
.sign(self.alg.try_into()?, self.rng(), data.as_bytes(), &mut sig) | ||
.context("sign protected data")?; | ||
Ok(sig) | ||
} | ||
} | ||
|
||
mod utils { | ||
pub(super) fn encode_der_length(len: usize) -> Vec<u8> { | ||
if len < 128 { | ||
vec![len as u8] | ||
} else { | ||
let mut len_bytes = len.to_be_bytes().to_vec(); | ||
while len_bytes[0] == 0 { | ||
len_bytes.remove(0); | ||
} | ||
let first_byte = 0x80 | len_bytes.len() as u8; | ||
let mut result = vec![first_byte]; | ||
result.extend_from_slice(&len_bytes); | ||
result | ||
} | ||
} | ||
|
||
/// This function should only be used for parsing JWK encoded RSA values. | ||
/// The function should **not** be used for general ASN1 encoded values. | ||
/// The function assumes the input is in minimal form, not empty and is a | ||
/// positive integer. | ||
pub(super) fn encode_integer(bytes: Vec<u8>) -> Vec<u8> { | ||
let needs_leading_zero = bytes[0] & 0x80 != 0; | ||
let value_len = bytes.len() + needs_leading_zero as usize; | ||
let len_bytes = encode_der_length(value_len); | ||
let mut result = Vec::with_capacity(1 + len_bytes.len() + value_len); | ||
result.push(0x02); | ||
result.extend_from_slice(&len_bytes); | ||
if needs_leading_zero { | ||
result.push(0); | ||
} | ||
result.extend(bytes); | ||
result | ||
soundofspace marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
@@ -322,4 +539,16 @@ mod tests { | |
|
||
assert_eq!(key.create_jwk(), recreated_key.create_jwk()) | ||
} | ||
|
||
#[test] | ||
fn test_n_and_e_are_base64_encoded() { | ||
let rsa_key_pair = RsaKey::generate(KeySize::Rsa4096).unwrap(); | ||
let jwk = JWK::new_from_rsa_key_pair(&rsa_key_pair.inner, JWA::PS512).unwrap(); | ||
let (n, e) = match jwk.key_type { | ||
JWKType::RSA { n, e } => (n, e), | ||
_ => panic!("JWK type not RSA"), | ||
}; | ||
assert!(BASE64_URL_SAFE_NO_PAD.decode(n).is_ok()); | ||
assert!(BASE64_URL_SAFE_NO_PAD.decode(e).is_ok()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you also add these it should solve the issue you have with converting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And once that is is added, the logic from this one, should also be placed in there
impl TryFrom<JWA> for &'static EcdsaVerificationAlgorithm {
Something along these lines:
Otherwise there might be issues with try_into overlapping for the use case of verification