7
7
import copy
8
8
import json
9
9
import math
10
- import time
11
10
import sys
11
+ import time
12
12
import logging
13
13
14
14
import jwt
15
- from Cryptodome .PublicKey import RSA
16
15
from edx_django_utils .monitoring import function_trace
16
+ from jwt .api_jwk import PyJWK
17
17
18
18
from . import exceptions
19
19
@@ -52,7 +52,9 @@ def __init__(self, public_key=None, keyset_url=None):
52
52
try :
53
53
# Import Key and save to internal state
54
54
algo_obj = jwt .get_algorithm_by_name ('RS256' )
55
- self .public_key = algo_obj .prepare_key (public_key )
55
+ public_key = algo_obj .prepare_key (public_key )
56
+ public_jwk = json .loads (algo_obj .to_jwk (public_key ))
57
+ self .public_key = PyJWK .from_dict (public_jwk )
56
58
except ValueError as err :
57
59
log .warning (
58
60
'An error was encountered while loading the LTI tool\' s key from the public key. '
@@ -82,15 +84,16 @@ def _get_keyset(self, kid=None):
82
84
'The RSA keys could not be loaded.'
83
85
)
84
86
raise exceptions .NoSuitableKeys () from err
85
- keyset .extend (keys )
87
+ keyset .extend (keys .keys )
88
+
89
+ if self .public_key and kid :
90
+ # Fill in key id of stored key.
91
+ # This is needed because if the JWS is signed with a
92
+ # key with a kid, pyjwkest doesn't match them with
93
+ # keys without kid (kid=None) and fails verification
94
+ self .public_key .kid = kid
86
95
87
96
if self .public_key :
88
- if kid :
89
- # Fill in key id of stored key.
90
- # This is needed because if the JWS is signed with a
91
- # key with a kid, pyjwkest doesn't match them with
92
- # keys without kid (kid=None) and fails verification
93
- self .public_key .kid = kid
94
97
# Add to keyset
95
98
keyset .append (self .public_key )
96
99
@@ -105,25 +108,29 @@ def validate_and_decode(self, token):
105
108
The authorization server decodes the JWT and MUST validate the values for the
106
109
iss, sub, exp, aud and jti claims.
107
110
"""
108
- try :
109
- key_set = self ._get_keyset ()
110
- if not key_set :
111
- raise exceptions .NoSuitableKeys ()
112
- for i in range (len (key_set )):
113
- try :
114
- message = jwt .decode (
115
- token ,
116
- key = key_set [i ],
117
- algorithms = ['RS256' , 'RS512' ,],
118
- options = {'verify_signature' : True }
119
- )
120
- return message
121
- except Exception :
122
- if i == len (key_set ) - 1 :
123
- raise
124
- except Exception as token_error :
125
- exc_info = sys .exc_info ()
126
- raise jwt .InvalidTokenError (exc_info [2 ]) from token_error
111
+ key_set = self ._get_keyset ()
112
+
113
+ for i , obj in enumerate (key_set ):
114
+ try :
115
+ if hasattr (obj .key , 'public_key' ):
116
+ key = obj .key .public_key ()
117
+ else :
118
+ key = obj .key
119
+ message = jwt .decode (
120
+ token ,
121
+ key ,
122
+ algorithms = ['RS256' , 'RS512' ,],
123
+ options = {
124
+ 'verify_signature' : True ,
125
+ 'verify_aud' : False
126
+ }
127
+ )
128
+ return message
129
+ except Exception : # pylint: disable=broad-except
130
+ if i == len (key_set ) - 1 :
131
+ raise
132
+
133
+ raise exceptions .NoSuitableKeys ()
127
134
128
135
129
136
class PlatformKeyHandler :
@@ -144,7 +151,10 @@ def __init__(self, key_pem, kid=None):
144
151
# Import JWK from RSA key
145
152
try :
146
153
algo = jwt .get_algorithm_by_name ('RS256' )
147
- self .key = algo .prepare_key (key_pem )
154
+ private_key = algo .prepare_key (key_pem )
155
+ private_jwk = json .loads (algo .to_jwk (private_key ))
156
+ private_jwk ['kid' ] = kid
157
+ self .key = PyJWK .from_dict (private_jwk )
148
158
except ValueError as err :
149
159
log .warning (
150
160
'An error was encountered while loading the LTI platform\' s key. '
@@ -175,7 +185,7 @@ def encode_and_sign(self, message, expiration=None):
175
185
176
186
# The class instance that sets up the signing operation
177
187
# An RS 256 key is required for LTI 1.3
178
- return jwt .encode (_message , self .key , algorithm = "RS256" )
188
+ return jwt .encode (_message , self .key . key , algorithm = "RS256" )
179
189
180
190
def get_public_jwk (self ):
181
191
"""
@@ -186,11 +196,11 @@ def get_public_jwk(self):
186
196
# Only append to keyset if a key exists
187
197
if self .key :
188
198
algo_obj = jwt .get_algorithm_by_name ('RS256' )
189
- public_key = algo_obj .prepare_key (self .key ).public_key ()
199
+ public_key = algo_obj .prepare_key (self .key . key ).public_key ()
190
200
jwk ['keys' ].append (json .loads (algo_obj .to_jwk (public_key )))
191
201
return jwk
192
202
193
- def validate_and_decode (self , token , iss = None , aud = None ):
203
+ def validate_and_decode (self , token , iss = None , aud = None , exp = True ):
194
204
"""
195
205
Check if a platform token is valid, and return allowed scopes.
196
206
@@ -202,13 +212,15 @@ def validate_and_decode(self, token, iss=None, aud=None):
202
212
try :
203
213
message = jwt .decode (
204
214
token ,
205
- key = self .key .public_key (),
215
+ key = self .key .key . public_key (),
206
216
audience = aud ,
207
217
issuer = iss ,
208
218
algorithms = ['RS256' , 'RS512' ],
209
219
options = {
210
220
'verify_signature' : True ,
211
- 'verify_aud' : True if aud else False
221
+ 'verify_exp' : bool (exp ),
222
+ 'verify_iss' : bool (iss ),
223
+ 'verify_aud' : bool (aud )
212
224
}
213
225
)
214
226
return message
0 commit comments