Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/conduit/client/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Coroutine
from typing import Any, Awaitable, Callable, TypeVar

from conduit.client.request_context import RequestContext
from conduit.client.message_context import MessageContext
from conduit.client.server_manager import ServerManager
from conduit.protocol.base import (
INTERNAL_ERROR,
Expand All @@ -35,9 +35,9 @@
TResult = TypeVar("TResult", bound=Result)
TNotification = TypeVar("TNotification", bound=Notification)

RequestHandler = Callable[[RequestContext, TRequest], Awaitable[TResult | Error]]
RequestHandler = Callable[[MessageContext, TRequest], Awaitable[TResult | Error]]
NotificationHandler = Callable[
[RequestContext, TNotification], Coroutine[Any, Any, None]
[MessageContext, TNotification], Coroutine[Any, Any, None]
]


Expand Down Expand Up @@ -137,14 +137,14 @@ def _on_message_loop_done(self, task: asyncio.Task[None]) -> None:
# Build context
# ================================

def _build_context(self, server_id: str) -> RequestContext:
def _build_context(self, server_id: str) -> MessageContext:
"""Builds context for a request.

Args:
server_id: ID of the server making the request

Returns:
RequestContext: Context with server state and helpers
MessageContext: Context with server state and helpers

Raises:
ValueError: If the server is not registered with the client
Expand All @@ -153,7 +153,7 @@ def _build_context(self, server_id: str) -> RequestContext:
if server_state is None:
raise ValueError(f"Server {server_id} not registered")

return RequestContext(
return MessageContext(
server_id=server_id,
server_state=server_state,
server_manager=self.server_manager,
Expand Down Expand Up @@ -250,7 +250,7 @@ async def _route_request(
async def _execute_request_handler(
self,
handler: RequestHandler,
context: RequestContext,
context: MessageContext,
request_id: str | int,
request: Request,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


@dataclass
class RequestContext:
"""Rich context for handling server -> client requests.
class MessageContext:
"""Rich context for handling server -> client messages.

Provides immediate access to server state, capabilities, and helper methods.
"""
Expand Down Expand Up @@ -60,6 +60,6 @@ def get_server_display_name(self) -> str:
def __str__(self) -> str:
"""String representation for logging."""
return (
f"RequestContext(server={self.get_server_display_name()},"
f"MessageContext(server={self.get_server_display_name()},"
f"id={self.server_id})"
)
4 changes: 2 additions & 2 deletions src/conduit/client/protocol/elicitation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Awaitable, Callable

from conduit.client.request_context import RequestContext
from conduit.client.message_context import MessageContext
from conduit.protocol.elicitation import ElicitRequest, ElicitResult


Expand All @@ -19,7 +19,7 @@ def __init__(self):
self.logger = logging.getLogger("conduit.client.protocol.elicitation")

async def handle_elicitation(
self, context: RequestContext, request: ElicitRequest
self, context: MessageContext, request: ElicitRequest
) -> ElicitResult:
"""Elicit a response from the user.

Expand Down
4 changes: 2 additions & 2 deletions src/conduit/client/protocol/roots.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from copy import deepcopy

from conduit.client.request_context import RequestContext
from conduit.client.message_context import MessageContext
from conduit.protocol.roots import ListRootsRequest, ListRootsResult, Root


Expand Down Expand Up @@ -78,7 +78,7 @@ def cleanup_server(self, server_id: str) -> None:
# ================================

async def handle_list_roots(
self, context: RequestContext, request: ListRootsRequest
self, context: MessageContext, request: ListRootsRequest
) -> ListRootsResult:
"""List the roots available to the server making the request."""
roots = self.get_server_roots(context.server_id)
Expand Down
4 changes: 2 additions & 2 deletions src/conduit/client/protocol/sampling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Awaitable, Callable

from conduit.client.request_context import RequestContext
from conduit.client.message_context import MessageContext
from conduit.protocol.sampling import CreateMessageRequest, CreateMessageResult


Expand All @@ -19,7 +19,7 @@ def __init__(self):
self.logger = logging.getLogger("conduit.client.protocol.sampling")

async def handle_create_message(
self, context: RequestContext, request: CreateMessageRequest
self, context: MessageContext, request: CreateMessageRequest
) -> CreateMessageResult:
"""Sample the host LLM for the server.

Expand Down
24 changes: 12 additions & 12 deletions src/conduit/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from conduit.client.callbacks import CallbackManager
from conduit.client.coordinator import MessageCoordinator
from conduit.client.message_context import MessageContext
from conduit.client.protocol.elicitation import (
ElicitationManager,
ElicitationNotConfiguredError,
)
from conduit.client.protocol.roots import RootsManager
from conduit.client.protocol.sampling import SamplingManager, SamplingNotConfiguredError
from conduit.client.request_context import RequestContext
from conduit.client.server_manager import ServerManager
from conduit.protocol.base import (
INTERNAL_ERROR,
Expand Down Expand Up @@ -293,7 +293,7 @@ def _validate_protocol_version(self, result: InitializeResult) -> None:
# ================================

async def _handle_ping(
self, context: RequestContext, request: PingRequest
self, context: MessageContext, request: PingRequest
) -> EmptyResult:
"""Returns an empty result."""

Expand All @@ -304,7 +304,7 @@ async def _handle_ping(
# ================================

async def _handle_list_roots(
self, context: RequestContext, request: ListRootsRequest
self, context: MessageContext, request: ListRootsRequest
) -> ListRootsResult | Error:
"""Returns the roots available to the server.

Expand All @@ -324,7 +324,7 @@ async def _handle_list_roots(
# ================================

async def _handle_sampling(
self, context: RequestContext, request: CreateMessageRequest
self, context: MessageContext, request: CreateMessageRequest
) -> CreateMessageResult | Error:
"""Creates a message using the configured sampling handler.

Expand Down Expand Up @@ -354,7 +354,7 @@ async def _handle_sampling(
# ================================

async def _handle_elicitation(
self, context: RequestContext, request: ElicitRequest
self, context: MessageContext, request: ElicitRequest
) -> ElicitResult | Error:
"""Returns an elicitation result using the configured elicitation handler.

Expand Down Expand Up @@ -384,7 +384,7 @@ async def _handle_elicitation(
# ================================

async def _handle_cancelled(
self, context: RequestContext, notification: CancelledNotification
self, context: MessageContext, notification: CancelledNotification
) -> None:
"""Cancels a request from the server and calls the registered callback."""
request_exists = (
Expand All @@ -400,13 +400,13 @@ async def _handle_cancelled(
await self.callbacks.call_cancelled(context.server_id, notification)

async def _handle_progress(
self, context: RequestContext, notification: ProgressNotification
self, context: MessageContext, notification: ProgressNotification
) -> None:
"""Calls the registered callback for progress updates."""
await self.callbacks.call_progress(context.server_id, notification)

async def _handle_prompts_list_changed(
self, context: RequestContext, notification: PromptListChangedNotification
self, context: MessageContext, notification: PromptListChangedNotification
) -> None:
"""Fetches the updated prompts list and calls the registered callback.

Expand All @@ -433,7 +433,7 @@ async def _handle_prompts_list_changed(
)

async def _handle_resources_list_changed(
self, context: RequestContext, notification: ResourceListChangedNotification
self, context: MessageContext, notification: ResourceListChangedNotification
) -> None:
"""Fetches the updated resources/templates and calls the registered callback.

Expand Down Expand Up @@ -479,7 +479,7 @@ async def _handle_resources_list_changed(
)

async def _handle_resources_updated(
self, context: RequestContext, notification: ResourceUpdatedNotification
self, context: MessageContext, notification: ResourceUpdatedNotification
) -> None:
"""Reads the updated resource content and calls the registered callback.

Expand All @@ -502,7 +502,7 @@ async def _handle_resources_updated(
)

async def _handle_tools_list_changed(
self, context: RequestContext, notification: ToolListChangedNotification
self, context: MessageContext, notification: ToolListChangedNotification
) -> None:
"""Fetches the updated tools list and calls the registered callback.

Expand All @@ -529,7 +529,7 @@ async def _handle_tools_list_changed(
)

async def _handle_logging_message(
self, context: RequestContext, notification: LoggingMessageNotification
self, context: MessageContext, notification: LoggingMessageNotification
) -> None:
"""Calls the registered callback for logging messages."""
await self.callbacks.call_logging_message(context.server_id, notification)
Expand Down
16 changes: 8 additions & 8 deletions src/conduit/server/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@
JSONRPCResponse,
)
from conduit.server.client_manager import ClientManager
from conduit.server.request_context import RequestContext
from conduit.server.message_context import MessageContext
from conduit.shared.message_parser import MessageParser
from conduit.transport.server import ClientMessage, ServerTransport, TransportContext

TRequest = TypeVar("TRequest", bound=Request)
TResult = TypeVar("TResult", bound=Result)
TNotification = TypeVar("TNotification", bound=Notification)

RequestHandler = Callable[[RequestContext, TRequest], Awaitable[TResult | Error]]
RequestHandler = Callable[[MessageContext, TRequest], Awaitable[TResult | Error]]
NotificationHandler = Callable[
[RequestContext, TNotification], Coroutine[Any, Any, None]
[MessageContext, TNotification], Coroutine[Any, Any, None]
]


Expand Down Expand Up @@ -140,14 +140,14 @@ def _on_message_loop_done(self, task: asyncio.Task[None]) -> None:

def _build_context(
self, client_id: str, originating_request_id: str | int | None = None
) -> RequestContext:
"""Builds context for a request.
) -> MessageContext:
"""Builds context for handling a message.

Args:
client_id: ID of the client making the request

Returns:
RequestContext: Rich context with client state and helpers
MessageContext: Rich context with client state and helpers

Raises:
ValueError: If client is not registered
Expand All @@ -156,7 +156,7 @@ def _build_context(
if client_state is None:
raise ValueError(f"Client {client_id} not registered")

return RequestContext(
return MessageContext(
client_id=client_id,
client_state=client_state,
client_manager=self.client_manager,
Expand Down Expand Up @@ -256,7 +256,7 @@ async def _route_request(
async def _execute_request_handler(
self,
handler: RequestHandler,
context: RequestContext,
context: MessageContext,
request_id: str | int,
request: Request,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@


@dataclass
class RequestContext:
"""Rich context for handling client -> server requests.
class MessageContext:
"""Rich context for handling client -> server messages.

Provides immediate access to client state, capabilities, and helper methods
instead of requiring handlers to work with bare client_id strings.

This context is built once at the coordinator level and threaded through
the request pipeline, giving handlers everything they need to make
the message pipeline, giving handlers everything they need to make
informed decisions about client capabilities and state.
"""

Expand Down Expand Up @@ -104,6 +104,6 @@ def get_client_display_name(self) -> str:
def __str__(self) -> str:
"""String representation for logging."""
return (
f"RequestContext(client={self.get_client_display_name()},"
f"MessageContext(client={self.get_client_display_name()},"
f"id={self.client_id})"
)
6 changes: 3 additions & 3 deletions src/conduit/server/protocol/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from conduit.protocol.completions import CompleteRequest, CompleteResult

if TYPE_CHECKING:
from conduit.server.request_context import RequestContext
from conduit.server.message_context import MessageContext


class CompletionNotConfiguredError(Exception):
Expand All @@ -14,13 +14,13 @@ class CompletionNotConfiguredError(Exception):
class CompletionManager:
def __init__(self):
self.completion_handler: (
Callable[["RequestContext", CompleteRequest], Awaitable[CompleteResult]]
Callable[["MessageContext", CompleteRequest], Awaitable[CompleteResult]]
| None
) = None
self.logger = logging.getLogger("conduit.server.protocol.completions")

async def handle_complete(
self, context: "RequestContext", request: CompleteRequest
self, context: "MessageContext", request: CompleteRequest
) -> CompleteResult:
"""Generate a completion for a given argument.

Expand Down
4 changes: 2 additions & 2 deletions src/conduit/server/protocol/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from conduit.protocol.logging import LoggingLevel, SetLevelRequest

if TYPE_CHECKING:
from conduit.server.request_context import RequestContext
from conduit.server.message_context import MessageContext


class LoggingManager:
Expand Down Expand Up @@ -55,7 +55,7 @@ def cleanup_client(self, client_id: str) -> None:
self._client_log_levels.pop(client_id, None)

async def handle_set_level(
self, context: "RequestContext", request: SetLevelRequest
self, context: "MessageContext", request: SetLevelRequest
) -> EmptyResult:
"""Set the MCP protocol logging level for a specific client.

Expand Down
8 changes: 4 additions & 4 deletions src/conduit/server/protocol/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
)

if TYPE_CHECKING:
from conduit.server.request_context import RequestContext
from conduit.server.message_context import MessageContext

PromptHandler = Callable[
["RequestContext", GetPromptRequest], Awaitable[GetPromptResult]
["MessageContext", GetPromptRequest], Awaitable[GetPromptResult]
]


Expand Down Expand Up @@ -170,7 +170,7 @@ def cleanup_client(self, client_id: str) -> None:
# ================================

async def handle_list_prompts(
self, context: "RequestContext", request: ListPromptsRequest
self, context: "MessageContext", request: ListPromptsRequest
) -> ListPromptsResult:
"""List all prompts available to this client.

Expand All @@ -188,7 +188,7 @@ async def handle_list_prompts(
return ListPromptsResult(prompts=list(prompts.values()))

async def handle_get_prompt(
self, context: "RequestContext", request: GetPromptRequest
self, context: "MessageContext", request: GetPromptRequest
) -> GetPromptResult:
"""Execute a prompt request for a specific client.

Expand Down
Loading
Loading