Skip to content

Commit

Permalink
Support images in node input
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Oct 20, 2023
1 parent 70e65be commit 5f14fb9
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 45 deletions.
67 changes: 32 additions & 35 deletions src/promptflow/promptflow/_utils/multimedia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
MIME_PATTERN = re.compile(r"^data:image/(.*);(path|base64|url)$")


def get_mime_type_from_path(path: Path):
def _get_mime_type_from_path(path: Path):
ext = path.suffix[1:]
return f"image/{ext}" if ext else "image/*"


def get_extension_from_mime_type(mime_type: str):
def _get_extension_from_mime_type(mime_type: str):
ext = mime_type.split("/")[-1]
if ext == "*":
return None
return ext


def is_multimedia_dict(multimedia_dict: dict):
def _is_multimedia_dict(multimedia_dict: dict):
if len(multimedia_dict) != 1:
return False
key = list(multimedia_dict.keys())[0]
Expand All @@ -40,44 +40,44 @@ def is_multimedia_dict(multimedia_dict: dict):
return False


def get_multimedia_info(key: str):
def _get_multimedia_info(key: str):
match = re.match(MIME_PATTERN, key)
if match:
return match.group(1), match.group(2)
return None, None


def is_url(value: str):
def _is_url(value: str):
try:
result = urlparse(value)
return all([result.scheme, result.netloc])
except ValueError:
return False


def is_base64(value: str):
def _is_base64(value: str):
base64_regex = re.compile(r"^([A-Za-z0-9+/]{4})*(([A-Za-z0-9+/]{2})*(==|[A-Za-z0-9+/]=)?)?$")
if re.match(base64_regex, value):
return True
return False


def create_image_from_file(f: Path, mime_type: str = None):
def _create_image_from_file(f: Path, mime_type: str = None):
if not mime_type:
mime_type = get_mime_type_from_path(f)
mime_type = _get_mime_type_from_path(f)
with open(f, "rb") as fin:
return Image(fin.read(), mime_type=mime_type)


def create_image_from_base64(base64_str: str, mime_type: str = None):
def _create_image_from_base64(base64_str: str, mime_type: str = None):
image_bytes = base64.b64decode(base64_str)
if not mime_type:
format = imghdr.what(None, image_bytes)
mime_type = f"image/{format}" if format else "image/*"
return Image(image_bytes, mime_type=mime_type)


def create_image_from_url(url: str, mime_type: str = None):
def _create_image_from_url(url: str, mime_type: str = None):
response = requests.get(url)
if response.status_code == 200:
if not mime_type:
Expand All @@ -92,15 +92,15 @@ def create_image_from_url(url: str, mime_type: str = None):
)


def create_image_from_dict(image_dict: dict):
def _create_image_from_dict(image_dict: dict):
for k, v in image_dict.items():
format, resource = get_multimedia_info(k)
format, resource = _get_multimedia_info(k)
if resource == "path":
return create_image_from_file(v, mime_type=f"image/{format}")
return _create_image_from_file(v, mime_type=f"image/{format}")
elif resource == "base64":
return create_image_from_base64(v, mime_type=f"image/{format}")
return _create_image_from_base64(v, mime_type=f"image/{format}")
elif resource == "url":
return create_image_from_url(v, mime_type=f"image/{format}")
return _create_image_from_url(v, mime_type=f"image/{format}")
else:
raise InvalidImageInput(
message_format=f"Unsupported image resource: {resource}. "
Expand All @@ -109,32 +109,29 @@ def create_image_from_dict(image_dict: dict):
)


def create_image_from_string(value: str, base_dir: Path = None):
if is_base64(value):
return create_image_from_base64(value)
elif is_url(value):
return create_image_from_url(value)
def _create_image_from_string(value: str):
if _is_base64(value):
return _create_image_from_base64(value)
elif _is_url(value):
return _create_image_from_url(value)
else:
path = Path(value)
if base_dir and not path.is_absolute():
path = Path.joinpath(base_dir, path)
return create_image_from_file(path)
return _create_image_from_file(Path(value))


def create_image(value: any, base_dir: Path = None):
def create_image(value: any):
if isinstance(value, PFBytes):
return value
elif isinstance(value, dict):
if is_multimedia_dict(value):
return create_image_from_dict(value)
if _is_multimedia_dict(value):
return _create_image_from_dict(value)
else:
raise InvalidImageInput(
message_format="Invalid image input format. The image input should be a dictionary like: "
"{data:image/<image_type>;[path|base64|url]: <image_data>}.",
target=ErrorTarget.EXECUTOR,
)
elif isinstance(value, str):
return create_image_from_string(value, base_dir)
return _create_image_from_string(value)
else:
raise InvalidImageInput(
message_format=f"Unsupported image input type: {type(value)}. "
Expand All @@ -143,8 +140,8 @@ def create_image(value: any, base_dir: Path = None):
)


