Skip to content

Commit 8039516

Browse files
committed
chore: add pyjwt requirement
1 parent 1f1cc25 commit 8039516

File tree

14 files changed

+174
-174
lines changed

14 files changed

+174
-174
lines changed

lti_consumer/lti_1p3/key_handlers.py

+58-88
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
This handles validating messages sent by the tool and generating
55
access token with LTI scopes.
66
"""
7-
import codecs
87
import copy
9-
import time
108
import json
9+
import math
10+
import time
11+
import sys
1112

13+
import jwt
1214
from Cryptodome.PublicKey import RSA
13-
from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk
14-
from jwkest.jwk import RSAKey, load_jwks_from_url
15-
from jwkest.jws import JWS, NoSuitableSigningKeys
16-
from jwkest.jwt import JWT
1715

1816
from . import exceptions
1917

@@ -47,14 +45,9 @@ def __init__(self, public_key=None, keyset_url=None):
4745
# Import from public key
4846
if public_key:
4947
try:
50-
new_key = RSAKey(use='sig')
51-
52-
# Unescape key before importing it
53-
raw_key = codecs.decode(public_key, 'unicode_escape')
54-
5548
# Import Key and save to internal state
56-
new_key.load_key(RSA.import_key(raw_key))
57-
self.public_key = new_key
49+
algo_obj = jwt.get_algorithm_by_name('RS256')
50+
self.public_key = algo_obj.prepare_key(public_key)
5851
except ValueError as err:
5952
raise exceptions.InvalidRsaKey() from err
6053

@@ -69,7 +62,7 @@ def _get_keyset(self, kid=None):
6962

7063
if self.keyset_url:
7164
try:
72-
keys = load_jwks_from_url(self.keyset_url)
65+
keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set()
7366
except Exception as err:
7467
# Broad Exception is required here because jwkest raises
7568
# an Exception object explicitly.
@@ -78,13 +71,13 @@ def _get_keyset(self, kid=None):
7871
raise exceptions.NoSuitableKeys() from err
7972
keyset.extend(keys)
8073

81-
if self.public_key and kid:
82-
# Fill in key id of stored key.
83-
# This is needed because if the JWS is signed with a
84-
# key with a kid, pyjwkest doesn't match them with
85-
# keys without kid (kid=None) and fails verification
86-
self.public_key.kid = kid
87-
74+
if self.public_key:
75+
if kid:
76+
# Fill in key id of stored key.
77+
# This is needed because if the JWS is signed with a
78+
# key with a kid, pyjwkest doesn't match them with
79+
# keys without kid (kid=None) and fails verification
80+
self.public_key.kid = kid
8881
# Add to keyset
8982
keyset.append(self.public_key)
9083

@@ -100,32 +93,24 @@ def validate_and_decode(self, token):
10093
iss, sub, exp, aud and jti claims.
10194
"""
10295
try:
103-
# Get KID from JWT header
104-
jwt = JWT().unpack(token)
105-
106-
# Verify message signature
107-
message = JWS().verify_compact(
108-
token,
109-
keys=self._get_keyset(
110-
jwt.headers.get('kid')
111-
)
112-
)
113-
114-
# If message is valid, check expiration from JWT
115-
if 'exp' in message and message['exp'] < time.time():
116-
raise exceptions.TokenSignatureExpired()
117-
118-
# TODO: Validate other JWT claims
119-
120-
# Else returns decoded message
121-
return message
122-
123-
except NoSuitableSigningKeys as err:
124-
raise exceptions.NoSuitableKeys() from err
125-
except (BadSyntax, WrongNumberOfParts) as err:
126-
raise exceptions.MalformedJwtToken() from err
127-
except BadSignature as err:
128-
raise exceptions.BadJwtSignature() from err
96+
key_set = self._get_keyset()
97+
if not key_set:
98+
raise exceptions.NoSuitableKeys()
99+
for i in range(len(key_set)):
100+
try:
101+
message = jwt.decode(
102+
token,
103+
key=key_set[i],
104+
algorithms=['RS256', 'RS512',],
105+
options={'verify_signature': True}
106+
)
107+
return message
108+
except Exception:
109+
if i == len(key_set) - 1:
110+
raise
111+
except Exception as token_error:
112+
exc_info = sys.exc_info()
113+
raise jwt.InvalidTokenError(exc_info[2]) from token_error
129114

