Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UT for multimedia_utils #849

Merged
merged 1 commit into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/promptflow/promptflow/_utils/multimedia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _create_image_from_dict(image_dict: dict):
for k, v in image_dict.items():
format, resource = _get_multimedia_info(k)
if resource == "path":
return _create_image_from_file(v, mime_type=f"image/{format}")
return _create_image_from_file(Path(v), mime_type=f"image/{format}")
elif resource == "base64":
return _create_image_from_base64(v, mime_type=f"image/{format}")
elif resource == "url":
Expand Down Expand Up @@ -176,10 +176,7 @@ def persist_multimedia_data(value: Any, base_dir: Path, sub_dir: Path = None):


def convert_multimedia_data_to_base64(value: Any, with_type=False):
func = (
lambda x: f"data:{x._mime_type};base64," + PFBytes.to_base64(x) if with_type else PFBytes.to_base64
) # noqa: E731
to_base64_funcs = {PFBytes: func}
to_base64_funcs = {PFBytes: partial(PFBytes.to_base64, **{"with_type": with_type})}
return recursive_process(value, process_funcs=to_base64_funcs)


Expand Down
5 changes: 3 additions & 2 deletions src/promptflow/promptflow/contracts/multimedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def __init__(self, data: bytes, mime_type: str):
self._hash = hashlib.sha1(data).hexdigest()[:8]
self._mime_type = mime_type.lower()

def to_base64(self):
def to_base64(self, with_type: bool = False):
"""Returns the base64 representation of the PFBytes."""

if with_type:
return f"data:{self._mime_type};base64," + base64.b64encode(self).decode("utf-8")
return base64.b64encode(self).decode("utf-8")


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import pytest
import re
from pathlib import Path
from unittest.mock import mock_open

from promptflow._utils.multimedia_utils import (
_create_image_from_file,
convert_multimedia_data_to_base64,
create_image,
load_multimedia_data,
persist_multimedia_data,
)
from promptflow.contracts._errors import InvalidImageInput
from promptflow.contracts.flow import FlowInputDefinition
from promptflow.contracts.tool import ValueType

from ...utils import DATA_ROOT

TEST_IMAGE_PATH = DATA_ROOT / "test_image.jpg"


@pytest.mark.unittest
class TestMultimediaUtils:
def test_create_image_with_dict(self, mocker):
## From path
image_dict = {"data:image/jpg;path": TEST_IMAGE_PATH}
image_from_path = create_image(image_dict)
assert image_from_path._mime_type == "image/jpg"

## From base64
image_dict = {"data:image/jpg;base64": image_from_path.to_base64()}
image_from_base64 = create_image(image_dict)
assert str(image_from_path) == str(image_from_base64)
assert image_from_base64._mime_type == "image/jpg"

## From url
mocker.patch("requests.get", return_value=mocker.Mock(content=image_from_path, status_code=200))
image_dict = {"data:image/jpg;url": ""}
image_from_url = create_image(image_dict)
assert str(image_from_path) == str(image_from_url)
assert image_from_url._mime_type == "image/jpg"

mocker.patch("requests.get", return_value=mocker.Mock(content=None, status_code=404))
with pytest.raises(InvalidImageInput) as ex:
create_image(image_dict)
assert "Error while fetching image from URL" in ex.value.message_format

def test_create_image_with_string(self, mocker):
## From path
image_from_path = create_image(str(TEST_IMAGE_PATH))
assert image_from_path._mime_type == "image/jpg"

# From base64
image_from_base64 = create_image(image_from_path.to_base64())
assert str(image_from_path) == str(image_from_base64)
assert image_from_base64._mime_type in ["image/jpg", "image/jpeg"]

## From url
mocker.patch("promptflow._utils.multimedia_utils._is_url", return_value=True)
mocker.patch("promptflow._utils.multimedia_utils._is_base64", return_value=False)
mocker.patch("requests.get", return_value=mocker.Mock(content=image_from_path, status_code=200))
image_from_url = create_image("")
assert str(image_from_path) == str(image_from_url)
assert image_from_url._mime_type in ["image/jpg", "image/jpeg"]

## From image
image_from_image = create_image(image_from_path)
assert str(image_from_path) == str(image_from_image)

def test_create_image_with_invalid_cases(self):
# Test invalid input type
with pytest.raises(InvalidImageInput) as ex:
create_image(0)
assert "Unsupported image input type" in ex.value.message_format

# Test invalid image dict
with pytest.raises(InvalidImageInput) as ex:
invalid_image_dict = {"invalid_image": "invalid_image"}
create_image(invalid_image_dict)
assert "Invalid image input format" in ex.value.message_format

def test_persist_multimedia_date(self, mocker):
image = _create_image_from_file(TEST_IMAGE_PATH)
mocker.patch('builtins.open', mock_open())
data = {"image": image, "images": [image, image, "other_data"], "other_data": "other_data"}
persisted_data = persist_multimedia_data(data, base_dir=Path(__file__).parent)
file_name = re.compile(r"^[0-9a-z]{8}-[0-9a-z]{4}-[0-9a-z]{4}-[0-9a-z]{4}-[0-9a-z]{12}.jpg$")
assert re.match(file_name, persisted_data["image"]["data:image/jpg;path"])
assert re.match(file_name, persisted_data["images"][0]["data:image/jpg;path"])
assert re.match(file_name, persisted_data["images"][1]["data:image/jpg;path"])

def test_convert_multimedia_date_to_base64(self):
image = _create_image_from_file(TEST_IMAGE_PATH)
data = {"image": image, "images": [image, image, "other_data"], "other_data": "other_data"}
base64_data = convert_multimedia_data_to_base64(data)
assert base64_data == {
"image": image.to_base64(),
"images": [image.to_base64(), image.to_base64(), "other_data"],
"other_data": "other_data",
}

base64_data = convert_multimedia_data_to_base64(data, with_type=True)
prefix = f"data:{image._mime_type};base64,"
assert base64_data == {
"image": prefix + image.to_base64(),
"images": [prefix + image.to_base64(), prefix + image.to_base64(), "other_data"],
"other_data": "other_data",
}

def test_load_multimedia_data(self):
inputs = {
"image": FlowInputDefinition(type=ValueType.IMAGE),
"images": FlowInputDefinition(type=ValueType.LIST),
"object": FlowInputDefinition(type=ValueType.OBJECT),
}
line_inputs = {
"image": {"data:image/jpg;path": str(TEST_IMAGE_PATH)},
"images": [{"data:image/jpg;path": str(TEST_IMAGE_PATH)}, {"data:image/jpg;path": str(TEST_IMAGE_PATH)}],
"object": {"image": {"data:image/jpg;path": str(TEST_IMAGE_PATH)}, "other_data": "other_data"}
}
updated_inputs = load_multimedia_data(inputs, line_inputs)
image = _create_image_from_file(TEST_IMAGE_PATH)
assert updated_inputs == {
"image": image,
"images": [image, image],
"object": {"image": image, "other_data": "other_data"}
}
Loading