def save_image_to_file(image: Image, file_name: str, folder_path: Path, relative_path: Path = None):
ext = get_extension_from_mime_type(image._mime_type)
def _save_image_to_file(image: Image, file_name: str, folder_path: Path, relative_path: Path = None):
ext = _get_extension_from_mime_type(image._mime_type)
file_name = f"{file_name}.{ext}" if ext else file_name
image_reference = {
f"data:{image._mime_type};path": str(relative_path / file_name) if relative_path else file_name
Expand All @@ -161,7 +158,7 @@ def pfbytes_file_reference_encoder(obj):
"""Dumps PFBytes to a file and returns its reference."""
if isinstance(obj, PFBytes):
file_name = str(uuid.uuid4())
return save_image_to_file(obj, file_name, folder_path, relative_path)
return _save_image_to_file(obj, file_name, folder_path, relative_path)
raise TypeError(f"Not supported to dump type '{type(obj).__name__}'.")
return pfbytes_file_reference_encoder

Expand Down Expand Up @@ -197,11 +194,11 @@ def recursive_process(value: Any, process_funcs: dict[type, Callable] = None) ->
return value


def load_multimedia_data(inputs: dict[str, FlowInputDefinition], line_inputs: dict, base_dir: Path):
def load_multimedia_data(inputs: dict[str, FlowInputDefinition], line_inputs: dict):
updated_inputs = dict(line_inputs or {})
for key, value in inputs.items():
if value.type == ValueType.IMAGE:
updated_inputs[key] = create_image(updated_inputs[key], base_dir)
updated_inputs[key] = create_image(updated_inputs[key])
elif value.type == ValueType.LIST or value.type == ValueType.OBJECT:
updated_inputs[key] = load_multimedia_data_recursively(updated_inputs[key])
return updated_inputs
Expand All @@ -211,8 +208,8 @@ def load_multimedia_data_recursively(value: Any):
if isinstance(value, list):
return [load_multimedia_data_recursively(item) for item in value]
elif isinstance(value, dict):
if is_multimedia_dict(value):
return create_image_from_dict(value)
if _is_multimedia_dict(value):
return _create_image_from_dict(value)
else:
return {k: load_multimedia_data_recursively(v) for k, v in value.items()}
else:
Expand Down
7 changes: 4 additions & 3 deletions src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from promptflow._core.connection_manager import ConnectionManager
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func
from promptflow._utils.multimedia_utils import create_image
from promptflow._utils.multimedia_utils import create_image, load_multimedia_data_recursively
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
from promptflow.contracts.types import PromptTemplate
Expand Down Expand Up @@ -102,10 +102,11 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: type
else:
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
elif value_type == ValueType.IMAGE:
updated_inputs[k].value = create_image(v.value, self._working_dir)
updated_inputs[k].value = create_image(v.value)
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
updated_inputs[k].value = load_multimedia_data_recursively(updated_inputs[k].value)
except Exception as e:
msg = f"Input '{k}' for node '{node.name}' of value {v.value} is not type {value_type}."
raise NodeInputValidationError(message=msg) from e
Expand Down Expand Up @@ -173,7 +174,7 @@ def _load_images_for_prompt_tpl(self, prompt_tpl_inputs_mapping: dict, node_inpu
for input_name, input in prompt_tpl_inputs_mapping.items():
if ValueType.IMAGE in input.type and input_name in node_inputs:
if node_inputs[input_name].value_type == InputValueType.LITERAL:
node_inputs[input_name].value = create_image(node_inputs[input_name].value, self._working_dir)
node_inputs[input_name].value = create_image(node_inputs[input_name].value)
return node_inputs

def _resolve_prompt_node(self, node: Node) -> ResolvedTool:
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def exec_line(
"""
self._node_concurrency = node_concurrency
inputs_with_default_value = FlowExecutor._apply_default_value_for_input(self._flow.inputs, inputs)
inputs = load_multimedia_data(self._flow.inputs, inputs_with_default_value, self._working_dir)
inputs = load_multimedia_data(self._flow.inputs, inputs_with_default_value)
# For flow run, validate inputs as default
with self._run_tracker.node_log_manager:
# exec_line interface may be called by exec_bulk, so we only set run_mode as flow run when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ nodes:
type: code
path: pick_images_from_list.py
inputs:
image_list: ${inputs.image_list}
image_list:
- data:image/jpg;path: logo.jpg
- data:image/jpg;path: logo_2.jpg
- data:image/jpg;path: logo_3.jpg
image_list_2: ${inputs.image_list}
image_dict: ${inputs.image_dict}
idx_l: 1
idx_r: 2
idx_1: 1
idx_2: 2
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@


@tool
def pick_images_from_list(image_list: list[Image], image_dict: dict, idx_l: int, idx_r: int) -> list[Image]:
if idx_l <= idx_r and idx_l >= 0 and idx_r < len(image_list):
return {"Image list": image_list[idx_l:idx_r + 1], "Image dict": image_dict}
def pick_images_from_list(
image_list: list[Image],
image_list_2: list[Image],
image_dict: dict,
idx_1: int,
idx_2: int
) -> list[Image]:
if idx_1 >= 0 and idx_1 < len(image_list) and idx_2 >= 0 and idx_2 < len(image_list_2):
return {"Image list": [image_list[idx_1], image_list_2[idx_2]], "Image dict": image_dict}
else:
raise Exception(f"Invalid index.")

0 comments on commit 5f14fb9

Please sign in to comment.