@@ -57,6 +57,11 @@ public class KeySetRetriever implements KeySetProvider {
57
57
private long lastCacheStatusLog = 0 ;
58
58
private String jwksUri ;
59
59
60
+ // Security validation settings (optional, for JWKS endpoints)
61
+ private long maxResponseSizeBytes = -1 ; // -1 means no limit
62
+ private int maxKeyCount = -1 ; // -1 means no limit
63
+ private boolean enableSecurityValidation = false ;
64
+
60
65
KeySetRetriever (String openIdConnectEndpoint , SSLConfig sslConfig , boolean useCacheForOidConnectEndpoint ) {
61
66
this .openIdConnectEndpoint = openIdConnectEndpoint ;
62
67
this .sslConfig = sslConfig ;
@@ -71,10 +76,41 @@ public class KeySetRetriever implements KeySetProvider {
71
76
configureCache (useCacheForOidConnectEndpoint );
72
77
}
73
78
79
+ /**
80
+ * Factory method to create a KeySetRetriever for JWKS endpoint access.
81
+ * This method provides a public API for creating KeySetRetriever instances
82
+ * with built-in security validation to protect against malicious JWKS endpoints.
83
+ *
84
+ * @param sslConfig SSL configuration for HTTPS connections
85
+ * @param useCacheForJwksEndpoint whether to enable caching for JWKS endpoint
86
+ * When true, JWKS responses will be cached to improve performance
87
+ * and reduce network calls to the JWKS endpoint.
88
+ * @param jwksUri the JWKS endpoint URI
89
+ * @param maxResponseSizeBytes maximum allowed HTTP response size in bytes
90
+ * @param maxKeyCount maximum number of keys allowed in JWKS
91
+ * @return a new KeySetRetriever instance with security validation enabled
92
+ */
93
+ public static KeySetRetriever createForJwksUri (
94
+ SSLConfig sslConfig ,
95
+ boolean useCacheForJwksEndpoint ,
96
+ String jwksUri ,
97
+ long maxResponseSizeBytes ,
98
+ int maxKeyCount
99
+ ) {
100
+ KeySetRetriever retriever = new KeySetRetriever (sslConfig , useCacheForJwksEndpoint , jwksUri );
101
+ retriever .enableSecurityValidation = true ;
102
+ retriever .maxResponseSizeBytes = maxResponseSizeBytes ;
103
+ retriever .maxKeyCount = maxKeyCount ;
104
+ return retriever ;
105
+ }
106
+
74
107
public JWKSet get () throws AuthenticatorUnavailableException {
75
108
String uri = getJwksUri ();
76
109
77
- try (CloseableHttpClient httpClient = createHttpClient (null )) {
110
+ // Use cache storage if it's configured
111
+ HttpCacheStorage cacheStorage = oidcHttpCacheStorage ;
112
+
113
+ try (CloseableHttpClient httpClient = createHttpClient (cacheStorage )) {
78
114
79
115
HttpGet httpGet = new HttpGet (uri );
80
116
@@ -85,7 +121,20 @@ public JWKSet get() throws AuthenticatorUnavailableException {
85
121
86
122
httpGet .setConfig (requestConfig );
87
123
88
- try (CloseableHttpResponse response = httpClient .execute (httpGet )) {
124
+ // Configure HTTP client to only accept JSON responses for JWKS endpoints
125
+ if (enableSecurityValidation ) {
126
+ httpGet .setHeader ("Accept" , "application/json, application/jwk-set+json" );
127
+ }
128
+
129
+ HttpCacheContext httpContext = null ;
130
+ if (cacheStorage != null ) {
131
+ httpContext = new HttpCacheContext ();
132
+ }
133
+
134
+ try (CloseableHttpResponse response = httpClient .execute (httpGet , httpContext )) {
135
+ if (httpContext != null ) {
136
+ logCacheResponseStatus (httpContext , true );
137
+ }
89
138
if (response .getCode () < 200 || response .getCode () >= 300 ) {
90
139
throw new AuthenticatorUnavailableException ("Error while getting " + uri + ": " + response .getReasonPhrase ());
91
140
}
@@ -95,11 +144,41 @@ public JWKSet get() throws AuthenticatorUnavailableException {
95
144
if (httpEntity == null ) {
96
145
throw new AuthenticatorUnavailableException ("Error while getting " + uri + ": Empty response entity" );
97
146
}
147
+
148
+ // Apply security validation if enabled (for JWKS endpoints)
149
+ if (enableSecurityValidation ) {
150
+ // Validate response size
151
+ if (maxResponseSizeBytes > 0 ) {
152
+ long contentLength = httpEntity .getContentLength ();
153
+ if (contentLength > maxResponseSizeBytes ) {
154
+ throw new AuthenticatorUnavailableException (
155
+ String .format (
156
+ "JWKS response too large from %s: %d bytes (max: %d)" ,
157
+ uri ,
158
+ contentLength ,
159
+ maxResponseSizeBytes
160
+ )
161
+ );
162
+ }
163
+ }
164
+ }
165
+
166
+ // Load JWKS using Nimbus JOSE (handles JSON parsing and validation)
98
167
JWKSet keySet = JWKSet .load (httpEntity .getContent ());
99
168
169
+ // Apply minimal additional validation only for direct JWKS endpoints
170
+ if (enableSecurityValidation ) {
171
+ // Simple key count validation - HARD LIMIT
172
+ if (maxKeyCount > 0 && keySet .getKeys ().size () > maxKeyCount ) {
173
+ throw new AuthenticatorUnavailableException (
174
+ String .format ("JWKS from %s contains %d keys, but max allowed is %d" , uri , keySet .getKeys ().size (), maxKeyCount )
175
+ );
176
+ }
177
+ }
178
+
100
179
return keySet ;
101
180
} catch (ParseException e ) {
102
- throw new RuntimeException ( e );
181
+ throw new AuthenticatorUnavailableException ( "Error parsing JWKS from " + uri + ": " + e . getMessage (), e );
103
182
}
104
183
} catch (IOException e ) {
105
184
throw new AuthenticatorUnavailableException ("Error while getting " + uri + ": " + e , e );
@@ -177,21 +256,43 @@ public void setRequestTimeoutMs(int httpTimeoutMs) {
177
256
}
178
257
179
258
private void logCacheResponseStatus (HttpCacheContext httpContext ) {
259
+ logCacheResponseStatus (httpContext , false );
260
+ }
261
+
262
+ private void logCacheResponseStatus (HttpCacheContext httpContext , boolean isJwksRequest ) {
180
263
this .oidcRequests ++;
181
264
182
- switch (httpContext .getCacheResponseStatus ()) {
183
- case CACHE_HIT :
184
- this .oidcCacheHits ++;
185
- break ;
186
- case CACHE_MODULE_RESPONSE :
187
- this .oidcCacheModuleResponses ++;
188
- break ;
189
- case CACHE_MISS :
265
+ // Handle cache statistics based on the response status
266
+ // For OIDC discovery flow, only count the JWKS request (not the discovery request)
267
+ // For direct JWKS URI, count all requests
268
+ boolean shouldCountStats = (jwksUri != null ) || isJwksRequest ;
269
+
270
+ if (!shouldCountStats ) {
271
+ log .debug ("Skipping cache statistics for OIDC discovery request #{}" , this .oidcRequests );
272
+ return ;
273
+ }
274
+
275
+ if (httpContext .getCacheResponseStatus () == null ) {
276
+ if (oidcHttpCacheStorage != null ) {
190
277
this .oidcCacheMisses ++;
191
- break ;
192
- case VALIDATED :
193
- this .oidcCacheHitsValidated ++;
194
- break ;
278
+ log .debug ("Null cache status - counting as cache miss. Total misses: {}" , this .oidcCacheMisses );
279
+ }
280
+ } else {
281
+ switch (httpContext .getCacheResponseStatus ()) {
282
+ case CACHE_HIT :
283
+ this .oidcCacheHits ++;
284
+ break ;
285
+ case CACHE_MODULE_RESPONSE :
286
+ this .oidcCacheModuleResponses ++;
287
+ break ;
288
+ case CACHE_MISS :
289
+ this .oidcCacheMisses ++;
290
+ break ;
291
+ case VALIDATED :
292
+ this .oidcCacheHits ++;
293
+ this .oidcCacheHitsValidated ++;
294
+ break ;
295
+ }
195
296
}
196
297
197
298
long now = System .currentTimeMillis ();
@@ -208,7 +309,6 @@ private void logCacheResponseStatus(HttpCacheContext httpContext) {
208
309
);
209
310
lastCacheStatusLog = now ;
210
311
}
211
-
212
312
}
213
313
214
314
private CloseableHttpClient createHttpClient (HttpCacheStorage httpCacheStorage ) {
@@ -255,4 +355,5 @@ public int getOidcCacheHitsValidated() {
255
355
public int getOidcCacheModuleResponses () {
256
356
return oidcCacheModuleResponses ;
257
357
}
358
+
258
359
}
0 commit comments