Skip to content

Commit

Permalink
Support including images in complicated structure as node input (#837)
Browse files Browse the repository at this point in the history
# Description

We only support single image as node input before, in this PR, we
support including images in complicated structure as node input.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

Co-authored-by: Lina Tang <[email protected]>
  • Loading branch information
lumoslnt and Lina Tang authored Oct 20, 2023
1 parent 089f18d commit cae6dfd
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 45 deletions.
67 changes: 32 additions & 35 deletions src/promptflow/promptflow/_utils/multimedia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
MIME_PATTERN = re.compile(r"^data:image/(.*);(path|base64|url)$")


def get_mime_type_from_path(path: Path):
def _get_mime_type_from_path(path: Path):
ext = path.suffix[1:]
return f"image/{ext}" if ext else "image/*"


def get_extension_from_mime_type(mime_type: str):
def _get_extension_from_mime_type(mime_type: str):
ext = mime_type.split("/")[-1]
if ext == "*":
return None
return ext


def is_multimedia_dict(multimedia_dict: dict):
def _is_multimedia_dict(multimedia_dict: dict):
if len(multimedia_dict) != 1:
return False
key = list(multimedia_dict.keys())[0]
Expand All @@ -40,44 +40,44 @@ def is_multimedia_dict(multimedia_dict: dict):
return False


def get_multimedia_info(key: str):
def _get_multimedia_info(key: str):
match = re.match(MIME_PATTERN, key)
if match:
return match.group(1), match.group(2)
return None, None


def is_url(value: str):
def _is_url(value: str):
try:
result = urlparse(value)
return all([result.scheme, result.netloc])
except ValueError:
return False


def is_base64(value: str):
def _is_base64(value: str):
base64_regex = re.compile(r"^([A-Za-z0-9+/]{4})*(([A-Za-z0-9+/]{2})*(==|[A-Za-z0-9+/]=)?)?$")
if re.match(base64_regex, value):
return True
return False


def create_image_from_file(f: Path, mime_type: str = None):
def _create_image_from_file(f: Path, mime_type: str = None):
if not mime_type:
mime_type = get_mime_type_from_path(f)
mime_type = _get_mime_type_from_path(f)
with open(f, "rb") as fin:
return Image(fin.read(), mime_type=mime_type)


def create_image_from_base64(base64_str: str, mime_type: str = None):
def _create_image_from_base64(base64_str: str, mime_type: str = None):
image_bytes = base64.b64decode(base64_str)
if not mime_type:
format = imghdr.what(None, image_bytes)
mime_type = f"image/{format}" if format else "image/*"
return Image(image_bytes, mime_type=mime_type)


def create_image_from_url(url: str, mime_type: str = None):
def _create_image_from_url(url: str, mime_type: str = None):
response = requests.get(url)
if response.status_code == 200:
if not mime_type:
Expand All @@ -92,15 +92,15 @@ def create_image_from_url(url: str, mime_type: str = None):
)


def create_image_from_dict(image_dict: dict):
def _create_image_from_dict(image_dict: dict):
for k, v in image_dict.items():
format, resource = get_multimedia_info(k)
format, resource = _get_multimedia_info(k)
if resource == "path":
return create_image_from_file(v, mime_type=f"image/{format}")
return _create_image_from_file(v, mime_type=f"image/{format}")
elif resource == "base64":
return create_image_from_base64(v, mime_type=f"image/{format}")
return _create_image_from_base64(v, mime_type=f"image/{format}")
elif resource == "url":
return create_image_from_url(v, mime_type=f"image/{format}")
return _create_image_from_url(v, mime_type=f"image/{format}")
else:
raise InvalidImageInput(
message_format=f"Unsupported image resource: {resource}. "
Expand All @@ -109,32 +109,29 @@ def create_image_from_dict(image_dict: dict):
)


def create_image_from_string(value: str, base_dir: Path = None):
if is_base64(value):
return create_image_from_base64(value)
elif is_url(value):
return create_image_from_url(value)
def _create_image_from_string(value: str):
if _is_base64(value):
return _create_image_from_base64(value)
elif _is_url(value):
return _create_image_from_url(value)
else:
path = Path(value)
if base_dir and not path.is_absolute():
path = Path.joinpath(base_dir, path)
return create_image_from_file(path)
return _create_image_from_file(Path(value))


def create_image(value: any, base_dir: Path = None):
def create_image(value: any):
if isinstance(value, PFBytes):
return value
elif isinstance(value, dict):
if is_multimedia_dict(value):
return create_image_from_dict(value)
if _is_multimedia_dict(value):
return _create_image_from_dict(value)
else:
raise InvalidImageInput(
message_format="Invalid image input format. The image input should be a dictionary like: "
"{data:image/<image_type>;[path|base64|url]: <image_data>}.",
target=ErrorTarget.EXECUTOR,
)
elif isinstance(value, str):
return create_image_from_string(value, base_dir)
return _create_image_from_string(value)
else:
raise InvalidImageInput(
message_format=f"Unsupported image input type: {type(value)}. "
Expand All @@ -143,8 +140,8 @@ def create_image(value: any, base_dir: Path = None):
)


