Skip to content

Commit eb7153d

Browse files
Rishav9852KumarRishav Kumar
andauthored
Usage of JWKS with JWT (without using OpenID connect) (#5578)
Signed-off-by: Rishav Kumar <[email protected]> Co-authored-by: Rishav Kumar <[email protected]>
1 parent e8a82af commit eb7153d

File tree

9 files changed

+1358
-20
lines changed

9 files changed

+1358
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1414
- [Resource Sharing] Keep track of tenant for sharable resources by persisting user requested tenant with sharing info ([#5588](https://github.com/opensearch-project/security/pull/5588))
1515
- [SecurityPlugin Health Check] Add AuthZ initialization completion check in health check API [(#5626)](https://github.com/opensearch-project/security/pull/5626)
1616
- [Resource Sharing] Adds API to provide dashboards support for resource access management ([#5597](https://github.com/opensearch-project/security/pull/5597))
17+
- Direct JWKS (JSON Web Key Set) support in the JWT authentication backend ([#5578](https://github.com/opensearch-project/security/pull/5578))
1718

1819

1920
### Bug Fixes

config/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ config:
130130
type: jwt
131131
challenge: false
132132
config:
133+
jwks_uri: 'https://your-jwks-endpoint.com/.well-known/jwks.json'
133134
signing_key: "base64 encoded HMAC key or public RSA/ECDSA pem key"
134135
jwt_header: "Authorization"
135136
jwt_url_parameter: null
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*/
11+
12+
package org.opensearch.security.auth.http.jwt.keybyjwks;
13+
14+
import java.nio.file.Path;
15+
import java.util.Collections;
16+
import java.util.Set;
17+
18+
import org.apache.logging.log4j.LogManager;
19+
import org.apache.logging.log4j.Logger;
20+
21+
import org.opensearch.OpenSearchSecurityException;
22+
import org.opensearch.common.settings.Settings;
23+
import org.opensearch.common.util.concurrent.ThreadContext;
24+
import org.opensearch.core.common.Strings;
25+
import org.opensearch.security.auth.http.jwt.AbstractHTTPJwtAuthenticator;
26+
import org.opensearch.security.auth.http.jwt.HTTPJwtAuthenticator;
27+
import org.opensearch.security.auth.http.jwt.keybyoidc.KeyProvider;
28+
import org.opensearch.security.auth.http.jwt.keybyoidc.KeySetRetriever;
29+
import org.opensearch.security.auth.http.jwt.keybyoidc.SelfRefreshingKeySet;
30+
import org.opensearch.security.filter.SecurityRequest;
31+
import org.opensearch.security.user.AuthCredentials;
32+
import org.opensearch.security.util.SettingsBasedSSLConfigurator;
33+
34+
/**
35+
* JWT authenticator that uses JWKS (JSON Web Key Set) endpoints for key retrieval.
36+
*
37+
* This authenticator extends AbstractHTTPJwtAuthenticator and provides JWKS-specific
38+
* key provider initialization. It supports direct JWKS endpoint access with caching,
39+
* SSL configuration, automatic key refresh, and enhanced security features to protect
40+
* against malicious JWKS endpoints.
41+
*
42+
* Security Features:
43+
* - Response size validation before parsing to prevent memory exhaustion
44+
* - Hard key count limit after parsing to reject oversized JWKS
45+
* - Configurable timeouts and rate limiting
46+
*
47+
* Configuration:
48+
* - jwks_uri: Direct JWKS endpoint URL (required)
49+
* - cache_jwks_endpoint: Enable/disable caching (default: true)
50+
* - jwks_request_timeout_ms: Request timeout in milliseconds (default: 5000)
51+
* - jwks_queued_thread_timeout_ms: Queued thread timeout (default: 2500)
52+
* - refresh_rate_limit_time_window_ms: Rate limit window (default: 10000)
53+
* - refresh_rate_limit_count: Max refreshes per window (default: 10)
54+
* - max_jwks_keys: HARD LIMIT - Rejects JWKS if exceeded (default: 10)
55+
* - max_jwks_response_size_bytes: Max HTTP response size (default: 1MB)
56+
*/
57+
public class HTTPJwtKeyByJWKSAuthenticator extends AbstractHTTPJwtAuthenticator {
58+
59+
private final static Logger log = LogManager.getLogger(HTTPJwtKeyByJWKSAuthenticator.class);
60+
61+
// Fallback to static JWT authenticator if jwks_uri is null
62+
private final HTTPJwtAuthenticator staticJwtAuthenticator;
63+
private final boolean useJwks;
64+
private final String jwtUrlParameter;
65+
66+
public HTTPJwtKeyByJWKSAuthenticator(Settings settings, Path configPath) {
67+
super(settings, configPath);
68+
69+
String jwksUri = settings.get("jwks_uri");
70+
this.useJwks = !Strings.isNullOrEmpty(jwksUri);
71+
this.jwtUrlParameter = settings.get("jwt_url_parameter");
72+
73+
// Initialize static JWT authenticator as fallback if jwks_uri is not configured
74+
if (!useJwks) {
75+
log.warn("jwks_uri is not configured, falling back to static JWT authentication");
76+
this.staticJwtAuthenticator = new HTTPJwtAuthenticator(settings, configPath);
77+
} else {
78+
this.staticJwtAuthenticator = null;
79+
}
80+
}
81+
82+
@Override
83+
protected KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception {
84+
String jwksUri = settings.get("jwks_uri");
85+
86+
// If jwks_uri is not configured, return null (will use static JWT fallback)
87+
if (jwksUri == null || jwksUri.isBlank()) {
88+
log.warn("jwks_uri is not configured, will use static JWT authentication fallback");
89+
return null;
90+
}
91+
92+
log.debug("Initializing JWKS key provider with endpoint: {}", jwksUri);
93+
94+
// Initialize configuration parameters
95+
int jwksRequestTimeoutMs = settings.getAsInt("jwks_request_timeout_ms", 5000);
96+
int jwksQueuedThreadTimeoutMs = settings.getAsInt("jwks_queued_thread_timeout_ms", 2500);
97+
int refreshRateLimitTimeWindowMs = settings.getAsInt("refresh_rate_limit_time_window_ms", 10000);
98+
int refreshRateLimitCount = settings.getAsInt("refresh_rate_limit_count", 10);
99+
boolean cacheJwksEndpoint = settings.getAsBoolean("cache_jwks_endpoint", true);
100+
int maxJwksKeys = settings.getAsInt("max_jwks_keys", -1);
101+
102+
log.warn("Initializing JWKS key provider with endpoint: {} (max keys: {})", jwksUri, maxJwksKeys);
103+
104+
// Add security configuration parameters
105+
long maxJwksResponseSizeBytes = settings.getAsLong("max_jwks_response_size_bytes", 1024L * 1024L); // 1MB default
106+
107+
// Create secure key set retriever with HARD LIMIT enforcement using maxJwksKeys
108+
KeySetRetriever keySetRetriever = KeySetRetriever.createForJwksUri(
109+
getSSLConfig(settings, configPath),
110+
cacheJwksEndpoint,
111+
jwksUri,
112+
maxJwksResponseSizeBytes,
113+
maxJwksKeys
114+
);
115+
keySetRetriever.setRequestTimeoutMs(jwksRequestTimeoutMs);
116+
117+
// Create self-refreshing key set with caching and rate limiting
118+
SelfRefreshingKeySet selfRefreshingKeySet = new SelfRefreshingKeySet(keySetRetriever);
119+
selfRefreshingKeySet.setRequestTimeoutMs(jwksRequestTimeoutMs);
120+
selfRefreshingKeySet.setQueuedThreadTimeoutMs(jwksQueuedThreadTimeoutMs);
121+
selfRefreshingKeySet.setRefreshRateLimitTimeWindowMs(refreshRateLimitTimeWindowMs);
122+
selfRefreshingKeySet.setRefreshRateLimitCount(refreshRateLimitCount);
123+
124+
return selfRefreshingKeySet;
125+
}
126+
127+
@Override
128+
public AuthCredentials extractCredentials(final SecurityRequest request, final ThreadContext context)
129+
throws OpenSearchSecurityException {
130+
131+
// If jwks_uri is not configured, delegate to static JWT authenticator
132+
if (!useJwks && staticJwtAuthenticator != null) {
133+
log.debug("Delegating to static JWT authenticator since jwks_uri is not configured");
134+
return staticJwtAuthenticator.extractCredentials(request, context);
135+
}
136+
137+
// Use the standard JWKS authentication flow
138+
return super.extractCredentials(request, context);
139+
}
140+
141+
private static SettingsBasedSSLConfigurator.SSLConfig getSSLConfig(Settings settings, Path configPath) throws Exception {
142+
return new SettingsBasedSSLConfigurator(settings, configPath, "jwks").buildSSLConfig();
143+
}
144+
145+
@Override
146+
public String getType() {
147+
return "jwt";
148+
}
149+
150+
@Override
151+
public Set<String> getSensitiveUrlParams() {
152+
if (jwtUrlParameter != null) {
153+
return Set.of(jwtUrlParameter);
154+
}
155+
return Collections.emptySet();
156+
}
157+
158+
}

src/main/java/org/opensearch/security/auth/http/jwt/keybyoidc/JwtVerifier.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExc
5757
String kid = escapedKid;
5858
if (!Strings.isNullOrEmpty(kid)) {
5959
kid = StringEscapeUtils.unescapeJava(escapedKid);
60+
} else {
61+
log.debug("JWT token is missing 'kid' (Key ID) claim in header. This may cause key selection issues.");
6062
}
6163
JWK key = keyProvider.getKey(kid);
6264

@@ -65,8 +67,10 @@ public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExc
6567

6668
if (!signatureValid && Strings.isNullOrEmpty(kid)) {
6769
key = keyProvider.getKeyAfterRefresh(null);
68-
signatureVerifier = getInitializedSignatureVerifier(key, jwt);
69-
signatureValid = jwt.verify(signatureVerifier);
70+
if (key != null) {
71+
signatureVerifier = getInitializedSignatureVerifier(key, jwt);
72+
signatureValid = jwt.verify(signatureVerifier);
73+
}
7074
}
7175

7276
if (!signatureValid) {

src/main/java/org/opensearch/security/auth/http/jwt/keybyoidc/KeySetRetriever.java

Lines changed: 117 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ public class KeySetRetriever implements KeySetProvider {
5757
private long lastCacheStatusLog = 0;
5858
private String jwksUri;
5959

60+
// Security validation settings (optional, for JWKS endpoints)
61+
private long maxResponseSizeBytes = -1; // -1 means no limit
62+
private int maxKeyCount = -1; // -1 means no limit
63+
private boolean enableSecurityValidation = false;
64+
6065
KeySetRetriever(String openIdConnectEndpoint, SSLConfig sslConfig, boolean useCacheForOidConnectEndpoint) {
6166
this.openIdConnectEndpoint = openIdConnectEndpoint;
6267
this.sslConfig = sslConfig;
@@ -71,10 +76,41 @@ public class KeySetRetriever implements KeySetProvider {
7176
configureCache(useCacheForOidConnectEndpoint);
7277
}
7378

79+
/**
80+
* Factory method to create a KeySetRetriever for JWKS endpoint access.
81+
* This method provides a public API for creating KeySetRetriever instances
82+
* with built-in security validation to protect against malicious JWKS endpoints.
83+
*
84+
* @param sslConfig SSL configuration for HTTPS connections
85+
* @param useCacheForJwksEndpoint whether to enable caching for JWKS endpoint
86+
* When true, JWKS responses will be cached to improve performance
87+
* and reduce network calls to the JWKS endpoint.
88+
* @param jwksUri the JWKS endpoint URI
89+
* @param maxResponseSizeBytes maximum allowed HTTP response size in bytes
90+
* @param maxKeyCount maximum number of keys allowed in JWKS
91+
* @return a new KeySetRetriever instance with security validation enabled
92+
*/
93+
public static KeySetRetriever createForJwksUri(
94+
SSLConfig sslConfig,
95+
boolean useCacheForJwksEndpoint,
96+
String jwksUri,
97+
long maxResponseSizeBytes,
98+
int maxKeyCount
99+
) {
100+
KeySetRetriever retriever = new KeySetRetriever(sslConfig, useCacheForJwksEndpoint, jwksUri);
101+
retriever.enableSecurityValidation = true;
102+
retriever.maxResponseSizeBytes = maxResponseSizeBytes;
103+
retriever.maxKeyCount = maxKeyCount;
104+
return retriever;
105+
}
106+
74107
public JWKSet get() throws AuthenticatorUnavailableException {
75108
String uri = getJwksUri();
76109

77-
try (CloseableHttpClient httpClient = createHttpClient(null)) {
110+
// Use cache storage if it's configured
111+
HttpCacheStorage cacheStorage = oidcHttpCacheStorage;
112+
113+
try (CloseableHttpClient httpClient = createHttpClient(cacheStorage)) {
78114

79115
HttpGet httpGet = new HttpGet(uri);
80116

@@ -85,7 +121,20 @@ public JWKSet get() throws AuthenticatorUnavailableException {
85121

86122
httpGet.setConfig(requestConfig);
87123

88-
try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
124+
// Configure HTTP client to only accept JSON responses for JWKS endpoints
125+
if (enableSecurityValidation) {
126+
httpGet.setHeader("Accept", "application/json, application/jwk-set+json");
127+
}
128+
129+
HttpCacheContext httpContext = null;
130+
if (cacheStorage != null) {
131+
httpContext = new HttpCacheContext();
132+
}
133+
134+
try (CloseableHttpResponse response = httpClient.execute(httpGet, httpContext)) {
135+
if (httpContext != null) {
136+
logCacheResponseStatus(httpContext, true);
137+
}
89138
if (response.getCode() < 200 || response.getCode() >= 300) {
90139
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + response.getReasonPhrase());
91140
}
@@ -95,11 +144,41 @@ public JWKSet get() throws AuthenticatorUnavailableException {
95144
if (httpEntity == null) {
96145
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": Empty response entity");
97146
}
147+
148+
// Apply security validation if enabled (for JWKS endpoints)
149+
if (enableSecurityValidation) {
150+
// Validate response size
151+
if (maxResponseSizeBytes > 0) {
152+
long contentLength = httpEntity.getContentLength();
153+
if (contentLength > maxResponseSizeBytes) {
154+
throw new AuthenticatorUnavailableException(
155+
String.format(
156+
"JWKS response too large from %s: %d bytes (max: %d)",
157+
uri,
158+
contentLength,
159+
maxResponseSizeBytes
160+
)
161+
);
162+
}
163+
}
164+
}
165+
166+
// Load JWKS using Nimbus JOSE (handles JSON parsing and validation)
98167
JWKSet keySet = JWKSet.load(httpEntity.getContent());
99168

169+
// Apply minimal additional validation only for direct JWKS endpoints
170+
if (enableSecurityValidation) {
171+
// Simple key count validation - HARD LIMIT
172+
if (maxKeyCount > 0 && keySet.getKeys().size() > maxKeyCount) {
173+
throw new AuthenticatorUnavailableException(
174+
String.format("JWKS from %s contains %d keys, but max allowed is %d", uri, keySet.getKeys().size(), maxKeyCount)
175+
);
176+
}
177+
}
178+
100179
return keySet;
101180
} catch (ParseException e) {
102-
throw new RuntimeException(e);
181+
throw new AuthenticatorUnavailableException("Error parsing JWKS from " + uri + ": " + e.getMessage(), e);
103182
}
104183
} catch (IOException e) {
105184
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + e, e);
@@ -177,21 +256,43 @@ public void setRequestTimeoutMs(int httpTimeoutMs) {
177256
}
178257

179258
private void logCacheResponseStatus(HttpCacheContext httpContext) {
259+
logCacheResponseStatus(httpContext, false);
260+
}
261+
262+
private void logCacheResponseStatus(HttpCacheContext httpContext, boolean isJwksRequest) {
180263
this.oidcRequests++;
181264

182-
switch (httpContext.getCacheResponseStatus()) {
183-
case CACHE_HIT:
184-
this.oidcCacheHits++;
185-
break;
186-
case CACHE_MODULE_RESPONSE:
187-
this.oidcCacheModuleResponses++;
188-
break;
189-
case CACHE_MISS:
265+
// Handle cache statistics based on the response status
266+
// For OIDC discovery flow, only count the JWKS request (not the discovery request)
267+
// For direct JWKS URI, count all requests
268+
boolean shouldCountStats = (jwksUri != null) || isJwksRequest;
269+
270+
if (!shouldCountStats) {
271+
log.debug("Skipping cache statistics for OIDC discovery request #{}", this.oidcRequests);
272+
return;
273+
}
274+
275+
if (httpContext.getCacheResponseStatus() == null) {
276+
if (oidcHttpCacheStorage != null) {
190277
this.oidcCacheMisses++;
191-
break;
192-
case VALIDATED:
193-
this.oidcCacheHitsValidated++;
194-
break;
278+
log.debug("Null cache status - counting as cache miss. Total misses: {}", this.oidcCacheMisses);
279+
}
280+
} else {
281+
switch (httpContext.getCacheResponseStatus()) {
282+
case CACHE_HIT:
283+
this.oidcCacheHits++;
284+
break;
285+
case CACHE_MODULE_RESPONSE:
286+
this.oidcCacheModuleResponses++;
287+
break;
288+
case CACHE_MISS:
289+
this.oidcCacheMisses++;
290+
break;
291+
case VALIDATED:
292+
this.oidcCacheHits++;
293+
this.oidcCacheHitsValidated++;
294+
break;
295+
}
195296
}
196297

197298
long now = System.currentTimeMillis();
@@ -208,7 +309,6 @@ private void logCacheResponseStatus(HttpCacheContext httpContext) {
208309
);
209310
lastCacheStatusLog = now;
210311
}
211-
212312
}
213313

214314
private CloseableHttpClient createHttpClient(HttpCacheStorage httpCacheStorage) {
@@ -255,4 +355,5 @@ public int getOidcCacheHitsValidated() {
255355
public int getOidcCacheModuleResponses() {
256356
return oidcCacheModuleResponses;
257357
}
358+
258359
}

0 commit comments

Comments
 (0)