Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
102 changes: 102 additions & 0 deletions src/conduit/auth/client/models/discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Discovery-related models for OAuth 2.1 server metadata.

Contains models for Protected Resource Metadata (RFC 9728) and
Authorization Server Metadata (RFC 8414) discovery.
"""

from __future__ import annotations

from dataclasses import dataclass
from urllib.parse import urlparse

from pydantic import BaseModel, Field, field_validator


class ProtectedResourceMetadata(BaseModel):
"""OAuth 2.0 Protected Resource Metadata (RFC 9728).

Metadata returned by MCP servers to indicate their authorization servers
and resource configuration.
"""

resource: str | None = None
authorization_servers: list[str] = Field(min_length=1)

# Optional fields from RFC 9728
bearer_methods_supported: list[str] | None = None
resource_documentation: str | None = None
resource_policy_uri: str | None = None
resource_tos_uri: str | None = None

@field_validator("authorization_servers")
@classmethod
def validate_auth_servers(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("At least one authorization server is required")
return v


class AuthorizationServerMetadata(BaseModel):
"""OAuth 2.0 Authorization Server Metadata (RFC 8414).

Metadata returned by authorization servers describing their endpoints
and supported capabilities.
"""

# Required by RFC 8414
issuer: str
response_types_supported: list[str] = Field(min_length=1)

# Required for authorization code flow (our use case)
authorization_endpoint: str
token_endpoint: str

# PKCE support (required for OAuth 2.1)
code_challenge_methods_supported: list[str] = Field(default=["S256"])

# Dynamic registration (RFC 7591)
registration_endpoint: str | None = None

# Optional but commonly used
revocation_endpoint: str | None = None
introspection_endpoint: str | None = None
scopes_supported: list[str] | None = None
grant_types_supported: list[str] = Field(default=["authorization_code"])

@field_validator("code_challenge_methods_supported")
@classmethod
def validate_pkce_support(cls, v: list[str]) -> list[str]:
if "S256" not in v:
raise ValueError("Authorization server must support S256 PKCE method")
return v


@dataclass(frozen=True)
class DiscoveryResult:
"""Complete discovery results for an MCP server.

Immutable result containing all metadata needed for OAuth flow.
Combines Protected Resource Metadata and Authorization Server Metadata.
"""

server_url: str
protected_resource_metadata: ProtectedResourceMetadata
authorization_server_metadata: AuthorizationServerMetadata
auth_server_url: str

def get_resource_url(self) -> str:
"""Get the resource URL for RFC 8707 resource parameter.

Uses the most specific URI (the actual server URL) as per MCP spec
guidance to provide the most specific URI possible.
"""
# Canonicalize the server URL per RFC 3986 and MCP spec
parsed = urlparse(self.server_url)
canonical = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}"
if parsed.path and parsed.path != "/":
canonical += parsed.path.rstrip("/") # Preserve case, remove trailing slash

return canonical

# TODO: Consider if we ever need PRM resource override
# The spec is unclear about when to use PRM resource vs server URL
99 changes: 99 additions & 0 deletions src/conduit/auth/client/models/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Exception hierarchy for OAuth 2.1 authentication errors.

Provides specific exception types for different failure modes to enable
precise error handling and recovery strategies.
"""

from __future__ import annotations


class OAuth2Error(Exception):
"""Base exception for all OAuth 2.1 related errors."""

pass


class DiscoveryError(OAuth2Error):
"""Raised when OAuth server discovery fails."""

pass


class ProtectedResourceMetadataError(DiscoveryError):
"""Raised when Protected Resource Metadata discovery fails."""

pass


class AuthorizationServerMetadataError(DiscoveryError):
"""Raised when Authorization Server Metadata discovery fails."""

pass


class RegistrationError(OAuth2Error):
"""Raised when dynamic client registration fails."""

pass


class TokenError(OAuth2Error):
"""Raised when token operations fail."""

pass


class TokenRefreshError(TokenError):
"""Raised when token refresh fails."""

pass


class TokenExchangeError(TokenError):
"""Raised when authorization code to token exchange fails."""

pass


class AuthorizationError(OAuth2Error):
"""Raised when user authorization fails."""

pass


class AuthorizationResponseError(OAuth2Error):
"""Raised when authorization response is malformed or invalid."""

pass


class UserAuthCancelledError(AuthorizationError):
"""Raised when user cancels the authorization flow."""

pass


class PKCEError(OAuth2Error):
"""Raised when PKCE parameter generation or validation fails."""

pass


class AuthorizationCallbackError(OAuth2Error):
"""Raised when authorization server callback data is malformed or invalid.

This indicates the authorization server sent an invalid callback URL,
not that our callback handling code failed.
"""

pass


