From aa8479057ce0e7a5a21606792e47b22aa89037f5 Mon Sep 17 00:00:00 2001 From: Nitish Agarwal Date: Tue, 16 Sep 2025 11:29:40 +0530 Subject: [PATCH] feat: implement user-based daily publish rate limiting (#21) - Rate limit by authenticated user (authMethodSubject) instead of namespace - Admin bypass via hasGlobalPermissions parameter from auth handler - Atomic database operations with separate publish_attempts table - Integrated rate limiting directly into registry service - Support for rate limit exemptions with wildcard patterns - Comprehensive test coverage including concurrent request handling Configuration: - MCP_REGISTRY_RATE_LIMIT_ENABLED: Enable/disable rate limiting (default: true) - MCP_REGISTRY_RATE_LIMIT_PER_DAY: Daily publish limit per user (default: 10) - MCP_REGISTRY_RATE_LIMIT_EXEMPTIONS: Comma-separated exempt users/patterns Database changes: - New table: publish_attempts tracking auth_method_subject instead of namespace - Atomic check-and-increment operation prevents race conditions Testing: - All existing tests updated for new method signatures - New tests for concurrent requests, exemptions, and user-specific limits --- .env.example | 12 + .../guides/administration/admin-operations.md | 24 + docs/guides/publishing/publish-server.md | 3 + docs/reference/faq.md | 20 +- internal/api/handlers/v0/edit_test.go | 6 +- internal/api/handlers/v0/publish.go | 11 +- internal/api/handlers/v0/publish_test.go | 2 +- internal/api/handlers/v0/servers_test.go | 32 +- internal/api/handlers/v0/telemetry_test.go | 2 +- internal/config/config.go | 5 + internal/database/database.go | 7 + internal/database/memory.go | 61 ++- .../005_add_publish_rate_limiting.sql | 21 + internal/database/postgres.go | 74 +++ internal/database/rate_limit_db_test.go | 483 ++++++++++++++++++ internal/service/publish_rate_limit_test.go | 314 ++++++++++++ internal/service/registry_service.go | 71 ++- internal/service/registry_service_test.go | 2 +- internal/service/service.go | 2 +- .../docker-compose.integration-test.yml | 1 + tests/integration/main.go | 64 +++ 21 files changed, 1184 insertions(+), 33 deletions(-) create mode 100644 internal/database/migrations/005_add_publish_rate_limiting.sql create mode 100644 internal/database/rate_limit_db_test.go create mode 100644 internal/service/publish_rate_limit_test.go diff --git a/.env.example b/.env.example index 24d32d71..390bf636 100644 --- a/.env.example +++ b/.env.example @@ -39,3 +39,15 @@ MCP_REGISTRY_OIDC_EXTRA_CLAIMS=[{"hd":"modelcontextprotocol.io"}] # Grant admin permissions to OIDC-authenticated users MCP_REGISTRY_OIDC_EDIT_PERMISSIONS=* MCP_REGISTRY_OIDC_PUBLISH_PERMISSIONS=* + +# Rate Limiting Configuration +# Enable/disable rate limiting for publish operations +MCP_REGISTRY_RATE_LIMIT_ENABLED=true + +# Maximum number of servers a user can publish per day +MCP_REGISTRY_RATE_LIMIT_PER_DAY=10 + +# Comma-separated list of authenticated users (auth subjects) exempt from rate limiting +# Supports wildcards: anthropic/* to exempt all users under anthropic domain +# Examples: modelcontextprotocol, anthropic/*, specific-username +MCP_REGISTRY_RATE_LIMIT_EXEMPTIONS= diff --git a/docs/guides/administration/admin-operations.md b/docs/guides/administration/admin-operations.md index 0d9f8c37..544c590d 100644 --- a/docs/guides/administration/admin-operations.md +++ b/docs/guides/administration/admin-operations.md @@ -44,3 +44,27 @@ export SERVER_ID="" ``` This soft deletes the server. If you need to delete the content of a server (usually only where legally necessary), use the edit workflow above to scrub it all. + +## Rate Limiting Configuration + +The registry enforces daily publish rate limits to prevent abuse: + +### Environment Variables + +- `MCP_REGISTRY_RATE_LIMIT_ENABLED`: Enable/disable rate limiting (default: true) +- `MCP_REGISTRY_RATE_LIMIT_PER_DAY`: Maximum publishes per user per day (default: 10) +- `MCP_REGISTRY_RATE_LIMIT_EXEMPTIONS`: Comma-separated list of exempt users or patterns + +### Exemption Patterns + +Exemptions support wildcard patterns: +- Exact match: `anthropic` (exempts user "anthropic") +- Wildcard: `anthropic/*` (exempts "anthropic", "anthropic.claude", etc.) +- Multiple exemptions: `anthropic/*,modelcontextprotocol,github/*` + +### Notes + +- Rate limits are per authenticated user (not per namespace) +- Users with global admin permissions automatically bypass rate limits +- Limits reset on a rolling 24-hour window +- The counter is stored in the `publish_attempts` database table diff --git a/docs/guides/publishing/publish-server.md b/docs/guides/publishing/publish-server.md index 69bd6d2f..a44bb267 100644 --- a/docs/guides/publishing/publish-server.md +++ b/docs/guides/publishing/publish-server.md @@ -414,6 +414,9 @@ With authentication complete, publish your server: mcp-publisher publish ``` +> [!NOTE] +> **Rate Limits**: The registry enforces a limit of 10 publishes per user per day to prevent abuse. If you exceed this limit, you'll receive an error message with your current count. If you need a higher limit for legitimate use cases, please [open an issue](https://github.com/modelcontextprotocol/registry/issues). + You'll see output like: ``` ✓ Successfully published diff --git a/docs/reference/faq.md b/docs/reference/faq.md index c45520e1..8f623f6c 100644 --- a/docs/reference/faq.md +++ b/docs/reference/faq.md @@ -90,6 +90,18 @@ Yes, extensions under the `x-publisher` property are preserved when publishing t At time of last update, this was open for discussion in [#104](https://github.com/modelcontextprotocol/registry/issues/104). +### What are the rate limits for publishing? + +The registry enforces daily rate limits to prevent abuse: + +- **Default limit**: 10 publishes per authenticated user per day (rolling 24-hour window) +- **Who is affected**: All users except those with global admin permissions +- **What counts**: Each successful publish counts toward your daily limit +- **Exemptions**: Specific users or organizations can be exempted from rate limiting +- **Error message**: If you exceed the limit, you'll receive an error with your current count + +If you need a higher limit for legitimate use cases, please open an issue at https://github.com/modelcontextprotocol/registry/issues + ### Can I publish a private server? Private servers are those that are only accessible to a narrow set of users. For example, servers published on a private network (like `mcp.acme-corp.internal`) or on private package registries (e.g. `npx -y @acme/mcp --registry https://artifactory.acme-corp.internal/npm`). @@ -118,9 +130,15 @@ The MVP delegates security scanning to: - Namespace authentication requirements - Character limits and regex validation on free-form fields - Manual takedown of spam or malicious servers +- Daily publish rate limiting per authenticated user (10 publishes per day by default) + +The rate limiting system: +- Limits are per authenticated user (not per namespace) +- Default limit is 10 publishes per 24-hour period +- Administrators with global permissions bypass rate limits +- Specific users or patterns can be exempted from rate limiting In future we might explore: -- Stricter rate limiting (e.g., 10 new servers per user per day) - Potential AI-based spam detection - Community reporting and admin blacklisting capabilities diff --git a/internal/api/handlers/v0/edit_test.go b/internal/api/handlers/v0/edit_test.go index ef8562f2..d5b142d3 100644 --- a/internal/api/handlers/v0/edit_test.go +++ b/internal/api/handlers/v0/edit_test.go @@ -36,7 +36,7 @@ func TestEditServerEndpoint(t *testing.T) { }, Version: "1.0.0", } - published, err := registryService.Publish(testServer) + published, err := registryService.Publish(testServer, "testuser", false) assert.NoError(t, err) assert.NotNil(t, published) assert.NotNil(t, published.Meta) @@ -56,7 +56,7 @@ func TestEditServerEndpoint(t *testing.T) { }, Version: "1.0.0", } - otherPublished, err := registryService.Publish(otherServer) + otherPublished, err := registryService.Publish(otherServer, "testuser", false) assert.NoError(t, err) assert.NotNil(t, otherPublished) assert.NotNil(t, otherPublished.Meta) @@ -76,7 +76,7 @@ func TestEditServerEndpoint(t *testing.T) { }, Version: "1.0.0", } - deletedPublished, err := registryService.Publish(deletedServer) + deletedPublished, err := registryService.Publish(deletedServer, "testuser", false) assert.NoError(t, err) assert.NotNil(t, deletedPublished) assert.NotNil(t, deletedPublished.Meta) diff --git a/internal/api/handlers/v0/publish.go b/internal/api/handlers/v0/publish.go index 864f308d..302cde55 100644 --- a/internal/api/handlers/v0/publish.go +++ b/internal/api/handlers/v0/publish.go @@ -53,8 +53,17 @@ func RegisterPublishEndpoint(api huma.API, registry service.RegistryService, cfg return nil, huma.Error403Forbidden(buildPermissionErrorMessage(input.Body.Name, claims.Permissions)) } + // Check if user has global permissions (admin) + hasGlobalPermissions := false + for _, perm := range claims.Permissions { + if perm.ResourcePattern == "*" { + hasGlobalPermissions = true + break + } + } + // Publish the server with extensions - publishedServer, err := registry.Publish(input.Body) + publishedServer, err := registry.Publish(input.Body, claims.AuthMethodSubject, hasGlobalPermissions) if err != nil { return nil, huma.Error400BadRequest("Failed to publish server", err) } diff --git a/internal/api/handlers/v0/publish_test.go b/internal/api/handlers/v0/publish_test.go index 1811301b..5a377f38 100644 --- a/internal/api/handlers/v0/publish_test.go +++ b/internal/api/handlers/v0/publish_test.go @@ -192,7 +192,7 @@ func TestPublishEndpoint(t *testing.T) { ID: "example/test-server-existing", }, } - _, _ = registry.Publish(existingServer) + _, _ = registry.Publish(existingServer, "testuser", false) }, expectedStatus: http.StatusBadRequest, expectedError: "invalid version: cannot publish duplicate version", diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go index 5b96e253..19e2fe9f 100644 --- a/internal/api/handlers/v0/servers_test.go +++ b/internal/api/handlers/v0/servers_test.go @@ -52,8 +52,8 @@ func TestServersListEndpoint(t *testing.T) { }, Version: "2.0.0", } - _, _ = registry.Publish(server1) - _, _ = registry.Publish(server2) + _, _ = registry.Publish(server1, "testuser", false) + _, _ = registry.Publish(server2, "testuser", false) }, expectedStatus: http.StatusOK, }, @@ -71,7 +71,7 @@ func TestServersListEndpoint(t *testing.T) { }, Version: "1.5.0", } - _, _ = registry.Publish(server) + _, _ = registry.Publish(server, "testuser", false) }, expectedStatus: http.StatusOK, }, @@ -141,8 +141,8 @@ func TestServersListEndpoint(t *testing.T) { }, Version: "1.0.0", } - _, _ = registry.Publish(server1) - _, _ = registry.Publish(server2) + _, _ = registry.Publish(server1, "testuser", false) + _, _ = registry.Publish(server2, "testuser", false) }, expectedStatus: http.StatusOK, }, @@ -160,7 +160,7 @@ func TestServersListEndpoint(t *testing.T) { }, Version: "1.0.0", } - _, _ = registry.Publish(server) + _, _ = registry.Publish(server, "testuser", false) }, expectedStatus: http.StatusOK, }, @@ -188,8 +188,8 @@ func TestServersListEndpoint(t *testing.T) { }, Version: "2.0.0", } - _, _ = registry.Publish(server1) - _, _ = registry.Publish(server2) // This will be marked as latest + _, _ = registry.Publish(server1, "testuser", false) + _, _ = registry.Publish(server2, "testuser", false) // This will be marked as latest }, expectedStatus: http.StatusOK, }, @@ -217,8 +217,8 @@ func TestServersListEndpoint(t *testing.T) { }, Version: "1.0.0", } - _, _ = registry.Publish(server1) - _, _ = registry.Publish(server2) + _, _ = registry.Publish(server1, "testuser", false) + _, _ = registry.Publish(server2, "testuser", false) }, expectedStatus: http.StatusOK, }, @@ -274,10 +274,10 @@ func TestServersListEndpoint(t *testing.T) { }, Version: "3.0.0", } - _, _ = registry.Publish(server1v1) - _, _ = registry.Publish(server1v2) - _, _ = registry.Publish(server2) - _, _ = registry.Publish(server3) + _, _ = registry.Publish(server1v1, "testuser", false) + _, _ = registry.Publish(server1v2, "testuser", false) + _, _ = registry.Publish(server2, "testuser", false) + _, _ = registry.Publish(server3, "testuser", false) }, expectedStatus: http.StatusOK, }, @@ -384,7 +384,7 @@ func TestServersDetailEndpoint(t *testing.T) { Name: "com.example/test-server", Description: "A test server", Version: "1.0.0", - }) + }, "testuser", false) assert.NoError(t, err) testCases := []struct { @@ -472,7 +472,7 @@ func TestServersEndpointsIntegration(t *testing.T) { Version: "1.0.0", } - published, err := registryService.Publish(testServer) + published, err := registryService.Publish(testServer, "testuser", false) assert.NoError(t, err) assert.NotNil(t, published) diff --git a/internal/api/handlers/v0/telemetry_test.go b/internal/api/handlers/v0/telemetry_test.go index 8b00127c..0dff92f4 100644 --- a/internal/api/handlers/v0/telemetry_test.go +++ b/internal/api/handlers/v0/telemetry_test.go @@ -31,7 +31,7 @@ func TestPrometheusHandler(t *testing.T) { ID: "example/test-server", }, Version: "2.0.0", - }) + }, "testuser", false) assert.NoError(t, err) cfg := config.NewConfig() diff --git a/internal/config/config.go b/internal/config/config.go index 38d77d13..646b245b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -33,6 +33,11 @@ type Config struct { OIDCExtraClaims string `env:"OIDC_EXTRA_CLAIMS" envDefault:""` OIDCEditPerms string `env:"OIDC_EDIT_PERMISSIONS" envDefault:""` OIDCPublishPerms string `env:"OIDC_PUBLISH_PERMISSIONS" envDefault:""` + + // Rate Limiting Configuration + RateLimitEnabled bool `env:"RATE_LIMIT_ENABLED" envDefault:"true"` + RateLimitPerDay int `env:"RATE_LIMIT_PER_DAY" envDefault:"10"` + RateLimitExemptions string `env:"RATE_LIMIT_EXEMPTIONS" envDefault:""` // comma-separated } // NewConfig creates a new configuration with default values diff --git a/internal/database/database.go b/internal/database/database.go index 0f9dc88a..97f5fed3 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -40,6 +40,13 @@ type Database interface { UpdateServer(ctx context.Context, id string, server *apiv0.ServerJSON) (*apiv0.ServerJSON, error) // Close closes the database connection Close() error + + // Rate limiting methods + IncrementPublishCount(ctx context.Context, authMethodSubject string) error + GetPublishCount(ctx context.Context, authMethodSubject string, date time.Time) (int, error) + // CheckAndIncrementPublishCount atomically checks if the count is under the limit and increments if so + // Returns the current count and whether the increment was successful + CheckAndIncrementPublishCount(ctx context.Context, authMethodSubject string, limit int) (currentCount int, incrementSuccessful bool, err error) } // ConnectionType represents the type of database connection diff --git a/internal/database/memory.go b/internal/database/memory.go index d0ba3a62..df7a2e05 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -6,21 +6,24 @@ import ( "sort" "strings" "sync" + "time" apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" ) // MemoryDB is an in-memory implementation of the Database interface type MemoryDB struct { - entries map[string]*apiv0.ServerJSON // maps registry metadata ID to ServerJSON - mu sync.RWMutex + entries map[string]*apiv0.ServerJSON // maps registry metadata ID to ServerJSON + publishAttempts map[string]map[string]int // authMethodSubject -> date -> count + mu sync.RWMutex } func NewMemoryDB() *MemoryDB { // Convert input ServerJSON entries to have proper metadata serverRecords := make(map[string]*apiv0.ServerJSON) return &MemoryDB{ - entries: serverRecords, + entries: serverRecords, + publishAttempts: make(map[string]map[string]int), } } @@ -139,6 +142,58 @@ func (db *MemoryDB) UpdateServer(ctx context.Context, id string, server *apiv0.S return server, nil } +// IncrementPublishCount increments the publish count for an authenticated user today +func (db *MemoryDB) IncrementPublishCount(_ context.Context, authMethodSubject string) error { + db.mu.Lock() + defer db.mu.Unlock() + + today := time.Now().Format(time.DateOnly) + if db.publishAttempts[authMethodSubject] == nil { + db.publishAttempts[authMethodSubject] = make(map[string]int) + } + db.publishAttempts[authMethodSubject][today]++ + return nil +} + +// GetPublishCount returns the number of publishes for an authenticated user on a specific date +func (db *MemoryDB) GetPublishCount(_ context.Context, authMethodSubject string, date time.Time) (int, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + dateStr := date.Format(time.DateOnly) + if db.publishAttempts[authMethodSubject] == nil { + return 0, nil + } + return db.publishAttempts[authMethodSubject][dateStr], nil +} + +// CheckAndIncrementPublishCount atomically checks if the count is under the limit and increments if so +func (db *MemoryDB) CheckAndIncrementPublishCount(_ context.Context, authMethodSubject string, limit int) (currentCount int, incrementSuccessful bool, err error) { + db.mu.Lock() + defer db.mu.Unlock() + + today := time.Now().Format(time.DateOnly) + + // Initialize authMethodSubject map if needed + if db.publishAttempts[authMethodSubject] == nil { + db.publishAttempts[authMethodSubject] = make(map[string]int) + } + + // Get current count + currentCount = db.publishAttempts[authMethodSubject][today] + + // Check if under limit and increment if so + if currentCount < limit { + db.publishAttempts[authMethodSubject][today]++ + currentCount++ // Return the new count after increment + incrementSuccessful = true + } else { + incrementSuccessful = false + } + + return currentCount, incrementSuccessful, nil +} + // For an in-memory database, this is a no-op func (db *MemoryDB) Close() error { return nil diff --git a/internal/database/migrations/005_add_publish_rate_limiting.sql b/internal/database/migrations/005_add_publish_rate_limiting.sql new file mode 100644 index 00000000..72b19655 --- /dev/null +++ b/internal/database/migrations/005_add_publish_rate_limiting.sql @@ -0,0 +1,21 @@ +-- Add rate limiting table to track publish attempts by authenticated user and date +CREATE TABLE publish_attempts ( + auth_method_subject VARCHAR(255) NOT NULL, + attempt_date DATE NOT NULL DEFAULT CURRENT_DATE, + attempt_count INTEGER NOT NULL DEFAULT 0, + first_attempt_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + last_attempt_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + PRIMARY KEY (auth_method_subject, attempt_date) +); + +-- Index for efficient lookups by auth_method_subject +CREATE INDEX idx_publish_attempts_auth_subject ON publish_attempts(auth_method_subject); + +-- Index for cleanup queries by date +CREATE INDEX idx_publish_attempts_date ON publish_attempts(attempt_date); + +-- Comment for documentation +COMMENT ON TABLE publish_attempts IS 'Tracks daily publish attempts per authenticated user for rate limiting'; +COMMENT ON COLUMN publish_attempts.auth_method_subject IS 'The authenticated user identifier (e.g., GitHub username)'; +COMMENT ON COLUMN publish_attempts.attempt_date IS 'The date of the attempts (resets daily)'; +COMMENT ON COLUMN publish_attempts.attempt_count IS 'Number of successful publishes on this date'; \ No newline at end of file diff --git a/internal/database/postgres.go b/internal/database/postgres.go index 420fc1e4..6a57b843 100644 --- a/internal/database/postgres.go +++ b/internal/database/postgres.go @@ -281,6 +281,80 @@ func (db *PostgreSQL) UpdateServer(ctx context.Context, id string, server *apiv0 return server, nil } +// IncrementPublishCount atomically increments the publish count for an authenticated user today +func (db *PostgreSQL) IncrementPublishCount(ctx context.Context, authMethodSubject string) error { + query := ` + INSERT INTO publish_attempts (auth_method_subject, attempt_date, attempt_count, first_attempt_at, last_attempt_at) + VALUES ($1, CURRENT_DATE, 1, NOW(), NOW()) + ON CONFLICT (auth_method_subject, attempt_date) + DO UPDATE SET + attempt_count = publish_attempts.attempt_count + 1, + last_attempt_at = NOW() + ` + _, err := db.pool.Exec(ctx, query, authMethodSubject) + if err != nil { + return fmt.Errorf("failed to increment publish count: %w", err) + } + return nil +} + +// GetPublishCount returns the number of publishes for an authenticated user on a specific date +func (db *PostgreSQL) GetPublishCount(ctx context.Context, authMethodSubject string, date time.Time) (int, error) { + var count int + query := ` + SELECT COALESCE(attempt_count, 0) + FROM publish_attempts + WHERE auth_method_subject = $1 AND attempt_date = $2 + ` + err := db.pool.QueryRow(ctx, query, authMethodSubject, date.Format(time.DateOnly)).Scan(&count) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return 0, nil // No attempts yet + } + return 0, fmt.Errorf("failed to get publish count: %w", err) + } + return count, nil +} + +// CheckAndIncrementPublishCount atomically checks if the count is under the limit and increments if so +func (db *PostgreSQL) CheckAndIncrementPublishCount(ctx context.Context, authMethodSubject string, limit int) (currentCount int, incrementSuccessful bool, err error) { + // Use a CTE to atomically check and increment + query := ` + WITH current_state AS ( + SELECT COALESCE(attempt_count, 0) as count + FROM publish_attempts + WHERE auth_method_subject = $1 AND attempt_date = CURRENT_DATE + UNION ALL + SELECT 0 WHERE NOT EXISTS ( + SELECT 1 FROM publish_attempts + WHERE auth_method_subject = $1 AND attempt_date = CURRENT_DATE + ) + LIMIT 1 + ), + increment AS ( + INSERT INTO publish_attempts (auth_method_subject, attempt_date, attempt_count, first_attempt_at, last_attempt_at) + SELECT $1, CURRENT_DATE, 1, NOW(), NOW() + WHERE (SELECT count FROM current_state) < $2 + ON CONFLICT (auth_method_subject, attempt_date) + DO UPDATE SET + attempt_count = publish_attempts.attempt_count + 1, + last_attempt_at = NOW() + WHERE publish_attempts.attempt_count < $2 + RETURNING attempt_count + ) + SELECT + COALESCE((SELECT attempt_count FROM increment), (SELECT count FROM current_state)) as final_count, + EXISTS (SELECT 1 FROM increment) as was_incremented + ` + + err = db.pool.QueryRow(ctx, query, authMethodSubject, limit).Scan(¤tCount, &incrementSuccessful) + if err != nil { + return 0, false, fmt.Errorf("failed to check and increment publish count: %w", err) + } + + return currentCount, incrementSuccessful, nil +} + // Close closes the database connection func (db *PostgreSQL) Close() error { db.pool.Close() diff --git a/internal/database/rate_limit_db_test.go b/internal/database/rate_limit_db_test.go new file mode 100644 index 00000000..f0f04711 --- /dev/null +++ b/internal/database/rate_limit_db_test.go @@ -0,0 +1,483 @@ +package database_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testAuthMethodSubject = "io.github.testuser" + +// TestDatabaseRateLimitPersistence tests that rate limit data persists correctly +func TestDatabaseRateLimitPersistence(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + // PostgreSQL tests would require a test database connection + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + authMethodSubject := testAuthMethodSubject + today := time.Now() + + // Initial count should be 0 + count, err := tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Initial count should be 0") + + // Increment count + err = tc.db.IncrementPublishCount(ctx, authMethodSubject) + assert.NoError(t, err) + + // Count should be 1 + count, err = tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, 1, count, "Count should be 1 after first increment") + + // Increment multiple times + for i := 0; i < 4; i++ { + err = tc.db.IncrementPublishCount(ctx, authMethodSubject) + assert.NoError(t, err) + } + + // Count should be 5 + count, err = tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, 5, count, "Count should be 5 after 5 increments") + + // Different authMethodSubject should have independent count + authMethodSubject2 := "io.github.otheruser" + count, err = tc.db.GetPublishCount(ctx, authMethodSubject2, today) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Different authMethodSubject should have 0 count") + + err = tc.db.IncrementPublishCount(ctx, authMethodSubject2) + assert.NoError(t, err) + + count, err = tc.db.GetPublishCount(ctx, authMethodSubject2, today) + assert.NoError(t, err) + assert.Equal(t, 1, count, "Different authMethodSubject should have independent count") + + // Original authMethodSubject count should be unchanged + count, err = tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, 5, count, "Original authMethodSubject count should be unchanged") + }) + } +} + +// TestDatabaseRateLimitDateIsolation tests that counts are isolated by date +func TestDatabaseRateLimitDateIsolation(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + authMethodSubject := testAuthMethodSubject + today := time.Now() + yesterday := today.AddDate(0, 0, -1) + tomorrow := today.AddDate(0, 0, 1) + + // Increment count for today + for i := 0; i < 3; i++ { + err := tc.db.IncrementPublishCount(ctx, authMethodSubject) + assert.NoError(t, err) + } + + // Today's count should be 3 + count, err := tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, 3, count, "Today's count should be 3") + + // Yesterday's count should be 0 + count, err = tc.db.GetPublishCount(ctx, authMethodSubject, yesterday) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Yesterday's count should be 0") + + // Tomorrow's count should be 0 + count, err = tc.db.GetPublishCount(ctx, authMethodSubject, tomorrow) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Tomorrow's count should be 0") + }) + } +} + +// TestDatabaseRateLimitConcurrency tests concurrent increments +func TestDatabaseRateLimitConcurrency(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + authMethodSubject := testAuthMethodSubject + today := time.Now() + numGoroutines := 100 + incrementsPerGoroutine := 10 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Launch concurrent increments + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < incrementsPerGoroutine; j++ { + err := tc.db.IncrementPublishCount(ctx, authMethodSubject) + if err != nil { + t.Errorf("Unexpected error during increment: %v", err) + } + } + }() + } + + wg.Wait() + + // Verify final count is correct + expectedCount := numGoroutines * incrementsPerGoroutine + count, err := tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, expectedCount, count, "Final count should be %d", expectedCount) + }) + } +} + +// TestDatabaseRateLimitMultipleAuthMethodSubjects tests operations across multiple authMethodSubjects +func TestDatabaseRateLimitMultipleAuthMethodSubjects(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + today := time.Now() + authMethodSubjects := []string{ + "io.github.user1", + "io.github.user2", + "com.example.app", + "org.nonprofit.project", + } + + // Increment each authMethodSubject a different number of times + for i, ns := range authMethodSubjects { + for j := 0; j <= i; j++ { + err := tc.db.IncrementPublishCount(ctx, ns) + assert.NoError(t, err) + } + } + + // Verify each authMethodSubject has the correct count + for i, ns := range authMethodSubjects { + count, err := tc.db.GetPublishCount(ctx, ns, today) + assert.NoError(t, err) + assert.Equal(t, i+1, count, "AuthMethodSubject %s should have count %d", ns, i+1) + } + }) + } +} + +// TestDatabaseRateLimitAtomicity tests that increments are atomic +func TestDatabaseRateLimitAtomicity(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + authMethodSubject := testAuthMethodSubject + today := time.Now() + + // Run many concurrent increments and reads + numOperations := 1000 + var wg sync.WaitGroup + wg.Add(numOperations * 2) + + incrementCount := 0 + var mu sync.Mutex + + // Half doing increments + for i := 0; i < numOperations; i++ { + go func() { + defer wg.Done() + err := tc.db.IncrementPublishCount(ctx, authMethodSubject) + if err == nil { + mu.Lock() + incrementCount++ + mu.Unlock() + } + }() + } + + // Half doing reads (to test read/write races) + counts := make([]int, numOperations) + for i := 0; i < numOperations; i++ { + go func(idx int) { + defer wg.Done() + count, _ := tc.db.GetPublishCount(ctx, authMethodSubject, today) + counts[idx] = count + }(i) + } + + wg.Wait() + + // Final count should match successful increments + finalCount, err := tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, incrementCount, finalCount, "Final count should match successful increments") + + // All read counts should be valid (between 0 and final count) + for _, count := range counts { + assert.GreaterOrEqual(t, count, 0, "Count should never be negative") + assert.LessOrEqual(t, count, finalCount, "Count should never exceed final count") + } + }) + } +} + +// TestDatabaseRateLimitZeroAndNegative tests edge cases with zero and boundary values +func TestDatabaseRateLimitZeroAndNegative(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + today := time.Now() + + // Empty authMethodSubject (edge case) + emptyNS := "" + count, err := tc.db.GetPublishCount(ctx, emptyNS, today) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Empty authMethodSubject should return 0") + + _ = tc.db.IncrementPublishCount(ctx, emptyNS) + // The behavior for empty authMethodSubject might vary by implementation + // Just ensure it doesn't panic + + // Very long authMethodSubject + longNS := "com." + string(make([]byte, 1000)) + count, err = tc.db.GetPublishCount(ctx, longNS, today) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Long authMethodSubject should return 0 initially") + + // Special characters in authMethodSubject + specialNS := "io.github.user-_.test@#$%" + err = tc.db.IncrementPublishCount(ctx, specialNS) + assert.NoError(t, err) + + count, err = tc.db.GetPublishCount(ctx, specialNS, today) + assert.NoError(t, err) + assert.Equal(t, 1, count, "Special character authMethodSubject should work") + }) + } +} + +// TestDatabaseRateLimitTimestamps tests timestamp recording functionality +func TestDatabaseRateLimitTimestamps(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + authMethodSubject := testAuthMethodSubject + today := time.Now() + + // First increment + beforeFirst := time.Now() + err := tc.db.IncrementPublishCount(ctx, authMethodSubject) + require.NoError(t, err) + afterFirst := time.Now() + + // Small delay + time.Sleep(10 * time.Millisecond) + + // Second increment + beforeSecond := time.Now() + err = tc.db.IncrementPublishCount(ctx, authMethodSubject) + require.NoError(t, err) + afterSecond := time.Now() + + // Verify count is correct + count, err := tc.db.GetPublishCount(ctx, authMethodSubject, today) + assert.NoError(t, err) + assert.Equal(t, 2, count) + + // Note: Actual timestamp verification would require access to the + // first_attempt_at and last_attempt_at fields, which aren't exposed + // through the current interface. This test ensures the operations + // complete successfully with timing considerations. + _ = beforeFirst + _ = afterFirst + _ = beforeSecond + _ = afterSecond + }) + } +} + +// TestDatabaseRateLimitLargeScale tests with a large number of authMethodSubjects and operations +func TestDatabaseRateLimitLargeScale(t *testing.T) { + if testing.Short() { + t.Skip("Skipping large scale test in short mode") + } + + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + today := time.Now() + numAuthMethodSubjects := 100 + incrementsPerAuthMethodSubject := 50 + + // Create many authMethodSubjects + authMethodSubjects := make([]string, numAuthMethodSubjects) + for i := 0; i < numAuthMethodSubjects; i++ { + authMethodSubjects[i] = fmt.Sprintf("io.github.user%d", i) + } + + // Increment each authMethodSubject multiple times concurrently + var wg sync.WaitGroup + wg.Add(numAuthMethodSubjects) + + for _, ns := range authMethodSubjects { + go func(authMethodSubject string, expectedCount int) { + defer wg.Done() + for j := 0; j < expectedCount; j++ { + err := tc.db.IncrementPublishCount(ctx, authMethodSubject) + if err != nil { + t.Errorf("Error incrementing %s: %v", authMethodSubject, err) + } + } + }(ns, incrementsPerAuthMethodSubject) + } + + wg.Wait() + + // Verify all counts + for _, ns := range authMethodSubjects { + count, err := tc.db.GetPublishCount(ctx, ns, today) + assert.NoError(t, err) + assert.Equal(t, incrementsPerAuthMethodSubject, count, "AuthMethodSubject %s should have correct count", ns) + } + }) + } +} + +// TestDatabaseRateLimitDateBoundaries tests behavior around date boundaries +func TestDatabaseRateLimitDateBoundaries(t *testing.T) { + testCases := []struct { + name string + db database.Database + }{ + { + name: "MemoryDB", + db: database.NewMemoryDB(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + authMethodSubject := testAuthMethodSubject + + // Test with various dates + dates := []time.Time{ + time.Now(), // Today + time.Now().AddDate(0, 0, -1), // Yesterday + time.Now().AddDate(0, 0, 1), // Tomorrow + time.Now().AddDate(-1, 0, 0), // Last year + time.Now().AddDate(1, 0, 0), // Next year + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), // Y2K + time.Date(2038, 1, 19, 3, 14, 7, 0, time.UTC), // Near Unix timestamp limit + } + + for _, date := range dates { + count, err := tc.db.GetPublishCount(ctx, authMethodSubject, date) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Count for date %v should be 0", date) + } + + // Increment today's count + err := tc.db.IncrementPublishCount(ctx, authMethodSubject) + assert.NoError(t, err) + + // Only today should have a count + count, err := tc.db.GetPublishCount(ctx, authMethodSubject, time.Now()) + assert.NoError(t, err) + assert.Equal(t, 1, count, "Today's count should be 1") + + // All other dates should still be 0 + for _, date := range dates[1:] { + count, err = tc.db.GetPublishCount(ctx, authMethodSubject, date) + assert.NoError(t, err) + assert.Equal(t, 0, count, "Count for date %v should still be 0", date) + } + }) + } +} + diff --git a/internal/service/publish_rate_limit_test.go b/internal/service/publish_rate_limit_test.go new file mode 100644 index 00000000..b8c12160 --- /dev/null +++ b/internal/service/publish_rate_limit_test.go @@ -0,0 +1,314 @@ +package service_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/config" + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/service" + apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" + "github.com/modelcontextprotocol/registry/pkg/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createTestServer(name string, version string) apiv0.ServerJSON { + return apiv0.ServerJSON{ + Name: name, + Description: "A test server", + Repository: model.Repository{ + URL: "https://github.com/testuser/test-repo", + Source: "github", + ID: "testuser/test-repo", + }, + Version: version, + } +} + +func TestPublish_RateLimiting(t *testing.T) { + tests := []struct { + name string + authMethodSubject string + hasGlobalPermissions bool + existingCount int + limit int + enabled bool + exemptions string + expectError bool + errorContains string + }{ + { + name: "under limit allows publish", + authMethodSubject: "testuser", + hasGlobalPermissions: false, + existingCount: 5, + limit: 10, + enabled: true, + expectError: false, + }, + { + name: "at limit blocks publish", + authMethodSubject: "testuser", + hasGlobalPermissions: false, + existingCount: 10, + limit: 10, + enabled: true, + expectError: true, + errorContains: "rate limit exceeded", + }, + { + name: "disabled allows any", + authMethodSubject: "testuser", + hasGlobalPermissions: false, + existingCount: 100, + limit: 10, + enabled: false, + expectError: false, + }, + { + name: "global permissions bypasses", + authMethodSubject: "testuser", + hasGlobalPermissions: true, + existingCount: 100, + limit: 10, + enabled: true, + expectError: false, + }, + { + name: "exempt user bypasses", + authMethodSubject: "exemptuser", + hasGlobalPermissions: false, + existingCount: 100, + limit: 10, + enabled: true, + exemptions: "exemptuser", + expectError: false, + }, + { + name: "wildcard exemption works", + authMethodSubject: "anthropic.claude", + hasGlobalPermissions: false, + existingCount: 100, + limit: 10, + enabled: true, + exemptions: "anthropic/*", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + db := database.NewMemoryDB() + cfg := &config.Config{ + RateLimitEnabled: tt.enabled, + RateLimitPerDay: tt.limit, + RateLimitExemptions: tt.exemptions, + EnableRegistryValidation: false, + } + + // Pre-populate existing attempts + ctx := context.Background() + for i := 0; i < tt.existingCount; i++ { + err := db.IncrementPublishCount(ctx, tt.authMethodSubject) + require.NoError(t, err) + } + + // Test + svc := service.NewRegistryService(db, cfg) + testServer := createTestServer("io.github.test/server", "1.0.0") + _, err := svc.Publish(testServer, tt.authMethodSubject, tt.hasGlobalPermissions) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPublish_ConcurrentRateLimiting(t *testing.T) { + ctx := context.Background() + db := database.NewMemoryDB() + cfg := &config.Config{ + RateLimitEnabled: true, + RateLimitPerDay: 10, + EnableRegistryValidation: false, + } + svc := service.NewRegistryService(db, cfg) + + authMethodSubject := "testuser" + numGoroutines := 20 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + successCount := 0 + var mu sync.Mutex + errors := make([]error, 0) + + // Launch concurrent publish attempts + for i := 0; i < numGoroutines; i++ { + go func(version int) { + defer wg.Done() + testServer := createTestServer("io.github.test/server", fmt.Sprintf("1.0.%d", version)) + _, err := svc.Publish(testServer, authMethodSubject, false) + mu.Lock() + defer mu.Unlock() + if err == nil { + successCount++ + } else { + errors = append(errors, err) + } + }(i) + } + + wg.Wait() + + // With the atomic check-and-increment operation, exactly 10 should succeed + assert.Equal(t, 10, successCount, "Expected exactly 10 successful publishes") + assert.Equal(t, 10, len(errors), "Expected exactly 10 rate limit errors") + + // Verify all errors are rate limit errors + for _, err := range errors { + assert.Contains(t, err.Error(), "rate limit exceeded") + } + + // Verify the database count is exactly 10 + count, err := db.GetPublishCount(ctx, authMethodSubject, time.Now()) + require.NoError(t, err) + assert.Equal(t, 10, count, "Database should show exactly 10 publishes") +} + +func TestPublish_DifferentUsersIndependentLimits(t *testing.T) { + cfg := &config.Config{ + RateLimitEnabled: true, + RateLimitPerDay: 3, + EnableRegistryValidation: false, + } + db := database.NewMemoryDB() + svc := service.NewRegistryService(db, cfg) + + user1 := "user1" + user2 := "user2" + + // Fill up user1's limit + for i := 0; i < 3; i++ { + testServer := createTestServer("io.github.test/server1", fmt.Sprintf("1.0.%d", i)) + _, err := svc.Publish(testServer, user1, false) + assert.NoError(t, err) + } + + // user1 should be blocked + testServer := createTestServer("io.github.test/server1", "1.0.99") + _, err := svc.Publish(testServer, user1, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "rate limit exceeded") + + // user2 should still have full quota + for i := 0; i < 3; i++ { + testServer := createTestServer("io.github.test/server2", fmt.Sprintf("1.0.%d", i)) + _, err := svc.Publish(testServer, user2, false) + assert.NoError(t, err, "user2 publish %d should succeed", i+1) + } + + // Now user2 should also be blocked + testServer = createTestServer("io.github.test/server2", "1.0.99") + _, err = svc.Publish(testServer, user2, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "rate limit exceeded") +} + +func TestPublish_ExemptionPatterns(t *testing.T) { + tests := []struct { + name string + exemptions string + authMethodSubject string + isExempt bool + }{ + { + name: "exact match exemption", + exemptions: "modelcontextprotocol", + authMethodSubject: "modelcontextprotocol", + isExempt: true, + }, + { + name: "wildcard root match", + exemptions: "anthropic/*", + authMethodSubject: "anthropic", + isExempt: true, + }, + { + name: "wildcard subdomain match", + exemptions: "anthropic/*", + authMethodSubject: "anthropic.claude", + isExempt: true, + }, + { + name: "wildcard deep subdomain match", + exemptions: "anthropic/*", + authMethodSubject: "anthropic.claude.test.deep", + isExempt: true, + }, + { + name: "no match for partial", + exemptions: "anthropic/*", + authMethodSubject: "anthropi", + isExempt: false, + }, + { + name: "multiple exemptions", + exemptions: "test1,example/*,foo.bar", + authMethodSubject: "example.app", + isExempt: true, + }, + { + name: "empty exemptions", + exemptions: "", + authMethodSubject: "anyone", + isExempt: false, + }, + { + name: "whitespace in exemptions", + exemptions: "test1, example/*, foo.bar", + authMethodSubject: "example.app", + isExempt: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{ + RateLimitEnabled: true, + RateLimitPerDay: 1, + RateLimitExemptions: tt.exemptions, + EnableRegistryValidation: false, + } + db := database.NewMemoryDB() + svc := service.NewRegistryService(db, cfg) + + // Pre-fill the limit + ctx := context.Background() + err := db.IncrementPublishCount(ctx, tt.authMethodSubject) + require.NoError(t, err) + + // Try to publish + testServer := createTestServer("io.github.test/server", "1.0.0") + _, err = svc.Publish(testServer, tt.authMethodSubject, false) + + if tt.isExempt { + assert.NoError(t, err, "Exempt user should be able to publish") + } else { + assert.Error(t, err, "Non-exempt user should be rate limited") + assert.Contains(t, err.Error(), "rate limit exceeded") + } + }) + } +} \ No newline at end of file diff --git a/internal/service/registry_service.go b/internal/service/registry_service.go index 86dc0108..beb10c49 100644 --- a/internal/service/registry_service.go +++ b/internal/service/registry_service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/google/uuid" @@ -71,11 +72,31 @@ func (s *registryServiceImpl) GetByID(id string) (*apiv0.ServerJSON, error) { } // Publish publishes a server with flattened _meta extensions -func (s *registryServiceImpl) Publish(req apiv0.ServerJSON) (*apiv0.ServerJSON, error) { +//nolint:cyclop // Complexity is necessary for validation, rate limiting, and version management +func (s *registryServiceImpl) Publish(req apiv0.ServerJSON, authMethodSubject string, hasGlobalPermissions bool) (*apiv0.ServerJSON, error) { // Create a timeout context for the database operation ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + // Check rate limiting (skip for admins with global permissions or disabled rate limiting) + if authMethodSubject != "" && s.cfg.RateLimitEnabled && !hasGlobalPermissions { + // Check if user is exempt from rate limiting + isExempt := s.isExemptFromRateLimit(authMethodSubject) + + if !isExempt { + // Check and increment the publish count atomically + currentCount, incrementSuccessful, err := s.db.CheckAndIncrementPublishCount(ctx, authMethodSubject, s.cfg.RateLimitPerDay) + if err != nil { + return nil, fmt.Errorf("failed to check rate limit: %w", err) + } + + if !incrementSuccessful { + return nil, fmt.Errorf("rate limit exceeded: you have published %d servers today (limit: %d per day). If you need a higher limit, please open an issue at https://github.com/modelcontextprotocol/registry/issues", + currentCount, s.cfg.RateLimitPerDay) + } + } + } + // Validate the request if err := validators.ValidatePublishRequest(req, s.cfg); err != nil { return nil, err @@ -153,10 +174,23 @@ func (s *registryServiceImpl) Publish(req apiv0.ServerJSON) (*apiv0.ServerJSON, existingLatestID = existingLatest.Meta.Official.ID } if existingLatestID != "" { - // Update the existing server to set is_latest = false - existingLatest.Meta.Official.IsLatest = false - existingLatest.Meta.Official.UpdatedAt = time.Now() - if _, err := s.db.UpdateServer(ctx, existingLatestID, existingLatest); err != nil { + // Create a deep copy to avoid race conditions + updatedLatest := *existingLatest + if updatedLatest.Meta != nil { + // Create a copy of the Meta structure + metaCopy := *updatedLatest.Meta + updatedLatest.Meta = &metaCopy + + if updatedLatest.Meta.Official != nil { + // Create a copy of the Official metadata + officialCopy := *updatedLatest.Meta.Official + officialCopy.IsLatest = false + officialCopy.UpdatedAt = time.Now() + updatedLatest.Meta.Official = &officialCopy + } + } + + if _, err := s.db.UpdateServer(ctx, existingLatestID, &updatedLatest); err != nil { return nil, err } } @@ -200,6 +234,33 @@ func (s *registryServiceImpl) getCurrentLatestVersion(existingServerVersions []* return nil } +// isExemptFromRateLimit checks if an auth subject is exempt from rate limiting +func (s *registryServiceImpl) isExemptFromRateLimit(authMethodSubject string) bool { + if s.cfg.RateLimitExemptions == "" { + return false + } + + exemptions := strings.Split(s.cfg.RateLimitExemptions, ",") + for _, exemption := range exemptions { + exemption = strings.TrimSpace(exemption) + if exemption == "" { + continue + } + + // Handle wildcard exemptions + if strings.HasSuffix(exemption, "/*") { + prefix := strings.TrimSuffix(exemption, "/*") + // Match auth subject exactly or with a separator + if authMethodSubject == prefix || strings.HasPrefix(authMethodSubject, prefix+".") || strings.HasPrefix(authMethodSubject, prefix+"/") { + return true + } + } else if exemption == authMethodSubject { + return true + } + } + return false +} + // EditServer updates an existing server with new details (admin operation) func (s *registryServiceImpl) EditServer(id string, req apiv0.ServerJSON) (*apiv0.ServerJSON, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/internal/service/registry_service_test.go b/internal/service/registry_service_test.go index 94ffdcc6..811f5cd1 100644 --- a/internal/service/registry_service_test.go +++ b/internal/service/registry_service_test.go @@ -38,7 +38,7 @@ func TestValidateNoDuplicateRemoteURLs(t *testing.T) { service := NewRegistryService(memDB, &config.Config{EnableRegistryValidation: false}) for _, server := range existingServers { - _, err := service.Publish(*server) + _, err := service.Publish(*server, "testuser", false) if err != nil { t.Fatalf("failed to publish server: %v", err) } diff --git a/internal/service/service.go b/internal/service/service.go index fd6ebfdb..97eda59e 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -12,7 +12,7 @@ type RegistryService interface { // Retrieve a single server by registry metadata ID GetByID(id string) (*apiv0.ServerJSON, error) // Publish a server - Publish(req apiv0.ServerJSON) (*apiv0.ServerJSON, error) + Publish(req apiv0.ServerJSON, authMethodSubject string, hasGlobalPermissions bool) (*apiv0.ServerJSON, error) // Update an existing server EditServer(id string, req apiv0.ServerJSON) (*apiv0.ServerJSON, error) } diff --git a/tests/integration/docker-compose.integration-test.yml b/tests/integration/docker-compose.integration-test.yml index f5d05ec8..43cd5861 100644 --- a/tests/integration/docker-compose.integration-test.yml +++ b/tests/integration/docker-compose.integration-test.yml @@ -3,6 +3,7 @@ services: environment: - MCP_REGISTRY_SEED_FROM= - MCP_REGISTRY_ENABLE_REGISTRY_VALIDATION=false + - MCP_REGISTRY_RATE_LIMIT_PER_DAY=15 healthcheck: test: ["CMD", "wget", "-qO-", "http://localhost:8080/v0/servers"] interval: 1s diff --git a/tests/integration/main.go b/tests/integration/main.go index 994d35f1..c45d85de 100644 --- a/tests/integration/main.go +++ b/tests/integration/main.go @@ -88,6 +88,13 @@ func publish(examples []example) error { return errors.New(msg) } log.Println(msg) + + // Verify rate limiting is still enforced by attempting one more publish + // This should fail since we set the limit to 15 and have 13 examples + if err := verifyRateLimitEnforced(); err != nil { + return fmt.Errorf("rate limit verification failed: %w", err) + } + return nil } @@ -237,6 +244,63 @@ func findServerIDByName(serverName string) (string, error) { return "", fmt.Errorf("could not find any server with name %s", serverName) } +func verifyRateLimitEnforced() error { + log.Println("Verifying rate limit enforcement...") + + // Create test servers for rate limit verification + // We've published 13 examples, limit is 15, so we can publish 2 more + testServers := []struct { + name string + shouldFail bool + }{ + {"io.modelcontextprotocol.anonymous/rate-limit-test-1", false}, // 14th - should succeed + {"io.modelcontextprotocol.anonymous/rate-limit-test-2", false}, // 15th - should succeed + {"io.modelcontextprotocol.anonymous/rate-limit-test-3", true}, // 16th - should fail + } + + for i, test := range testServers { + server := &apiv0.ServerJSON{ + Name: test.name, + Description: fmt.Sprintf("Rate limit test server %d", i+1), + Status: "available", + Version: "1.0.0", + } + + content, _ := json.Marshal(server) + p := filepath.Join("bin", fmt.Sprintf("rate-limit-test-%d.json", i+1)) + if err := os.WriteFile(p, content, 0600); err != nil { + return fmt.Errorf("failed to write test file: %w", err) + } + defer os.Remove(p) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "./bin/publisher", "publish", p) + cmd.WaitDelay = 100 * time.Millisecond + + out, err := cmd.CombinedOutput() + + if test.shouldFail { + if err == nil { + return fmt.Errorf("expected rate limit error for %s but got success", test.name) + } + if !strings.Contains(string(out), "rate limit") && !strings.Contains(string(out), "Rate limit") { + return fmt.Errorf("expected rate limit error for %s but got: %s", test.name, string(out)) + } + log.Printf(" ✅ Rate limit correctly enforced on attempt %d (16th total)", i+1) + } else { + if err != nil { + return fmt.Errorf("unexpected error for %s: %s", test.name, string(out)) + } + log.Printf(" ✅ Successfully published %s (within rate limit)", test.name) + } + } + + log.Println(" ✅ Rate limit enforcement verified successfully") + return nil +} + func verifyPublishedServer(id string, expected *apiv0.ServerJSON) error { log.Printf(" 🔍 Verifying server with ID: %s", id) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)