Skip to content

Commit b2e8366

Browse files
committed
fix: remove useless tests
1 parent 65ffe29 commit b2e8366

12 files changed

+139
-170
lines changed

lti_consumer/lti_1p3/key_handlers.py

+47-35
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import copy
88
import json
99
import math
10-
import time
1110
import sys
11+
import time
1212
import logging
1313

1414
import jwt
15-
from Cryptodome.PublicKey import RSA
1615
from edx_django_utils.monitoring import function_trace
16+
from jwt.api_jwk import PyJWK
1717

1818
from . import exceptions
1919

@@ -52,7 +52,9 @@ def __init__(self, public_key=None, keyset_url=None):
5252
try:
5353
# Import Key and save to internal state
5454
algo_obj = jwt.get_algorithm_by_name('RS256')
55-
self.public_key = algo_obj.prepare_key(public_key)
55+
public_key = algo_obj.prepare_key(public_key)
56+
public_jwk = json.loads(algo_obj.to_jwk(public_key))
57+
self.public_key = PyJWK.from_dict(public_jwk)
5658
except ValueError as err:
5759
log.warning(
5860
'An error was encountered while loading the LTI tool\'s key from the public key. '
@@ -82,15 +84,16 @@ def _get_keyset(self, kid=None):
8284
'The RSA keys could not be loaded.'
8385
)
8486
raise exceptions.NoSuitableKeys() from err
85-
keyset.extend(keys)
87+
keyset.extend(keys.keys)
88+
89+
if self.public_key and kid:
90+
# Fill in key id of stored key.
91+
# This is needed because if the JWS is signed with a
92+
# key with a kid, pyjwkest doesn't match them with
93+
# keys without kid (kid=None) and fails verification
94+
self.public_key.kid = kid
8695

8796
if self.public_key:
88-
if kid:
89-
# Fill in key id of stored key.
90-
# This is needed because if the JWS is signed with a
91-
# key with a kid, pyjwkest doesn't match them with
92-
# keys without kid (kid=None) and fails verification
93-
self.public_key.kid = kid
9497
# Add to keyset
9598
keyset.append(self.public_key)
9699

@@ -105,25 +108,29 @@ def validate_and_decode(self, token):
105108
The authorization server decodes the JWT and MUST validate the values for the
106109
iss, sub, exp, aud and jti claims.
107110
"""
108-
try:
109-
key_set = self._get_keyset()
110-
if not key_set:
111-
raise exceptions.NoSuitableKeys()
112-
for i in range(len(key_set)):
113-
try:
114-
message = jwt.decode(
115-
token,
116-
key=key_set[i],
117-
algorithms=['RS256', 'RS512',],
118-
options={'verify_signature': True}
119-
)
120-
return message
121-
except Exception:
122-
if i == len(key_set) - 1:
123-
raise
124-
except Exception as token_error:
125-
exc_info = sys.exc_info()
126-
raise jwt.InvalidTokenError(exc_info[2]) from token_error
111+
key_set = self._get_keyset()
112+
113+
for i, obj in enumerate(key_set):
114+
try:
115+
if hasattr(obj.key, 'public_key'):
116+
key = obj.key.public_key()
117+
else:
118+
key = obj.key
119+
message = jwt.decode(
120+
token,
121+
key,
122+
algorithms=['RS256', 'RS512',],
123+
options={
124+
'verify_signature': True,
125+
'verify_aud': False
126+
}
127+
)
128+
return message
129+
except Exception: # pylint: disable=broad-except
130+
if i == len(key_set) - 1:
131+
raise
132+
133+
raise exceptions.NoSuitableKeys()
127134

128135

129136
class PlatformKeyHandler:
@@ -144,7 +151,10 @@ def __init__(self, key_pem, kid=None):
144151
# Import JWK from RSA key
145152
try:
146153
algo = jwt.get_algorithm_by_name('RS256')
147-
self.key = algo.prepare_key(key_pem)
154+
private_key = algo.prepare_key(key_pem)
155+
private_jwk = json.loads(algo.to_jwk(private_key))
156+
private_jwk['kid'] = kid
157+
self.key = PyJWK.from_dict(private_jwk)
148158
except ValueError as err:
149159
log.warning(
150160
'An error was encountered while loading the LTI platform\'s key. '
@@ -175,7 +185,7 @@ def encode_and_sign(self, message, expiration=None):
175185

176186
# The class instance that sets up the signing operation
177187
# An RS 256 key is required for LTI 1.3
178-
return jwt.encode(_message, self.key, algorithm="RS256")
188+
return jwt.encode(_message, self.key.key, algorithm="RS256")
179189

180190
def get_public_jwk(self):
181191
"""
@@ -186,11 +196,11 @@ def get_public_jwk(self):
186196
# Only append to keyset if a key exists
187197
if self.key:
188198
algo_obj = jwt.get_algorithm_by_name('RS256')
189-
public_key = algo_obj.prepare_key(self.key).public_key()
199+
public_key = algo_obj.prepare_key(self.key.key).public_key()
190200
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
191201
return jwk
192202

