Skip to content

Commit

Permalink
Fix some issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-misuk-valor committed Jan 24, 2025
1 parent 048f420 commit 068fab9
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 70 deletions.
22 changes: 14 additions & 8 deletions src/hope_dedup_engine/apps/api/deduplication/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from hope_dedup_engine.apps.api.utils.notification import send_notification
from hope_dedup_engine.apps.faces.celery.pipeline import image_pipeline
from hope_dedup_engine.config.celery import app
from hope_dedup_engine.utils.celery.task_result import Result, is_error, is_value
from hope_dedup_engine.utils.celery.task_result import (
Result,
UnexpectedResultError,
is_error,
is_value,
)


@app.task
Expand All @@ -24,18 +29,18 @@ def clear_findings(deduplication_set_id: str) -> None:


@app.task
def save_pipeline_error(result: Result, deduplication_set_id: str) -> None:
def finish(result: Result, deduplication_set_id: str) -> None:
deduplication_set: DeduplicationSet = DeduplicationSet.objects.get(
id=deduplication_set_id
)

if is_error(result):
# TODO: save error in job
pass

if is_value(result):
deduplication_set.state = DeduplicationSet.State.DIRTY
elif is_value(result):
deduplication_set.state = DeduplicationSet.State.CLEAN
deduplication_set.save(update_fields=["state"])
else:
raise UnexpectedResultError(result)
deduplication_set.save(update_fields=["state"])

send_notification(deduplication_set.notification_url)

Expand All @@ -46,10 +51,11 @@ def find_duplicates(self: Task, dedup_job_id: int, version: int) -> None:
deduplication_set = dedup_job.deduplication_set

config = asdict(DeduplicationSetConfig.from_deduplication_set(deduplication_set))

pipeline = (
clear_findings.s(deduplication_set.id)
| image_pipeline(deduplication_set, config)
| save_pipeline_error.s(deduplication_set.id)
| finish.s(deduplication_set.id)
)

return self.replace(pipeline)
24 changes: 13 additions & 11 deletions src/hope_dedup_engine/apps/api/models/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import uuid4

from django.conf import settings
from django.db import models
from django.db import models, transaction

from hope_dedup_engine.apps.security.models import ExternalSystem
from hope_dedup_engine.types import (
Expand Down Expand Up @@ -92,18 +92,20 @@ def get_ignored_pairs(self) -> list[IgnoredPair]:
) + list(self.ignoredfilenamepair_set.values_list("first", "second"))

def update_encodings(self, encodings: list[ImageEmbedding]) -> None:
fresh_self: DeduplicationSet = DeduplicationSet.objects.select_for_update().get(
pk=self.pk
)
fresh_self.encodings.update(encodings)
fresh_self.save()
with transaction.atomic():
fresh_self: DeduplicationSet = (
DeduplicationSet.objects.select_for_update().get(pk=self.pk)
)
fresh_self.encodings.update(encodings)
fresh_self.save()

def update_encoding_errors(self, errors: list[ImageEmbeddingError]) -> None:
fresh_self: DeduplicationSet = DeduplicationSet.objects.select_for_update().get(
pk=self.pk
)
fresh_self.encoding_errors.update(errors)
fresh_self.save()
with transaction.atomic():
fresh_self: DeduplicationSet = (
DeduplicationSet.objects.select_for_update().get(pk=self.pk)
)
fresh_self.encoding_errors.update(errors)
fresh_self.save()

