Skip to content

Commit

Permalink
Create a new Parameter for batching the backfill operation (#113)
Browse files Browse the repository at this point in the history
* Parameterize backfill

* Resolve merge conflict

---------

Co-authored-by: James Kwon <[email protected]>
  • Loading branch information
james03160927 and james03160927 authored Jan 3, 2025
1 parent aa7a13c commit c46dd20
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 115 deletions.
219 changes: 118 additions & 101 deletions drip/api.gen.go

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion infrastructure/modules/node-pack-extract-trigger/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ resource "google_cloud_scheduler_job" "backfill" {

http_target {
http_method = "POST"
uri = "${var.registry_backend_url}/comfy-nodes/backfill"
uri = "${var.registry_backend_url}/comfy-nodes/backfill?max_node=${var.backfill_job_max_node}"


oidc_token {
service_account_email = data.google_service_account.cloudbuild_service_account.email
Expand Down
6 changes: 6 additions & 0 deletions infrastructure/modules/node-pack-extract-trigger/variable.tf
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ variable "backfill_job_schedule" {
default = "30 3 29 2 *"
}

variable "backfill_job_max_node" {
type = number
default = 10
description = "maximum number of nodes to be backfilled"
}

variable "git_repo_uri" {
type = string
description = "Connected git repo containing the cloud build pipeline. See https://cloud.google.com/build/docs/repositories"
Expand Down
28 changes: 23 additions & 5 deletions integration-tests/registry_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,9 @@ func TestRegistryComfyNode(t *testing.T) {
nodeVersionToBeBackfill := []*drip.NodeVersion{
randomNodeVersion(1),
randomNodeVersion(2),
randomNodeVersion(3),
randomNodeVersion(4),
randomNodeVersion(5),
}
for _, nv := range nodeVersionToBeBackfill {
_, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
Expand Down Expand Up @@ -1044,10 +1047,25 @@ func TestRegistryComfyNode(t *testing.T) {
})

t.Run("TriggerBackfill", func(t *testing.T) {
impl.mockPubsubService.On("PublishNodePack", mock.Anything, mock.Anything).Return(nil)
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(nodeVersionToBeBackfill))
mockCalled := 0
t.Run("Unlimited", func(t *testing.T) {
impl.mockPubsubService.On("PublishNodePack", mock.Anything, mock.Anything).Return(nil)
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(nodeVersionToBeBackfill)+mockCalled)
mockCalled += len(nodeVersionToBeBackfill)
})

t.Run("Limited", func(t *testing.T) {
limit := 2
impl.mockPubsubService.On("PublishNodePack", mock.Anything, mock.Anything).Return(nil)
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{Params: drip.ComfyNodesBackfillParams{MaxNode: &limit}})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", limit+mockCalled)
mockCalled += limit
})
})

}
4 changes: 2 additions & 2 deletions node-pack-extract/cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ steps:
- -c
- gcloud auth print-identity-token --audiences="$_REGISTRY_BACKEND_URL" | tee /workspace/token

- name: "gcr.io/cloud-builders/curl"
entrypoint: "bash"
- name: "curlimages/curl"
entrypoint: "sh"
args:
- -c
- |
Expand Down
2 changes: 1 addition & 1 deletion node-pack-extract/test/trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestApply(t *testing.T) {
require.NoError(t, err)
h := j.GetHttpTarget()
require.NotNil(t, h)
assert.Contains(t, h.GetUri(), "/comfy-nodes/backfill")
assert.Equal(t, h.GetUri(), "https://stagingapi.comfy.org/comfy-nodes/backfill?max_node=10")
assert.Equal(t, http.MethodPost, h.GetHttpMethod().String())
assert.Equal(t, "https://stagingapi.comfy.org", h.GetOidcToken().GetAudience())
assert.Equal(t, serviceAccount, h.GetOidcToken().GetServiceAccountEmail())
Expand Down
7 changes: 7 additions & 0 deletions openapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,13 @@ paths:
operationId: ComfyNodesBackfill
tags:
- ComfyNodes
parameters:
- in: query
name: max_node
required: false
schema:
type: integer
default: 10
responses:
'204':
description: Backfill triggered
Expand Down
2 changes: 1 addition & 1 deletion server/implementation/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ func (impl *DripStrictServerImplementation) GetComfyNode(ctx context.Context, re

func (impl *DripStrictServerImplementation) ComfyNodesBackfill(ctx context.Context, request drip.ComfyNodesBackfillRequestObject) (drip.ComfyNodesBackfillResponseObject, error) {
log.Ctx(ctx).Info().Msg("ComfyNodesBackfill request received")
err := impl.RegistryService.TriggerComfyNodesBackfill(ctx, impl.Client)
err := impl.RegistryService.TriggerComfyNodesBackfill(ctx, impl.Client, request.Params.MaxNode)
if err != nil {
log.Ctx(ctx).Error().Msgf("Failed to trigger comfy nodes backfill w/ err: %v", err)
return drip.ComfyNodesBackfill500JSONResponse{Message: "Failed to trigger comfy nodes backfill", Error: err.Error()}, nil
Expand Down
11 changes: 7 additions & 4 deletions services/registry/registry_svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,12 +700,15 @@ func (s *RegistryService) GetComfyNode(ctx context.Context, client *ent.Client,
return nv.Edges.ComfyNodes[0], nil
}

func (s *RegistryService) TriggerComfyNodesBackfill(ctx context.Context, client *ent.Client) error {
nvs, err := client.NodeVersion.
func (s *RegistryService) TriggerComfyNodesBackfill(ctx context.Context, client *ent.Client, max *int) error {
q := client.NodeVersion.
Query().
WithStorageFile().
Where(nodeversion.Not(nodeversion.HasComfyNodes())).
All(ctx)
Where(nodeversion.Not(nodeversion.HasComfyNodes()))
if max != nil {
q.Limit(*max)
}
nvs, err := q.All(ctx)
if err != nil {
return fmt.Errorf("failed to query node versions: %w", err)
}
Expand Down

0 comments on commit c46dd20

Please sign in to comment.