4
4
This handles validating messages sent by the tool and generating
5
5
access token with LTI scopes.
6
6
"""
7
- import codecs
8
7
import copy
9
- import time
10
8
import json
9
+ import math
10
+ import time
11
+ import sys
11
12
13
+ import jwt
12
14
from Cryptodome .PublicKey import RSA
13
- from jwkest import BadSignature , BadSyntax , WrongNumberOfParts , jwk
14
- from jwkest .jwk import RSAKey , load_jwks_from_url
15
- from jwkest .jws import JWS , NoSuitableSigningKeys
16
- from jwkest .jwt import JWT
17
15
18
16
from . import exceptions
19
17
@@ -47,14 +45,9 @@ def __init__(self, public_key=None, keyset_url=None):
47
45
# Import from public key
48
46
if public_key :
49
47
try :
50
- new_key = RSAKey (use = 'sig' )
51
-
52
- # Unescape key before importing it
53
- raw_key = codecs .decode (public_key , 'unicode_escape' )
54
-
55
48
# Import Key and save to internal state
56
- new_key . load_key ( RSA . import_key ( raw_key ) )
57
- self .public_key = new_key
49
+ algo_obj = jwt . get_algorithm_by_name ( 'RS256' )
50
+ self .public_key = algo_obj . prepare_key ( public_key )
58
51
except ValueError as err :
59
52
raise exceptions .InvalidRsaKey () from err
60
53
@@ -69,7 +62,7 @@ def _get_keyset(self, kid=None):
69
62
70
63
if self .keyset_url :
71
64
try :
72
- keys = load_jwks_from_url (self .keyset_url )
65
+ keys = jwt . PyJWKClient (self .keyset_url ). get_jwk_set ( )
73
66
except Exception as err :
74
67
# Broad Exception is required here because jwkest raises
75
68
# an Exception object explicitly.
@@ -78,13 +71,13 @@ def _get_keyset(self, kid=None):
78
71
raise exceptions .NoSuitableKeys () from err
79
72
keyset .extend (keys )
80
73
81
- if self .public_key and kid :
82
- # Fill in key id of stored key.
83
- # This is needed because if the JWS is signed with a
84
- # key with a kid, pyjwkest doesn't match them with
85
- # keys without kid ( kid=None) and fails verification
86
- self . public_key . kid = kid
87
-
74
+ if self .public_key :
75
+ if kid :
76
+ # Fill in key id of stored key.
77
+ # This is needed because if the JWS is signed with a
78
+ # key with a kid, pyjwkest doesn't match them with
79
+ # keys without kid (kid=None) and fails verification
80
+ self . public_key . kid = kid
88
81
# Add to keyset
89
82
keyset .append (self .public_key )
90
83
@@ -100,32 +93,22 @@ def validate_and_decode(self, token):
100
93
iss, sub, exp, aud and jti claims.
101
94
"""
102
95
try :
103
- # Get KID from JWT header
104
- jwt = JWT ().unpack (token )
105
-
106
- # Verify message signature
107
- message = JWS ().verify_compact (
108
- token ,
109
- keys = self ._get_keyset (
110
- jwt .headers .get ('kid' )
111
- )
112
- )
113
-
114
- # If message is valid, check expiration from JWT
115
- if 'exp' in message and message ['exp' ] < time .time ():
116
- raise exceptions .TokenSignatureExpired ()
117
-
118
- # TODO: Validate other JWT claims
119
-
120
- # Else returns decoded message
121
- return message
122
-
123
- except NoSuitableSigningKeys as err :
124
- raise exceptions .NoSuitableKeys () from err
125
- except (BadSyntax , WrongNumberOfParts ) as err :
126
- raise exceptions .MalformedJwtToken () from err
127
- except BadSignature as err :
128
- raise exceptions .BadJwtSignature () from err
96
+ key_set = self ._get_keyset ()
97
+ for i in range (len (key_set )):
98
+ try :
99
+ message = jwt .decode (
100
+ token ,
101
+ key = key_set [i ],
102
+ algorithms = ['RS256' , 'RS512' ,],
103
+ options = {'verify_signature' : True }
104
+ )
105
+ return message
106
+ except Exception :
107
+ if i == len (key_set ) - 1 :
108
+ raise
109
+ except Exception as token_error :
110
+ exc_info = sys .exc_info ()
111
+ raise jwt .InvalidTokenError (exc_info [2 ]) from token_error
129
112
130
113
131
114
class PlatformKeyHandler :
@@ -144,14 +127,17 @@ def __init__(self, key_pem, kid=None):
144
127
if key_pem :
145
128
# Import JWK from RSA key
146
129
try :
147
- self .key = RSAKey (
148
- # Using the same key ID as client id
149
- # This way we can easily serve multiple public
150
- # keys on teh same endpoint and keep all
151
- # LTI 1.3 blocks working
152
- kid = kid ,
153
- key = RSA .import_key (key_pem )
154
- )
130
+ algo = jwt .get_algorithm_by_name ('RS256' )
131
+ self .key = algo .prepare_key (key_pem )
132
+
133
+ # self.key = RSAKey(
134
+ # # Using the same key ID as client id
135
+ # # This way we can easily serve multiple public
136
+ # # keys on teh same endpoint and keep all
137
+ # # LTI 1.3 blocks working
138
+ # kid=kid,
139
+ # key=RSA.import_key(key_pem)
140
+ # )
155
141
except ValueError as err :
156
142
raise exceptions .InvalidRsaKey () from err
157
143
@@ -167,28 +153,26 @@ def encode_and_sign(self, message, expiration=None):
167
153
# Set iat and exp if expiration is set
168
154
if expiration :
169
155
_message .update ({
170
- "iat" : int (round (time .time ())),
156
+ "iat" : int (math . floor (time .time ())),
171
157
"exp" : int (round (time .time ()) + expiration ),
172
158
})
173
159
174
160
# The class instance that sets up the signing operation
175
161
# An RS 256 key is required for LTI 1.3
176
- _jws = JWS (_message , alg = "RS256" , cty = "JWT" )
177
-
178
- # Encode and sign LTI message
179
- return _jws .sign_compact ([self .key ])
162
+ return jwt .encode (_message , self .key , algorithm = "RS256" )
180
163
181
164
def get_public_jwk (self ):
182
165
"""
183
166
Export Public JWK
184
167
"""
185
- public_keys = jwk . KEYS ()
168
+ jwk = { "keys" : []}
186
169
187
170
# Only append to keyset if a key exists
188
171
if self .key :
189
- public_keys .append (self .key )
190
-
191
- return json .loads (public_keys .dump_jwks ())
172
+ algo_obj = jwt .get_algorithm_by_name ('RS256' )
173
+ public_key = algo_obj .prepare_key (self .key ).public_key ()
174
+ jwk ['keys' ].append (json .loads (algo_obj .to_jwk (public_key )))
175
+ return jwk
192
176
193
177
def validate_and_decode (self , token , iss = None , aud = None ):
194
178
"""
@@ -198,28 +182,18 @@ def validate_and_decode(self, token, iss=None, aud=None):
198
182
Optionally validate iss and aud claims if provided.
199
183
"""
200
184
try :
201
- # Verify message signature
202
- message = JWS ().verify_compact (token , keys = [self .key ])
203
-
204
- # If message is valid, check expiration from JWT
205
- if 'exp' in message and message ['exp' ] < time .time ():
206
- raise exceptions .TokenSignatureExpired ()
207
-
208
- # Validate issuer claim (if present)
209
- if iss :
210
- if 'iss' not in message or message ['iss' ] != iss :
211
- raise exceptions .InvalidClaimValue ('The required iss claim is either missing or does '
212
- 'not match the expected iss value.' )
213
-
214
- # Validate audience claim (if present)
215
- if aud :
216
- if 'aud' not in message or aud not in message ['aud' ]:
217
- raise exceptions .InvalidClaimValue ('The required aud claim is missing.' )
218
-
219
- # Else return token contents
185
+ message = jwt .decode (
186
+ token ,
187
+ key = self .key .public_key (),
188
+ audience = aud ,
189
+ algorithms = ['RS256' , 'RS512' ],
190
+ options = {
191
+ 'verify_signature' : True ,
192
+ 'verify_aud' : True if aud else False
193
+ }
194
+ )
220
195
return message
221
196
222
- except NoSuitableSigningKeys as err :
223
- raise exceptions .NoSuitableKeys () from err
224
- except BadSyntax as err :
225
- raise exceptions .MalformedJwtToken () from err
197
+ except Exception as token_error :
198
+ exc_info = sys .exc_info ()
199
+ raise jwt .InvalidTokenError (exc_info [2 ]) from token_error
0 commit comments