4
4
This handles validating messages sent by the tool and generating
5
5
access token with LTI scopes.
6
6
"""
7
- import codecs
8
7
import copy
9
- import time
10
8
import json
9
+ import math
10
+ import time
11
+ import sys
11
12
import logging
12
13
14
+ import jwt
13
15
from Cryptodome .PublicKey import RSA
14
- from jwkest import BadSignature , BadSyntax , WrongNumberOfParts , jwk
15
- from jwkest .jwk import RSAKey , load_jwks_from_url
16
- from jwkest .jws import JWS , NoSuitableSigningKeys , UnknownAlgorithm
17
- from jwkest .jwt import JWT
18
16
19
17
from . import exceptions
20
18
@@ -50,14 +48,9 @@ def __init__(self, public_key=None, keyset_url=None):
50
48
# Import from public key
51
49
if public_key :
52
50
try :
53
- new_key = RSAKey (use = 'sig' )
54
-
55
- # Unescape key before importing it
56
- raw_key = codecs .decode (public_key , 'unicode_escape' )
57
-
58
51
# Import Key and save to internal state
59
- new_key . load_key ( RSA . import_key ( raw_key ) )
60
- self .public_key = new_key
52
+ algo_obj = jwt . get_algorithm_by_name ( 'RS256' )
53
+ self .public_key = algo_obj . prepare_key ( public_key )
61
54
except ValueError as err :
62
55
log .warning (
63
56
'An error was encountered while loading the LTI tool\' s key from the public key. '
@@ -76,7 +69,7 @@ def _get_keyset(self, kid=None):
76
69
77
70
if self .keyset_url :
78
71
try :
79
- keys = load_jwks_from_url (self .keyset_url )
72
+ keys = jwt . PyJWKClient (self .keyset_url ). get_jwk_set ( )
80
73
except Exception as err :
81
74
# Broad Exception is required here because jwkest raises
82
75
# an Exception object explicitly.
@@ -89,13 +82,13 @@ def _get_keyset(self, kid=None):
89
82
raise exceptions .NoSuitableKeys () from err
90
83
keyset .extend (keys )
91
84
92
- if self .public_key and kid :
93
- # Fill in key id of stored key.
94
- # This is needed because if the JWS is signed with a
95
- # key with a kid, pyjwkest doesn't match them with
96
- # keys without kid ( kid=None) and fails verification
97
- self . public_key . kid = kid
98
-
85
+ if self .public_key :
86
+ if kid :
87
+ # Fill in key id of stored key.
88
+ # This is needed because if the JWS is signed with a
89
+ # key with a kid, pyjwkest doesn't match them with
90
+ # keys without kid (kid=None) and fails verification
91
+ self . public_key . kid = kid
99
92
# Add to keyset
100
93
keyset .append (self .public_key )
101
94
@@ -111,48 +104,24 @@ def validate_and_decode(self, token):
111
104
iss, sub, exp, aud and jti claims.
112
105
"""
113
106
try :
114
- # Get KID from JWT header
115
- jwt = JWT ().unpack (token )
116
-
117
- # Verify message signature
118
- message = JWS ().verify_compact (
119
- token ,
120
- keys = self ._get_keyset (
121
- jwt .headers .get ('kid' )
122
- )
123
- )
124
-
125
- # If message is valid, check expiration from JWT
126
- if 'exp' in message and message ['exp' ] < time .time ():
127
- log .warning (
128
- 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
129
- 'The JWT has expired.'
130
- )
131
- raise exceptions .TokenSignatureExpired ()
132
-
133
- # TODO: Validate other JWT claims
134
-
135
- # Else returns decoded message
136
- return message
137
-
138
- except NoSuitableSigningKeys as err :
139
- log .warning (
140
- 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
141
- 'There is no suitable signing key.'
142
- )
143
- raise exceptions .NoSuitableKeys () from err
144
- except (BadSyntax , WrongNumberOfParts ) as err :
145
- log .warning (
146
- 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
147
- 'The JWT is malformed.'
148
- )
149
- raise exceptions .MalformedJwtToken () from err
150
- except BadSignature as err :
151
- log .warning (
152
- 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
153
- 'The JWT signature is incorrect.'
154
- )
155
- raise exceptions .BadJwtSignature () from err
107
+ key_set = self ._get_keyset ()
108
+ if not key_set :
109
+ raise exceptions .NoSuitableKeys ()
110
+ for i in range (len (key_set )):
111
+ try :
112
+ message = jwt .decode (
113
+ token ,
114
+ key = key_set [i ],
115
+ algorithms = ['RS256' , 'RS512' ,],
116
+ options = {'verify_signature' : True }
117
+ )
118
+ return message
119
+ except Exception :
120
+ if i == len (key_set ) - 1 :
121
+ raise
122
+ except Exception as token_error :
123
+ exc_info = sys .exc_info ()
124
+ raise jwt .InvalidTokenError (exc_info [2 ]) from token_error
156
125
157
126
158
127
class PlatformKeyHandler :
@@ -171,14 +140,8 @@ def __init__(self, key_pem, kid=None):
171
140
if key_pem :
172
141
# Import JWK from RSA key
173
142
try :
174
- self .key = RSAKey (
175
- # Using the same key ID as client id
176
- # This way we can easily serve multiple public
177
- # keys on the same endpoint and keep all
178
- # LTI 1.3 blocks working
179
- kid = kid ,
180
- key = RSA .import_key (key_pem )
181
- )
143
+ algo = jwt .get_algorithm_by_name ('RS256' )
144
+ self .key = algo .prepare_key (key_pem )
182
145
except ValueError as err :
183
146
log .warning (
184
147
'An error was encountered while loading the LTI platform\' s key. '
@@ -203,41 +166,26 @@ def encode_and_sign(self, message, expiration=None):
203
166
# Set iat and exp if expiration is set
204
167
if expiration :
205
168
_message .update ({
206
- "iat" : int (round (time .time ())),
207
- "exp" : int (round (time .time ()) + expiration ),
169
+ "iat" : int (math . floor (time .time ())),
170
+ "exp" : int (math . floor (time .time ()) + expiration ),
208
171
})
209
172
210
173
# The class instance that sets up the signing operation
211
174
# An RS 256 key is required for LTI 1.3
212
- _jws = JWS (_message , alg = "RS256" , cty = "JWT" )
213
-
214
- try :
215
- # Encode and sign LTI message
216
- return _jws .sign_compact ([self .key ])
217
- except NoSuitableSigningKeys as err :
218
- log .warning (
219
- 'An error was encountered while signing the OAuth 2.0 access token JWT. '
220
- 'There is no suitable signing key.'
221
- )
222
- raise exceptions .NoSuitableKeys () from err
223
- except UnknownAlgorithm as err :
224
- log .warning (
225
- 'An error was encountered while signing the OAuth 2.0 access token JWT. '
226
- 'There algorithm is unknown.'
227
- )
228
- raise exceptions .MalformedJwtToken () from err
175
+ return jwt .encode (_message , self .key , algorithm = "RS256" )
229
176
230
177
def get_public_jwk (self ):
231
178
"""
232
179
Export Public JWK
233
180
"""
234
- public_keys = jwk . KEYS ()
181
+ jwk = { "keys" : []}
235
182
236
183
# Only append to keyset if a key exists
237
184
if self .key :
238
- public_keys .append (self .key )
239
-
240
- return json .loads (public_keys .dump_jwks ())
185
+ algo_obj = jwt .get_algorithm_by_name ('RS256' )
186
+ public_key = algo_obj .prepare_key (self .key ).public_key ()
187
+ jwk ['keys' ].append (json .loads (algo_obj .to_jwk (public_key )))
188
+ return jwk
241
189
242
190
def validate_and_decode (self , token , iss = None , aud = None ):
243
191
"""
@@ -246,49 +194,22 @@ def validate_and_decode(self, token, iss=None, aud=None):
246
194
Validates a token sent by the tool using the platform's RSA Key.
247
195
Optionally validate iss and aud claims if provided.
248
196
"""
197
+ if not self .key :
198
+ raise exceptions .RsaKeyNotSet ()
249
199
try :
250
- # Verify message signature
251
- message = JWS ().verify_compact (token , keys = [self .key ])
252
-
253
- # If message is valid, check expiration from JWT
254
- if 'exp' in message and message ['exp' ] < time .time ():
255
- log .warning (
256
- 'An error was encountered while verifying the OAuth 2.0 access token. '
257
- 'The JWT has expired.'
258
- )
259
- raise exceptions .TokenSignatureExpired ()
260
-
261
- # Validate issuer claim (if present)
262
- log_message_base = 'An error was encountered while verifying the OAuth 2.0 access token. '
263
- if iss :
264
- if 'iss' not in message or message ['iss' ] != iss :
265
- error_message = 'The required iss claim is missing or does not match the expected iss value. '
266
- log_message = log_message_base + error_message
267
-
268
- log .warning (log_message )
269
- raise exceptions .InvalidClaimValue (error_message )
270
-
271
- # Validate audience claim (if present)
272
- if aud :
273
- if 'aud' not in message or aud not in message ['aud' ]:
274
- error_message = 'The required aud claim is missing.'
275
- log_message = log_message_base + error_message
276
-
277
- log .warning (log_message )
278
- raise exceptions .InvalidClaimValue (error_message )
279
-
280
- # Else return token contents
200
+ message = jwt .decode (
201
+ token ,
202
+ key = self .key .public_key (),
203
+ audience = aud ,
204
+ issuer = iss ,
205
+ algorithms = ['RS256' , 'RS512' ],
206
+ options = {
207
+ 'verify_signature' : True ,
208
+ 'verify_aud' : True if aud else False
209
+ }
210
+ )
281
211
return message
282
212
283
- except NoSuitableSigningKeys as err :
284
- log .warning (
285
- 'An error was encountered while verifying the OAuth 2.0 access token. '
286
- 'There is no suitable signing key.'
287
- )
288
- raise exceptions .NoSuitableKeys () from err
289
- except BadSyntax as err :
290
- log .warning (
291
- 'An error was encountered while verifying the OAuth 2.0 access token. '
292
- 'The JWT is malformed.'
293
- )
294
- raise exceptions .MalformedJwtToken () from err
213
+ except Exception as token_error :
214
+ exc_info = sys .exc_info ()
215
+ raise jwt .InvalidTokenError (exc_info [2 ]) from token_error
0 commit comments