Skip to content
15 changes: 14 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"slices"
"sync"
"sync/atomic"

Expand All @@ -22,6 +23,7 @@ type Client struct {
requestID atomic.Int64
clientCapabilities mcp.ClientCapabilities
serverCapabilities mcp.ServerCapabilities
protocolVersion string
samplingHandler SamplingHandler
}

Expand Down Expand Up @@ -176,8 +178,19 @@ func (c *Client) Initialize(
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}

// Store serverCapabilities
// Validate protocol version
if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) {
return nil, mcp.UnsupportedProtocolVersionError{Version: result.ProtocolVersion}
}

// Store serverCapabilities and protocol version
c.serverCapabilities = result.Capabilities
c.protocolVersion = result.ProtocolVersion

// Set protocol version on HTTP transports
if httpConn, ok := c.transport.(transport.HTTPConnection); ok {
httpConn.SetProtocolVersion(result.ProtocolVersion)
}

// Send initialized notification
notification := mcp.JSONRPCNotification{
Expand Down
231 changes: 231 additions & 0 deletions client/protocol_negotiation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
package client

import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"

"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)

// mockProtocolTransport implements transport.Interface for testing protocol negotiation
type mockProtocolTransport struct {
responses map[string]string
notificationHandler func(mcp.JSONRPCNotification)
started bool
closed bool
}

func (m *mockProtocolTransport) Start(ctx context.Context) error {
m.started = true
return nil
}

func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
responseStr, ok := m.responses[request.Method]
if !ok {
return nil, fmt.Errorf("no mock response for method %s", request.Method)
}

return &transport.JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Result: json.RawMessage(responseStr),
}, nil
}

func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
return nil
}

func (m *mockProtocolTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
m.notificationHandler = handler
}

func (m *mockProtocolTransport) Close() error {
m.closed = true
return nil
}

func (m *mockProtocolTransport) GetSessionId() string {
return "mock-session"
}

func TestProtocolVersionNegotiation(t *testing.T) {
tests := []struct {
name string
serverVersion string
expectError bool
errorContains string
}{
{
name: "supported latest version",
serverVersion: mcp.LATEST_PROTOCOL_VERSION,
expectError: false,
},
{
name: "supported older version 2025-03-26",
serverVersion: "2025-03-26",
expectError: false,
},
{
name: "supported older version 2024-11-05",
serverVersion: "2024-11-05",
expectError: false,
},
{
name: "unsupported version",
serverVersion: "2023-01-01",
expectError: true,
errorContains: "unsupported protocol version",
},
{
name: "unsupported future version",
serverVersion: "2030-01-01",
expectError: true,
errorContains: "unsupported protocol version",
},
{
name: "empty protocol version",
serverVersion: "",
expectError: true,
errorContains: "unsupported protocol version",
},
{
name: "malformed protocol version - invalid format",
serverVersion: "not-a-date",
expectError: true,
errorContains: "unsupported protocol version",
},
{
name: "malformed protocol version - partial date",
serverVersion: "2025-06",
expectError: true,
errorContains: "unsupported protocol version",
},
{
name: "malformed protocol version - just numbers",
serverVersion: "20250618",
expectError: true,
errorContains: "unsupported protocol version",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock transport that returns specific version
mockTransport := &mockProtocolTransport{
responses: map[string]string{
"initialize": fmt.Sprintf(`{
"protocolVersion": "%s",
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"}
}`, tt.serverVersion),
},
}

client := NewClient(mockTransport)

_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
Capabilities: mcp.ClientCapabilities{},
},
})

if tt.expectError {
if err == nil {
t.Errorf("expected error but got none")
} else if !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("expected error containing %q, got %q", tt.errorContains, err.Error())
}
// Verify it's the correct error type
if !mcp.IsUnsupportedProtocolVersion(err) {
t.Errorf("expected UnsupportedProtocolVersionError, got %T", err)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// Verify the protocol version was stored
if client.protocolVersion != tt.serverVersion {
t.Errorf("expected protocol version %q, got %q", tt.serverVersion, client.protocolVersion)
}
}
})
}
}

// mockHTTPTransport implements both transport.Interface and transport.HTTPConnection
type mockHTTPTransport struct {
mockProtocolTransport
protocolVersion string
}

func (m *mockHTTPTransport) SetProtocolVersion(version string) {
m.protocolVersion = version
}