130115

131116
class PlatformKeyHandler:
@@ -144,14 +129,8 @@ def __init__(self, key_pem, kid=None):
144129
if key_pem:
145130
# Import JWK from RSA key
146131
try:
147-
self.key = RSAKey(
148-
# Using the same key ID as client id
149-
# This way we can easily serve multiple public
150-
# keys on teh same endpoint and keep all
151-
# LTI 1.3 blocks working
152-
kid=kid,
153-
key=RSA.import_key(key_pem)
154-
)
132+
algo = jwt.get_algorithm_by_name('RS256')
133+
self.key = algo.prepare_key(key_pem)
155134
except ValueError as err:
156135
raise exceptions.InvalidRsaKey() from err
157136

@@ -167,28 +146,26 @@ def encode_and_sign(self, message, expiration=None):
167146
# Set iat and exp if expiration is set
168147
if expiration:
169148
_message.update({
170-
"iat": int(round(time.time())),
171-
"exp": int(round(time.time()) + expiration),
149+
"iat": int(math.floor(time.time())),
150+
"exp": int(math.floor(time.time()) + expiration),
172151
})
173152

174153
# The class instance that sets up the signing operation
175154
# An RS 256 key is required for LTI 1.3
176-
_jws = JWS(_message, alg="RS256", cty="JWT")
177-
178-
# Encode and sign LTI message
179-
return _jws.sign_compact([self.key])
155+
return jwt.encode(_message, self.key, algorithm="RS256")
180156

181157
def get_public_jwk(self):
182158
"""
183159
Export Public JWK
184160
"""
185-
public_keys = jwk.KEYS()
161+
jwk = {"keys": []}
186162

187163
# Only append to keyset if a key exists
188164
if self.key:
189-
public_keys.append(self.key)
190-
191-
return json.loads(public_keys.dump_jwks())
165+
algo_obj = jwt.get_algorithm_by_name('RS256')
166+
public_key = algo_obj.prepare_key(self.key).public_key()
167+
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
168+
return jwk
192169

193170
def validate_and_decode(self, token, iss=None, aud=None):
194171
"""
@@ -197,29 +174,22 @@ def validate_and_decode(self, token, iss=None, aud=None):
197174
Validates a token sent by the tool using the platform's RSA Key.
198175
Optionally validate iss and aud claims if provided.
199176
"""
177+
if not self.key:
178+
raise exceptions.RsaKeyNotSet()
200179
try:
201-
# Verify message signature
202-
message = JWS().verify_compact(token, keys=[self.key])
203-
204-
# If message is valid, check expiration from JWT
205-
if 'exp' in message and message['exp'] < time.time():
206-
raise exceptions.TokenSignatureExpired()
207-
208-
# Validate issuer claim (if present)
209-
if iss:
210-
if 'iss' not in message or message['iss'] != iss:
211-
raise exceptions.InvalidClaimValue('The required iss claim is either missing or does '
212-
'not match the expected iss value.')
213-
214-
# Validate audience claim (if present)
215-
if aud:
216-
if 'aud' not in message or aud not in message['aud']:
217-
raise exceptions.InvalidClaimValue('The required aud claim is missing.')
218-
219-
# Else return token contents
180+
message = jwt.decode(
181+
token,
182+
key=self.key.public_key(),
183+
audience=aud,
184+
issuer=iss,
185+
algorithms=['RS256', 'RS512'],
186+
options={
187+
'verify_signature': True,
188+
'verify_aud': True if aud else False
189+
}
190+
)
220191
return message
221192

222-
except NoSuitableSigningKeys as err:
223-
raise exceptions.NoSuitableKeys() from err
224-
except BadSyntax as err:
225-
raise exceptions.MalformedJwtToken() from err
193+
except Exception as token_error:
194+
exc_info = sys.exc_info()
195+
raise jwt.InvalidTokenError(exc_info[2]) from token_error

