Skip to content

Commit ae6bda6

Browse files
authored
mcp: establish the streamable client standalone SSE stream in Connect (#604)
When Connect returns, client should be guaranteed that the streamable SSE stream is connected. Fixes #583
1 parent 80abbe6 commit ae6bda6

File tree

3 files changed

+118
-159
lines changed

3 files changed

+118
-159
lines changed

mcp/streamable.go

Lines changed: 100 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,44 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
13461346
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
13471347
// ID at initialization time, by including it in an Mcp-Session-Id header
13481348
// on the HTTP response containing the InitializeResult.
1349-
go c.handleSSE("standalone SSE stream", nil, true, nil)
1349+
c.connectStandaloneSSE()
1350+
}
1351+
1352+
func (c *streamableClientConn) connectStandaloneSSE() {
1353+
resp, err := c.connectSSE("")
1354+
if err != nil {
1355+
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
1356+
return
1357+
}
1358+
1359+
// [§2.2.3]: "The server MUST either return Content-Type:
1360+
// text/event-stream in response to this HTTP GET, or else return HTTP
1361+
// 405 Method Not Allowed, indicating that the server does not offer an
1362+
// SSE stream at this endpoint."
1363+
//
1364+
// [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server
1365+
if resp.StatusCode == http.StatusMethodNotAllowed {
1366+
// The server doesn't support the standalone SSE stream.
1367+
resp.Body.Close()
1368+
return
1369+
}
1370+
if resp.StatusCode == http.StatusNotFound && !c.strict {
1371+
// modelcontextprotocol/gosdk#393: some servers return NotFound instead
1372+
// of MethodNotAllowed for the standalone SSE stream.
1373+
//
1374+
// Treat this like MethodNotAllowed in non-strict mode.
1375+
if c.logger != nil {
1376+
c.logger.Warn("got 404 instead of 405 for standalone SSE stream")
1377+
}
1378+
resp.Body.Close()
1379+
return
1380+
}
1381+
summary := "standalone SSE stream"
1382+
if err := c.checkResponse(summary, resp); err != nil {
1383+
c.fail(err)
1384+
return
1385+
}
1386+
go c.handleSSE(summary, resp, true, nil)
13501387
}
13511388

13521389
// fail handles an asynchronous error while reading.
@@ -1434,22 +1471,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
14341471
return fmt.Errorf("%s: %v", requestSummary, err)
14351472
}
14361473

1437-
// §2.5.3: "The server MAY terminate the session at any time, after
1438-
// which it MUST respond to requests containing that session ID with HTTP
1439-
// 404 Not Found."
1440-
if resp.StatusCode == http.StatusNotFound {
1441-
// Fail the session immediately, rather than relying on jsonrpc2 to fail
1442-
// (and close) it, because we want the call to Close to know that this
1443-
// session is missing (and therefore not send the DELETE).
1444-
err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing)
1474+
if err := c.checkResponse(requestSummary, resp); err != nil {
14451475
c.fail(err)
1446-
resp.Body.Close()
14471476
return err
14481477
}
1449-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1450-
resp.Body.Close()
1451-
return fmt.Errorf("broken session: %v", resp.Status)
1452-
}
14531478

14541479
if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" {
14551480
c.mu.Lock()
@@ -1463,6 +1488,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
14631488
return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID)
14641489
}
14651490
}
1491+
// TODO(rfindley): this logic isn't quite right.
1492+
// We should keep going even if the server returns 202, if we have a call.
14661493
if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted {
14671494
// [§2.1.4]: "If the input is a JSON-RPC response or notification:
14681495
// If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body."
@@ -1543,73 +1570,63 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
15431570
//
15441571
// If forCall is set, it is the call that initiated the stream, and the
15451572
// stream is complete when we receive its response.
1546-
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
1547-
resp := initialResp
1548-
var lastEventID string
1573+
func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
15491574
for {
1575+
// Connection was successful. Continue the loop with the new response.
15501576
// TODO: we should set a reasonable limit on the number of times we'll try
15511577
// getting a response for a given request.
15521578
//
15531579
// Eventually, if we don't get the response, we should stop trying and
15541580
// fail the request.
1555-
if resp != nil {
1556-
eventID, clientClosed := c.processStream(requestSummary, resp, forCall)
1557-
lastEventID = eventID
1581+
lastEventID, clientClosed := c.processStream(requestSummary, resp, forCall)
15581582

1559-
// If the connection was closed by the client, we're done.
1560-
if clientClosed {
1561-
return
1562-
}
1563-
// If the stream has ended, then do not reconnect if the stream is
1564-
// temporary (POST initiated SSE).
1565-
if lastEventID == "" && !persistent {
1566-
return
1567-
}
1583+
// If the connection was closed by the client, we're done.
1584+
if clientClosed {
1585+
return
1586+
}
1587+
// If the stream has ended, then do not reconnect if the stream is
1588+
// temporary (POST initiated SSE).
1589+
if lastEventID == "" && !persistent {
1590+
return
15681591
}
15691592

15701593
// The stream was interrupted or ended by the server. Attempt to reconnect.
1571-
newResp, err := c.reconnect(lastEventID)
1594+
newResp, err := c.connectSSE(lastEventID)
15721595
if err != nil {
15731596
// All reconnection attempts failed: fail the connection.
15741597
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
15751598
return
15761599
}
15771600
resp = newResp
1578-
if resp.StatusCode == http.StatusMethodNotAllowed && persistent {
1579-
// [§2.2.3]: "The server MUST either return Content-Type:
1580-
// text/event-stream in response to this HTTP GET, or else return HTTP
1581-
// 405 Method Not Allowed, indicating that the server does not offer an
1582-
// SSE stream at this endpoint."
1583-
//
1584-
// [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server
1585-
1586-
// The server doesn't support the standalone SSE stream.
1587-
resp.Body.Close()
1588-
return
1589-
}
1590-
if resp.StatusCode == http.StatusNotFound && persistent && !c.strict {
1591-
// modelcontextprotocol/gosdk#393: some servers return NotFound instead
1592-
// of MethodNotAllowed for the standalone SSE stream.
1593-
//
1594-
// Treat this like MethodNotAllowed in non-strict mode.
1595-
if c.logger != nil {
1596-
c.logger.Warn("got 404 instead of 405 for standalonw SSE stream")
1597-
}
1598-
resp.Body.Close()
1599-
return
1600-
}
1601-
// (see equivalent handling in [streamableClientConn.Write]).
1602-
if resp.StatusCode == http.StatusNotFound {
1603-
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing))
1601+
if err := c.checkResponse(requestSummary, resp); err != nil {
1602+
c.fail(err)
16041603
return
16051604
}
1606-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1605+
}
1606+
}
1607+
1608+
// checkResponse checks the status code of the provided response, and
1609+
// translates it into an error if the request was unsuccessful.
1610+
//
1611+
// The response body is close if a non-nil error is returned.
1612+
func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.Response) (err error) {
1613+
defer func() {
1614+
if err != nil {
16071615
resp.Body.Close()
1608-
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode)))
1609-
return
16101616
}
1611-
// Reconnection was successful. Continue the loop with the new response.
1617+
}()
1618+
// §2.5.3: "The server MAY terminate the session at any time, after
1619+
// which it MUST respond to requests containing that session ID with HTTP
1620+
// 404 Not Found."
1621+
if resp.StatusCode == http.StatusNotFound {
1622+
// Return an errSessionMissing to avoid sending a redundant DELETE when the
1623+
// session is already gone.
1624+
return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing)
1625+
}
1626+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1627+
return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode))
16121628
}
1629+
return nil
16131630
}
16141631

