Skip to content

Commit b97de01

Browse files
authored
fix(jwt): Relax typing to allow Sequence for Token.aud (#4241)
1 parent b105cff commit b97de01

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

litestar/security/jwt/token.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
from collections.abc import Sequence # noqa: TC003
45
from dataclasses import asdict, dataclass, field
56
from datetime import datetime, timezone
67
from typing import TYPE_CHECKING, Any, TypedDict
@@ -11,8 +12,6 @@
1112
from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException
1213

1314
if TYPE_CHECKING:
14-
from collections.abc import Sequence
15-
1615
from typing_extensions import Self
1716

1817
__all__ = (
@@ -59,8 +58,8 @@ class Token:
5958
"""Issued at - should always be current now."""
6059
iss: str | None = field(default=None)
6160
"""Issuer - optional unique identifier for the issuer."""
62-
aud: str | None = field(default=None)
63-
"""Audience - intended audience."""
61+
aud: str | Sequence[str] | None = field(default=None)
62+
"""Audience - intended audience(s)."""
6463
jti: str | None = field(default=None)
6564
"""JWT ID - a unique identifier of the JWT between different issuers."""
6665
extras: dict[str, Any] = field(default_factory=dict)

tests/unit/test_security/test_jwt/test_token.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,34 @@ def test_strict_aud_with_one_element_sequence(audience: str | list[str]) -> None
199199
)
200200

201201

202+
@pytest.mark.parametrize(
203+
"audience",
204+
[
205+
pytest.param(None, id="None"),
206+
pytest.param("foo", id="String"),
207+
pytest.param("not-foo", id="InvalidAudience"),
208+
pytest.param(["foo", "bar"], id="List"),
209+
],
210+
)
211+
def test_validate_audience(audience: str | list[str]) -> None:
212+
secret = secrets.token_hex()
213+
encoded = Token(exp=datetime.now() + timedelta(days=1), sub="foo", aud=["foo", "bar"]).encode(secret, "HS256")
214+
215+
def decode() -> None:
216+
Token.decode(
217+
encoded,
218+
secret=secret,
219+
algorithm="HS256",
220+
audience=audience,
221+
)
222+
223+
if audience != "not-foo":
224+
decode()
225+
else:
226+
with pytest.raises(NotAuthorizedException):
227+
decode()
228+
229+
202230
def test_custom_decode_payload() -> None:
203231
@dataclasses.dataclass
204232
class CustomToken(Token):

0 commit comments

Comments
 (0)