Skip to content

Commit e33d723

Browse files
committed
chore: add pyjwt requirement
1 parent c5d9a5f commit e33d723

File tree

14 files changed

+173
-223
lines changed

14 files changed

+173
-223
lines changed

lti_consumer/lti_1p3/key_handlers.py

+58-137
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44
This handles validating messages sent by the tool and generating
55
access token with LTI scopes.
66
"""
7-
import codecs
87
import copy
9-
import time
108
import json
9+
import math
10+
import time
11+
import sys
1112
import logging
1213

14+
import jwt
1315
from Cryptodome.PublicKey import RSA
14-
from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk
15-
from jwkest.jwk import RSAKey, load_jwks_from_url
16-
from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm
17-
from jwkest.jwt import JWT
1816

1917
from . import exceptions
2018

@@ -50,14 +48,9 @@ def __init__(self, public_key=None, keyset_url=None):
5048
# Import from public key
5149
if public_key:
5250
try:
53-
new_key = RSAKey(use='sig')
54-
55-
# Unescape key before importing it
56-
raw_key = codecs.decode(public_key, 'unicode_escape')
57-
5851
# Import Key and save to internal state
59-
new_key.load_key(RSA.import_key(raw_key))
60-
self.public_key = new_key
52+
algo_obj = jwt.get_algorithm_by_name('RS256')
53+
self.public_key = algo_obj.prepare_key(public_key)
6154
except ValueError as err:
6255
log.warning(
6356
'An error was encountered while loading the LTI tool\'s key from the public key. '
@@ -76,7 +69,7 @@ def _get_keyset(self, kid=None):
7669

7770
if self.keyset_url:
7871
try:
79-
keys = load_jwks_from_url(self.keyset_url)
72+
keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set()
8073
except Exception as err:
8174
# Broad Exception is required here because jwkest raises
8275
# an Exception object explicitly.
@@ -89,13 +82,13 @@ def _get_keyset(self, kid=None):
8982
raise exceptions.NoSuitableKeys() from err
9083
keyset.extend(keys)
9184

92-
if self.public_key and kid:
93-
# Fill in key id of stored key.
94-
# This is needed because if the JWS is signed with a
95-
# key with a kid, pyjwkest doesn't match them with
96-
# keys without kid (kid=None) and fails verification
97-
self.public_key.kid = kid
98-
85+
if self.public_key:
86+
if kid:
87+
# Fill in key id of stored key.
88+
# This is needed because if the JWS is signed with a
89+
# key with a kid, pyjwkest doesn't match them with
90+
# keys without kid (kid=None) and fails verification
91+
self.public_key.kid = kid
9992
# Add to keyset
10093
keyset.append(self.public_key)
10194

@@ -111,48 +104,24 @@ def validate_and_decode(self, token):
111104
iss, sub, exp, aud and jti claims.
112105
"""
113106
try:
114-
# Get KID from JWT header
115-
jwt = JWT().unpack(token)
116-
117-
# Verify message signature
118-
message = JWS().verify_compact(
119-
token,
120-
keys=self._get_keyset(
121-
jwt.headers.get('kid')
122-
)
123-
)
124-
125-
# If message is valid, check expiration from JWT
126-
if 'exp' in message and message['exp'] < time.time():
127-
log.warning(
128-
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
129-
'The JWT has expired.'
130-
)
131-
raise exceptions.TokenSignatureExpired()
132-
133-
# TODO: Validate other JWT claims
134-
135-
# Else returns decoded message
136-
return message
137-
138-
except NoSuitableSigningKeys as err:
139-
log.warning(
140-
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
141-
'There is no suitable signing key.'
142-
)
143-
raise exceptions.NoSuitableKeys() from err
144-
except (BadSyntax, WrongNumberOfParts) as err:
145-
log.warning(
146-
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
147-
'The JWT is malformed.'
148-
)
149-
raise exceptions.MalformedJwtToken() from err
150-
except BadSignature as err:
151-
log.warning(
152-
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
153-
'The JWT signature is incorrect.'
154-
)
155-
raise exceptions.BadJwtSignature() from err
107+
key_set = self._get_keyset()
108+
if not key_set:
109+
raise exceptions.NoSuitableKeys()
110+
for i in range(len(key_set)):
111+
try:
112+
message = jwt.decode(
113+
token,
114+
key=key_set[i],
115+
algorithms=['RS256', 'RS512',],
116+
options={'verify_signature': True}
117+
)
118+
return message
119+
except Exception:
120+
if i == len(key_set) - 1:
121+
raise
122+
except Exception as token_error:
123+
exc_info = sys.exc_info()
124+
raise jwt.InvalidTokenError(exc_info[2]) from token_error
156125

157126

