Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rama-crypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ rustls-pki-types = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
x509-parser = { workspace = true }
byteorder = { workspace = true }

[dev-dependencies]
tokio = { workspace = true, features = ["full"] }
Expand Down
69 changes: 67 additions & 2 deletions rama-crypto/src/jose/jwa.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
use std::ops::Deref;

use aws_lc_rs::signature::{
ECDSA_P256_SHA256_FIXED_SIGNING, ECDSA_P384_SHA384_FIXED_SIGNING, EcdsaSigningAlgorithm,
EcdsaVerificationAlgorithm,
RSA_PKCS1_2048_8192_SHA256, RSA_PKCS1_2048_8192_SHA384, RSA_PKCS1_2048_8192_SHA512,
RSA_PSS_2048_8192_SHA256, RSA_PSS_2048_8192_SHA384, RSA_PSS_2048_8192_SHA512,
VerificationAlgorithm,
};
use aws_lc_rs::{
hmac,
hmac::{HMAC_SHA256, HMAC_SHA384, HMAC_SHA512},
signature::{
ECDSA_P256_SHA256_FIXED_SIGNING, ECDSA_P384_SHA384_FIXED_SIGNING, EcdsaSigningAlgorithm,
EcdsaVerificationAlgorithm, RSA_PKCS1_SHA256, RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
RSA_PSS_SHA256, RSA_PSS_SHA384, RSA_PSS_SHA512, RsaEncoding,
},
};
use rama_core::error::OpaqueError;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -99,3 +109,58 @@ impl TryFrom<JWA> for &'static EcdsaVerificationAlgorithm {
Ok(signing_algo.deref())
}
}

impl TryFrom<JWA> for &'static hmac::Algorithm {
type Error = OpaqueError;

fn try_from(value: JWA) -> Result<Self, Self::Error> {
match value {
JWA::HS256 => Ok(&HMAC_SHA256),
JWA::HS384 => Ok(&HMAC_SHA384),
JWA::HS512 => Ok(&HMAC_SHA512),
_ => Err(OpaqueError::from_display(
"Non-Hmac algorithm cannot be converted to hmac types",
)),
}
}
}

impl TryFrom<JWA> for &'static dyn RsaEncoding {
type Error = OpaqueError;

