Skip to content

Commit

Permalink
Bringing up image tags functions
Browse files Browse the repository at this point in the history
  • Loading branch information
beijbom committed Apr 18, 2024
1 parent 0ddb42b commit 7d78a64
Show file tree
Hide file tree
Showing 14 changed files with 516 additions and 9 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"ncols",
"nyckel",
"pytest",
"Resizer",
"tqdm"
],
}
14 changes: 14 additions & 0 deletions docs/data_classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,22 @@

::: nyckel.LabelName

::: nyckel.ImageSampleData

::: nyckel.TextSampleData

::: nyckel.TabularSampleData

::: nyckel.TabularFieldKey

::: nyckel.TabularFieldValue

::: nyckel.ClassificationLabel

::: nyckel.ClassificationPrediction

::: nyckel.ClassificationAnnotation

::: nyckel.TagsAnnotation

::: nyckel.TagsPrediction
5 changes: 5 additions & 0 deletions docs/image_tags.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
::: nyckel.ImageTagsFunction

::: nyckel.ImageTagsSample

::: nyckel.ImageSampleData
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ nav:
- Text Classification: text_classification.md
- Tabular Classification: tabular_classification.md
- Text Tags: text_tags.md
- Image Tags: image_tags.md
- Credentials: credentials.md
- Data classes: data_classes.md
theme:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ packages = ["src/nyckel"]

[project]
name = "nyckel"
version = "0.4.0"
version = "0.4.1"
authors = [{ name = "Oscar Beijbom", email = "[email protected]" }]
description = "Python package for the Nyckel API"
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions src/nyckel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
from .functions.classification.tabular_classification import TabularClassificationFunction # noqa: F401
from .functions.classification.text_classification import TextClassificationFunction # noqa: F401
from .functions.tags.text_tags import TextTagsFunction # noqa: F401
from .functions.tags.image_tags import ImageTagsFunction # noqa: F401
310 changes: 310 additions & 0 deletions src/nyckel/functions/tags/image_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
import abc
from typing import Dict, List, Sequence, Union

from PIL import Image

from nyckel import (
ClassificationLabel,
ClassificationPrediction,
Credentials,
ImageSampleData,
ImageTagsSample,
NyckelId,
TagsAnnotation,
TagsPrediction,
)
from nyckel.functions.classification.image_classification import ImageDecoder, ImageEncoder
from nyckel.functions.classification.label_handler import ClassificationLabelHandler
from nyckel.functions.tags import tags_function_factory
from nyckel.functions.tags.tags import TagsFunctionURLHandler
from nyckel.functions.tags.tags_function_handler import TagsFunctionHandler
from nyckel.functions.tags.tags_sample_handler import TagsSampleHandler
from nyckel.functions.utils import strip_nyckel_prefix
from nyckel.image_processing import ImageResizer


class ImageTagsFunctionInterface(abc.ABC):

@abc.abstractmethod
def __init__(self, function_id: str, credentials: Credentials):
pass

@property
@abc.abstractmethod
def function_id(self) -> NyckelId:
pass

@property
@abc.abstractmethod
def sample_count(self) -> int:
pass

@property
@abc.abstractmethod
def label_count(self) -> int:
pass

@property
@abc.abstractmethod
def name(self) -> str:
pass

@classmethod
@abc.abstractmethod
def create(cls, name: str, credentials: Credentials) -> "ImageTagsFunction":
pass

@abc.abstractmethod
def delete(self) -> None:
pass

@abc.abstractmethod
def invoke(self, sample_data_list: List[ImageSampleData]) -> List[TagsPrediction]:
"""Invokes the trained function. Raises ValueError if function is not trained"""
pass

@abc.abstractmethod
def has_trained_model(self) -> bool:
pass

@abc.abstractmethod
def create_labels(self, labels: Sequence[Union[ClassificationLabel, str]]) -> List[NyckelId]:
pass

@abc.abstractmethod
def list_labels(self) -> List[ClassificationLabel]:
pass

@abc.abstractmethod
def read_label(self, label_id: NyckelId) -> ClassificationLabel:
pass

@abc.abstractmethod
def update_label(self, label: ClassificationLabel) -> ClassificationLabel:
pass

@abc.abstractmethod
def delete_labels(self, label_ids: List[NyckelId]) -> None:
pass

@abc.abstractmethod
def create_samples(self, samples: Sequence[Union[ImageTagsSample, ImageSampleData, Image.Image]]) -> List[NyckelId]:
pass

