diff --git a/integration-tests/ban_test.go b/integration-tests/ban_test.go index b097635..812b259 100644 --- a/integration-tests/ban_test.go +++ b/integration-tests/ban_test.go @@ -12,7 +12,7 @@ import ( "registry-backend/ent/schema" "registry-backend/mock/gateways" "registry-backend/server/implementation" - drip_middleware "registry-backend/server/middleware" + drip_authorization "registry-backend/server/middleware/authorization" "testing" strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" @@ -51,9 +51,11 @@ func TestBan(t *testing.T) { return nil, errNotBanned } } + authorizationManager := drip_authorization.NewAuthorizationManager(client, impl.RegistryService) + authz := authorizationManager.AuthorizationMiddleware() wrapped := drip.NewStrictHandler(impl, []strictecho.StrictEchoMiddlewareFunc{ notBanned, - drip_middleware.AuthorizationMiddleware(client), + authz, }) t.Run("Publisher", func(t *testing.T) { @@ -148,17 +150,17 @@ func TestBan(t *testing.T) { fn: wrapped.CreatePublisher, }, { - name: "UpdatePublisher", + name: "DeleteNodeVersion", req: func(ctx context.Context) (req *http.Request) { payloadBuf := new(bytes.Buffer) - json.NewEncoder(payloadBuf).Encode(drip.UpdatePublisherJSONRequestBody{}) + json.NewEncoder(payloadBuf).Encode(drip.DeleteNodeVersionRequestObject{}) req = httptest.NewRequest(http.MethodPost, "/", payloadBuf).WithContext(ctx) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) return }, fn: func(ctx echo.Context) error { - return wrapped.UpdatePublisher(ctx, "") + return wrapped.DeleteNodeVersion(ctx, "", "", "") }, }, } @@ -225,7 +227,7 @@ func TestBan(t *testing.T) { nodeTags := []string{"test-node-tag"} icon := "https://wwww.github.com/test-icon.svg" githubUrl := "https://www.github.com/test-github-url" - _, err = impl.CreateNode(ctx, drip.CreateNodeRequestObject{ + _, err = withMiddleware(authz, "CreateNode", impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ PublisherId: publisherId, Body: &drip.Node{ Id: &nodeId, @@ -276,42 +278,53 @@ func TestBan(t *testing.T) { t.Run("Operate", func(t *testing.T) { t.Run("Get", func(t *testing.T) { - res, err := impl.GetNode(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) - require.NoError(t, err) - require.IsType(t, drip.GetNode403JSONResponse{}, res) + f := withMiddleware(authz, "GetNode", impl.GetNode) + _, err := f(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) }) t.Run("Update", func(t *testing.T) { - res, err := impl.UpdateNode(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) - require.NoError(t, err) - require.IsType(t, drip.UpdateNode403JSONResponse{}, res) + f := withMiddleware(authz, "UpdateNode", impl.UpdateNode) + _, err := f(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) }) t.Run("ListNodeVersion", func(t *testing.T) { - res, err := impl.ListNodeVersions(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) - require.NoError(t, err) - require.IsType(t, drip.ListNodeVersions403JSONResponse{}, res) + f := withMiddleware(authz, "ListNodeVersions", impl.ListNodeVersions) + _, err := f(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) }) t.Run("PublishNodeVersion", func(t *testing.T) { - res, err := impl.PublishNodeVersion(ctx, drip.PublishNodeVersionRequestObject{ + f := withMiddleware(authz, "PublishNodeVersion", impl.PublishNodeVersion) + _, err := f(ctx, drip.PublishNodeVersionRequestObject{ PublisherId: publisherId, NodeId: nodeId, Body: &drip.PublishNodeVersionJSONRequestBody{PersonalAccessToken: *pat}, }) - require.NoError(t, err) - require.IsType(t, drip.PublishNodeVersion403JSONResponse{}, res) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) }) t.Run("InstallNode", func(t *testing.T) { - res, err := impl.InstallNode(ctx, drip.InstallNodeRequestObject{NodeId: nodeId}) - require.NoError(t, err) - require.IsType(t, drip.InstallNode403JSONResponse{}, res) + f := withMiddleware(authz, "InstallNode", impl.InstallNode) + _, err := f(ctx, drip.InstallNodeRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) }) t.Run("SearchNodes", func(t *testing.T) { - res, err := impl.SearchNodes(ctx, drip.SearchNodesRequestObject{ + f := withMiddleware(authz, "SearchNodes", impl.SearchNodes) + res, err := f(ctx, drip.SearchNodesRequestObject{ Params: drip.SearchNodesParams{}, }) require.NoError(t, err) require.IsType(t, drip.SearchNodes200JSONResponse{}, res) require.Empty(t, res.(drip.SearchNodes200JSONResponse).Nodes) - res, err = impl.SearchNodes(ctx, drip.SearchNodesRequestObject{ + res, err = f(ctx, drip.SearchNodesRequestObject{ Params: drip.SearchNodesParams{IncludeBanned: proto.Bool(true)}, }) require.NoError(t, err) diff --git a/integration-tests/jwt_auth_wrapper_test.go b/integration-tests/jwt_auth_wrapper_test.go deleted file mode 100644 index f49bdae..0000000 --- a/integration-tests/jwt_auth_wrapper_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package integration - -import ( - "context" - "net/http" - "net/http/httptest" - "registry-backend/server/middleware/authentication" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/labstack/echo/v4" - "github.com/rs/zerolog/log" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestJwtAuthWrapper(t *testing.T) { - clientCtx := context.Background() - client, postgresContainer := setupDB(t, clientCtx) - // Cleanup - defer func() { - if err := postgresContainer.Terminate(clientCtx); err != nil { - log.Ctx(clientCtx).Error().Msgf("failed to terminate container: %s", err) - } - }() - - jwtsecret := "test" - - newHandler := func() (echo.HandlerFunc, *bool) { - invoked := false - - return func(c echo.Context) error { - invoked = true - return nil - }, &invoked - } - - e := echo.New() - jwtmw := drip_authentication.JWTAdminAuthMiddleware(client, jwtsecret) - - t.Run("No JWT", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/publishers/test-publisher/ban", nil) - c := e.NewContext(req, httptest.NewRecorder()) - - next, nextivc := newHandler() - err := jwtmw(next)(c) - require.NoError(t, err) - assert.True(t, *nextivc, "should invoke the handler") - }) - - t.Run("Invalid JWT", func(t *testing.T) { - _, user := setUpTest(client) - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "sub": user.ID, - "nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(), - }) - // Sign and get the complete encoded token as a string using the secret - tokenString, err := token.SignedString([]byte("invalid")) - require.NoError(t, err, "should not return error") - - req := httptest.NewRequest(http.MethodGet, "/publishers/test-publisher/ban", nil) - req.Header.Set("Authorization", "Bearer "+tokenString) - c := e.NewContext(req, httptest.NewRecorder()) - next, nextivc := newHandler() - err = jwtmw(next)(c) - require.NoError(t, err, "should not return error") - assert.True(t, *nextivc, "should invoke the wrapped middleware") - }) - - t.Run("Valid JWT Invalid User", func(t *testing.T) { - _, user := setUpTest(client) - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "sub": user.ID + "Invalid", - "nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(), - }) - // Sign and get the complete encoded token as a string using the secret - tokenString, err := token.SignedString([]byte(jwtsecret)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/publishers/test-publisher/ban", nil) - req.Header.Set("Authorization", "Bearer "+tokenString) - c := e.NewContext(req, httptest.NewRecorder()) - next, nextivc := newHandler() - err = jwtmw(next)(c) - require.Error(t, err, "should return error") - assert.False(t, *nextivc, "should not invoke the handler") - }) - - t.Run("Valid JWT", func(t *testing.T) { - _, user := setUpTest(client) - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "sub": user.ID, - "nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(), - }) - // Sign and get the complete encoded token as a string using the secret - tokenString, err := token.SignedString([]byte(jwtsecret)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/publishers/test-publisher/ban", nil) - req.Header.Set("Authorization", "Bearer "+tokenString) - c := e.NewContext(req, httptest.NewRecorder()) - next, nextivc := newHandler() - err = jwtmw(next)(c) - require.NoError(t, err, "should not return error") - assert.True(t, *nextivc, "should invoke the handler") - }) - - t.Run("Non-Protected Endpoint", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/non-protected-endpoint", nil) - c := e.NewContext(req, httptest.NewRecorder()) - - next, nextivc := newHandler() - err := jwtmw(next)(c) - require.NoError(t, err) - assert.True(t, *nextivc, "should invoke the handler for non-protected endpoint") - }) -} diff --git a/integration-tests/test_util.go b/integration-tests/test_util.go index 1e7d04a..6555865 100644 --- a/integration-tests/test_util.go +++ b/integration-tests/test_util.go @@ -3,7 +3,11 @@ package integration import ( "context" "fmt" + "github.com/labstack/echo/v4" "net" + "net/http" + "net/http/httptest" + "registry-backend/drip" auth "registry-backend/server/middleware/authentication" "registry-backend/ent" @@ -117,5 +121,23 @@ func waitPortOpen(t *testing.T, host string, port string, timeout time.Duration) conn.Close() return } +} + +func withMiddleware[R any, S any](mw drip.StrictMiddlewareFunc, opname string, h func(ctx context.Context, req R) (res S, err error)) func(ctx context.Context, req R) (res S, err error) { + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return h(ctx.Request().Context(), request.(R)) + } + return func(ctx context.Context, req R) (res S, err error) { + fakeReq := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) + fakeRes := httptest.NewRecorder() + fakeCtx := echo.New().NewContext(fakeReq, fakeRes) + + f := mw(handler, opname) + r, err := f(fakeCtx, req) + if r == nil { + return *new(S), err + } + return r.(S), err + } } diff --git a/server/middleware/authentication/jwt_admin_auth.go b/server/middleware/authentication/jwt_admin_auth.go index 8c8b9d8..0b1b886 100644 --- a/server/middleware/authentication/jwt_admin_auth.go +++ b/server/middleware/authentication/jwt_admin_auth.go @@ -52,23 +52,20 @@ func JWTAdminAuthMiddleware(entClient *ent.Client, secret string) echo.Middlewar // Get the Authorization header header := c.Request().Header.Get("Authorization") if header == "" { - // No Authorization header, pass to the next auth middleware - return next(c) + return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token") } // Extract the JWT token from the header splitToken := strings.Split(header, "Bearer ") if len(splitToken) != 2 { - // Invalid format, pass to the next auth middleware - return next(c) + return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token") } token := splitToken[1] // Parse and validate the JWT token tokenData, err := jwt.Parse(token, keyfunc) if err != nil { - // Invalid token, pass to the next auth middleware - return next(c) + return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token") } // Extract claims from the token diff --git a/server/middleware/authentication/jwt_admin_auth_test.go b/server/middleware/authentication/jwt_admin_auth_test.go index d6da53a..10c05fe 100644 --- a/server/middleware/authentication/jwt_admin_auth_test.go +++ b/server/middleware/authentication/jwt_admin_auth_test.go @@ -1 +1,71 @@ package drip_authentication + +import ( + "net/http" + "net/http/httptest" + "registry-backend/ent" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestJWTAdminAllowlist(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Mock ent.Client + mockEntClient := &ent.Client{} + + middleware := JWTAdminAuthMiddleware(mockEntClient, "secret") + + tests := []struct { + name string + path string + method string + allowed bool + }{ + {"OpenAPI GET", "/openapi", "GET", true}, + {"Session DELETE", "/users/sessions", "DELETE", true}, + {"Health GET", "/health", "GET", true}, + {"VM ANY", "/vm", "POST", true}, + {"VM ANY GET", "/vm", "GET", true}, + {"Artifact POST", "/upload-artifact", "POST", true}, + {"Git Commit POST", "/gitcommit", "POST", true}, + {"Git Commit GET", "/gitcommit", "GET", true}, + {"Branch GET", "/branch", "GET", true}, + {"Node Version Path POST", "/publishers/pub123/nodes/node456/versions", "POST", true}, + {"Publisher POST", "/publishers", "POST", true}, + {"Unauthorized Path", "/nonexistent", "GET", true}, + {"Get All Nodes", "/nodes", "GET", true}, + {"Install Nodes", "/nodes/node-id/install", "GET", true}, + + {"Ban Publisher", "/publishers/publisher-id/ban", "POST", false}, + {"Ban Node", "/publishers/publisher-id/nodes/node-id/ban", "POST", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + c.SetRequest(req) + handled := false + next := echo.HandlerFunc(func(c echo.Context) error { + handled = true + return nil + }) + err := middleware(next)(c) + if tt.allowed { + assert.True(t, handled, "Request should be allowed through") + assert.Nil(t, err) + } else { + assert.False(t, handled, "Request should not be allowed through") + assert.NotNil(t, err) + httpError, ok := err.(*echo.HTTPError) + assert.True(t, ok, "Error should be HTTPError") + assert.Equal(t, http.StatusUnauthorized, httpError.Code) + } + }) + } +} diff --git a/server/middleware/authorization/authorization_manager.go b/server/middleware/authorization/authorization_manager.go new file mode 100644 index 0000000..0405075 --- /dev/null +++ b/server/middleware/authorization/authorization_manager.go @@ -0,0 +1,130 @@ +package drip_authorization + +import ( + "github.com/labstack/echo/v4" + strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" + "github.com/rs/zerolog/log" + "net/http" + "registry-backend/drip" + "registry-backend/ent" + "registry-backend/ent/schema" + drip_authentication "registry-backend/server/middleware/authentication" + drip_services "registry-backend/services/registry" +) + +// AuthorizationManager manages authorization-related tasks +type AuthorizationManager struct { + EntClient *ent.Client + RegistryService *drip_services.RegistryService +} + +// NewAuthorizationManager creates a new instance of AuthorizationManager +func NewAuthorizationManager( + entClient *ent.Client, registryService *drip_services.RegistryService) *AuthorizationManager { + return &AuthorizationManager{ + EntClient: entClient, + RegistryService: registryService, + } +} + +// assertUserBanned checks if the user is banned +func (m *AuthorizationManager) assertUserBanned() drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + v := ctx.Value(drip_authentication.UserContextKey) + userDetails, ok := v.(*drip_authentication.UserDetails) + if !ok { + return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") + } + + u, err := m.EntClient.User.Get(ctx, userDetails.ID) + if err != nil { + return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") + } + + if u.Status == schema.UserStatusTypeBanned { + return nil, echo.NewHTTPError(http.StatusForbidden, "user/publisher is banned") + } + + return f(c, request) + } + } +} + +// assertPublisherPermission checks if the user has the required permissions for the publisher +func (m *AuthorizationManager) assertPublisherPermission( + permissions []schema.PublisherPermissionType, extractor func(req interface{}) string) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + v := ctx.Value(drip_authentication.UserContextKey) + userDetails, ok := v.(*drip_authentication.UserDetails) + if !ok { + return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") + } + publisherID := extractor(request) + + log.Ctx(ctx).Info().Msgf("Checking if user ID %s has permission "+ + "to update publisher ID %s", userDetails.ID, publisherID) + err = m.RegistryService.AssertPublisherPermissions(ctx, m.EntClient, publisherID, userDetails.ID, permissions) + switch { + case ent.IsNotFound(err): + log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", publisherID) + return nil, echo.NewHTTPError(http.StatusNotFound, "Publisher Not Found") + + case drip_services.IsPermissionError(err): + log.Ctx(ctx).Error().Msgf("Permission denied for user ID %s on "+ + "publisher ID %s w/ err: %v", userDetails.ID, publisherID, err) + return nil, echo.NewHTTPError(http.StatusForbidden, "Permission denied") + + case err != nil: + log.Ctx(ctx).Error().Msgf("Failed to assert publisher "+ + "permission %s w/ err: %v", publisherID, err) + return nil, err + } + + return f(c, request) + } + } +} + +// assertNodeBanned checks if the node is banned +func (m *AuthorizationManager) assertNodeBanned(extractor func(req interface{}) string) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + nodeID := extractor(request) + err = m.RegistryService.AssertNodeBanned(ctx, m.EntClient, nodeID) + switch { + case drip_services.IsPermissionError(err): + log.Ctx(ctx).Error().Msgf("Node %s banned", nodeID) + return nil, echo.NewHTTPError(http.StatusForbidden, "Node Banned") + + case err != nil: + log.Ctx(ctx).Error().Msgf("Failed to assert node ban status %s w/ err: %v", nodeID, err) + return nil, err + } + + return f(c, request) + } + } +} + +// assertPublisherBanned checks if the publisher is banned +func (m *AuthorizationManager) assertPublisherBanned(extractor func(req interface{}) string) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + publisherID := extractor(request) + + pub, _ := m.RegistryService.GetPublisher(ctx, m.EntClient, publisherID) + if pub != nil && pub.Status == schema.PublisherStatusTypeBanned { + log.Ctx(ctx).Error().Msgf("Publisher %s banned", publisherID) + return nil, echo.NewHTTPError(http.StatusForbidden, "Publisher Banned") + } + + return f(c, request) + } + } +} diff --git a/server/middleware/authorization/authorization_middleware.go b/server/middleware/authorization/authorization_middleware.go new file mode 100644 index 0000000..f3ebdf9 --- /dev/null +++ b/server/middleware/authorization/authorization_middleware.go @@ -0,0 +1,176 @@ +package drip_authorization + +import ( + "registry-backend/drip" + "registry-backend/ent/schema" + "slices" + + strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" +) + +func (m *AuthorizationManager) AuthorizationMiddleware() drip.StrictMiddlewareFunc { + subMiddlewares := map[string][]drip.StrictMiddlewareFunc{ + "CreatePublisher": { + m.assertUserBanned(), + }, + "UpdatePublisher": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.UpdatePublisherRequestObject).PublisherId + }, + ), + }, + "CreateNode": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.CreateNodeRequestObject).PublisherId + }, + ), + }, + "DeleteNode": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.DeleteNodeRequestObject).PublisherId + }, + ), + }, + "UpdateNode": { + m.assertUserBanned(), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.UpdateNodeRequestObject).NodeId + }, + ), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.UpdateNodeRequestObject).PublisherId + }, + ), + }, + "GetNode": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.GetNodeRequestObject).NodeId + }, + ), + }, + "PublishNodeVersion": { + m.assertPublisherBanned( + func(req interface{}) (publisherID string) { + return req.(drip.PublishNodeVersionRequestObject).PublisherId + }), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.PublishNodeVersionRequestObject).NodeId + }, + ), + }, + "UpdateNodeVersion": { + m.assertUserBanned(), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.UpdateNodeVersionRequestObject).NodeId + }, + ), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.UpdateNodeVersionRequestObject).PublisherId + }, + ), + }, + "DeleteNodeVersion": { + m.assertUserBanned(), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.DeleteNodeVersionRequestObject).NodeId + }, + ), + }, + "GetNodeVersion": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.GetNodeVersionRequestObject).NodeId + }, + ), + }, + "ListNodeVersions": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.ListNodeVersionsRequestObject).NodeId + }, + ), + }, + "InstallNode": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.InstallNodeRequestObject).NodeId + }, + ), + }, + "CreatePersonalAccessToken": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.CreatePersonalAccessTokenRequestObject).PublisherId + }, + ), + }, + "DeletePersonalAccessToken": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.DeletePersonalAccessTokenRequestObject).PublisherId + }, + ), + }, + "ListPersonalAccessTokens": { + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.ListPersonalAccessTokensRequestObject).PublisherId + }, + ), + }, + "GetPermissionOnPublisherNodes": { + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.GetPermissionOnPublisherNodesRequestObject).PublisherId + }, + ), + }, + "GetPermissionOnPublisher": { + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.GetPermissionOnPublisherRequestObject).PublisherId + }, + ), + }, + } + for _, v := range subMiddlewares { + slices.Reverse(v) + } + + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + middlewares, ok := subMiddlewares[operationID] + if !ok { + return f + } + + for _, mw := range middlewares { + f = mw(f, operationID) + } + return f + } +} diff --git a/server/middleware/authorization_middleware.go b/server/middleware/authorization_middleware.go deleted file mode 100644 index b75c203..0000000 --- a/server/middleware/authorization_middleware.go +++ /dev/null @@ -1,54 +0,0 @@ -package drip_middleware - -import ( - "github.com/labstack/echo/v4" - strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" - "net/http" - "registry-backend/drip" - "registry-backend/ent" - "registry-backend/ent/schema" - "registry-backend/server/middleware/authentication" -) - -func AuthorizationMiddleware(entClient *ent.Client) drip.StrictMiddlewareFunc { - restrictedOperationsForBannedUsers := map[string]struct{}{ - "CreatePublisher": {}, - "UpdatePublisher": {}, - "CreateNode": {}, - "DeleteNode": {}, - "UpdateNode": {}, - //"PublishNodeVersion": {}, - "UpdateNodeVersion": {}, - "DeleteNodeVersion": {}, - "CreatePersonalAccessToken": {}, - "DeletePersonalAccessToken": {}, - } - return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { - return func(c echo.Context, request interface{}) (response interface{}, err error) { - // Bypass authorization for non-write operations - if _, ok := restrictedOperationsForBannedUsers[operationID]; !ok { - return f(c, request) - } - - // Get user details from the context - ctx := c.Request().Context() - v := ctx.Value(drip_authentication.UserContextKey) - userDetails, ok := v.(*drip_authentication.UserDetails) - if !ok { - return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") - } - - u, err := entClient.User.Get(ctx, userDetails.ID) - if err != nil { - return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") - } - - if _, ok := restrictedOperationsForBannedUsers[operationID]; ok && u.Status == schema.UserStatusTypeBanned { - return nil, echo.NewHTTPError(http.StatusForbidden, "user/publisher is banned") - } - - return f(c, request) - } - - } -} diff --git a/server/server.go b/server/server.go index 99e6984..ca41304 100644 --- a/server/server.go +++ b/server/server.go @@ -12,6 +12,7 @@ import ( "registry-backend/server/implementation" drip_middleware "registry-backend/server/middleware" drip_authentication "registry-backend/server/middleware/authentication" + drip_authorization "registry-backend/server/middleware/authorization" "strings" monitoring "cloud.google.com/go/monitoring/apiv3/v2" @@ -95,8 +96,9 @@ func (s *Server) Start() error { impl := implementation.NewStrictServerImplementation(s.Client, s.Config, storageService, slackService) // Define middlewares in the order of operations + authorizationManager := drip_authorization.NewAuthorizationManager(s.Client, impl.RegistryService) middlewares := []generated.StrictMiddlewareFunc{ - drip_middleware.AuthorizationMiddleware(s.Client), + authorizationManager.AuthorizationMiddleware(), } wrapped := generated.NewStrictHandler(impl, middlewares)