16151632
// processStream reads from a single response body, sending events to the
@@ -1620,6 +1637,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
16201637
defer resp.Body.Close()
16211638
for evt, err := range scanEvents(resp.Body) {
16221639
if err != nil {
1640+
// TODO: we should differentiate EOF from other errors here.
16231641
break
16241642
}
16251643

@@ -1664,39 +1682,48 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
16641682
return lastEventID, false
16651683
}
16661684

1667-
// reconnect handles the logic of retrying a connection with an exponential
1668-
// backoff strategy. It returns a new, valid HTTP response if successful, or
1669-
// an error if all retries are exhausted.
1670-
func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
1685+
// connectSSE handles the logic of connecting a text/event-stream connection.
1686+
//
1687+
// If lastEventID is set, it is the last-event ID of a stream being resumed.
1688+
//
1689+
// If connection fails, connectSSE retries with an exponential backoff
1690+
// strategy. It returns a new, valid HTTP response if successful, or an error
1691+
// if all retries are exhausted.
1692+
func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, error) {
16711693
var finalErr error
1672-
1673-
// We can reach the 'reconnect' path through the standlone SSE request, in which case
1674-
// lastEventID will be "".
1675-
//
1676-
// In this case, we need an initial attempt.
1694+
// If lastEventID is set, we've already connected successfully once, so
1695+
// consider that to be the first attempt.
16771696
attempt := 0
16781697
if lastEventID != "" {
16791698
attempt = 1
16801699
}
1681-
16821700
for ; attempt <= c.maxRetries; attempt++ {
16831701
select {
16841702
case <-c.done:
16851703
return nil, fmt.Errorf("connection closed by client during reconnect")
16861704
case <-time.After(calculateReconnectDelay(attempt)):
1687-
resp, err := c.establishSSE(lastEventID)
1705+
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
1706+
if err != nil {
1707+
return nil, err
1708+
}
1709+
c.setMCPHeaders(req)
1710+
if lastEventID != "" {
1711+
req.Header.Set("Last-Event-ID", lastEventID)
1712+
}
1713+
req.Header.Set("Accept", "text/event-stream")
1714+
resp, err := c.client.Do(req)
16881715
if err != nil {
16891716
finalErr = err // Store the error and try again.
16901717
continue
16911718
}
16921719
return resp, nil
16931720
}
16941721
}
1695-
// If the loop completes, all retries have failed.
1722+
// If the loop completes, all retries have failed, or the client is closing.
16961723
if finalErr != nil {
16971724
return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr)
16981725
}
1699-
return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries)
1726+
return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries)
17001727
}
17011728

17021729
// Close implements the [Connection] interface.
@@ -1723,23 +1750,6 @@ func (c *streamableClientConn) Close() error {
17231750
return c.closeErr
17241751
}
17251752

1726-
// establishSSE establishes the persistent SSE listening stream.
1727-
// It is used for reconnect attempts using the Last-Event-ID header to
1728-
// resume a broken stream where it left off.
1729-
func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) {
1730-
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
1731-
if err != nil {
1732-
return nil, err
1733-
}
1734-
c.setMCPHeaders(req)
1735-
if lastEventID != "" {
1736-
req.Header.Set("Last-Event-ID", lastEventID)
1737-
}
1738-
req.Header.Set("Accept", "text/event-stream")
1739-
1740-
return c.client.Do(req)
1741-
}
1742-
17431753
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
17441754
func calculateReconnectDelay(attempt int) time.Duration {
17451755
if attempt == 0 {

0 commit comments

Comments
 (0)