Skip to content

Commit

Permalink
fix(auth): sync claim/token times in SA creds (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolduc authored Jan 22, 2025
1 parent 2f601d5 commit e25a672
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 82 deletions.
22 changes: 15 additions & 7 deletions src/auth/src/credentials/service_account_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

use crate::credentials::dynamic::CredentialTrait;
use crate::credentials::util::jws::{JwsClaimsBuilder, JwsHeader, DEFAULT_TOKEN_TIMEOUT};
use crate::credentials::util::jws::{
JwsClaims, JwsHeader, CLOCK_SKEW_FUDGE, DEFAULT_TOKEN_TIMEOUT,
};
use crate::credentials::Result;
use crate::errors::CredentialError;
use crate::token::{Token, TokenProvider};
Expand Down Expand Up @@ -66,11 +68,17 @@ impl TokenProvider for ServiceAccountTokenProvider {
async fn get_token(&self) -> Result<Token> {
let signer = self.signer(&self.service_account_info.private_key)?;

let claims = JwsClaimsBuilder::default()
.iss(self.service_account_info.client_email.as_str())
.scope(DEFAULT_SCOPES)
.build()
.map_err(CredentialError::non_retryable)?;
let now = OffsetDateTime::now_utc() - CLOCK_SKEW_FUDGE;
let exp = now + DEFAULT_TOKEN_TIMEOUT;
let claims = JwsClaims {
iss: self.service_account_info.client_email.clone(),
scope: Some(DEFAULT_SCOPES.map(|s| s.to_string()).to_vec()),
aud: None,
exp,
iat: now,
typ: None,
sub: None,
};

let encoded_header_claims = format!("{}.{}", DEFAULT_HEADER.encode()?, claims.encode()?);
let sig = signer
Expand All @@ -86,7 +94,7 @@ impl TokenProvider for ServiceAccountTokenProvider {
let token = Token {
token,
token_type: "Bearer".to_string(),
expires_at: Some(OffsetDateTime::now_utc() + DEFAULT_TOKEN_TIMEOUT),
expires_at: Some(exp),
metadata: None,
};
Ok(token)
Expand Down
145 changes: 70 additions & 75 deletions src/auth/src/credentials/util/jws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use crate::credentials::CredentialError;
use crate::credentials::Result;
use derive_builder::Builder;
use serde::Serialize;
use std::time::Duration;
use time::OffsetDateTime;
Expand All @@ -29,32 +28,28 @@ pub(crate) const CLOCK_SKEW_FUDGE: Duration = Duration::from_secs(10);
pub(crate) const DEFAULT_TOKEN_TIMEOUT: Duration = Duration::from_secs(3600);

/// JSON Web Signature for a token.
#[derive(Clone, Serialize, Default, Builder)]
#[builder(setter(into, strip_option), default)]
pub struct JwsClaims<'a> {
pub iss: &'a str,
#[derive(Serialize)]
pub struct JwsClaims {
pub iss: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<Vec<&'a str>>,
pub aud: Option<&'a str>,
#[serde(with = "time::serde::timestamp::option")]
pub exp: Option<OffsetDateTime>,
#[serde(with = "time::serde::timestamp::option")]
pub iat: Option<OffsetDateTime>,
pub scope: Option<Vec<String>>,
pub aud: Option<String>,
#[serde(with = "time::serde::timestamp")]
pub exp: OffsetDateTime,
#[serde(with = "time::serde::timestamp")]
pub iat: OffsetDateTime,
#[serde(skip_serializing_if = "Option::is_none")]
pub typ: Option<&'a str>,
pub typ: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<&'a str>,
pub sub: Option<String>,
}

impl JwsClaims<'_> {
impl JwsClaims {
pub fn encode(&self) -> Result<String> {
let now = OffsetDateTime::now_utc() - CLOCK_SKEW_FUDGE;
let iat = self.iat.unwrap_or(now);
let exp = self.exp.unwrap_or_else(|| now + DEFAULT_TOKEN_TIMEOUT);
if exp < iat {
if self.exp < self.iat {
return Err(CredentialError::non_retryable(format!(
"expiration time {:?}, must be later than issued time {:?}",
exp, iat
self.exp, self.iat
)));
}

Expand All @@ -65,19 +60,13 @@ impl JwsClaims<'_> {
)));
}

let updated_jws_claim = JwsClaims {
iat: Some(iat),
exp: Some(exp),
..self.clone()
};
use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};
let json =
serde_json::to_string(&updated_jws_claim).map_err(CredentialError::non_retryable)?;
let json = serde_json::to_string(&self).map_err(CredentialError::non_retryable)?;
Ok(BASE64_URL_SAFE_NO_PAD.encode(json.as_bytes()))
}
}

