Skip to content

Commit 4d9cc16

Browse files
committed
feat: add support for custom client notifications
MCP servers, particularly ones that offer "experimental" capabilities, may wish to handle custom client notifications that are not part of the standard MCP specification. This change introduces a new `CustomClientNotification` type that allows a server to process such custom notifications. - introduces `CustomClientNotification` to carry arbitrary methods/params while still preserving meta/extensions; wires it into the `ClientNotification` union and `serde` so `params` can be decoded with `params_as` - allows server handlers to receive custom notifications via a new `on_custom_notification` hook - adds integration coverage that sends a custom client notification end-to-end and asserts the server sees the method and payload Test: ```shell cargo test -p rmcp --features client test_custom_client_notification_reaches_server ```
1 parent cf45070 commit 4d9cc16

File tree

5 files changed

+229
-7
lines changed

5 files changed

+229
-7
lines changed

crates/rmcp/src/handler/server.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ impl<H: ServerHandler> Service<RoleServer> for H {
8989
ClientNotification::RootsListChangedNotification(_notification) => {
9090
self.on_roots_list_changed(context).await
9191
}
92+
ClientNotification::CustomClientNotification(notification) => {
93+
self.on_custom_notification(notification, context).await
94+
}
9295
};
9396
Ok(())
9497
}
@@ -224,6 +227,14 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
224227
) -> impl Future<Output = ()> + Send + '_ {
225228
std::future::ready(())
226229
}
230+
fn on_custom_notification(
231+
&self,
232+
notification: CustomClientNotification,
233+
context: NotificationContext<RoleServer>,
234+
) -> impl Future<Output = ()> + Send + '_ {
235+
let _ = (notification, context);
236+
std::future::ready(())
237+
}
227238

228239
fn get_info(&self) -> ServerInfo {
229240
ServerInfo::default()

crates/rmcp/src/model.rs

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,40 @@ const_string!(CancelledNotificationMethod = "notifications/cancelled");
627627
pub type CancelledNotification =
628628
Notification<CancelledNotificationMethod, CancelledNotificationParam>;
629629

630+
/// A catch-all notification the client can use to send custom messages to a server.
631+
///
632+
/// This preserves the raw `method` name and `params` payload so handlers can
633+
/// deserialize them into domain-specific types.
634+
#[derive(Debug, Clone)]
635+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
636+
pub struct CustomClientNotification {
637+
pub method: String,
638+
pub params: Option<Value>,
639+
/// extensions will carry anything possible in the context, including [`Meta`]
640+
///
641+
/// this is similar with the Extensions in `http` crate
642+
#[cfg_attr(feature = "schemars", schemars(skip))]
643+
pub extensions: Extensions,
644+
}
645+
646+
impl CustomClientNotification {
647+
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
648+
Self {
649+
method: method.into(),
650+
params,
651+
extensions: Extensions::default(),
652+
}
653+
}
654+
655+
/// Deserialize `params` into a strongly-typed structure.
656+
pub fn params_as<T: DeserializeOwned>(&self) -> Result<Option<T>, serde_json::Error> {
657+
self.params
658+
.as_ref()
659+
.map(|params| serde_json::from_value(params.clone()))
660+
.transpose()
661+
}
662+
}
663+
630664
const_string!(InitializeResultMethod = "initialize");
631665
/// # Initialization
632666
/// This request is sent from the client to the server when it first connects, asking it to begin initialization.
@@ -1748,7 +1782,8 @@ ts_union!(
17481782
| CancelledNotification
17491783
| ProgressNotification
17501784
| InitializedNotification
1751-
| RootsListChangedNotification;
1785+
| RootsListChangedNotification
1786+
| CustomClientNotification;
17521787
);
17531788

