From 557a80cb29874085859a38ddefbd4bb85f03be20 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 22 Jul 2025 22:23:55 +0300 Subject: [PATCH 1/7] feat: implement protocol version negotiation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement protocol version negotiation following the TypeScript SDK approach: - Update LATEST_PROTOCOL_VERSION to 2025-06-18 - Add client-side validation of server protocol version - Return UnsupportedProtocolVersionError for incompatible versions - Add Mcp-Protocol-Version header support for HTTP transports - Implement SetProtocolVersion method on HTTP connections - Add comprehensive tests for protocol negotiation This ensures both client and server agree on a mutually supported protocol version, preventing compatibility issues. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode --- client/client.go | 15 ++- client/protocol_negotiation_test.go | 182 ++++++++++++++++++++++++++++ client/stdio_test.go | 2 +- client/transport/interface.go | 7 ++ client/transport/sse.go | 18 +++ client/transport/streamable_http.go | 17 ++- mcp/errors.go | 19 +++ mcp/types.go | 5 +- server/streamable_http.go | 3 +- server/streamable_http_test.go | 118 +++++++++--------- testdata/mockstdio_server.go | 2 +- 11 files changed, 320 insertions(+), 68 deletions(-) create mode 100644 client/protocol_negotiation_test.go create mode 100644 mcp/errors.go diff --git a/client/client.go b/client/client.go index 6990ae19c..5e00f2e5c 100644 --- a/client/client.go +++ b/client/client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "sync" "sync/atomic" @@ -22,6 +23,7 @@ type Client struct { requestID atomic.Int64 clientCapabilities mcp.ClientCapabilities serverCapabilities mcp.ServerCapabilities + protocolVersion string samplingHandler SamplingHandler } @@ -176,8 +178,19 @@ func (c *Client) Initialize( return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - // Store serverCapabilities + // Validate protocol version + if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) { + return nil, mcp.UnsupportedProtocolVersionError{Version: result.ProtocolVersion} + } + + // Store serverCapabilities and protocol version c.serverCapabilities = result.Capabilities + c.protocolVersion = result.ProtocolVersion + + // Set protocol version on HTTP transports + if httpConn, ok := c.transport.(transport.HTTPConnection); ok { + httpConn.SetProtocolVersion(result.ProtocolVersion) + } // Send initialized notification notification := mcp.JSONRPCNotification{ diff --git a/client/protocol_negotiation_test.go b/client/protocol_negotiation_test.go new file mode 100644 index 000000000..2dc35116c --- /dev/null +++ b/client/protocol_negotiation_test.go @@ -0,0 +1,182 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// mockProtocolTransport implements transport.Interface for testing protocol negotiation +type mockProtocolTransport struct { + responses map[string]string + notificationHandler func(mcp.JSONRPCNotification) + started bool + closed bool +} + +func (m *mockProtocolTransport) Start(ctx context.Context) error { + m.started = true + return nil +} + +func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + responseStr, ok := m.responses[request.Method] + if !ok { + return nil, fmt.Errorf("no mock response for method %s", request.Method) + } + + return &transport.JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: json.RawMessage(responseStr), + }, nil +} + +func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + return nil +} + +func (m *mockProtocolTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + m.notificationHandler = handler +} + +func (m *mockProtocolTransport) Close() error { + m.closed = true + return nil +} + +func (m *mockProtocolTransport) GetSessionId() string { + return "mock-session" +} + +func TestProtocolVersionNegotiation(t *testing.T) { + tests := []struct { + name string + serverVersion string + expectError bool + errorContains string + }{ + { + name: "supported latest version", + serverVersion: mcp.LATEST_PROTOCOL_VERSION, + expectError: false, + }, + { + name: "supported older version 2025-03-26", + serverVersion: "2025-03-26", + expectError: false, + }, + { + name: "supported older version 2024-11-05", + serverVersion: "2024-11-05", + expectError: false, + }, + { + name: "unsupported version", + serverVersion: "2023-01-01", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "unsupported future version", + serverVersion: "2030-01-01", + expectError: true, + errorContains: "unsupported protocol version", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock transport that returns specific version + mockTransport := &mockProtocolTransport{ + responses: map[string]string{ + "initialize": fmt.Sprintf(`{ + "protocolVersion": "%s", + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"} + }`, tt.serverVersion), + }, + } + + client := NewClient(mockTransport) + + _, err := client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"}, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } else if !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("expected error containing %q, got %q", tt.errorContains, err.Error()) + } + // Verify it's the correct error type + if !mcp.IsUnsupportedProtocolVersion(err) { + t.Errorf("expected UnsupportedProtocolVersionError, got %T", err) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // Verify the protocol version was stored + if client.protocolVersion != tt.serverVersion { + t.Errorf("expected protocol version %q, got %q", tt.serverVersion, client.protocolVersion) + } + } + }) + } +} + +// mockHTTPTransport implements both transport.Interface and transport.HTTPConnection +type mockHTTPTransport struct { + mockProtocolTransport + protocolVersion string +} + +func (m *mockHTTPTransport) SetProtocolVersion(version string) { + m.protocolVersion = version +} + +func TestProtocolVersionHeaderSetting(t *testing.T) { + // Create mock HTTP transport + mockTransport := &mockHTTPTransport{ + mockProtocolTransport: mockProtocolTransport{ + responses: map[string]string{ + "initialize": fmt.Sprintf(`{ + "protocolVersion": "%s", + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"} + }`, mcp.LATEST_PROTOCOL_VERSION), + }, + }, + } + + client := NewClient(mockTransport) + + _, err := client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"}, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify SetProtocolVersion was called on HTTP transport + if mockTransport.protocolVersion != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("expected SetProtocolVersion to be called with %q, got %q", + mcp.LATEST_PROTOCOL_VERSION, mockTransport.protocolVersion) + } +} diff --git a/client/stdio_test.go b/client/stdio_test.go index f41e48114..7eb6dd38a 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -93,7 +93,7 @@ func TestStdioMCPClient(t *testing.T) { defer cancel() request := mcp.InitializeRequest{} - request.Params.ProtocolVersion = "1.0" + request.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION request.Params.ClientInfo = mcp.Implementation{ Name: "test-client", Version: "1.0.0", diff --git a/client/transport/interface.go b/client/transport/interface.go index 5f8ed6180..e6feeb742 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -47,6 +47,13 @@ type BidirectionalInterface interface { SetRequestHandler(handler RequestHandler) } +// HTTPConnection is a Transport that runs over HTTP and supports +// protocol version headers. +type HTTPConnection interface { + Interface + SetProtocolVersion(version string) +} + type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` ID mcp.RequestId `json:"id"` diff --git a/client/transport/sse.go b/client/transport/sse.go index ffe3247f0..6cd24b4ca 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -37,6 +37,7 @@ type SSE struct { started atomic.Bool closed atomic.Bool cancelSSEStream context.CancelFunc + protocolVersion atomic.Value // string // OAuth support oauthHandler *OAuthHandler @@ -324,6 +325,12 @@ func (c *SSE) SendRequest( // Set headers req.Header.Set("Content-Type", "application/json") + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(headerKeyProtocolVersion, version) + } + } for k, v := range c.headers { req.Header.Set(k, v) } @@ -434,6 +441,11 @@ func (c *SSE) GetSessionId() string { return "" } +// SetProtocolVersion sets the negotiated protocol version for this connection. +func (c *SSE) SetProtocolVersion(version string) { + c.protocolVersion.Store(version) +} + // SendNotification sends a JSON-RPC notification to the server without expecting a response. func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { if c.endpoint == nil { @@ -456,6 +468,12 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti } req.Header.Set("Content-Type", "application/json") + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(headerKeyProtocolVersion, version) + } + } // Set custom HTTP headers for k, v := range c.headers { req.Header.Set(k, v) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 245baba85..864d571a4 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -102,7 +102,8 @@ type StreamableHTTP struct { logger util.Logger getListeningEnabled bool - sessionID atomic.Value // string + sessionID atomic.Value // string + protocolVersion atomic.Value // string initialized chan struct{} initializedOnce sync.Once @@ -207,8 +208,14 @@ func (c *StreamableHTTP) Close() error { return nil } +// SetProtocolVersion sets the negotiated protocol version for this connection. +func (c *StreamableHTTP) SetProtocolVersion(version string) { + c.protocolVersion.Store(version) +} + const ( - headerKeySessionID = "Mcp-Session-Id" + headerKeySessionID = "Mcp-Session-Id" + headerKeyProtocolVersion = "Mcp-Protocol-Version" ) // ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required @@ -337,6 +344,12 @@ func (c *StreamableHTTP) sendHTTP( if sessionID != "" { req.Header.Set(headerKeySessionID, sessionID) } + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(headerKeyProtocolVersion, version) + } + } for k, v := range c.headers { req.Header.Set(k, v) } diff --git a/mcp/errors.go b/mcp/errors.go new file mode 100644 index 000000000..4d1639a9e --- /dev/null +++ b/mcp/errors.go @@ -0,0 +1,19 @@ +package mcp + +import "fmt" + +// UnsupportedProtocolVersionError is returned when the server responds with +// a protocol version that the client doesn't support. +type UnsupportedProtocolVersionError struct { + Version string +} + +func (e UnsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.Version) +} + +// IsUnsupportedProtocolVersion checks if an error is an UnsupportedProtocolVersionError +func IsUnsupportedProtocolVersion(err error) bool { + _, ok := err.(UnsupportedProtocolVersionError) + return ok +} diff --git a/mcp/types.go b/mcp/types.go index 909ebd892..0ef6811fd 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -97,12 +97,13 @@ func (t *URITemplate) UnmarshalJSON(data []byte) error { type JSONRPCMessage any // LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. -const LATEST_PROTOCOL_VERSION = "2025-03-26" +const LATEST_PROTOCOL_VERSION = "2025-06-18" // ValidProtocolVersions lists all known valid MCP protocol versions. var ValidProtocolVersions = []string{ - "2024-11-05", LATEST_PROTOCOL_VERSION, + "2025-03-26", + "2024-11-05", } // JSONRPC_VERSION is the version of JSON-RPC used by MCP. diff --git a/server/streamable_http.go b/server/streamable_http.go index b4d344abf..889326380 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -208,7 +208,8 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { // --- internal methods --- const ( - headerKeySessionID = "Mcp-Session-Id" + headerKeySessionID = "Mcp-Session-Id" + headerKeyProtocolVersion = "Mcp-Protocol-Version" ) func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 5be010a74..05e7fc4d6 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -28,8 +28,7 @@ var initRequest = map[string]any{ "id": 1, "method": "initialize", "params": map[string]any{ - "protocolVersion": "2025-03-26", - "clientInfo": map[string]any{ + "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, "clientInfo": map[string]any{ "name": "test-client", "version": "1.0.0", }, @@ -146,8 +145,8 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } - if responseMessage.Result["protocolVersion"] != "2025-03-26" { - t.Errorf("Expected protocol version 2025-03-26, got %s", responseMessage.Result["protocolVersion"]) + if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"]) } // get session id from header @@ -339,8 +338,8 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } - if responseMessage.Result["protocolVersion"] != "2025-03-26" { - t.Errorf("Expected protocol version 2025-03-26, got %s", responseMessage.Result["protocolVersion"]) + if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"]) } // no session id from header @@ -565,8 +564,7 @@ func TestStreamableHTTP_HttpHandler(t *testing.T) { "id": 1, "method": "initialize", "params": map[string]any{ - "protocolVersion": "2025-03-26", - "clientInfo": map[string]any{ + "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, "clientInfo": map[string]any{ "name": "test-client", "version": "1.0.0", }, @@ -586,8 +584,8 @@ func TestStreamableHTTP_HttpHandler(t *testing.T) { if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } - if responseMessage.Result["protocolVersion"] != "2025-03-26" { - t.Errorf("Expected protocol version 2025-03-26, got %s", responseMessage.Result["protocolVersion"]) + if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"]) } }) } @@ -844,56 +842,56 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) { } func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { - mcpServer := NewMCPServer("test-mcp-server", "1.0") - - var receivedHeaders struct { - contentType string - customHeader string - } - mcpServer.AddTool( - mcp.NewTool("check-headers"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - receivedHeaders.contentType = request.Header.Get("Content-Type") - receivedHeaders.customHeader = request.Header.Get("X-Custom-Header") - return mcp.NewToolResultText("ok"), nil - }, - ) - - server := NewTestStreamableHTTPServer(mcpServer) - defer server.Close() - - // Initialize to get session - resp, _ := postJSON(server.URL, initRequest) - sessionID := resp.Header.Get(headerKeySessionID) - resp.Body.Close() - - // Test header passthrough - toolRequest := map[string]any{ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": map[string]any{ - "name": "check-headers", - }, - } - toolBody, _ := json.Marshal(toolRequest) - req, _ := http.NewRequest("POST", server.URL, bytes.NewReader(toolBody)) - - const expectedContentType = "application/json" - const expectedCustomHeader = "test-value" - req.Header.Set("Content-Type", expectedContentType) - req.Header.Set("X-Custom-Header", expectedCustomHeader) - req.Header.Set(headerKeySessionID, sessionID) - - resp, _ = server.Client().Do(req) - resp.Body.Close() - - if receivedHeaders.contentType != expectedContentType { - t.Errorf("Expected Content-Type header '%s', got '%s'", expectedContentType, receivedHeaders.contentType) - } - if receivedHeaders.customHeader != expectedCustomHeader { - t.Errorf("Expected X-Custom-Header '%s', got '%s'", expectedCustomHeader, receivedHeaders.customHeader) - } + mcpServer := NewMCPServer("test-mcp-server", "1.0") + + var receivedHeaders struct { + contentType string + customHeader string + } + mcpServer.AddTool( + mcp.NewTool("check-headers"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + receivedHeaders.contentType = request.Header.Get("Content-Type") + receivedHeaders.customHeader = request.Header.Get("X-Custom-Header") + return mcp.NewToolResultText("ok"), nil + }, + ) + + server := NewTestStreamableHTTPServer(mcpServer) + defer server.Close() + + // Initialize to get session + resp, _ := postJSON(server.URL, initRequest) + sessionID := resp.Header.Get(headerKeySessionID) + resp.Body.Close() + + // Test header passthrough + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "check-headers", + }, + } + toolBody, _ := json.Marshal(toolRequest) + req, _ := http.NewRequest("POST", server.URL, bytes.NewReader(toolBody)) + + const expectedContentType = "application/json" + const expectedCustomHeader = "test-value" + req.Header.Set("Content-Type", expectedContentType) + req.Header.Set("X-Custom-Header", expectedCustomHeader) + req.Header.Set(headerKeySessionID, sessionID) + + resp, _ = server.Client().Do(req) + resp.Body.Close() + + if receivedHeaders.contentType != expectedContentType { + t.Errorf("Expected Content-Type header '%s', got '%s'", expectedContentType, receivedHeaders.contentType) + } + if receivedHeaders.customHeader != expectedCustomHeader { + t.Errorf("Expected X-Custom-Header '%s', got '%s'", expectedCustomHeader, receivedHeaders.customHeader) + } } func postJSON(url string, bodyObject any) (*http.Response, error) { diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index f561285e9..30bf0c001 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -52,7 +52,7 @@ func handleRequest(request JSONRPCRequest) JSONRPCResponse { switch request.Method { case "initialize": response.Result = map[string]any{ - "protocolVersion": "1.0", + "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, "serverInfo": map[string]any{ "name": "mock-server", "version": "1.0.0", From 2be3442f26c50c22fbeea568cc11b6b4016ab535 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 22 Jul 2025 22:50:10 +0300 Subject: [PATCH 2/7] fmt --- client/transport/sse_test.go | 6 +++--- client/transport/streamable_http_test.go | 4 ++-- mcp/tools.go | 16 ++++++++-------- mcptest/mcptest.go | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index a672e02fe..f72c8e8c8 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -408,9 +408,9 @@ func TestSSE(t *testing.T) { t.Run("SSEEventWithoutEventField", func(t *testing.T) { // Test that SSE events with only data field (no event field) are processed correctly // This tests the fix for issue #369 - + var messageReceived chan struct{} - + // Create a custom mock server that sends SSE events without event field sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") @@ -449,7 +449,7 @@ func TestSSE(t *testing.T) { messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) - + // Signal that message was received close(messageReceived) }) diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 25962940e..4831d5ecc 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -418,7 +418,7 @@ func TestStreamableHTTP(t *testing.T) { t.Run("SSEEventWithoutEventField", func(t *testing.T) { // Test that SSE events with only data field (no event field) are processed correctly // This tests the fix for issue #369 - + // Create a custom mock server that sends SSE events without event field handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -438,7 +438,7 @@ func TestStreamableHTTP(t *testing.T) { // This should be processed as a "message" event according to SSE spec w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + response := map[string]any{ "jsonrpc": "2.0", "id": request["id"], diff --git a/mcp/tools.go b/mcp/tools.go index 3227a0f61..7ec7f8f26 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -467,24 +467,24 @@ func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) { // MarshalJSON implements custom JSON marshaling for CallToolResult func (r CallToolResult) MarshalJSON() ([]byte, error) { m := make(map[string]any) - + // Marshal Meta if present if r.Meta != nil { m["_meta"] = r.Meta } - + // Marshal Content array content := make([]any, len(r.Content)) for i, c := range r.Content { content[i] = c } m["content"] = content - + // Marshal IsError if true if r.IsError { m["isError"] = r.IsError } - + return json.Marshal(m) } @@ -494,14 +494,14 @@ func (r *CallToolResult) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &raw); err != nil { return err } - + // Unmarshal Meta if meta, ok := raw["_meta"]; ok { if metaMap, ok := meta.(map[string]any); ok { r.Meta = metaMap } } - + // Unmarshal Content array if contentRaw, ok := raw["content"]; ok { if contentArray, ok := contentRaw.([]any); ok { @@ -519,14 +519,14 @@ func (r *CallToolResult) UnmarshalJSON(data []byte) error { } } } - + // Unmarshal IsError if isError, ok := raw["isError"]; ok { if isErrorBool, ok := isError.(bool); ok { r.IsError = isErrorBool } } - + return nil } diff --git a/mcptest/mcptest.go b/mcptest/mcptest.go index bc7ccc0fa..31bf3c886 100644 --- a/mcptest/mcptest.go +++ b/mcptest/mcptest.go @@ -142,7 +142,7 @@ func (s *Server) Start(ctx context.Context) error { mcpServer.AddTools(s.tools...) mcpServer.AddPrompts(s.prompts...) mcpServer.AddResources(s.resources...) - + for _, template := range s.resourceTemplates { mcpServer.AddResourceTemplate(template.Template, template.Handler) } From 2a6718186ad7132638c0cbe49c84a2d840d92998 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Wed, 23 Jul 2025 10:15:42 +0300 Subject: [PATCH 3/7] refactor: improve protocol negotiation implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move HTTP header constants to common location to avoid duplication - Add errors.Is interface to UnsupportedProtocolVersionError for better Go error handling - Add comprehensive edge case tests for empty and malformed protocol versions - Ensure consistent header constant usage across client and server packages 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode --- client/protocol_negotiation_test.go | 49 +++++++++++++++++++++++++++++ client/transport/constants.go | 7 +++++ client/transport/sse.go | 4 +-- client/transport/streamable_http.go | 13 +++----- mcp/errors.go | 6 ++++ server/constants.go | 7 +++++ server/streamable_http.go | 13 +++----- server/streamable_http_test.go | 24 +++++++------- 8 files changed, 91 insertions(+), 32 deletions(-) create mode 100644 client/transport/constants.go create mode 100644 server/constants.go diff --git a/client/protocol_negotiation_test.go b/client/protocol_negotiation_test.go index 2dc35116c..022b7fc6d 100644 --- a/client/protocol_negotiation_test.go +++ b/client/protocol_negotiation_test.go @@ -88,6 +88,30 @@ func TestProtocolVersionNegotiation(t *testing.T) { expectError: true, errorContains: "unsupported protocol version", }, + { + name: "empty protocol version", + serverVersion: "", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "malformed protocol version - invalid format", + serverVersion: "not-a-date", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "malformed protocol version - partial date", + serverVersion: "2025-06", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "malformed protocol version - just numbers", + serverVersion: "20250618", + expectError: true, + errorContains: "unsupported protocol version", + }, } for _, tt := range tests { @@ -180,3 +204,28 @@ func TestProtocolVersionHeaderSetting(t *testing.T) { mcp.LATEST_PROTOCOL_VERSION, mockTransport.protocolVersion) } } + +func TestUnsupportedProtocolVersionError_Is(t *testing.T) { + // Test that errors.Is works correctly with UnsupportedProtocolVersionError + err1 := mcp.UnsupportedProtocolVersionError{Version: "2023-01-01"} + err2 := mcp.UnsupportedProtocolVersionError{Version: "2024-01-01"} + + // Test Is method + if !err1.Is(err2) { + t.Error("expected UnsupportedProtocolVersionError.Is to return true for same error type") + } + + // Test with different error type + otherErr := fmt.Errorf("some other error") + if err1.Is(otherErr) { + t.Error("expected UnsupportedProtocolVersionError.Is to return false for different error type") + } + + // Test IsUnsupportedProtocolVersion helper + if !mcp.IsUnsupportedProtocolVersion(err1) { + t.Error("expected IsUnsupportedProtocolVersion to return true") + } + if mcp.IsUnsupportedProtocolVersion(otherErr) { + t.Error("expected IsUnsupportedProtocolVersion to return false for different error type") + } +} diff --git a/client/transport/constants.go b/client/transport/constants.go new file mode 100644 index 000000000..2fb503084 --- /dev/null +++ b/client/transport/constants.go @@ -0,0 +1,7 @@ +package transport + +// Common HTTP header constants used across transports +const ( + HeaderKeySessionID = "Mcp-Session-Id" + HeaderKeyProtocolVersion = "Mcp-Protocol-Version" +) diff --git a/client/transport/sse.go b/client/transport/sse.go index 6cd24b4ca..97f78192f 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -328,7 +328,7 @@ func (c *SSE) SendRequest( // Set protocol version header if negotiated if v := c.protocolVersion.Load(); v != nil { if version, ok := v.(string); ok && version != "" { - req.Header.Set(headerKeyProtocolVersion, version) + req.Header.Set(HeaderKeyProtocolVersion, version) } } for k, v := range c.headers { @@ -471,7 +471,7 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti // Set protocol version header if negotiated if v := c.protocolVersion.Load(); v != nil { if version, ok := v.(string); ok && version != "" { - req.Header.Set(headerKeyProtocolVersion, version) + req.Header.Set(HeaderKeyProtocolVersion, version) } } // Set custom HTTP headers diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 864d571a4..8a4568c1d 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -195,7 +195,7 @@ func (c *StreamableHTTP) Close() error { c.logger.Errorf("failed to create close request: %v", err) return } - req.Header.Set(headerKeySessionID, sessionId) + req.Header.Set(HeaderKeySessionID, sessionId) res, err := c.httpClient.Do(req) if err != nil { c.logger.Errorf("failed to send close request: %v", err) @@ -213,11 +213,6 @@ func (c *StreamableHTTP) SetProtocolVersion(version string) { c.protocolVersion.Store(version) } -const ( - headerKeySessionID = "Mcp-Session-Id" - headerKeyProtocolVersion = "Mcp-Protocol-Version" -) - // ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required") @@ -290,7 +285,7 @@ func (c *StreamableHTTP) SendRequest( if request.Method == string(mcp.MethodInitialize) { // saved the received session ID in the response // empty session ID is allowed - if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { + if sessionID := resp.Header.Get(HeaderKeySessionID); sessionID != "" { c.sessionID.Store(sessionID) } @@ -342,12 +337,12 @@ func (c *StreamableHTTP) sendHTTP( req.Header.Set("Accept", acceptType) sessionID := c.sessionID.Load().(string) if sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) } // Set protocol version header if negotiated if v := c.protocolVersion.Load(); v != nil { if version, ok := v.(string); ok && version != "" { - req.Header.Set(headerKeyProtocolVersion, version) + req.Header.Set(HeaderKeyProtocolVersion, version) } } for k, v := range c.headers { diff --git a/mcp/errors.go b/mcp/errors.go index 4d1639a9e..01888bf5b 100644 --- a/mcp/errors.go +++ b/mcp/errors.go @@ -12,6 +12,12 @@ func (e UnsupportedProtocolVersionError) Error() string { return fmt.Sprintf("unsupported protocol version: %q", e.Version) } +// Is implements the errors.Is interface for better error handling +func (e UnsupportedProtocolVersionError) Is(target error) bool { + _, ok := target.(UnsupportedProtocolVersionError) + return ok +} + // IsUnsupportedProtocolVersion checks if an error is an UnsupportedProtocolVersionError func IsUnsupportedProtocolVersion(err error) bool { _, ok := err.(UnsupportedProtocolVersionError) diff --git a/server/constants.go b/server/constants.go new file mode 100644 index 000000000..e071b2ef4 --- /dev/null +++ b/server/constants.go @@ -0,0 +1,7 @@ +package server + +// Common HTTP header constants used across server transports +const ( + HeaderKeySessionID = "Mcp-Session-Id" + HeaderKeyProtocolVersion = "Mcp-Protocol-Version" +) diff --git a/server/streamable_http.go b/server/streamable_http.go index 889326380..b0bfac00a 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -207,11 +207,6 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { // --- internal methods --- -const ( - headerKeySessionID = "Mcp-Session-Id" - headerKeyProtocolVersion = "Mcp-Protocol-Version" -) - func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { // post request carry request/notification message @@ -248,7 +243,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } else { // Get session ID from header. // Stateful servers need the client to carry the session ID. - sessionID = r.Header.Get(headerKeySessionID) + sessionID = r.Header.Get(HeaderKeySessionID) isTerminated, err := s.sessionIdManager.Validate(sessionID) if err != nil { http.Error(w, "Invalid session ID", http.StatusBadRequest) @@ -348,7 +343,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request w.Header().Set("Content-Type", "application/json") if isInitializeRequest && sessionID != "" { // send the session ID back to the client - w.Header().Set(headerKeySessionID, sessionID) + w.Header().Set(HeaderKeySessionID, sessionID) } w.WriteHeader(http.StatusOK) err := json.NewEncoder(w).Encode(response) @@ -362,7 +357,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) // get request is for listening to notifications // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server - sessionID := r.Header.Get(headerKeySessionID) + sessionID := r.Header.Get(HeaderKeySessionID) // the specification didn't say we should validate the session id if sessionID == "" { @@ -461,7 +456,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { // delete request terminate the session - sessionID := r.Header.Get(headerKeySessionID) + sessionID := r.Header.Get(HeaderKeySessionID) notAllowed, err := s.sessionIdManager.Terminate(sessionID) if err != nil { http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError) diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 05e7fc4d6..6f4a6edad 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -150,7 +150,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { } // get session id from header - sessionID = resp.Header.Get(headerKeySessionID) + sessionID = resp.Header.Get(HeaderKeySessionID) if sessionID == "" { t.Fatalf("Expected session id in header, got %s", sessionID) } @@ -170,7 +170,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, err := server.Client().Do(req) if err != nil { @@ -215,7 +215,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, err := server.Client().Do(req) if err != nil { t.Fatalf("Failed to send message: %v", err) @@ -245,7 +245,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, "dummy-session-id") + req.Header.Set(HeaderKeySessionID, "dummy-session-id") resp, err := server.Client().Do(req) if err != nil { @@ -274,7 +274,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, err := server.Client().Do(req) if err != nil { @@ -343,7 +343,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { } // no session id from header - sessionID := resp.Header.Get(headerKeySessionID) + sessionID := resp.Header.Get(HeaderKeySessionID) if sessionID != "" { t.Fatalf("Expected no session id in header, got %s", sessionID) } @@ -432,7 +432,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, "dummy-session-id") + req.Header.Set(HeaderKeySessionID, "dummy-session-id") resp, err := server.Client().Do(req) if err != nil { @@ -472,7 +472,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, "mcp-session-2c44d701-fd50-44ce-92b8-dec46185a741") + req.Header.Set(HeaderKeySessionID, "mcp-session-2c44d701-fd50-44ce-92b8-dec46185a741") resp, err := server.Client().Do(req) if err != nil { @@ -747,7 +747,7 @@ func TestStreamableHTTP_SessionWithLogging(t *testing.T) { t.Fatalf("Failed to send init request: %v", err) } defer initResp.Body.Close() - sessionID := initResp.Header.Get(headerKeySessionID) + sessionID := initResp.Header.Get(HeaderKeySessionID) if sessionID == "" { t.Fatal("Expected session id in header") } @@ -767,7 +767,7 @@ func TestStreamableHTTP_SessionWithLogging(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, err := testServer.Client().Do(req) if err != nil { @@ -862,7 +862,7 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { // Initialize to get session resp, _ := postJSON(server.URL, initRequest) - sessionID := resp.Header.Get(headerKeySessionID) + sessionID := resp.Header.Get(HeaderKeySessionID) resp.Body.Close() // Test header passthrough @@ -881,7 +881,7 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { const expectedCustomHeader = "test-value" req.Header.Set("Content-Type", expectedContentType) req.Header.Set("X-Custom-Header", expectedCustomHeader) - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, _ = server.Client().Do(req) resp.Body.Close() From 79460df6e6d2461ccdc06f53f7084fd5060fae83 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Wed, 23 Jul 2025 10:31:30 +0300 Subject: [PATCH 4/7] fmt --- server/session_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/session_test.go b/server/session_test.go index 35f9b8db2..9bd8bc9fa 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -902,7 +902,7 @@ func TestMCPServer_SessionToolCapabilitiesBehavior(t *testing.T) { validateServer func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools) }{ { - name: "no tool capabilities provided", + name: "no tool capabilities provided", serverOptions: []ServerOption{ // No WithToolCapabilities }, From 278238b03fcfee02b9c8ce83551f3e4b4930b30d Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 24 Jul 2025 15:51:46 +0300 Subject: [PATCH 5/7] fix: include protocol version header in DELETE request for session termination MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As per MCP specification, the MCP-Protocol-Version header must be included on all subsequent requests to the MCP server, including DELETE requests for terminating sessions. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode --- client/transport/streamable_http.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 8a4568c1d..4b3613262 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -196,6 +196,12 @@ func (c *StreamableHTTP) Close() error { return } req.Header.Set(HeaderKeySessionID, sessionId) + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(HeaderKeyProtocolVersion, version) + } + } res, err := c.httpClient.Do(req) if err != nil { c.logger.Errorf("failed to send close request: %v", err) From 6a868bc71c7bcb12a527e4161c113c016a5ecea5 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 24 Jul 2025 16:02:12 +0300 Subject: [PATCH 6/7] fix: maintain backward compatibility for protocol version negotiation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the server does not receive an MCP-Protocol-Version header, it should assume protocol version 2025-03-26 for backward compatibility as per the MCP specification. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode --- server/server.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server/server.go b/server/server.go index 896eb9b75..a98a2132b 100644 --- a/server/server.go +++ b/server/server.go @@ -604,6 +604,14 @@ func (s *MCPServer) handleInitialize( } func (s *MCPServer) protocolVersion(clientVersion string) string { + // For backwards compatibility, if the server does not receive an MCP-Protocol-Version header, + // and has no other way to identify the version - for example, by relying on the protocol version negotiated + // during initialization - the server SHOULD assume protocol version 2025-03-26 + // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + if len(clientVersion) == 0 { + clientVersion = "2025-03-26" + } + if slices.Contains(mcp.ValidProtocolVersions, clientVersion) { return clientVersion } From b5b7feaa41b732fc5e09d51661d6a09ebc72e4df Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 24 Jul 2025 16:07:42 +0300 Subject: [PATCH 7/7] test: update tests to reflect backward compatibility behavior MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tests now expect protocol version 2025-03-26 when no protocol version is provided during initialization, as per MCP specification for backward compatibility. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode --- server/server_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index bd7dd4719..aca99ef60 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -42,7 +42,7 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name) @@ -70,7 +70,7 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name) @@ -107,7 +107,7 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name) @@ -407,7 +407,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name)