Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Celery perormance #138

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 39 additions & 113 deletions src/hope_dedup_engine/apps/api/deduplication/process.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,61 @@
from dataclasses import asdict

from django.db.models import F

from celery import chord, shared_task
from celery import Task

from hope_dedup_engine.apps.api.deduplication.config import DeduplicationSetConfig

# from hope_dedup_engine.apps.api.deduplication.registry import ( # DuplicateFinder,; DuplicateKeyPair,
# get_finders,
# )
from hope_dedup_engine.apps.api.models import DedupJob, DeduplicationSet, Finding
from hope_dedup_engine.apps.api.utils.notification import send_notification

# from hope_dedup_engine.apps.api.utils.progress import track_progress_multi
from hope_dedup_engine.apps.faces.celery_tasks import (
callback_encodings,
encode_chunk,
get_chunks,
handle_error,
from hope_dedup_engine.apps.faces.celery.pipeline import image_pipeline
from hope_dedup_engine.config.celery import app
from hope_dedup_engine.utils.celery.task_result import (
Result,
UnexpectedResultError,
is_error,
is_value,
)

# def _sort_keys(pair: DuplicateKeyPair) -> DuplicateKeyPair:
# first, second, score = pair
# return *sorted((first, second)), score

@app.task
def clear_findings(deduplication_set_id: str) -> None:
deduplication_set: DeduplicationSet = DeduplicationSet.objects.get(
id=deduplication_set_id
)

# def _save_duplicates(
# finder: DuplicateFinder,
# deduplication_set: DeduplicationSet,
# tracker: Callable[[int], None],
# ) -> None:
# reference_pk_to_filename_mapping = dict(
# deduplication_set.image_set.values_list("reference_pk", "filename")
# )
# ignored_filename_pairs = frozenset(
# map(
# tuple,
# map(
# sorted,
# deduplication_set.ignoredfilenamepair_set.values_list(
# "first", "second"
# ),
# ),
# )
# )
Finding.objects.filter(deduplication_set=deduplication_set).delete()

# ignored_reference_pk_pairs = frozenset(
# deduplication_set.ignoredreferencepkpair_set.values_list("first", "second")
# )
deduplication_set.state = DeduplicationSet.State.DIRTY
deduplication_set.save(update_fields=["state"])
send_notification(deduplication_set.notification_url)

# for first, second, score in map(_sort_keys, finder.run(tracker)):
# first_filename, second_filename = sorted(
# (
# reference_pk_to_filename_mapping[first],
# reference_pk_to_filename_mapping[second],
# )
# )
# ignored = (first, second) in ignored_reference_pk_pairs or (
# first_filename,
# second_filename,
# ) in ignored_filename_pairs
# if not ignored:
# duplicate, _ = Duplicate.objects.get_or_create(
# deduplication_set=deduplication_set,
# first_reference_pk=first,
# second_reference_pk=second,
# )
# duplicate.score += score * finder.weight
# duplicate.save()

@app.task
def finish(result: Result, deduplication_set_id: str) -> None:
deduplication_set: DeduplicationSet = DeduplicationSet.objects.get(
id=deduplication_set_id
)

HOUR = 60 * 60

if is_error(result):
deduplication_set.state = DeduplicationSet.State.DIRTY
elif is_value(result):
deduplication_set.state = DeduplicationSet.State.CLEAN
else:
raise UnexpectedResultError(result)
deduplication_set.save(update_fields=["state"])

def update_job_progress(job: DedupJob, progress: int) -> None:
job.progress = progress
job.save(update_fields=["progress"])
send_notification(deduplication_set.notification_url)


@shared_task(soft_time_limit=0.5 * HOUR, time_limit=1 * HOUR)
def find_duplicates(dedup_job_id: int, version: int) -> None:
@app.task(bind=True)
def find_duplicates(self: Task, dedup_job_id: int, version: int) -> None:
dedup_job: DedupJob = DedupJob.objects.get(pk=dedup_job_id, version=version)
deduplication_set = dedup_job.deduplication_set
try:

deduplication_set.state = DeduplicationSet.State.DIRTY
deduplication_set.save(update_fields=["state"])
send_notification(deduplication_set.notification_url)

config = asdict(
DeduplicationSetConfig.from_deduplication_set(deduplication_set)
)

# clean results
Finding.objects.filter(deduplication_set=deduplication_set).delete()
dedup_job.progress = 0
dedup_job.save(update_fields=["progress"])

# weight_total = 0
# for finder, tracker in zip(
# for finder, _ in zip(
# get_finders(deduplication_set),
# track_progress_multi(partial(update_job_progress, dedup_job)),
# ):
# # _save_duplicates(finder, deduplication_set, tracker)
# weight_total += finder.weight

weight_total = 1
deduplication_set.finding_set.update(score=F("score") / weight_total)

files = deduplication_set.image_set.values_list("filename", flat=True)
chunks = get_chunks(files)
tasks = [encode_chunk.s(chunk, config) for chunk in chunks]
chord_id = chord(tasks)(callback_encodings.s(config=config))

# for finder, tracker in zip(
# get_finders(deduplication_set),
# track_progress_multi(partial(update_job_progress, dedup_job)),
# ):
# for first, second, score in finder.run(tracker):
# finding = (first, second, score * finder.weight)
# deduplication_set.update_findings(finding)
config = asdict(DeduplicationSetConfig.from_deduplication_set(deduplication_set))

deduplication_set.state = deduplication_set.State.CLEAN
deduplication_set.save(update_fields=["state"])
pipeline = (
clear_findings.s(deduplication_set.id)
| image_pipeline(deduplication_set, config)
| finish.s(deduplication_set.id)
)

return {
"deduplication_set": str(deduplication_set),
"chord_id": str(chord_id),
"chunks": len(chunks),
}
except Exception:
handle_error(deduplication_set)
raise
return self.replace(pipeline)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.1.4 on 2025-01-23 16:03

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("api", "0015_deduplicationset_encodings_finding_delete_duplicate"),
]