193-
def validate_and_decode(self, token, iss=None, aud=None):
203+
def validate_and_decode(self, token, iss=None, aud=None, exp=True):
194204
"""
195205
Check if a platform token is valid, and return allowed scopes.
196206
@@ -202,13 +212,15 @@ def validate_and_decode(self, token, iss=None, aud=None):
202212
try:
203213
message = jwt.decode(
204214
token,
205-
key=self.key.public_key(),
215+
key=self.key.key.public_key(),
206216
audience=aud,
207217
issuer=iss,
208218
algorithms=['RS256', 'RS512'],
209219
options={
210220
'verify_signature': True,
211-
'verify_aud': True if aud else False
221+
'verify_exp': bool(exp),
222+
'verify_iss': bool(iss),
223+
'verify_aud': bool(aud)
212224
}
213225
)
214226
return message

lti_consumer/lti_1p3/tests/test_consumer.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import ddt
1010
import jwt
11-
import sys
1211
from Cryptodome.PublicKey import RSA
1312
from django.conf import settings
1413
from django.test.testcases import TestCase
@@ -115,30 +114,26 @@ def _get_lti_message(
115114

116115
def _decode_token(self, token):
117116
"""
118-
Checks for a valid signarute and decodes JWT signed LTI message
117+
Checks for a valid signature and decodes JWT signed LTI message
119118
120119
This also tests the public keyset function.
121120
"""
122121
public_keyset = self.lti_consumer.get_public_keyset()
123122
keyset = PyJWKSet.from_dict(public_keyset).keys
124123

125-
for i in range(len(keyset)):
126-
try:
127-
message = jwt.decode(
128-
token,
129-
key=keyset[i].key,
130-
algorithms=['RS256', 'RS512'],
131-
options={
132-
'verify_signature': True,
133-
'verify_aud': False
134-
}
135-
)
136-
return message
137-
except Exception as token_error:
138-
if i < len(keyset) - 1:
139-
continue
140-
exc_info = sys.exc_info()
141-
raise jwt.InvalidTokenError(exc_info[2]) from token_error
124+
for obj in keyset:
125+
message = jwt.decode(
126+
token,
127+
key=obj.key,
128+
algorithms=['RS256', 'RS512'],
129+
options={
130+
'verify_signature': True,
131+
'verify_aud': False
132+
}
133+
)
134+
return message
135+
136+
return exceptions.NoSuitableKeys()
142137

143138
@ddt.data(
144139
({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),

lti_consumer/lti_1p3/tests/test_key_handlers.py

+19-46
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
import json
66
import math
77
import time
8+
from datetime import datetime, timezone
89
from unittest.mock import patch
910

1011
import ddt
1112
import jwt
1213
from Cryptodome.PublicKey import RSA
1314
from django.test.testcases import TestCase
14-
from jwkest import BadSignature
15-
from jwkest.jwk import RSAKey, load_jwks
16-
from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm
17-
15+
from jwt.api_jwk import PyJWK
1816

1917
from lti_consumer.lti_1p3 import exceptions
2018
from lti_consumer.lti_1p3.key_handlers import PlatformKeyHandler, ToolKeyHandler
@@ -39,16 +37,13 @@ def setUp(self):
3937
kid=self.rsa_key_id
4038
)
4139

42-
def _decode_token(self, token):
40+
def _decode_token(self, token, exp=True):
4341
"""
44-
Checks for a valid signarute and decodes JWT signed LTI message
42+
Checks for a valid signature and decodes JWT signed LTI message
4543
4644
This also touches the public keyset method.
4745
"""
48-
public_keyset = self.key_handler.get_public_jwk()
49-
key_set = load_jwks(json.dumps(public_keyset))
50-
51-
return JWS().verify_compact(token, keys=key_set)
46+
return self.key_handler.validate_and_decode(token, exp=exp)
5247

5348
def test_encode_and_sign(self):
5449
"""
@@ -59,7 +54,7 @@ def test_encode_and_sign(self):
5954
}
6055
signed_token = self.key_handler.encode_and_sign(message)
6156
self.assertEqual(
62-
self._decode_token(signed_token),
57+
self._decode_token(signed_token, exp=False),
6358
message
6459
)
6560

