Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions codex-rs/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ portable-pty = { workspace = true }
rand = { workspace = true }
regex-lite = { workspace = true }
reqwest = { workspace = true, features = ["json", "stream"] }
rmcp = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
sha1 = { workspace = true }
Expand Down
5 changes: 4 additions & 1 deletion codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ pub(crate) async fn stream_chat_completions(
let mut messages = Vec::<serde_json::Value>::new();

let full_instructions = prompt.get_full_instructions(model_family);
messages.push(json!({"role": "system", "content": full_instructions}));
// Only add system message if instructions are non-empty
if !full_instructions.is_empty() {
messages.push(json!({"role": "system", "content": full_instructions}));
}

let input = prompt.get_formatted_input();

Expand Down
5 changes: 5 additions & 0 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ impl Session {
model_reasoning_summary,
conversation_id,
);

// Set the config in the MCP connection manager's sampling handler
// so that MCP servers can make LLM requests with proper settings.
mcp_connection_manager.set_config(config.clone()).await;

let turn_context = TurnContext {
client,
tools_config: ToolsConfig::new(&ToolsConfigParams {
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub mod git_info;
pub mod landlock;
pub mod mcp;
mod mcp_connection_manager;
mod mcp_sampling_handler;
mod mcp_tool_call;
mod message_history;
mod model_provider_info;
Expand Down
59 changes: 53 additions & 6 deletions codex-rs/core/src/mcp_connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use tracing::warn;

use crate::config_types::McpServerConfig;
use crate::config_types::McpServerTransportConfig;
use crate::mcp_sampling_handler::CodexSamplingHandler;

/// Delimiter used to separate the server name from the tool name in a fully
/// qualified tool name.
Expand Down Expand Up @@ -109,9 +110,12 @@ impl McpClientAdapter {
env: Option<HashMap<String, String>>,
params: mcp_types::InitializeRequestParams,
startup_timeout: Duration,
sampling_handler: Option<Arc<CodexSamplingHandler>>,
) -> Result<Self> {
if use_rmcp_client {
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
let handler: Option<Arc<dyn codex_rmcp_client::SamplingHandler>> =
sampling_handler.map(|h| h as Arc<dyn codex_rmcp_client::SamplingHandler>);
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env, handler).await?);
client.initialize(params, Some(startup_timeout)).await?;
Ok(McpClientAdapter::Rmcp(client))
} else {
Expand All @@ -128,10 +132,19 @@ impl McpClientAdapter {
params: mcp_types::InitializeRequestParams,
startup_timeout: Duration,
store_mode: OAuthCredentialsStoreMode,
sampling_handler: Option<Arc<CodexSamplingHandler>>,
) -> Result<Self> {
let handler: Option<Arc<dyn codex_rmcp_client::SamplingHandler>> =
sampling_handler.map(|h| h as Arc<dyn codex_rmcp_client::SamplingHandler>);
let client = Arc::new(
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode)
.await?,
RmcpClient::new_streamable_http_client(
&server_name,
&url,
bearer_token,
store_mode,
handler,
)
.await?,
);
client.initialize(params, Some(startup_timeout)).await?;
Ok(McpClientAdapter::Rmcp(client))
Expand Down Expand Up @@ -172,6 +185,9 @@ pub(crate) struct McpConnectionManager {

/// Fully qualified tool name -> tool instance.
tools: HashMap<String, ToolInfo>,

/// Sampling handler shared across all MCP clients for LLM requests.
sampling_handler: Option<Arc<CodexSamplingHandler>>,
}

impl McpConnectionManager {
Expand All @@ -188,9 +204,19 @@ impl McpConnectionManager {
use_rmcp_client: bool,
store_mode: OAuthCredentialsStoreMode,
) -> Result<(Self, ClientStartErrors)> {
// Create sampling handler if using rmcp client
let sampling_handler = use_rmcp_client.then_some(Arc::new(CodexSamplingHandler::new()));

// Early exit if no servers are configured.
if mcp_servers.is_empty() {
return Ok((Self::default(), ClientStartErrors::default()));
return Ok((
Self {
clients: HashMap::new(),
tools: HashMap::new(),
sampling_handler,
},
ClientStartErrors::default(),
));
}

// Launch all configured servers concurrently.
Expand Down Expand Up @@ -222,13 +248,15 @@ impl McpConnectionManager {
_ => Ok(None),
};

let sampling_handler = sampling_handler.clone();

join_set.spawn(async move {
let McpServerConfig { transport, .. } = cfg;
let params = mcp_types::InitializeRequestParams {
capabilities: ClientCapabilities {
experimental: None,
roots: None,
sampling: None,
sampling: use_rmcp_client.then_some(json!({})),
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
// indicates this should be an empty object.
elicitation: Some(json!({})),
Expand Down Expand Up @@ -256,6 +284,7 @@ impl McpConnectionManager {
env,
params,
startup_timeout,
sampling_handler.clone(),
)
.await
}
Expand All @@ -267,6 +296,7 @@ impl McpConnectionManager {
params,
startup_timeout,
store_mode,
sampling_handler,
)
.await
}
Expand Down Expand Up @@ -315,7 +345,24 @@ impl McpConnectionManager {

let tools = qualify_tools(all_tools);

Ok((Self { clients, tools }, errors))
Ok((
Self {
clients,
tools,
sampling_handler,
},
errors,
))
}

/// Set the Config for the sampling handler.
///
/// This must be called after initialization to enable MCP servers
/// to make LLM sampling requests with proper configuration.
pub async fn set_config(&self, config: Arc<crate::config::Config>) {
if let Some(handler) = &self.sampling_handler {
handler.set_config(config).await;
}
}

/// Returns a single map that contains **all** tools. Each key is the
Expand Down
Loading