operations = [
migrations.AddField(
model_name="deduplicationset",
name="encoding_errors",
field=models.JSONField(blank=True, default=dict, null=True),
),
]
91 changes: 73 additions & 18 deletions src/hope_dedup_engine/apps/api/models/deduplication.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from itertools import chain
from typing import Any, Final, override
from uuid import uuid4

from django.conf import settings
from django.db import models
from django.db import models, transaction

from hope_dedup_engine.apps.security.models import ExternalSystem
from hope_dedup_engine.types import EncodingType, FindingType, IgnoredPairType
from hope_dedup_engine.types import (
Embedding,
EntityEmbedding,
EntityEmbeddingError,
Filename,
Finding,
IgnoredPair,
ImageEmbedding,
ImageEmbeddingError,
Score,
SortedTuple,
)

REFERENCE_PK_LENGTH: Final[int] = 100

Expand Down Expand Up @@ -54,43 +66,86 @@ class State(models.IntegerChoices):
notification_url = models.CharField(max_length=255, null=True, blank=True)
config = models.ForeignKey("Config", null=True, on_delete=models.SET_NULL)

# TODO: rename to embeddings as it's more correct term
encodings = models.JSONField(
null=True, blank=True, default=dict
) # {file1: encoding1, file2: encoding2, ...}
) # {file1: embedding1, file2: embedding2, ...}

encoding_errors = models.JSONField(
null=True, blank=True, default=dict
) # {file1: embedding_error1, file2: embedding_error2, ...}

def __str__(self) -> str:
return self.name or f"ID: {self.pk}"

def get_encodings(self) -> EncodingType:
def get_encodings(self) -> dict[Filename, Embedding]:
return self.encodings

def get_findings(self) -> FindingType:
def get_findings(self) -> list[Finding]:
return list(
self.finding_set.values_list(
"first_reference_pk", "second_reference_pk", "score"
)
)

def get_ignored_pairs(self) -> IgnoredPairType:
return list(
self.ignoredreferencepkpair_set.values_list("first", "second")
) + list(self.ignoredfilenamepair_set.values_list("first", "second"))
def get_ignored_pairs(self) -> set[IgnoredPair]:
return set(
chain(
map(
SortedTuple,
self.ignoredreferencepkpair_set.values_list("first", "second"),
),
map(
SortedTuple,
list(self.ignoredfilenamepair_set.values_list("first", "second")),
),
)
)

