Skip to content

Commit

Permalink
Parameterize backfill
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 committed Jan 3, 2025
1 parent aa7a13c commit 7063521
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 224 deletions.
239 changes: 123 additions & 116 deletions drip/api.gen.go

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion infrastructure/modules/node-pack-extract-trigger/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ resource "google_cloud_scheduler_job" "backfill" {

http_target {
http_method = "POST"
uri = "${var.registry_backend_url}/comfy-nodes/backfill"
uri = "${var.registry_backend_url}/comfy-nodes/backfill?max_node=${var.backfill_job_max_node}"


oidc_token {
service_account_email = data.google_service_account.cloudbuild_service_account.email
Expand Down
6 changes: 6 additions & 0 deletions infrastructure/modules/node-pack-extract-trigger/variable.tf
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ variable "backfill_job_schedule" {
default = "30 3 29 2 *"
}

variable "backfill_job_max_node" {
type = number
default = 10
description = "maximum number of nodes to be backfilled"
}

variable "git_repo_uri" {
type = string
description = "Connected git repo containing the cloud build pipeline. See https://cloud.google.com/build/docs/repositories"
Expand Down
39 changes: 23 additions & 16 deletions integration-tests/registry_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,9 @@ func TestRegistryComfyNode(t *testing.T) {
nodeVersionToBeBackfill := []*drip.NodeVersion{
randomNodeVersion(1),
randomNodeVersion(2),
randomNodeVersion(3),
randomNodeVersion(4),
randomNodeVersion(5),
}
for _, nv := range nodeVersionToBeBackfill {
_, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
Expand Down Expand Up @@ -976,17 +979,6 @@ func TestRegistryComfyNode(t *testing.T) {
}
})

t.Run("Conflict", func(t *testing.T) {
body := drip.CreateComfyNodesJSONRequestBody(comfyNodes)
res, err := withMiddleware(authz, impl.CreateComfyNodes)(ctx, drip.CreateComfyNodesRequestObject{
NodeId: *node.Id,
Version: *nodeVersion.Version,
Body: &body,
})
require.NoError(t, err)
require.IsType(t, drip.CreateComfyNodes409JSONResponse{}, res)
})

t.Run("GetNodeVersion", func(t *testing.T) {
res, err := withMiddleware(authz, impl.GetNodeVersion)(ctx, drip.GetNodeVersionRequestObject{
NodeId: *node.Id,
Expand Down Expand Up @@ -1044,10 +1036,25 @@ func TestRegistryComfyNode(t *testing.T) {
})

t.Run("TriggerBackfill", func(t *testing.T) {
impl.mockPubsubService.On("PublishNodePack", mock.Anything, mock.Anything).Return(nil)
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(nodeVersionToBeBackfill))
mockCalled := 0
t.Run("Unlimited", func(t *testing.T) {
impl.mockPubsubService.On("PublishNodePack", mock.Anything, mock.Anything).Return(nil)
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(nodeVersionToBeBackfill)+mockCalled)
mockCalled += len(nodeVersionToBeBackfill)
})

t.Run("Limited", func(t *testing.T) {
limit := 2
impl.mockPubsubService.On("PublishNodePack", mock.Anything, mock.Anything).Return(nil)
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{Params: drip.ComfyNodesBackfillParams{MaxNode: &limit}})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", limit+mockCalled)
mockCalled += limit
})
})

}
4 changes: 2 additions & 2 deletions node-pack-extract/cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ steps:
- -c
- gcloud auth print-identity-token --audiences="$_REGISTRY_BACKEND_URL" | tee /workspace/token

- name: "gcr.io/cloud-builders/curl"
entrypoint: "bash"
- name: "curlimages/curl"
entrypoint: "sh"
args:
- -c
- |
Expand Down
2 changes: 1 addition & 1 deletion node-pack-extract/test/trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestApply(t *testing.T) {
require.NoError(t, err)
h := j.GetHttpTarget()
require.NotNil(t, h)
assert.Contains(t, h.GetUri(), "/comfy-nodes/backfill")
assert.Equal(t, h.GetUri(), "https://stagingapi.comfy.org/comfy-nodes/backfill?max_node=10")
assert.Equal(t, http.MethodPost, h.GetHttpMethod().String())
assert.Equal(t, "https://stagingapi.comfy.org", h.GetOidcToken().GetAudience())
assert.Equal(t, serviceAccount, h.GetOidcToken().GetServiceAccountEmail())
Expand Down
13 changes: 7 additions & 6 deletions openapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1803,12 +1803,6 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
'409':
description: Existing Comfy Nodes exists
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
'500':
description: Internal server error
content:
Expand Down Expand Up @@ -1870,6 +1864,13 @@ paths:
operationId: ComfyNodesBackfill
tags:
- ComfyNodes
parameters:
- in: query
name: max_node
required: false
schema:
type: integer
default: 10
responses:
'204':
description: Backfill triggered
Expand Down
38 changes: 15 additions & 23 deletions server/implementation/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package implementation

import (
"context"
"errors"
"registry-backend/drip"
"registry-backend/ent"
"registry-backend/ent/publisher"
Expand Down Expand Up @@ -727,12 +726,6 @@ func (s *DripStrictServerImplementation) InstallNode(

// Install node version
if request.Params.Version == nil {
s.MixpanelService.Track(ctx, []*mixpanel.Event{
s.MixpanelService.NewEvent("Install Node", "", map[string]any{
"Node ID": request.NodeId,
"Version": "latest",
}),
})
nodeVersion, err := s.RegistryService.GetLatestNodeVersion(ctx, s.Client, request.NodeId)
if err == nil && nodeVersion == nil {
log.Ctx(ctx).Error().Msgf("Latest node version not found")
Expand All @@ -743,19 +736,22 @@ func (s *DripStrictServerImplementation) InstallNode(
log.Ctx(ctx).Error().Msgf("Error retrieving latest node version w/ err: %v", err)
return drip.InstallNode500JSONResponse{Message: errMessage}, err
}

_, err = s.RegistryService.RecordNodeInstallation(ctx, s.Client, node)
_, err = s.RegistryService.RecordNodeInstalation(ctx, s.Client, node)
if err != nil {
errMessage := "Failed to get increment number of node version install: " + err.Error()
log.Ctx(ctx).Error().Msgf("Error incrementing number of latest node version install w/ err: %v", err)
return drip.InstallNode500JSONResponse{Message: errMessage}, err
}

s.MixpanelService.Track(ctx, []*mixpanel.Event{
s.MixpanelService.NewEvent("Install Node Latest", "", map[string]any{
"Node ID": request.NodeId,
"Version": nodeVersion.Version,
}),
})
return drip.InstallNode200JSONResponse(
*mapper.DbNodeVersionToApiNodeVersion(nodeVersion),
), nil
} else {

nodeVersion, err := s.RegistryService.GetNodeVersionByVersion(ctx, s.Client, request.NodeId, *request.Params.Version)
if ent.IsNotFound(err) {
log.Ctx(ctx).Error().Msgf("Error retrieving node version w/ err: %v", err)
Expand All @@ -766,18 +762,18 @@ func (s *DripStrictServerImplementation) InstallNode(
log.Ctx(ctx).Error().Msgf("Error retrieving node version w/ err: %v", err)
return drip.InstallNode500JSONResponse{Message: errMessage}, err
}
s.MixpanelService.Track(ctx, []*mixpanel.Event{
s.MixpanelService.NewEvent("Install Node", "", map[string]any{
"Node ID": request.NodeId,
"Version": request.Params.Version,
}),
})
_, err = s.RegistryService.RecordNodeInstallation(ctx, s.Client, node)
_, err = s.RegistryService.RecordNodeInstalation(ctx, s.Client, node)
if err != nil {
errMessage := "Failed to get increment number of node version install: " + err.Error()
log.Ctx(ctx).Error().Msgf("Error incrementing number of latest node version install w/ err: %v", err)
return drip.InstallNode500JSONResponse{Message: errMessage}, err
}
s.MixpanelService.Track(ctx, []*mixpanel.Event{
s.MixpanelService.NewEvent("Install Node", "", map[string]any{
"Node ID": request.NodeId,
"Version": nodeVersion.Version,
}),
})
return drip.InstallNode200JSONResponse(
*mapper.DbNodeVersionToApiNodeVersion(nodeVersion),
), nil
Expand Down Expand Up @@ -1027,10 +1023,6 @@ func (impl *DripStrictServerImplementation) CreateComfyNodes(ctx context.Context
log.Ctx(ctx).Error().Msgf("Node or node version not found w/ err: %v", err)
return drip.CreateComfyNodes404JSONResponse{Message: "Node or node version not found", Error: err.Error()}, nil
}
if errors.Is(err, drip_services.ErrComfyNodesAlreadyExist) {
log.Ctx(ctx).Error().Msgf("Comfy nodes for %s %s exist", request.NodeId, request.Version)
return drip.CreateComfyNodes409JSONResponse{Message: "Comfy nodes already exist", Error: err.Error()}, nil
}
if err != nil {
log.Ctx(ctx).Error().Msgf("Failed to create comfy nodes w/ err: %v", err)
return drip.CreateComfyNodes500JSONResponse{Message: "Failed to create comfy nodes", Error: err.Error()}, nil
Expand Down Expand Up @@ -1063,7 +1055,7 @@ func (impl *DripStrictServerImplementation) GetComfyNode(ctx context.Context, re

func (impl *DripStrictServerImplementation) ComfyNodesBackfill(ctx context.Context, request drip.ComfyNodesBackfillRequestObject) (drip.ComfyNodesBackfillResponseObject, error) {
log.Ctx(ctx).Info().Msg("ComfyNodesBackfill request received")
err := impl.RegistryService.TriggerComfyNodesBackfill(ctx, impl.Client)
err := impl.RegistryService.TriggerComfyNodesBackfill(ctx, impl.Client, request.Params.MaxNode)
if err != nil {
log.Ctx(ctx).Error().Msgf("Failed to trigger comfy nodes backfill w/ err: %v", err)
return drip.ComfyNodesBackfill500JSONResponse{Message: "Failed to trigger comfy nodes backfill", Error: err.Error()}, nil
Expand Down
109 changes: 50 additions & 59 deletions services/registry/registry_svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Cli
})
}

func (s *RegistryService) RecordNodeInstallation(ctx context.Context, client *ent.Client, node *ent.Node) (*ent.Node, error) {
func (s *RegistryService) RecordNodeInstalation(ctx context.Context, client *ent.Client, node *ent.Node) (*ent.Node, error) {
var n *ent.Node
err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) {
node, err = tx.Node.UpdateOne(node).AddTotalInstall(1).Save(ctx)
Expand Down Expand Up @@ -626,64 +626,52 @@ func (s *RegistryService) GetLatestNodeVersion(ctx context.Context, client *ent.
return nodeVersion, nil
}

var ErrComfyNodesAlreadyExist = errors.New("comfy nodes already exist")

func (s *RegistryService) CreateComfyNodes(ctx context.Context, client *ent.Client, nodeID string, nodeVersion string, comfyNodes map[string]drip.ComfyNode) (err error) {
return db.WithTx(ctx, client, func(tx *ent.Tx) error {
nv, err := client.NodeVersion.Query().
Where(nodeversion.VersionEQ(nodeVersion)).
Where(nodeversion.NodeIDEQ(nodeID)).
WithComfyNodes().
ForUpdate().
Only(ctx)
if err != nil {
return err
}

if len(nv.Edges.ComfyNodes) > 0 {
return ErrComfyNodesAlreadyExist
}
nv, err := client.NodeVersion.Query().
Where(nodeversion.VersionEQ(nodeVersion)).
Where(nodeversion.NodeIDEQ(nodeID)).Only(ctx)
if err != nil {
return err
}

comfyNodesCreates := make([]*ent.ComfyNodeCreate, 0, len(comfyNodes))
for k, n := range comfyNodes {
comfyNodeCreate := client.ComfyNode.Create().
SetID(k).
SetNodeVersionID(nv.ID)
comfyNodesCreates := make([]*ent.ComfyNodeCreate, 0, len(comfyNodes))
for k, n := range comfyNodes {
comfyNodeCreate := client.ComfyNode.Create().
SetID(k).
SetNodeVersionID(nv.ID)

if n.Category != nil {
comfyNodeCreate.SetCategory(*n.Category)
}
if n.Description != nil {
comfyNodeCreate.SetDescription(*n.Description)
}
if n.InputTypes != nil {
comfyNodeCreate.SetInputTypes(*n.InputTypes)
}
if n.Deprecated != nil {
comfyNodeCreate.SetDeprecated(*n.Deprecated)
}
if n.Experimental != nil {
comfyNodeCreate.SetExperimental(*n.Experimental)
}
if n.OutputIsList != nil {
comfyNodeCreate.SetOutputIsList(*n.OutputIsList)
}
if n.ReturnNames != nil {
comfyNodeCreate.SetReturnNames(*n.ReturnNames)
}
if n.ReturnTypes != nil {
comfyNodeCreate.SetReturnTypes(*n.ReturnTypes)
}
if n.Function != nil {
comfyNodeCreate.SetFunction(*n.Function)
}
comfyNodesCreates = append(comfyNodesCreates, comfyNodeCreate)
if n.Category != nil {
comfyNodeCreate.SetCategory(*n.Category)
}
return client.ComfyNode.
CreateBulk(comfyNodesCreates...).
Exec(ctx)
})

if n.Description != nil {
comfyNodeCreate.SetDescription(*n.Description)
}
if n.InputTypes != nil {
comfyNodeCreate.SetInputTypes(*n.InputTypes)
}
if n.Deprecated != nil {
comfyNodeCreate.SetDeprecated(*n.Deprecated)
}
if n.Experimental != nil {
comfyNodeCreate.SetExperimental(*n.Experimental)
}
if n.OutputIsList != nil {
comfyNodeCreate.SetOutputIsList(*n.OutputIsList)
}
if n.ReturnNames != nil {
comfyNodeCreate.SetReturnNames(*n.ReturnNames)
}
if n.ReturnTypes != nil {
comfyNodeCreate.SetReturnTypes(*n.ReturnTypes)
}
if n.Function != nil {
comfyNodeCreate.SetFunction(*n.Function)
}
comfyNodesCreates = append(comfyNodesCreates, comfyNodeCreate)
}
return client.ComfyNode.
CreateBulk(comfyNodesCreates...).
Exec(ctx)
}

func (s *RegistryService) GetComfyNode(ctx context.Context, client *ent.Client, nodeID string, nodeVersion string, comfyNodeID string) (*ent.ComfyNode, error) {
Expand All @@ -700,12 +688,15 @@ func (s *RegistryService) GetComfyNode(ctx context.Context, client *ent.Client,
return nv.Edges.ComfyNodes[0], nil
}

func (s *RegistryService) TriggerComfyNodesBackfill(ctx context.Context, client *ent.Client) error {
nvs, err := client.NodeVersion.
func (s *RegistryService) TriggerComfyNodesBackfill(ctx context.Context, client *ent.Client, max *int) error {
q := client.NodeVersion.
Query().
WithStorageFile().
Where(nodeversion.Not(nodeversion.HasComfyNodes())).
All(ctx)
Where(nodeversion.Not(nodeversion.HasComfyNodes()))
if max != nil {
q.Limit(*max)
}
nvs, err := q.All(ctx)
if err != nil {
return fmt.Errorf("failed to query node versions: %w", err)
}
Expand Down

0 comments on commit 7063521

Please sign in to comment.