@abc.abstractmethod
def list_samples(self) -> List[ImageTagsSample]:
pass

@abc.abstractmethod
def read_sample(self, sample_id: NyckelId) -> ImageTagsSample:
pass

@abc.abstractmethod
def update_annotation(self, sample: ImageTagsSample) -> None:
pass

@abc.abstractmethod
def delete_samples(self, sample_ids: List[NyckelId]) -> None:
pass


class ImageTagsFunction(ImageTagsFunctionInterface):
"""
Example:
```py
from nyckel import Credentials, ImageTagsFunction, ImageTagsSample, TagsAnnotation
credentials = Credentials(client_id="...", client_secret="...")
func = ImageTagsFunction.create("ClothingColor", credentials)
func.create_samples([
ImageTagsSample(data="t-shirt1.jpg", annotation=[TagsAnnotation("White"), TagsAnnotation("Blue")]),
ImageTagsSample(data="t=shirt2.jpg", annotation=[TagsAnnotation("Red"), TagsAnnotation("White")]),
ImageTagsSample(data="jacket.jpg", annotation=[TagsAnnotation("Black")]),
ImageTagsSample(data="jeans.jpg", annotation=[TagsAnnotation("Blue")]),
])
predictions = func.invoke(["new-jacket.jpg"])
```
"""

def __init__(self, function_id: NyckelId, credentials: Credentials):
self._function_id = function_id

self._function_handler = TagsFunctionHandler(function_id, credentials)
self._label_handler = ClassificationLabelHandler(function_id, credentials)
self._url_handler = TagsFunctionURLHandler(function_id, credentials.server_url)
self._sample_handler = TagsSampleHandler(function_id, credentials)
self._decoder = ImageDecoder()
self._encoder = ImageEncoder()

assert self._function_handler.get_input_modality() == "Image"

@property
def function_id(self) -> NyckelId:
return self._function_id

@property
def sample_count(self) -> int:
return self._function_handler.sample_count

@property
def label_count(self) -> int:
return self._function_handler.label_count

@property
def name(self) -> str:
return self._function_handler.get_name()

@classmethod
def create(cls, name: str, credentials: Credentials) -> "ImageTagsFunction":
return tags_function_factory.TagsFunctionFactory().create(name, "Image", credentials) # type:ignore

def delete(self) -> None:
self._function_handler.delete()

def invoke(self, sample_data_list: List[ImageSampleData]) -> List[TagsPrediction]:
return self._sample_handler.invoke(sample_data_list, ImageSampleBodyTransformer()) # type: ignore

def has_trained_model(self) -> bool:
return self._function_handler.is_trained

def create_labels(self, labels: Sequence[Union[ClassificationLabel, str]]) -> List[NyckelId]: # type:ignore
typed_labels = [
label if isinstance(label, ClassificationLabel) else ClassificationLabel(name=label) # type:ignore
for label in labels
]
return self._label_handler.create_labels(typed_labels)

def list_labels(self) -> List[ClassificationLabel]:
return self._label_handler.list_labels(self.label_count)

def read_label(self, label_id: NyckelId) -> ClassificationLabel:
return self._label_handler.read_label(label_id)

def update_label(self, label: ClassificationLabel) -> ClassificationLabel:
return self._label_handler.update_label(label)

def delete_labels(self, label_ids: List[NyckelId]) -> None:
return self._label_handler.delete_labels(label_ids)

def create_samples(self, samples: Sequence[Union[ImageTagsSample, ImageSampleData, Image.Image]]) -> List[NyckelId]:
typed_samples = self._wrangle_post_samples_input(samples)
typed_samples = self._strip_label_names(typed_samples)
self._create_labels_as_needed(typed_samples)
return self._sample_handler.create_samples(typed_samples, ImageSampleBodyTransformer())

def _wrangle_post_samples_input(
self, samples: Sequence[Union[ImageTagsSample, ImageSampleData]]
) -> List[ImageTagsSample]:
typed_samples: List[ImageTagsSample] = []
for sample in samples:
if isinstance(sample, str):
typed_samples.append(ImageTagsSample(data=sample))
elif isinstance(sample, Image.Image):
typed_samples.append(ImageTagsSample(data=self._encoder.to_base64(sample)))
elif isinstance(sample, ImageTagsSample):
typed_samples.append(sample)
else:
raise ValueError(f"Unknown sample type: {type(sample)}")
return typed_samples

def _strip_label_names(self, samples: List[ImageTagsSample]) -> List[ImageTagsSample]:
for sample in samples:
if sample.annotation:
for entry in sample.annotation:
entry.label_name = entry.label_name.strip()
return samples

def _create_labels_as_needed(self, samples: List[ImageTagsSample]) -> None:
existing_labels = self._label_handler.list_labels(None)
existing_label_names = {label.name for label in existing_labels}
new_label_names: set = set()
for sample in samples:
if sample.annotation:
new_label_names |= {annotation.label_name for annotation in sample.annotation}
missing_label_names = new_label_names - existing_label_names
missing_labels = [ClassificationLabel(name=label_name) for label_name in missing_label_names]
if len(missing_labels) > 0:
self._label_handler.create_labels(missing_labels)

def list_samples(self) -> List[ImageTagsSample]:
samples_dict_list = self._sample_handler.list_samples(self.sample_count)
labels = self._label_handler.list_labels(None)
label_name_by_id = {label.id: label.name for label in labels}

return [self._sample_from_dict(entry, label_name_by_id) for entry in samples_dict_list] # type: ignore

def _sample_from_dict(self, sample_dict: Dict, label_name_by_id: Dict) -> ImageTagsSample:
if "annotation" in sample_dict:
annotation = [
TagsAnnotation(
label_name=label_name_by_id[strip_nyckel_prefix(entry["labelId"])],
present=entry["present"],
)
for entry in sample_dict["annotation"]
]
else:
annotation = None

if "prediction" in sample_dict:
prediction = [
ClassificationPrediction(
confidence=entry["confidence"],
label_name=label_name_by_id[strip_nyckel_prefix(entry["labelId"])],
)
for entry in sample_dict["prediction"]
]
else:
prediction = None

return ImageTagsSample(
id=strip_nyckel_prefix(sample_dict["id"]),
data=sample_dict["data"],
external_id=sample_dict["externalId"] if "externalId" in sample_dict else None,
annotation=annotation,
prediction=prediction,
)

def read_sample(self, sample_id: NyckelId) -> ImageTagsSample:
sample_dict = self._sample_handler.read_sample(sample_id)

labels = self._label_handler.list_labels(None)
label_name_by_id = {strip_nyckel_prefix(label.id): label.name for label in labels} # type: ignore

return self._sample_from_dict(sample_dict, label_name_by_id) # type: ignore

def update_annotation(self, sample: ImageTagsSample) -> None:
self._sample_handler.update_annotation(sample)

def delete_samples(self, sample_ids: List[NyckelId]) -> None:
self._sample_handler.delete_samples(sample_ids)


class ImageSampleBodyTransformer:

def __init__(self):
self._decoder = ImageDecoder()
self._encoder = ImageEncoder()
self._resizer = ImageResizer()

def __call__(self, sample_data: ImageSampleData) -> str:
"""Resizes if needed and encodes the sample data as a URL or dataURI."""
if self._is_nyckel_owned_url(sample_data):
# If the input points to a Nyckel S3 bucket, we know that the image is processed and verified.
# In that case, we just point back to that URL.
return sample_data

if self._decoder.looks_like_url(sample_data):
return self._encoder.to_base64(self._resizer(self._decoder.to_image(sample_data)))

if self._decoder.looks_like_local_filepath(sample_data):
return self._encoder.to_base64(self._resizer(self._decoder.to_image(sample_data)))

if self._decoder.looks_like_data_uri(sample_data):
return self._encoder.to_base64(self._resizer(self._decoder.to_image(sample_data)))

raise ValueError(f"Can't parse input sample.data={sample_data}")

def _is_nyckel_owned_url(self, sample_data: str) -> bool:
return sample_data.startswith("https://s3.us-west-2.amazonaws.com/nyckel.server.")
4 changes: 2 additions & 2 deletions src/nyckel/functions/tags/tags_function_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

from nyckel import Credentials
from nyckel.functions.tags import text_tags
from nyckel.functions.tags import image_tags, text_tags
from nyckel.functions.tags.tags_function_handler import TagsFunctionHandler
from nyckel.functions.utils import strip_nyckel_prefix

Expand All @@ -11,7 +11,7 @@ class TagsFunctionFactory:
def __init__(self) -> None:
self.function_type_by_input = {
"Text": text_tags.TextTagsFunction,
# "Image": image_classification.ImageClassificationFunction,
"Image": image_tags.ImageTagsFunction,
# "Tabular": tabular_classification.TabularClassificationFunction,
}

Expand Down
Loading

0 comments on commit 7d78a64

Please sign in to comment.