4
4
This handles validating messages sent by the tool and generating
5
5
access token with LTI scopes.
6
6
"""
7
- import codecs
8
7
import copy
9
- import time
10
8
import json
9
+ import math
10
+ import time
11
+ import sys
11
12
13
+ import jwt
12
14
from Cryptodome .PublicKey import RSA
13
- from jwkest import BadSignature , BadSyntax , WrongNumberOfParts , jwk
14
- from jwkest .jwk import RSAKey , load_jwks_from_url
15
- from jwkest .jws import JWS , NoSuitableSigningKeys
16
- from jwkest .jwt import JWT
17
15
18
16
from . import exceptions
19
17
@@ -47,14 +45,9 @@ def __init__(self, public_key=None, keyset_url=None):
47
45
# Import from public key
48
46
if public_key :
49
47
try :
50
- new_key = RSAKey (use = 'sig' )
51
-
52
- # Unescape key before importing it
53
- raw_key = codecs .decode (public_key , 'unicode_escape' )
54
-
55
48
# Import Key and save to internal state
56
- new_key . load_key ( RSA . import_key ( raw_key ) )
57
- self .public_key = new_key
49
+ algo_obj = jwt . get_algorithm_by_name ( 'RS256' )
50
+ self .public_key = algo_obj . prepare_key ( public_key )
58
51
except ValueError as err :
59
52
raise exceptions .InvalidRsaKey () from err
60
53
@@ -69,7 +62,7 @@ def _get_keyset(self, kid=None):
69
62
70
63
if self .keyset_url :
71
64
try :
72
- keys = load_jwks_from_url (self .keyset_url )
65
+ keys = jwt . PyJWKClient (self .keyset_url ). get_jwk_set ( )
73
66
except Exception as err :
74
67
# Broad Exception is required here because jwkest raises
75
68
# an Exception object explicitly.
@@ -78,13 +71,13 @@ def _get_keyset(self, kid=None):
78
71
raise exceptions .NoSuitableKeys () from err
79
72
keyset .extend (keys )
80
73
81
- if self .public_key and kid :
82
- # Fill in key id of stored key.
83
- # This is needed because if the JWS is signed with a
84
- # key with a kid, pyjwkest doesn't match them with
85
- # keys without kid ( kid=None) and fails verification
86
- self . public_key . kid = kid
87
-
74
+ if self .public_key :
75
+ if kid :
76
+ # Fill in key id of stored key.
77
+ # This is needed because if the JWS is signed with a
78
+ # key with a kid, pyjwkest doesn't match them with
79
+ # keys without kid (kid=None) and fails verification
80
+ self . public_key . kid = kid
88
81
# Add to keyset
89
82
keyset .append (self .public_key )
90
83
@@ -100,32 +93,24 @@ def validate_and_decode(self, token):
100
93
iss, sub, exp, aud and jti claims.
101
94
"""
102
95
try :
103
- # Get KID from JWT header
104
- jwt = JWT ().unpack (token )
105
-
106
- # Verify message signature
107
- message = JWS ().verify_compact (
108
- token ,
109
- keys = self ._get_keyset (
110
- jwt .headers .get ('kid' )
111
- )
112
- )
113
-
114
- # If message is valid, check expiration from JWT
115
- if 'exp' in message and message ['exp' ] < time .time ():
116
- raise exceptions .TokenSignatureExpired ()
117
-
118
- # TODO: Validate other JWT claims
119
-
120
- # Else returns decoded message
121
- return message
122
-
123
- except NoSuitableSigningKeys as err :
124
- raise exceptions .NoSuitableKeys () from err
125
- except (BadSyntax , WrongNumberOfParts ) as err :
126
- raise exceptions .MalformedJwtToken () from err
127
- except BadSignature as err :
128
- raise exceptions .BadJwtSignature () from err
96
+ key_set = self ._get_keyset ()
97
+ if not key_set :
98
+ raise exceptions .NoSuitableKeys ()
99
+ for i in range (len (key_set )):
100
+ try :
101
+ message = jwt .decode (
102
+ token ,
103
+ key = key_set [i ],
104
+ algorithms = ['RS256' , 'RS512' ,],
105
+ options = {'verify_signature' : True }
106
+ )
107
+ return message
108
+ except Exception :
109
+ if i == len (key_set ) - 1 :
110
+ raise
111
+ except Exception as token_error :
112
+ exc_info = sys .exc_info ()
113
+ raise jwt .InvalidTokenError (exc_info [2 ]) from token_error
129
114
130
115
131
116
class PlatformKeyHandler :
@@ -144,14 +129,8 @@ def __init__(self, key_pem, kid=None):
144
129
if key_pem :
145
130
# Import JWK from RSA key
146
131
try :
147
- self .key = RSAKey (
148
- # Using the same key ID as client id
149
- # This way we can easily serve multiple public
150
- # keys on teh same endpoint and keep all
151
- # LTI 1.3 blocks working
152
- kid = kid ,
153
- key = RSA .import_key (key_pem )
154
- )
132
+ algo = jwt .get_algorithm_by_name ('RS256' )
133
+ self .key = algo .prepare_key (key_pem )
155
134
except ValueError as err :
156
135
raise exceptions .InvalidRsaKey () from err
157
136
@@ -167,28 +146,26 @@ def encode_and_sign(self, message, expiration=None):
167
146
# Set iat and exp if expiration is set
168
147
if expiration :
169
148
_message .update ({
170
- "iat" : int (round (time .time ())),
171
- "exp" : int (round (time .time ()) + expiration ),
149
+ "iat" : int (math . floor (time .time ())),
150
+ "exp" : int (math . floor (time .time ()) + expiration ),
172
151
})
173
152
174
153
# The class instance that sets up the signing operation
175
154
# An RS 256 key is required for LTI 1.3
176
- _jws = JWS (_message , alg = "RS256" , cty = "JWT" )
177
-
178
- # Encode and sign LTI message
179
- return _jws .sign_compact ([self .key ])
155
+ return jwt .encode (_message , self .key , algorithm = "RS256" )
180
156
181
157
def get_public_jwk (self ):
182
158
"""
183
159
Export Public JWK
184
160
"""
185
- public_keys = jwk . KEYS ()
161
+ jwk = { "keys" : []}
186
162
187
163
# Only append to keyset if a key exists
188
164
if self .key :
189
- public_keys .append (self .key )
190
-
191
- return json .loads (public_keys .dump_jwks ())
165
+ algo_obj = jwt .get_algorithm_by_name ('RS256' )
166
+ public_key = algo_obj .prepare_key (self .key ).public_key ()
167
+ jwk ['keys' ].append (json .loads (algo_obj .to_jwk (public_key )))
168
+ return jwk
192
169
193
170
def validate_and_decode (self , token , iss = None , aud = None ):
194
171
"""
@@ -197,29 +174,22 @@ def validate_and_decode(self, token, iss=None, aud=None):
197
174
Validates a token sent by the tool using the platform's RSA Key.
198
175
Optionally validate iss and aud claims if provided.
199
176
"""
177
+ if not self .key :
178
+ raise exceptions .RsaKeyNotSet ()
200
179
try :
201
- # Verify message signature
202
- message = JWS ().verify_compact (token , keys = [self .key ])
203
-
204
- # If message is valid, check expiration from JWT
205
- if 'exp' in message and message ['exp' ] < time .time ():
206
- raise exceptions .TokenSignatureExpired ()
207
-
208
- # Validate issuer claim (if present)
209
- if iss :
210
- if 'iss' not in message or message ['iss' ] != iss :
211
- raise exceptions .InvalidClaimValue ('The required iss claim is either missing or does '
212
- 'not match the expected iss value.' )
213
-
214
- # Validate audience claim (if present)
215
- if aud :
216
- if 'aud' not in message or aud not in message ['aud' ]:
217
- raise exceptions .InvalidClaimValue ('The required aud claim is missing.' )
218
-
219
- # Else return token contents
180
+ message = jwt .decode (
181
+ token ,
182
+ key = self .key .public_key (),
183
+ audience = aud ,
184
+ issuer = iss ,
185
+ algorithms = ['RS256' , 'RS512' ],
186
+ options = {
187
+ 'verify_signature' : True ,
188
+ 'verify_aud' : True if aud else False
189
+ }
190
+ )
220
191
return message
221
192
222
- except NoSuitableSigningKeys as err :
223
- raise exceptions .NoSuitableKeys () from err
224
- except BadSyntax as err :
225
- raise exceptions .MalformedJwtToken () from err
193
+ except Exception as token_error :
194
+ exc_info = sys .exc_info ()
195
+ raise jwt .InvalidTokenError (exc_info [2 ]) from token_error
0 commit comments