Skip to content

Commit 435146b

Browse files
Bring in changes from Trino 3c687c3
1 parent 0f2e467 commit 435146b

File tree

3 files changed

+127
-11
lines changed

3 files changed

+127
-11
lines changed

presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import com.nimbusds.jose.EncryptionMethod;
1818
import com.nimbusds.jose.JOSEException;
1919
import com.nimbusds.jose.JWEAlgorithm;
20+
import com.nimbusds.jose.JWEDecrypter;
21+
import com.nimbusds.jose.JWEEncrypter;
2022
import com.nimbusds.jose.JWEHeader;
2123
import com.nimbusds.jose.JWEObject;
2224
import com.nimbusds.jose.KeyLengthException;
@@ -49,8 +51,8 @@
4951
public class JweTokenSerializer
5052
implements TokenPairSerializer
5153
{
52-
private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW;
53-
private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512;
54+
private final JWEHeader encryptionHeader;
55+
5456
private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec();
5557
private static final String ACCESS_TOKEN_KEY = "access_token";
5658
private static final String EXPIRATION_TIME_KEY = "expiration_time";
@@ -61,8 +63,8 @@ public class JweTokenSerializer
6163
private final String audience;
6264
private final Duration tokenExpiration;
6365
private final JwtParser parser;
64-
private final AESEncrypter jweEncrypter;
65-
private final AESDecrypter jweDecrypter;
66+
private final JWEEncrypter jweEncrypter;
67+
private final JWEDecrypter jweDecrypter;
6668
private final String principalField;
6769

6870
public JweTokenSerializer(
@@ -84,6 +86,7 @@ public JweTokenSerializer(
8486
this.audience = requireNonNull(audience, "issuer is null");
8587
this.clock = requireNonNull(clock, "clock is null");
8688
this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null");
89+
this.encryptionHeader = createEncryptionHeader(secretKey);
8790

8891
this.parser = newJwtParserBuilder()
8992
.setClock(() -> Date.from(clock.instant()))
@@ -93,11 +96,21 @@ public JweTokenSerializer(
9396
.build();
9497
}
9598

99+
private JWEHeader createEncryptionHeader(SecretKey key)
100+
{
101+
int keyLength = key.getEncoded().length;
102+
return switch (keyLength) {
103+
case 16 -> new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM);
104+
case 24 -> new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM);
105+
case 32 -> new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM);
106+
default -> throw new IllegalArgumentException("Secret key size must be either 16, 24 or 32 bytes but was %d".formatted(keyLength));
107+
};
108+
}
109+
96110
@Override
97111
public TokenPair deserialize(String token)
98112
{
99113
requireNonNull(token, "token is null");
100-
101114
try {
102115
JWEObject jwe = JWEObject.parse(token);
103116
jwe.decrypt(jweDecrypter);
@@ -139,9 +152,7 @@ public String serialize(TokenPair tokenPair)
139152
.compressWith(COMPRESSION_CODEC);
140153

141154
try {
142-
JWEObject jwe = new JWEObject(
143-
new JWEHeader(ALGORITHM, ENCRYPTION_METHOD),
144-
new Payload(jwt.compact()));
155+
JWEObject jwe = new JWEObject(encryptionHeader, new Payload(jwt.compact()));
145156
jwe.encrypt(jweEncrypter);
146157
return jwe.serialize();
147158
}

presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,10 @@ public Optional<String> getRefreshToken()
8787
{
8888
return refreshToken;
8989
}
90+
91+
public static TokenPair withAccessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken)
92+
{
93+
return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken));
94+
}
9095
}
9196
}

presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,27 @@
1818
import com.nimbusds.jose.KeyLengthException;
1919
import io.jsonwebtoken.ExpiredJwtException;
2020
import io.jsonwebtoken.Jwts;
21+
import org.testng.annotations.DataProvider;
2122
import org.testng.annotations.Test;
2223

2324
import java.net.URI;
2425
import java.security.GeneralSecurityException;
26+
import java.security.NoSuchAlgorithmException;
27+
import java.security.SecureRandom;
2528
import java.time.Clock;
2629
import java.time.Instant;
2730
import java.time.ZoneId;
2831
import java.time.ZonedDateTime;
32+
import java.util.Base64;
2933
import java.util.Calendar;
3034
import java.util.Date;
3135
import java.util.Map;
3236
import java.util.Optional;
37+
import java.util.Random;
3338

3439
import static com.facebook.airlift.units.Duration.succinctDuration;
3540
import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.accessAndRefreshTokens;
41+
import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.withAccessAndRefreshTokens;
3642
import static java.time.temporal.ChronoUnit.MILLIS;
3743
import static java.util.concurrent.TimeUnit.MINUTES;
3844
import static java.util.concurrent.TimeUnit.SECONDS;
@@ -45,7 +51,7 @@ public class TestJweTokenSerializer
4551
public void testSerialization()
4652
throws Exception
4753
{
48-
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS));
54+
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), randomEncodedSecret());
4955

