1
- use std:: sync:: { Arc , LazyLock , Mutex } ;
1
+ use std:: {
2
+ collections:: HashMap ,
3
+ sync:: { Arc , LazyLock , Mutex } ,
4
+ } ;
2
5
3
- use anyhow:: anyhow;
4
- use anyhow:: Result ;
6
+ use anyhow:: { anyhow, Result } ;
5
7
use async_trait:: async_trait;
6
8
use cached:: { proc_macro:: cached, TimedCache } ;
7
9
use openidconnect:: {
8
- AccessTokenHash ,
9
- AuthorizationCode ,
10
- CsrfToken ,
11
- ClientId ,
12
- ClientSecret ,
13
- IssuerUrl ,
14
- Nonce ,
15
- OAuth2TokenResponse ,
16
- PkceCodeChallenge ,
17
- PkceCodeVerifier ,
18
- RedirectUrl ,
19
- TokenResponse ,
20
- } ;
21
- use openidconnect:: core:: {
22
- CoreAuthenticationFlow ,
23
- CoreClient ,
24
- CoreProviderMetadata ,
25
- CoreUserInfoClaims ,
10
+ core:: { CoreAuthenticationFlow , CoreClient , CoreProviderMetadata , CoreUserInfoClaims } ,
11
+ AccessTokenHash , AuthorizationCode , ClientId , ClientSecret , CsrfToken , IssuerUrl , Nonce ,
12
+ OAuth2TokenResponse , PkceCodeChallenge , PkceCodeVerifier , RedirectUrl , TokenResponse ,
26
13
} ;
27
14
use serde:: Deserialize ;
28
- use std:: collections:: HashMap ;
29
15
use tabby_schema:: auth:: { AuthenticationService , OAuthCredential , OAuthProvider } ;
30
16
31
17
use super :: OAuthClient ;
@@ -42,9 +28,8 @@ pub struct OAuthRequest {
42
28
pub pkce_verifier : String ,
43
29
}
44
30
45
- static AUTH_REQS : LazyLock < Mutex < HashMap < String , OAuthRequest > > > = LazyLock :: new ( || {
46
- Mutex :: new ( HashMap :: new ( ) )
47
- } ) ;
31
+ static AUTH_REQS : LazyLock < Mutex < HashMap < String , OAuthRequest > > > =
32
+ LazyLock :: new ( || Mutex :: new ( HashMap :: new ( ) ) ) ;
48
33
49
34
impl GeneralClient {
50
35
pub fn new ( auth : Arc < dyn AuthenticationService > ) -> Self {
@@ -65,7 +50,7 @@ impl GeneralClient {
65
50
}
66
51
}
67
52
68
- async fn retrieve_provider_metadata ( & self , config_url : String ) -> Option < CoreProviderMetadata > {
53
+ async fn retrieve_provider_metadata ( & self , config_url : String ) -> Option < CoreProviderMetadata > {
69
54
retrieve_provider_metadata ( config_url) . await
70
55
}
71
56
}
@@ -98,14 +83,15 @@ impl OAuthClient for GeneralClient {
98
83
99
84
let client = reqwest:: Client :: new ( ) ;
100
85
let pkce_verifier = PkceCodeVerifier :: new ( auth_req. pkce_verifier . clone ( ) ) ;
101
- let token_response = oidc_client. exchange_code ( AuthorizationCode :: new ( code) ) ?
86
+ let token_response = oidc_client
87
+ . exchange_code ( AuthorizationCode :: new ( code) ) ?
102
88
. set_pkce_verifier ( pkce_verifier)
103
89
. request_async ( & client)
104
90
. await ?;
105
91
106
92
let id_token = token_response
107
- . id_token ( )
108
- . ok_or_else ( || anyhow ! ( "Invalid authentication token" ) ) ?;
93
+ . id_token ( )
94
+ . ok_or_else ( || anyhow ! ( "Invalid authentication token" ) ) ?;
109
95
110
96
let id_token_verifier = oidc_client. id_token_verifier ( ) ;
111
97
let nonce = Nonce :: new ( auth_req. nonce . clone ( ) ) ;
@@ -118,31 +104,29 @@ impl OAuthClient for GeneralClient {
118
104
id_token. signing_key ( & id_token_verifier) ?,
119
105
) ?;
120
106
if actual_access_token_hash != * expected_access_token_hash {
121
- bail ! ( "Invalid access token" ) ;
107
+ bail ! ( "Invalid access token" ) ;
122
108
}
123
109
}
124
110
125
111
let access_token = token_response. access_token ( ) . secret ( ) . to_string ( ) ;
126
112
127
113
// Get User info
128
- let user_info_response = oidc_client. user_info ( token_response. access_token ( ) . to_owned ( ) , None ) ?
129
- . request_async ( & client) . await ;
114
+ let user_info_response = oidc_client
115
+ . user_info ( token_response. access_token ( ) . to_owned ( ) , None ) ?
116
+ . request_async ( & client)
117
+ . await ;
130
118
131
119
let mut user_info = self . user_info . lock ( ) . unwrap ( ) ;
132
- * user_info = match user_info_response {
133
- Ok ( user_info) => Some ( user_info) ,
134
- Err ( _err) => None ,
135
- } ;
120
+ * user_info = user_info_response. ok ( ) ;
136
121
137
122
Ok ( access_token)
138
123
}
139
124
140
125
async fn fetch_user_email ( & self , _access_token : & str ) -> Result < String > {
141
126
let user_info = self . user_info . lock ( ) . unwrap ( ) ;
142
127
match & * user_info {
143
- Some ( user_info) =>{
144
- let end_user_email = user_info. email ( ) . unwrap ( )
145
- . to_owned ( ) ;
128
+ Some ( user_info) => {
129
+ let end_user_email = user_info. email ( ) . unwrap ( ) . to_owned ( ) ;
146
130
let email = end_user_email. to_string ( ) ;
147
131
Ok ( email)
148
132
}
@@ -154,8 +138,7 @@ impl OAuthClient for GeneralClient {
154
138
let user_info = self . user_info . lock ( ) . unwrap ( ) ;
155
139
match & * user_info {
156
140
Some ( user_info) => {
157
- let end_user_full_name = user_info. name ( ) . unwrap ( )
158
- . to_owned ( ) ;
141
+ let end_user_full_name = user_info. name ( ) . unwrap ( ) . to_owned ( ) ;
159
142
let full_name = end_user_full_name. get ( None ) . unwrap ( ) . to_string ( ) ;
160
143
Ok ( full_name)
161
144
}
@@ -171,16 +154,16 @@ impl OAuthClient for GeneralClient {
171
154
} ;
172
155
let provider_metadata = self . retrieve_provider_metadata ( config_url) . await . unwrap ( ) ;
173
156
174
- let redirect_uri = RedirectUrl :: new (
175
- self . auth . oauth_callback_url ( OAuthProvider :: General ) . await ?
176
- ) ?;
157
+ let redirect_uri =
158
+ RedirectUrl :: new ( self . auth . oauth_callback_url ( OAuthProvider :: General ) . await ?) ?;
177
159
let scopes_supported = provider_metadata. scopes_supported ( ) . unwrap ( ) . clone ( ) ;
178
160
179
161
let oidc_client = CoreClient :: from_provider_metadata (
180
162
provider_metadata,
181
163
ClientId :: new ( credential. client_id ) ,
182
164
Some ( ClientSecret :: new ( credential. client_secret ) ) ,
183
- ) . set_redirect_uri ( redirect_uri) ;
165
+ )
166
+ . set_redirect_uri ( redirect_uri) ;
184
167
185
168
let ( pkce_challenge, pkce_verifier) = PkceCodeChallenge :: new_random_sha256 ( ) ;
186
169
@@ -203,7 +186,7 @@ impl OAuthClient for GeneralClient {
203
186
let ( auth_uri, csrf_token, nonce) = authorization_request. url ( ) ;
204
187
let auth_req = OAuthRequest {
205
188
nonce : nonce. secret ( ) . clone ( ) ,
206
- pkce_verifier : pkce_verifier. into_secret ( )
189
+ pkce_verifier : pkce_verifier. into_secret ( ) ,
207
190
} ;
208
191
209
192
let mut auth_reqs = AUTH_REQS . lock ( ) . unwrap ( ) ;
@@ -213,20 +196,15 @@ impl OAuthClient for GeneralClient {
213
196
}
214
197
}
215
198
216
-
217
199
#[ cached(
218
200
type = "TimedCache<String, Option<CoreProviderMetadata>>" ,
219
201
create = "{ TimedCache::with_lifespan(3600 * 12) }"
220
202
) ]
221
203
async fn retrieve_provider_metadata ( config_url : String ) -> Option < CoreProviderMetadata > {
222
204
let client = reqwest:: Client :: new ( ) ;
223
- let provider_metadata = CoreProviderMetadata :: discover_async (
224
- IssuerUrl :: new ( config_url) . ok ( ) . unwrap ( ) ,
225
- & client,
226
- ) . await ;
227
-
228
- match provider_metadata {
229
- Ok ( provider_metadata) => Some ( provider_metadata) ,
230
- Err ( _) => None ,
231
- }
205
+ let provider_metadata =
206
+ CoreProviderMetadata :: discover_async ( IssuerUrl :: new ( config_url) . ok ( ) . unwrap ( ) , & client)
207
+ . await ;
208
+
209
+ provider_metadata. ok ( )
232
210
}
0 commit comments