17541789
ts_union!(
@@ -1857,6 +1892,38 @@ mod tests {
18571892
assert_eq!(json, raw);
18581893
}
18591894

1895+
#[test]
1896+
fn test_custom_client_notification_roundtrip() {
1897+
let raw = json!( {
1898+
"jsonrpc": JsonRpcVersion2_0,
1899+
"method": "notifications/custom",
1900+
"params": {"foo": "bar"},
1901+
});
1902+
1903+
let message: ClientJsonRpcMessage =
1904+
serde_json::from_value(raw.clone()).expect("invalid notification");
1905+
match &message {
1906+
ClientJsonRpcMessage::Notification(JsonRpcNotification {
1907+
notification: ClientNotification::CustomClientNotification(notification),
1908+
..
1909+
}) => {
1910+
assert_eq!(notification.method, "notifications/custom");
1911+
assert_eq!(
1912+
notification
1913+
.params
1914+
.as_ref()
1915+
.and_then(|p| p.get("foo"))
1916+
.expect("foo present"),
1917+
"bar"
1918+
);
1919+
}
1920+
_ => panic!("Expected custom client notification"),
1921+
}
1922+
1923+
let json = serde_json::to_value(message).expect("valid json");
1924+
assert_eq!(json, raw);
1925+
}
1926+
18601927
#[test]
18611928
fn test_request_conversion() {
18621929
let raw = json!( {

crates/rmcp/src/model/meta.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
44
use serde_json::Value;
55

66
use super::{
7-
ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString,
8-
ProgressToken, ServerNotification, ServerRequest,
7+
ClientNotification, ClientRequest, CustomClientNotification, Extensions, JsonObject,
8+
JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
99
};
1010

1111
pub trait GetMeta {
@@ -18,6 +18,26 @@ pub trait GetExtensions {
1818
fn extensions_mut(&mut self) -> &mut Extensions;
1919
}
2020

21+
impl GetExtensions for CustomClientNotification {
22+
fn extensions(&self) -> &Extensions {
23+
&self.extensions
24+
}
25+
fn extensions_mut(&mut self) -> &mut Extensions {
26+
&mut self.extensions
27+
}
28+
}
29+
30+
impl GetMeta for CustomClientNotification {
31+
fn get_meta_mut(&mut self) -> &mut Meta {
32+
self.extensions_mut().get_or_insert_default()
33+
}
34+
fn get_meta(&self) -> &Meta {
35+
self.extensions()
36+
.get::<Meta>()
37+
.unwrap_or(Meta::static_empty())
38+
}
39+
}
40+
2141
macro_rules! variant_extension {
2242
(
2343
$Enum: ident {
@@ -84,6 +104,7 @@ variant_extension! {
84104
ProgressNotification
85105
InitializedNotification
86106
RootsListChangedNotification
107+
CustomClientNotification
87108
}
88109
}
89110

crates/rmcp/src/model/serde_impl.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::borrow::Cow;
33
use serde::{Deserialize, Serialize};
44

55
use super::{
6-
Extensions, Meta, Notification, NotificationNoParam, Request, RequestNoParam,
7-
RequestOptionalParam,
6+
CustomClientNotification, Extensions, Meta, Notification, NotificationNoParam, Request,
7+
RequestNoParam, RequestOptionalParam,
88
};
99
#[derive(Serialize, Deserialize)]
1010
struct WithMeta<'a, P> {
@@ -249,6 +249,59 @@ where
249249
}
250250
}
251251

252+
impl Serialize for CustomClientNotification {
253+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
254+
where
255+
S: serde::Serializer,
256+
{
257+
let extensions = &self.extensions;
258+
let _meta = extensions.get::<Meta>().map(Cow::Borrowed);
259+
let params = self.params.as_ref();
260+
261+
let params = if _meta.is_some() || params.is_some() {
262+
Some(WithMeta {
263+
_meta,
264+
_rest: &self.params,
265+
})
266+
} else {
267+
None
268+
};
269+
270+
ProxyOptionalParam::serialize(
271+
&ProxyOptionalParam {
272+
method: &self.method,
273+
params,
274+
},
275+
serializer,
276+
)
277+
}
278+
}
279+
280+
impl<'de> Deserialize<'de> for CustomClientNotification {
281+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
282+
where
283+
D: serde::Deserializer<'de>,
284+
{
285+
let body =
286+
ProxyOptionalParam::<'_, _, Option<serde_json::Value>>::deserialize(deserializer)?;
287+
let mut params = None;
288+
let mut _meta = None;
289+
if let Some(body_params) = body.params {
290+
params = body_params._rest;
291+
_meta = body_params._meta.map(|m| m.into_owned());
292+
}
293+
let mut extensions = Extensions::new();
294+
if let Some(meta) = _meta {
295+
extensions.insert(meta);
296+
}
297+
Ok(CustomClientNotification {
298+
extensions,
299+
method: body.method,
300+
params,
301+
})
302+
}
303+
}
304+
252305
#[cfg(test)]
253306
mod test {
254307
use serde_json::json;

crates/rmcp/tests/test_notification.rs

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ use std::sync::Arc;
33
use rmcp::{
44
ClientHandler, ServerHandler, ServiceExt,
55
model::{
6-
ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam,
6+
ClientNotification, CustomClientNotification, ResourceUpdatedNotificationParam,
7+
ServerCapabilities, ServerInfo, SubscribeRequestParam,
78
},
89
};
9-
use tokio::sync::Notify;
10+
use serde_json::json;
11+
use tokio::sync::{Mutex, Notify};
1012
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
1113

1214
pub struct Server {}
@@ -93,3 +95,71 @@ async fn test_server_notification() -> anyhow::Result<()> {
9395
client.cancel().await?;
9496
Ok(())
9597
}
98+
99+
struct CustomServer {
100+
receive_signal: Arc<Notify>,
101+
payload: Arc<Mutex<Option<(String, Option<serde_json::Value>)>>>,
102+
}
103+
104+
impl ServerHandler for CustomServer {
105+
async fn on_custom_notification(
106+
&self,
107+
notification: CustomClientNotification,
108+
_context: rmcp::service::NotificationContext<rmcp::RoleServer>,
109+
) {
110+
let CustomClientNotification { method, params, .. } = notification;
111+
let mut payload = self.payload.lock().await;
112+
*payload = Some((method, params));
113+
self.receive_signal.notify_one();
114+
}
115+
}
116+
117+
#[tokio::test]
118+
async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()> {
119+
let _ = tracing_subscriber::registry()
120+
.with(
121+
tracing_subscriber::EnvFilter::try_from_default_env()
122+
.unwrap_or_else(|_| "debug".to_string().into()),
123+
)
124+
.with(tracing_subscriber::fmt::layer())
125+
.try_init();
126+
127+
let (server_transport, client_transport) = tokio::io::duplex(4096);
128+
let receive_signal = Arc::new(Notify::new());
129+
let payload = Arc::new(Mutex::new(None));
130+
131+
{
132+
let receive_signal = receive_signal.clone();
133+
let payload = payload.clone();
134+
tokio::spawn(async move {
135+
let server = CustomServer {
136+
receive_signal,
137+
payload,
138+
}
139+
.serve(server_transport)
140+
.await?;
141+
server.waiting().await?;
142+
anyhow::Ok(())
143+
});
144+
}
145+
146+
let client = ().serve(client_transport).await?;
147+
148+
client
149+
.send_notification(ClientNotification::CustomClientNotification(
150+
CustomClientNotification::new(
151+
"notifications/custom-test",
152+
Some(json!({ "foo": "bar" })),
153+
),
154+
))
155+
.await?;
156+
157+
tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?;
158+
159+
let (method, params) = payload.lock().await.clone().expect("payload set");
160+
assert_eq!("notifications/custom-test", method);
161+
assert_eq!(Some(json!({ "foo": "bar" })), params);
162+
163+
client.cancel().await?;
164+
Ok(())
165+
}

0 commit comments

Comments
 (0)