func TestProtocolVersionHeaderSetting(t *testing.T) {
// Create mock HTTP transport
mockTransport := &mockHTTPTransport{
mockProtocolTransport: mockProtocolTransport{
responses: map[string]string{
"initialize": fmt.Sprintf(`{
"protocolVersion": "%s",
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"}
}`, mcp.LATEST_PROTOCOL_VERSION),
},
},
}

client := NewClient(mockTransport)

_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
Capabilities: mcp.ClientCapabilities{},
},
})

if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Verify SetProtocolVersion was called on HTTP transport
if mockTransport.protocolVersion != mcp.LATEST_PROTOCOL_VERSION {
t.Errorf("expected SetProtocolVersion to be called with %q, got %q",
mcp.LATEST_PROTOCOL_VERSION, mockTransport.protocolVersion)
}
}

func TestUnsupportedProtocolVersionError_Is(t *testing.T) {
// Test that errors.Is works correctly with UnsupportedProtocolVersionError
err1 := mcp.UnsupportedProtocolVersionError{Version: "2023-01-01"}
err2 := mcp.UnsupportedProtocolVersionError{Version: "2024-01-01"}

// Test Is method
if !err1.Is(err2) {
t.Error("expected UnsupportedProtocolVersionError.Is to return true for same error type")
}

// Test with different error type
otherErr := fmt.Errorf("some other error")
if err1.Is(otherErr) {
t.Error("expected UnsupportedProtocolVersionError.Is to return false for different error type")
}

// Test IsUnsupportedProtocolVersion helper
if !mcp.IsUnsupportedProtocolVersion(err1) {
t.Error("expected IsUnsupportedProtocolVersion to return true")
}
if mcp.IsUnsupportedProtocolVersion(otherErr) {
t.Error("expected IsUnsupportedProtocolVersion to return false for different error type")
}
}
2 changes: 1 addition & 1 deletion client/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestStdioMCPClient(t *testing.T) {
defer cancel()

request := mcp.InitializeRequest{}
request.Params.ProtocolVersion = "1.0"
request.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
request.Params.ClientInfo = mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
Expand Down
7 changes: 7 additions & 0 deletions client/transport/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package transport

// Common HTTP header constants used across transports
const (
HeaderKeySessionID = "Mcp-Session-Id"
HeaderKeyProtocolVersion = "Mcp-Protocol-Version"
)
7 changes: 7 additions & 0 deletions client/transport/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ type BidirectionalInterface interface {
SetRequestHandler(handler RequestHandler)
}

// HTTPConnection is a Transport that runs over HTTP and supports
// protocol version headers.
type HTTPConnection interface {
Interface
SetProtocolVersion(version string)
}

type JSONRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID mcp.RequestId `json:"id"`
Expand Down
18 changes: 18 additions & 0 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type SSE struct {
started atomic.Bool
closed atomic.Bool
cancelSSEStream context.CancelFunc
protocolVersion atomic.Value // string

// OAuth support
oauthHandler *OAuthHandler
Expand Down Expand Up @@ -324,6 +325,12 @@ func (c *SSE) SendRequest(

// Set headers
req.Header.Set("Content-Type", "application/json")
// Set protocol version header if negotiated
if v := c.protocolVersion.Load(); v != nil {
if version, ok := v.(string); ok && version != "" {
req.Header.Set(HeaderKeyProtocolVersion, version)
}
}
for k, v := range c.headers {
req.Header.Set(k, v)
}
Expand Down Expand Up @@ -434,6 +441,11 @@ func (c *SSE) GetSessionId() string {
return ""
}

// SetProtocolVersion sets the negotiated protocol version for this connection.
func (c *SSE) SetProtocolVersion(version string) {
c.protocolVersion.Store(version)
}

// SendNotification sends a JSON-RPC notification to the server without expecting a response.
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
if c.endpoint == nil {
Expand All @@ -456,6 +468,12 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
}

req.Header.Set("Content-Type", "application/json")
// Set protocol version header if negotiated
if v := c.protocolVersion.Load(); v != nil {
if version, ok := v.(string); ok && version != "" {
req.Header.Set(HeaderKeyProtocolVersion, version)
}
}
// Set custom HTTP headers
for k, v := range c.headers {
req.Header.Set(k, v)
Expand Down
6 changes: 3 additions & 3 deletions client/transport/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,9 @@ func TestSSE(t *testing.T) {
t.Run("SSEEventWithoutEventField", func(t *testing.T) {
// Test that SSE events with only data field (no event field) are processed correctly
// This tests the fix for issue #369

var messageReceived chan struct{}

// Create a custom mock server that sends SSE events without event field
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand Down Expand Up @@ -449,7 +449,7 @@ func TestSSE(t *testing.T) {
messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)

// Signal that message was received
close(messageReceived)
})
Expand Down
Loading