diff --git a/crates/arroyo-storage/src/aws.rs b/crates/arroyo-storage/src/aws.rs index f7701a748..92fba14bf 100644 --- a/crates/arroyo-storage/src/aws.rs +++ b/crates/arroyo-storage/src/aws.rs @@ -15,7 +15,7 @@ type TemporaryToken = (Arc, Option, Instant); #[derive(Clone)] pub struct ArroyoCredentialProvider { - cache: Arc>>, + cache: Arc>, provider: SharedCredentialsProvider, refresh_task: Arc>>>, } @@ -63,8 +63,12 @@ impl ArroyoCredentialProvider { .clone(); info!("Creating credential provider"); + + let initial_token = get_token(&credentials).await?; + let cache = Arc::new(Mutex::new(initial_token)); + Ok::(Self { - cache: Default::default(), + cache, refresh_task: Default::default(), provider: credentials, }) @@ -105,20 +109,21 @@ impl CredentialProvider for ArroyoCredentialProvider { type Credential = AwsCredential; async fn get_credential(&self) -> object_store::Result> { - let token = self.cache.lock().await.clone(); + let token = { self.cache.lock().unwrap().clone() }; + match token { - Some((token, Some(expiration), last_refreshed)) => { + (token, Some(expiration), last_refreshed) => { let expires_in = expiration .duration_since(SystemTime::now()) .unwrap_or_default(); if expires_in < Duration::from_millis(100) { info!("AWS token has expired, immediately refreshing"); - let lock = self.cache.try_lock(); let token = get_token(&self.provider).await?; + let lock = self.cache.try_lock(); if let Ok(mut lock) = lock { - *lock = Some(token.clone()); + *lock = token.clone(); } return Ok(token.0); } @@ -141,26 +146,15 @@ impl CredentialProvider for ArroyoCredentialProvider { .await .unwrap_or_else(|e| panic!("Failed to refresh AWS token: {:?}", e)); - let mut lock = our_lock.lock().await; - *lock = Some(token); + let mut lock = our_lock.lock().unwrap(); + *lock = token; })); } } Ok(token) } - Some((token, None, _)) => Ok(token), - None => { - // get the initial token - let mut cache = self.cache.lock().await; - if let Some((token, _, _)) = &*cache { - return Ok(token.clone()); - } - - let token = get_token(&self.provider).await?; - *cache = Some(token.clone()); - Ok(token.0) - } + (token, None, _) => Ok(token), } } }