|
| 1 | +import os |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +from promptflow._utils.multimedia_utils import _create_image_from_file, _is_multimedia_dict |
| 7 | +from promptflow.contracts.multimedia import Image |
| 8 | +from promptflow.contracts.run_info import Status |
| 9 | +from promptflow.executor import FlowExecutor |
| 10 | +from promptflow.storage._run_storage import DefaultRunStorage |
| 11 | + |
| 12 | +from ..utils import FLOW_ROOT, get_yaml_file, get_yaml_working_dir |
| 13 | + |
| 14 | +SIMPLE_IMAGE_FLOW = "python_tool_with_simple_image" |
| 15 | +SIMPLE_IMAGE_FLOW_PATH = FLOW_ROOT / "python_tool_with_simple_image" |
| 16 | +COMPOSITE_IMAGE_FLOW = "python_tool_with_composite_image" |
| 17 | +COMPOSITE_IMAGE_FLOW_PATH = FLOW_ROOT / "python_tool_with_composite_image" |
| 18 | +IMAGE_URL = ( |
| 19 | + "https://github.com/microsoft/promptflow/blob/93776a0631abf991896ab07d294f62082d5df3f3/src" |
| 20 | + "/promptflow/tests/test_configs/datas/test_image.jpg?raw=true" |
| 21 | +) |
| 22 | + |
| 23 | + |
| 24 | +def get_test_cases_for_simple_input(): |
| 25 | + image = _create_image_from_file(SIMPLE_IMAGE_FLOW_PATH / "logo.jpg") |
| 26 | + inputs = [ |
| 27 | + {"data:image/jpg;path": str(SIMPLE_IMAGE_FLOW_PATH / "logo.jpg")}, |
| 28 | + {"data:image/jpg;base64": image.to_base64()}, |
| 29 | + {"data:image/jpg;url": IMAGE_URL}, |
| 30 | + str(SIMPLE_IMAGE_FLOW_PATH / "logo.jpg"), |
| 31 | + image.to_base64(), |
| 32 | + IMAGE_URL, |
| 33 | + ] |
| 34 | + return [(SIMPLE_IMAGE_FLOW, {"image": input}) for input in inputs] |
| 35 | + |
| 36 | + |
| 37 | +def get_test_cases_for_composite_input(): |
| 38 | + image_1 = _create_image_from_file(COMPOSITE_IMAGE_FLOW_PATH / "logo.jpg") |
| 39 | + image_2 = _create_image_from_file(COMPOSITE_IMAGE_FLOW_PATH / "logo_2.png") |
| 40 | + inputs = [ |
| 41 | + [ |
| 42 | + {"data:image/jpg;path": str(COMPOSITE_IMAGE_FLOW_PATH / "logo.jpg")}, |
| 43 | + {"data:image/png;path": str(COMPOSITE_IMAGE_FLOW_PATH / "logo_2.png")} |
| 44 | + ], |
| 45 | + [{"data:image/jpg;base64": image_1.to_base64()}, {"data:image/png;base64": image_2.to_base64()}], |
| 46 | + [{"data:image/jpg;url": IMAGE_URL}, {"data:image/png;url": IMAGE_URL}], |
| 47 | + ] |
| 48 | + return [ |
| 49 | + (COMPOSITE_IMAGE_FLOW, {"image_list": input, "image_dict": {"image_1": input[0], "image_2": input[1]}}) |
| 50 | + for input in inputs |
| 51 | + ] |
| 52 | + |
| 53 | + |
| 54 | +def get_test_cases_for_node_run(): |
| 55 | + image = {"data:image/jpg;path": str(SIMPLE_IMAGE_FLOW_PATH / "logo.jpg")} |
| 56 | + simple_image_input = {"image": image} |
| 57 | + image_list = [{"data:image/jpg;path": "logo.jpg"}, {"data:image/png;path": "logo_2.png"}] |
| 58 | + image_dict = { |
| 59 | + "image_dict": { |
| 60 | + "image_1": {"data:image/jpg;path": "logo.jpg"}, |
| 61 | + "image_2": {"data:image/png;path": "logo_2.png"}, |
| 62 | + } |
| 63 | + } |
| 64 | + composite_image_input = {"image_list": image_list, "image_dcit": image_dict} |
| 65 | + |
| 66 | + return [ |
| 67 | + (SIMPLE_IMAGE_FLOW, "python_node", simple_image_input, None), |
| 68 | + (SIMPLE_IMAGE_FLOW, "python_node_2", simple_image_input, {"python_node": image}), |
| 69 | + (COMPOSITE_IMAGE_FLOW, "python_node", composite_image_input, None), |
| 70 | + (COMPOSITE_IMAGE_FLOW, "python_node_2", composite_image_input, None), |
| 71 | + ( |
| 72 | + COMPOSITE_IMAGE_FLOW, "python_node_3", composite_image_input, |
| 73 | + {"python_node": image_list, "python_node_2": image_dict} |
| 74 | + ), |
| 75 | + ] |
| 76 | + |
| 77 | + |
| 78 | +def assert_contain_image_reference(value): |
| 79 | + assert not isinstance(value, Image) |
| 80 | + if isinstance(value, list): |
| 81 | + for item in value: |
| 82 | + assert_contain_image_reference(item) |
| 83 | + elif isinstance(value, dict): |
| 84 | + if _is_multimedia_dict(value): |
| 85 | + path = list(value.values())[0] |
| 86 | + assert isinstance(path, str) |
| 87 | + assert path.endswith(".jpg") or path.endswith(".jpeg") or path.endswith(".png") |
| 88 | + else: |
| 89 | + for _, v in value.items(): |
| 90 | + assert_contain_image_reference(v) |
| 91 | + |
| 92 | + |
| 93 | +def assert_contain_image_object(value): |
| 94 | + if isinstance(value, list): |
| 95 | + for item in value: |
| 96 | + assert_contain_image_object(item) |
| 97 | + elif isinstance(value, dict): |
| 98 | + assert not _is_multimedia_dict(value) |
| 99 | + for _, v in value.items(): |
| 100 | + assert_contain_image_object(v) |
| 101 | + else: |
| 102 | + assert isinstance(value, Image) |
| 103 | + |
| 104 | + |
| 105 | +@pytest.mark.usefixtures("dev_connections") |
| 106 | +@pytest.mark.e2etest |
| 107 | +class TestExecutorWithImage: |
| 108 | + @pytest.mark.parametrize( |
| 109 | + "flow_folder, inputs", get_test_cases_for_simple_input() + get_test_cases_for_composite_input() |
| 110 | + ) |
| 111 | + def test_executor_exec_line_with_image(self, flow_folder, inputs, dev_connections): |
| 112 | + working_dir = get_yaml_working_dir(flow_folder) |
| 113 | + os.chdir(working_dir) |
| 114 | + storage = DefaultRunStorage(base_dir=working_dir, sub_dir=Path("./temp")) |
| 115 | + executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections, storage=storage) |
| 116 | + flow_result = executor.exec_line(inputs) |
| 117 | + assert isinstance(flow_result.output, dict) |
| 118 | + assert_contain_image_object(flow_result.output) |
| 119 | + assert flow_result.run_info.status == Status.Completed |
| 120 | + assert_contain_image_reference(flow_result.run_info) |
| 121 | + for _, node_run_info in flow_result.node_run_infos.items(): |
| 122 | + assert node_run_info.status == Status.Completed |
| 123 | + assert_contain_image_reference(node_run_info) |
| 124 | + |
| 125 | + @pytest.mark.parametrize( |
| 126 | + "flow_folder, node_name, flow_inputs, dependency_nodes_outputs", get_test_cases_for_node_run() |
| 127 | + ) |
| 128 | + def test_executor_exec_node_with_image(self, flow_folder, node_name, flow_inputs, dependency_nodes_outputs, |
| 129 | + dev_connections): |
| 130 | + working_dir = get_yaml_working_dir(flow_folder) |
| 131 | + os.chdir(working_dir) |
| 132 | + run_info = FlowExecutor.load_and_exec_node( |
| 133 | + get_yaml_file(flow_folder), |
| 134 | + node_name, |
| 135 | + flow_inputs=flow_inputs, |
| 136 | + dependency_nodes_outputs=dependency_nodes_outputs, |
| 137 | + connections=dev_connections, |
| 138 | + output_sub_dir=("./temp"), |
| 139 | + raise_ex=True, |
| 140 | + ) |
| 141 | + assert run_info.status == Status.Completed |
| 142 | + assert_contain_image_reference(run_info) |
0 commit comments