def save_image_to_file(image: Image, file_name: str, folder_path: Path, relative_path: Path = None):
ext = get_extension_from_mime_type(image._mime_type)
def _save_image_to_file(image: Image, file_name: str, folder_path: Path, relative_path: Path = None):
ext = _get_extension_from_mime_type(image._mime_type)
file_name = f"{file_name}.{ext}" if ext else file_name
image_reference = {f"data:{image._mime_type};path": str(relative_path / file_name) if relative_path else file_name}
path = folder_path / relative_path if relative_path else folder_path
Expand All @@ -159,7 +156,7 @@ def pfbytes_file_reference_encoder(obj):
"""Dumps PFBytes to a file and returns its reference."""
if isinstance(obj, PFBytes):
file_name = str(uuid.uuid4())
return save_image_to_file(obj, file_name, folder_path, relative_path)
return _save_image_to_file(obj, file_name, folder_path, relative_path)
raise TypeError(f"Not supported to dump type '{type(obj).__name__}'.")

return pfbytes_file_reference_encoder
Expand Down Expand Up @@ -196,11 +193,11 @@ def recursive_process(value: Any, process_funcs: Dict[type, Callable] = None) ->
return value


def load_multimedia_data(inputs: Dict[str, FlowInputDefinition], line_inputs: dict, base_dir: Path):
def load_multimedia_data(inputs: Dict[str, FlowInputDefinition], line_inputs: dict):
updated_inputs = dict(line_inputs or {})
for key, value in inputs.items():
if value.type == ValueType.IMAGE:
updated_inputs[key] = create_image(updated_inputs[key], base_dir)
updated_inputs[key] = create_image(updated_inputs[key])
elif value.type == ValueType.LIST or value.type == ValueType.OBJECT:
updated_inputs[key] = load_multimedia_data_recursively(updated_inputs[key])
return updated_inputs
Expand All @@ -210,8 +207,8 @@ def load_multimedia_data_recursively(value: Any):
if isinstance(value, list):
return [load_multimedia_data_recursively(item) for item in value]
elif isinstance(value, dict):
if is_multimedia_dict(value):
return create_image_from_dict(value)
if _is_multimedia_dict(value):
return _create_image_from_dict(value)
else:
return {k: load_multimedia_data_recursively(v) for k, v in value.items()}
else:
Expand Down
7 changes: 4 additions & 3 deletions src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from promptflow._core.connection_manager import ConnectionManager
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func
from promptflow._utils.multimedia_utils import create_image
from promptflow._utils.multimedia_utils import create_image, load_multimedia_data_recursively
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
from promptflow.contracts.types import PromptTemplate
Expand Down Expand Up @@ -102,10 +102,11 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: type
else:
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
elif value_type == ValueType.IMAGE:
updated_inputs[k].value = create_image(v.value, self._working_dir)
updated_inputs[k].value = create_image(v.value)
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
updated_inputs[k].value = load_multimedia_data_recursively(updated_inputs[k].value)
except Exception as e:
msg = f"Input '{k}' for node '{node.name}' of value {v.value} is not type {value_type}."
raise NodeInputValidationError(message=msg) from e
Expand Down Expand Up @@ -173,7 +174,7 @@ def _load_images_for_prompt_tpl(self, prompt_tpl_inputs_mapping: dict, node_inpu
for input_name, input in prompt_tpl_inputs_mapping.items():
if ValueType.IMAGE in input.type and input_name in node_inputs:
if node_inputs[input_name].value_type == InputValueType.LITERAL:
node_inputs[input_name].value = create_image(node_inputs[input_name].value, self._working_dir)
node_inputs[input_name].value = create_image(node_inputs[input_name].value)
return node_inputs

def _resolve_prompt_node(self, node: Node) -> ResolvedTool:
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def exec_line(
"""
self._node_concurrency = node_concurrency
inputs_with_default_value = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs)
inputs = load_multimedia_data(self._flow.inputs, inputs_with_default_value, self._working_dir)
inputs = load_multimedia_data(self._flow.inputs, inputs_with_default_value)
# For flow run, validate inputs as default
with self._run_tracker.node_log_manager:
# exec_line interface may be called by exec_bulk, so we only set run_mode as flow run when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ nodes:
type: code
path: pick_images_from_list.py
inputs:
image_list: ${inputs.image_list}
image_list:
- data:image/jpg;path: logo.jpg
- data:image/jpg;path: logo_2.jpg
- data:image/jpg;path: logo_3.jpg
image_list_2: ${inputs.image_list}
image_dict: ${inputs.image_dict}
idx_l: 1
idx_r: 2
idx_1: 1
idx_2: 2
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@


@tool
def pick_images_from_list(image_list: list[Image], image_dict: dict, idx_l: int, idx_r: int) -> list[Image]:
if idx_l <= idx_r and idx_l >= 0 and idx_r < len(image_list):
return {"Image list": image_list[idx_l:idx_r + 1], "Image dict": image_dict}
def pick_images_from_list(
image_list: list[Image],
image_list_2: list[Image],
image_dict: dict,
idx_1: int,
idx_2: int
) -> list[Image]:
if idx_1 >= 0 and idx_1 < len(image_list) and idx_2 >= 0 and idx_2 < len(image_list_2):
return {"Image list": [image_list[idx_1], image_list_2[idx_2]], "Image dict": image_dict}
else:
raise Exception(f"Invalid index.")

0 comments on commit cae6dfd

Please sign in to comment.