Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(auth): sync claim/token times in SA creds #789

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/auth/src/credentials/service_account_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

use crate::credentials::dynamic::CredentialTrait;
use crate::credentials::util::jws::{JwsClaimsBuilder, JwsHeader, DEFAULT_TOKEN_TIMEOUT};
use crate::credentials::util::jws::{
JwsClaims, JwsHeader, CLOCK_SKEW_FUDGE, DEFAULT_TOKEN_TIMEOUT,
};
use crate::credentials::Result;
use crate::errors::CredentialError;
use crate::token::{Token, TokenProvider};
Expand Down Expand Up @@ -66,11 +68,17 @@ impl TokenProvider for ServiceAccountTokenProvider {
async fn get_token(&self) -> Result<Token> {
let signer = self.signer(&self.service_account_info.private_key)?;

let claims = JwsClaimsBuilder::default()
.iss(self.service_account_info.client_email.as_str())
.scope(DEFAULT_SCOPES)
.build()
.map_err(CredentialError::non_retryable)?;
let now = OffsetDateTime::now_utc() - CLOCK_SKEW_FUDGE;
let exp = now + DEFAULT_TOKEN_TIMEOUT;
let claims = JwsClaims {
iss: self.service_account_info.client_email.clone(),
scope: Some(DEFAULT_SCOPES.map(|s| s.to_string()).to_vec()),
aud: None,
exp,
iat: now,
typ: None,
sub: None,
};
Comment on lines +71 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, you could create a JwsClaims::new(now: OffsetDateTime) -> Self function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are going to support options to override the scope, so that function will make less sense in the future.

(Also, we should not bother to have typ or sub if they are always None. I assume there will be some code path where they have concrete values in the future)


let encoded_header_claims = format!("{}.{}", DEFAULT_HEADER.encode()?, claims.encode()?);
let sig = signer
Expand All @@ -86,7 +94,7 @@ impl TokenProvider for ServiceAccountTokenProvider {
let token = Token {
token,
token_type: "Bearer".to_string(),
expires_at: Some(OffsetDateTime::now_utc() + DEFAULT_TOKEN_TIMEOUT),
expires_at: Some(exp),
metadata: None,
};
Ok(token)
Expand Down
145 changes: 70 additions & 75 deletions src/auth/src/credentials/util/jws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use crate::credentials::CredentialError;
use crate::credentials::Result;
use derive_builder::Builder;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think there is an unused dependency in src/auth/Cargo.toml now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still used elsewhere:

use serde::Serialize;
use std::time::Duration;
use time::OffsetDateTime;
Expand All @@ -29,32 +28,28 @@ pub(crate) const CLOCK_SKEW_FUDGE: Duration = Duration::from_secs(10);
pub(crate) const DEFAULT_TOKEN_TIMEOUT: Duration = Duration::from_secs(3600);

/// JSON Web Signature for a token.
#[derive(Clone, Serialize, Default, Builder)]
#[builder(setter(into, strip_option), default)]
pub struct JwsClaims<'a> {
pub iss: &'a str,
#[derive(Serialize)]
pub struct JwsClaims {
pub iss: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the simplicity of using String. It is more expensive, but this function will be called like once per hour so we should not complicate things too much in search of small performance improvements.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<Vec<&'a str>>,
pub aud: Option<&'a str>,
#[serde(with = "time::serde::timestamp::option")]
pub exp: Option<OffsetDateTime>,
#[serde(with = "time::serde::timestamp::option")]
pub iat: Option<OffsetDateTime>,
pub scope: Option<Vec<String>>,
pub aud: Option<String>,
#[serde(with = "time::serde::timestamp")]
pub exp: OffsetDateTime,
#[serde(with = "time::serde::timestamp")]
pub iat: OffsetDateTime,
#[serde(skip_serializing_if = "Option::is_none")]
pub typ: Option<&'a str>,
pub typ: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<&'a str>,
pub sub: Option<String>,
}

impl JwsClaims<'_> {
impl JwsClaims {
pub fn encode(&self) -> Result<String> {
let now = OffsetDateTime::now_utc() - CLOCK_SKEW_FUDGE;
let iat = self.iat.unwrap_or(now);
let exp = self.exp.unwrap_or_else(|| now + DEFAULT_TOKEN_TIMEOUT);
if exp < iat {
if self.exp < self.iat {
return Err(CredentialError::non_retryable(format!(
"expiration time {:?}, must be later than issued time {:?}",
exp, iat
self.exp, self.iat
)));
}

Expand All @@ -65,19 +60,13 @@ impl JwsClaims<'_> {
)));
}

let updated_jws_claim = JwsClaims {
iat: Some(iat),
exp: Some(exp),
..self.clone()
};
use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};
let json =
serde_json::to_string(&updated_jws_claim).map_err(CredentialError::non_retryable)?;
let json = serde_json::to_string(&self).map_err(CredentialError::non_retryable)?;
Ok(BASE64_URL_SAFE_NO_PAD.encode(json.as_bytes()))
}
}