def update_encodings(self, encodings: list[ImageEmbedding]) -> None:
with transaction.atomic():
fresh_self: DeduplicationSet = (
DeduplicationSet.objects.select_for_update().get(pk=self.pk)
)
fresh_self.encodings.update(encodings)
fresh_self.save()

def update_encoding_errors(self, errors: list[ImageEmbeddingError]) -> None:
with transaction.atomic():
fresh_self: DeduplicationSet = (
DeduplicationSet.objects.select_for_update().get(pk=self.pk)
)
fresh_self.encoding_errors.update(errors)
fresh_self.save()

def update_encodings(self, encodings: EncodingType) -> None:
self.encodings.update(encodings)
self.save()
def update_findings(
self, findings: list[tuple[EntityEmbedding, EntityEmbedding, Score]]
) -> None:
Finding.objects.bulk_create(
[
Finding(
deduplication_set=self,
first_reference_pk=first_reference_pk,
second_reference_pk=second_reference_pk,
score=score,
)
for (first_reference_pk, _), (second_reference_pk, _), score in findings
],
ignore_conflicts=True,
)

def update_findings(self, findings: FindingType) -> None:
def update_finding_errors(
self, encoding_errors: list[EntityEmbeddingError]
) -> None:
Finding.objects.bulk_create(
[
Finding(
deduplication_set=self,
first_reference_pk=f[0],
second_reference_pk=f[1],
score=f[2],
error=f[3],
first_reference_pk=reference_pk,
second_reference_pk=error.name,
error=error.value,
)
for f in findings
for reference_pk, error in encoding_errors
],
ignore_conflicts=True,
)
Expand Down
1 change: 1 addition & 0 deletions src/hope_dedup_engine/apps/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def perform_create(self, serializer: Serializer) -> None:
deduplication_set.save()

def perform_destroy(self, instance: Image) -> None:
# TODO: remove encoding
deduplication_set = instance.deduplication_set
super().perform_destroy(instance)
deduplication_set.state = DeduplicationSet.State.DIRTY
Expand Down
2 changes: 1 addition & 1 deletion src/hope_dedup_engine/apps/faces/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from admin_extra_buttons.mixins import ExtraButtonsMixin
from celery import group

from hope_dedup_engine.apps.faces.celery_tasks import sync_dnn_files
from hope_dedup_engine.apps.faces.celery.tasks.dnn_files import sync_dnn_files
from hope_dedup_engine.apps.faces.models import DummyModel
from hope_dedup_engine.config.celery import app as celery_app

Expand Down
41 changes: 41 additions & 0 deletions src/hope_dedup_engine/apps/faces/celery/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any

from celery.canvas import Signature

from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.faces.celery.tasks.deduplication import (
encode_images,
find_duplicates,
get_deduplication_set_embedding_pairs,
get_deduplication_set_image_files,
save_encoding_errors_in_findings,
)
from hope_dedup_engine.utils.celery.utility_tasks import (
batched_compact_pairs,
parallelize,
)

IMAGE_ENCODING_BATCH_SIZE = 50
DUPLICATE_FINDING_BATCH_SIZE = 200


def image_pipeline(
deduplication_set: DeduplicationSet, config: dict[str, Any]
) -> Signature:
encode_images_pipeline = parallelize.si(
get_deduplication_set_image_files.s(deduplication_set.id),
encode_images.s(deduplication_set.id, config.get("encoding")),
IMAGE_ENCODING_BATCH_SIZE,
)
find_duplicates_pipeline = parallelize.si(
get_deduplication_set_embedding_pairs.s(deduplication_set.id),
find_duplicates.s(deduplication_set.id, config.get("deduplicate", {})),
DUPLICATE_FINDING_BATCH_SIZE,
splitter=batched_compact_pairs.s(),
)

return (
encode_images_pipeline
| find_duplicates_pipeline
| save_encoding_errors_in_findings.si(deduplication_set.id)
)
Empty file.
Loading
Loading