Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions codex-rs/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ portable-pty = { workspace = true }
rand = { workspace = true }
regex-lite = { workspace = true }
reqwest = { workspace = true, features = ["json", "stream"] }
rmcp = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
sha1 = { workspace = true }
Expand Down
5 changes: 4 additions & 1 deletion codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ pub(crate) async fn stream_chat_completions(
let mut messages = Vec::<serde_json::Value>::new();

let full_instructions = prompt.get_full_instructions(model_family);
messages.push(json!({"role": "system", "content": full_instructions}));
// Only add system message if instructions are non-empty
if !full_instructions.is_empty() {
messages.push(json!({"role": "system", "content": full_instructions}));
}

let input = prompt.get_formatted_input();

Expand Down
5 changes: 5 additions & 0 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ impl Session {
model_reasoning_summary,
conversation_id,
);

// Set the config in the MCP connection manager's sampling handler
// so that MCP servers can make LLM requests with proper settings.
mcp_connection_manager.set_config(config.clone()).await;

let turn_context = TurnContext {
client,
tools_config: ToolsConfig::new(&ToolsConfigParams {
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub mod git_info;
pub mod landlock;
pub mod mcp;
mod mcp_connection_manager;
mod mcp_sampling_handler;
mod mcp_tool_call;
mod message_history;
mod model_provider_info;
Expand Down
63 changes: 57 additions & 6 deletions codex-rs/core/src/mcp_connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use tracing::warn;

use crate::config_types::McpServerConfig;
use crate::config_types::McpServerTransportConfig;
use crate::mcp_sampling_handler::CodexSamplingHandler;

/// Delimiter used to separate the server name from the tool name in a fully
/// qualified tool name.
Expand Down Expand Up @@ -109,9 +110,12 @@ impl McpClientAdapter {
env: Option<HashMap<String, String>>,
params: mcp_types::InitializeRequestParams,
startup_timeout: Duration,
sampling_handler: Option<Arc<CodexSamplingHandler>>,
) -> Result<Self> {
if use_rmcp_client {
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?);
let handler: Option<Arc<dyn codex_rmcp_client::SamplingHandler>> =
sampling_handler.map(|h| h as Arc<dyn codex_rmcp_client::SamplingHandler>);
let client = Arc::new(RmcpClient::new_stdio_client(program, args, env, handler).await?);
client.initialize(params, Some(startup_timeout)).await?;
Ok(McpClientAdapter::Rmcp(client))
} else {
Expand All @@ -128,10 +132,19 @@ impl McpClientAdapter {
params: mcp_types::InitializeRequestParams,
startup_timeout: Duration,
store_mode: OAuthCredentialsStoreMode,
sampling_handler: Option<Arc<CodexSamplingHandler>>,
) -> Result<Self> {
let handler: Option<Arc<dyn codex_rmcp_client::SamplingHandler>> =
sampling_handler.map(|h| h as Arc<dyn codex_rmcp_client::SamplingHandler>);
let client = Arc::new(
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode)
.await?,
RmcpClient::new_streamable_http_client(
&server_name,
&url,
bearer_token,
store_mode,
handler,
)
.await?,
);
client.initialize(params, Some(startup_timeout)).await?;
Ok(McpClientAdapter::Rmcp(client))
Expand Down Expand Up @@ -172,6 +185,9 @@ pub(crate) struct McpConnectionManager {

/// Fully qualified tool name -> tool instance.
tools: HashMap<String, ToolInfo>,

/// Sampling handler shared across all MCP clients for LLM requests.
sampling_handler: Option<Arc<CodexSamplingHandler>>,
}

impl McpConnectionManager {
Expand All @@ -188,9 +204,23 @@ impl McpConnectionManager {
use_rmcp_client: bool,
store_mode: OAuthCredentialsStoreMode,
) -> Result<(Self, ClientStartErrors)> {
// Create sampling handler if using rmcp client
let sampling_handler = if use_rmcp_client {
Some(Arc::new(CodexSamplingHandler::new()))
} else {
None
};

// Early exit if no servers are configured.
if mcp_servers.is_empty() {
return Ok((Self::default(), ClientStartErrors::default()));
return Ok((
Self {
clients: HashMap::new(),
tools: HashMap::new(),
sampling_handler,
},
ClientStartErrors::default(),
));
}

// Launch all configured servers concurrently.
Expand Down Expand Up @@ -222,13 +252,15 @@ impl McpConnectionManager {
_ => Ok(None),
};

let sampling_handler = sampling_handler.clone();

join_set.spawn(async move {
let McpServerConfig { transport, .. } = cfg;
let params = mcp_types::InitializeRequestParams {
capabilities: ClientCapabilities {
experimental: None,
roots: None,
sampling: None,
sampling: Some(json!({})),
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
// indicates this should be an empty object.
elicitation: Some(json!({})),
Expand Down Expand Up @@ -256,6 +288,7 @@ impl McpConnectionManager {
env,
params,
startup_timeout,
sampling_handler.clone(),
)
.await
}
Expand All @@ -267,6 +300,7 @@ impl McpConnectionManager {
params,
startup_timeout,
store_mode,
sampling_handler,
)
.await
}
Expand Down Expand Up @@ -315,7 +349,24 @@ impl McpConnectionManager {

let tools = qualify_tools(all_tools);

Ok((Self { clients, tools }, errors))
Ok((
Self {
clients,
tools,
sampling_handler,
},
errors,
))
}

/// Set the Config for the sampling handler.
///
/// This must be called after initialization to enable MCP servers
/// to make LLM sampling requests with proper configuration.
pub async fn set_config(&self, config: Arc<crate::config::Config>) {
if let Some(handler) = &self.sampling_handler {
handler.set_config(config).await;
}
}

/// Returns a single map that contains **all** tools. Each key is the
Expand Down
Loading