Skip to content

Commit 8e15abc

Browse files
committed
Add test for auth context middleware
1 parent 5230180 commit 8e15abc

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Tests for the AuthContext middleware components.
3+
"""
4+
5+
import time
6+
7+
import pytest
8+
from starlette.types import Message, Receive, Scope, Send
9+
10+
from mcp.server.auth.middleware.auth_context import (
11+
AuthContextMiddleware,
12+
auth_context_var,
13+
get_access_token,
14+
)
15+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
16+
from mcp.server.auth.provider import AccessToken
17+
18+
19+
class MockApp:
20+
"""Mock ASGI app for testing."""
21+
22+
def __init__(self):
23+
self.called = False
24+
self.scope: Scope | None = None
25+
self.receive: Receive | None = None
26+
self.send: Send | None = None
27+
self.access_token_during_call: AccessToken | None = None
28+
29+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
30+
self.called = True
31+
self.scope = scope
32+
self.receive = receive
33+
self.send = send
34+
# Check the context during the call
35+
self.access_token_during_call = get_access_token()
36+
37+
38+
@pytest.fixture
39+
def valid_access_token() -> AccessToken:
40+
"""Create a valid access token."""
41+
return AccessToken(
42+
token="valid_token",
43+
client_id="test_client",
44+
scopes=["read", "write"],
45+
expires_at=int(time.time()) + 3600, # 1 hour from now
46+
)
47+
48+
49+
@pytest.mark.anyio
50+
class TestAuthContextMiddleware:
51+
"""Tests for the AuthContextMiddleware class."""
52+
53+
async def test_with_authenticated_user(self, valid_access_token: AccessToken):
54+
"""Test middleware with an authenticated user in scope."""
55+
app = MockApp()
56+
middleware = AuthContextMiddleware(app)
57+
58+
# Create an authenticated user
59+
user = AuthenticatedUser(valid_access_token)
60+
61+
scope: Scope = {"type": "http", "user": user}
62+
63+
# Create dummy async functions for receive and send
64+
async def receive() -> Message:
65+
return {"type": "http.request"}
66+
67+
async def send(message: Message) -> None:
68+
pass
69+
70+
# Verify context is empty before middleware
71+
assert auth_context_var.get() is None
72+
assert get_access_token() is None
73+
74+
# Run the middleware
75+
await middleware(scope, receive, send)
76+
77+
# Verify the app was called
78+
assert app.called
79+
assert app.scope == scope
80+
assert app.receive == receive
81+
assert app.send == send
82+
83+
# Verify the access token was available during the call
84+
assert app.access_token_during_call == valid_access_token
85+
86+
# Verify context is reset after middleware
87+
assert auth_context_var.get() is None
88+
assert get_access_token() is None
89+
90+
async def test_with_no_user(self):
91+
"""Test middleware with no user in scope."""
92+
app = MockApp()
93+
middleware = AuthContextMiddleware(app)
94+
95+
scope: Scope = {"type": "http"} # No user
96+
97+
# Create dummy async functions for receive and send
98+
async def receive() -> Message:
99+
return {"type": "http.request"}
100+
101+
async def send(message: Message) -> None:
102+
pass
103+
104+
# Verify context is empty before middleware
105+
assert auth_context_var.get() is None
106+
assert get_access_token() is None
107+
108+
# Run the middleware
109+
await middleware(scope, receive, send)
110+
111+
# Verify the app was called
112+
assert app.called
113+
assert app.scope == scope
114+
assert app.receive == receive
115+
assert app.send == send
116+
117+
# Verify the access token was not available during the call
118+
assert app.access_token_during_call is None
119+
120+
# Verify context is still empty after middleware
121+
assert auth_context_var.get() is None
122+
assert get_access_token() is None

0 commit comments

Comments
 (0)