5056
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
5157
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
@@ -56,14 +62,73 @@ public void testSerialization()
5662
assertThat(deserializedTokenPair.getRefreshToken()).isEqualTo(Optional.of("refresh_token"));
5763
}
5864

65+
@Test(dataProvider = "wrongSecretsProvider")
66+
public void testDeserializationWithWrongSecret(String encryptionSecret, String decryptionSecret)
67+
{
68+
assertThatThrownBy(() -> assertRoundTrip(Optional.ofNullable(encryptionSecret), Optional.ofNullable(decryptionSecret)))
69+
.isInstanceOf(IllegalArgumentException.class)
70+
.hasMessage("Decryption failed")
71+
.hasStackTraceContaining("Tag mismatch!");
72+
}
73+
74+
@DataProvider
75+
public Object[][] wrongSecretsProvider()
76+
{
77+
return new Object[][] {
78+
{randomEncodedSecret(), randomEncodedSecret()},
79+
{randomEncodedSecret(16), randomEncodedSecret(24)},
80+
{null, null}, // This will generate two different secret keys
81+
{null, randomEncodedSecret()},
82+
{randomEncodedSecret(), null}
83+
};
84+
}
85+
86+
@Test
87+
public void testSerializationDeserializationRoundTripWithDifferentKeyLengths()
88+
throws Exception
89+
{
90+
for (int keySize : new int[] {16, 24, 32}) {
91+
String secret = randomEncodedSecret(keySize);
92+
assertRoundTrip(secret, secret);
93+
}
94+
}
95+
96+
@Test
97+
public void testSerializationFailsWithWrongKeySize()
98+
{
99+
for (int wrongKeySize : new int[] {8, 64, 128}) {
100+
String tooShortSecret = randomEncodedSecret(wrongKeySize);
101+
assertThatThrownBy(() -> assertRoundTrip(tooShortSecret, tooShortSecret))
102+
.hasStackTraceContaining("Secret key size must be either 16, 24 or 32 bytes but was " + wrongKeySize);
103+
}
104+
}
105+
106+
private void assertRoundTrip(String serializerSecret, String deserializerSecret)
107+
throws Exception
108+
{
109+
assertRoundTrip(Optional.of(serializerSecret), Optional.of(deserializerSecret));
110+
}
111+
112+
private void assertRoundTrip(Optional<String> serializerSecret, Optional<String> deserializerSecret)
113+
throws Exception
114+
{
115+
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), serializerSecret);
116+
JweTokenSerializer deserializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), deserializerSecret);
117+
Date expiration = new Calendar.Builder().setDate(2023, 6, 22).build().getTime();
118+
TokenPair tokenPair = withAccessAndRefreshTokens(randomEncodedSecret(), expiration, randomEncodedSecret());
119+
assertThat(deserializer.deserialize(serializer.serialize(tokenPair)))
120+
.isEqualTo(tokenPair);
121+
}
122+
59123
@Test
60124
public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension()
61125
throws Exception
62126
{
63127
TestingClock clock = new TestingClock();
64128
JweTokenSerializer serializer = tokenSerializer(
65129
clock,
66-
succinctDuration(12, MINUTES));
130+
succinctDuration(12, MINUTES),
131+
randomEncodedSecret());
67132
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
68133
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
69134
clock.advanceBy(succinctDuration(10, MINUTES));
@@ -82,7 +147,8 @@ public void testTokenDeserializationAfterTimeoutAndExpirationExtension()
82147

83148
JweTokenSerializer serializer = tokenSerializer(
84149
clock,
85-
succinctDuration(12, MINUTES));
150+
succinctDuration(12, MINUTES),
151+
randomEncodedSecret());
86152
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
87153
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
88154

@@ -104,6 +170,40 @@ private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration
104170
tokenExpiration);
105171
}
106172

173+
private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, String encodedSecretKey)
174+
throws GeneralSecurityException, KeyLengthException
175+
{
176+
return tokenSerializer(clock, tokenExpiration, Optional.of(encodedSecretKey));
177+
}
178+
179+
private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, Optional<String> secretKey)
180+
throws NoSuchAlgorithmException, KeyLengthException
181+
{
182+
RefreshTokensConfig refreshTokensConfig = new RefreshTokensConfig();
183+
secretKey.ifPresent(refreshTokensConfig::setSecretKey);
184+
return new JweTokenSerializer(
185+
refreshTokensConfig,
186+
new Oauth2ClientStub(),
187+
"trino_coordinator_test_version",
188+
"trino_coordinator",
189+
"sub",
190+
clock,
191+
tokenExpiration);
192+
}
193+
194+
private static String randomEncodedSecret()
195+
{
196+
return randomEncodedSecret(24);
197+
}
198+
199+
private static String randomEncodedSecret(int length)
200+
{
201+
Random random = new SecureRandom();
202+
final byte[] buffer = new byte[length];
203+
random.nextBytes(buffer);
204+
return Base64.getEncoder().encodeToString(buffer);
205+
}
206+
107207
static class Oauth2ClientStub
108208
implements OAuth2Client
109209
{

0 commit comments

Comments
 (0)