Skip to content

Commit d630605

Browse files
committed
feat: implement Interface for MCPServer and update server references
1 parent 0fdb197 commit d630605

File tree

5 files changed

+318
-7
lines changed

5 files changed

+318
-7
lines changed

examples/custom_server/main.go

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"log/slog"
7+
"net/http"
8+
"os"
9+
"os/signal"
10+
"syscall"
11+
"time"
12+
13+
"github.com/mark3labs/mcp-go/mcp"
14+
"github.com/mark3labs/mcp-go/server"
15+
)
16+
17+
// LoggingMCPServer wraps an Interface implementation with structured logging using slog
18+
type LoggingMCPServer struct {
19+
server server.Interface
20+
logger *slog.Logger
21+
}
22+
23+
// NewLoggingMCPServer creates a new logging wrapper around an Interface
24+
func NewLoggingMCPServer(server server.Interface, logger *slog.Logger) *LoggingMCPServer {
25+
return &LoggingMCPServer{
26+
server: server,
27+
logger: logger,
28+
}
29+
}
30+
31+
func (l *LoggingMCPServer) HandleMessage(ctx context.Context, message json.RawMessage) mcp.JSONRPCMessage {
32+
// Parse basic message info for logging
33+
var baseMsg struct {
34+
ID any `json:"id,omitempty"`
35+
Method mcp.MCPMethod `json:"method,omitempty"`
36+
}
37+
json.Unmarshal(message, &baseMsg)
38+
39+
start := time.Now()
40+
l.logger.InfoContext(ctx, "handling message",
41+
slog.String("method", string(baseMsg.Method)),
42+
slog.Any("id", baseMsg.ID),
43+
slog.Int("message_size", len(message)))
44+
45+
response := l.server.HandleMessage(ctx, message)
46+
duration := time.Since(start)
47+
48+
if response != nil {
49+
// Log response details
50+
responseBytes, _ := json.Marshal(response)
51+
l.logger.InfoContext(ctx, "message handled",
52+
slog.String("method", string(baseMsg.Method)),
53+
slog.Any("id", baseMsg.ID),
54+
slog.Duration("duration", duration),
55+
slog.Int("response_size", len(responseBytes)))
56+
} else {
57+
// Notification - no response
58+
l.logger.InfoContext(ctx, "notification handled",
59+
slog.String("method", string(baseMsg.Method)),
60+
slog.Duration("duration", duration))
61+
}
62+
63+
return response
64+
}
65+
66+
func (l *LoggingMCPServer) RegisterSession(ctx context.Context, session server.ClientSession) error {
67+
l.logger.InfoContext(ctx, "registering session",
68+
slog.String("session_id", session.SessionID()))
69+
70+
err := l.server.RegisterSession(ctx, session)
71+
if err != nil {
72+
l.logger.ErrorContext(ctx, "failed to register session",
73+
slog.String("session_id", session.SessionID()),
74+
slog.String("error", err.Error()))
75+
} else {
76+
l.logger.InfoContext(ctx, "session registered successfully",
77+
slog.String("session_id", session.SessionID()))
78+
}
79+
return err
80+
}
81+
82+
func (l *LoggingMCPServer) UnregisterSession(ctx context.Context, sessionID string) {
83+
l.logger.InfoContext(ctx, "unregistering session",
84+
slog.String("session_id", sessionID))
85+
l.server.UnregisterSession(ctx, sessionID)
86+
l.logger.InfoContext(ctx, "session unregistered",
87+
slog.String("session_id", sessionID))
88+
}
89+
90+
func (l *LoggingMCPServer) WithContext(ctx context.Context, session server.ClientSession) context.Context {
91+
return l.server.WithContext(ctx, session)
92+
}
93+
94+
func (l *LoggingMCPServer) SendNotificationToClient(ctx context.Context, method string, params map[string]any) error {
95+
l.logger.InfoContext(ctx, "sending notification to client",
96+
slog.String("method", method),
97+
slog.Any("params", params))
98+
99+
err := l.server.SendNotificationToClient(ctx, method, params)
100+
if err != nil {
101+
l.logger.ErrorContext(ctx, "failed to send notification to client",
102+
slog.String("method", method),
103+
slog.String("error", err.Error()))
104+
}
105+
return err
106+
}
107+
108+
func (l *LoggingMCPServer) SendNotificationToSpecificClient(sessionID string, method string, params map[string]any) error {
109+
l.logger.Info("sending notification to specific client",
110+
slog.String("session_id", sessionID),
111+
slog.String("method", method),
112+
slog.Any("params", params))
113+
114+
err := l.server.SendNotificationToSpecificClient(sessionID, method, params)
115+
if err != nil {
116+
l.logger.Error("failed to send notification to specific client",
117+
slog.String("session_id", sessionID),
118+
slog.String("method", method),
119+
slog.String("error", err.Error()))
120+
}
121+
return err
122+
}
123+
124+
func (l *LoggingMCPServer) SendNotificationToAllClients(method string, params map[string]any) {
125+
l.logger.Info("broadcasting notification to all clients",
126+
slog.String("method", method),
127+
slog.Any("params", params))
128+
l.server.SendNotificationToAllClients(method, params)
129+
}
130+
131+
func (l *LoggingMCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler server.ToolHandlerFunc) error {
132+
l.logger.Info("adding session tool",
133+
slog.String("session_id", sessionID),
134+
slog.String("tool_name", tool.Name),
135+
slog.String("tool_description", tool.Description))
136+
137+
err := l.server.AddSessionTool(sessionID, tool, handler)
138+
if err != nil {
139+
l.logger.Error("failed to add session tool",
140+
slog.String("session_id", sessionID),
141+
slog.String("tool_name", tool.Name),
142+
slog.String("error", err.Error()))
143+
}
144+
return err
145+
}
146+
147+
func (l *LoggingMCPServer) AddSessionTools(sessionID string, tools ...server.ServerTool) error {
148+
toolNames := make([]string, len(tools))
149+
for i, tool := range tools {
150+
toolNames[i] = tool.Tool.Name
151+
}
152+
153+
l.logger.Info("adding session tools",
154+
slog.String("session_id", sessionID),
155+
slog.Int("tool_count", len(tools)),
156+
slog.Any("tool_names", toolNames))
157+
158+
err := l.server.AddSessionTools(sessionID, tools...)
159+
if err != nil {
160+
l.logger.Error("failed to add session tools",
161+
slog.String("session_id", sessionID),
162+
slog.String("error", err.Error()))
163+
}
164+
return err
165+
}
166+
167+
func (l *LoggingMCPServer) DeleteSessionTools(sessionID string, names ...string) error {
168+
l.logger.Info("deleting session tools",
169+
slog.String("session_id", sessionID),
170+
slog.Any("tool_names", names))
171+
172+
err := l.server.DeleteSessionTools(sessionID, names...)
173+
if err != nil {
174+
l.logger.Error("failed to delete session tools",
175+
slog.String("session_id", sessionID),
176+
slog.Any("tool_names", names),
177+
slog.String("error", err.Error()))
178+
}
179+
return err
180+
}
181+
182+
func main() {
183+
// Configure structured logging with slog
184+
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
185+
Level: slog.LevelInfo,
186+
AddSource: true,
187+
}))
188+
189+
// Create the base MCP server with tools and resources
190+
mcpServer := server.NewMCPServer("example-server", "1.0.0",
191+
server.WithResourceCapabilities(true, true),
192+
server.WithToolCapabilities(true),
193+
server.WithPromptCapabilities(true),
194+
)
195+
196+
// Add some example tools
197+
mcpServer.AddTool(
198+
mcp.NewTool("time", mcp.WithDescription("Get current time")),
199+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
200+
logger.InfoContext(ctx, "time tool called")
201+
return mcp.NewToolResultText("Current time: " + time.Now().Format(time.RFC3339)), nil
202+
},
203+
)
204+
205+
// Add example resource
206+
mcpServer.AddResource(
207+
mcp.NewResource("example://info", "Server Info", mcp.WithResourceDescription("Information about this server")),
208+
func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
209+
logger.InfoContext(ctx, "info resource accessed")
210+
return []mcp.ResourceContents{
211+
mcp.TextResourceContents{
212+
URI: "example://info",
213+
MIMEType: "text/plain",
214+
Text: "This is an example MCP server with logging",
215+
},
216+
}, nil
217+
},
218+
)
219+
220+
// Wrap the server with logging
221+
customLoggingServer := NewLoggingMCPServer(mcpServer, logger)
222+
223+
// Create the StreamableHTTP server with the logging wrapper
224+
httpServer := server.NewStreamableHTTPServer(customLoggingServer,
225+
server.WithEndpointPath("/mcp"),
226+
server.WithHeartbeatInterval(30*time.Second),
227+
)
228+
229+
logger.Info("starting MCP server",
230+
slog.String("address", ":8080"),
231+
slog.String("endpoint", "/mcp"))
232+
233+
// Start server in a goroutine
234+
go func() {
235+
if err := httpServer.Start(":8080"); err != nil && err != http.ErrServerClosed {
236+
logger.Error("server failed to start", slog.String("error", err.Error()))
237+
os.Exit(1)
238+
}
239+
}()
240+
241+
// Wait for interrupt signal
242+
quit := make(chan os.Signal, 1)
243+
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
244+
<-quit
245+
246+
logger.Info("shutting down server")
247+
248+
// Graceful shutdown
249+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
250+
defer cancel()
251+
252+
if err := httpServer.Shutdown(ctx); err != nil {
253+
logger.Error("server shutdown failed", slog.String("error", err.Error()))
254+
} else {
255+
logger.Info("server shutdown complete")
256+
}
257+
}