fn try_from(value: JWA) -> Result<Self, Self::Error> {
match value {
JWA::RS256 => Ok(&RSA_PKCS1_SHA256),
JWA::RS384 => Ok(&RSA_PKCS1_SHA384),
JWA::RS512 => Ok(&RSA_PKCS1_SHA512),
JWA::PS256 => Ok(&RSA_PSS_SHA256),
JWA::PS384 => Ok(&RSA_PSS_SHA384),
JWA::PS512 => Ok(&RSA_PSS_SHA512),
_ => Err(OpaqueError::from_display(
"Non-RSA algorithm cannot be converted to rsa types",
)),
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you also add these it should solve the issue you have with converting

impl TryFrom<JWA> for &'static dyn VerificationAlgorithm {
    type Error = OpaqueError;

    fn try_from(value: JWA) -> Result<Self, Self::Error> {
        match value {
            JWA::RS256 => Ok(&RSA_PKCS1_2048_8192_SHA256),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And once that is is added, the logic from this one, should also be placed in there
impl TryFrom<JWA> for &'static EcdsaVerificationAlgorithm {

Something along these lines:

           JWA::PS512 => Ok(&RSA_PSS_2048_8192_SHA512),
            JWA::ES256 | JWA::ES384 ... => {
                let signing_algo: &'static EcdsaSigningAlgorithm = value.try_into()?;
                Ok(signing_algo.deref())
            }

Otherwise there might be issues with try_into overlapping for the use case of verification


impl TryFrom<JWA> for &'static dyn VerificationAlgorithm {
type Error = OpaqueError;

fn try_from(value: JWA) -> Result<Self, Self::Error> {
match value {
JWA::RS256 => Ok(&RSA_PKCS1_2048_8192_SHA256),
JWA::RS384 => Ok(&RSA_PKCS1_2048_8192_SHA384),
JWA::RS512 => Ok(&RSA_PKCS1_2048_8192_SHA512),
JWA::PS256 => Ok(&RSA_PSS_2048_8192_SHA256),
JWA::PS384 => Ok(&RSA_PSS_2048_8192_SHA384),
JWA::PS512 => Ok(&RSA_PSS_2048_8192_SHA512),
JWA::ES256 | JWA::ES384 | JWA::ES512 => {
let signing_algo: &'static EcdsaSigningAlgorithm = value.try_into()?;
Ok(signing_algo.deref())
}
_ => Err(OpaqueError::from_display(
"Verification algorithm is not supported",
)),
}
}
}
235 changes: 232 additions & 3 deletions rama-crypto/src/jose/jwk.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use aws_lc_rs::encoding::{AsDer, Pkcs8V1Der};
use aws_lc_rs::rsa::KeySize;
use aws_lc_rs::signature::RsaKeyPair;
use aws_lc_rs::{
digest::{Digest, SHA256, digest},
pkcs8::Document,
Expand Down Expand Up @@ -156,7 +159,40 @@ impl JWK {
&self,
) -> Result<signature::UnparsedPublicKey<Vec<u8>>, OpaqueError> {
match &self.key_type {
JWKType::RSA { .. } => Err(OpaqueError::from_display("currently not supported")),
JWKType::RSA { n, e } => {
let n_bytes = BASE64_URL_SAFE_NO_PAD
.decode(n)
.context("decode RSA modulus (n)")?;
let e_bytes = BASE64_URL_SAFE_NO_PAD
.decode(e)
.context("decode RSA exponent (e)")?;
let n_der_encoded = utils::encode_integer(n_bytes);
let e_der_encoded = utils::encode_integer(e_bytes);

let rsa_public_key_sequence =
Self::create_public_key_sequence(n_der_encoded, e_der_encoded);

let rsa_key_len = utils::encode_der_length(rsa_public_key_sequence.len());

let der_rsa_key = Self::create_der_rsa_key(rsa_key_len, rsa_public_key_sequence);

let bit_string_payload = Self::create_bit_string_payload(der_rsa_key);
let bit_string_len = utils::encode_der_length(bit_string_payload.len());

let bit_string = Self::create_bit_string(bit_string_len, bit_string_payload);

let algorithm_identifier = [
0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01,
0x05, 0x00,
];

let final_sequence = Self::create_final_sequence(algorithm_identifier, bit_string);

Ok(signature::UnparsedPublicKey::new(
self.alg.try_into()?,
final_sequence,
))
}
JWKType::OCT { .. } => Err(OpaqueError::from_display(
"Symmetric key cannot be converted to public key",
)),
Expand All @@ -180,6 +216,70 @@ impl JWK {
}
}
}

/// Creates a new [`JWK`] from a given [`RSAKeyPair`]
pub fn new_from_rsa_key_pair(rsa_key_pair: &RsaKeyPair, alg: JWA) -> Result<Self, OpaqueError> {
let n = rsa_key_pair.public_key().modulus();
let e = rsa_key_pair.public_key().exponent();
Ok(Self {
alg,
key_type: JWKType::RSA {
n: BASE64_URL_SAFE_NO_PAD.encode(n.big_endian_without_leading_zero()),
e: BASE64_URL_SAFE_NO_PAD.encode(e.big_endian_without_leading_zero()),
},

r#use: Some(JWKUse::Signature),
key_ops: None,
x5c: None,
x5t: None,
x5t_sha256: None,
})
}

fn create_public_key_sequence(n_der_encoded: Vec<u8>, e_der_encoded: Vec<u8>) -> Vec<u8> {
let mut rsa_public_key_sequence =
Vec::with_capacity(n_der_encoded.len() + e_der_encoded.len());
rsa_public_key_sequence.extend(n_der_encoded);
rsa_public_key_sequence.extend(e_der_encoded);
rsa_public_key_sequence
}

fn create_der_rsa_key(rsa_key_len: Vec<u8>, rsa_public_key_sequence: Vec<u8>) -> Vec<u8> {
let mut rsa_key_der =
Vec::with_capacity(1 + rsa_key_len.len() + rsa_public_key_sequence.len());
rsa_key_der.push(0x30);
rsa_key_der.extend(rsa_key_len);
rsa_key_der.extend(rsa_public_key_sequence);
rsa_key_der
}

fn create_bit_string_payload(rsa_key_der: Vec<u8>) -> Vec<u8> {
let mut bit_string_payload = Vec::with_capacity(1 + rsa_key_der.len());
bit_string_payload.push(0x00);
bit_string_payload.extend(rsa_key_der);
bit_string_payload
}

fn create_bit_string(bit_string_len: Vec<u8>, bit_string_payload: Vec<u8>) -> Vec<u8> {
let mut bit_string =
Vec::with_capacity(1 + bit_string_len.len() + bit_string_payload.len());
bit_string.push(0x03);
bit_string.extend(bit_string_len);
bit_string.extend(bit_string_payload);
bit_string
}

fn create_final_sequence(algorithm_identifier: [u8; 15], bit_string: Vec<u8>) -> Vec<u8> {
let mut final_sequence = Vec::with_capacity(algorithm_identifier.len() + bit_string.len());
final_sequence.extend(algorithm_identifier);
final_sequence.extend(bit_string);
let final_sequence_len = utils::encode_der_length(final_sequence.len());
let mut result = Vec::with_capacity(1 + final_sequence_len.len() + final_sequence.len());
result.push(0x30);
result.extend(final_sequence_len);
result.extend(final_sequence);
result
}
}

/// [`EcdsaKey`] which is used to identify and authenticate our requests
Expand Down Expand Up @@ -252,7 +352,7 @@ impl EcdsaKey {
}

#[derive(Serialize)]
struct EcdsaKeySigningHeaders<'a> {
struct SigningHeaders<'a> {
alg: JWA,
jwk: &'a JWK,
}
Expand All @@ -267,7 +367,7 @@ impl Signer for EcdsaKey {
_unprotected_headers: &mut super::jws::Headers,
) -> Result<(), Self::Error> {
let jwk = self.create_jwk();
protected_headers.try_set_headers(EcdsaKeySigningHeaders {
protected_headers.try_set_headers(SigningHeaders {
alg: jwk.alg,
jwk: &jwk,
})?;
Expand All @@ -284,6 +384,123 @@ impl Signer for EcdsaKey {
}
}

pub struct RsaKey {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one will also have to import/exported in mod.rs

When you run linting/clippy/tests

just qa

or only linting/clippy without tests

just qq

It should also detect these things, same for unused imports

rng: SystemRandom,
alg: JWA,
inner: RsaKeyPair,
}

impl RsaKey {
/// Create a new [`RsaKey`] from the given [`RsaKeyPair`]
pub fn new(key_pair: RsaKeyPair, alg: JWA, rng: SystemRandom) -> Result<Self, OpaqueError> {
Ok(Self {
rng,
alg,
inner: key_pair,
})
}

/// Generate a new [`RsaKey`] from a newly generated [`RsaKeyPair`]
pub fn generate(key_size: KeySize) -> Result<Self, OpaqueError> {
let key_pair = RsaKeyPair::generate(key_size).context("error generating rsa key pair")?;

Self::new(key_pair, JWA::RS256, SystemRandom::new())
}

/// Generate a new [`RsaKey`] from the given pkcs8 der
pub fn from_pkcs8_der(
pkcs8_der: &[u8],
alg: JWA,
rng: SystemRandom,
) -> Result<Self, OpaqueError> {
let key_pair = RsaKeyPair::from_pkcs8(pkcs8_der).context("create RSAKeyPair from pkcs8")?;

Self::new(key_pair, alg, rng)
}

/// Create pkcs8 der for the current [`RsaKeyPair`]
pub fn pkcs8_der(&self) -> Result<(JWA, Pkcs8V1Der<'static>), OpaqueError> {
let doc = self
.inner
.as_der()
.context("error creating pkcs8 der from rsa keypair")?;
Ok((self.alg, doc))
}

/// Create a [`JWK`] for this [`RsaKey`]
#[must_use]
pub fn create_jwk(&self) -> JWK {
JWK::new_from_rsa_key_pair(&self.inner, self.alg)
.expect("error creating jwa from rsa keypair")
}

#[must_use]
pub fn rng(&self) -> &SystemRandom {
&self.rng
}
}

impl Signer for RsaKey {
type Signature = Vec<u8>;
type Error = OpaqueError;

fn set_headers(
&self,
protected_headers: &mut super::jws::Headers,
_unprotected_headers: &mut super::jws::Headers,
) -> Result<(), Self::Error> {
let jwk = self.create_jwk();
protected_headers.try_set_headers(SigningHeaders {
alg: jwk.alg,
jwk: &jwk,
})?;
Ok(())
}

fn sign(&self, data: &str) -> Result<Self::Signature, Self::Error> {
let mut sig = vec![0; self.inner.public_modulus_len()];
self.inner
.sign(self.alg.try_into()?, self.rng(), data.as_bytes(), &mut sig)
.context("sign protected data")?;
Ok(sig)
}
}

mod utils {
pub(super) fn encode_der_length(len: usize) -> Vec<u8> {
if len < 128 {
vec![len as u8]
} else {
let mut len_bytes = len.to_be_bytes().to_vec();
while len_bytes[0] == 0 {
len_bytes.remove(0);
}
let first_byte = 0x80 | len_bytes.len() as u8;
let mut result = vec![first_byte];
result.extend_from_slice(&len_bytes);
result
}
}

/// This function should only be used for parsing JWK encoded RSA values.
/// The function should **not** be used for general ASN1 encoded values.
/// The function assumes the input is in minimal form, not empty and is a
/// positive integer.
pub(super) fn encode_integer(bytes: Vec<u8>) -> Vec<u8> {
let needs_leading_zero = bytes[0] & 0x80 != 0;
let value_len = bytes.len() + needs_leading_zero as usize;
let len_bytes = encode_der_length(value_len);
let mut result = Vec::with_capacity(1 + len_bytes.len() + value_len);
result.push(0x02);
result.extend_from_slice(&len_bytes);
if needs_leading_zero {
result.push(0);
}
result.extend(bytes);
result
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -322,4 +539,16 @@ mod tests {

assert_eq!(key.create_jwk(), recreated_key.create_jwk())
}

#[test]
fn test_n_and_e_are_base64_encoded() {
let rsa_key_pair = RsaKey::generate(KeySize::Rsa4096).unwrap();
let jwk = JWK::new_from_rsa_key_pair(&rsa_key_pair.inner, JWA::PS512).unwrap();
let (n, e) = match jwk.key_type {
JWKType::RSA { n, e } => (n, e),
_ => panic!("JWK type not RSA"),
};
assert!(BASE64_URL_SAFE_NO_PAD.decode(n).is_ok());
assert!(BASE64_URL_SAFE_NO_PAD.decode(e).is_ok());
}
}
2 changes: 1 addition & 1 deletion rama-crypto/src/jose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mod jwa;
pub use jwa::JWA;

mod jwk;
pub use jwk::{EcdsaKey, JWK, JWKEllipticCurves, JWKType, JWKUse};
pub use jwk::{EcdsaKey, JWK, JWKEllipticCurves, JWKType, JWKUse, RsaKey};

mod jws;
pub use jws::{
Expand Down
Loading