1818import com .nimbusds .jose .KeyLengthException ;
1919import io .jsonwebtoken .ExpiredJwtException ;
2020import io .jsonwebtoken .Jwts ;
21+ import org .testng .annotations .DataProvider ;
2122import org .testng .annotations .Test ;
2223
2324import java .net .URI ;
2425import java .security .GeneralSecurityException ;
26+ import java .security .NoSuchAlgorithmException ;
27+ import java .security .SecureRandom ;
2528import java .time .Clock ;
2629import java .time .Instant ;
2730import java .time .ZoneId ;
2831import java .time .ZonedDateTime ;
32+ import java .util .Base64 ;
2933import java .util .Calendar ;
3034import java .util .Date ;
3135import java .util .Map ;
3236import java .util .Optional ;
37+ import java .util .Random ;
3338
3439import static com .facebook .airlift .units .Duration .succinctDuration ;
3540import static com .facebook .presto .server .security .oauth2 .TokenPairSerializer .TokenPair .accessAndRefreshTokens ;
41+ import static com .facebook .presto .server .security .oauth2 .TokenPairSerializer .TokenPair .withAccessAndRefreshTokens ;
3642import static java .time .temporal .ChronoUnit .MILLIS ;
3743import static java .util .concurrent .TimeUnit .MINUTES ;
3844import static java .util .concurrent .TimeUnit .SECONDS ;
@@ -45,7 +51,7 @@ public class TestJweTokenSerializer
4551 public void testSerialization ()
4652 throws Exception
4753 {
48- JweTokenSerializer serializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ));
54+ JweTokenSerializer serializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ), randomEncodedSecret () );
4955
5056 Date expiration = new Calendar .Builder ().setDate (2022 , 6 , 22 ).build ().getTime ();
5157 String serializedTokenPair = serializer .serialize (accessAndRefreshTokens ("access_token" , expiration , "refresh_token" ));
@@ -56,14 +62,73 @@ public void testSerialization()
5662 assertThat (deserializedTokenPair .getRefreshToken ()).isEqualTo (Optional .of ("refresh_token" ));
5763 }
5864
65+ @ Test (dataProvider = "wrongSecretsProvider" )
66+ public void testDeserializationWithWrongSecret (String encryptionSecret , String decryptionSecret )
67+ {
68+ assertThatThrownBy (() -> assertRoundTrip (Optional .ofNullable (encryptionSecret ), Optional .ofNullable (decryptionSecret )))
69+ .isInstanceOf (IllegalArgumentException .class )
70+ .hasMessage ("Decryption failed" )
71+ .hasStackTraceContaining ("Tag mismatch!" );
72+ }
73+
74+ @ DataProvider
75+ public Object [][] wrongSecretsProvider ()
76+ {
77+ return new Object [][] {
78+ {randomEncodedSecret (), randomEncodedSecret ()},
79+ {randomEncodedSecret (16 ), randomEncodedSecret (24 )},
80+ {null , null }, // This will generate two different secret keys
81+ {null , randomEncodedSecret ()},
82+ {randomEncodedSecret (), null }
83+ };
84+ }
85+
86+ @ Test
87+ public void testSerializationDeserializationRoundTripWithDifferentKeyLengths ()
88+ throws Exception
89+ {
90+ for (int keySize : new int [] {16 , 24 , 32 }) {
91+ String secret = randomEncodedSecret (keySize );
92+ assertRoundTrip (secret , secret );
93+ }
94+ }
95+
96+ @ Test
97+ public void testSerializationFailsWithWrongKeySize ()
98+ {
99+ for (int wrongKeySize : new int [] {8 , 64 , 128 }) {
100+ String tooShortSecret = randomEncodedSecret (wrongKeySize );
101+ assertThatThrownBy (() -> assertRoundTrip (tooShortSecret , tooShortSecret ))
102+ .hasStackTraceContaining ("Secret key size must be either 16, 24 or 32 bytes but was " + wrongKeySize );
103+ }
104+ }
105+
106+ private void assertRoundTrip (String serializerSecret , String deserializerSecret )
107+ throws Exception
108+ {
109+ assertRoundTrip (Optional .of (serializerSecret ), Optional .of (deserializerSecret ));
110+ }
111+
112+ private void assertRoundTrip (Optional <String > serializerSecret , Optional <String > deserializerSecret )
113+ throws Exception
114+ {
115+ JweTokenSerializer serializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ), serializerSecret );
116+ JweTokenSerializer deserializer = tokenSerializer (Clock .systemUTC (), succinctDuration (5 , SECONDS ), deserializerSecret );
117+ Date expiration = new Calendar .Builder ().setDate (2023 , 6 , 22 ).build ().getTime ();
118+ TokenPair tokenPair = withAccessAndRefreshTokens (randomEncodedSecret (), expiration , randomEncodedSecret ());
119+ assertThat (deserializer .deserialize (serializer .serialize (tokenPair )))
120+ .isEqualTo (tokenPair );
121+ }
122+
59123 @ Test
60124 public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension ()
61125 throws Exception
62126 {
63127 TestingClock clock = new TestingClock ();
64128 JweTokenSerializer serializer = tokenSerializer (
65129 clock ,
66- succinctDuration (12 , MINUTES ));
130+ succinctDuration (12 , MINUTES ),
131+ randomEncodedSecret ());
67132 Date expiration = new Calendar .Builder ().setDate (2022 , 6 , 22 ).build ().getTime ();
68133 String serializedTokenPair = serializer .serialize (accessAndRefreshTokens ("access_token" , expiration , "refresh_token" ));
69134 clock .advanceBy (succinctDuration (10 , MINUTES ));
@@ -82,7 +147,8 @@ public void testTokenDeserializationAfterTimeoutAndExpirationExtension()
82147
83148 JweTokenSerializer serializer = tokenSerializer (
84149 clock ,
85- succinctDuration (12 , MINUTES ));
150+ succinctDuration (12 , MINUTES ),
151+ randomEncodedSecret ());
86152 Date expiration = new Calendar .Builder ().setDate (2022 , 6 , 22 ).build ().getTime ();
87153 String serializedTokenPair = serializer .serialize (accessAndRefreshTokens ("access_token" , expiration , "refresh_token" ));
88154
@@ -104,6 +170,40 @@ private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration
104170 tokenExpiration );
105171 }
106172
173+ private JweTokenSerializer tokenSerializer (Clock clock , Duration tokenExpiration , String encodedSecretKey )
174+ throws GeneralSecurityException , KeyLengthException
175+ {
176+ return tokenSerializer (clock , tokenExpiration , Optional .of (encodedSecretKey ));
177+ }
178+
179+ private JweTokenSerializer tokenSerializer (Clock clock , Duration tokenExpiration , Optional <String > secretKey )
180+ throws NoSuchAlgorithmException , KeyLengthException
181+ {
182+ RefreshTokensConfig refreshTokensConfig = new RefreshTokensConfig ();
183+ secretKey .ifPresent (refreshTokensConfig ::setSecretKey );
184+ return new JweTokenSerializer (
185+ refreshTokensConfig ,
186+ new Oauth2ClientStub (),
187+ "trino_coordinator_test_version" ,
188+ "trino_coordinator" ,
189+ "sub" ,
190+ clock ,
191+ tokenExpiration );
192+ }
193+
194+ private static String randomEncodedSecret ()
195+ {
196+ return randomEncodedSecret (24 );
197+ }
198+
199+ private static String randomEncodedSecret (int length )
200+ {
201+ Random random = new SecureRandom ();
202+ final byte [] buffer = new byte [length ];
203+ random .nextBytes (buffer );
204+ return Base64 .getEncoder ().encodeToString (buffer );
205+ }
206+
107207 static class Oauth2ClientStub
108208 implements OAuth2Client
109209 {
0 commit comments