class StateValidationError(AuthorizationCallbackError):
"""Raised when OAuth state parameter validation fails.

This indicates either a missing state parameter or a state mismatch,
which could indicate a CSRF attack or authorization server issue.
"""

pass
56 changes: 56 additions & 0 deletions src/conduit/auth/client/models/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Authorization flow models for OAuth 2.1.

Contains models for authorization requests and callback handling.
"""

from __future__ import annotations

from dataclasses import dataclass
from urllib.parse import urlencode


@dataclass(frozen=True)
class AuthorizationRequest:
"""Authorization request parameters for OAuth 2.1 flow."""

authorization_endpoint: str
client_id: str
redirect_uri: str
code_challenge: str
code_challenge_method: str
state: str
resource: str | None = None # RFC 8707
scope: str | None = None

def build_authorization_url(self) -> str:
"""Build the complete authorization URL."""
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"code_challenge": self.code_challenge,
"code_challenge_method": self.code_challenge_method,
"state": self.state,
}

if self.resource:
params["resource"] = self.resource
if self.scope:
params["scope"] = self.scope

return f"{self.authorization_endpoint}?{urlencode(params)}"


@dataclass(frozen=True)
class AuthorizationResponse:
code: str | None = None
state: str | None = None
error: str | None = None
error_description: str | None = None
error_uri: str | None = None

def is_success(self) -> bool:
return self.error is None and self.code is not None

def is_error(self) -> bool:
return self.error is not None
80 changes: 80 additions & 0 deletions src/conduit/auth/client/models/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Client registration models for OAuth 2.0 Dynamic Client Registration.

Contains models for client metadata (RFC 7591) and registration results.
"""

from __future__ import annotations

import time
from dataclasses import dataclass
from urllib.parse import urlparse

from pydantic import BaseModel, Field, field_validator


class ClientMetadata(BaseModel):
"""OAuth 2.0 Client Metadata for dynamic registration (RFC 7591)."""

client_name: str
redirect_uris: list[str] = Field(min_length=1)

# Optional metadata
client_uri: str | None = None
logo_uri: str | None = None
scope: str | None = None
contacts: list[str] | None = None
tos_uri: str | None = None
policy_uri: str | None = None

# OAuth 2.1 specific
token_endpoint_auth_method: str = "none" # Public client
grant_types: list[str] = Field(default=["authorization_code"])
response_types: list[str] = Field(default=["code"])

@field_validator("redirect_uris")
@classmethod
def validate_redirect_uris(cls, v: list[str]) -> list[str]:
"""Validate redirect URIs meet OAuth 2.1 security requirements."""
for uri in v:
parsed = urlparse(uri)
# Must be HTTPS or localhost
if parsed.scheme == "http" and parsed.hostname != "localhost":
raise ValueError(f"Redirect URI must use HTTPS or localhost: {uri}")
return v

@field_validator("client_uri", "logo_uri", "tos_uri", "policy_uri")
@classmethod
def validate_https_uris(cls, v: str | None) -> str | None:
"""Validate optional URIs use HTTPS when provided."""
if v is not None and not v.startswith("https://"):
raise ValueError(f"URI must use HTTPS: {v}")
return v


class ClientCredentials(BaseModel):
"""OAuth 2.0 Client Credentials from registration response."""

client_id: str
client_secret: str | None = None # None for public clients
registration_access_token: str | None = None
registration_client_uri: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None

def is_expired(self) -> bool:
"""Check if client credentials have expired."""
if self.client_secret_expires_at is None:
return False
return time.time() >= self.client_secret_expires_at


@dataclass(frozen=True)
class ClientRegistration:
"""Complete client registration result.

Immutable result containing client metadata and credentials.
"""

metadata: ClientMetadata
credentials: ClientCredentials
registration_endpoint: str
31 changes: 31 additions & 0 deletions src/conduit/auth/client/models/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Security-related models for OAuth 2.1 authentication.

Contains PKCE parameters and other cryptographic primitives needed
for secure OAuth flows.
"""

from __future__ import annotations

from dataclasses import dataclass, field


@dataclass(frozen=True)
class PKCEParameters:
"""PKCE (Proof Key for Code Exchange) parameters for OAuth 2.1 security.

Immutable parameters generated for each authorization flow to prevent
authorization code interception attacks (RFC 7636).
"""

code_verifier: str = field()
code_challenge: str = field()
code_challenge_method: str = field(default="S256")

def __post_init__(self) -> None:
"""Validate PKCE parameters meet RFC 7636 requirements."""
if not (43 <= len(self.code_verifier) <= 128):
raise ValueError("code_verifier must be 43-128 characters")
if not (43 <= len(self.code_challenge) <= 128):
raise ValueError("code_challenge must be 43-128 characters")
if self.code_challenge_method != "S256":
raise ValueError("Only S256 code challenge method is supported")
Loading
Loading