server/server_interface.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"github.com/mark3labs/mcp-go/mcp"
8+
)
9+
10+
// Ensure MCPServer implements the Interface
11+
var _ Interface = (*MCPServer)(nil)
12+
13+
// Interface defines the essential interface that all MCP server transports depend on.
14+
// This allows for custom implementations of the core server logic while maintaining
15+
// compatibility with all existing transports (SSE, Stdio, StreamableHTTP).
16+
type Interface interface {
17+
// HandleMessage processes an incoming JSON-RPC message and returns an appropriate response.
18+
// This is the core method that processes MCP protocol messages.
19+
HandleMessage(ctx context.Context, message json.RawMessage) mcp.JSONRPCMessage
20+
21+
// RegisterSession registers a new client session with the server.
22+
// Returns an error if the session already exists or registration fails.
23+
RegisterSession(ctx context.Context, session ClientSession) error
24+
25+
// UnregisterSession removes a session from the server by session ID.
26+
UnregisterSession(ctx context.Context, sessionID string)
27+
28+
// WithContext creates a new context with the given session attached.
29+
// This allows the session to be retrieved later using ClientSessionFromContext.
30+
WithContext(ctx context.Context, session ClientSession) context.Context
31+
32+
// SendNotificationToClient sends a notification to the client associated with the given context.
33+
// Returns an error if no session is found or the notification cannot be sent.
34+
SendNotificationToClient(ctx context.Context, method string, params map[string]any) error
35+
36+
// SendNotificationToSpecificClient sends a notification to a specific client by session ID.
37+
// Returns an error if the session is not found or the notification cannot be sent.
38+
SendNotificationToSpecificClient(sessionID string, method string, params map[string]any) error
39+
40+
// SendNotificationToAllClients broadcasts a notification to all currently active sessions.
41+
SendNotificationToAllClients(method string, params map[string]any)
42+
43+
// AddSessionTool adds a tool for a specific session.
44+
// Returns an error if the session doesn't exist or doesn't support tools.
45+
AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error
46+
47+
// AddSessionTools adds multiple tools for a specific session.
48+
// Returns an error if the session doesn't exist or doesn't support tools.
49+
AddSessionTools(sessionID string, tools ...ServerTool) error
50+
51+
// DeleteSessionTools removes tools from a specific session.
52+
// Returns an error if the session doesn't exist or doesn't support tools.
53+
DeleteSessionTools(sessionID string, names ...string) error
54+
}

