diff --git a/mcp/prompts.go b/mcp/prompts.go index a63a21450..6c5cf5708 100644 --- a/mcp/prompts.go +++ b/mcp/prompts.go @@ -68,6 +68,8 @@ type PromptArgument struct { // Whether this argument must be provided. // If true, clients must include this argument when calling prompts/get. Required bool `json:"required,omitempty"` + // Optional CompletionHandlerFunc for autocompleting the argument value. + CompletionHandler *CompletionHandlerFunc `json:"-"` } // Role represents the sender or recipient of messages and data in a @@ -168,3 +170,10 @@ func RequiredArgument() ArgumentOption { arg.Required = true } } + +// ArgumentCompletion configures an autocomplete handler for the argument. +func ArgumentCompletion(handler CompletionHandlerFunc) ArgumentOption { + return func(arg *PromptArgument) { + arg.CompletionHandler = &handler + } +} diff --git a/mcp/resources.go b/mcp/resources.go index 07a59a322..ef5ce4a1d 100644 --- a/mcp/resources.go +++ b/mcp/resources.go @@ -97,3 +97,14 @@ func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplate t.Annotations.Priority = priority } } + +// WithTemplateArgumentCompletion adds an autocomplete handler for the specified argument of ResourceTemplate. +// The argument should be one of the variables referenced by the URI template. +func WithTemplateArgumentCompletion(argument string, handler CompletionHandlerFunc) ResourceTemplateOption { + return func(t *ResourceTemplate) { + if t.URITemplate.ArgumentCompletionHandlers == nil { + t.URITemplate.ArgumentCompletionHandlers = make(map[string]CompletionHandlerFunc) + } + t.URITemplate.ArgumentCompletionHandlers[argument] = handler + } +} diff --git a/mcp/types.go b/mcp/types.go index d4f6132c8..7c73f3138 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -3,6 +3,7 @@ package mcp import ( + "context" "encoding/json" "fmt" "maps" @@ -50,6 +51,10 @@ const ( // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsCall MCPMethod = "tools/call" + // MethodCompletion provides autocompletion suggestions for URI arguments. + // https://modelcontextprotocol.io/specification/draft/server/utilities/completion + MethodCompletion MCPMethod = "completion/complete" + // MethodSetLogLevel configures the minimum log level for client // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging MethodSetLogLevel MCPMethod = "logging/setLevel" @@ -69,8 +74,14 @@ const ( MethodNotificationToolsListChanged = "notifications/tools/list_changed" ) +// CompletionHandlerFunc handles completion requests. +type CompletionHandlerFunc func(ctx context.Context, request CompleteRequest) (*CompleteResult, error) + type URITemplate struct { *uritemplate.Template + + // Optional mapping of URI template arguments to CompletionHandlerFunc for autocompleting each argument's value. + ArgumentCompletionHandlers map[string]CompletionHandlerFunc `json:"-"` } func (t *URITemplate) MarshalJSON() ([]byte, error) { @@ -456,6 +467,8 @@ type ServerCapabilities struct { Experimental map[string]any `json:"experimental,omitempty"` // Present if the server supports sending log messages to the client. Logging *struct{} `json:"logging,omitempty"` + // Present if the server supports autocompletion for prompt and resource template arguments. + Completion *struct{} `json:"completion,omitempty"` // Present if the server offers any prompt templates. Prompts *struct { // Whether this server supports notifications for changes to the prompt list. @@ -965,37 +978,51 @@ type CompleteParams struct { // The value of the argument to use for completion matching. Value string `json:"value"` } `json:"argument"` + Context struct { + // Previously completed arguments for this reference. + Arguments map[string]string `json:"arguments,omitempty"` + } `json:"context,omitempty"` } // CompleteResult is the server's response to a completion/complete request type CompleteResult struct { Result - Completion struct { - // An array of completion values. Must not exceed 100 items. - Values []string `json:"values"` - // The total number of completion options available. This can exceed the - // number of values actually sent in the response. - Total int `json:"total,omitempty"` - // Indicates whether there are additional completion options beyond those - // provided in the current response, even if the exact total is unknown. - HasMore bool `json:"hasMore,omitempty"` - } `json:"completion"` + Completion Completion `json:"completion"` +} + +// Completion represents the resulting completion values for a completion/complete request +type Completion struct { + // An array of completion values. Must not exceed 100 items. + Values []string `json:"values"` + // The total number of completion options available. This can exceed the + // number of values actually sent in the response. + Total int `json:"total,omitempty"` + // Indicates whether there are additional completion options beyond those + // provided in the current response, even if the exact total is unknown. + HasMore bool `json:"hasMore,omitempty"` } // ResourceReference is a reference to a resource or resource template definition. type ResourceReference struct { - Type string `json:"type"` + Type RefType `json:"type"` // The URI or URI template of the resource. URI string `json:"uri"` } // PromptReference identifies a prompt. type PromptReference struct { - Type string `json:"type"` + Type RefType `json:"type"` // The name of the prompt or prompt template Name string `json:"name"` } +type RefType string + +const ( + RefTypeResource RefType = "ref/resource" + RefTypePrompt RefType = "ref/prompt" +) + /* Roots */ // ListRootsRequest is sent from the server to request a list of root URIs from the client. Roots allow diff --git a/mcp/utils.go b/mcp/utils.go index 3e652efd7..54529da8b 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -514,6 +514,35 @@ func ParseContent(contentMap map[string]any) (Content, error) { return nil, fmt.Errorf("unsupported content type: %s", contentType) } +func ParseCompletionReference(req CompleteRequest) (*PromptReference, *ResourceReference, error) { + ref, ok := req.Params.Ref.(map[string]interface{}) + if !ok { + return nil, nil, fmt.Errorf("params.ref must be a mapping") + } + + refType, ok := ref["type"].(string) + if !ok { + return nil, nil, fmt.Errorf("params.ref.type must be a string") + } + + switch RefType(refType) { + case RefTypeResource: + if uri, ok := ref["uri"].(string); !ok { + return nil, nil, fmt.Errorf("params.ref.uri must be a string") + } else { + return nil, &ResourceReference{Type: RefType(refType), URI: uri}, nil + } + case RefTypePrompt: + if name, ok := ref["name"].(string); !ok { + return nil, nil, fmt.Errorf("params.ref.name must be a string") + } else { + return &PromptReference{Type: RefType(refType), Name: name}, nil, nil + } + default: + return nil, nil, fmt.Errorf("unexpected value for params.ref.type: %s", refType) + } +} + func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { if rawMessage == nil { return nil, fmt.Errorf("response is nil") diff --git a/server/hooks.go b/server/hooks.go index 4baa1c4e0..a8ecfa4bf 100644 --- a/server/hooks.go +++ b/server/hooks.go @@ -91,6 +91,9 @@ type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToo type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest) type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) +type OnBeforeCompleteFunc func(ctx context.Context, id any, message *mcp.CompleteRequest) +type OnAfterCompleteFunc func(ctx context.Context, id any, message *mcp.CompleteRequest, result *mcp.CompleteResult) + type Hooks struct { OnRegisterSession []OnRegisterSessionHookFunc OnUnregisterSession []OnUnregisterSessionHookFunc @@ -118,6 +121,8 @@ type Hooks struct { OnAfterListTools []OnAfterListToolsFunc OnBeforeCallTool []OnBeforeCallToolFunc OnAfterCallTool []OnAfterCallToolFunc + OnBeforeComplete []OnBeforeCompleteFunc + OnAfterComplete []OnAfterCompleteFunc } func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { @@ -530,3 +535,30 @@ func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallTool hook(ctx, id, message, result) } } +func (c *Hooks) AddBeforeComplete(hook OnBeforeCompleteFunc) { + c.OnBeforeComplete = append(c.OnBeforeComplete, hook) +} + +func (c *Hooks) AddAfterComplete(hook OnAfterCompleteFunc) { + c.OnAfterComplete = append(c.OnAfterComplete, hook) +} + +func (c *Hooks) beforeComplete(ctx context.Context, id any, message *mcp.CompleteRequest) { + c.beforeAny(ctx, id, mcp.MethodCompletion, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeComplete { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterComplete(ctx context.Context, id any, message *mcp.CompleteRequest, result *mcp.CompleteResult) { + c.onSuccess(ctx, id, mcp.MethodCompletion, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterComplete { + hook(ctx, id, message, result) + } +} diff --git a/server/internal/gen/data.go b/server/internal/gen/data.go index a468f4605..e5029b8e3 100644 --- a/server/internal/gen/data.go +++ b/server/internal/gen/data.go @@ -107,5 +107,15 @@ var MCPRequestTypes = []MCPRequestType{ HookName: "CallTool", UnmarshalError: "invalid call tool request", HandlerFunc: "handleToolCall", + }, { + MethodName: "MethodCompletion", + ParamType: "CompleteRequest", + ResultType: "CompleteResult", + Group: "completions", + GroupName: "Completions", + GroupHookName: "Completion", + HookName: "Complete", + UnmarshalError: "invalid completion request", + HandlerFunc: "handleCompletion", }, } diff --git a/server/request_handler.go b/server/request_handler.go index 25f6ef14f..4058ba76c 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -310,6 +310,31 @@ func (s *MCPServer) HandleMessage( } s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result) return createResponse(baseMessage.ID, *result) + case mcp.MethodCompletion: + var request mcp.CompleteRequest + var result *mcp.CompleteResult + if s.capabilities.completions == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("completions %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeComplete(ctx, baseMessage.ID, &request) + result, err = s.handleCompletion(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterComplete(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) default: return createErrorResponse( baseMessage.ID, diff --git a/server/server.go b/server/server.go index 46e6d9c57..b4dbfccd4 100644 --- a/server/server.go +++ b/server/server.go @@ -171,10 +171,11 @@ func WithPaginationLimit(limit int) ServerOption { // serverCapabilities defines the supported features of the MCP server type serverCapabilities struct { - tools *toolCapabilities - resources *resourceCapabilities - prompts *promptCapabilities - logging *bool + tools *toolCapabilities + resources *resourceCapabilities + prompts *promptCapabilities + completions *bool + logging *bool } // resourceCapabilities defines the supported resource-related features @@ -281,6 +282,13 @@ func WithLogging() ServerOption { } } +// WithCompletions enables autocomplete capabilities for the server +func WithCompletions() ServerOption { + return func(s *MCPServer) { + s.capabilities.completions = mcp.ToBoolPtr(true) + } +} + // WithInstructions sets the server instructions for the client returned in the initialize response func WithInstructions(instructions string) ServerOption { return func(s *MCPServer) { @@ -374,6 +382,10 @@ func (s *MCPServer) AddResourceTemplate( } s.resourcesMu.Unlock() + if len(template.URITemplate.ArgumentCompletionHandlers) > 0 { + s.implicitlyRegisterCompletionCapabilities() + } + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification if s.capabilities.resources.listChanged { // Send notification to all initialized sessions @@ -392,6 +404,15 @@ func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { } s.promptsMu.Unlock() + for _, entry := range prompts { + for _, arg := range entry.Prompt.Arguments { + if arg.CompletionHandler != nil { + s.implicitlyRegisterCompletionCapabilities() + break + } + } + } + // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. if s.capabilities.prompts.listChanged { // Send notification to all initialized sessions @@ -402,6 +423,12 @@ func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { // AddPrompt registers a new prompt handler with the given name func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler}) + for _, arg := range prompt.Arguments { + if arg.CompletionHandler != nil { + s.implicitlyRegisterCompletionCapabilities() + break + } + } } // DeletePrompts removes prompts from the server @@ -446,6 +473,13 @@ func (s *MCPServer) implicitlyRegisterResourceCapabilities() { ) } +func (s *MCPServer) implicitlyRegisterCompletionCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.completions != nil }, + func() { s.capabilities.completions = mcp.ToBoolPtr(true) }, + ) +} + func (s *MCPServer) implicitlyRegisterPromptCapabilities() { s.implicitlyRegisterCapabilities( func() bool { return s.capabilities.prompts != nil }, @@ -562,6 +596,10 @@ func (s *MCPServer) handleInitialize( capabilities.Logging = &struct{}{} } + if s.capabilities.completions != nil && *s.capabilities.completions { + capabilities.Completion = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ @@ -1037,6 +1075,71 @@ func (s *MCPServer) handleToolCall( return result, nil } +func (s *MCPServer) handleCompletion( + ctx context.Context, + id any, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, *requestError) { + + promptRef, resourceRef, err := mcp.ParseCompletionReference(request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + + var handler *mcp.CompletionHandlerFunc + if promptRef != nil { + s.promptsMu.RLock() + if prompt, ok := s.prompts[promptRef.Name]; ok { + for _, arg := range prompt.Arguments { + if arg.Name == request.Params.Argument.Name { + handler = arg.CompletionHandler + } + } + } else { + s.promptsMu.RUnlock() + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("prompt '%s' not found: %w", promptRef.Name, ErrPromptNotFound), + } + } + s.promptsMu.RUnlock() + } else if resourceRef != nil { + s.resourcesMu.RLock() + if resourceTemplate, ok := s.resourceTemplates[resourceRef.URI]; ok { + if h, ok := resourceTemplate.template.URITemplate.ArgumentCompletionHandlers[request.Params.Argument.Name]; ok { + handler = &h + } + } else { + s.resourcesMu.RUnlock() + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("resource template '%s' not found: %w", resourceRef.URI, ErrResourceNotFound), + } + } + s.resourcesMu.RUnlock() + } + + if handler == nil { + return &mcp.CompleteResult{Completion: mcp.Completion{Values: make([]string, 0)}}, nil + } + + result, err := (*handler)(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil +} + func (s *MCPServer) handleNotification( ctx context.Context, notification mcp.JSONRPCNotification, diff --git a/server/server_test.go b/server/server_test.go index 1c81d18dd..5ef0bd84a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -59,6 +59,7 @@ func TestMCPServer_Capabilities(t *testing.T) { WithPromptCapabilities(true), WithToolCapabilities(true), WithLogging(), + WithCompletions(), }, validate: func(t *testing.T, response mcp.JSONRPCMessage) { resp, ok := response.(mcp.JSONRPCResponse) @@ -87,6 +88,7 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.True(t, initResult.Capabilities.Tools.ListChanged) assert.NotNil(t, initResult.Capabilities.Logging) + assert.NotNil(t, initResult.Capabilities.Completion) }, }, { @@ -1458,6 +1460,119 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { }) } +func TestMCPServer_HandleCompletion(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + server.AddPrompt( + mcp.NewPrompt("test-prompt", + mcp.WithArgument("test-arg", mcp.ArgumentCompletion(func(_ context.Context, req mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return &mcp.CompleteResult{ + Completion: mcp.Completion{ + Values: []string{fmt.Sprintf("%sbar", req.Params.Argument.Value)}, + }, + }, nil + })), + ), + func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return nil, nil + }, + ) + + tests := []struct { + name string + message string + validate func(t *testing.T, response mcp.JSONRPCMessage) + }{ + { + name: "Prompt argument completion", + message: `{ + "jsonrpc": "2.0", + "id": 1, + "method": "completion/complete", + "params": { + "ref": { + "type": "ref/prompt", + "name": "test-prompt" + }, + "argument": { + "name": "test-arg", + "value": "foo" + } + } + }`, + validate: func(t *testing.T, response mcp.JSONRPCMessage) { + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + result, ok := resp.Result.(mcp.CompleteResult) + assert.True(t, ok) + + assert.Len(t, result.Completion.Values, 1) + assert.Equal(t, "foobar", result.Completion.Values[0]) + }, + }, + { + name: "No completion for prompt argument", + message: `{ + "jsonrpc": "2.0", + "id": 1, + "method": "completion/complete", + "params": { + "ref": { + "type": "ref/prompt", + "name": "test-prompt" + }, + "argument": { + "name": "another-arg", + "value": "foo" + } + } + }`, + validate: func(t *testing.T, response mcp.JSONRPCMessage) { + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + result, ok := resp.Result.(mcp.CompleteResult) + assert.True(t, ok) + + assert.NotNil(t, result.Completion.Values) + assert.Len(t, result.Completion.Values, 0) + }, + }, + { + name: "Prompt not found", + message: `{ + "jsonrpc": "2.0", + "id": 1, + "method": "completion/complete", + "params": { + "ref": { + "type": "ref/prompt", + "name": "unknown-prompt" + }, + "argument": { + "name": "test-arg", + "value": "foo" + } + } + }`, + validate: func(t *testing.T, response mcp.JSONRPCMessage) { + resp, ok := response.(mcp.JSONRPCError) + assert.True(t, ok) + assert.Equal(t, resp.Error.Code, mcp.INVALID_PARAMS) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + response := server.HandleMessage(context.Background(), []byte(tt.message)) + assert.NotNil(t, response) + tt.validate(t, response) + }) + } +} + func createTestServer() *MCPServer { server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(true, true), diff --git a/www/docs/pages/servers/prompts.mdx b/www/docs/pages/servers/prompts.mdx index f1801b548..d3b96f7ea 100644 --- a/www/docs/pages/servers/prompts.mdx +++ b/www/docs/pages/servers/prompts.mdx @@ -166,6 +166,76 @@ Please provide a comprehensive analysis including: ## Prompt Arguments +### Basic Argument Completion + +```go +import ( + "context" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +var dictionary = map[string][]string{ + "adjective": {"funny", "silly", "clever", "witty", "absurd"}, + "subject": {"cats", "dogs", "programmers", "aliens", "politicians"}, +} + +func main() { + s := server.NewMCPServer("Personal Assistant", "1.0.0", + server.WithPromptCapabilities(true), + ) + + // Tell joke prompt + tellJokePrompt := mcp.NewPrompt("tell_joke", + mcp.WithArgument("adjective", mcp.ArgumentCompletion(handleCompleteFromDictionary)), + mcp.WithArgument("subject", mcp.RequiredArgument(), mcp.ArgumentCompletion(handleCompleteFromDictionary)), + ) + + s.AddPrompt(tellJokePrompt, handleTellJoke) + server.ServeStdio(s) +} + +func handleCompleteFromDictionary(ctx context.Context, req mcp.CompleteRequest) (*mcp.CompleteResult, error) { + input := strings.ToLower(req.Params.Argument.Value) + options := dictionary[req.Params.Argument.Name] // Get options for argument + + // Initialize results container + result := &mcp.CompleteResult{ + Completion: mcp.Completion{ + Values: make([]string, 0, len(options)), + }, + } + + // Simple match logic + for _, option := range options { + if strings.Contains(strings.ToLower(option), input) { + result.Completion.Values = append(result.Completion.Values, option) + } + } + return result, nil +} + +func handleTellJoke(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + adjective := req.Params.Arguments["adjective"] + subject, ok := req.Params.Arguments["subject"] + if !ok || strings.TrimSpace(subject) == "" { + return nil, fmt.Errorf("subject argument is required and cannot be empty") + } + + return &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + { + Role: "user", + Content: mcp.NewTextContent(fmt.Sprintf("Tell a %s joke about %s.", adjective, subject)), + }, + }, + }, nil +} +``` + ### Flexible Parameter Handling ```go