Skip to content

Commit 3de6e90

Browse files
authored
Merge pull request #7 from tidewave-ai/sd-override-protocol-version
Add --override-protocol-version
2 parents 5ea4ea7 + ec8d292 commit 3de6e90

File tree

8 files changed

+187
-32
lines changed

8 files changed

+187
-32
lines changed

Cargo.lock

Lines changed: 10 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version = "0.2.0"
44
edition = "2024"
55

66
[dependencies]
7-
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "076dc2c2cd8910bee56bae13f29bbcff8c279666", features = [
7+
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "4c34b64b7f8dcabf94d52a9c6518c6b49c1f0451", features = [
88
"server",
99
"client",
1010
"reqwest",
@@ -29,7 +29,7 @@ version = "0.9"
2929
features = ["vendored"]
3030

3131
[dev-dependencies]
32-
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "076dc2c2cd8910bee56bae13f29bbcff8c279666", features = [
32+
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "4c34b64b7f8dcabf94d52a9c6518c6b49c1f0451", features = [
3333
"server",
3434
"client",
3535
"reqwest",

examples/echo_streamable.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use anyhow::Context;
22
use clap::Parser;
3-
use rmcp::transport::StreamableHttpServer;
3+
use rmcp::transport::streamable_http_server::{
4+
StreamableHttpService, session::local::LocalSessionManager,
5+
};
46
use tracing_subscriber::FmtSubscriber;
57

68
use rmcp::{
@@ -12,6 +14,10 @@ use rmcp::{
1214
pub struct Echo;
1315
#[tool(tool_box)]
1416
impl Echo {
17+
pub fn new() -> Self {
18+
Self {}
19+
}
20+
1521
#[tool(description = "Echo a message")]
1622
fn echo(&self, #[tool(param)] message: String) -> String {
1723
message
@@ -49,11 +55,17 @@ async fn main() -> anyhow::Result<()> {
4955

5056
tracing::subscriber::set_global_default(subscriber).context("Failed to set up logging")?;
5157

52-
let ct = StreamableHttpServer::serve(args.address)
53-
.await?
54-
.with_service(Echo::default);
58+
let service = StreamableHttpService::new(
59+
|| Ok(Echo::new()),
60+
LocalSessionManager::default().into(),
61+
Default::default(),
62+
);
63+
64+
let router = axum::Router::new().nest_service("/mcp", service);
65+
let tcp_listener = tokio::net::TcpListener::bind(args.address).await?;
66+
let _ = axum::serve(tcp_listener, router)
67+
.with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() })
68+
.await;
5569

56-
tokio::signal::ctrl_c().await?;
57-
ct.cancel();
5870
Ok(())
5971
}

src/cli.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@ pub struct Args {
1818
/// Initial retry interval in seconds. Default is 5 seconds
1919
#[arg(long, default_value = "5")]
2020
pub initial_retry_interval: u64,
21+
22+
#[arg(long)]
23+
/// Override the protocol version returned to the client
24+
pub override_protocol_version: Option<String>,
2125
}

src/core.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,9 @@ pub(crate) async fn connect_with_streamable(app_state: &AppState) -> Result<SseC
6464
rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
6565
uri: app_state.url.clone().into(),
6666
// we don't want the sdk to perform any retries
67-
retry_config: std::sync::Arc::new(
68-
rmcp::transport::common::client_side_sse::FixedInterval {
69-
max_times: Some(0),
70-
duration: Duration::from_millis(0),
71-
},
72-
),
67+
retry_config: std::sync::Arc::new(rmcp::transport::common::client_side_sse::NeverRetry),
7368
channel_buffer_capacity: 16,
69+
allow_stateless: true,
7470
},
7571
);
7672

src/main.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use anyhow::{Context, Result, anyhow};
22
use clap::Parser;
33
use futures::StreamExt;
44
use rmcp::{
5-
model::{ClientJsonRpcMessage, ErrorCode, ServerJsonRpcMessage},
5+
model::{ClientJsonRpcMessage, ErrorCode, ProtocolVersion, ServerJsonRpcMessage},
66
transport::{StreamableHttpClientTransport, Transport, sse_client::SseClientTransport},
77
};
88
use std::env;
@@ -128,12 +128,33 @@ async fn main() -> Result<()> {
128128
debug!("Starting MCP proxy with URL: {}", sse_url);
129129
debug!("Max disconnected time: {:?}s", args.max_disconnected_time);
130130

131+
// Parse protocol version override if provided
132+
let override_protocol_version = if let Some(version_str) = args.override_protocol_version {
133+
let protocol_version = match version_str.as_str() {
134+
"2024-11-05" => ProtocolVersion::V_2024_11_05,
135+
"2025-03-26" => ProtocolVersion::V_2025_03_26,
136+
_ => {
137+
return Err(anyhow!(
138+
"Unsupported protocol version: {}. Supported versions are: 2024-11-05, 2025-03-26",
139+
version_str
140+
));
141+
}
142+
};
143+
Some(protocol_version)
144+
} else {
145+
None
146+
};
147+
131148
// Set up communication channels
132149
let (reconnect_tx, mut reconnect_rx) = tokio::sync::mpsc::channel(10);
133150
let (timer_tx, mut timer_rx) = tokio::sync::mpsc::channel(10);
134151

135152
// Initialize application state
136-
let mut app_state = AppState::new(sse_url.clone(), args.max_disconnected_time);
153+
let mut app_state = AppState::new(
154+
sse_url.clone(),
155+
args.max_disconnected_time,
156+
override_protocol_version,
157+
);
137158
// Pass channel senders to state
138159
app_state.reconnect_tx = Some(reconnect_tx.clone());
139160
app_state.timer_tx = Some(timer_tx.clone());

src/state.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use anyhow::Result;
88
use futures::SinkExt;
99
use rmcp::model::{
1010
ClientJsonRpcMessage, ClientNotification, ClientRequest, EmptyResult, InitializedNotification,
11-
InitializedNotificationMethod, RequestId, ServerJsonRpcMessage,
11+
InitializedNotificationMethod, ProtocolVersion, RequestId, ServerJsonRpcMessage, ServerResult,
1212
};
1313
use std::collections::HashMap;
1414
use std::time::{Duration, Instant};
@@ -49,6 +49,8 @@ pub struct AppState {
4949
pub url: String,
5050
/// Maximum time to try reconnecting in seconds (None = infinity)
5151
pub max_disconnected_time: Option<u64>,
52+
/// Override protocol version
53+
pub override_protocol_version: Option<ProtocolVersion>,
5254
/// When we were disconnected
5355
pub disconnected_since: Option<Instant>,
5456
/// Current state of the application
@@ -78,10 +80,15 @@ pub struct AppState {
7880
}
7981

8082
impl AppState {
81-
pub fn new(url: String, max_disconnected_time: Option<u64>) -> Self {
83+
pub fn new(
84+
url: String,
85+
max_disconnected_time: Option<u64>,
86+
override_protocol_version: Option<ProtocolVersion>,
87+
) -> Self {
8288
Self {
8389
url,
8490
max_disconnected_time,
91+
override_protocol_version,
8592
disconnected_since: None,
8693
state: ProxyState::Connecting,
8794
connect_tries: 0,
@@ -286,6 +293,7 @@ impl AppState {
286293
"Initial connection successful, received init response. Waiting for client initialized."
287294
);
288295
self.state = ProxyState::WaitingForClientInitialized;
296+
message = self.maybe_overwrite_protocol_version(message);
289297
}
290298
}
291299
// --- End Initialization Response Handling ---
@@ -537,4 +545,25 @@ impl AppState {
537545
// Not a response/error, return Some(original_message)
538546
Some(message)
539547
}
548+
549+
fn maybe_overwrite_protocol_version(
550+
&mut self,
551+
message: ServerJsonRpcMessage,
552+
) -> ServerJsonRpcMessage {
553+
if let Some(protocol_version) = &self.override_protocol_version {
554+
match message {
555+
ServerJsonRpcMessage::Response(mut resp) => {
556+
if let ServerResult::InitializeResult(mut initialize_result) = resp.result {
557+
initialize_result.protocol_version = protocol_version.clone();
558+
resp.result = ServerResult::InitializeResult(initialize_result);
559+
return ServerJsonRpcMessage::Response(resp);
560+
}
561+
ServerJsonRpcMessage::Response(resp)
562+
}
563+
other => other,
564+
}
565+
} else {
566+
message
567+
}
568+
}
540569
}

tests/advanced_test.rs

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async fn create_sse_server(
110110
address: SocketAddr,
111111
) -> Result<(tokio::process::Child, String)> {
112112
let url = if server_name == "echo_streamable" {
113-
format!("http://{}", address)
113+
format!("http://{}/mcp", address)
114114
} else {
115115
format!("http://{}/sse", address)
116116
};
@@ -348,7 +348,7 @@ async fn initial_connection_retry(server_name: &str) -> Result<()> {
348348

349349
const BIND_ADDRESS: &str = "127.0.0.1:8184";
350350
let server_url = if server_name == "echo_streamable" {
351-
format!("http://{}", BIND_ADDRESS)
351+
format!("http://{}/mcp", BIND_ADDRESS)
352352
} else {
353353
format!("http://{}/sse", BIND_ADDRESS)
354354
};
@@ -545,3 +545,98 @@ async fn test_ping_when_disconnected() -> Result<()> {
545545

546546
Ok(())
547547
}
548+
549+
async fn protocol_version_override(server_name: &str) -> Result<()> {
550+
const BIND_ADDRESS: &str = "127.0.0.1:8186";
551+
552+
// Phase 1: Test normal behavior (no override)
553+
{
554+
let (server_handle, server_url) =
555+
create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
556+
let (child, mut reader, stderr_reader, mut stdin) =
557+
spawn_proxy(&server_url, vec![]).await?;
558+
let stderr_buffer = collect_stderr(stderr_reader);
559+
let _guard = TestGuard::new(child, server_handle, stderr_buffer);
560+
561+
// Send initialization message with 2025-03-26
562+
let init_message = r#"{"jsonrpc":"2.0","id":"init-normal","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#;
563+
stdin.write_all(init_message.as_bytes()).await?;
564+
stdin.write_all(b"\n").await?;
565+
566+
// Read the initialization response
567+
let mut response = String::new();
568+
timeout(Duration::from_secs(10), reader.read_line(&mut response)).await??;
569+
570+
// Verify the response contains the original protocol version
571+
assert!(
572+
response.contains("\"protocolVersion\":\"2025-03-26\""),
573+
"Expected server to respond with 2025-03-26 protocol version, got: {}",
574+
response
575+
);
576+
577+
// Send initialized notification
578+
let initialized_message = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
579+
stdin.write_all(initialized_message.as_bytes()).await?;
580+
stdin.write_all(b"\n").await?;
581+
582+
// Clean shutdown
583+
drop(stdin);
584+
}
585+
586+
// Give a moment for cleanup
587+
sleep(Duration::from_millis(500)).await;
588+
589+
// Phase 2: Test with protocol version override
590+
{
591+
let (server_handle, server_url) =
592+
create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
593+
let (child, mut reader, stderr_reader, mut stdin) = spawn_proxy(
594+
&server_url,
595+
vec!["--override-protocol-version", "2024-11-05"],
596+
)
597+
.await?;
598+
let stderr_buffer = collect_stderr(stderr_reader);
599+
let _guard = TestGuard::new(child, server_handle, stderr_buffer);
600+
601+
// Send initialization message with 2025-03-26 (same as phase 1)
602+
let init_message = r#"{"jsonrpc":"2.0","id":"init-override","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#;
603+
stdin.write_all(init_message.as_bytes()).await?;
604+
stdin.write_all(b"\n").await?;
605+
606+
// Read the initialization response
607+
let mut response = String::new();
608+
timeout(Duration::from_secs(10), reader.read_line(&mut response)).await??;
609+
610+
// Verify the response contains the overridden protocol version
611+
assert!(
612+
response.contains("\"protocolVersion\":\"2024-11-05\""),
613+
"Expected proxy to override protocol version to 2024-11-05, got: {}",
614+
response
615+
);
616+
617+
// Verify it does NOT contain the original version
618+
assert!(
619+
!response.contains("\"protocolVersion\":\"2025-03-26\""),
620+
"Protocol version should have been overridden from 2025-03-26 to 2024-11-05, got: {}",
621+
response
622+
);
623+
624+
// Send initialized notification
625+
let initialized_message = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
626+
stdin.write_all(initialized_message.as_bytes()).await?;
627+
stdin.write_all(b"\n").await?;
628+
629+
// Clean shutdown
630+
drop(stdin);
631+
}
632+
633+
Ok(())
634+
}
635+
636+
#[tokio::test]
637+
async fn test_protocol_version_override() -> Result<()> {
638+
protocol_version_override("echo").await?;
639+
protocol_version_override("echo_streamable").await?;
640+
641+
Ok(())
642+
}

0 commit comments

Comments
 (0)