server/sse.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ var (
118118
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
119119
// It provides real-time communication capabilities over HTTP using the SSE protocol.
120120
type SSEServer struct {
121-
server *MCPServer
121+
server Interface
122122
baseURL string
123123
basePath string
124124
appendQueryToMessageEndpoint bool
@@ -258,7 +258,7 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
258258
}
259259

260260
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
261-
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
261+
func NewSSEServer(server Interface, opts ...SSEOption) *SSEServer {
262262
s := &SSEServer{
263263
server: server,
264264
sseEndpoint: "/sse",
@@ -277,7 +277,7 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
277277
}
278278

279279
// NewTestServer creates a test server for testing purposes
280-
func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
280+
func NewTestServer(server Interface, opts ...SSEOption) *httptest.Server {
281281
sseServer := NewSSEServer(server, opts...)
282282

283283
testServer := httptest.NewServer(sseServer)

server/stdio.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type StdioContextFunc func(ctx context.Context) context.Context
2525
// It provides a simple way to create command-line MCP servers that
2626
// communicate via standard input/output streams using JSON-RPC messages.
2727
type StdioServer struct {
28-
server *MCPServer
28+
server Interface
2929
errLogger *log.Logger
3030
contextFunc StdioContextFunc
3131
}
@@ -112,7 +112,7 @@ var stdioSessionInstance = stdioSession{
112112

113113
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
114114
// It initializes the server with a default error logger that discards all output.
115-
func NewStdioServer(server *MCPServer) *StdioServer {
115+
func NewStdioServer(server Interface) *StdioServer {
116116
return &StdioServer{
117117
server: server,
118118
errLogger: log.New(

server/streamable_http.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func WithLogger(logger util.Logger) StreamableHTTPOption {
117117
// - Batching of requests/notifications/responses in arrays.
118118
// - Stream Resumability
119119
type StreamableHTTPServer struct {
120-
server *MCPServer
120+
server Interface
121121
sessionTools *sessionToolsStore
122122
sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
123123

@@ -132,7 +132,7 @@ type StreamableHTTPServer struct {
132132
}
133133

134134
// NewStreamableHTTPServer creates a new streamable-http server instance
135-
func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer {
135+
func NewStreamableHTTPServer(server Interface, opts ...StreamableHTTPOption) *StreamableHTTPServer {
136136
s := &StreamableHTTPServer{
137137
server: server,
138138
sessionTools: newSessionToolsStore(),

0 commit comments

Comments
 (0)