Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Require list of expected algorithms when secret/publicKey is given #135

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,13 @@ fn config_options<'a, 'b>() -> App<'a, 'b> {
.index(1)
.required(true),
).arg(
Arg::with_name("algorithm")
.help("the algorithm to use for signing the JWT")
Arg::with_name("algorithms")
.help("a comma-separated list of algorithms to be used for signature validation. All algorithms need to be of the same family (HMAC, RSA, EC).")
.require_delimiter(true)
.takes_value(true)
.long("alg")
.long("algs")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can probably just stay as alg? Less breaking changes that way

.short("A")
.possible_values(&SupportedAlgorithms::variants())
.default_value("HS256"),
).arg(
Arg::with_name("iso_dates")
.help("display unix timestamps as ISO 8601 dates")
Expand All @@ -264,7 +264,7 @@ fn config_options<'a, 'b>() -> App<'a, 'b> {
.takes_value(true)
.long("secret")
.short("S")
.default_value(""),
.requires("algorithms")
).arg(
Arg::with_name("json")
.help("render decoded JWT as JSON")
Expand Down Expand Up @@ -465,13 +465,6 @@ fn decode_token(
JWTResult<TokenData<Payload>>,
OutputFormat,
) {
let algorithm = translate_algorithm(SupportedAlgorithms::from_string(
matches.value_of("algorithm").unwrap(),
));
let secret = match matches.value_of("secret").map(|s| (s, !s.is_empty())) {
Some((secret, true)) => Some(decoding_key_from_secret(&algorithm, &secret)),
_ => None,
};
let jwt = matches
.value_of("jwt")
.map(|value| {
Expand All @@ -491,13 +484,7 @@ fn decode_token(
.trim()
.to_owned();

let secret_validator = Validation {
leeway: 1000,
algorithms: vec![algorithm],
validate_exp: !matches.is_present("ignore_exp"),
..Default::default()
};

// decode token without signature verification
let token_data = dangerous_insecure_decode::<Payload>(&jwt).map(|mut token| {
if matches.is_present("iso_dates") {
token.claims.convert_timestamps();
Expand All @@ -506,6 +493,31 @@ fn decode_token(
token
});

// get vector of allowed algorithms from command line argument
let algorithms: Vec<Algorithm> = match matches.values_of("algorithms") {
Some(algorithms) => algorithms
.map(|x| translate_algorithm(SupportedAlgorithms::from_string(x)))
.collect(),
None => vec![],
};

let secret_validator = Validation {
leeway: 1000,
algorithms: algorithms,
validate_exp: !matches.is_present("ignore_exp"),
..Default::default()
};

// get the shared secret/public key to be used for signature validation
let secret = match matches.value_of("secret").map(|s| (s, !s.is_empty())) {
Some((secret, true)) => Some(decoding_key_from_secret(
&token_data.as_ref().unwrap().header.alg, // decode key according to algorithm used in the JWT
&secret,
)),
_ => None,
};

// return validated token, non-validated token data and output format
(
match secret {
Some(secret_key) => decode::<Payload>(&jwt, &secret_key.unwrap(), &secret_validator),
Expand Down
196 changes: 187 additions & 9 deletions tests/jwt-cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand Down Expand Up @@ -257,7 +265,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand All @@ -279,7 +295,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, token_data, _) = decode_token(&decode_matches);
Expand All @@ -299,7 +323,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand Down Expand Up @@ -328,7 +360,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand Down Expand Up @@ -356,7 +396,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand All @@ -378,7 +426,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS512",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand All @@ -399,6 +455,8 @@ mod tests {
"decode",
"-S",
"1234567890",
"-A",
"HS256",
"--ignore-exp",
&encoded_token,
])
Expand All @@ -424,7 +482,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand Down Expand Up @@ -460,7 +526,15 @@ mod tests {
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token])
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (decoded_token, _, _) = decode_token(&decode_matches);
Expand Down Expand Up @@ -597,6 +671,108 @@ mod tests {
assert!(result.is_ok());
}

#[test]
fn encodes_and_decodes_a_token_with_multiple_algorithms() {
let body: String = "{\"field\":\"value\"}".to_string();
let encode_matcher = config_options()
.get_matches_from_safe(vec![
"jwt",
"encode",
"-A",
"HS256",
"--exp",
"-S",
"1234567890",
&body,
])
.unwrap();
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256,HS384,HS512",
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (result, _, _) = decode_token(&decode_matches);

assert!(result.is_ok());
}

#[test]
fn encodes_and_decodes_a_token_with_invalid_algorithms_family() {
let body: String = "{\"field\":\"value\"}".to_string();
let encode_matcher = config_options()
.get_matches_from_safe(vec![
"jwt",
"encode",
"-A",
"HS256",
"--exp",
"-S",
"1234567890",
&body,
])
.unwrap();
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"RS256,RS384,RS512", // invalid algorithm family
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (result, _, _) = decode_token(&decode_matches);

assert!(result.is_err());
}

#[test]
fn encodes_and_decodes_a_token_with_mixed_algorithms_family() {
let body: String = "{\"field\":\"value\"}".to_string();
let encode_matcher = config_options()
.get_matches_from_safe(vec![
"jwt",
"encode",
"-A",
"HS256",
"--exp",
"-S",
"1234567890",
&body,
])
.unwrap();
let encode_matches = encode_matcher.subcommand_matches("encode").unwrap();
let encoded_token = encode_token(&encode_matches).unwrap();
let decode_matcher = config_options()
.get_matches_from_safe(vec![
"jwt",
"decode",
"-S",
"1234567890",
"-A",
"HS256,RS512", // algorithms from incompatible algorithm families
&encoded_token,
])
.unwrap();
let decode_matches = decode_matcher.subcommand_matches("decode").unwrap();
let (result, _, _) = decode_token(&decode_matches);

assert!(result.is_err());
}

#[test]
fn encodes_and_decodes_an_rsa_token_using_key_from_file() {
let body: String = "{\"field\":\"value\"}".to_string();
Expand Down Expand Up @@ -705,6 +881,8 @@ mod tests {
"decode",
"-S",
"1234567890",
"-A",
"HS256",
"--iso8601",
&encoded_token,
])
Expand Down