/// The header that describes who, what, how a token was created.
/// The header that describes who, what, and how a token was created.
#[derive(Serialize)]
pub struct JwsHeader<'a> {
pub alg: &'a str,
Expand All @@ -101,12 +90,19 @@ mod tests {
use serde_json::Value;

#[test]
fn test_jws_claims_encode_defaults() {
let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.aud("test_aud")
.build()
.unwrap();
fn test_jws_claims_encode_partial() {
let now = OffsetDateTime::now_utc();
let then = now + Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: None,
aud: Some("test_aud".to_string()),
exp: then,
iat: now,
typ: None,
sub: None,
};

let encoded = claims.encode().unwrap();
let decoded = String::from_utf8(
Expand All @@ -116,44 +112,30 @@ mod tests {
)
.unwrap();

// 5 seconds is like 640KiB, good enough for everybody
const TOLERANCE: i64 = 5;
let now = OffsetDateTime::now_utc() - CLOCK_SKEW_FUDGE;
let expected_iat = now.unix_timestamp();
let expected_iat = (expected_iat - TOLERANCE)..=(expected_iat + TOLERANCE);
let expected_exp = (now + DEFAULT_TOKEN_TIMEOUT).unix_timestamp();
let expected_exp = (expected_exp - TOLERANCE)..=(expected_exp + TOLERANCE);

let v: Value = serde_json::from_str(&decoded).unwrap();
assert_eq!(v["iss"], "test_iss");
assert_eq!(v.get("scope"), None);
assert_eq!(v["aud"], "test_aud");
assert!(
expected_iat.contains(&v["iat"].as_i64().unwrap()),
"The iat field in {v:?} should be in the {expected_iat:?} range"
);
assert!(
expected_exp.contains(&v["exp"].as_i64().unwrap()),
"The exp field in {v:?} should be in the {expected_exp:?} range"
);
assert_eq!(v["iat"], now.unix_timestamp());
assert_eq!(v["exp"], then.unix_timestamp());
assert_eq!(v.get("typ"), None);
assert_eq!(v.get("sub"), None);
}

#[test]
fn test_jws_claims_encode_custom() {
let iat_custom = OffsetDateTime::now_utc() - DEFAULT_TOKEN_TIMEOUT;
let exp_custom = OffsetDateTime::now_utc() + DEFAULT_TOKEN_TIMEOUT;

let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.iat(iat_custom)
.exp(exp_custom)
.typ("test_typ")
.sub("test_sub")
.scope(vec!["scope1", "scope2"])
.build()
.unwrap();
fn test_jws_claims_encode_full() {
let now = OffsetDateTime::now_utc();
let then = now + Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: Some(vec!["scope1".to_string(), "scope2".to_string()]),
aud: None,
exp: then,
iat: now,
typ: Some("test_typ".to_string()),
sub: Some("test_sub".to_string()),
};

let encoded = claims.encode().unwrap();
let decoded = String::from_utf8(
Expand All @@ -167,19 +149,26 @@ mod tests {
assert_eq!(v["iss"], "test_iss");
assert_eq!(v["scope"], serde_json::json!(["scope1", "scope2"]));

assert_eq!(v["iat"], iat_custom.unix_timestamp());
assert_eq!(v["exp"], exp_custom.unix_timestamp());
assert_eq!(v["iat"], now.unix_timestamp());
assert_eq!(v["exp"], then.unix_timestamp());
assert_eq!(v["typ"], "test_typ");
assert_eq!(v["sub"], "test_sub");
}

#[test]
fn test_jws_claims_encode_error_exp_before_iat() {
let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.exp(OffsetDateTime::now_utc() - DEFAULT_TOKEN_TIMEOUT)
.build()
.unwrap();
let now = OffsetDateTime::now_utc();
let then = now - Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: None,
aud: None,
exp: then,
iat: now,
typ: None,
sub: None,
};
let expected_error_message = "must be later than issued time";
assert!(claims
.encode()
Expand All @@ -188,12 +177,18 @@ mod tests {

#[test]
fn test_jws_claims_encode_error_set_scope_and_aud() {
let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.scope(vec!["scope1", "scope2"])
.aud("test_aud")
.build()
.unwrap();
let now = OffsetDateTime::now_utc();
let then = now + Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: Some(vec!["scope".to_string()]),
aud: Some("test-aud".to_string()),
exp: then,
iat: now,
typ: None,
sub: None,
};
let expected_error_message = "expecting only 1 of them to be set";
assert!(claims
.encode()
Expand Down

0 comments on commit e25a672

Please sign in to comment.