lti_consumer/lti_1p3/tests/test_consumer.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
Unit tests for LTI 1.3 consumer implementation
33
"""
44

5-
import json
65
from unittest.mock import patch
76
from urllib.parse import parse_qs, urlparse
87

98
import ddt
9+
import jwt
10+
import sys
1011
from Cryptodome.PublicKey import RSA
1112
from django.test.testcases import TestCase
1213
from edx_django_utils.cache import get_cache_key, TieredCache
13-
from jwkest.jwk import load_jwks
14-
from jwkest.jws import JWS
14+
from jwt.api_jwk import PyJWKSet
1515

1616
from lti_consumer.data import Lti1p3LaunchData
1717
from lti_consumer.lti_1p3 import exceptions
@@ -34,7 +34,9 @@
3434
STATE = "ABCD"
3535
# Consider storing a fixed key
3636
RSA_KEY_ID = "1"
37-
RSA_KEY = RSA.generate(2048).export_key('PEM')
37+
RSA_KEY = RSA.generate(2048)
38+
RSA_PRIVATE_KEY = RSA_KEY.export_key('PEM')
39+
RSA_PUBLIC_KEY = RSA_KEY.public_key().export_key('PEM')
3840

3941

4042
# Test classes
@@ -53,11 +55,11 @@ def setUp(self):
5355
lti_launch_url=LAUNCH_URL,
5456
client_id=CLIENT_ID,
5557
deployment_id=DEPLOYMENT_ID,
56-
rsa_key=RSA_KEY,
58+
rsa_key=RSA_PRIVATE_KEY,
5759
rsa_key_id=RSA_KEY_ID,
5860
redirect_uris=REDIRECT_URIS,
5961
# Use the same key for testing purposes
60-
tool_key=RSA_KEY
62+
tool_key=RSA_PUBLIC_KEY
6163
)
6264

6365
def _setup_lti_launch_data(self):
@@ -102,9 +104,25 @@ def _decode_token(self, token):
102104
This also tests the public keyset function.
103105
"""
104106
public_keyset = self.lti_consumer.get_public_keyset()
105-
key_set = load_jwks(json.dumps(public_keyset))
106-
107-
return JWS().verify_compact(token, keys=key_set)
107+
keyset = PyJWKSet.from_dict(public_keyset).keys
108+
109+
for i in range(len(keyset)):
110+
try:
111+
message = jwt.decode(
112+
token,
113+
key=keyset[i].key,
114+
algorithms=['RS256', 'RS512'],
115+
options={
116+
'verify_signature': True,
117+
'verify_aud': False
118+
}
119+
)
120+
return message
121+
except Exception as token_error:
122+
if i < len(keyset) - 1:
123+
continue
124+
exc_info = sys.exc_info()
125+
raise jwt.InvalidTokenError(exc_info[2]) from token_error
108126

109127
@ddt.data(
110128
({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),
@@ -526,7 +544,7 @@ def test_access_token_invalid_jwt(self):
526544
"scope": "",
527545
}
528546

529-
with self.assertRaises(exceptions.MalformedJwtToken):
547+
with self.assertRaises(jwt.exceptions.InvalidTokenError):
530548
self.lti_consumer.access_token(request_data)
531549

532550
def test_access_token(self):
@@ -641,11 +659,11 @@ def setUp(self):
641659
lti_launch_url=LAUNCH_URL,
642660
client_id=CLIENT_ID,
643661
deployment_id=DEPLOYMENT_ID,
644-
rsa_key=RSA_KEY,
662+
rsa_key=RSA_PRIVATE_KEY,
645663
rsa_key_id=RSA_KEY_ID,
646664
redirect_uris=REDIRECT_URIS,
647665
# Use the same key for testing purposes
648-
tool_key=RSA_KEY
666+
tool_key=RSA_PUBLIC_KEY
649667
)
650668

651669
self.preflight_response = {}
@@ -884,11 +902,11 @@ def setUp(self):
884902
lti_launch_url=LAUNCH_URL,
885903
client_id=CLIENT_ID,
886904
deployment_id=DEPLOYMENT_ID,
887-
rsa_key=RSA_KEY,
905+
rsa_key=RSA_PRIVATE_KEY,
888906
rsa_key_id=RSA_KEY_ID,
889907
redirect_uris=REDIRECT_URIS,
890908
# Use the same key for testing purposes
891-
tool_key=RSA_KEY
909+
tool_key=RSA_PUBLIC_KEY
892910
)
893911

894912
self.preflight_response = {}

0 commit comments

Comments
 (0)