diff --git a/README.md b/README.md index 766224ab..ecc52cea 100644 --- a/README.md +++ b/README.md @@ -298,6 +298,63 @@ Runtime server logs are emitted by FastMCP/Uvicorn. +
+🔐 Remote Deployment with OAuth + +When deploying the server remotely (e.g. on Cloud Run, Fly.io, Railway), +enable OAuth 2.1 to protect the MCP endpoint. + +**Quick Start:** + +```bash +docker run --rm -i \ + -v ${HOME}/.linkedin-mcp:/home/pwuser/.linkedin-mcp \ + -e TRANSPORT=streamable-http \ + -e HOST=0.0.0.0 \ + -e AUTH=oauth \ + -e OAUTH_BASE_URL=https://your-server.example.com \ + -e OAUTH_PASSWORD=your-secret-password \ + -p 8000:8000 \ + stickerdaniel/linkedin-mcp-server +``` + +**Adding as a Claude.ai Custom Connector:** + +1. Deploy the server with OAuth enabled +2. In claude.ai, go to **Settings → Connectors → Add custom connector** +3. Enter the **full MCP endpoint URL** including `/mcp`: + `https://your-server.example.com/mcp` + > **Important:** Use the `/mcp` path, not the base URL — claude.ai will return "no tools" if you omit it. +4. Claude.ai will discover the OAuth endpoints automatically +5. You'll be redirected to the login page — enter your `OAUTH_PASSWORD` +6. The connection is now authenticated + +**Retrieving the OAuth password (if stored in GCP Secret Manager):** + +```bash +gcloud secrets versions access latest --secret=linkedin-mcp-oauth-password --project=YOUR_PROJECT +``` + +**Environment Variables:** + +| Variable | Description | +|----------|-------------| +| `AUTH` | Set to `oauth` to enable OAuth 2.1 authentication | +| `OAUTH_BASE_URL` | Public URL of your server (e.g. `https://my-mcp.example.com`) | +| `OAUTH_PASSWORD` | Password for the OAuth login page | + +**CLI Flags:** + +| Flag | Description | +|------|-------------| +| `--auth oauth` | Enable OAuth 2.1 authentication | +| `--oauth-base-url URL` | Public URL of your server | +| `--oauth-password PASSWORD` | Password for the login page | + +> **Note:** OAuth state is stored in-memory. Deploy with a single instance (`--max-instances 1` on Cloud Run) — multi-instance setups will break the login flow because `/authorize` and `/login` may land on different instances. + +
+

