Skip to content

Commit

Permalink
Add UT for multimedia_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Oct 23, 2023
1 parent 42be774 commit 830067b
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 7 deletions.
7 changes: 2 additions & 5 deletions src/promptflow/promptflow/_utils/multimedia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _create_image_from_dict(image_dict: dict):
for k, v in image_dict.items():
format, resource = _get_multimedia_info(k)
if resource == "path":
return _create_image_from_file(v, mime_type=f"image/{format}")
return _create_image_from_file(Path(v), mime_type=f"image/{format}")
elif resource == "base64":
return _create_image_from_base64(v, mime_type=f"image/{format}")
elif resource == "url":
Expand Down Expand Up @@ -176,10 +176,7 @@ def persist_multimedia_data(value: Any, base_dir: Path, sub_dir: Path = None):


def convert_multimedia_data_to_base64(value: Any, with_type=False):
func = (
lambda x: f"data:{x._mime_type};base64," + PFBytes.to_base64(x) if with_type else PFBytes.to_base64
) # noqa: E731
to_base64_funcs = {PFBytes: func}
to_base64_funcs = {PFBytes: partial(PFBytes.to_base64, **{"with_type": with_type})}
return recursive_process(value, process_funcs=to_base64_funcs)


Expand Down
5 changes: 3 additions & 2 deletions src/promptflow/promptflow/contracts/multimedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def __init__(self, data: bytes, mime_type: str):
self._hash = hashlib.sha1(data).hexdigest()[:8]
self._mime_type = mime_type.lower()

def to_base64(self):
def to_base64(self, with_type: bool = False):
"""Returns the base64 representation of the PFBytes."""

if with_type:
return f"data:{self._mime_type};base64," + base64.b64encode(self).decode("utf-8")
return base64.b64encode(self).decode("utf-8")


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import pytest
import re
from pathlib import Path
from unittest.mock import mock_open

from promptflow._utils.multimedia_utils import (
_create_image_from_file,
convert_multimedia_data_to_base64,
create_image,
load_multimedia_data,
persist_multimedia_data,
)
from promptflow.contracts._errors import InvalidImageInput
from promptflow.contracts.flow import FlowInputDefinition
from promptflow.contracts.tool import ValueType

from ...utils import DATA_ROOT

TEST_IMAGE_PATH = DATA_ROOT / "test_image.jpg"


@pytest.mark.unittest
class TestMultimediaUtils:
def test_create_image(self, mocker):
# Test create image from dict
## From path
image_dict = {"data:image/jpg;path": TEST_IMAGE_PATH}
image_from_path = create_image(image_dict)
assert image_from_path._mime_type == "image/jpg"

## From base64
image_dict = {"data:image/jpg;base64": image_from_path.to_base64()}
image_from_base64 = create_image(image_dict)
assert str(image_from_path) == str(image_from_base64)
assert image_from_base64._mime_type == "image/jpg"

## From url
mocker.patch("requests.get", return_value=mocker.Mock(content=image_from_path, status_code=200))
image_dict = {"data:image/jpg;url": ""}
image_from_url = create_image(image_dict)
assert str(image_from_path) == str(image_from_url)
assert image_from_url._mime_type == "image/jpg"

mocker.patch("requests.get", return_value=mocker.Mock(content=None, status_code=404))
with pytest.raises(InvalidImageInput) as ex:
create_image(image_dict)
assert "Error while fetching image from URL" in ex.value.message_format

# Test create image from string
## From path
image_from_path = create_image(str(TEST_IMAGE_PATH))
assert image_from_path._mime_type == "image/jpg"

# From base64
image_from_base64 = create_image(image_from_path.to_base64())
assert str(image_from_path) == str(image_from_base64)
assert image_from_base64._mime_type in ["image/jpg", "image/jpeg"]

## From url
mocker.patch("promptflow._utils.multimedia_utils._is_url", return_value=True)
mocker.patch("promptflow._utils.multimedia_utils._is_base64", return_value=False)
mocker.patch("requests.get", return_value=mocker.Mock(content=image_from_path, status_code=200))
image_from_url = create_image("")
assert str(image_from_path) == str(image_from_url)
assert image_from_url._mime_type in ["image/jpg", "image/jpeg"]

## From image
image_from_image = create_image(image_from_path)
assert str(image_from_path) == str(image_from_image)

# Test invalid input type
with pytest.raises(InvalidImageInput) as ex:
create_image(0)
assert "Unsupported image input type" in ex.value.message_format

# Test invalid image dict
with pytest.raises(InvalidImageInput) as ex:
invalid_image_dict = {"invalid_image": "invalid_image"}
create_image(invalid_image_dict)
assert "Invalid image input format" in ex.value.message_format

def test_persist_multimedia_date(self, mocker):
image = _create_image_from_file(TEST_IMAGE_PATH)
mocker.patch('builtins.open', mock_open())
data = {"image": image, "images": [image, image, "other_data"], "other_data": "other_data"}
persisted_data = persist_multimedia_data(data, base_dir=Path(__file__).parent)
file_name = re.compile(r"^[0-9a-z]{8}-[0-9a-z]{4}-[0-9a-z]{4}-[0-9a-z]{4}-[0-9a-z]{12}.jpg$")
assert re.match(file_name, persisted_data["image"]["data:image/jpg;path"])
assert re.match(file_name, persisted_data["images"][0]["data:image/jpg;path"])
assert re.match(file_name, persisted_data["images"][1]["data:image/jpg;path"])

def test_convert_multimedia_date_to_base64(self):
image = _create_image_from_file(TEST_IMAGE_PATH)
data = {"image": image, "images": [image, image, "other_data"], "other_data": "other_data"}
base64_data = convert_multimedia_data_to_base64(data)
assert base64_data == {
"image": image.to_base64(),
"images": [image.to_base64(), image.to_base64(), "other_data"],
"other_data": "other_data",
}

base64_data = convert_multimedia_data_to_base64(data, with_type=True)
prefix = f"data:{image._mime_type};base64,"
assert base64_data == {
"image": prefix + image.to_base64(),
"images": [prefix + image.to_base64(), prefix + image.to_base64(), "other_data"],
"other_data": "other_data",
}

def test_load_multimedia_data(self):
inputs = {
"image": FlowInputDefinition(type=ValueType.IMAGE),
"images": FlowInputDefinition(type=ValueType.LIST),
"object": FlowInputDefinition(type=ValueType.OBJECT),
}
line_inputs = {
"image": {"data:image/jpg;path": str(TEST_IMAGE_PATH)},
"images": [{"data:image/jpg;path": str(TEST_IMAGE_PATH)}, {"data:image/jpg;path": str(TEST_IMAGE_PATH)}],
"object": {"image": {"data:image/jpg;path": str(TEST_IMAGE_PATH)}, "other_data": "other_data"}
}
updated_inputs = load_multimedia_data(inputs, line_inputs)
image = _create_image_from_file(TEST_IMAGE_PATH)
assert updated_inputs == {
"image": image,
"images": [image, image],
"object": {"image": image, "other_data": "other_data"}
}

0 comments on commit 830067b

Please sign in to comment.