Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ func WithLogger(logger util.Logger) StreamableHTTPOption {
}
}

// WithDisableSSEUpgrade disables automatic upgrade to SSE when notifications are sent.
// When enabled, responses will always be returned as direct JSON responses,
// making it compatible with HTTP streaming clients like the TypeScript MCP SDK.
// The default is false (SSE upgrade enabled for backward compatibility).
func WithDisableSSEUpgrade(disable bool) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
s.disableSSEUpgrade = disable
}
}

// StreamableHTTPServer implements a Streamable-http based MCP server.
// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
Expand Down Expand Up @@ -127,6 +137,7 @@ type StreamableHTTPServer struct {
sessionIdManager SessionIdManager
listenHeartbeatInterval time.Duration
logger util.Logger
disableSSEUpgrade bool
}

// NewStreamableHTTPServer creates a new streamable-http server instance
Expand Down Expand Up @@ -253,7 +264,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
}
}

session := newStreamableHttpSession(sessionID, s.sessionTools)
session := newStreamableHttpSession(sessionID, s.sessionTools, s)

// Set the client context before handling the message
ctx := s.server.WithContext(r.Context(), session)
Expand Down Expand Up @@ -363,7 +374,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
sessionID = uuid.New().String()
}

session := newStreamableHttpSession(sessionID, s.sessionTools)
session := newStreamableHttpSession(sessionID, s.sessionTools, s)
if err := s.server.RegisterSession(r.Context(), session); err != nil {
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
return
Expand Down Expand Up @@ -547,13 +558,15 @@ type streamableHttpSession struct {
notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
tools *sessionToolsStore
upgradeToSSE atomic.Bool
server *StreamableHTTPServer // reference to server for configuration access
}

func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, server *StreamableHTTPServer) *streamableHttpSession {
return &streamableHttpSession{
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
tools: toolStore,
server: server,
}
}

Expand Down Expand Up @@ -588,6 +601,10 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
var _ SessionWithTools = (*streamableHttpSession)(nil)

func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
// Check if SSE upgrade is disabled on the server
if s.server != nil && s.server.disableSSEUpgrade {
return // Don't upgrade to SSE
}
s.upgradeToSSE.Store(true)
}

Expand Down
119 changes: 119 additions & 0 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,125 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) {
})
}

// TestStreamableHTTPServer_DisableSSEUpgrade tests that notifications don't upgrade to SSE
// when WithDisableSSEUpgrade(true) is set, ensuring compatibility with HTTP streaming clients
func TestStreamableHTTPServer_DisableSSEUpgrade(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
mcpServer.AddTool(mcp.Tool{
Name: "test_tool",
Description: "Test tool that sends notifications",
InputSchema: mcp.ToolInputSchema{Type: "object"},
}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Send a notification during tool execution
server := ServerFromContext(ctx)
err := server.SendNotificationToClient(ctx, "test/notification", map[string]any{
"message": "test notification",
})
if err != nil {
return nil, err
}

return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "tool completed"},
},
}, nil
})

// Create server with SSE upgrade disabled
server := NewTestStreamableHTTPServer(mcpServer, WithDisableSSEUpgrade(true))
defer server.Close()

// Send initialize request
initResp, err := postJSON(server.URL, map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-03-26",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "test", "version": "1.0.0"},
},
})
if err != nil {
t.Fatalf("Failed to send initialize request: %v", err)
}
defer initResp.Body.Close()

if initResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for initialize, got %d", initResp.StatusCode)
}

sessionID := initResp.Header.Get("Mcp-Session-Id")
if sessionID == "" {
t.Fatal("Expected session ID in response header")
}

// Send tool call request that triggers notification
toolReq, _ := http.NewRequest("POST", server.URL, strings.NewReader(`{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "test_tool",
"arguments": {}
}
}`))
toolReq.Header.Set("Content-Type", "application/json")
toolReq.Header.Set("Mcp-Session-Id", sessionID)

resp, err := http.DefaultClient.Do(toolReq)
if err != nil {
t.Fatalf("Failed to send tool call request: %v", err)
}
defer resp.Body.Close()

// Should receive JSON response (200 OK), not SSE (202 Accepted)
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Errorf("Expected status 200, got %d. Response: %s", resp.StatusCode, string(bodyBytes))
}

// Should be JSON, not SSE
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected content-type application/json, got %s", contentType)
}

// Read and verify the response contains the tool result
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}

var response map[string]any
if err := json.Unmarshal(responseBody, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}

// Check it's a proper JSON-RPC response
if response["id"].(float64) != 2 {
t.Errorf("Expected id 2, got %v", response["id"])
}

if response["jsonrpc"] != "2.0" {
t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"])
}

// Verify the tool result is present
result, ok := response["result"]
if !ok {
t.Error("Expected result field in response")
} else {
resultMap := result.(map[string]any)
content := resultMap["content"].([]any)
firstContent := content[0].(map[string]any)
if firstContent["text"] != "tool completed" {
t.Errorf("Expected tool completed text, got %v", firstContent["text"])
}
}
}

func postJSON(url string, bodyObject any) (*http.Response, error) {
jsonBody, _ := json.Marshal(bodyObject)
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
Expand Down