def update_findings(
self, findings: list[tuple[EntityEmbedding, EntityEmbedding, Score]]
Expand Down
2 changes: 1 addition & 1 deletion src/hope_dedup_engine/apps/faces/celery/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def image_pipeline(
return (
encode_images_pipeline
| find_duplicates_pipeline
| save_encoding_errors_in_findings.s(deduplication_set.id)
| save_encoding_errors_in_findings.si(deduplication_set.id)
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
find_similar_faces,
)
from hope_dedup_engine.config.celery import app
from hope_dedup_engine.constants import FacialError
from hope_dedup_engine.types import EntityEmbedding, Filename
from hope_dedup_engine.utils.celery.task_result import wrapped

Expand Down Expand Up @@ -108,10 +109,10 @@ def save_encoding_errors_in_findings(deduplication_set_id: str) -> None:
pk=deduplication_set_id
)
embedding_errors = [
(reference_pk, deduplication_set.encoding_errors[filename])
(reference_pk, FacialError(deduplication_set.encoding_errors[filename]))
for reference_pk, filename in deduplication_set.image_set.values_list(
"reference_pk", "filename"
)
if filename in deduplication_set.encoding_errors
]
deduplication_set.update_encoding_errors(embedding_errors)
deduplication_set.update_finding_errors(embedding_errors)
16 changes: 2 additions & 14 deletions src/hope_dedup_engine/constants.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
from enum import Enum
from enum import IntEnum


class FacialError(Enum):
class FacialError(IntEnum):
GENERIC_ERROR = 999
NO_FACE_DETECTED = 998
MULTIPLE_FACES_DETECTED = 997
NO_FILE_FOUND = 996

@property
def code(self) -> int:
return self.value


def is_facial_error(value):
if isinstance(value, str):
return value in FacialError.__members__
if isinstance(value, int):
return value in FacialError._value2member_map_
return False
35 changes: 29 additions & 6 deletions src/hope_dedup_engine/utils/celery/task_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from collections.abc import Callable
from functools import wraps
from typing import Any, Literal, TypedDict
from typing import Any, Literal, NotRequired, TypedDict

import celery
from celery import exceptions as celery_exceptions
Expand All @@ -18,6 +19,7 @@
IS_WRAPPED: Literal["is_wrapped"] = "is_wrapped"
MESSAGE: Literal["message"] = "message"
UNKNOWN_RESULT: Literal["Unknown result"] = "Unknown result"
PROPAGATED: Literal["propagated"] = "propagated"


# Because of Celery default JSON serializer we cannot use classes or dataclasses
Expand All @@ -32,6 +34,7 @@ class Value(Result):

class Error(Result):
message: str
propagated: NotRequired[bool]


class UnexpectedResultError(Exception):
Expand All @@ -56,6 +59,10 @@ def is_error(a: Any) -> bool:
return is_result(a) and MESSAGE in a


def mark_propagated(error: Error) -> None:
error[PROPAGATED] = True


def make_value(v: Any) -> Value:
return {IS_WRAPPED: True, DATA: v}

Expand Down Expand Up @@ -104,6 +111,7 @@ def inner(*args: Any, **kwargs: Any) -> Result:

if is_result(first_arg):
if is_error(first_arg):
mark_propagated(first_arg)
return first_arg

if is_value(first_arg):
Expand All @@ -130,14 +138,29 @@ def inner(*args: Any, **kwargs: Any) -> Result:
return inner


def unwrap_result(result: Any) -> Any:
if not is_result(result):
return result

if is_value(result):
return result[DATA]
elif is_error(result):
raise Exception(MESSAGE)

raise UnexpectedResultError(result)


@celery_signals.task_postrun.connect
def unwrap_results(sender=None, headers=None, body=None, **kwargs) -> None:
def fix_results_in_db(sender=None, headers=None, body=None, **kwargs) -> None:
if (task_id := kwargs.get("task_id")) and (result := kwargs.get("retval")):
if is_result(result):
result_model = TaskResult.objects.get(task_id=task_id)
if is_value(result):
result_model.result = result[DATA]
result_model.result = json.dumps(result[DATA])
elif is_error(result):
result_model.result = result[MESSAGE]
result_model.status = FAILURE
result_model.save()
if result.get(PROPAGATED):
result_model.delete()
else:
result_model.result = result[MESSAGE]
result_model.status = FAILURE
result_model.save()
44 changes: 16 additions & 28 deletions src/hope_dedup_engine/utils/celery/utility_tasks.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,37 @@
from collections.abc import Callable
from itertools import batched
from typing import Any, Mapping
from typing import Any, NoReturn

import celery
from celery import canvas

from hope_dedup_engine.config.celery import app
from hope_dedup_engine.utils.celery.task_result import wrapped
from hope_dedup_engine.utils.celery.task_result import unwrap_result, wrapped

SerializedTask = dict[str, Any]


@app.task(bind=True)
@wrapped
def map_[
T
](self: celery.Task, results: list[T], serialize_task: SerializedTask) -> list[T]:
T, P
](self: celery.Task, results: list[T], serialize_task: SerializedTask) -> list[P]:
"""Celery map/starmap/xmap cannot be used in chain"""
signature = self.app.signature(serialize_task)
signature: Callable[[T], P] = self.app.signature(serialize_task)
return list(map(signature, results))


@app.task
@wrapped
def noop(*_: Any, **__: Any) -> None:
pass


@app.task
@wrapped
def concat_lists[T](items: list[list[T]]) -> list[T]:
return sum(items, start=[])


NOOP: Mapping = dict(noop.s())


@app.task(bind=True)
@wrapped
def parallelize[
T
](
def parallelize(
self: celery.Task,
producer: SerializedTask,
task: SerializedTask,
size: int,
end_task: SerializedTask = NOOP,
) -> list[T]:
data = self.app.signature(producer)()
end_task: SerializedTask | None = None,
) -> NoReturn:
producer_signature = self.app.signature(producer)
data = unwrap_result(producer_signature())

signature: canvas.Signature = self.app.signature(task)

Expand All @@ -61,6 +46,9 @@ def parallelize[
clone = signature.clone(args)
signatures.append(clone)

chord = celery.chord(signatures, self.app.signature(end_task))
group = celery.group(signatures)

if end_task:
return self.replace(group | self.app.signature(end_task))

return self.replace(chord)
return self.replace(group)

0 comments on commit 068fab9

Please sign in to comment.