Skip to content

Commit

Permalink
Refactor to handle authorization in middleware for all registry endpo…
Browse files Browse the repository at this point in the history
…ints (#45)

* Refactor to handle authorization in middleware for all registry endpoints

* fmt
  • Loading branch information
james03160927 authored Jun 16, 2024
1 parent 431070f commit 192b63a
Show file tree
Hide file tree
Showing 9 changed files with 440 additions and 202 deletions.
59 changes: 36 additions & 23 deletions integration-tests/ban_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"registry-backend/ent/schema"
"registry-backend/mock/gateways"
"registry-backend/server/implementation"
drip_middleware "registry-backend/server/middleware"
drip_authorization "registry-backend/server/middleware/authorization"
"testing"

strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo"
Expand Down Expand Up @@ -51,9 +51,11 @@ func TestBan(t *testing.T) {
return nil, errNotBanned
}
}
authorizationManager := drip_authorization.NewAuthorizationManager(client, impl.RegistryService)
authz := authorizationManager.AuthorizationMiddleware()
wrapped := drip.NewStrictHandler(impl, []strictecho.StrictEchoMiddlewareFunc{
notBanned,
drip_middleware.AuthorizationMiddleware(client),
authz,
})

t.Run("Publisher", func(t *testing.T) {
Expand Down Expand Up @@ -148,17 +150,17 @@ func TestBan(t *testing.T) {
fn: wrapped.CreatePublisher,
},
{
name: "UpdatePublisher",
name: "DeleteNodeVersion",
req: func(ctx context.Context) (req *http.Request) {
payloadBuf := new(bytes.Buffer)
json.NewEncoder(payloadBuf).Encode(drip.UpdatePublisherJSONRequestBody{})
json.NewEncoder(payloadBuf).Encode(drip.DeleteNodeVersionRequestObject{})

req = httptest.NewRequest(http.MethodPost, "/", payloadBuf).WithContext(ctx)
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
return
},
fn: func(ctx echo.Context) error {
return wrapped.UpdatePublisher(ctx, "")
return wrapped.DeleteNodeVersion(ctx, "", "", "")
},
},
}
Expand Down Expand Up @@ -225,7 +227,7 @@ func TestBan(t *testing.T) {
nodeTags := []string{"test-node-tag"}
icon := "https://wwww.github.com/test-icon.svg"
githubUrl := "https://www.github.com/test-github-url"
_, err = impl.CreateNode(ctx, drip.CreateNodeRequestObject{
_, err = withMiddleware(authz, "CreateNode", impl.CreateNode)(ctx, drip.CreateNodeRequestObject{
PublisherId: publisherId,
Body: &drip.Node{
Id: &nodeId,
Expand Down Expand Up @@ -276,42 +278,53 @@ func TestBan(t *testing.T) {

t.Run("Operate", func(t *testing.T) {
t.Run("Get", func(t *testing.T) {
res, err := impl.GetNode(ctx, drip.GetNodeRequestObject{NodeId: nodeId})
require.NoError(t, err)
require.IsType(t, drip.GetNode403JSONResponse{}, res)
f := withMiddleware(authz, "GetNode", impl.GetNode)
_, err := f(ctx, drip.GetNodeRequestObject{NodeId: nodeId})
require.Error(t, err)
require.IsType(t, &echo.HTTPError{}, err)
assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden)
})
t.Run("Update", func(t *testing.T) {
res, err := impl.UpdateNode(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId})
require.NoError(t, err)
require.IsType(t, drip.UpdateNode403JSONResponse{}, res)
f := withMiddleware(authz, "UpdateNode", impl.UpdateNode)
_, err := f(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId})
require.Error(t, err)
require.IsType(t, &echo.HTTPError{}, err)
assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden)
})
t.Run("ListNodeVersion", func(t *testing.T) {
res, err := impl.ListNodeVersions(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId})
require.NoError(t, err)
require.IsType(t, drip.ListNodeVersions403JSONResponse{}, res)
f := withMiddleware(authz, "ListNodeVersions", impl.ListNodeVersions)
_, err := f(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId})
require.Error(t, err)
require.IsType(t, &echo.HTTPError{}, err)
assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden)
})
t.Run("PublishNodeVersion", func(t *testing.T) {
res, err := impl.PublishNodeVersion(ctx, drip.PublishNodeVersionRequestObject{
f := withMiddleware(authz, "PublishNodeVersion", impl.PublishNodeVersion)
_, err := f(ctx, drip.PublishNodeVersionRequestObject{
PublisherId: publisherId, NodeId: nodeId,
Body: &drip.PublishNodeVersionJSONRequestBody{PersonalAccessToken: *pat},
})
require.NoError(t, err)
require.IsType(t, drip.PublishNodeVersion403JSONResponse{}, res)
require.Error(t, err)
require.IsType(t, &echo.HTTPError{}, err)
assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden)
})
t.Run("InstallNode", func(t *testing.T) {
res, err := impl.InstallNode(ctx, drip.InstallNodeRequestObject{NodeId: nodeId})
require.NoError(t, err)
require.IsType(t, drip.InstallNode403JSONResponse{}, res)
f := withMiddleware(authz, "InstallNode", impl.InstallNode)
_, err := f(ctx, drip.InstallNodeRequestObject{NodeId: nodeId})
require.Error(t, err)
require.IsType(t, &echo.HTTPError{}, err)
assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden)
})
t.Run("SearchNodes", func(t *testing.T) {
res, err := impl.SearchNodes(ctx, drip.SearchNodesRequestObject{
f := withMiddleware(authz, "SearchNodes", impl.SearchNodes)
res, err := f(ctx, drip.SearchNodesRequestObject{
Params: drip.SearchNodesParams{},
})
require.NoError(t, err)
require.IsType(t, drip.SearchNodes200JSONResponse{}, res)
require.Empty(t, res.(drip.SearchNodes200JSONResponse).Nodes)

res, err = impl.SearchNodes(ctx, drip.SearchNodesRequestObject{
res, err = f(ctx, drip.SearchNodesRequestObject{
Params: drip.SearchNodesParams{IncludeBanned: proto.Bool(true)},
})
require.NoError(t, err)
Expand Down
118 changes: 0 additions & 118 deletions integration-tests/jwt_auth_wrapper_test.go

This file was deleted.

22 changes: 22 additions & 0 deletions integration-tests/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package integration
import (
"context"
"fmt"
"github.com/labstack/echo/v4"
"net"
"net/http"
"net/http/httptest"
"registry-backend/drip"
auth "registry-backend/server/middleware/authentication"

"registry-backend/ent"
Expand Down Expand Up @@ -117,5 +121,23 @@ func waitPortOpen(t *testing.T, host string, port string, timeout time.Duration)
conn.Close()
return
}
}

func withMiddleware[R any, S any](mw drip.StrictMiddlewareFunc, opname string, h func(ctx context.Context, req R) (res S, err error)) func(ctx context.Context, req R) (res S, err error) {
handler := func(ctx echo.Context, request interface{}) (interface{}, error) {
return h(ctx.Request().Context(), request.(R))
}

return func(ctx context.Context, req R) (res S, err error) {
fakeReq := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
fakeRes := httptest.NewRecorder()
fakeCtx := echo.New().NewContext(fakeReq, fakeRes)

f := mw(handler, opname)
r, err := f(fakeCtx, req)
if r == nil {
return *new(S), err
}
return r.(S), err
}
}
9 changes: 3 additions & 6 deletions server/middleware/authentication/jwt_admin_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,20 @@ func JWTAdminAuthMiddleware(entClient *ent.Client, secret string) echo.Middlewar
// Get the Authorization header
header := c.Request().Header.Get("Authorization")
if header == "" {
// No Authorization header, pass to the next auth middleware
return next(c)
return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token")
}

// Extract the JWT token from the header
splitToken := strings.Split(header, "Bearer ")
if len(splitToken) != 2 {
// Invalid format, pass to the next auth middleware
return next(c)
return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token")
}
token := splitToken[1]

// Parse and validate the JWT token
tokenData, err := jwt.Parse(token, keyfunc)
if err != nil {
// Invalid token, pass to the next auth middleware
return next(c)
return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token")
}

// Extract claims from the token
Expand Down
70 changes: 70 additions & 0 deletions server/middleware/authentication/jwt_admin_auth_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,71 @@
package drip_authentication

import (
"net/http"
"net/http/httptest"
"registry-backend/ent"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestJWTAdminAllowlist(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

// Mock ent.Client
mockEntClient := &ent.Client{}

middleware := JWTAdminAuthMiddleware(mockEntClient, "secret")

tests := []struct {
name string
path string
method string
allowed bool
}{
{"OpenAPI GET", "/openapi", "GET", true},
{"Session DELETE", "/users/sessions", "DELETE", true},
{"Health GET", "/health", "GET", true},
{"VM ANY", "/vm", "POST", true},
{"VM ANY GET", "/vm", "GET", true},
{"Artifact POST", "/upload-artifact", "POST", true},
{"Git Commit POST", "/gitcommit", "POST", true},
{"Git Commit GET", "/gitcommit", "GET", true},
{"Branch GET", "/branch", "GET", true},
{"Node Version Path POST", "/publishers/pub123/nodes/node456/versions", "POST", true},
{"Publisher POST", "/publishers", "POST", true},
{"Unauthorized Path", "/nonexistent", "GET", true},
{"Get All Nodes", "/nodes", "GET", true},
{"Install Nodes", "/nodes/node-id/install", "GET", true},

{"Ban Publisher", "/publishers/publisher-id/ban", "POST", false},
{"Ban Node", "/publishers/publisher-id/nodes/node-id/ban", "POST", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
c.SetRequest(req)
handled := false
next := echo.HandlerFunc(func(c echo.Context) error {
handled = true
return nil
})
err := middleware(next)(c)
if tt.allowed {
assert.True(t, handled, "Request should be allowed through")
assert.Nil(t, err)
} else {
assert.False(t, handled, "Request should not be allowed through")
assert.NotNil(t, err)
httpError, ok := err.(*echo.HTTPError)
assert.True(t, ok, "Error should be HTTPError")
assert.Equal(t, http.StatusUnauthorized, httpError.Code)
}
})
}
}
Loading

0 comments on commit 192b63a

Please sign in to comment.