diff --git a/docs/docker-hub.md b/docs/docker-hub.md index 2ceb526d..1a617be2 100644 --- a/docs/docker-hub.md +++ b/docs/docker-hub.md @@ -65,6 +65,9 @@ This opens a browser window where you log in manually (5 minute timeout for 2FA, | `SLOW_MO` | `0` | Delay between browser actions in ms (debugging) | | `VIEWPORT` | `1280x720` | Browser viewport size as WIDTHxHEIGHT | | `CHROME_PATH` | - | Path to Chrome/Chromium executable (rarely needed in Docker) | +| `AUTH` | - | Set to `oauth` to enable OAuth 2.1 authentication for remote deployments | +| `OAUTH_BASE_URL` | - | Public URL of the server (required when `AUTH=oauth`) | +| `OAUTH_PASSWORD` | - | Password for the OAuth login page (required when `AUTH=oauth`) | | `LINKEDIN_EXPERIMENTAL_PERSIST_DERIVED_SESSION` | `false` | Experimental: reuse checkpointed derived Linux runtime profiles across Docker restarts instead of fresh-bridging each startup | | `LINKEDIN_TRACE_MODE` | `on_error` | Trace/log retention mode: `on_error` keeps ephemeral artifacts only when a failure occurs, `always` keeps every run, `off` disables trace persistence | diff --git a/linkedin_mcp_server/auth.py b/linkedin_mcp_server/auth.py new file mode 100644 index 00000000..5f91e7df --- /dev/null +++ b/linkedin_mcp_server/auth.py @@ -0,0 +1,268 @@ +""" +OAuth 2.1 provider with password-based login for remote MCP deployments. + +Subclasses FastMCP's InMemoryOAuthProvider to add a login page in the +authorization flow. All other OAuth infrastructure (DCR, PKCE, token +management, .well-known endpoints) is handled by the parent class. +""" + +import html +import secrets +import time + +from mcp.server.auth.provider import AuthorizationParams +from mcp.shared.auth import OAuthClientInformationFull +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response +from starlette.routing import Route + +from fastmcp.server.auth.providers.in_memory import ( + AuthorizationCode, + InMemoryOAuthProvider, + construct_redirect_uri, +) + +# Pending auth requests expire after 10 minutes +_PENDING_REQUEST_TTL_SECONDS = 600 + +# Global rate limiting: max failed attempts across all request_ids in a time window +_GLOBAL_MAX_FAILED_ATTEMPTS = 20 +_GLOBAL_RATE_LIMIT_WINDOW_SECONDS = 300 # 5 minutes +_GLOBAL_LOCKOUT_SECONDS = 60 + +_LOGIN_SECURITY_HEADERS = { + "X-Frame-Options": "DENY", + "Content-Security-Policy": "default-src 'none'; style-src 'unsafe-inline'; frame-ancestors 'none'", + "X-Content-Type-Options": "nosniff", +} + + +def _html_response(content: str, status_code: int = 200) -> Response: + """HTMLResponse with security headers to prevent clickjacking and XSS.""" + from starlette.responses import HTMLResponse + + return HTMLResponse( + content, status_code=status_code, headers=_LOGIN_SECURITY_HEADERS + ) + + +# Max failed password attempts before the request is invalidated +_MAX_FAILED_ATTEMPTS = 5 + + +class PasswordOAuthProvider(InMemoryOAuthProvider): + """OAuth provider that requires a password before issuing authorization codes. + + When a client (e.g. claude.ai) hits /authorize, the user is redirected to + a login page. After entering the correct password, the authorization code + is issued and the user is redirected back to the client's callback URL. + """ + + def __init__( + self, + *, + base_url: str, + password: str, + **kwargs, + ): + from mcp.server.auth.settings import ClientRegistrationOptions + + super().__init__( + base_url=base_url, + client_registration_options=ClientRegistrationOptions(enabled=True), + **kwargs, + ) + self._password = password + self._pending_auth_requests: dict[str, dict] = {} + self._global_failed_attempts: list[float] = [] # timestamps of failures + self._global_lockout_until: float = 0.0 + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """Redirect to login page instead of auto-approving.""" + self._cleanup_expired_requests() + + request_id = secrets.token_urlsafe(32) + self._pending_auth_requests[request_id] = { + "client_id": client.client_id, + "params": params, + "created_at": time.time(), + } + + base = str(self.base_url).rstrip("/") + return f"{base}/login?request_id={request_id}" + + def get_login_routes(self) -> list[Route]: + """Return Starlette routes for the login page.""" + return [ + Route("/login", endpoint=self._handle_login, methods=["GET", "POST"]), + ] + + def get_routes(self, mcp_path: str | None = None) -> list[Route]: + """Extend parent routes with login page.""" + routes = super().get_routes(mcp_path) + routes.extend(self.get_login_routes()) + return routes + + async def _handle_login(self, request: Request) -> Response: + if request.method == "GET": + return await self._render_login(request) + return await self._process_login(request) + + async def _render_login(self, request: Request) -> Response: + request_id = request.query_params.get("request_id", "") + pending = self._pending_auth_requests.get(request_id) if request_id else None + if not pending: + return _html_response("Invalid or expired login request.", status_code=400) + + if time.time() - pending["created_at"] > _PENDING_REQUEST_TTL_SECONDS: + del self._pending_auth_requests[request_id] + return _html_response( + "Login request expired. Please restart the authorization flow.", + status_code=400, + ) + + return _html_response(self._login_html(request_id)) + + async def _process_login(self, request: Request) -> Response: + form = await request.form() + request_id = str(form.get("request_id", "")) + password = str(form.get("password", "")) + + pending = self._pending_auth_requests.get(request_id) + if not pending: + return _html_response("Invalid or expired login request.", status_code=400) + + # Enforce TTL at submission time (not only during cleanup) + if time.time() - pending["created_at"] > _PENDING_REQUEST_TTL_SECONDS: + del self._pending_auth_requests[request_id] + return _html_response( + "Login request expired. Please restart the authorization flow.", + status_code=400, + ) + + # Global rate limit: reject if locked out + now = time.time() + if now < self._global_lockout_until: + return _html_response( + "Too many failed login attempts. Please try again later.", + status_code=429, + ) + + if not secrets.compare_digest(password, self._password): + # Track per-request failures + pending["failed_attempts"] = pending.get("failed_attempts", 0) + 1 + if pending["failed_attempts"] >= _MAX_FAILED_ATTEMPTS: + del self._pending_auth_requests[request_id] + + # Track global failures and trigger lockout if threshold exceeded + self._global_failed_attempts = [ + t + for t in self._global_failed_attempts + if now - t < _GLOBAL_RATE_LIMIT_WINDOW_SECONDS + ] + self._global_failed_attempts.append(now) + if len(self._global_failed_attempts) >= _GLOBAL_MAX_FAILED_ATTEMPTS: + self._global_lockout_until = now + _GLOBAL_LOCKOUT_SECONDS + return _html_response( + "Too many failed login attempts. Please try again later, " + "then restart the authorization flow from your client.", + status_code=429, + ) + + if pending.get("failed_attempts", 0) >= _MAX_FAILED_ATTEMPTS: + return _html_response( + "Too many failed attempts. Please restart the authorization flow.", + status_code=403, + ) + remaining = _MAX_FAILED_ATTEMPTS - pending["failed_attempts"] + return _html_response( + self._login_html( + request_id, + error=f"Invalid password. {remaining} attempt(s) remaining.", + ), + status_code=200, + ) + + # Password correct — create the authorization code and redirect + del self._pending_auth_requests[request_id] + + client = await self.get_client(pending["client_id"]) + if not client: + return _html_response( + "Client registration not found. " + "Please restart the authorization flow from your client.", + status_code=400, + ) + + params: AuthorizationParams = pending["params"] + scopes_list = params.scopes if params.scopes is not None else [] + + auth_code_value = f"auth_code_{secrets.token_hex(16)}" + expires_at = time.time() + 300 # 5 min + + auth_code = AuthorizationCode( + code=auth_code_value, + client_id=pending["client_id"], + redirect_uri=params.redirect_uri, + redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, + scopes=scopes_list, + expires_at=expires_at, + code_challenge=params.code_challenge, + ) + self.auth_codes[auth_code_value] = auth_code + + redirect_url = construct_redirect_uri( + str(params.redirect_uri), code=auth_code_value, state=params.state + ) + return RedirectResponse(redirect_url, status_code=302) + + def _cleanup_expired_requests(self) -> None: + now = time.time() + expired = [ + rid + for rid, data in self._pending_auth_requests.items() + if now - data["created_at"] > _PENDING_REQUEST_TTL_SECONDS + ] + for rid in expired: + del self._pending_auth_requests[rid] + + @staticmethod + def _login_html(request_id: str, error: str = "") -> str: + error_html = ( + f'

{html.escape(error)}

' if error else "" + ) + return f""" + + + + +LinkedIn MCP Server — Login + + + +
+

LinkedIn MCP Server

+

Enter the server password to authorize this connection.

+ {error_html} +
+ + + + +
+
+ +""" diff --git a/linkedin_mcp_server/cli_main.py b/linkedin_mcp_server/cli_main.py index f25a7d22..ca3e9a02 100644 --- a/linkedin_mcp_server/cli_main.py +++ b/linkedin_mcp_server/cli_main.py @@ -380,7 +380,7 @@ def main() -> None: transport = choose_transport_interactive() # Create and run the MCP server - mcp = create_mcp_server() + mcp = create_mcp_server(oauth_config=config.server.oauth) if transport == "streamable-http": mcp.run( diff --git a/linkedin_mcp_server/config/__init__.py b/linkedin_mcp_server/config/__init__.py index 6dc34bb0..d82b3b77 100644 --- a/linkedin_mcp_server/config/__init__.py +++ b/linkedin_mcp_server/config/__init__.py @@ -8,7 +8,7 @@ import logging from .loaders import load_config -from .schema import AppConfig, BrowserConfig, ServerConfig +from .schema import AppConfig, BrowserConfig, OAuthConfig, ServerConfig logger = logging.getLogger(__name__) @@ -35,6 +35,7 @@ def reset_config() -> None: __all__ = [ "AppConfig", "BrowserConfig", + "OAuthConfig", "ServerConfig", "get_config", "reset_config", diff --git a/linkedin_mcp_server/config/loaders.py b/linkedin_mcp_server/config/loaders.py index e3c86133..ce7dfd90 100644 --- a/linkedin_mcp_server/config/loaders.py +++ b/linkedin_mcp_server/config/loaders.py @@ -47,6 +47,9 @@ class EnvironmentKeys: VIEWPORT = "VIEWPORT" CHROME_PATH = "CHROME_PATH" USER_DATA_DIR = "USER_DATA_DIR" + AUTH = "AUTH" + OAUTH_BASE_URL = "OAUTH_BASE_URL" + OAUTH_PASSWORD = "OAUTH_PASSWORD" def is_interactive_environment() -> bool: @@ -147,6 +150,19 @@ def load_from_env(config: AppConfig) -> AppConfig: if chrome_path_env := os.environ.get(EnvironmentKeys.CHROME_PATH): config.browser.chrome_path = chrome_path_env + # OAuth authentication + if auth_env := os.environ.get(EnvironmentKeys.AUTH): + if auth_env == "oauth": + config.server.oauth.enabled = True + else: + raise ConfigurationError(f"Invalid AUTH: '{auth_env}'. Must be 'oauth'.") + + if oauth_base_url := os.environ.get(EnvironmentKeys.OAUTH_BASE_URL): + config.server.oauth.base_url = oauth_base_url + + if oauth_password := os.environ.get(EnvironmentKeys.OAUTH_PASSWORD): + config.server.oauth.password = oauth_password + return config @@ -263,6 +279,30 @@ def load_from_args(config: AppConfig) -> AppConfig: help="Path to persistent browser profile directory (default: ~/.linkedin-mcp/profile)", ) + # OAuth authentication + parser.add_argument( + "--auth", + choices=["oauth"], + default=None, + help="Enable authentication (oauth for OAuth 2.1)", + ) + + parser.add_argument( + "--oauth-base-url", + type=str, + default=None, + metavar="URL", + help="Public URL of this server for OAuth (e.g. https://my-mcp.example.com)", + ) + + parser.add_argument( + "--oauth-password", + type=str, + default=None, + metavar="PASSWORD", + help="Password for the OAuth login page (visible in process list; prefer OAUTH_PASSWORD env var)", + ) + args = parser.parse_args() # Update configuration with parsed arguments @@ -322,6 +362,16 @@ def load_from_args(config: AppConfig) -> AppConfig: if args.user_data_dir: config.browser.user_data_dir = args.user_data_dir + # OAuth authentication + if args.auth == "oauth": + config.server.oauth.enabled = True + + if args.oauth_base_url: + config.server.oauth.base_url = args.oauth_base_url + + if args.oauth_password: + config.server.oauth.password = args.oauth_password + return config diff --git a/linkedin_mcp_server/config/schema.py b/linkedin_mcp_server/config/schema.py index 01fd37cf..82a4152a 100644 --- a/linkedin_mcp_server/config/schema.py +++ b/linkedin_mcp_server/config/schema.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Literal +from urllib.parse import urlparse class ConfigurationError(Exception): @@ -53,6 +54,17 @@ def validate(self) -> None: ) +@dataclass +class OAuthConfig: + """OAuth 2.1 authentication configuration for remote deployments.""" + + enabled: bool = False + base_url: str | None = ( + None # Public URL of this server (e.g. https://my-mcp.example.com) + ) + password: str | None = None # Password for the OAuth login page + + @dataclass class ServerConfig: """MCP server configuration.""" @@ -67,6 +79,8 @@ class ServerConfig: host: str = "127.0.0.1" port: int = 8000 path: str = "/mcp" + # OAuth authentication + oauth: OAuthConfig = field(default_factory=OAuthConfig) @dataclass @@ -84,6 +98,7 @@ def validate(self) -> None: self._validate_transport_config() self._validate_path_format() self._validate_port_range() + self._validate_oauth() def _validate_transport_config(self) -> None: """Validate transport configuration is consistent.""" @@ -109,3 +124,35 @@ def _validate_path_format(self) -> None: raise ConfigurationError( f"HTTP path '{self.server.path}' must be at least 2 characters" ) + + def _validate_oauth(self) -> None: + """Validate OAuth configuration when enabled. + + Skipped for command-only modes (--login, --status, --logout) that exit + before starting the server, so AUTH=oauth in the environment doesn't + break maintenance commands. + """ + if not self.server.oauth.enabled: + return + if self.server.login or self.server.status or self.server.logout: + return + if self.server.transport != "streamable-http": + raise ConfigurationError("OAuth requires --transport streamable-http") + if not self.server.oauth.base_url: + raise ConfigurationError( + "OAuth requires OAUTH_BASE_URL (the public URL of this server)" + ) + if not self.server.oauth.base_url.startswith("https://"): + raise ConfigurationError( + "OAuth requires OAUTH_BASE_URL to use HTTPS (e.g. https://my-mcp.example.com)" + ) + parsed = urlparse(self.server.oauth.base_url) + if parsed.path not in ("", "/"): + raise ConfigurationError( + "OAuth base URL must not contain a path component " + "(e.g. https://my-mcp.example.com, not https://my-mcp.example.com/api)" + ) + if not self.server.oauth.password: + raise ConfigurationError( + "OAuth requires OAUTH_PASSWORD (password for the login page)" + ) diff --git a/linkedin_mcp_server/server.py b/linkedin_mcp_server/server.py index 11025d2a..ff66b8c8 100644 --- a/linkedin_mcp_server/server.py +++ b/linkedin_mcp_server/server.py @@ -5,10 +5,15 @@ person profiles, company data, job information, and session management capabilities. """ +from __future__ import annotations + import logging -from typing import Any, AsyncIterator +from typing import TYPE_CHECKING, Any, AsyncIterator from fastmcp import FastMCP + +if TYPE_CHECKING: + from linkedin_mcp_server.config.schema import OAuthConfig from fastmcp.server.lifespan import lifespan from linkedin_mcp_server.constants import TOOL_TIMEOUT_SECONDS @@ -46,12 +51,26 @@ async def auth_lifespan(app: FastMCP) -> AsyncIterator[dict[str, Any]]: yield {} -def create_mcp_server() -> FastMCP: +def create_mcp_server(oauth_config: "OAuthConfig | None" = None) -> FastMCP: """Create and configure the MCP server with all LinkedIn tools.""" + auth = None + if oauth_config and oauth_config.enabled: + from linkedin_mcp_server.auth import PasswordOAuthProvider + + if oauth_config.base_url is None: + raise ValueError("oauth_config.base_url must be set when OAuth is enabled") + if oauth_config.password is None: + raise ValueError("oauth_config.password must be set when OAuth is enabled") + auth = PasswordOAuthProvider( + base_url=oauth_config.base_url, + password=oauth_config.password, + ) + mcp = FastMCP( "linkedin_scraper", lifespan=auth_lifespan | browser_lifespan, mask_error_details=True, + auth=auth, ) mcp.add_middleware(SequentialToolExecutionMiddleware()) diff --git a/manifest.json b/manifest.json index 92b35b65..ad5ca8e9 100644 --- a/manifest.json +++ b/manifest.json @@ -60,7 +60,20 @@ "description": "Properly close browser session and clean up resources" } ], - "user_config": {}, + "user_config": { + "AUTH": { + "description": "Set to 'oauth' to enable OAuth 2.1 authentication for remote deployments", + "required": false + }, + "OAUTH_BASE_URL": { + "description": "Public URL of the server (required when AUTH=oauth)", + "required": false + }, + "OAUTH_PASSWORD": { + "description": "Password for the OAuth login page (required when AUTH=oauth)", + "required": false + } + }, "compatibility": { "claude_desktop": ">=0.10.0", "platforms": ["darwin", "linux", "win32"] diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 00000000..7aa47108 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,340 @@ +import time + +import pytest +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from linkedin_mcp_server.auth import PasswordOAuthProvider + + +@pytest.fixture +def provider(): + return PasswordOAuthProvider( + base_url="http://localhost:8000", + password="test-secret", + ) + + +class TestPasswordOAuthProvider: + def test_init_stores_password(self, provider): + assert provider._password == "test-secret" + + async def test_authorize_returns_login_url(self, provider): + from mcp.server.auth.provider import AuthorizationParams + from mcp.shared.auth import OAuthClientInformationFull + from pydantic import AnyUrl + + client_info = OAuthClientInformationFull( + client_id="test-client", + client_name="Test", + redirect_uris=[AnyUrl("https://claude.ai/api/mcp/auth_callback")], + grant_types=["authorization_code"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + provider.clients["test-client"] = client_info + + params = AuthorizationParams( + state="test-state", + scopes=[], + code_challenge="test-challenge", + redirect_uri=AnyUrl("https://claude.ai/api/mcp/auth_callback"), + redirect_uri_provided_explicitly=True, + ) + + result = await provider.authorize(client_info, params) + assert "/login?" in result + assert "request_id=" in result + + async def test_authorize_stores_pending_request(self, provider): + from mcp.server.auth.provider import AuthorizationParams + from mcp.shared.auth import OAuthClientInformationFull + from pydantic import AnyUrl + + client_info = OAuthClientInformationFull( + client_id="test-client", + client_name="Test", + redirect_uris=[AnyUrl("https://example.com/callback")], + grant_types=["authorization_code"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + provider.clients["test-client"] = client_info + + params = AuthorizationParams( + state="s", + scopes=[], + code_challenge="c", + redirect_uri=AnyUrl("https://example.com/callback"), + redirect_uri_provided_explicitly=True, + ) + + await provider.authorize(client_info, params) + assert len(provider._pending_auth_requests) == 1 + + +class TestLoginRoutes: + @pytest.fixture + def app(self, provider): + routes = provider.get_login_routes() + return Starlette(routes=routes) + + @pytest.fixture + def client(self, app): + return TestClient(app) + + def test_get_login_renders_form(self, client, provider): + provider._pending_auth_requests["req123"] = { + "client_id": "test", + "params": None, + "created_at": time.time(), + } + + response = client.get("/login?request_id=req123") + assert response.status_code == 200 + assert "password" in response.text + assert "req123" in response.text + + def test_get_login_invalid_request_id(self, client): + response = client.get("/login?request_id=nonexistent") + assert response.status_code == 400 + + def test_get_login_missing_request_id(self, client): + response = client.get("/login") + assert response.status_code == 400 + + def test_login_page_has_security_headers(self, client, provider): + provider._pending_auth_requests["req-hdr"] = { + "client_id": "test", + "params": None, + "created_at": time.time(), + } + response = client.get("/login?request_id=req-hdr") + assert response.headers["X-Frame-Options"] == "DENY" + assert "frame-ancestors 'none'" in response.headers["Content-Security-Policy"] + assert response.headers["X-Content-Type-Options"] == "nosniff" + + def test_post_login_correct_password(self, client, provider): + from mcp.server.auth.provider import AuthorizationParams + from mcp.shared.auth import OAuthClientInformationFull + from pydantic import AnyUrl + + params = AuthorizationParams( + state="test-state", + scopes=[], + code_challenge="test-challenge", + redirect_uri=AnyUrl("https://example.com/callback"), + redirect_uri_provided_explicitly=True, + ) + provider._pending_auth_requests["req123"] = { + "client_id": "test-client", + "params": params, + "created_at": time.time(), + } + provider.clients["test-client"] = OAuthClientInformationFull( + client_id="test-client", + client_name="Test", + redirect_uris=[AnyUrl("https://example.com/callback")], + grant_types=["authorization_code"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + + response = client.post( + "/login", + data={"request_id": "req123", "password": "test-secret"}, + follow_redirects=False, + ) + assert response.status_code == 302 + assert "code=" in response.headers["location"] + assert "state=test-state" in response.headers["location"] + # Pending request consumed + assert "req123" not in provider._pending_auth_requests + + def test_post_login_wrong_password(self, client, provider): + from mcp.server.auth.provider import AuthorizationParams + from pydantic import AnyUrl + + params = AuthorizationParams( + state="s", + scopes=[], + code_challenge="c", + redirect_uri=AnyUrl("https://example.com/callback"), + redirect_uri_provided_explicitly=True, + ) + provider._pending_auth_requests["req123"] = { + "client_id": "test-client", + "params": params, + "created_at": time.time(), + } + + response = client.post( + "/login", + data={"request_id": "req123", "password": "wrong"}, + follow_redirects=False, + ) + assert response.status_code == 200 + assert "invalid" in response.text.lower() + assert "4 attempt(s) remaining" in response.text + # Pending request NOT consumed + assert "req123" in provider._pending_auth_requests + + def test_get_login_expired_request_rejected(self, client, provider): + provider._pending_auth_requests["req-expired-get"] = { + "client_id": "test-client", + "params": None, + "created_at": time.time() - 700, # 11+ minutes ago + } + + response = client.get("/login?request_id=req-expired-get") + assert response.status_code == 400 + assert "expired" in response.text.lower() + assert "req-expired-get" not in provider._pending_auth_requests + + def test_post_login_expired_request_rejected(self, client, provider): + from mcp.server.auth.provider import AuthorizationParams + from pydantic import AnyUrl + + params = AuthorizationParams( + state="s", + scopes=[], + code_challenge="c", + redirect_uri=AnyUrl("https://example.com/callback"), + redirect_uri_provided_explicitly=True, + ) + provider._pending_auth_requests["req-expired"] = { + "client_id": "test-client", + "params": params, + "created_at": time.time() - 700, # 11+ minutes ago + } + + response = client.post( + "/login", + data={"request_id": "req-expired", "password": "test-secret"}, + follow_redirects=False, + ) + assert response.status_code == 400 + assert "expired" in response.text.lower() + assert "req-expired" not in provider._pending_auth_requests + + def test_post_login_global_rate_limit(self, client, provider): + from mcp.server.auth.provider import AuthorizationParams + from pydantic import AnyUrl + + params = AuthorizationParams( + state="s", + scopes=[], + code_challenge="c", + redirect_uri=AnyUrl("https://example.com/callback"), + redirect_uri_provided_explicitly=True, + ) + + # Simulate 20 global failures (across different request_ids) + provider._global_failed_attempts = [time.time()] * 19 + provider._pending_auth_requests["req-global"] = { + "client_id": "test-client", + "params": params, + "created_at": time.time(), + } + + # This 20th failure should trigger global lockout + response = client.post( + "/login", + data={"request_id": "req-global", "password": "wrong"}, + follow_redirects=False, + ) + assert response.status_code == 429 + assert "try again later" in response.text.lower() + assert "restart" in response.text.lower() + + # Subsequent attempts also blocked even with new request_id + provider._pending_auth_requests["req-blocked"] = { + "client_id": "test-client", + "params": params, + "created_at": time.time(), + } + response = client.post( + "/login", + data={"request_id": "req-blocked", "password": "test-secret"}, + follow_redirects=False, + ) + assert response.status_code == 429 + + def test_post_login_lockout_after_max_attempts(self, client, provider): + from mcp.server.auth.provider import AuthorizationParams + from pydantic import AnyUrl + + params = AuthorizationParams( + state="s", + scopes=[], + code_challenge="c", + redirect_uri=AnyUrl("https://example.com/callback"), + redirect_uri_provided_explicitly=True, + ) + provider._pending_auth_requests["req-lock"] = { + "client_id": "test-client", + "params": params, + "created_at": time.time(), + } + + # Exhaust all 5 attempts + for i in range(5): + response = client.post( + "/login", + data={"request_id": "req-lock", "password": "wrong"}, + follow_redirects=False, + ) + + assert response.status_code == 403 + assert "too many" in response.text.lower() + # Request invalidated + assert "req-lock" not in provider._pending_auth_requests + + +class TestOAuthIntegration: + """Integration tests verifying OAuth through the HTTP layer.""" + + @pytest.fixture + def oauth_mcp(self, provider): + from fastmcp import FastMCP + + mcp = FastMCP("test-oauth", auth=provider) + + @mcp.tool + async def echo(message: str) -> dict: + return {"echo": message} + + return mcp + + @pytest.fixture + def http_client(self, oauth_mcp): + app = oauth_mcp.http_app(transport="streamable-http") + return TestClient(app) + + def test_unauthenticated_request_returns_401(self, http_client): + response = http_client.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}}, + headers={"Accept": "application/json, text/event-stream"}, + ) + assert response.status_code == 401 + assert "WWW-Authenticate" in response.headers + + def test_well_known_oauth_metadata_accessible(self, http_client): + response = http_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + data = response.json() + assert "authorization_endpoint" in data + assert "token_endpoint" in data + assert "registration_endpoint" in data + + def test_login_page_accessible_without_auth(self, http_client, provider): + """Login page should be reachable without a bearer token.""" + provider._pending_auth_requests["int-req"] = { + "client_id": "test", + "params": None, + "created_at": time.time(), + } + response = http_client.get("/login?request_id=int-req") + assert response.status_code == 200 + assert "password" in response.text diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index e428ff7a..4dded60b 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -47,7 +47,9 @@ def test_main_non_interactive_stdio_has_no_human_stdout( ) _patch_main_dependencies(monkeypatch, config) mcp = MagicMock() - monkeypatch.setattr("linkedin_mcp_server.cli_main.create_mcp_server", lambda: mcp) + monkeypatch.setattr( + "linkedin_mcp_server.cli_main.create_mcp_server", lambda **_: mcp + ) cli_main.main() @@ -68,7 +70,9 @@ def test_main_interactive_prompts_when_transport_not_explicit( "linkedin_mcp_server.cli_main.choose_transport_interactive", choose_transport ) mcp = MagicMock() - monkeypatch.setattr("linkedin_mcp_server.cli_main.create_mcp_server", lambda: mcp) + monkeypatch.setattr( + "linkedin_mcp_server.cli_main.create_mcp_server", lambda **_: mcp + ) cli_main.main() @@ -95,7 +99,9 @@ def test_main_explicit_transport_skips_prompt( "linkedin_mcp_server.cli_main.choose_transport_interactive", choose_transport ) mcp = MagicMock() - monkeypatch.setattr("linkedin_mcp_server.cli_main.create_mcp_server", lambda: mcp) + monkeypatch.setattr( + "linkedin_mcp_server.cli_main.create_mcp_server", lambda **_: mcp + ) cli_main.main() @@ -118,7 +124,9 @@ def test_main_streamable_http_passes_host_port_path( config.server.path = "/custom-mcp" _patch_main_dependencies(monkeypatch, config) mcp = MagicMock() - monkeypatch.setattr("linkedin_mcp_server.cli_main.create_mcp_server", lambda: mcp) + monkeypatch.setattr( + "linkedin_mcp_server.cli_main.create_mcp_server", lambda **_: mcp + ) cli_main.main() diff --git a/tests/test_config.py b/tests/test_config.py index 48b78fd4..2c1d3a5e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,7 @@ AppConfig, BrowserConfig, ConfigurationError, + OAuthConfig, ServerConfig, ) @@ -42,6 +43,84 @@ def test_validate_invalid_port(self): config.validate() +class TestOAuthConfig: + def test_defaults(self): + config = OAuthConfig() + assert config.enabled is False + assert config.base_url is None + assert config.password is None + + def test_validate_requires_base_url_when_enabled(self): + config = AppConfig() + config.server.transport = "streamable-http" + config.server.oauth = OAuthConfig(enabled=True, password="secret") + with pytest.raises(ConfigurationError, match="OAUTH_BASE_URL"): + config.validate() + + def test_validate_requires_password_when_enabled(self): + config = AppConfig() + config.server.transport = "streamable-http" + config.server.oauth = OAuthConfig(enabled=True, base_url="https://example.com") + with pytest.raises(ConfigurationError, match="OAUTH_PASSWORD"): + config.validate() + + def test_validate_passes_when_fully_configured(self): + config = AppConfig() + config.server.transport = "streamable-http" + config.server.oauth = OAuthConfig( + enabled=True, base_url="https://example.com", password="secret" + ) + config.validate() # No error + + def test_validate_requires_streamable_http_transport(self): + config = AppConfig() + config.server.transport = "stdio" + config.server.oauth = OAuthConfig( + enabled=True, base_url="https://example.com", password="secret" + ) + with pytest.raises(ConfigurationError, match="streamable-http"): + config.validate() + + def test_validate_rejects_http_base_url(self): + config = AppConfig() + config.server.transport = "streamable-http" + config.server.oauth = OAuthConfig( + enabled=True, base_url="http://example.com", password="secret" + ) + with pytest.raises(ConfigurationError, match="HTTPS"): + config.validate() + + def test_validate_rejects_base_url_with_path(self): + config = AppConfig() + config.server.transport = "streamable-http" + config.server.oauth = OAuthConfig( + enabled=True, base_url="https://example.com/api", password="secret" + ) + with pytest.raises(ConfigurationError, match="path component"): + config.validate() + + def test_validate_accepts_base_url_with_trailing_slash(self): + config = AppConfig() + config.server.transport = "streamable-http" + config.server.oauth = OAuthConfig( + enabled=True, base_url="https://example.com/", password="secret" + ) + config.validate() # No error — trailing slash is fine + + def test_validate_passes_when_disabled(self): + config = AppConfig() + config.server.oauth = OAuthConfig(enabled=False) + config.validate() # No error + + @pytest.mark.parametrize("flag", ["login", "status", "logout"]) + def test_validate_skips_oauth_in_command_only_modes(self, flag): + """OAuth validation should not block --login, --status, --logout.""" + config = AppConfig() + config.server.oauth = OAuthConfig(enabled=True) # Missing base_url + password + setattr(config.server, flag, True) + config.validate() # No error — skipped for command-only modes + + class TestConfigSingleton: def test_get_config_returns_same_instance(self, monkeypatch): # Mock sys.argv to prevent argparse from parsing pytest's arguments @@ -156,3 +235,29 @@ def test_load_from_env_user_data_dir(self, monkeypatch): config = load_from_env(AppConfig()) assert config.browser.user_data_dir == "/custom/profile" + + def test_load_from_env_oauth_enabled(self, monkeypatch): + monkeypatch.setenv("AUTH", "oauth") + monkeypatch.setenv("OAUTH_BASE_URL", "https://example.com") + monkeypatch.setenv("OAUTH_PASSWORD", "secret123") + from linkedin_mcp_server.config.loaders import load_from_env + + config = load_from_env(AppConfig()) + assert config.server.oauth.enabled is True + assert config.server.oauth.base_url == "https://example.com" + assert config.server.oauth.password == "secret123" + + def test_load_from_env_oauth_disabled_by_default(self, monkeypatch): + for var in ["AUTH", "OAUTH_BASE_URL", "OAUTH_PASSWORD"]: + monkeypatch.delenv(var, raising=False) + from linkedin_mcp_server.config.loaders import load_from_env + + config = load_from_env(AppConfig()) + assert config.server.oauth.enabled is False + + def test_load_from_env_invalid_auth_mode(self, monkeypatch): + monkeypatch.setenv("AUTH", "invalid") + from linkedin_mcp_server.config.loaders import load_from_env + + with pytest.raises(ConfigurationError, match="Invalid AUTH"): + load_from_env(AppConfig()) diff --git a/tests/test_server.py b/tests/test_server.py index 4ed40244..8e9862fd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,6 +5,7 @@ from fastmcp import FastMCP from fastmcp.server.middleware import MiddlewareContext +from linkedin_mcp_server.config.schema import OAuthConfig from linkedin_mcp_server.sequential_tool_middleware import ( SequentialToolExecutionMiddleware, ) @@ -87,3 +88,20 @@ async def test_sequential_tool_middleware_reports_queue_progress(self): ), ] ) + + +class TestServerAuth: + async def test_create_mcp_server_no_auth_by_default(self): + mcp = create_mcp_server() + assert mcp.auth is None + + async def test_create_mcp_server_with_oauth(self): + from linkedin_mcp_server.auth import PasswordOAuthProvider + + oauth_config = OAuthConfig( + enabled=True, + base_url="https://example.com", + password="secret", + ) + mcp = create_mcp_server(oauth_config=oauth_config) + assert isinstance(mcp.auth, PasswordOAuthProvider)