From 9adc3631611b7f7b53e4a69557dae449c23d9474 Mon Sep 17 00:00:00 2001 From: James Kwon <96548424+hongil0316@users.noreply.github.com> Date: Sat, 4 Jan 2025 13:16:38 -0500 Subject: [PATCH] Refactor code: request/response logging middleware, entity, standard logging for troubleshooting --- common/string.go | 23 ++++++ entity/node.go | 19 +++++ entity/node_version.go | 24 ++++++ entity/publisher.go | 5 ++ common/types.go => entity/token.go | 2 +- server/implementation/registry.go | 65 +++------------- .../authentication/firebase_auth.go | 2 +- .../authentication/firebase_auth_test.go | 2 +- .../authentication/jwt_admin_auth.go | 2 +- .../authentication/jwt_admin_auth_test.go | 2 +- .../authentication/service_account_auth.go | 2 +- .../service_account_auth_test.go | 2 +- server/middleware/error_logger.go | 24 ------ server/middleware/metric/metric.go | 2 +- server/middleware/metric/metric_middleware.go | 2 +- server/middleware/request_logger.go | 38 ++++++++++ server/middleware/response_logger.go | 63 ++++++++++++++++ server/middleware/tracing_middleware.go | 2 +- server/server.go | 57 ++++---------- services/registry/registry_svc.go | 74 ++----------------- 20 files changed, 215 insertions(+), 197 deletions(-) create mode 100644 common/string.go create mode 100644 entity/node.go create mode 100644 entity/node_version.go create mode 100644 entity/publisher.go rename common/types.go => entity/token.go (95%) delete mode 100644 server/middleware/error_logger.go create mode 100644 server/middleware/request_logger.go create mode 100644 server/middleware/response_logger.go diff --git a/common/string.go b/common/string.go new file mode 100644 index 0000000..fa95638 --- /dev/null +++ b/common/string.go @@ -0,0 +1,23 @@ +package common + +import ( + "encoding/json" + "fmt" +) + +func PrettifyJSON(input string) (string, error) { + // First unmarshal the input string into a generic interface{} + var temp interface{} + err := json.Unmarshal([]byte(input), &temp) + if err != nil { + return "", fmt.Errorf("invalid JSON input: %v", err) + } + + // Marshal back to JSON with indentation + pretty, err := json.MarshalIndent(temp, "", " ") + if err != nil { + return "", fmt.Errorf("failed to marshal JSON: %v", err) + } + + return string(pretty), nil +} diff --git a/entity/node.go b/entity/node.go new file mode 100644 index 0000000..d0ee5ea --- /dev/null +++ b/entity/node.go @@ -0,0 +1,19 @@ +package entity + +import "registry-backend/ent" + +// NodeFilter holds optional parameters for filtering node results +type NodeFilter struct { + PublisherID string + Search string + IncludeBanned bool +} + +// ListNodesResult is the structure that holds the paginated result of nodes +type ListNodesResult struct { + Total int `json:"total"` + Nodes []*ent.Node `json:"nodes"` + Page int `json:"page"` + Limit int `json:"limit"` + TotalPages int `json:"totalPages"` +} diff --git a/entity/node_version.go b/entity/node_version.go new file mode 100644 index 0000000..5bd7ca2 --- /dev/null +++ b/entity/node_version.go @@ -0,0 +1,24 @@ +package entity + +import ( + "registry-backend/ent" + "registry-backend/ent/schema" + "time" +) + +type NodeVersionFilter struct { + NodeId string + Status []schema.NodeVersionStatus + IncludeStatusReason bool + MinAge time.Duration + PageSize int + Page int +} + +type ListNodeVersionsResult struct { + Total int `json:"total"` + NodeVersions []*ent.NodeVersion `json:"nodes"` + Page int `json:"page"` + Limit int `json:"limit"` + TotalPages int `json:"totalPages"` +} diff --git a/entity/publisher.go b/entity/publisher.go new file mode 100644 index 0000000..44d27bc --- /dev/null +++ b/entity/publisher.go @@ -0,0 +1,5 @@ +package entity + +type PublisherFilter struct { + UserID string +} diff --git a/common/types.go b/entity/token.go similarity index 95% rename from common/types.go rename to entity/token.go index ffd64d1..6fe395b 100644 --- a/common/types.go +++ b/entity/token.go @@ -1,6 +1,6 @@ // File: common/types.go -package common +package entity type InviteTokenStatus string diff --git a/server/implementation/registry.go b/server/implementation/registry.go index 9b77d45..8d1cd2d 100644 --- a/server/implementation/registry.go +++ b/server/implementation/registry.go @@ -8,6 +8,7 @@ import ( "registry-backend/ent" "registry-backend/ent/publisher" "registry-backend/ent/schema" + "registry-backend/entity" "registry-backend/mapper" drip_services "registry-backend/services/registry" "time" @@ -20,7 +21,6 @@ import ( func (impl *DripStrictServerImplementation) ListPublishersForUser( ctx context.Context, request drip.ListPublishersForUserRequestObject) (drip.ListPublishersForUserResponseObject, error) { - log.Ctx(ctx).Debug().Msg("ListPublishersForUser called.") // Extract user ID from context userId, err := mapper.GetUserIDFromContext(ctx) @@ -31,7 +31,7 @@ func (impl *DripStrictServerImplementation) ListPublishersForUser( // Call the service to list publishers log.Ctx(ctx).Info().Msgf("Fetching publishers for user %s", userId) - publishers, err := impl.RegistryService.ListPublishers(ctx, impl.Client, &drip_services.PublisherFilter{ + publishers, err := impl.RegistryService.ListPublishers(ctx, impl.Client, &entity.PublisherFilter{ UserID: userId, }) if err != nil { @@ -53,9 +53,6 @@ func (impl *DripStrictServerImplementation) ListPublishersForUser( func (s *DripStrictServerImplementation) ValidatePublisher( ctx context.Context, request drip.ValidatePublisherRequestObject) (drip.ValidatePublisherResponseObject, error) { - // Log the incoming request for validation - log.Ctx(ctx).Info().Msgf("ValidatePublisher request with username: %s", request.Params.Username) - // Check if the username is empty name := request.Params.Username if name == "" { @@ -92,9 +89,6 @@ func (s *DripStrictServerImplementation) ValidatePublisher( func (s *DripStrictServerImplementation) CreatePublisher( ctx context.Context, request drip.CreatePublisherRequestObject) (drip.CreatePublisherResponseObject, error) { - // Log the incoming request - log.Ctx(ctx).Info().Msgf("CreatePublisher request called") - // Extract user ID from context userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { @@ -104,7 +98,7 @@ func (s *DripStrictServerImplementation) CreatePublisher( log.Ctx(ctx).Info().Msgf("Checking if user ID %s has reached the maximum number of publishers", userId) userPublishers, err := s.RegistryService.ListPublishers( - ctx, s.Client, &drip_services.PublisherFilter{UserID: userId}) + ctx, s.Client, &entity.PublisherFilter{UserID: userId}) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to list publishers for user ID %s w/ err: %v", userId, err) return drip.CreatePublisher500JSONResponse{Message: "Failed to list publishers", Error: err.Error()}, err @@ -164,10 +158,9 @@ func (s *DripStrictServerImplementation) DeletePublisher( func (s *DripStrictServerImplementation) GetPublisher( ctx context.Context, request drip.GetPublisherRequestObject) (drip.GetPublisherResponseObject, error) { - publisherId := request.PublisherId - log.Ctx(ctx).Info().Msgf("GetPublisher request received for publisher ID: %s", publisherId) - publisher, err := s.RegistryService.GetPublisher(ctx, s.Client, publisherId) + publisherId := request.PublisherId + publisher, err := s.RegistryService.GetPublisher(ctx, s.Client, request.PublisherId) if ent.IsNotFound(err) { log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", publisherId) return drip.GetPublisher404JSONResponse{Message: "Publisher not found"}, nil @@ -183,9 +176,7 @@ func (s *DripStrictServerImplementation) GetPublisher( func (s *DripStrictServerImplementation) UpdatePublisher( ctx context.Context, request drip.UpdatePublisherRequestObject) (drip.UpdatePublisherResponseObject, error) { - log.Ctx(ctx).Info().Msgf("UpdatePublisher called with publisher ID: %s", request.PublisherId) - log.Ctx(ctx).Info().Msgf("Updating publisher with ID %s", request.PublisherId) updateOne := mapper.ApiUpdatePublisherToUpdateFields(request.PublisherId, request.Body, s.Client) updatedPublisher, err := s.RegistryService.UpdatePublisher(ctx, s.Client, updateOne) if err != nil { @@ -199,7 +190,6 @@ func (s *DripStrictServerImplementation) UpdatePublisher( func (s *DripStrictServerImplementation) CreateNode( ctx context.Context, request drip.CreateNodeRequestObject) (drip.CreateNodeResponseObject, error) { - log.Ctx(ctx).Info().Msgf("CreateNode called with publisher ID: %s", request.PublisherId) node, err := s.RegistryService.CreateNode(ctx, s.Client, request.PublisherId, request.Body) if mapper.IsErrorBadRequest(err) || ent.IsConstraintError(err) { @@ -218,10 +208,9 @@ func (s *DripStrictServerImplementation) CreateNode( func (s *DripStrictServerImplementation) ListNodesForPublisher( ctx context.Context, request drip.ListNodesForPublisherRequestObject) (drip.ListNodesForPublisherResponseObject, error) { - log.Ctx(ctx).Info().Msgf("ListNodesForPublisher request received for publisher ID: %s", request.PublisherId) nodeResults, err := s.RegistryService.ListNodes( - ctx, s.Client /*page=*/, 1 /*limit=*/, 10, &drip_services.NodeFilter{ + ctx, s.Client /*page=*/, 1 /*limit=*/, 10, &entity.NodeFilter{ PublisherID: request.PublisherId, }) if err != nil { @@ -257,8 +246,6 @@ func (s *DripStrictServerImplementation) ListAllNodes( log.Ctx(ctx).Error().Msgf("Failed to track event w/ err: %v", err) } - log.Ctx(ctx).Info().Msg("ListAllNodes request received") - // Set default values for pagination parameters page := 1 if request.Params.Page != nil { @@ -270,7 +257,7 @@ func (s *DripStrictServerImplementation) ListAllNodes( } // Initialize the node filter - filter := &drip_services.NodeFilter{} + filter := &entity.NodeFilter{} if request.Params.IncludeBanned != nil { filter.IncludeBanned = *request.Params.IncludeBanned } @@ -321,7 +308,6 @@ func (s *DripStrictServerImplementation) ListAllNodes( // SearchNodes implements drip.StrictServerInterface. func (s *DripStrictServerImplementation) SearchNodes(ctx context.Context, request drip.SearchNodesRequestObject) (drip.SearchNodesResponseObject, error) { - log.Ctx(ctx).Info().Msg("SearchNodes request received") // Set default values for pagination parameters page := 1 @@ -333,7 +319,7 @@ func (s *DripStrictServerImplementation) SearchNodes(ctx context.Context, reques limit = *request.Params.Limit } - f := &drip_services.NodeFilter{} + f := &entity.NodeFilter{} if request.Params.Search != nil { f.Search = *request.Params.Search } @@ -386,8 +372,6 @@ func (s *DripStrictServerImplementation) SearchNodes(ctx context.Context, reques func (s *DripStrictServerImplementation) DeleteNode( ctx context.Context, request drip.DeleteNodeRequestObject) (drip.DeleteNodeResponseObject, error) { - log.Ctx(ctx).Info().Msgf("DeleteNode request received for node ID: %s", request.NodeId) - err := s.RegistryService.DeleteNode(ctx, s.Client, request.NodeId) if err != nil && !ent.IsNotFound(err) { log.Ctx(ctx).Error().Msgf("Failed to delete node %s w/ err: %v", request.NodeId, err) @@ -400,7 +384,6 @@ func (s *DripStrictServerImplementation) DeleteNode( func (s *DripStrictServerImplementation) GetNode( ctx context.Context, request drip.GetNodeRequestObject) (drip.GetNodeResponseObject, error) { - log.Ctx(ctx).Info().Msgf("GetNode request received for node ID: %s", request.NodeId) node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) if ent.IsNotFound(err) { @@ -429,8 +412,6 @@ func (s *DripStrictServerImplementation) GetNode( func (s *DripStrictServerImplementation) UpdateNode( ctx context.Context, request drip.UpdateNodeRequestObject) (drip.UpdateNodeResponseObject, error) { - log.Ctx(ctx).Info().Msgf("UpdateNode request received for node ID: %s", request.NodeId) - updateOneFunc := func(client *ent.Client) *ent.NodeUpdateOne { return mapper.ApiUpdateNodeToUpdateFields(request.NodeId, request.Body, client) } @@ -450,11 +431,10 @@ func (s *DripStrictServerImplementation) UpdateNode( func (s *DripStrictServerImplementation) ListNodeVersions( ctx context.Context, request drip.ListNodeVersionsRequestObject) (drip.ListNodeVersionsResponseObject, error) { - log.Ctx(ctx).Info().Msgf("ListNodeVersions request received for node ID: %s", request.NodeId) apiStatus := mapper.ApiNodeVersionStatusesToDbNodeVersionStatuses(request.Params.Statuses) - nodeVersionsResult, err := s.RegistryService.ListNodeVersions(ctx, s.Client, &drip_services.NodeVersionFilter{ + nodeVersionsResult, err := s.RegistryService.ListNodeVersions(ctx, s.Client, &entity.NodeVersionFilter{ NodeId: request.NodeId, Status: apiStatus, IncludeStatusReason: mapper.BoolPtrToBool(request.Params.IncludeStatusReason), @@ -475,7 +455,6 @@ func (s *DripStrictServerImplementation) ListNodeVersions( func (s *DripStrictServerImplementation) PublishNodeVersion( ctx context.Context, request drip.PublishNodeVersionRequestObject) (drip.PublishNodeVersionResponseObject, error) { - log.Ctx(ctx).Info().Msgf("PublishNodeVersion request received for node ID: %s", request.NodeId) // Check if node exists, create if not node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) @@ -532,9 +511,6 @@ func (s *DripStrictServerImplementation) PublishNodeVersion( func (s *DripStrictServerImplementation) UpdateNodeVersion( ctx context.Context, request drip.UpdateNodeVersionRequestObject) (drip.UpdateNodeVersionResponseObject, error) { - log.Ctx(ctx).Info().Msgf("UpdateNodeVersion request received for node ID: "+ - "%s, version ID: %s", request.NodeId, request.VersionId) - // Update node version updateOne := mapper.ApiUpdateNodeVersionToUpdateFields(request.VersionId, request.Body, s.Client) version, err := s.RegistryService.UpdateNodeVersion(ctx, s.Client, updateOne) @@ -557,8 +533,6 @@ func (s *DripStrictServerImplementation) UpdateNodeVersion( // PostNodeVersionReview implements drip.StrictServerInterface. func (s *DripStrictServerImplementation) PostNodeReview(ctx context.Context, request drip.PostNodeReviewRequestObject) (drip.PostNodeReviewResponseObject, error) { - log.Ctx(ctx).Info().Msgf("PostNodeReview request received for "+ - "node ID: %s", request.NodeId) if request.Params.Star < 1 || request.Params.Star > 5 { log.Ctx(ctx).Error().Msgf("Invalid star received: %d", request.Params.Star) @@ -585,8 +559,6 @@ func (s *DripStrictServerImplementation) PostNodeReview(ctx context.Context, req func (s *DripStrictServerImplementation) DeleteNodeVersion( ctx context.Context, request drip.DeleteNodeVersionRequestObject) (drip.DeleteNodeVersionResponseObject, error) { - log.Ctx(ctx).Info().Msgf("DeleteNodeVersion request received for node ID: "+ - "%s, version ID: %s", request.NodeId, request.VersionId) // Directly return the message that node versions cannot be deleted errMessage := "Cannot delete node versions. Please deprecate it instead." @@ -598,8 +570,6 @@ func (s *DripStrictServerImplementation) DeleteNodeVersion( func (s *DripStrictServerImplementation) GetNodeVersion( ctx context.Context, request drip.GetNodeVersionRequestObject) (drip.GetNodeVersionResponseObject, error) { - log.Ctx(ctx).Info().Msgf("GetNodeVersion request received for "+ - "node ID: %s, version ID: %s", request.NodeId, request.VersionId) nodeVersion, err := s.RegistryService.GetNodeVersionByVersion(ctx, s.Client, request.NodeId, request.VersionId) if ent.IsNotFound(err) { @@ -622,7 +592,6 @@ func (s *DripStrictServerImplementation) GetNodeVersion( func (s *DripStrictServerImplementation) ListPersonalAccessTokens( ctx context.Context, request drip.ListPersonalAccessTokensRequestObject) (drip.ListPersonalAccessTokensResponseObject, error) { - log.Ctx(ctx).Info().Msgf("ListPersonalAccessTokens request received for publisher ID: %s", request.PublisherId) // List personal access tokens personalAccessTokens, err := s.RegistryService.ListPersonalAccessTokens(ctx, s.Client, request.PublisherId) @@ -645,9 +614,6 @@ func (s *DripStrictServerImplementation) ListPersonalAccessTokens( func (s *DripStrictServerImplementation) CreatePersonalAccessToken( ctx context.Context, request drip.CreatePersonalAccessTokenRequestObject) (drip.CreatePersonalAccessTokenResponseObject, error) { - log.Ctx(ctx).Info().Msgf("CreatePersonalAccessToken request received "+ - "for publisher ID: %s", request.PublisherId) - // Create personal access token description := "" if request.Body.Description != nil { @@ -672,8 +638,6 @@ func (s *DripStrictServerImplementation) CreatePersonalAccessToken( func (s *DripStrictServerImplementation) DeletePersonalAccessToken( ctx context.Context, request drip.DeletePersonalAccessTokenRequestObject) (drip.DeletePersonalAccessTokenResponseObject, error) { - log.Ctx(ctx).Info().Msgf("DeletePersonalAccessToken request received for token ID: %s", request.TokenId) - // Retrieve user ID from context userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { @@ -713,8 +677,6 @@ func (s *DripStrictServerImplementation) DeletePersonalAccessToken( func (s *DripStrictServerImplementation) InstallNode( ctx context.Context, request drip.InstallNodeRequestObject) (drip.InstallNodeResponseObject, error) { // TODO(robinhuang): Refactor to separate class - log.Ctx(ctx).Info().Msgf("InstallNode request received for node ID: %s", request.NodeId) - // Get node node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) if ent.IsNotFound(err) { @@ -931,7 +893,7 @@ func (s *DripStrictServerImplementation) SecurityScan( maxNodes = *request.Params.MaxNodes } - nodeVersionsResult, err := s.RegistryService.ListNodeVersions(ctx, s.Client, &drip_services.NodeVersionFilter{ + nodeVersionsResult, err := s.RegistryService.ListNodeVersions(ctx, s.Client, &entity.NodeVersionFilter{ Status: []schema.NodeVersionStatus{schema.NodeVersionStatusPending}, MinAge: minAge, PageSize: maxNodes, @@ -955,7 +917,6 @@ func (s *DripStrictServerImplementation) SecurityScan( func (s *DripStrictServerImplementation) ListAllNodeVersions( ctx context.Context, request drip.ListAllNodeVersionsRequestObject) (drip.ListAllNodeVersionsResponseObject, error) { - log.Ctx(ctx).Info().Msgf("ListAllNodeVersions request received %+v", request.Params) // Default values for pagination page := 1 @@ -978,7 +939,7 @@ func (s *DripStrictServerImplementation) ListAllNodeVersions( pageSize = *request.Params.PageSize } - f := &drip_services.NodeVersionFilter{ + f := &entity.NodeVersionFilter{ Page: page, PageSize: pageSize, IncludeStatusReason: mapper.BoolPtrToBool(request.Params.IncludeStatusReason), @@ -1032,7 +993,6 @@ func (s *DripStrictServerImplementation) ListAllNodeVersions( } func (s *DripStrictServerImplementation) ReindexNodes(ctx context.Context, request drip.ReindexNodesRequestObject) (res drip.ReindexNodesResponseObject, err error) { - log.Ctx(ctx).Info().Msg("ReindexNodes request received") err = s.RegistryService.ReindexAllNodes(ctx, s.Client) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to reindex all nodes w/ err: %v", err) @@ -1045,7 +1005,6 @@ func (s *DripStrictServerImplementation) ReindexNodes(ctx context.Context, reque // CreateComfyNodes bulk-creates comfy-nodes for a node version func (impl *DripStrictServerImplementation) CreateComfyNodes(ctx context.Context, request drip.CreateComfyNodesRequestObject) (res drip.CreateComfyNodesResponseObject, err error) { - log.Ctx(ctx).Info().Msg("CreateComfyNodes request received") err = impl.RegistryService.CreateComfyNodes(ctx, impl.Client, request.NodeId, request.Version, *request.Body.Nodes) if ent.IsNotFound(err) { log.Ctx(ctx).Error().Msgf("Node or node version not found w/ err: %v", err) @@ -1066,7 +1025,6 @@ func (impl *DripStrictServerImplementation) CreateComfyNodes(ctx context.Context // GetComfyNode return a certain comfy-node of a certain node version func (impl *DripStrictServerImplementation) GetComfyNode(ctx context.Context, request drip.GetComfyNodeRequestObject) (res drip.GetComfyNodeResponseObject, err error) { - log.Ctx(ctx).Info().Msg("GetComfyNode request received") n, err := impl.RegistryService.GetComfyNode(ctx, impl.Client, request.NodeId, request.Version, request.ComfyNodeId) if ent.IsNotFound(err) { @@ -1086,7 +1044,6 @@ func (impl *DripStrictServerImplementation) GetComfyNode(ctx context.Context, re } func (impl *DripStrictServerImplementation) ComfyNodesBackfill(ctx context.Context, request drip.ComfyNodesBackfillRequestObject) (drip.ComfyNodesBackfillResponseObject, error) { - log.Ctx(ctx).Info().Msg("ComfyNodesBackfill request received") err := impl.RegistryService.TriggerComfyNodesBackfill(ctx, impl.Client, request.Params.MaxNode) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to trigger comfy nodes backfill w/ err: %v", err) diff --git a/server/middleware/authentication/firebase_auth.go b/server/middleware/authentication/firebase_auth.go index 0fd4ba6..49344c8 100644 --- a/server/middleware/authentication/firebase_auth.go +++ b/server/middleware/authentication/firebase_auth.go @@ -1,4 +1,4 @@ -package drip_authentication +package authentication import ( "context" diff --git a/server/middleware/authentication/firebase_auth_test.go b/server/middleware/authentication/firebase_auth_test.go index c466bc3..f6d3651 100644 --- a/server/middleware/authentication/firebase_auth_test.go +++ b/server/middleware/authentication/firebase_auth_test.go @@ -1,4 +1,4 @@ -package drip_authentication +package authentication import ( "net/http" diff --git a/server/middleware/authentication/jwt_admin_auth.go b/server/middleware/authentication/jwt_admin_auth.go index 0b1b886..14ffa8c 100644 --- a/server/middleware/authentication/jwt_admin_auth.go +++ b/server/middleware/authentication/jwt_admin_auth.go @@ -1,4 +1,4 @@ -package drip_authentication +package authentication import ( "context" diff --git a/server/middleware/authentication/jwt_admin_auth_test.go b/server/middleware/authentication/jwt_admin_auth_test.go index 10c05fe..2c9b80c 100644 --- a/server/middleware/authentication/jwt_admin_auth_test.go +++ b/server/middleware/authentication/jwt_admin_auth_test.go @@ -1,4 +1,4 @@ -package drip_authentication +package authentication import ( "net/http" diff --git a/server/middleware/authentication/service_account_auth.go b/server/middleware/authentication/service_account_auth.go index cb2e87d..bc7b63c 100644 --- a/server/middleware/authentication/service_account_auth.go +++ b/server/middleware/authentication/service_account_auth.go @@ -1,4 +1,4 @@ -package drip_authentication +package authentication import ( "net/http" diff --git a/server/middleware/authentication/service_account_auth_test.go b/server/middleware/authentication/service_account_auth_test.go index 3642e66..975fad7 100644 --- a/server/middleware/authentication/service_account_auth_test.go +++ b/server/middleware/authentication/service_account_auth_test.go @@ -1,4 +1,4 @@ -package drip_authentication +package authentication import ( "net/http" diff --git a/server/middleware/error_logger.go b/server/middleware/error_logger.go deleted file mode 100644 index 6fa22d3..0000000 --- a/server/middleware/error_logger.go +++ /dev/null @@ -1,24 +0,0 @@ -package drip_middleware - -import ( - "github.com/rs/zerolog/log" - - "github.com/labstack/echo/v4" -) - -func ErrorLoggingMiddleware() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - err := next(c) - - if err != nil { - log.Ctx(c.Request().Context()). - Error(). - Err(err). - Msgf("Error occurred Path: %s, Method: %s\n", c.Path(), c.Request().Method) - } - - return err - } - } -} diff --git a/server/middleware/metric/metric.go b/server/middleware/metric/metric.go index ad3a900..7f0dd72 100644 --- a/server/middleware/metric/metric.go +++ b/server/middleware/metric/metric.go @@ -1,4 +1,4 @@ -package drip_metric +package metric import ( "context" diff --git a/server/middleware/metric/metric_middleware.go b/server/middleware/metric/metric_middleware.go index cf53b4f..492c143 100644 --- a/server/middleware/metric/metric_middleware.go +++ b/server/middleware/metric/metric_middleware.go @@ -1,4 +1,4 @@ -package drip_metric +package metric import ( "context" diff --git a/server/middleware/request_logger.go b/server/middleware/request_logger.go new file mode 100644 index 0000000..294dffe --- /dev/null +++ b/server/middleware/request_logger.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "bytes" + "fmt" + "github.com/labstack/echo/v4" + echo_middleware "github.com/labstack/echo/v4/middleware" + "github.com/rs/zerolog/log" + "io" +) + +func RequestLoggerMiddleware() echo.MiddlewareFunc { + return echo_middleware.RequestLoggerWithConfig(echo_middleware.RequestLoggerConfig{ + LogURI: true, + LogStatus: true, + LogValuesFunc: func(c echo.Context, v echo_middleware.RequestLoggerValues) error { + // Read the request body for logging + requestBody, err := io.ReadAll(c.Request().Body) + if err != nil { + log.Ctx(c.Request().Context()).Error().Err(err).Msg("Failed to read request body") + return err + } + // Reset the body for further use + c.Request().Body = io.NopCloser(bytes.NewReader(requestBody)) + + // Log request details including query parameters + log.Ctx(c.Request().Context()). + Info(). + Str("Method", c.Request().Method). + Str("Path", c.Path()). + Str("QueryParams", fmt.Sprintf("%v", c.QueryParams())). + Str("RequestBody", string(requestBody)). + Str("Headers", fmt.Sprintf("%v", c.Request().Header)). + Msg("Request received") + return nil + }, + }) +} diff --git a/server/middleware/response_logger.go b/server/middleware/response_logger.go new file mode 100644 index 0000000..51b358f --- /dev/null +++ b/server/middleware/response_logger.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "bytes" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" +) + +// Custom response writer to capture response body +type responseWriter struct { + http.ResponseWriter + body *bytes.Buffer +} + +func (rw *responseWriter) Write(p []byte) (n int, err error) { + // Capture the response body in the buffer + n, err = rw.body.Write(p) + if err != nil { + return n, err + } + // Write to the actual ResponseWriter + return rw.ResponseWriter.Write(p) +} + +// ResponseLoggerMiddleware will log response details and errors. +func ResponseLoggerMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // Create a custom response writer to capture the response body + rw := &responseWriter{ + ResponseWriter: c.Response().Writer, + body: new(bytes.Buffer), + } + c.Response().Writer = rw + + // Call the next handler in the chain + err := next(c) + + // Log any errors that occur during handling + if err != nil { + log.Ctx(c.Request().Context()). + Error(). + Err(err). + Str("Method", c.Request().Method). + Str("Path", c.Path()). + Msg("Error occurred during request handling") + } + + // Log the response details + log.Ctx(c.Request().Context()). + Info(). + Int("Status", c.Response().Status). + Str("ResponseBody", rw.body.String()). + Str("ResponseHeaders", fmt.Sprintf("%v", c.Response().Header())). + Msg("Response sent") + + return err + } + } +} diff --git a/server/middleware/tracing_middleware.go b/server/middleware/tracing_middleware.go index d7c79cd..d5b8e1a 100644 --- a/server/middleware/tracing_middleware.go +++ b/server/middleware/tracing_middleware.go @@ -1,4 +1,4 @@ -package drip_middleware +package middleware import ( "context" diff --git a/server/server.go b/server/server.go index 04ceaa8..6a08f67 100644 --- a/server/server.go +++ b/server/server.go @@ -1,7 +1,10 @@ package server import ( + monitoring "cloud.google.com/go/monitoring/apiv3/v2" "context" + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" "registry-backend/config" generated "registry-backend/drip" "registry-backend/ent" @@ -12,18 +15,10 @@ import ( "registry-backend/gateways/storage" handler "registry-backend/server/handlers" "registry-backend/server/implementation" - drip_middleware "registry-backend/server/middleware" - drip_authentication "registry-backend/server/middleware/authentication" - drip_authorization "registry-backend/server/middleware/authorization" - drip_metric "registry-backend/server/middleware/metric" - "strings" - - monitoring "cloud.google.com/go/monitoring/apiv3/v2" - - "github.com/labstack/echo/v4/middleware" - "github.com/rs/zerolog/log" - - "github.com/labstack/echo/v4" + "registry-backend/server/middleware" + "registry-backend/server/middleware/authentication" + "registry-backend/server/middleware/authorization" + "registry-backend/server/middleware/metric" ) type ServerDependencies struct { @@ -97,30 +92,13 @@ func (s *Server) Start() error { e.HideBanner = true // Apply middleware - e.Use(drip_middleware.TracingMiddleware) - e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: []string{"*"}, - AllowMethods: []string{"*"}, - AllowHeaders: []string{"*"}, - AllowOriginFunc: func(origin string) (bool, error) { - return true, nil - }, - AllowCredentials: true, - })) - e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ - LogURI: true, - LogStatus: true, - LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { - if strings.HasPrefix(c.Request().URL.Path, "/vm/") { - return nil - } - - log.Ctx(c.Request().Context()).Debug(). - Str("URI: ", v.URI). - Int("status", v.Status).Msg("") - return nil - }, - })) + e.Use(middleware.TracingMiddleware) + e.Use(middleware.RequestLoggerMiddleware()) + e.Use(middleware.ResponseLoggerMiddleware()) + e.Use(metric.MetricsMiddleware(&s.Dependencies.MonitoringClient, s.Config)) + e.Use(authentication.FirebaseAuthMiddleware(s.Client)) + e.Use(authentication.ServiceAccountAuthMiddleware()) + e.Use(authentication.JWTAdminAuthMiddleware(s.Client, s.Config.JWTSecret)) // Attach implementation of the generated OAPI strict server impl := implementation.NewStrictServerImplementation( @@ -144,13 +122,6 @@ func (s *Server) Start() error { e.GET("/openapi", handler.SwaggerHandler) e.GET("/health", s.HealthCheckHandler) - // Apply global middlewares - e.Use(drip_metric.MetricsMiddleware(&s.Dependencies.MonitoringClient, s.Config)) - e.Use(drip_authentication.FirebaseAuthMiddleware(s.Client)) - e.Use(drip_authentication.ServiceAccountAuthMiddleware()) - e.Use(drip_authentication.JWTAdminAuthMiddleware(s.Client, s.Config.JWTSecret)) - e.Use(drip_middleware.ErrorLoggingMiddleware()) - // Start the server return e.Start(":8080") } diff --git a/services/registry/registry_svc.go b/services/registry/registry_svc.go index bfa31eb..4cdcba2 100644 --- a/services/registry/registry_svc.go +++ b/services/registry/registry_svc.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "registry-backend/common" "registry-backend/config" "registry-backend/db" "registry-backend/drip" @@ -21,6 +22,7 @@ import ( "registry-backend/ent/publisherpermission" "registry-backend/ent/schema" "registry-backend/ent/user" + "registry-backend/entity" "registry-backend/gateways/algolia" "registry-backend/gateways/discord" "registry-backend/gateways/pubsub" @@ -59,68 +61,8 @@ func NewRegistryService(storageSvc storage.StorageService, pubsubService pubsub. } } -type PublisherFilter struct { - UserID string -} - -// NodeFilter holds optional parameters for filtering node results -type NodeFilter struct { - PublisherID string - Search string - IncludeBanned bool -} - -type NodeVersionFilter struct { - NodeId string - Status []schema.NodeVersionStatus - IncludeStatusReason bool - MinAge time.Duration - PageSize int - Page int -} - -type NodeData struct { - ID string `json:"id"` - Name string `json:"name"` - PublisherID string `json:"publisherId"` -} - -// ListNodesResult is the structure that holds the paginated result of nodes -type ListNodesResult struct { - Total int `json:"total"` - Nodes []*ent.Node `json:"nodes"` - Page int `json:"page"` - Limit int `json:"limit"` - TotalPages int `json:"totalPages"` -} - -type ListNodeVersionsResult struct { - Total int `json:"total"` - NodeVersions []*ent.NodeVersion `json:"nodes"` - Page int `json:"page"` - Limit int `json:"limit"` - TotalPages int `json:"totalPages"` -} - -func PrettifyJSON(input string) (string, error) { - // First unmarshal the input string into a generic interface{} - var temp interface{} - err := json.Unmarshal([]byte(input), &temp) - if err != nil { - return "", fmt.Errorf("invalid JSON input: %v", err) - } - - // Marshal back to JSON with indentation - pretty, err := json.MarshalIndent(temp, "", " ") - if err != nil { - return "", fmt.Errorf("failed to marshal JSON: %v", err) - } - - return string(pretty), nil -} - // ListNodes retrieves a paginated list of nodes with optional filtering. -func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, page, limit int, filter *NodeFilter) (*ListNodesResult, error) { +func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, page, limit int, filter *entity.NodeFilter) (*entity.ListNodesResult, error) { // Ensure valid pagination parameters if page < 1 { page = 1 @@ -200,7 +142,7 @@ func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, pag } // Return the result - return &ListNodesResult{ + return &entity.ListNodesResult{ Total: total, Nodes: nodes, Page: page, @@ -210,7 +152,7 @@ func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, pag } // ListPublishers queries the Publisher table with an optional user ID filter via PublisherPermission -func (s *RegistryService) ListPublishers(ctx context.Context, client *ent.Client, filter *PublisherFilter) ([]*ent.Publisher, error) { +func (s *RegistryService) ListPublishers(ctx context.Context, client *ent.Client, filter *entity.PublisherFilter) ([]*ent.Publisher, error) { log.Ctx(ctx).Info().Msg("Listing publishers") query := client.Publisher.Query() @@ -454,7 +396,7 @@ type NodeVersionCreation struct { } func (s *RegistryService) ListNodeVersions( - ctx context.Context, client *ent.Client, filter *NodeVersionFilter) (*ListNodeVersionsResult, error) { + ctx context.Context, client *ent.Client, filter *entity.NodeVersionFilter) (*entity.ListNodeVersionsResult, error) { query := client.NodeVersion.Query(). WithStorageFile(). WithComfyNodes(). @@ -513,7 +455,7 @@ func (s *RegistryService) ListNodeVersions( totalPages = (total + filter.PageSize - 1) / filter.PageSize // Use ceiling division for total pages } - return &ListNodeVersionsResult{ + return &entity.ListNodeVersionsResult{ Total: total, NodeVersions: versions, Page: filter.Page, @@ -1073,7 +1015,7 @@ func (s *RegistryService) PerformSecurityCheck(ctx context.Context, client *ent. "Security issues found in node %s@%s. Updating to flagged.", nodeVersion.NodeID, nodeVersion.Version) log.Ctx(ctx).Info().Msgf( "List of security issues %s.", issues) // 500 character max. - prettyIssues, err := PrettifyJSON(issues) + prettyIssues, err := common.PrettifyJSON(issues) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("failed to prettify JSON issues") prettyIssues = issues // fallback to unprettified issues