/// The header that describes who, what, how a token was created.
/// The header that describes who, what, and how a token was created.
#[derive(Serialize)]
pub struct JwsHeader<'a> {
pub alg: &'a str,
Expand All @@ -101,12 +90,19 @@ mod tests {
use serde_json::Value;

#[test]
fn test_jws_claims_encode_defaults() {
let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.aud("test_aud")
.build()
.unwrap();
fn test_jws_claims_encode_partial() {
let now = OffsetDateTime::now_utc();
let then = now + Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: None,
aud: Some("test_aud".to_string()),
exp: then,
iat: now,
typ: None,
sub: None,
};

let encoded = claims.encode().unwrap();
let decoded = String::from_utf8(
Expand All @@ -116,44 +112,30 @@ mod tests {
)
.unwrap();

// 5 seconds is like 640KiB, good enough for everybody
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

const TOLERANCE: i64 = 5;
let now = OffsetDateTime::now_utc() - CLOCK_SKEW_FUDGE;
let expected_iat = now.unix_timestamp();
let expected_iat = (expected_iat - TOLERANCE)..=(expected_iat + TOLERANCE);
let expected_exp = (now + DEFAULT_TOKEN_TIMEOUT).unix_timestamp();
let expected_exp = (expected_exp - TOLERANCE)..=(expected_exp + TOLERANCE);

let v: Value = serde_json::from_str(&decoded).unwrap();
assert_eq!(v["iss"], "test_iss");
assert_eq!(v.get("scope"), None);
assert_eq!(v["aud"], "test_aud");
assert!(
expected_iat.contains(&v["iat"].as_i64().unwrap()),
"The iat field in {v:?} should be in the {expected_iat:?} range"
);
assert!(
expected_exp.contains(&v["exp"].as_i64().unwrap()),
"The exp field in {v:?} should be in the {expected_exp:?} range"
);
assert_eq!(v["iat"], now.unix_timestamp());
assert_eq!(v["exp"], then.unix_timestamp());
assert_eq!(v.get("typ"), None);
assert_eq!(v.get("sub"), None);
}

#[test]
fn test_jws_claims_encode_custom() {
let iat_custom = OffsetDateTime::now_utc() - DEFAULT_TOKEN_TIMEOUT;
let exp_custom = OffsetDateTime::now_utc() + DEFAULT_TOKEN_TIMEOUT;

let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.iat(iat_custom)
.exp(exp_custom)
.typ("test_typ")
.sub("test_sub")
.scope(vec!["scope1", "scope2"])
.build()
.unwrap();
fn test_jws_claims_encode_full() {
let now = OffsetDateTime::now_utc();
let then = now + Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: Some(vec!["scope1".to_string(), "scope2".to_string()]),
aud: None,
exp: then,
iat: now,
typ: Some("test_typ".to_string()),
sub: Some("test_sub".to_string()),
};

let encoded = claims.encode().unwrap();
let decoded = String::from_utf8(
Expand All @@ -167,19 +149,26 @@ mod tests {
assert_eq!(v["iss"], "test_iss");
assert_eq!(v["scope"], serde_json::json!(["scope1", "scope2"]));

assert_eq!(v["iat"], iat_custom.unix_timestamp());
assert_eq!(v["exp"], exp_custom.unix_timestamp());
assert_eq!(v["iat"], now.unix_timestamp());
assert_eq!(v["exp"], then.unix_timestamp());
assert_eq!(v["typ"], "test_typ");
assert_eq!(v["sub"], "test_sub");
}

#[test]
fn test_jws_claims_encode_error_exp_before_iat() {
let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.exp(OffsetDateTime::now_utc() - DEFAULT_TOKEN_TIMEOUT)
.build()
.unwrap();
let now = OffsetDateTime::now_utc();
let then = now - Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: None,
aud: None,
exp: then,
iat: now,
typ: None,
sub: None,
};
let expected_error_message = "must be later than issued time";
assert!(claims
.encode()
Expand All @@ -188,12 +177,18 @@ mod tests {

#[test]
fn test_jws_claims_encode_error_set_scope_and_aud() {
let claims = JwsClaimsBuilder::default()
.iss("test_iss")
.scope(vec!["scope1", "scope2"])
.aud("test_aud")
.build()
.unwrap();
let now = OffsetDateTime::now_utc();
let then = now + Duration::from_secs(4200);

let claims = JwsClaims {
iss: "test_iss".to_string(),
scope: Some(vec!["scope".to_string()]),
aud: Some("test-aud".to_string()),
exp: then,
iat: now,
typ: None,
sub: None,
};
let expected_error_message = "expecting only 1 of them to be set";
assert!(claims
.encode()
Expand Down
Loading