@@ -72,45 +67,21 @@ def test_encode_and_sign_with_exp(self, mock_time):
7267
message = {
7368
"test": "test"
7469
}
75-
70+
expiration = int(datetime.now(tz=timezone.utc).timestamp())
7671
signed_token = self.key_handler.encode_and_sign(
7772
message,
78-
expiration=1000
73+
expiration=expiration
7974
)
8075

8176
self.assertEqual(
8277
self._decode_token(signed_token),
8378
{
8479
"test": "test",
8580
"iat": 1000,
86-
"exp": 2000
81+
"exp": expiration + 1000
8782
}
8883
)
8984

90-
def test_encode_and_sign_no_suitable_keys(self):
91-
"""
92-
Test if an exception is raised when there are no suitable keys when signing the JWT.
93-
"""
94-
message = {
95-
"test": "test"
96-
}
97-
98-
with patch('lti_consumer.lti_1p3.key_handlers.JWS.sign_compact', side_effect=NoSuitableSigningKeys):
99-
with self.assertRaises(exceptions.NoSuitableKeys):
100-
self.key_handler.encode_and_sign(message)
101-
102-
def test_encode_and_sign_unknown_algorithm(self):
103-
"""
104-
Test if an exception is raised when the signing algorithm is unknown when signing the JWT.
105-
"""
106-
message = {
107-
"test": "test"
108-
}
109-
110-
with patch('lti_consumer.lti_1p3.key_handlers.JWS.sign_compact', side_effect=UnknownAlgorithm):
111-
with self.assertRaises(exceptions.MalformedJwtToken):
112-
self.key_handler.encode_and_sign(message)
113-
11485
def test_invalid_rsa_key(self):
11586
"""
11687
Check that class raises when trying to import invalid RSA Key.
@@ -217,10 +188,14 @@ def setUp(self):
217188
self.rsa_key_id = "1"
218189

219190
# Generate RSA and save exports
220-
rsa_key = RSA.generate(2048).export_key('PEM')
191+
rsa_key = RSA.generate(2048)
221192
algo_obj = jwt.get_algorithm_by_name('RS256')
222-
self.key = algo_obj.prepare_key(rsa_key)
223-
self.public_key = self.key.public_key()
193+
private_key = algo_obj.prepare_key(rsa_key.export_key())
194+
private_jwk = json.loads(algo_obj.to_jwk(private_key))
195+
private_jwk['kid'] = self.rsa_key_id
196+
self.key = PyJWK.from_dict(private_jwk)
197+
198+
self.public_key = rsa_key.publickey().export_key()
224199

225200
# Key handler
226201
self.key_handler = None
@@ -318,12 +293,10 @@ def test_validate_and_decode_no_keys(self):
318293
signed = create_jwt(self.key, message)
319294

320295
# Decode and check results
321-
with self.assertRaises(jwt.InvalidTokenError):
296+
with self.assertRaises(exceptions.NoSuitableKeys):
322297
key_handler.validate_and_decode(signed)
323298

324-
@patch("lti_consumer.lti_1p3.key_handlers.jwt.decode")
325-
def test_validate_and_decode_bad_signature(self, mock_jwt_decode):
326-
mock_jwt_decode.side_effect = Exception()
299+
def test_validate_and_decode_bad_signature(self):
327300
self._setup_key_handler()
328301

329302
message = {
@@ -333,5 +306,5 @@ def test_validate_and_decode_bad_signature(self, mock_jwt_decode):
333306
}
334307
signed = create_jwt(self.key, message)
335308

336-
with self.assertRaises(jwt.InvalidTokenError):
309+
with self.assertRaises(jwt.exceptions.ExpiredSignatureError):
337310
self.key_handler.validate_and_decode(signed)

lti_consumer/lti_1p3/tests/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ def create_jwt(key, message):
99
Uses private key to create a JWS from a dict.
1010
"""
1111
token = jwt.encode(
12-
message, key, algorithm='RS256'
12+
message, key.key, algorithm='RS256'
1313
)
1414
return token

0 commit comments

Comments
 (0)