Skip to content

Commit 621999f

Browse files
committed
Add caching to metadata retrieval
1 parent 135e8fb commit 621999f

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

ee/tabby-webserver/src/oauth/general.rs

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::sync::{Arc, LazyLock, Mutex};
33
use anyhow::anyhow;
44
use anyhow::Result;
55
use async_trait::async_trait;
6+
use cached::{proc_macro::cached, TimedCache};
67
use openidconnect::{
78
AccessTokenHash,
89
AuthorizationCode,
@@ -64,21 +65,8 @@ impl GeneralClient {
6465
}
6566
}
6667

67-
// TODO: Ensure that the HTTP client *does not* follow redirects.
68-
// TODO: Cache the HTTP response so we do not hit the endpoint every time we need the OIDC Discovery Endpoint
69-
async fn retrieve_provider_metadata(&self, config_url: Option<String>) ->Result<CoreProviderMetadata, anyhow::Error> {
70-
let config_url = config_url.unwrap_or_else(|| "".to_owned());
71-
72-
let client = reqwest::Client::new();
73-
let provider_metadata = CoreProviderMetadata::discover_async(
74-
IssuerUrl::new(config_url).ok().unwrap(),
75-
&client,
76-
).await;
77-
78-
match provider_metadata {
79-
Ok(provider_metadata) => Ok(provider_metadata),
80-
Err(e) => bail!(e),
81-
}
68+
async fn retrieve_provider_metadata( &self, config_url: String) -> Option<CoreProviderMetadata> {
69+
retrieve_provider_metadata(config_url).await
8270
}
8371
}
8472

@@ -97,12 +85,11 @@ impl OAuthClient for GeneralClient {
9785
};
9886

9987
let credential = self.read_credential().await?;
100-
let config_url = credential.config_url;
101-
let provider_metadata = match self.retrieve_provider_metadata(config_url).await
102-
{
103-
Ok(config) => config,
104-
Err(err) => bail!(err),
88+
let config_url = match credential.config_url {
89+
Some(config_url) => config_url,
90+
None => bail!("No config url found."),
10591
};
92+
let provider_metadata = self.retrieve_provider_metadata(config_url).await.unwrap();
10693
let oidc_client = CoreClient::from_provider_metadata(
10794
provider_metadata,
10895
ClientId::new(credential.client_id),
@@ -178,12 +165,11 @@ impl OAuthClient for GeneralClient {
178165

179166
async fn get_authorization_url(&self) -> Result<String> {
180167
let credential = self.read_credential().await?;
181-
let config_url = credential.config_url;
182-
let provider_metadata = match self.retrieve_provider_metadata(config_url).await
183-
{
184-
Ok(config) => config,
185-
Err(err) => bail!(err),
168+
let config_url = match credential.config_url {
169+
Some(config_url) => config_url,
170+
None => bail!("No config url found."),
186171
};
172+
let provider_metadata = self.retrieve_provider_metadata(config_url).await.unwrap();
187173

188174
let redirect_uri = RedirectUrl::new(
189175
self.auth.oauth_callback_url(OAuthProvider::General).await?
@@ -226,3 +212,21 @@ impl OAuthClient for GeneralClient {
226212
Ok(auth_uri.to_string())
227213
}
228214
}
215+
216+
217+
#[cached(
218+
type = "TimedCache<String, Option<CoreProviderMetadata>>",
219+
create = "{ TimedCache::with_lifespan(3600 * 12) }"
220+
)]
221+
async fn retrieve_provider_metadata(config_url: String) -> Option<CoreProviderMetadata> {
222+
let client = reqwest::Client::new();
223+
let provider_metadata = CoreProviderMetadata::discover_async(
224+
IssuerUrl::new(config_url).ok().unwrap(),
225+
&client,
226+
).await;
227+
228+
match provider_metadata {
229+
Ok(provider_metadata) => Some(provider_metadata),
230+
Err(_) => None,
231+
}
232+
}

0 commit comments

Comments
 (0)