-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add service account middleware authentication
- Loading branch information
1 parent
6fd8b9c
commit 8f72676
Showing
3 changed files
with
138 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package drip_authentication | ||
|
||
import ( | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/labstack/echo/v4" | ||
"github.com/rs/zerolog/log" | ||
"google.golang.org/api/idtoken" | ||
) | ||
|
||
func ServiceAccountAuthMiddleware() echo.MiddlewareFunc { | ||
// Handlers in here should be checked by this middleware. | ||
var checklist = map[string][]string{ | ||
"/security-scan": {"GET"}, | ||
"/nodes/reindex": {"POST"}, | ||
} | ||
|
||
return func(next echo.HandlerFunc) echo.HandlerFunc { | ||
return func(ctx echo.Context) error { | ||
// Check if the request path and method are in the checklist | ||
path := ctx.Request().URL.Path | ||
method := ctx.Request().Method | ||
|
||
methods, ok := checklist[path] | ||
if !ok { | ||
return next(ctx) | ||
} | ||
|
||
for _, m := range methods { | ||
if method == m { | ||
ok = true | ||
break | ||
} | ||
} | ||
if !ok { | ||
return next(ctx) | ||
} | ||
|
||
// validate token | ||
authHeader := ctx.Request().Header.Get("Authorization") | ||
token := "" | ||
if strings.HasPrefix(authHeader, "Bearer ") { | ||
token = authHeader[7:] // Skip the "Bearer " part | ||
} | ||
|
||
if token == "" { | ||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing token") | ||
} | ||
|
||
log.Ctx(ctx.Request().Context()).Info().Msgf("Validating google id token %s for path %s and method %s", token, path, method) | ||
|
||
payload, err := idtoken.Validate(ctx.Request().Context(), token, "https://api.comfy.org") | ||
if err != nil { | ||
log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid token") | ||
return ctx.JSON(http.StatusUnauthorized, "Invalid token") | ||
} | ||
|
||
email, _ := payload.Claims["email"].(string) | ||
if email != "[email protected]" { | ||
log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid email") | ||
return ctx.JSON(http.StatusUnauthorized, "Invalid email") | ||
} | ||
|
||
log.Ctx(ctx.Request().Context()).Info().Msgf("Service Account Email: %s", email) | ||
return next(ctx) | ||
} | ||
} | ||
} |
67 changes: 67 additions & 0 deletions
67
server/middleware/authentication/service_account_auth_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package drip_authentication | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/labstack/echo/v4" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestServiceAccountAllowList(t *testing.T) { | ||
e := echo.New() | ||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
rec := httptest.NewRecorder() | ||
c := e.NewContext(req, rec) | ||
|
||
middleware := ServiceAccountAuthMiddleware() | ||
|
||
tests := []struct { | ||
name string | ||
path string | ||
method string | ||
allowed bool | ||
}{ | ||
{"OpenAPI GET", "/openapi", "GET", true}, | ||
{"Session DELETE", "/users/sessions", "DELETE", true}, | ||
{"Health GET", "/health", "GET", true}, | ||
{"VM ANY", "/vm", "POST", true}, | ||
{"VM ANY GET", "/vm", "GET", true}, | ||
{"Artifact POST", "/upload-artifact", "POST", true}, | ||
{"Git Commit POST", "/gitcommit", "POST", true}, | ||
{"Git Commit GET", "/gitcommit", "GET", true}, | ||
{"Branch GET", "/branch", "GET", true}, | ||
{"Node Version Path POST", "/publishers/pub123/nodes/node456/versions", "POST", true}, | ||
{"Publisher POST", "/publishers", "POST", true}, | ||
{"Unauthorized Path", "/nonexistent", "GET", true}, | ||
{"Get All Nodes", "/nodes", "GET", true}, | ||
{"Install Nodes", "/nodes/node-id/install", "GET", true}, | ||
|
||
{"Reindex Nodes", "/nodes/reindex", "POST", false}, | ||
{"Reindex Nodes", "/security-scan", "GET", false}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
req := httptest.NewRequest(tt.method, tt.path, nil) | ||
c.SetRequest(req) | ||
handled := false | ||
next := echo.HandlerFunc(func(c echo.Context) error { | ||
handled = true | ||
return nil | ||
}) | ||
err := middleware(next)(c) | ||
if tt.allowed { | ||
assert.True(t, handled, "Request should be allowed through") | ||
assert.Nil(t, err) | ||
} else { | ||
assert.False(t, handled, "Request should not be allowed through") | ||
assert.NotNil(t, err) | ||
httpError, ok := err.(*echo.HTTPError) | ||
assert.True(t, ok, "Error should be HTTPError") | ||
assert.Equal(t, http.StatusUnauthorized, httpError.Code) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters