diff --git a/python/pyproject.toml b/python/pyproject.toml index f8c55e8f02..e2c5b96361 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -116,7 +116,8 @@ mlflow = [ "mlflow-skinny" ] image = [ - "Pillow" + "Pillow", + "numpy" ] fugue = [ "fugue", diff --git a/python/tests/extras/test_image_metric.py b/python/tests/extras/test_image_metric.py index ecf5e2fe5b..1d1db36b9f 100644 --- a/python/tests/extras/test_image_metric.py +++ b/python/tests/extras/test_image_metric.py @@ -10,11 +10,17 @@ from whylogs.core.preprocessing import ListView, PreprocessedColumn from whylogs.core.resolvers import Resolver from whylogs.core.schema import ColumnSchema, DatasetSchema -from whylogs.extras.image_metric import ImageMetric, ImageMetricConfig, log_image +from whylogs.extras.image_metric import ( + ImageMetric, + ImageMetricConfig, + init_image_schema, + log_image, +) logger = logging.getLogger(__name__) try: + import numpy as np from PIL.Image import Image as ImageType except ImportError as e: ImageType = None @@ -80,6 +86,20 @@ def test_image_metric() -> None: assert "ints" not in metric.submetrics["Software"] +def test_log_np_image() -> None: + image_path = os.path.join(TEST_DATA_PATH, "images", "flower2.jpg") + img = np.array(image_loader(image_path)) + + schema = init_image_schema() + profile = why.log({"image": img}, schema=schema) + df = profile.profile().view().to_pandas() + + # Ensure a few columns are in the data frame from the image metric + assert "image/Brightness.mean:cardinality/est" in df.columns + assert "image/Brightness.mean:cardinality/lower_1" in df.columns + assert "image/entropy:types/tensor" in df.columns + + def test_allowed_exif_tags() -> None: image_path = os.path.join(TEST_DATA_PATH, "images", "flower2.jpg") img = image_loader(image_path) diff --git a/python/whylogs/extras/image_metric.py b/python/whylogs/extras/image_metric.py index 8dfceed174..58e918750d 100644 --- a/python/whylogs/extras/image_metric.py +++ b/python/whylogs/extras/image_metric.py @@ -29,6 +29,8 @@ logger = logging.getLogger(__name__) try: + import numpy as np # type: ignore + from PIL import Image from PIL.Image import Image as ImageType # type: ignore from PIL.ImageStat import Stat # type: ignore from PIL.TiffImagePlugin import IFDRational # type: ignore @@ -211,6 +213,9 @@ def _update_relevant_submetrics(self, name: str, data: PreprocessedColumn) -> No def columnar_update(self, view: PreprocessedColumn) -> OperationResult: count = 0 for image in list(chain.from_iterable(view.raw_iterator())): + if isinstance(image, np.ndarray): + image = Image.fromarray(image.astype(np.uint8)) + if isinstance(image, ImageType): metadata = get_pil_exif_metadata(image) for name, value in metadata.items(): @@ -244,6 +249,24 @@ def zero(cls, config: Optional[MetricConfig] = None) -> "ImageMetric": ) +def init_image_schema(column_prefix: str = "image") -> DatasetSchema: + """ + Initialize a DatasetSchema for logging images. This can be passed into a logger or why.log. + + Args: + column_prefix (str): The prefix that appears in the dataset profiles along with all of the + image features. If the prefix is "image", then you'll log image with why.log({image: image_data}). + """ + + class ImageResolver(Resolver): + def resolve(self, name: str, why_type: DataType, column_schema: ColumnSchema) -> Dict[str, Metric]: + return {ImageMetric.get_namespace(): ImageMetric.zero(column_schema.cfg)} + + return DatasetSchema( + types={column_prefix: Image.Image}, default_configs=ImageMetricConfig(), resolvers=ImageResolver() + ) + + def log_image( images: Union[ImageType, List[ImageType], Dict[str, ImageType]], default_column_prefix: str = "image",