diff --git a/cmd/docker-mcp/commands/workingset.go b/cmd/docker-mcp/commands/workingset.go index f9c2eb04..3ae381e3 100644 --- a/cmd/docker-mcp/commands/workingset.go +++ b/cmd/docker-mcp/commands/workingset.go @@ -28,6 +28,7 @@ func workingSetCommand() *cobra.Command { cmd.AddCommand(pullWorkingSetCommand()) cmd.AddCommand(createWorkingSetCommand()) cmd.AddCommand(removeWorkingSetCommand()) + cmd.AddCommand(workingsetServerCommand()) cmd.AddCommand(configWorkingSetCommand()) return cmd } @@ -294,3 +295,76 @@ Use --workingset to show servers only from a specific working set.`, return cmd } + +func workingsetServerCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "server", + Short: "Manage servers in working sets", + } + + cmd.AddCommand(addServerCommand()) + cmd.AddCommand(removeServerCommand()) + + return cmd +} + +func addServerCommand() *cobra.Command { + var servers []string + + cmd := &cobra.Command{ + Use: "add --server --server ...", + Short: "Add MCP servers to a working set", + Long: "Add MCP servers to a working set.", + Example: ` # Add servers with OCI references + docker mcp workingset server add my-working-set --server docker://mcp/github:latest --server docker://mcp/slack:latest + + # Add servers with MCP Registry references + docker mcp workingset server add my-working-set --server http://registry.modelcontextprotocol.io/v0/servers/71de5a2a-6cfb-4250-a196-f93080ecc860 + + # Mix MCP Registry references and OCI references + docker mcp workingset server add my-working-set --server http://registry.modelcontextprotocol.io/v0/servers/71de5a2a-6cfb-4250-a196-f93080ecc860 --server docker://mcp/github:latest`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + dao, err := db.New() + if err != nil { + return err + } + registryClient := registryapi.NewClient() + ociService := oci.NewService() + return workingset.AddServers(cmd.Context(), dao, registryClient, ociService, args[0], servers) + }, + } + + flags := cmd.Flags() + flags.StringArrayVar(&servers, "server", []string{}, "Server to include: MCP Registry reference or OCI reference with docker:// prefix (can be specified multiple times)") + + return cmd +} + +func removeServerCommand() *cobra.Command { + var names []string + + cmd := &cobra.Command{ + Use: "remove --name --name ...", + Short: "Remove MCP servers from a working set", + Long: "Remove MCP servers from a working set by server name.", + Example: ` # Remove servers by name + docker mcp workingset server remove my-working-set --name github --name slack + + # Remove a single server + docker mcp workingset server remove my-working-set --name github`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + dao, err := db.New() + if err != nil { + return err + } + return workingset.RemoveServers(cmd.Context(), dao, args[0], names) + }, + } + + flags := cmd.Flags() + flags.StringArrayVar(&names, "name", []string{}, "Server name to remove (can be specified multiple times)") + + return cmd +} diff --git a/pkg/workingset/server.go b/pkg/workingset/server.go new file mode 100644 index 00000000..0fa73fef --- /dev/null +++ b/pkg/workingset/server.go @@ -0,0 +1,102 @@ +package workingset + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/docker/mcp-gateway/pkg/db" + "github.com/docker/mcp-gateway/pkg/oci" + "github.com/docker/mcp-gateway/pkg/registryapi" +) + +func AddServers(ctx context.Context, dao db.DAO, registryClient registryapi.Client, ociService oci.Service, id string, servers []string) error { + if len(servers) == 0 { + return fmt.Errorf("at least one server must be specified") + } + + dbWorkingSet, err := dao.GetWorkingSet(ctx, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("working set %s not found", id) + } + return fmt.Errorf("failed to get working set: %w", err) + } + + workingSet := NewFromDb(dbWorkingSet) + + newServers := make([]Server, len(servers)) + for i, server := range servers { + s, err := resolveServerFromString(ctx, registryClient, ociService, server) + if err != nil { + return fmt.Errorf("invalid server value: %w", err) + } + newServers[i] = s + } + + workingSet.Servers = append(workingSet.Servers, newServers...) + + if err := workingSet.Validate(); err != nil { + return fmt.Errorf("invalid working set: %w", err) + } + + err = dao.UpdateWorkingSet(ctx, workingSet.ToDb()) + if err != nil { + return fmt.Errorf("failed to update working set: %w", err) + } + + fmt.Printf("Added %d server(s) to working set %s\n", len(newServers), id) + + return nil +} + +func RemoveServers(ctx context.Context, dao db.DAO, id string, serverNames []string) error { + if len(serverNames) == 0 { + return fmt.Errorf("at least one server must be specified") + } + + dbWorkingSet, err := dao.GetWorkingSet(ctx, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("working set %s not found", id) + } + return fmt.Errorf("failed to get working set: %w", err) + } + + workingSet := NewFromDb(dbWorkingSet) + + namesToRemove := make(map[string]bool) + for _, name := range serverNames { + namesToRemove[name] = true + } + + originalCount := len(workingSet.Servers) + filtered := make([]Server, 0, len(workingSet.Servers)) + for _, server := range workingSet.Servers { + // TODO: Remove when Snapshot is required + if server.Snapshot == nil || !namesToRemove[server.Snapshot.Server.Name] { + filtered = append(filtered, server) + } + } + + removedCount := originalCount - len(filtered) + if removedCount == 0 { + return fmt.Errorf("no matching servers found to remove") + } + + workingSet.Servers = filtered + + if err := workingSet.Validate(); err != nil { + return fmt.Errorf("invalid working set: %w", err) + } + + err = dao.UpdateWorkingSet(ctx, workingSet.ToDb()) + if err != nil { + return fmt.Errorf("failed to update working set: %w", err) + } + + fmt.Printf("Removed %d server(s) from working set %s\n", removedCount, id) + + return nil +} diff --git a/pkg/workingset/server_test.go b/pkg/workingset/server_test.go new file mode 100644 index 00000000..1a0f4ea4 --- /dev/null +++ b/pkg/workingset/server_test.go @@ -0,0 +1,181 @@ +package workingset + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/mcp-gateway/pkg/db" +) + +var oneServerError = "at least one server must be specified" + +func TestAddOneServerToWorkingSet(t *testing.T) { + dao := setupTestDB(t) + ctx := t.Context() + + err := dao.CreateWorkingSet(ctx, db.WorkingSet{ + ID: "test-set", + Name: "Test Working Set", + Servers: db.ServerList{}, + Secrets: db.SecretMap{}, + }) + require.NoError(t, err) + + servers := []string{ + "docker://myimage:latest", + } + + err = AddServers(ctx, dao, getMockRegistryClient(), getMockOciService(), "test-set", servers) + require.NoError(t, err) + + dbSet, err := dao.GetWorkingSet(ctx, "test-set") + require.NoError(t, err) + require.NotNil(t, dbSet) + assert.Equal(t, "My Image", dbSet.Servers[0].Snapshot.Server.Name) +} + +func TestAddMultipleServersToWorkingSet(t *testing.T) { + dao := setupTestDB(t) + ctx := t.Context() + + err := dao.CreateWorkingSet(ctx, db.WorkingSet{ + ID: "test-set", + Name: "Test Working Set", + Servers: db.ServerList{}, + Secrets: db.SecretMap{}, + }) + require.NoError(t, err) + + servers := []string{ + "docker://myimage:latest", + "docker://anotherimage:v1.0", + } + + err = AddServers(ctx, dao, getMockRegistryClient(), getMockOciService(), "test-set", servers) + require.NoError(t, err) + + dbSet, err := dao.GetWorkingSet(ctx, "test-set") + require.NoError(t, err) + require.NotNil(t, dbSet) + assert.Equal(t, "My Image", dbSet.Servers[0].Snapshot.Server.Name) + assert.Equal(t, "Another Image", dbSet.Servers[1].Snapshot.Server.Name) +} + +func TestAddNoServersToWorkingSet(t *testing.T) { + dao := setupTestDB(t) + ctx := t.Context() + + err := dao.CreateWorkingSet(ctx, db.WorkingSet{ + ID: "test-set", + Name: "Test Working Set", + Servers: db.ServerList{}, + Secrets: db.SecretMap{}, + }) + require.NoError(t, err) + + servers := []string{} + + err = AddServers(ctx, dao, getMockRegistryClient(), getMockOciService(), "test-set", servers) + require.Error(t, err) + assert.Contains(t, err.Error(), oneServerError) +} + +func TestRemoveOneServerFromWorkingSet(t *testing.T) { + dao := setupTestDB(t) + ctx := t.Context() + + serverURI := "docker://myimage:latest" + setID := "test-set" + + err := Create(ctx, dao, getMockRegistryClient(), getMockOciService(), "test-set", "test-set", []string{ + serverURI, + }) + require.NoError(t, err) + + dbSet, err := dao.GetWorkingSet(ctx, setID) + require.NoError(t, err) + assert.Len(t, dbSet.Servers, 1) + + err = RemoveServers(ctx, dao, setID, []string{ + "My Image", + }) + require.NoError(t, err) + + dbSet, err = dao.GetWorkingSet(ctx, setID) + require.NoError(t, err) + + assert.Empty(t, dbSet.Servers) +} + +func TestRemoveMultipleServersFromWorkingSet(t *testing.T) { + dao := setupTestDB(t) + ctx := t.Context() + + workingSetID := "test-set" + + servers := []string{ + "docker://myimage:latest", + "docker://anotherimage:v1.0", + } + + err := Create(ctx, dao, getMockRegistryClient(), getMockOciService(), workingSetID, "My Test Set", servers) + require.NoError(t, err) + + dbSet, err := dao.GetWorkingSet(ctx, workingSetID) + require.NoError(t, err) + assert.Len(t, dbSet.Servers, 2) + + err = RemoveServers(ctx, dao, workingSetID, []string{"My Image", "Another Image"}) + require.NoError(t, err) + + dbSet, err = dao.GetWorkingSet(ctx, workingSetID) + require.NoError(t, err) + assert.Empty(t, dbSet.Servers) +} + +func TestRemoveOneOfManyServerFromWorkingSet(t *testing.T) { + dao := setupTestDB(t) + ctx := t.Context() + + workingSetID := "test-set" + + servers := []string{ + "docker://myimage:latest", + "docker://anotherimage:v1.0", + } + + err := Create(ctx, dao, getMockRegistryClient(), getMockOciService(), workingSetID, "My Test Set", servers) + require.NoError(t, err) + + dbSet, err := dao.GetWorkingSet(ctx, workingSetID) + require.NoError(t, err) + assert.Len(t, dbSet.Servers, 2) + + err = RemoveServers(ctx, dao, workingSetID, []string{"My Image"}) + require.NoError(t, err) + + dbSet, err = dao.GetWorkingSet(ctx, workingSetID) + require.NoError(t, err) + assert.Len(t, dbSet.Servers, 1) + assert.Equal(t, "Another Image", dbSet.Servers[0].Snapshot.Server.Name) +} + +func TestRemoveNoServersFromWorkingSet(t *testing.T) { + dao := setupTestDB(t) + ctx := t.Context() + + workingSetID := "test-set" + + servers := []string{ + "docker://myimage:latest", + } + + err := Create(ctx, dao, getMockRegistryClient(), getMockOciService(), workingSetID, "My Test Set", servers) + require.NoError(t, err) + + err = RemoveServers(ctx, dao, workingSetID, []string{}) + require.Error(t, err) + assert.Contains(t, err.Error(), oneServerError) +} diff --git a/pkg/workingset/workingset.go b/pkg/workingset/workingset.go index d3c5cc5f..5c57d499 100644 --- a/pkg/workingset/workingset.go +++ b/pkg/workingset/workingset.go @@ -150,7 +150,27 @@ func (workingSet WorkingSet) ToDb() db.WorkingSet { } func (workingSet *WorkingSet) Validate() error { - return validate.Get().Struct(workingSet) + err := validate.Get().Struct(workingSet) + if err != nil { + return err + } + return workingSet.validateUniqueServerNames() +} + +func (workingSet *WorkingSet) validateUniqueServerNames() error { + seen := make(map[string]bool) + for _, server := range workingSet.Servers { + // TODO: Update when Snapshot is required + if server.Snapshot == nil { + continue + } + name := server.Snapshot.Server.Name + if seen[name] { + return fmt.Errorf("duplicate server name %s", name) + } + seen[name] = true + } + return nil } func (workingSet *WorkingSet) FindServer(serverName string) *Server { diff --git a/pkg/workingset/workingset_test.go b/pkg/workingset/workingset_test.go index 7d906482..40d16b39 100644 --- a/pkg/workingset/workingset_test.go +++ b/pkg/workingset/workingset_test.go @@ -251,6 +251,35 @@ func TestWorkingSetValidate(t *testing.T) { }, expectErr: true, }, + { + name: "duplicate server name", + ws: WorkingSet{ + Version: CurrentWorkingSetVersion, + ID: "test-id", + Name: "Test", + Servers: []Server{ + { + Type: ServerTypeImage, + Image: "myimage:latest", + Snapshot: &ServerSnapshot{ + Server: catalog.Server{ + Name: "mcp.docker.com/test-server", + }, + }, + }, + { + Type: ServerTypeImage, + Image: "myimage:previous", + Snapshot: &ServerSnapshot{ + Server: catalog.Server{ + Name: "mcp.docker.com/test-server", + }, + }, + }, + }, + }, + expectErr: true, + }, } for _, tt := range tests {