158127
class PlatformKeyHandler:
@@ -171,14 +140,8 @@ def __init__(self, key_pem, kid=None):
171140
if key_pem:
172141
# Import JWK from RSA key
173142
try:
174-
self.key = RSAKey(
175-
# Using the same key ID as client id
176-
# This way we can easily serve multiple public
177-
# keys on the same endpoint and keep all
178-
# LTI 1.3 blocks working
179-
kid=kid,
180-
key=RSA.import_key(key_pem)
181-
)
143+
algo = jwt.get_algorithm_by_name('RS256')
144+
self.key = algo.prepare_key(key_pem)
182145
except ValueError as err:
183146
log.warning(
184147
'An error was encountered while loading the LTI platform\'s key. '
@@ -203,41 +166,26 @@ def encode_and_sign(self, message, expiration=None):
203166
# Set iat and exp if expiration is set
204167
if expiration:
205168
_message.update({
206-
"iat": int(round(time.time())),
207-
"exp": int(round(time.time()) + expiration),
169+
"iat": int(math.floor(time.time())),
170+
"exp": int(math.floor(time.time()) + expiration),
208171
})
209172

210173
# The class instance that sets up the signing operation
211174
# An RS 256 key is required for LTI 1.3
212-
_jws = JWS(_message, alg="RS256", cty="JWT")
213-
214-
try:
215-
# Encode and sign LTI message
216-
return _jws.sign_compact([self.key])
217-
except NoSuitableSigningKeys as err:
218-
log.warning(
219-
'An error was encountered while signing the OAuth 2.0 access token JWT. '
220-
'There is no suitable signing key.'
221-
)
222-
raise exceptions.NoSuitableKeys() from err
223-
except UnknownAlgorithm as err:
224-
log.warning(
225-
'An error was encountered while signing the OAuth 2.0 access token JWT. '
226-
'There algorithm is unknown.'
227-
)
228-
raise exceptions.MalformedJwtToken() from err
175+
return jwt.encode(_message, self.key, algorithm="RS256")
229176

230177
def get_public_jwk(self):
231178
"""
232179
Export Public JWK
233180
"""
234-
public_keys = jwk.KEYS()
181+
jwk = {"keys": []}
235182

236183
# Only append to keyset if a key exists
237184
if self.key:
238-
public_keys.append(self.key)
239-
240-
return json.loads(public_keys.dump_jwks())
185+
algo_obj = jwt.get_algorithm_by_name('RS256')
186+
public_key = algo_obj.prepare_key(self.key).public_key()
187+
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
188+
return jwk
241189

242190
def validate_and_decode(self, token, iss=None, aud=None):
243191
"""
@@ -246,49 +194,22 @@ def validate_and_decode(self, token, iss=None, aud=None):
246194
Validates a token sent by the tool using the platform's RSA Key.
247195
Optionally validate iss and aud claims if provided.
248196
"""
197+
if not self.key:
198+
raise exceptions.RsaKeyNotSet()
249199
try:
250-
# Verify message signature
251-
message = JWS().verify_compact(token, keys=[self.key])
252-
253-
# If message is valid, check expiration from JWT
254-
if 'exp' in message and message['exp'] < time.time():
255-
log.warning(
256-
'An error was encountered while verifying the OAuth 2.0 access token. '
257-
'The JWT has expired.'
258-
)
259-
raise exceptions.TokenSignatureExpired()
260-
261-
# Validate issuer claim (if present)
262-
log_message_base = 'An error was encountered while verifying the OAuth 2.0 access token. '
263-
if iss:
264-
if 'iss' not in message or message['iss'] != iss:
265-
error_message = 'The required iss claim is missing or does not match the expected iss value. '
266-
log_message = log_message_base + error_message
267-
268-
log.warning(log_message)
269-
raise exceptions.InvalidClaimValue(error_message)
270-
271-
# Validate audience claim (if present)
272-
if aud:
273-
if 'aud' not in message or aud not in message['aud']:
274-
error_message = 'The required aud claim is missing.'
275-
log_message = log_message_base + error_message
276-
277-
log.warning(log_message)
278-
raise exceptions.InvalidClaimValue(error_message)
279-
280-
# Else return token contents
200+
message = jwt.decode(
201+
token,
202+
key=self.key.public_key(),
203+
audience=aud,
204+
issuer=iss,
205+
algorithms=['RS256', 'RS512'],
206+
options={
207+
'verify_signature': True,
208+
'verify_aud': True if aud else False
209+
}
210+
)
281211
return message
282212

283-
except NoSuitableSigningKeys as err:
284-
log.warning(
285-
'An error was encountered while verifying the OAuth 2.0 access token. '
286-
'There is no suitable signing key.'
287-
)
288-
raise exceptions.NoSuitableKeys() from err
289-
except BadSyntax as err:
290-
log.warning(
291-
'An error was encountered while verifying the OAuth 2.0 access token. '
292-
'The JWT is malformed.'
293-
)
294-
raise exceptions.MalformedJwtToken() from err
213+
except Exception as token_error:
214+
exc_info = sys.exc_info()
215+
raise jwt.InvalidTokenError(exc_info[2]) from token_error

lti_consumer/lti_1p3/tests/test_consumer.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
Unit tests for LTI 1.3 consumer implementation
33
"""
44

5-
import json
65
from unittest.mock import patch
76
from urllib.parse import parse_qs, urlparse
87
import uuid
98

109
import ddt
10+
import jwt
11+
import sys
1112
from Cryptodome.PublicKey import RSA
1213
from django.conf import settings
1314
from django.test.testcases import TestCase
1415
from edx_django_utils.cache import get_cache_key, TieredCache
15-
from jwkest.jwk import load_jwks
16-
from jwkest.jws import JWS
16+
from jwt.api_jwk import PyJWKSet
1717

1818
from lti_consumer.data import Lti1p3LaunchData
1919
from lti_consumer.lti_1p3 import exceptions
@@ -36,7 +36,9 @@
3636
STATE = "ABCD"
3737
# Consider storing a fixed key
3838
RSA_KEY_ID = "1"
39-
RSA_KEY = RSA.generate(2048).export_key('PEM')
39+
RSA_KEY = RSA.generate(2048)
40+
RSA_PRIVATE_KEY = RSA_KEY.export_key('PEM')
41+
RSA_PUBLIC_KEY = RSA_KEY.public_key().export_key('PEM')
4042

4143

4244
def _generate_token_request_data(token, scope):
@@ -69,11 +71,11 @@ def setUp(self):
6971
lti_launch_url=LAUNCH_URL,
7072
client_id=CLIENT_ID,
7173
deployment_id=DEPLOYMENT_ID,
72-
rsa_key=RSA_KEY,
74+
rsa_key=RSA_PRIVATE_KEY,
7375
rsa_key_id=RSA_KEY_ID,
7476
redirect_uris=REDIRECT_URIS,
7577
# Use the same key for testing purposes
76-
tool_key=RSA_KEY
78+
tool_key=RSA_PUBLIC_KEY
7779
)
7880

7981
def _setup_lti_launch_data(self):
@@ -118,9 +120,25 @@ def _decode_token(self, token):
118120
This also tests the public keyset function.
119121
"""
120122
public_keyset = self.lti_consumer.get_public_keyset()
121-
key_set = load_jwks(json.dumps(public_keyset))
122-
123-
return JWS().verify_compact(token, keys=key_set)
123+
keyset = PyJWKSet.from_dict(public_keyset).keys
124+
125+
for i in range(len(keyset)):
126+
try:
127+
message = jwt.decode(
128+
token,
129+
key=keyset[i].key,
130+
algorithms=['RS256', 'RS512'],
131+
options={
132+
'verify_signature': True,
133+
'verify_aud': False
134+
}
135+
)
136+
return message
137+
except Exception as token_error:
138+
if i < len(keyset) - 1:
139+
continue
140+
exc_info = sys.exc_info()
141+
raise jwt.InvalidTokenError(exc_info[2]) from token_error
124142

125143
@ddt.data(
126144
({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),
@@ -558,7 +576,7 @@ def test_access_token_invalid_jwt(self):
558576
"""
559577
request_data = _generate_token_request_data("invalid_jwt", "")
560578

561-
with self.assertRaises(exceptions.MalformedJwtToken):
579+
with self.assertRaises(jwt.exceptions.InvalidTokenError):
562580
self.lti_consumer.access_token(request_data)
563581

564582
def test_access_token_no_acs(self):
@@ -686,11 +704,11 @@ def setUp(self):
686704
lti_launch_url=LAUNCH_URL,
687705
client_id=CLIENT_ID,
688706
deployment_id=DEPLOYMENT_ID,
689-
rsa_key=RSA_KEY,
707+
rsa_key=RSA_PRIVATE_KEY,
690708
rsa_key_id=RSA_KEY_ID,
691709
redirect_uris=REDIRECT_URIS,
692710
# Use the same key for testing purposes
693-
tool_key=RSA_KEY
711+
tool_key=RSA_PUBLIC_KEY
694712
)
695713

696714
self.preflight_response = {}
@@ -930,11 +948,11 @@ def setUp(self):
930948
lti_launch_url=LAUNCH_URL,
931949
client_id=CLIENT_ID,
932950
deployment_id=DEPLOYMENT_ID,
933-
rsa_key=RSA_KEY,
951+
rsa_key=RSA_PRIVATE_KEY,
934952
rsa_key_id=RSA_KEY_ID,
935953
redirect_uris=REDIRECT_URIS,
936954
# Use the same key for testing purposes
937-
tool_key=RSA_KEY
955+
tool_key=RSA_PUBLIC_KEY
938956
)
939957

940958
self.preflight_response = {}

0 commit comments

Comments
 (0)