diff --git a/src/promptflow/promptflow/_utils/multimedia_utils.py b/src/promptflow/promptflow/_utils/multimedia_utils.py index fd059b07345..21ee4ebe1de 100644 --- a/src/promptflow/promptflow/_utils/multimedia_utils.py +++ b/src/promptflow/promptflow/_utils/multimedia_utils.py @@ -19,19 +19,19 @@ MIME_PATTERN = re.compile(r"^data:image/(.*);(path|base64|url)$") -def get_mime_type_from_path(path: Path): +def _get_mime_type_from_path(path: Path): ext = path.suffix[1:] return f"image/{ext}" if ext else "image/*" -def get_extension_from_mime_type(mime_type: str): +def _get_extension_from_mime_type(mime_type: str): ext = mime_type.split("/")[-1] if ext == "*": return None return ext -def is_multimedia_dict(multimedia_dict: dict): +def _is_multimedia_dict(multimedia_dict: dict): if len(multimedia_dict) != 1: return False key = list(multimedia_dict.keys())[0] @@ -40,14 +40,14 @@ def is_multimedia_dict(multimedia_dict: dict): return False -def get_multimedia_info(key: str): +def _get_multimedia_info(key: str): match = re.match(MIME_PATTERN, key) if match: return match.group(1), match.group(2) return None, None -def is_url(value: str): +def _is_url(value: str): try: result = urlparse(value) return all([result.scheme, result.netloc]) @@ -55,21 +55,21 @@ def is_url(value: str): return False -def is_base64(value: str): +def _is_base64(value: str): base64_regex = re.compile(r"^([A-Za-z0-9+/]{4})*(([A-Za-z0-9+/]{2})*(==|[A-Za-z0-9+/]=)?)?$") if re.match(base64_regex, value): return True return False -def create_image_from_file(f: Path, mime_type: str = None): +def _create_image_from_file(f: Path, mime_type: str = None): if not mime_type: - mime_type = get_mime_type_from_path(f) + mime_type = _get_mime_type_from_path(f) with open(f, "rb") as fin: return Image(fin.read(), mime_type=mime_type) -def create_image_from_base64(base64_str: str, mime_type: str = None): +def _create_image_from_base64(base64_str: str, mime_type: str = None): image_bytes = base64.b64decode(base64_str) if not mime_type: format = imghdr.what(None, image_bytes) @@ -77,7 +77,7 @@ def create_image_from_base64(base64_str: str, mime_type: str = None): return Image(image_bytes, mime_type=mime_type) -def create_image_from_url(url: str, mime_type: str = None): +def _create_image_from_url(url: str, mime_type: str = None): response = requests.get(url) if response.status_code == 200: if not mime_type: @@ -92,15 +92,15 @@ def create_image_from_url(url: str, mime_type: str = None): ) -def create_image_from_dict(image_dict: dict): +def _create_image_from_dict(image_dict: dict): for k, v in image_dict.items(): - format, resource = get_multimedia_info(k) + format, resource = _get_multimedia_info(k) if resource == "path": - return create_image_from_file(v, mime_type=f"image/{format}") + return _create_image_from_file(v, mime_type=f"image/{format}") elif resource == "base64": - return create_image_from_base64(v, mime_type=f"image/{format}") + return _create_image_from_base64(v, mime_type=f"image/{format}") elif resource == "url": - return create_image_from_url(v, mime_type=f"image/{format}") + return _create_image_from_url(v, mime_type=f"image/{format}") else: raise InvalidImageInput( message_format=f"Unsupported image resource: {resource}. " @@ -109,24 +109,21 @@ def create_image_from_dict(image_dict: dict): ) -def create_image_from_string(value: str, base_dir: Path = None): - if is_base64(value): - return create_image_from_base64(value) - elif is_url(value): - return create_image_from_url(value) +def _create_image_from_string(value: str): + if _is_base64(value): + return _create_image_from_base64(value) + elif _is_url(value): + return _create_image_from_url(value) else: - path = Path(value) - if base_dir and not path.is_absolute(): - path = Path.joinpath(base_dir, path) - return create_image_from_file(path) + return _create_image_from_file(Path(value)) -def create_image(value: any, base_dir: Path = None): +def create_image(value: any): if isinstance(value, PFBytes): return value elif isinstance(value, dict): - if is_multimedia_dict(value): - return create_image_from_dict(value) + if _is_multimedia_dict(value): + return _create_image_from_dict(value) else: raise InvalidImageInput( message_format="Invalid image input format. The image input should be a dictionary like: " @@ -134,7 +131,7 @@ def create_image(value: any, base_dir: Path = None): target=ErrorTarget.EXECUTOR, ) elif isinstance(value, str): - return create_image_from_string(value, base_dir) + return _create_image_from_string(value) else: raise InvalidImageInput( message_format=f"Unsupported image input type: {type(value)}. " @@ -143,8 +140,8 @@ def create_image(value: any, base_dir: Path = None): ) -def save_image_to_file(image: Image, file_name: str, folder_path: Path, relative_path: Path = None): - ext = get_extension_from_mime_type(image._mime_type) +def _save_image_to_file(image: Image, file_name: str, folder_path: Path, relative_path: Path = None): + ext = _get_extension_from_mime_type(image._mime_type) file_name = f"{file_name}.{ext}" if ext else file_name image_reference = { f"data:{image._mime_type};path": str(relative_path / file_name) if relative_path else file_name @@ -161,7 +158,7 @@ def pfbytes_file_reference_encoder(obj): """Dumps PFBytes to a file and returns its reference.""" if isinstance(obj, PFBytes): file_name = str(uuid.uuid4()) - return save_image_to_file(obj, file_name, folder_path, relative_path) + return _save_image_to_file(obj, file_name, folder_path, relative_path) raise TypeError(f"Not supported to dump type '{type(obj).__name__}'.") return pfbytes_file_reference_encoder @@ -197,11 +194,11 @@ def recursive_process(value: Any, process_funcs: dict[type, Callable] = None) -> return value -def load_multimedia_data(inputs: dict[str, FlowInputDefinition], line_inputs: dict, base_dir: Path): +def load_multimedia_data(inputs: dict[str, FlowInputDefinition], line_inputs: dict): updated_inputs = dict(line_inputs or {}) for key, value in inputs.items(): if value.type == ValueType.IMAGE: - updated_inputs[key] = create_image(updated_inputs[key], base_dir) + updated_inputs[key] = create_image(updated_inputs[key]) elif value.type == ValueType.LIST or value.type == ValueType.OBJECT: updated_inputs[key] = load_multimedia_data_recursively(updated_inputs[key]) return updated_inputs @@ -211,8 +208,8 @@ def load_multimedia_data_recursively(value: Any): if isinstance(value, list): return [load_multimedia_data_recursively(item) for item in value] elif isinstance(value, dict): - if is_multimedia_dict(value): - return create_image_from_dict(value) + if _is_multimedia_dict(value): + return _create_image_from_dict(value) else: return {k: load_multimedia_data_recursively(v) for k, v in value.items()} else: diff --git a/src/promptflow/promptflow/executor/_tool_resolver.py b/src/promptflow/promptflow/executor/_tool_resolver.py index 978589a0e02..0a360b02ac6 100644 --- a/src/promptflow/promptflow/executor/_tool_resolver.py +++ b/src/promptflow/promptflow/executor/_tool_resolver.py @@ -13,7 +13,7 @@ from promptflow._core.connection_manager import ConnectionManager from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func -from promptflow._utils.multimedia_utils import create_image +from promptflow._utils.multimedia_utils import create_image, load_multimedia_data_recursively from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType from promptflow.contracts.types import PromptTemplate @@ -102,10 +102,11 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: type else: updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type) elif value_type == ValueType.IMAGE: - updated_inputs[k].value = create_image(v.value, self._working_dir) + updated_inputs[k].value = create_image(v.value) elif isinstance(value_type, ValueType): try: updated_inputs[k].value = value_type.parse(v.value) + updated_inputs[k].value = load_multimedia_data_recursively(updated_inputs[k].value) except Exception as e: msg = f"Input '{k}' for node '{node.name}' of value {v.value} is not type {value_type}." raise NodeInputValidationError(message=msg) from e @@ -173,7 +174,7 @@ def _load_images_for_prompt_tpl(self, prompt_tpl_inputs_mapping: dict, node_inpu for input_name, input in prompt_tpl_inputs_mapping.items(): if ValueType.IMAGE in input.type and input_name in node_inputs: if node_inputs[input_name].value_type == InputValueType.LITERAL: - node_inputs[input_name].value = create_image(node_inputs[input_name].value, self._working_dir) + node_inputs[input_name].value = create_image(node_inputs[input_name].value) return node_inputs def _resolve_prompt_node(self, node: Node) -> ResolvedTool: diff --git a/src/promptflow/promptflow/executor/flow_executor.py b/src/promptflow/promptflow/executor/flow_executor.py index a4be3377b7a..1d524806901 100644 --- a/src/promptflow/promptflow/executor/flow_executor.py +++ b/src/promptflow/promptflow/executor/flow_executor.py @@ -696,7 +696,7 @@ def exec_line( """ self._node_concurrency = node_concurrency inputs_with_default_value = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs) - inputs = load_multimedia_data(self._flow.inputs, inputs_with_default_value, self._working_dir) + inputs = load_multimedia_data(self._flow.inputs, inputs_with_default_value) # For flow run, validate inputs as default with self._run_tracker.node_log_manager: # exec_line interface may be called by exec_bulk, so we only set run_mode as flow run when diff --git a/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/flow.dag.yaml b/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/flow.dag.yaml index 8bb57626e55..5ab71cb5a9c 100644 --- a/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/flow.dag.yaml +++ b/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/flow.dag.yaml @@ -23,7 +23,11 @@ nodes: type: code path: pick_images_from_list.py inputs: - image_list: ${inputs.image_list} + image_list: + - data:image/jpg;path: logo.jpg + - data:image/jpg;path: logo_2.jpg + - data:image/jpg;path: logo_3.jpg + image_list_2: ${inputs.image_list} image_dict: ${inputs.image_dict} - idx_l: 1 - idx_r: 2 + idx_1: 1 + idx_2: 2 diff --git a/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/pick_images_from_list.py b/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/pick_images_from_list.py index d6f03b39ed4..1b439593974 100644 --- a/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/pick_images_from_list.py +++ b/src/promptflow/tests/test_configs/flows/python_tool_with_image_list/pick_images_from_list.py @@ -3,8 +3,14 @@ @tool -def pick_images_from_list(image_list: list[Image], image_dict: dict, idx_l: int, idx_r: int) -> list[Image]: - if idx_l <= idx_r and idx_l >= 0 and idx_r < len(image_list): - return {"Image list": image_list[idx_l:idx_r + 1], "Image dict": image_dict} +def pick_images_from_list( + image_list: list[Image], + image_list_2: list[Image], + image_dict: dict, + idx_1: int, + idx_2: int +) -> list[Image]: + if idx_1 >= 0 and idx_1 < len(image_list) and idx_2 >= 0 and idx_2 < len(image_list_2): + return {"Image list": [image_list[idx_1], image_list_2[idx_2]], "Image dict": image_dict} else: raise Exception(f"Invalid index.")