diff --git a/src/conduit/auth/client/models/__init__.py b/src/conduit/auth/client/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/conduit/auth/client/models/discovery.py b/src/conduit/auth/client/models/discovery.py new file mode 100644 index 0000000..3613847 --- /dev/null +++ b/src/conduit/auth/client/models/discovery.py @@ -0,0 +1,102 @@ +"""Discovery-related models for OAuth 2.1 server metadata. + +Contains models for Protected Resource Metadata (RFC 9728) and +Authorization Server Metadata (RFC 8414) discovery. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from urllib.parse import urlparse + +from pydantic import BaseModel, Field, field_validator + + +class ProtectedResourceMetadata(BaseModel): + """OAuth 2.0 Protected Resource Metadata (RFC 9728). + + Metadata returned by MCP servers to indicate their authorization servers + and resource configuration. + """ + + resource: str | None = None + authorization_servers: list[str] = Field(min_length=1) + + # Optional fields from RFC 9728 + bearer_methods_supported: list[str] | None = None + resource_documentation: str | None = None + resource_policy_uri: str | None = None + resource_tos_uri: str | None = None + + @field_validator("authorization_servers") + @classmethod + def validate_auth_servers(cls, v: list[str]) -> list[str]: + if not v: + raise ValueError("At least one authorization server is required") + return v + + +class AuthorizationServerMetadata(BaseModel): + """OAuth 2.0 Authorization Server Metadata (RFC 8414). + + Metadata returned by authorization servers describing their endpoints + and supported capabilities. + """ + + # Required by RFC 8414 + issuer: str + response_types_supported: list[str] = Field(min_length=1) + + # Required for authorization code flow (our use case) + authorization_endpoint: str + token_endpoint: str + + # PKCE support (required for OAuth 2.1) + code_challenge_methods_supported: list[str] = Field(default=["S256"]) + + # Dynamic registration (RFC 7591) + registration_endpoint: str | None = None + + # Optional but commonly used + revocation_endpoint: str | None = None + introspection_endpoint: str | None = None + scopes_supported: list[str] | None = None + grant_types_supported: list[str] = Field(default=["authorization_code"]) + + @field_validator("code_challenge_methods_supported") + @classmethod + def validate_pkce_support(cls, v: list[str]) -> list[str]: + if "S256" not in v: + raise ValueError("Authorization server must support S256 PKCE method") + return v + + +@dataclass(frozen=True) +class DiscoveryResult: + """Complete discovery results for an MCP server. + + Immutable result containing all metadata needed for OAuth flow. + Combines Protected Resource Metadata and Authorization Server Metadata. + """ + + server_url: str + protected_resource_metadata: ProtectedResourceMetadata + authorization_server_metadata: AuthorizationServerMetadata + auth_server_url: str + + def get_resource_url(self) -> str: + """Get the resource URL for RFC 8707 resource parameter. + + Uses the most specific URI (the actual server URL) as per MCP spec + guidance to provide the most specific URI possible. + """ + # Canonicalize the server URL per RFC 3986 and MCP spec + parsed = urlparse(self.server_url) + canonical = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}" + if parsed.path and parsed.path != "/": + canonical += parsed.path.rstrip("/") # Preserve case, remove trailing slash + + return canonical + + # TODO: Consider if we ever need PRM resource override + # The spec is unclear about when to use PRM resource vs server URL diff --git a/src/conduit/auth/client/models/errors.py b/src/conduit/auth/client/models/errors.py new file mode 100644 index 0000000..fa4719d --- /dev/null +++ b/src/conduit/auth/client/models/errors.py @@ -0,0 +1,99 @@ +"""Exception hierarchy for OAuth 2.1 authentication errors. + +Provides specific exception types for different failure modes to enable +precise error handling and recovery strategies. +""" + +from __future__ import annotations + + +class OAuth2Error(Exception): + """Base exception for all OAuth 2.1 related errors.""" + + pass + + +class DiscoveryError(OAuth2Error): + """Raised when OAuth server discovery fails.""" + + pass + + +class ProtectedResourceMetadataError(DiscoveryError): + """Raised when Protected Resource Metadata discovery fails.""" + + pass + + +class AuthorizationServerMetadataError(DiscoveryError): + """Raised when Authorization Server Metadata discovery fails.""" + + pass + + +class RegistrationError(OAuth2Error): + """Raised when dynamic client registration fails.""" + + pass + + +class TokenError(OAuth2Error): + """Raised when token operations fail.""" + + pass + + +class TokenRefreshError(TokenError): + """Raised when token refresh fails.""" + + pass + + +class TokenExchangeError(TokenError): + """Raised when authorization code to token exchange fails.""" + + pass + + +class AuthorizationError(OAuth2Error): + """Raised when user authorization fails.""" + + pass + + +class AuthorizationResponseError(OAuth2Error): + """Raised when authorization response is malformed or invalid.""" + + pass + + +class UserAuthCancelledError(AuthorizationError): + """Raised when user cancels the authorization flow.""" + + pass + + +class PKCEError(OAuth2Error): + """Raised when PKCE parameter generation or validation fails.""" + + pass + + +class AuthorizationCallbackError(OAuth2Error): + """Raised when authorization server callback data is malformed or invalid. + + This indicates the authorization server sent an invalid callback URL, + not that our callback handling code failed. + """ + + pass + + +class StateValidationError(AuthorizationCallbackError): + """Raised when OAuth state parameter validation fails. + + This indicates either a missing state parameter or a state mismatch, + which could indicate a CSRF attack or authorization server issue. + """ + + pass diff --git a/src/conduit/auth/client/models/flow.py b/src/conduit/auth/client/models/flow.py new file mode 100644 index 0000000..73911e8 --- /dev/null +++ b/src/conduit/auth/client/models/flow.py @@ -0,0 +1,56 @@ +"""Authorization flow models for OAuth 2.1. + +Contains models for authorization requests and callback handling. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from urllib.parse import urlencode + + +@dataclass(frozen=True) +class AuthorizationRequest: + """Authorization request parameters for OAuth 2.1 flow.""" + + authorization_endpoint: str + client_id: str + redirect_uri: str + code_challenge: str + code_challenge_method: str + state: str + resource: str | None = None # RFC 8707 + scope: str | None = None + + def build_authorization_url(self) -> str: + """Build the complete authorization URL.""" + params = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "code_challenge": self.code_challenge, + "code_challenge_method": self.code_challenge_method, + "state": self.state, + } + + if self.resource: + params["resource"] = self.resource + if self.scope: + params["scope"] = self.scope + + return f"{self.authorization_endpoint}?{urlencode(params)}" + + +@dataclass(frozen=True) +class AuthorizationResponse: + code: str | None = None + state: str | None = None + error: str | None = None + error_description: str | None = None + error_uri: str | None = None + + def is_success(self) -> bool: + return self.error is None and self.code is not None + + def is_error(self) -> bool: + return self.error is not None diff --git a/src/conduit/auth/client/models/registration.py b/src/conduit/auth/client/models/registration.py new file mode 100644 index 0000000..080817e --- /dev/null +++ b/src/conduit/auth/client/models/registration.py @@ -0,0 +1,80 @@ +"""Client registration models for OAuth 2.0 Dynamic Client Registration. + +Contains models for client metadata (RFC 7591) and registration results. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from urllib.parse import urlparse + +from pydantic import BaseModel, Field, field_validator + + +class ClientMetadata(BaseModel): + """OAuth 2.0 Client Metadata for dynamic registration (RFC 7591).""" + + client_name: str + redirect_uris: list[str] = Field(min_length=1) + + # Optional metadata + client_uri: str | None = None + logo_uri: str | None = None + scope: str | None = None + contacts: list[str] | None = None + tos_uri: str | None = None + policy_uri: str | None = None + + # OAuth 2.1 specific + token_endpoint_auth_method: str = "none" # Public client + grant_types: list[str] = Field(default=["authorization_code"]) + response_types: list[str] = Field(default=["code"]) + + @field_validator("redirect_uris") + @classmethod + def validate_redirect_uris(cls, v: list[str]) -> list[str]: + """Validate redirect URIs meet OAuth 2.1 security requirements.""" + for uri in v: + parsed = urlparse(uri) + # Must be HTTPS or localhost + if parsed.scheme == "http" and parsed.hostname != "localhost": + raise ValueError(f"Redirect URI must use HTTPS or localhost: {uri}") + return v + + @field_validator("client_uri", "logo_uri", "tos_uri", "policy_uri") + @classmethod + def validate_https_uris(cls, v: str | None) -> str | None: + """Validate optional URIs use HTTPS when provided.""" + if v is not None and not v.startswith("https://"): + raise ValueError(f"URI must use HTTPS: {v}") + return v + + +class ClientCredentials(BaseModel): + """OAuth 2.0 Client Credentials from registration response.""" + + client_id: str + client_secret: str | None = None # None for public clients + registration_access_token: str | None = None + registration_client_uri: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + + def is_expired(self) -> bool: + """Check if client credentials have expired.""" + if self.client_secret_expires_at is None: + return False + return time.time() >= self.client_secret_expires_at + + +@dataclass(frozen=True) +class ClientRegistration: + """Complete client registration result. + + Immutable result containing client metadata and credentials. + """ + + metadata: ClientMetadata + credentials: ClientCredentials + registration_endpoint: str diff --git a/src/conduit/auth/client/models/security.py b/src/conduit/auth/client/models/security.py new file mode 100644 index 0000000..1fd87b1 --- /dev/null +++ b/src/conduit/auth/client/models/security.py @@ -0,0 +1,31 @@ +"""Security-related models for OAuth 2.1 authentication. + +Contains PKCE parameters and other cryptographic primitives needed +for secure OAuth flows. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class PKCEParameters: + """PKCE (Proof Key for Code Exchange) parameters for OAuth 2.1 security. + + Immutable parameters generated for each authorization flow to prevent + authorization code interception attacks (RFC 7636). + """ + + code_verifier: str = field() + code_challenge: str = field() + code_challenge_method: str = field(default="S256") + + def __post_init__(self) -> None: + """Validate PKCE parameters meet RFC 7636 requirements.""" + if not (43 <= len(self.code_verifier) <= 128): + raise ValueError("code_verifier must be 43-128 characters") + if not (43 <= len(self.code_challenge) <= 128): + raise ValueError("code_challenge must be 43-128 characters") + if self.code_challenge_method != "S256": + raise ValueError("Only S256 code challenge method is supported") diff --git a/src/conduit/auth/client/models/tokens.py b/src/conduit/auth/client/models/tokens.py new file mode 100644 index 0000000..7ee79e5 --- /dev/null +++ b/src/conduit/auth/client/models/tokens.py @@ -0,0 +1,205 @@ +"""Token state and lifecycle models for OAuth 2.1. + +Contains mutable token state management and token response handling. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + + +@dataclass +class TokenState: + """Mutable token state with lifecycle management. + + Represents the current state of OAuth tokens for a server. + Mutable to allow token refresh without recreating the entire auth provider. + """ + + access_token: str | None = None + refresh_token: str | None = None + token_type: str = "Bearer" + expires_at: float | None = None # Unix timestamp + scope: str | None = None + + def is_valid(self, buffer_seconds: float = 30.0) -> bool: + """Check if access token is valid with optional buffer. + + Args: + buffer_seconds: Refresh token this many seconds before expiry + """ + if not self.access_token: + return False + + if self.expires_at is None: + return True # No expiry means token doesn't expire + + return time.time() < (self.expires_at - buffer_seconds) + + def can_refresh(self) -> bool: + """Check if token can be refreshed.""" + return bool(self.refresh_token) + + def clear(self) -> None: + """Clear all token data.""" + self.access_token = None + self.refresh_token = None + self.expires_at = None + self.scope = None + + def update_from_response(self, token_response: dict[str, Any]) -> None: + """Update token state from OAuth token response.""" + self.access_token = token_response.get("access_token") + self.token_type = token_response.get("token_type", "Bearer") + self.scope = token_response.get("scope") + + # Update refresh token if provided + if "refresh_token" in token_response: + self.refresh_token = token_response["refresh_token"] + + # Calculate expiry time + if "expires_in" in token_response: + expires_in = int(token_response["expires_in"]) + self.expires_at = time.time() + expires_in + else: + self.expires_at = None + + +@dataclass(frozen=True) +class TokenRequest: + """OAuth 2.1 token exchange request parameters (RFC 6749 Section 4.1.3). + + Immutable request parameters for exchanging authorization codes for access tokens. + Includes PKCE code_verifier (RFC 7636) and resource parameter (RFC 8707). + """ + + # Required fields first + token_endpoint: str + code: str + redirect_uri: str + client_id: str + code_verifier: str # RFC 7636 PKCE + + # Optional fields with defaults last + grant_type: str = "authorization_code" + resource: str | None = None # RFC 8707 Resource Indicators + scope: str | None = None + + def to_form_data(self) -> dict[str, str]: + """Convert to form data for application/x-www-form-urlencoded request. + + Token requests must use form encoding, not JSON (RFC 6749 Section 4.1.3). + + Returns: + Dictionary suitable for httpx data parameter + """ + data = { + "grant_type": self.grant_type, + "code": self.code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + "code_verifier": self.code_verifier, + } + + # Add optional parameters + if self.resource: + data["resource"] = self.resource + if self.scope: + data["scope"] = self.scope + + return data + + +class TokenResponse(BaseModel): + """OAuth 2.1 token response (RFC 6749 Section 5). + + Represents the response from a token endpoint, including both + successful responses (Section 5.1) and error responses (Section 5.2). + """ + + # Success response fields (RFC 6749 Section 5.1) + access_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None # Seconds until expiry + refresh_token: str | None = None + scope: str | None = None + + # Error response fields (RFC 6749 Section 5.2) + error: str | None = None + error_description: str | None = None + error_uri: str | None = None + + def is_success(self) -> bool: + """Check if token response indicates success.""" + return self.error is None and self.access_token is not None + + def is_error(self) -> bool: + """Check if token response indicates an error.""" + return self.error is not None + + def calculate_expires_at(self) -> float | None: + """Calculate absolute expiry timestamp from expires_in. + + Returns: + Unix timestamp when token expires, or None if no expiry + """ + if self.expires_in is None: + return None + return time.time() + self.expires_in + + def to_token_state(self) -> TokenState: + """Convert successful token response to mutable TokenState. + + Returns: + TokenState for ongoing token management + + Raises: + ValueError: If response is not successful + """ + if not self.is_success(): + raise ValueError("Cannot convert error response to TokenState") + + return TokenState( + access_token=self.access_token, + refresh_token=self.refresh_token, + token_type=self.token_type, + expires_at=self.calculate_expires_at(), + scope=self.scope, + ) + + +@dataclass(frozen=True) +class RefreshTokenRequest: + """OAuth 2.1 refresh token request parameters (RFC 6749 Section 6). + + Immutable request parameters for refreshing access tokens. + """ + + # Required fields first + token_endpoint: str + refresh_token: str + client_id: str + + # Optional fields with defaults last + grant_type: str = "refresh_token" + resource: str | None = None # RFC 8707 Resource Indicators + scope: str | None = None + + def to_form_data(self) -> dict[str, str]: + """Convert to form data for application/x-www-form-urlencoded request.""" + data = { + "grant_type": self.grant_type, + "refresh_token": self.refresh_token, + "client_id": self.client_id, + } + + if self.resource: + data["resource"] = self.resource + if self.scope: + data["scope"] = self.scope + + return data diff --git a/src/conduit/auth/client/oauth_client.py b/src/conduit/auth/client/oauth_client.py new file mode 100644 index 0000000..fd60bec --- /dev/null +++ b/src/conduit/auth/client/oauth_client.py @@ -0,0 +1,289 @@ +"""Complete OAuth 2.1 client orchestration for MCP authentication. + +Coordinates discovery, registration, authorization, and token exchange +to provide a complete OAuth authentication flow. +""" + +from __future__ import annotations + +import logging +from typing import Protocol + +from conduit.auth.client.models.discovery import DiscoveryResult +from conduit.auth.client.models.registration import ClientMetadata, ClientRegistration +from conduit.auth.client.models.tokens import TokenRequest, TokenState +from conduit.auth.client.services.discovery import OAuth2Discovery +from conduit.auth.client.services.flow import OAuth2FlowManager +from conduit.auth.client.services.registration import OAuth2Registration +from conduit.auth.client.services.tokens import OAuth2TokenManager + +logger = logging.getLogger(__name__) + + +class AuthorizationHandler(Protocol): + """Protocol for handling user authorization step. + + Allows different strategies for browser interaction: + - Manual (return URL to developer) + - Browser automation (open browser + local server) + - Custom UI integration + """ + + async def handle_authorization(self, auth_url: str) -> str: + """Handle user authorization and return callback URL. + + Args: + auth_url: Authorization URL for user to visit + + Returns: + Callback URL received after user authorization + """ + ... + + +class ManualAuthorizationHandler: + """Authorization handler that requires manual user interaction. + + Returns the authorization URL and waits for developer to provide + the callback URL. Suitable for CLI tools and custom integrations. + """ + + def __init__(self, callback_handler: callable[[str], str] | None = None): + """Initialize manual authorization handler. + + Args: + callback_handler: Optional function to call with auth URL. + Should return the callback URL. + """ + self.callback_handler = callback_handler + + async def handle_authorization(self, auth_url: str) -> str: + """Handle authorization by delegating to callback handler or raising.""" + if self.callback_handler: + return await self.callback_handler(auth_url) + else: + raise NotImplementedError( + f"Please visit {auth_url} and provide the callback URL" + ) + + +class AuthenticatedSession: + """Represents an authenticated session with an MCP server. + + Manages token lifecycle including automatic refresh when needed. + """ + + def __init__( + self, + server_url: str, + token_state: TokenState, + token_manager: OAuth2TokenManager, + discovery_result: DiscoveryResult, + client_registration: ClientRegistration, + ): + self.server_url = server_url + self.token_state = token_state + self._token_manager = token_manager + self._discovery_result = discovery_result + self._client_registration = client_registration + + @property + def access_token(self) -> str | None: + """Get current access token.""" + return self.token_state.access_token + + @property + def is_valid(self) -> bool: + """Check if session has valid access token.""" + return self.token_state.is_valid() + + async def refresh_if_needed(self) -> bool: + """Refresh access token if needed and possible. + + Returns: + True if token was refreshed or is still valid, False if refresh failed + """ + if self.token_state.is_valid(): + return True + + if not self.token_state.can_refresh(): + logger.warning("Token expired and cannot be refreshed") + return False + + try: + # Attempt token refresh + from conduit.auth.client.models.tokens import RefreshTokenRequest + + refresh_request = RefreshTokenRequest( + token_endpoint=self._discovery_result.authorization_server_metadata.token_endpoint, + refresh_token=self.token_state.refresh_token, + client_id=self._client_registration.credentials.client_id, + resource=self._discovery_result.get_resource_url(), + ) + + token_response = await self._token_manager.refresh_access_token( + refresh_request + ) + + if token_response.is_success(): + # Update token state with new tokens + new_token_state = token_response.to_token_state() + self.token_state.access_token = new_token_state.access_token + self.token_state.refresh_token = ( + new_token_state.refresh_token or self.token_state.refresh_token + ) + self.token_state.expires_at = new_token_state.expires_at + self.token_state.scope = new_token_state.scope + + logger.info("Successfully refreshed access token") + return True + else: + logger.error(f"Token refresh failed: {token_response.error}") + return False + + except Exception as e: + logger.error(f"Token refresh error: {e}") + return False + + +class OAuth2Client: + """Complete OAuth 2.1 client for MCP authentication. + + Orchestrates the full OAuth flow from discovery through token exchange, + providing a high-level interface for MCP client authentication. + """ + + def __init__( + self, + authorization_handler: AuthorizationHandler | None = None, + redirect_uri: str = "http://localhost:8080/callback", + client_name: str = "MCP Client", + timeout: float = 30.0, + ): + """Initialize OAuth client. + + Args: + authorization_handler: Handler for user authorization step + redirect_uri: OAuth redirect URI for callbacks + client_name: Name to use for dynamic client registration + timeout: HTTP request timeout + """ + self.authorization_handler = ( + authorization_handler or ManualAuthorizationHandler() + ) + self.redirect_uri = redirect_uri + self.client_name = client_name + + # Initialize service components + self.discovery = OAuth2Discovery(timeout=timeout) + self.registration = OAuth2Registration(timeout=timeout) + self.flow_manager = OAuth2FlowManager() + self.token_manager = OAuth2TokenManager(timeout=timeout) + + async def authenticate_with_server( + self, + server_url: str, + scope: str | None = None, + client_metadata: ClientMetadata | None = None, + ) -> AuthenticatedSession: + """Authenticate with an MCP server using OAuth 2.1. + + Performs the complete OAuth flow: + 1. Discover OAuth configuration + 2. Register client dynamically + 3. Handle authorization flow + 4. Exchange code for tokens + 5. Return authenticated session + + Args: + server_url: MCP server URL to authenticate with + scope: Optional OAuth scope to request + client_metadata: Optional custom client metadata + + Returns: + AuthenticatedSession: Session with valid access tokens + + Raises: + Various OAuth errors if authentication fails + """ + logger.info(f"Starting OAuth authentication with {server_url}") + + # 1. Discover OAuth configuration + logger.debug("Discovering OAuth configuration") + discovery_result = await self.discovery.discover_from_url(server_url) + + # 2. Register client (if registration endpoint available) + logger.debug("Registering OAuth client") + client_reg = await self._register_client(discovery_result, client_metadata) + + # 3. Start authorization flow + logger.debug("Starting authorization flow") + auth_url, pkce_params, state = await self.flow_manager.start_authorization_flow( + discovery_result, client_reg.credentials, self.redirect_uri, scope + ) + + # 4. Handle user authorization + logger.debug("Handling user authorization") + callback_url = await self.authorization_handler.handle_authorization(auth_url) + + # 5. Process callback + logger.debug("Processing authorization callback") + auth_response = await self.flow_manager.handle_authorization_callback( + callback_url, state + ) + + if not auth_response.is_success(): + raise ValueError(f"Authorization failed: {auth_response.error}") + + # 6. Exchange code for tokens + logger.debug("Exchanging authorization code for tokens") + + token_request = TokenRequest( + token_endpoint=discovery_result.authorization_server_metadata.token_endpoint, + code=auth_response.code, + redirect_uri=self.redirect_uri, + client_id=client_reg.credentials.client_id, + code_verifier=pkce_params.code_verifier, + resource=discovery_result.get_resource_url(), + scope=scope, + ) + + token_response = await self.token_manager.exchange_code_for_token(token_request) + + if not token_response.is_success(): + raise ValueError(f"Token exchange failed: {token_response.error}") + + # 7. Create authenticated session + token_state = token_response.to_token_state() + session = AuthenticatedSession( + server_url, token_state, self.token_manager, discovery_result, client_reg + ) + + logger.info(f"Successfully authenticated with {server_url}") + return session + + async def _register_client( + self, + discovery_result: DiscoveryResult, + client_metadata: ClientMetadata | None = None, + ) -> ClientRegistration: + """Register OAuth client with authorization server.""" + if not discovery_result.authorization_server_metadata.registration_endpoint: + raise ValueError("Server does not support dynamic client registration") + + if client_metadata is None: + client_metadata = ClientMetadata( + client_name=self.client_name, + redirect_uris=[self.redirect_uri], + ) + + return await self.registration.register_client( + discovery_result.authorization_server_metadata.registration_endpoint, + client_metadata, + ) + + async def close(self) -> None: + """Close all service connections.""" + await self.discovery.close() + await self.registration.close() + await self.token_manager.close() diff --git a/src/conduit/auth/client/services/discovery.py b/src/conduit/auth/client/services/discovery.py new file mode 100644 index 0000000..e8bbe90 --- /dev/null +++ b/src/conduit/auth/client/services/discovery.py @@ -0,0 +1,299 @@ +"""OAuth 2.1 server discovery primitive. + +Implements RFC 9728 (Protected Resource Metadata) and RFC 8414 +(Authorization Server Metadata) discovery to find OAuth endpoints and capabilities +for MCP servers. +""" + +from __future__ import annotations + +import logging +import re +from urllib.parse import urljoin, urlparse + +import httpx +from pydantic import ValidationError + +from conduit.auth.client.models.discovery import ( + AuthorizationServerMetadata, + DiscoveryResult, + ProtectedResourceMetadata, +) +from conduit.auth.client.models.errors import ( + AuthorizationServerMetadataError, + DiscoveryError, + ProtectedResourceMetadataError, +) + +logger = logging.getLogger(__name__) + + +class OAuth2Discovery: + """Handles OAuth 2.1 server discovery for MCP authentication. + + Implements the two-step discovery process: + 1. Protected Resource Metadata (RFC 9728) - find authorization servers + 2. Authorization Server Metadata (RFC 8414) - find OAuth endpoints + + Supports discovery from 401 responses (WWW-Authenticate header) and + direct URL discovery with proper fallback strategies. + """ + + def __init__(self, timeout: float = 30.0): + """Initialize OAuth discovery. + + Args: + timeout: HTTP request timeout in seconds + """ + self.timeout = timeout + self._http_client = httpx.AsyncClient(timeout=timeout) + + async def discover_from_401(self, response: httpx.Response) -> DiscoveryResult: + """Discover OAuth configuration from a 401 Unauthorized response. + + Extracts the resource metadata URL from the WWW-Authenticate header + and performs the complete discovery chain. + + Args: + response: 401 response from MCP server + + Returns: + Complete discovery results + + Raises: + DiscoveryError: If discovery fails + """ + if response.status_code != 401: + raise DiscoveryError(f"Expected 401 response, got {response.status_code}") + + # Extract server URL from the original request + server_url = str(response.request.url) + + # Try to get resource metadata URL from WWW-Authenticate header + resource_metadata_url = self._extract_resource_metadata_from_www_auth(response) + + if resource_metadata_url: + logger.debug( + "Found resource metadata URL in WWW-Authenticate: " + f"{resource_metadata_url}" + ) + prm = await self._fetch_protected_resource_metadata(resource_metadata_url) + else: + logger.debug( + "No resource metadata URL in WWW-Authenticate, using fallbackdiscovery" + ) + prm = await self._discover_protected_resource_metadata(server_url) + + # Get authorization server metadata + auth_server_url = str(prm.authorization_servers[0]) + asm = await self._discover_authorization_server_metadata(auth_server_url) + + return DiscoveryResult( + server_url=server_url, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + auth_server_url=auth_server_url, + ) + + async def discover_from_url(self, server_url: str) -> DiscoveryResult: + """Discover OAuth configuration from an MCP server URL. + + Performs discovery using well-known endpoints without requiring + a 401 response first. + + Args: + server_url: MCP server URL to discover OAuth config for + + Returns: + Complete discovery results + + Raises: + DiscoveryError: If discovery fails + """ + # Discover protected resource metadata + prm = await self._discover_protected_resource_metadata(server_url) + + # Get authorization server metadata + auth_server_url = str(prm.authorization_servers[0]) + asm = await self._discover_authorization_server_metadata(auth_server_url) + + return DiscoveryResult( + server_url=server_url, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + auth_server_url=auth_server_url, + ) + + def _extract_resource_metadata_from_www_auth( + self, response: httpx.Response + ) -> str | None: + """Extract resource metadata URL from WWW-Authenticate header. + + RFC 9728 Section 5.1: WWW-Authenticate response should contain + resource_metadata parameter pointing to the metadata URL. + + Args: + response: HTTP response containing WWW-Authenticate header + + Returns: + Resource metadata URL if found, None otherwise + """ + www_auth_header = response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) + pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + async def _discover_protected_resource_metadata( + self, server_url: str + ) -> ProtectedResourceMetadata: + """Discover protected resource metadata using well-known endpoint. + + RFC 9728: Protected resource metadata should be available at + /.well-known/oauth-protected-resource + + Args: + server_url: MCP server URL + + Returns: + Protected resource metadata + + Raises: + ProtectedResourceMetadataError: If discovery fails + """ + # Build well-known URL + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + metadata_url = urljoin(base_url, "/.well-known/oauth-protected-resource") + + return await self._fetch_protected_resource_metadata(metadata_url) + + async def _fetch_protected_resource_metadata( + self, metadata_url: str + ) -> ProtectedResourceMetadata: + """Fetch and parse protected resource metadata. + + Args: + metadata_url: URL to fetch metadata from + + Returns: + Parsed protected resource metadata + + Raises: + ProtectedResourceMetadataError: If fetch or parsing fails + httpx.HTTPStatusError: If an HTTP error occurs + """ + try: + logger.debug(f"Fetching protected resource metadata from: {metadata_url}") + response = await self._http_client.get(metadata_url) + response.raise_for_status() + + content = response.text + metadata = ProtectedResourceMetadata.model_validate_json(content) + + logger.debug( + f"Successfully discovered protected resource metadata: " + f"{len(metadata.authorization_servers)} auth servers" + ) + return metadata + + except httpx.HTTPStatusError as e: + raise ProtectedResourceMetadataError( + f"Failed to fetch protected resource metadata from {metadata_url}: {e}" + ) from e + except ValidationError as e: + raise ProtectedResourceMetadataError( + f"Invalid protected resource metadata from {metadata_url}: {e}" + ) from e + except Exception as e: + raise ProtectedResourceMetadataError( + f"Unexpected error fetching protected resource metadata: {e}" + ) from e + + async def _discover_authorization_server_metadata( + self, auth_server_url: str + ) -> AuthorizationServerMetadata: + """Discover authorization server metadata. + + RFC 8414: Authorization server metadata should be available at + /.well-known/oauth-authorization-server (with path-aware discovery) + + Args: + auth_server_url: Authorization server URL + + Returns: + Authorization server metadata + + Raises: + AuthorizationServerMetadataError: If discovery fails + """ + discovery_urls = self._build_discovery_urls(auth_server_url) + + for url in discovery_urls: + try: + logger.debug(f"Trying authorization server metadata discovery: {url}") + response = await self._http_client.get(url) + + if response.status_code == 200: + content = response.text + metadata = AuthorizationServerMetadata.model_validate_json(content) + logger.debug( + f"Successfully discovered authorization server metadata from: " + f"{url}" + ) + return metadata + elif response.status_code >= 500: + # Server error - don't try other URLs + break + + except ValidationError: + # Invalid metadata - try next URL + continue + except httpx.RequestError: + # Network error - try next URL + continue + + raise AuthorizationServerMetadataError( + f"Failed to discover authorization server metadata for {auth_server_url}. " + f"Tried URLs: {discovery_urls}" + ) + + def _build_discovery_urls(self, auth_server_url: str) -> list[str]: + """Build ordered list of discovery URLs to try. + + RFC 8414 Section 3: Path-aware discovery should be tried first, + then fallback to root discovery. + + Args: + auth_server_url: Authorization server URL + + Returns: + Ordered list of URLs to try for discovery + """ + parsed = urlparse(auth_server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + urls = [] + + # RFC 8414: Path-aware OAuth discovery + if parsed.path and parsed.path != "/": + oauth_path = ( + f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + ) + urls.append(urljoin(base_url, oauth_path)) + + # OAuth root fallback + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) + + return urls + + async def close(self) -> None: + """Close the HTTP client.""" + await self._http_client.aclose() diff --git a/src/conduit/auth/client/services/flow.py b/src/conduit/auth/client/services/flow.py new file mode 100644 index 0000000..e50ebd1 --- /dev/null +++ b/src/conduit/auth/client/services/flow.py @@ -0,0 +1,199 @@ +"""OAuth 2.1 authorization flow orchestration service. + +Coordinates the complete authorization code flow including PKCE security, +state validation, and callback handling for MCP client authentication. +""" + +from __future__ import annotations + +import logging +from urllib.parse import parse_qs, urlparse + +from conduit.auth.client.models.discovery import DiscoveryResult +from conduit.auth.client.models.errors import ( + AuthorizationCallbackError, + AuthorizationError, + StateValidationError, +) +from conduit.auth.client.models.flow import AuthorizationRequest, AuthorizationResponse +from conduit.auth.client.models.registration import ClientCredentials +from conduit.auth.client.models.security import PKCEParameters +from conduit.auth.client.services.pkce import PKCEManager +from conduit.auth.client.services.security import generate_state, validate_state + +logger = logging.getLogger(__name__) + + +class OAuth2FlowManager: + """Orchestrates OAuth 2.1 authorization code flows for MCP authentication. + + Handles the complete authorization flow from initial request generation + through callback processing, including: + - PKCE parameter generation and validation + - State parameter security (CSRF protection) + - Authorization URL construction + - Callback URL parsing and validation + - Resource parameter handling (RFC 8707) + """ + + def __init__(self): + """Initialize the OAuth flow manager.""" + self._pkce_manager = PKCEManager() + + async def start_authorization_flow( + self, + discovery_result: DiscoveryResult, + client_credentials: ClientCredentials, + redirect_uri: str, + scope: str | None = None, + ) -> tuple[str, PKCEParameters, str]: + """Start an OAuth 2.1 authorization flow. + + Generates secure PKCE parameters and state, then builds the authorization + URL that users should visit to grant permissions. + + Args: + discovery_result: OAuth server discovery results + client_credentials: Registered client credentials + redirect_uri: URI to redirect to after authorization + scope: Optional scope to request + + Returns: + Tuple of (authorization_url, pkce_parameters, state) + - authorization_url: URL for user to visit + - pkce_parameters: Store these for token exchange + - state: Store this for callback validation + + Raises: + AuthorizationError: If flow setup fails + """ + try: + # Generate security parameters + pkce_params = self._pkce_manager.generate_parameters() + state = generate_state() + + # Get resource URL for RFC 8707 + resource_url = discovery_result.get_resource_url() + + logger.debug( + f"Starting authorization flow for client" + f"{client_credentials.client_id} with resource" + f" {resource_url}" + ) + + # Build authorization request + auth_endpoint = ( + discovery_result.authorization_server_metadata.authorization_endpoint + ) + auth_request = AuthorizationRequest( + authorization_endpoint=auth_endpoint, + client_id=client_credentials.client_id, + redirect_uri=redirect_uri, + code_challenge=pkce_params.code_challenge, + code_challenge_method=pkce_params.code_challenge_method, + state=state, + resource=resource_url, + scope=scope, + ) + + # Generate authorization URL + authorization_url = auth_request.build_authorization_url() + + logger.info( + f"Generated authorization URL for client {client_credentials.client_id}" + ) + + return authorization_url, pkce_params, state + + except Exception as e: + raise AuthorizationError(f"Failed to start authorization flow: {e}") from e + + async def handle_authorization_callback( + self, + callback_url: str, + expected_state: str, + ) -> AuthorizationResponse: + """Handle OAuth authorization callback from the authorization server. + + Parses the callback URL, validates the state parameter for CSRF protection, + and returns the authorization response for further processing. + + Args: + callback_url: Full callback URL received from authorization server + expected_state: State parameter that was sent in authorization request + + Returns: + AuthorizationResponse: Parsed callback response + + Raises: + CallbackError: If callback URL is malformed + StateValidationError: If state parameter doesn't match + """ + try: + logger.debug(f"Processing authorization callback: {callback_url}") + + # Parse callback URL + auth_response = self._parse_callback_url(callback_url) + + if auth_response.state is None: + raise StateValidationError( + "Authorization server callback missing required state parameter" + ) + + # Validate state parameter (CSRF protection) + validate_state(expected_state, auth_response.state) + + if auth_response.is_success(): + logger.info( + "Authorization callback successful - received authorization code" + ) + elif auth_response.is_error(): + logger.warning( + f"Authorization callback contained error: {auth_response.error} - " + f"{auth_response.error_description}" + ) + else: + logger.warning("Authorization callback missing both code and error") + + return auth_response + + except StateValidationError: + raise # Re-raise state validation errors as-is + except Exception as e: + raise AuthorizationCallbackError( + f"Failed to process authorization callback: {e}" + ) from e + + def _parse_callback_url(self, callback_url: str) -> AuthorizationResponse: + """Parse OAuth callback URL into AuthorizationResponse. + + Args: + callback_url: Full callback URL from authorization server + + Returns: + AuthorizationResponse: Parsed callback parameters + + Raises: + CallbackError: If URL is malformed + """ + try: + parsed = urlparse(callback_url) + query_params = parse_qs(parsed.query) + + # Extract single values from query parameter lists + def get_single_param(key: str) -> str | None: + values = query_params.get(key, []) + return values[0] if values else None + + return AuthorizationResponse( + code=get_single_param("code"), + state=get_single_param("state"), + error=get_single_param("error"), + error_description=get_single_param("error_description"), + error_uri=get_single_param("error_uri"), + ) + + except Exception as e: + raise AuthorizationCallbackError( + f"Failed to parse callback URL: {e}" + ) from e diff --git a/src/conduit/auth/client/services/pkce.py b/src/conduit/auth/client/services/pkce.py new file mode 100644 index 0000000..2d32f84 --- /dev/null +++ b/src/conduit/auth/client/services/pkce.py @@ -0,0 +1,82 @@ +"""PKCE (Proof Key for Code Exchange) manager for OAuth 2.1 security. + +Implements RFC 7636 PKCE parameters generation and validation to prevent +authorization code interception attacks. This is required for OAuth 2.1. +""" + +from __future__ import annotations + +import base64 +import hashlib +import secrets +import string + +from conduit.auth.client.models.errors import PKCEError +from conduit.auth.client.models.security import PKCEParameters + + +class PKCEManager: + """Manages PKCE parameter generation for OAuth 2.1 flows. + + PKCE (Proof Key for Code Exchange) is a security extension that prevents + authorization code interception attacks by requiring clients to prove + they initiated the authorization request. + + This implementation follows RFC 7636 requirements: + - Uses S256 code challenge method (SHA256 + base64url) + - Generates cryptographically secure code verifiers + """ + + def generate_parameters(self) -> PKCEParameters: + """Generate new PKCE parameters for an authorization flow. + + Creates a cryptographically secure code verifier and derives the + corresponding code challenge using SHA256. + + Returns: + PKCEParameters: Immutable PKCE parameters for the authorization flow + + Raises: + PKCEError: If parameter generation fails + """ + try: + code_verifier = self._generate_code_verifier() + code_challenge = self._generate_code_challenge(code_verifier) + + return PKCEParameters( + code_verifier=code_verifier, + code_challenge=code_challenge, + code_challenge_method="S256", + ) + + except Exception as e: + raise PKCEError(f"Failed to generate PKCE parameters: {e}") from e + + def _generate_code_verifier(self) -> str: + """Generate a cryptographically secure code verifier. + + RFC 7636 Section 4.1: code verifier must be 43-128 characters long + and use only unreserved characters: + [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" + + Returns: + A 128-character code verifier (maximum length for best security) + """ + alphabet = string.ascii_letters + string.digits + "-._~" + return "".join(secrets.choice(alphabet) for _ in range(128)) + + def _generate_code_challenge(self, code_verifier: str) -> str: + """Generate code challenge from code verifier using S256 method. + + RFC 7636 Section 4.2: For S256, the code challenge is: + BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) + + Args: + code_verifier: The code verifier to hash + + Returns: + Base64url-encoded SHA256 hash of the code verifier + """ + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + challenge = base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") + return challenge diff --git a/src/conduit/auth/client/services/registration.py b/src/conduit/auth/client/services/registration.py new file mode 100644 index 0000000..06bd6bf --- /dev/null +++ b/src/conduit/auth/client/services/registration.py @@ -0,0 +1,196 @@ +"""OAuth 2.1 dynamic client registration service. + +Implements RFC 7591 (OAuth 2.0 Dynamic Client Registration Protocol) +to automatically register MCP clients with authorization servers. +""" + +from __future__ import annotations + +import logging + +import httpx +from pydantic import ValidationError + +from conduit.auth.client.models.errors import RegistrationError +from conduit.auth.client.models.registration import ( + ClientCredentials, + ClientMetadata, + ClientRegistration, +) + +logger = logging.getLogger(__name__) + + +class OAuth2Registration: + """Handles OAuth 2.1 dynamic client registration for MCP authentication. + + Implements RFC 7591 to automatically register clients with authorization + servers, eliminating the need for manual client configuration. + + Supports both confidential and public client registration with proper + error handling and validation. + """ + + def __init__(self, timeout: float = 30.0): + """Initialize OAuth registration. + + Args: + timeout: HTTP request timeout in seconds + """ + self.timeout = timeout + self._http_client = httpx.AsyncClient(timeout=timeout) + + async def register_client( + self, + registration_endpoint: str, + client_metadata: ClientMetadata, + initial_access_token: str | None = None, + ) -> ClientRegistration: + """Register a new OAuth client with the authorization server. + + Args: + registration_endpoint: Client registration endpoint URL + client_metadata: Client metadata to register + initial_access_token: Optional initial access token for protected + registration endpoints + + Returns: + Complete client registration result + + Raises: + RegistrationError: If registration fails + """ + logger.debug(f"Registering client at {registration_endpoint}") + + try: + # Prepare registration request + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + # Add initial access token if provided (RFC 7591 Section 3.1) + if initial_access_token: + headers["Authorization"] = f"Bearer {initial_access_token}" + + # Send registration request + response = await self._http_client.post( + registration_endpoint, + json=client_metadata.model_dump(exclude_none=True, mode="json"), + headers=headers, + ) + + # Handle response + if response.status_code == 201: + return await self._handle_successful_registration( + response, registration_endpoint, client_metadata + ) + else: + await self._handle_registration_error(response) + + except httpx.HTTPError as e: + raise RegistrationError(f"HTTP error during registration: {e}") from e + except Exception as e: + raise RegistrationError(f"Unexpected error during registration: {e}") from e + + async def _handle_successful_registration( + self, + response: httpx.Response, + registration_endpoint: str, + original_metadata: ClientMetadata, + ) -> ClientRegistration: + """Handle successful registration response. + + Args: + response: Successful HTTP response + registration_endpoint: Registration endpoint used + original_metadata: Original client metadata sent + + Returns: + Complete client registration result + + Raises: + RegistrationError: If response parsing fails + """ + try: + response_data = response.json() + + # Validate required fields are present + if "client_id" not in response_data: + raise RegistrationError( + "Registration response missing required client_id" + ) + + # Parse client credentials from response + credentials = ClientCredentials(**response_data) + + logger.info( + f"Successfully registered client {credentials.client_id} " + f"at {registration_endpoint}" + ) + + return ClientRegistration( + metadata=original_metadata, + credentials=credentials, + registration_endpoint=registration_endpoint, + ) + + except (ValueError, ValidationError) as e: + raise RegistrationError(f"Invalid registration response format: {e}") from e + except Exception as e: + raise RegistrationError( + f"Failed to parse registration response: {e}" + ) from e + + async def _handle_registration_error(self, response: httpx.Response) -> None: + """Handle registration error response. + + Args: + response: Error HTTP response + + Raises: + RegistrationError: Always raises with appropriate error message + """ + try: + error_data = response.json() + error_code = error_data.get("error", "unknown_error") + error_description = error_data.get( + "error_description", "No description provided" + ) + + logger.error( + f"Client registration failed with {response.status_code}: " + f"{error_code} - {error_description}" + ) + + # Map common OAuth error codes to more specific messages + if error_code == "invalid_client_metadata": + raise RegistrationError(f"Invalid client metadata: {error_description}") + elif error_code == "invalid_redirect_uri": + raise RegistrationError(f"Invalid redirect URI: {error_description}") + elif error_code == "invalid_client_uri": + raise RegistrationError(f"Invalid client URI: {error_description}") + elif response.status_code == 401: + raise RegistrationError( + "Registration endpoint requires authentication " + "(initial access token)" + ) + elif response.status_code == 403: + raise RegistrationError( + "Registration forbidden - check authorization server policy" + ) + else: + raise RegistrationError( + f"Registration failed ({response.status_code}): {error_code} - " + f"{error_description}" + ) + + except (ValueError, KeyError): + # Fallback for non-JSON error responses + raise RegistrationError( + f"Registration failed with HTTP {response.status_code}: {response.text}" + ) + + async def close(self) -> None: + """Close the HTTP client and clean up resources.""" + await self._http_client.aclose() diff --git a/src/conduit/auth/client/services/security.py b/src/conduit/auth/client/services/security.py new file mode 100644 index 0000000..b5f6452 --- /dev/null +++ b/src/conduit/auth/client/services/security.py @@ -0,0 +1,60 @@ +"""Security utilities for OAuth 2.1 flows. + +Provides cryptographically secure parameter generation and validation +for OAuth security mechanisms including state parameters and URI validation. +""" + +from __future__ import annotations + +import secrets +import string +from urllib.parse import urlparse + +from conduit.auth.client.models.errors import StateValidationError + + +def generate_state() -> str: + """Generate cryptographically secure state parameter. + + The state parameter provides CSRF protection by ensuring the callback + matches the original authorization request. + + Returns: + Cryptographically secure random state string (32 characters) + """ + # Generate 32 characters of URL-safe random data + alphabet = string.ascii_letters + string.digits + "-._~" + return "".join(secrets.choice(alphabet) for _ in range(32)) + + +def validate_state(expected: str, actual: str) -> None: + """Validate state parameter matches expected value. + + Args: + expected: State parameter from original authorization request + actual: State parameter from callback URL + + Raises: + StateValidationError: If state parameters don't match + """ + if not secrets.compare_digest(expected, actual): + raise StateValidationError("State parameter mismatch - possible CSRF attack") + + +def validate_redirect_uri(uri: str) -> bool: + """Validate redirect URI meets OAuth 2.1 security requirements. + + Args: + uri: Redirect URI to validate + + Returns: + True if URI is valid for OAuth 2.1 + """ + try: + parsed = urlparse(uri) + # Must be HTTPS or localhost + return parsed.scheme == "https" or ( + parsed.scheme == "http" and parsed.hostname == "localhost" + ) + except Exception: + return False diff --git a/src/conduit/auth/client/services/tokens.py b/src/conduit/auth/client/services/tokens.py new file mode 100644 index 0000000..781334c --- /dev/null +++ b/src/conduit/auth/client/services/tokens.py @@ -0,0 +1,188 @@ +"""OAuth 2.1 token exchange and management service. + +Implements RFC 6749 token endpoint interactions with PKCE (RFC 7636) +and Resource Indicators (RFC 8707) for secure MCP authentication. +""" + +from __future__ import annotations + +import logging + +import httpx +from pydantic import ValidationError + +from conduit.auth.client.models.errors import TokenError +from conduit.auth.client.models.tokens import ( + RefreshTokenRequest, + TokenRequest, + TokenResponse, +) + +logger = logging.getLogger(__name__) + + +class OAuth2TokenManager: + """Manages OAuth 2.1 token exchange and refresh operations. + + Handles the token endpoint interactions including: + - Authorization code to access token exchange (RFC 6749 Section 4.1.3) + - Access token refresh (RFC 6749 Section 6) + - PKCE code verification (RFC 7636) + - Resource parameter handling (RFC 8707) + + Uses application/x-www-form-urlencoded encoding as required by OAuth 2.1. + """ + + def __init__(self, timeout: float = 30.0): + """Initialize OAuth token manager. + + Args: + timeout: HTTP request timeout in seconds + """ + self.timeout = timeout + self._http_client = httpx.AsyncClient(timeout=timeout) + + async def exchange_code_for_token( + self, token_request: TokenRequest + ) -> TokenResponse: + """Exchange authorization code for access token. + + Implements RFC 6749 Section 4.1.3 - Access Token Request. + Includes PKCE code_verifier for security (RFC 7636). + + Args: + token_request: Token exchange request parameters + + Returns: + TokenResponse: Token response (success or error) + + Raises: + TokenError: If token exchange fails due to network/parsing issues + """ + logger.debug(f"Exchanging authorization code at {token_request.token_endpoint}") + + try: + # Prepare form-encoded request (RFC 6749 requires this format) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + form_data = token_request.to_form_data() + + # Log request details (without sensitive data) + logger.debug( + f"Token request: grant_type={form_data['grant_type']}, " + f"client_id={form_data['client_id']}, " + f"resource={form_data.get('resource', 'none')}" + ) + + # Send token request + response = await self._http_client.post( + token_request.token_endpoint, + data=form_data, + headers=headers, + ) + + # Parse response (both success and error responses are JSON) + return await self._parse_token_response(response) + + except httpx.HTTPError as e: + raise TokenError(f"HTTP error during token exchange: {e}") from e + except Exception as e: + raise TokenError(f"Unexpected error during token exchange: {e}") from e + + async def refresh_access_token( + self, refresh_request: RefreshTokenRequest + ) -> TokenResponse: + """Refresh an access token using a refresh token. + + Implements RFC 6749 Section 6 - Refreshing an Access Token. + + Args: + refresh_request: Refresh token request parameters + + Returns: + TokenResponse: New token response (success or error) + + Raises: + TokenError: If token refresh fails due to network/parsing issues + """ + logger.debug(f"Refreshing access token at {refresh_request.token_endpoint}") + + try: + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + form_data = refresh_request.to_form_data() + + logger.debug( + f"Refresh request: client_id={form_data['client_id']}, " + f"resource={form_data.get('resource', 'none')}" + ) + + response = await self._http_client.post( + refresh_request.token_endpoint, + data=form_data, + headers=headers, + ) + + return await self._parse_token_response(response) + + except httpx.HTTPError as e: + raise TokenError(f"HTTP error during token refresh: {e}") from e + except Exception as e: + raise TokenError(f"Unexpected error during token refresh: {e}") from e + + async def _parse_token_response(self, response: httpx.Response) -> TokenResponse: + """Parse token endpoint response into TokenResponse. + + Handles both successful responses (200) and error responses (400+) + according to RFC 6749 Section 5. + + Args: + response: HTTP response from token endpoint + + Returns: + TokenResponse: Parsed response (success or error) + + Raises: + TokenError: If response cannot be parsed + """ + try: + response_data = response.json() + + if response.status_code == 200: + # Successful token response (RFC 6749 Section 5.1) + logger.info("Token exchange successful") + + # Validate required fields + if "access_token" not in response_data: + raise TokenError("Token response missing required access_token") + + return TokenResponse(**response_data) + + else: + # Error response (RFC 6749 Section 5.2) + error_code = response_data.get("error", "unknown_error") + error_description = response_data.get( + "error_description", "No description provided" + ) + + logger.warning( + f"Token exchange failed with {response.status_code}: " + f"{error_code} - {error_description}" + ) + + return TokenResponse(**response_data) + + except (ValueError, ValidationError) as e: + raise TokenError(f"Invalid token response format: {e}") from e + except Exception as e: + raise TokenError(f"Failed to parse token response: {e}") from e + + async def close(self) -> None: + """Close the HTTP client and clean up resources.""" + await self._http_client.aclose() diff --git a/src/conduit/auth/oauth_client.py b/src/conduit/auth/oauth_client.py deleted file mode 100644 index 4202ecd..0000000 --- a/src/conduit/auth/oauth_client.py +++ /dev/null @@ -1,152 +0,0 @@ -""" -OAuth 2.1 client for MCP server authentication. - -Provides a simple, Pythonic interface for authenticating with MCP servers -that require OAuth 2.1 authorization following the MCP authorization specification. -""" - -from __future__ import annotations - -from typing import Protocol - -import httpx - - -class TokenStorage(Protocol): - """Protocol for token storage implementations.""" - - async def get_tokens(self) -> dict[str, str] | None: - """Get stored tokens for the server.""" - ... - - async def set_tokens(self, tokens: dict[str, str]) -> None: - """Store tokens for the server.""" - ... - - async def clear_tokens(self) -> None: - """Clear stored tokens.""" - ... - - -class AuthHandler(Protocol): - """Protocol for handling user authorization flow.""" - - async def handle_authorization(self, auth_url: str) -> tuple[str, str | None]: - """ - Handle user authorization. - - Args: - auth_url: The authorization URL to present to the user - - Returns: - Tuple of (authorization_code, state) from the callback - """ - ... - - -class MCPOAuthClient: - """ - OAuth 2.1 client for MCP servers. - - Handles the complete OAuth flow including: - - Server discovery (RFC 9728, RFC 8414) - - Dynamic client registration (RFC 7591) - - PKCE authorization flow (OAuth 2.1) - - Token management and refresh - - Example: - ```python - oauth_client = MCPOAuthClient("https://api.example.com/mcp") - http_client = await oauth_client.get_authenticated_http_client() - - # Use the authenticated client for MCP requests - response = await http_client.post("/mcp", json=mcp_request) - ``` - """ - - def __init__( - self, - server_url: str, - *, - client_name: str = "Conduit MCP Client", - storage: TokenStorage | None = None, - auth_handler: AuthHandler | None = None, - timeout: float = 300.0, - ) -> None: - """ - Initialize OAuth client for an MCP server. - - Args: - server_url: The MCP server URL to authenticate with - client_name: Name for dynamic client registration - storage: Token storage implementation (defaults to file-based) - auth_handler: Authorization handler (defaults to browser-based) - timeout: Timeout for the authorization flow in seconds - """ - self.server_url = server_url - self.client_name = client_name - self.timeout = timeout - - # Use defaults if not provided - self.storage = storage or self._default_storage() - self.auth_handler = auth_handler or self._default_auth_handler() - - async def get_authenticated_http_client(self) -> httpx.AsyncClient: - """ - Get an authenticated HTTP client for the MCP server. - - This method handles the complete OAuth flow if needed: - 1. Check for valid cached tokens - 2. Attempt token refresh if expired - 3. Perform full OAuth flow if no valid tokens - - Returns: - An httpx.AsyncClient configured with OAuth authentication - - Raises: - OAuthError: If authentication fails - """ - await self._ensure_authenticated() - return httpx.AsyncClient(auth=self._create_auth_provider()) - - async def clear_authentication(self) -> None: - """Clear stored authentication data.""" - await self.storage.clear_tokens() - - def _default_storage(self) -> TokenStorage: - """Create default file-based token storage.""" - # TODO: Implement FileTokenStorage - raise NotImplementedError("Default storage not yet implemented") - - def _default_auth_handler(self) -> AuthHandler: - """Create default browser-based auth handler.""" - # TODO: Implement BrowserAuthHandler - raise NotImplementedError("Default auth handler not yet implemented") - - async def _ensure_authenticated(self) -> None: - """Ensure we have valid authentication.""" - # TODO: Implement OAuth flow - raise NotImplementedError("OAuth flow not yet implemented") - - def _create_auth_provider(self) -> httpx.Auth: - """Create httpx auth provider with current tokens.""" - # TODO: Implement token-based auth provider - raise NotImplementedError("Auth provider not yet implemented") - - -class OAuthError(Exception): - """Base exception for OAuth-related errors.""" - - pass - - -class AuthenticationError(OAuthError): - """Raised when authentication fails.""" - - pass - - -class TokenError(OAuthError): - """Raised when token operations fail.""" - - pass diff --git a/tests/auth/client/test_discovery_urls.py b/tests/auth/client/test_discovery_urls.py new file mode 100644 index 0000000..72dbafb --- /dev/null +++ b/tests/auth/client/test_discovery_urls.py @@ -0,0 +1,328 @@ +"""Tests for URL parsing and building in OAuth discovery. + +Tests the mechanical URL operations that are prone to edge cases: +- WWW-Authenticate header parsing +- Discovery URL construction +- Resource URL canonicalization +""" + +from unittest.mock import Mock + +from conduit.auth.client.models.discovery import ( + AuthorizationServerMetadata, + DiscoveryResult, + ProtectedResourceMetadata, +) +from conduit.auth.client.services.discovery import OAuth2Discovery + + +class TestWWWAuthenticateHeaderParsing: + """Test extraction of resource metadata URLs from WWW-Authenticate headers.""" + + def setup_method(self): + self.discovery = OAuth2Discovery() + + def test_extract_quoted_resource_metadata_from_www_auth(self): + # Arrange + response = Mock() + response.headers = { + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert url == "https://api.example.com/.well-known/oauth-protected-resource" + + def test_extract_unquoted_resource_metadata_from_www_auth(self): + # Arrange + response = Mock() + response.headers = { + "WWW-Authenticate": "Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource" + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert url == "https://api.example.com/.well-known/oauth-protected-resource" + + def test_complex_www_authenticate_header(self): + # Arrange + response = Mock() + response.headers = { + "WWW-Authenticate": 'Bearer realm="api", error="invalid_token",' + 'resource_metadata="https://api.example.com/.well-known/oauth-protected-resource",' + 'error_description="Token expired"' + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert url == "https://api.example.com/.well-known/oauth-protected-resource" + + def test_no_resource_metadata_parameter(self): + # Arrange + response = Mock() + response.headers = { + "WWW-Authenticate": 'Bearer realm="api", error="invalid_token"' + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert url is None + + def test_missing_www_authenticate_header(self): + # Arrange + response = Mock() + response.headers = {} + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert url is None + + def test_url_with_path_and_query(self): + # Arrange + response = Mock() + response.headers = { + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/v1/.well-known/oauth-protected-resource?version=2"' + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert ( + url + == "https://api.example.com/v1/.well-known/oauth-protected-resource?version=2" + ) + + +class TestDiscoveryURLBuilding: + """Test construction of discovery URLs with path-aware logic.""" + + def setup_method(self): + self.discovery = OAuth2Discovery() + + def test_root_authorization_server(self): + # Arrange & Act + urls = self.discovery._build_discovery_urls("https://auth.example.com") + + expected = [ + "https://auth.example.com/.well-known/oauth-authorization-server", + ] + + # Assert + assert urls == expected + + def test_authorization_server_with_path(self): + # Arrange & Act + urls = self.discovery._build_discovery_urls("https://auth.example.com/oauth") + + expected = [ + "https://auth.example.com/.well-known/oauth-authorization-server/oauth", + "https://auth.example.com/.well-known/oauth-authorization-server", + ] + + # Assert + assert urls == expected + + def test_authorization_server_with_nested_path(self): + # Arrange & Act + urls = self.discovery._build_discovery_urls("https://auth.example.com/v1/oauth") + + expected = [ + "https://auth.example.com/.well-known/oauth-authorization-server/v1/oauth", + "https://auth.example.com/.well-known/oauth-authorization-server", + ] + + # Assert + assert urls == expected + + def test_authorization_server_with_trailing_slash(self): + # Arrange & Act + urls = self.discovery._build_discovery_urls("https://auth.example.com/oauth/") + + expected = [ + "https://auth.example.com/.well-known/oauth-authorization-server/oauth", + "https://auth.example.com/.well-known/oauth-authorization-server", + ] + + # Assert + assert urls == expected + + def test_authorization_server_with_port(self): + # Arrange & Act + urls = self.discovery._build_discovery_urls( + "https://auth.example.com:8443/oauth" + ) + + expected = [ + "https://auth.example.com:8443/.well-known/oauth-authorization-server/oauth", + "https://auth.example.com:8443/.well-known/oauth-authorization-server", + ] + + # Assert + assert urls == expected + + +class TestResourceURLCanonicalization: + """Test resource URL canonicalization logic in DiscoveryResult.""" + + def test_basic_server_url_canonicalization(self): + # Arrange + prm = ProtectedResourceMetadata( + authorization_servers=["https://auth.example.com"] + ) + asm = AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ) + + result = DiscoveryResult( + server_url="https://api.example.com/mcp", + protected_resource_metadata=prm, + authorization_server_metadata=asm, + auth_server_url="https://auth.example.com", + ) + + # Act + resource_url = result.get_resource_url() + + # Assert + assert resource_url == "https://api.example.com/mcp" + + def test_server_url_with_trailing_slash_normalization(self): + # Arrange + prm = ProtectedResourceMetadata( + authorization_servers=["https://auth.example.com"] + ) + asm = AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ) + + result = DiscoveryResult( + server_url="https://api.example.com/mcp/", # Trailing slash + protected_resource_metadata=prm, + authorization_server_metadata=asm, + auth_server_url="https://auth.example.com", + ) + + # Act + resource_url = result.get_resource_url() + + # Assert + assert resource_url == "https://api.example.com/mcp" # No trailing slash + + def test_case_insensitive_scheme_and_host(self): + # Arrange + prm = ProtectedResourceMetadata( + authorization_servers=["https://auth.example.com"] + ) + asm = AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ) + + result = DiscoveryResult( + server_url="HTTPS://API.EXAMPLE.COM/MCP", # Uppercase + protected_resource_metadata=prm, + authorization_server_metadata=asm, + auth_server_url="https://auth.example.com", + ) + + # Act + resource_url = result.get_resource_url() + + # Assert + assert ( + resource_url == "https://api.example.com/MCP" + ) # Scheme/host lowercase, path preserved + + def test_port_preservation(self): + # Arrange + prm = ProtectedResourceMetadata( + authorization_servers=["https://auth.example.com"] + ) + asm = AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ) + + result = DiscoveryResult( + server_url="https://api.example.com:8443/mcp", + protected_resource_metadata=prm, + authorization_server_metadata=asm, + auth_server_url="https://auth.example.com", + ) + + # Act + resource_url = result.get_resource_url() + + # Assert + assert resource_url == "https://api.example.com:8443/mcp" + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def setup_method(self): + self.discovery = OAuth2Discovery() + + def test_malformed_www_authenticate_header(self): + # Arrange + response = Mock() + response.headers = { + "WWW-Authenticate": "Bearer resource_metadata=" # Missing value + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert url is None + + def test_multiple_www_authenticate_headers(self): + # Arrange + response = Mock() + # httpx normalizes multiple headers into a single comma-separated value + response.headers = { + "WWW-Authenticate": 'Bearer realm="api", Basic realm="api", Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert url == "https://api.example.com/.well-known/oauth-protected-resource" + + def test_url_with_special_characters(self): + # Arrange + response = Mock() + response.headers = { + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource?client=test%20app"' + } + + # Act + url = self.discovery._extract_resource_metadata_from_www_auth(response) + + # Assert + assert ( + url + == "https://api.example.com/.well-known/oauth-protected-resource?client=test%20app" + ) diff --git a/tests/auth/client/test_flow_manager.py b/tests/auth/client/test_flow_manager.py new file mode 100644 index 0000000..060a375 --- /dev/null +++ b/tests/auth/client/test_flow_manager.py @@ -0,0 +1,214 @@ +"""Tests for OAuth 2.1 authorization flow orchestration. + +High-impact tests covering the complete authorization flow: +- Authorization URL generation with proper parameters +- Callback parsing and state validation +- Security parameter handling +- Error scenarios and edge cases +""" + +from urllib.parse import parse_qs, urlparse + +import pytest + +from conduit.auth.client.models.discovery import ( + AuthorizationServerMetadata, + DiscoveryResult, + ProtectedResourceMetadata, +) +from conduit.auth.client.models.errors import ( + StateValidationError, +) +from conduit.auth.client.models.registration import ClientCredentials +from conduit.auth.client.models.security import PKCEParameters +from conduit.auth.client.services.flow import OAuth2FlowManager + + +class TestStartAuthorizationFlow: + """Test authorization flow initiation and URL generation.""" + + def setup_method(self): + # Arrange + self.flow_manager = OAuth2FlowManager() + + # Mock discovery result + self.discovery_result = DiscoveryResult( + server_url="https://mcp.example.com", + protected_resource_metadata=ProtectedResourceMetadata( + authorization_servers=["https://auth.example.com"] + ), + authorization_server_metadata=AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ), + auth_server_url="https://auth.example.com", + ) + + # Mock client credentials + self.client_credentials = ClientCredentials(client_id="test-client-123") + + async def test_successful_flow_start_generates_valid_url(self): + """Test successful authorization flow start with all parameters.""" + # Act + auth_url, pkce_params, state = await self.flow_manager.start_authorization_flow( + self.discovery_result, + self.client_credentials, + "https://myapp.com/callback", + scope="read write", + ) + + # Assert - Parse the generated URL + parsed = urlparse(auth_url) + query_params = parse_qs(parsed.query) + + # Check base URL + assert parsed.scheme == "https" + assert parsed.netloc == "auth.example.com" + assert parsed.path == "/authorize" + + # Check required OAuth 2.1 parameters + assert query_params["response_type"] == ["code"] + assert query_params["client_id"] == ["test-client-123"] + assert query_params["redirect_uri"] == ["https://myapp.com/callback"] + assert query_params["code_challenge_method"] == ["S256"] + assert query_params["state"] == [state] + + # Check PKCE parameters + assert "code_challenge" in query_params + assert query_params["code_challenge"][0] == pkce_params.code_challenge + + # Check RFC 8707 resource parameter + assert query_params["resource"] == ["https://mcp.example.com"] + + # Check scope + assert query_params["scope"] == ["read write"] + + # Verify return values + assert isinstance(pkce_params, PKCEParameters) + assert len(state) == 32 # Security module generates 32-char state + assert auth_url.startswith("https://auth.example.com/authorize?") + + async def test_flow_start_without_scope(self): + """Test authorization flow start without optional scope parameter.""" + # Act + auth_url, pkce_params, state = await self.flow_manager.start_authorization_flow( + self.discovery_result, + self.client_credentials, + "https://myapp.com/callback", + # No scope parameter + ) + + # Assert + parsed = urlparse(auth_url) + query_params = parse_qs(parsed.query) + + # Scope should not be present + assert "scope" not in query_params + + # Other parameters should still be there + assert query_params["response_type"] == ["code"] + assert query_params["resource"] == ["https://mcp.example.com"] + + +class TestHandleAuthorizationCallback: + """Test callback URL parsing and state validation.""" + + def setup_method(self): + # Arrange + self.flow_manager = OAuth2FlowManager() + + async def test_successful_callback_with_authorization_code(self): + """Test successful callback parsing with authorization code.""" + # Arrange + expected_state = "test-state-123" + callback_url = ( + "https://myapp.com/callback?code=auth-code-456&state=test-state-123" + ) + + # Act + auth_response = await self.flow_manager.handle_authorization_callback( + callback_url, expected_state + ) + + # Assert + assert auth_response.is_success() + assert not auth_response.is_error() + assert auth_response.code == "auth-code-456" + assert auth_response.state == "test-state-123" + assert auth_response.error is None + + async def test_error_callback_with_oauth_error(self): + """Test callback parsing with OAuth error response.""" + # Arrange + expected_state = "test-state-123" + callback_url = ( + "https://myapp.com/callback?" + "error=access_denied&" + "error_description=User+denied+access&" + "state=test-state-123" + ) + + # Act + auth_response = await self.flow_manager.handle_authorization_callback( + callback_url, expected_state + ) + + # Assert + assert not auth_response.is_success() + assert auth_response.is_error() + assert auth_response.error == "access_denied" + assert auth_response.error_description == "User denied access" + assert auth_response.code is None + + async def test_state_parameter_mismatch_raises_error(self): + """Test state validation failure raises StateValidationError.""" + # Arrange + expected_state = "expected-state-123" + callback_url = ( + "https://myapp.com/callback?code=auth-code-456&state=wrong-state-456" + ) + + # Act & Assert + with pytest.raises(StateValidationError) as exc_info: + await self.flow_manager.handle_authorization_callback( + callback_url, expected_state + ) + + assert "State parameter mismatch" in str(exc_info.value) + assert "CSRF attack" in str(exc_info.value) + + async def test_missing_state_parameter_raises_error(self): + """Test missing state parameter raises StateValidationError.""" + # Arrange + expected_state = "expected-state-123" + callback_url = "https://myapp.com/callback?code=auth-code-456" # No state + + # Act & Assert + with pytest.raises(StateValidationError): + await self.flow_manager.handle_authorization_callback( + callback_url, expected_state + ) + + async def test_callback_url_parsing_edge_cases(self): + """Test edge cases in URL parsing that should work.""" + # Arrange + expected_state = "test-state-123" + + # These URLs are weird but should parse successfully: + weird_but_valid_urls = [ + "ftp://example.com/callback?code=abc&state=test-state-123", + "https://example.com:8080/callback?code=abc&state=test-state-123", + "https://example.com/path/callback?code=abc&state=test-state-123&extra=ignored", + ] + + for callback_url in weird_but_valid_urls: + # Act - Should not raise + auth_response = await self.flow_manager.handle_authorization_callback( + callback_url, expected_state + ) + + # Assert + assert auth_response.code == "abc" + assert auth_response.state == "test-state-123" diff --git a/tests/auth/client/test_pkce.py b/tests/auth/client/test_pkce.py new file mode 100644 index 0000000..ae982b7 --- /dev/null +++ b/tests/auth/client/test_pkce.py @@ -0,0 +1,40 @@ +import base64 +import hashlib + +from conduit.auth.client.services.pkce import PKCEManager + + +class TestPKCEManager: + def test_generate_parameters_crypto_requirements(self) -> None: + # Arrange + pkce_manager = PKCEManager() + + # Act + params = pkce_manager.generate_parameters() + + # Assert RFC 7636 requirements + assert 43 <= len(params.code_verifier) <= 128 + assert 43 <= len(params.code_challenge) <= 128 + assert params.code_challenge_method == "S256" + + # Verify code_challenge is base64url(sha256(code_verifier)) + expected_challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256(params.code_verifier.encode("ascii")).digest() + ) + .decode("ascii") + .rstrip("=") + ) + assert params.code_challenge == expected_challenge + + def test_generate_parameters_uniqueness(self) -> None: + # Arrange + pkce_manager = PKCEManager() + + # Act - Generate multiple parameters + params1 = pkce_manager.generate_parameters() + params2 = pkce_manager.generate_parameters() + + # Assert - Each generation is unique + assert params1.code_verifier != params2.code_verifier + assert params1.code_challenge != params2.code_challenge diff --git a/tests/auth/client/test_registration.py b/tests/auth/client/test_registration.py new file mode 100644 index 0000000..0291b47 --- /dev/null +++ b/tests/auth/client/test_registration.py @@ -0,0 +1,327 @@ +"""Tests for OAuth 2.1 dynamic client registration. + +High-impact tests covering the core registration flow for public clients: +- Successful registration with proper request/response handling +- Error response parsing and appropriate exception raising +- Request body validation and header construction +- Edge cases and malformed responses +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from conduit.auth.client.models.errors import RegistrationError +from conduit.auth.client.models.registration import ( + ClientMetadata, + ClientRegistration, +) +from conduit.auth.client.services.registration import OAuth2Registration + + +class TestSuccessfulRegistration: + """Test successful registration flow for public clients.""" + + def setup_method(self): + # Arrange + self.registration_service = OAuth2Registration() + self.registration_service._http_client = AsyncMock() + + async def test_successful_public_client_registration(self): + """Test successful registration of public client with all fields.""" + # Arrange + client_metadata = ClientMetadata( + client_name="Test MCP Client", + redirect_uris=["https://localhost:8080/callback"], + client_uri="https://example.com/client", + scope="mcp:read mcp:write", + ) + + # Mock successful response + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = { + "client_id": "generated-client-id-123", + "client_name": "Test MCP Client", + "redirect_uris": ["https://localhost:8080/callback"], + "client_uri": "https://example.com/client", + "scope": "mcp:read mcp:write", + "token_endpoint_auth_method": "none", + "grant_types": ["authorization_code"], + "response_types": ["code"], + "client_id_issued_at": 1640995200, + } + self.registration_service._http_client.post.return_value = mock_response + + # Act + result = await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + # Assert + assert isinstance(result, ClientRegistration) + assert result.credentials.client_id == "generated-client-id-123" + assert result.credentials.client_secret is None # Public client + assert result.metadata == client_metadata + assert result.registration_endpoint == "https://auth.example.com/register" + + # Verify HTTP request was made correctly + self.registration_service._http_client.post.assert_awaited_once() + call_args = self.registration_service._http_client.post.call_args + + # Check endpoint + assert call_args[0][0] == "https://auth.example.com/register" + + # Check request body contains client metadata + request_json = call_args[1]["json"] + assert request_json["client_name"] == "Test MCP Client" + assert request_json["redirect_uris"] == ["https://localhost:8080/callback"] + assert request_json["token_endpoint_auth_method"] == "none" + assert request_json["grant_types"] == ["authorization_code"] + + # Check headers + headers = call_args[1]["headers"] + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + + async def test_registration_with_initial_access_token(self): + """Test registration with initial access token for protected endpoints.""" + # Arrange + client_metadata = ClientMetadata( + client_name="Protected Client", + redirect_uris=["https://localhost:8080/callback"], + ) + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = { + "client_id": "protected-client-123", + "client_name": "Protected Client", + "redirect_uris": ["https://localhost:8080/callback"], + } + self.registration_service._http_client.post.return_value = mock_response + + # Act + await self.registration_service.register_client( + "https://auth.example.com/register", + client_metadata, + initial_access_token="bearer-token-xyz", + ) + + # Assert + call_args = self.registration_service._http_client.post.call_args + headers = call_args[1]["headers"] + assert headers["Authorization"] == "Bearer bearer-token-xyz" + + async def test_minimal_client_metadata_registration(self): + """Test registration with only required fields.""" + # Arrange + client_metadata = ClientMetadata( + client_name="Minimal Client", + redirect_uris=["https://localhost:8080/callback"], + ) + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = { + "client_id": "minimal-client-456", + "client_name": "Minimal Client", + "redirect_uris": ["https://localhost:8080/callback"], + } + self.registration_service._http_client.post.return_value = mock_response + + # Act + result = await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + # Assert + assert result.credentials.client_id == "minimal-client-456" + + # Verify request excludes None values + call_args = self.registration_service._http_client.post.call_args + request_json = call_args[1]["json"] + assert "client_uri" not in request_json # Should be excluded + assert "scope" not in request_json # Should be excluded + + +class TestRegistrationErrors: + """Test error handling and OAuth error response parsing.""" + + def setup_method(self): + # Arrange + self.registration_service = OAuth2Registration() + self.registration_service._http_client = AsyncMock() + + async def test_invalid_client_metadata_error(self): + """Test handling of invalid_client_metadata OAuth error.""" + # Arrange + client_metadata = ClientMetadata( + client_name="Test Client", redirect_uris=["ftp://some.invalid.uri"] + ) + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "invalid_client_metadata", + "error_description": "redirect_uris must use HTTPS", + } + self.registration_service._http_client.post.return_value = mock_response + + # Act & Assert + with pytest.raises(RegistrationError): + await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + async def test_unauthorized_registration_endpoint(self): + """Test handling of 401 Unauthorized (missing initial access token).""" + # Arrange + client_metadata = ClientMetadata( + client_name="Test Client", redirect_uris=["https://localhost:8080/callback"] + ) + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.json.return_value = { + "error": "invalid_token", + "error_description": "Initial access token required", + } + self.registration_service._http_client.post.return_value = mock_response + + # Act & Assert + with pytest.raises(RegistrationError): + await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + async def test_forbidden_registration(self): + """Test handling of 403 Forbidden (policy rejection).""" + # Arrange + client_metadata = ClientMetadata( + client_name="Blocked Client", + redirect_uris=["https://localhost:8080/callback"], + ) + + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.json.return_value = { + "error": "access_denied", + "error_description": "Client registration not allowed for this issuer", + } + self.registration_service._http_client.post.return_value = mock_response + + # Act & Assert + with pytest.raises(RegistrationError): + await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + async def test_missing_client_id_in_response(self): + """Test handling of malformed success response missing client_id.""" + # Arrange + client_metadata = ClientMetadata( + client_name="Test Client", redirect_uris=["https://localhost:8080/callback"] + ) + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = { + "client_name": "Test Client", + # Missing client_id! + } + self.registration_service._http_client.post.return_value = mock_response + + # Act & Assert + with pytest.raises(RegistrationError): + await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + +class TestRequestValidation: + """Test request body construction and validation.""" + + def setup_method(self): + # Arrange + self.registration_service = OAuth2Registration() + self.registration_service._http_client = AsyncMock() + + async def test_request_body_is_valid_json(self): + """Test that request body contains valid JSON with proper OAuth fields.""" + # Arrange + client_metadata = ClientMetadata( + client_name="JSON Test Client", + redirect_uris=[ + "https://localhost:8080/callback", + "https://localhost:8081/callback", + ], + client_uri="https://example.com/client", + scope="read write", + contacts=["admin@example.com"], + ) + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"client_id": "test-123"} + self.registration_service._http_client.post.return_value = mock_response + + # Act + await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + # Assert + call_args = self.registration_service._http_client.post.call_args + request_json = call_args[1]["json"] + + # Verify OAuth 2.1 required fields + assert request_json["client_name"] == "JSON Test Client" + assert request_json["redirect_uris"] == [ + "https://localhost:8080/callback", + "https://localhost:8081/callback", + ] + assert request_json["token_endpoint_auth_method"] == "none" + assert request_json["grant_types"] == ["authorization_code"] + assert request_json["response_types"] == ["code"] + + # Verify optional fields + assert request_json["client_uri"] == "https://example.com/client" + assert request_json["scope"] == "read write" + assert request_json["contacts"] == ["admin@example.com"] + + async def test_exclude_none_values_from_request(self): + """Test that None values are excluded from the request body.""" + # Arrange + client_metadata = ClientMetadata( + client_name="Minimal Client", + redirect_uris=["https://localhost:8080/callback"], + # All other fields are None + ) + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"client_id": "minimal-123"} + self.registration_service._http_client.post.return_value = mock_response + + # Act + await self.registration_service.register_client( + "https://auth.example.com/register", client_metadata + ) + + # Assert + call_args = self.registration_service._http_client.post.call_args + request_json = call_args[1]["json"] + + # Required fields should be present + assert "client_name" in request_json + assert "redirect_uris" in request_json + assert "token_endpoint_auth_method" in request_json + + # None fields should be excluded + assert "client_uri" not in request_json + assert "logo_uri" not in request_json + assert "scope" not in request_json + assert "contacts" not in request_json + assert "tos_uri" not in request_json + assert "policy_uri" not in request_json diff --git a/tests/auth/client/test_token_manager.py b/tests/auth/client/test_token_manager.py new file mode 100644 index 0000000..51900e2 --- /dev/null +++ b/tests/auth/client/test_token_manager.py @@ -0,0 +1,414 @@ +"""Tests for OAuth 2.1 token exchange and management. + +High-impact tests covering the token exchange flow: +- Successful authorization code to token exchange +- Token refresh functionality +- Error response handling and OAuth error codes +- Form encoding and request validation +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from conduit.auth.client.models.errors import TokenError +from conduit.auth.client.models.tokens import ( + RefreshTokenRequest, + TokenRequest, +) +from conduit.auth.client.services.tokens import OAuth2TokenManager + + +class TestTokenExchange: + """Test authorization code to access token exchange.""" + + def setup_method(self): + # Arrange + self.token_manager = OAuth2TokenManager() + self.token_manager._http_client = AsyncMock() + self.code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + async def test_successful_token_exchange_with_all_fields(self): + """Test successful token exchange with complete response.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + resource="https://mcp.example.com", + scope="read write", + ) + + # Mock successful response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "access-token-xyz", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refresh-token-abc", + "scope": "read write", + } + self.token_manager._http_client.post.return_value = mock_response + + # Act + token_response = await self.token_manager.exchange_code_for_token(token_request) + + # Assert + assert token_response.is_success() + assert not token_response.is_error() + assert token_response.access_token == "access-token-xyz" + assert token_response.token_type == "Bearer" + assert token_response.expires_in == 3600 + assert token_response.refresh_token == "refresh-token-abc" + assert token_response.scope == "read write" + assert token_response.error is None + + # Verify HTTP request was made correctly + self.token_manager._http_client.post.assert_awaited_once() + call_args = self.token_manager._http_client.post.call_args + + # Check endpoint + assert call_args[0][0] == "https://auth.example.com/token" + + # Check form data + form_data = call_args[1]["data"] + assert form_data["grant_type"] == "authorization_code" + assert form_data["code"] == "auth-code-123" + assert form_data["redirect_uri"] == "https://myapp.com/callback" + assert form_data["client_id"] == "client-456" + assert form_data["code_verifier"] == self.code_verifier + assert form_data["resource"] == "https://mcp.example.com" + assert form_data["scope"] == "read write" + + # Check headers + headers = call_args[1]["headers"] + assert headers["Content-Type"] == "application/x-www-form-urlencoded" + assert headers["Accept"] == "application/json" + + async def test_token_exchange_without_optional_fields(self): + """Test token exchange with minimal required fields.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + # No resource or scope + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "access-token-xyz", + "token_type": "Bearer", + } + self.token_manager._http_client.post.return_value = mock_response + + # Act + token_response = await self.token_manager.exchange_code_for_token(token_request) + + # Assert + assert token_response.is_success() + assert token_response.access_token == "access-token-xyz" + + # Verify optional fields are excluded from form data + call_args = self.token_manager._http_client.post.call_args + form_data = call_args[1]["data"] + assert "resource" not in form_data + assert "scope" not in form_data + + async def test_token_response_conversion_to_token_state(self): + """Test converting successful token response to TokenState.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "access-token-xyz", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refresh-token-abc", + } + self.token_manager._http_client.post.return_value = mock_response + + # Act + token_response = await self.token_manager.exchange_code_for_token(token_request) + token_state = token_response.to_token_state() + + # Assert + assert token_state.access_token == "access-token-xyz" + assert token_state.refresh_token == "refresh-token-abc" + assert token_state.token_type == "Bearer" + assert token_state.expires_at is not None # Should calculate expiry time + assert token_state.is_valid() + assert token_state.can_refresh() + + +class TestTokenExchangeErrors: + """Test error handling in token exchange.""" + + def setup_method(self): + # Arrange + self.token_manager = OAuth2TokenManager() + self.token_manager._http_client = AsyncMock() + self.code_verifier = "dBjftJeZ4CVP" + + async def test_invalid_grant_error(self): + """Test handling of invalid_grant OAuth error.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="expired-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + ) + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "invalid_grant", + "error_description": "Authorization code has expired", + } + self.token_manager._http_client.post.return_value = mock_response + + # Act + token_response = await self.token_manager.exchange_code_for_token(token_request) + + # Assert + assert not token_response.is_success() + assert token_response.is_error() + assert token_response.error == "invalid_grant" + assert token_response.error_description == "Authorization code has expired" + assert token_response.access_token is None + + async def test_invalid_client_error(self): + """Test handling of invalid_client OAuth error.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="invalid-client", + code_verifier=self.code_verifier, + ) + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.json.return_value = { + "error": "invalid_client", + "error_description": "Client authentication failed", + } + self.token_manager._http_client.post.return_value = mock_response + + # Act + token_response = await self.token_manager.exchange_code_for_token(token_request) + + # Assert + assert token_response.is_error() + assert token_response.error == "invalid_client" + assert token_response.error_description == "Client authentication failed" + + async def test_missing_access_token_in_success_response(self): + """Test handling of malformed success response missing access_token.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "token_type": "Bearer", + # Missing access_token! + } + self.token_manager._http_client.post.return_value = mock_response + + # Act & Assert + with pytest.raises(TokenError) as exc_info: + await self.token_manager.exchange_code_for_token(token_request) + + assert "missing required access_token" in str(exc_info.value) + + +class TestTokenRefresh: + """Test token refresh functionality.""" + + def setup_method(self): + # Arrange + self.token_manager = OAuth2TokenManager() + self.token_manager._http_client = AsyncMock() + self.code_verifier = "dBjftJeZ4CVP" + + async def test_successful_token_refresh(self): + """Test successful token refresh.""" + # Arrange + refresh_request = RefreshTokenRequest( + token_endpoint="https://auth.example.com/token", + refresh_token="refresh-token-abc", + client_id="client-456", + resource="https://mcp.example.com", + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new-access-token-xyz", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new-refresh-token-def", # New refresh token + } + self.token_manager._http_client.post.return_value = mock_response + + # Act + token_response = await self.token_manager.refresh_access_token(refresh_request) + + # Assert + assert token_response.is_success() + assert token_response.access_token == "new-access-token-xyz" + assert token_response.refresh_token == "new-refresh-token-def" + + # Verify HTTP request + call_args = self.token_manager._http_client.post.call_args + form_data = call_args[1]["data"] + assert form_data["grant_type"] == "refresh_token" + assert form_data["refresh_token"] == "refresh-token-abc" + assert form_data["client_id"] == "client-456" + assert form_data["resource"] == "https://mcp.example.com" + + async def test_refresh_token_error(self): + """Test handling of refresh token errors.""" + # Arrange + refresh_request = RefreshTokenRequest( + token_endpoint="https://auth.example.com/token", + refresh_token="invalid-refresh-token", + client_id="client-456", + ) + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "invalid_grant", + "error_description": "Refresh token has expired", + } + self.token_manager._http_client.post.return_value = mock_response + + # Act + token_response = await self.token_manager.refresh_access_token(refresh_request) + + # Assert + assert token_response.is_error() + assert token_response.error == "invalid_grant" + assert token_response.error_description == "Refresh token has expired" + + +class TestFormEncoding: + """Test form encoding requirements.""" + + def setup_method(self): + # Arrange + self.token_manager = OAuth2TokenManager() + self.token_manager._http_client = AsyncMock() + self.code_verifier = "dBjftJeZ4CVP" + + async def test_form_encoding_content_type(self): + """Test that requests use proper form encoding content type.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"access_token": "token-xyz"} + self.token_manager._http_client.post.return_value = mock_response + + # Act + await self.token_manager.exchange_code_for_token(token_request) + + # Assert + call_args = self.token_manager._http_client.post.call_args + headers = call_args[1]["headers"] + + # Must use form encoding, not JSON + assert headers["Content-Type"] == "application/x-www-form-urlencoded" + assert headers["Accept"] == "application/json" + + # Should use 'data' parameter (form), not 'json' parameter + assert "data" in call_args[1] + assert "json" not in call_args[1] + + +class TestHttpErrors: + """Test HTTP-level errors and network issues.""" + + def setup_method(self): + # Arrange + self.token_manager = OAuth2TokenManager() + self.token_manager._http_client = AsyncMock() + self.code_verifier = "dBjftJeZ4CVP" + + async def test_network_error_raises_token_error(self): + """Test that network errors are wrapped in TokenError.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + ) + + # Mock network failure + import httpx + + self.token_manager._http_client.post.side_effect = httpx.ConnectError( + "Connection failed" + ) + + # Act & Assert + with pytest.raises(TokenError): + await self.token_manager.exchange_code_for_token(token_request) + + async def test_non_json_error_response_raises_token_error(self): + """Test that non-JSON error responses raise TokenError.""" + # Arrange + token_request = TokenRequest( + token_endpoint="https://auth.example.com/token", + code="auth-code-123", + redirect_uri="https://myapp.com/callback", + client_id="client-456", + code_verifier=self.code_verifier, + ) + + # Mock response that returns HTML instead of JSON (common for 401/500 errors) + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.json.side_effect = ValueError( + "Not valid JSON" + ) # JSON parsing fails + + self.token_manager._http_client.post.return_value = mock_response + + # Act & Assert + with pytest.raises(TokenError): + await self.token_manager.exchange_code_for_token(token_request)