Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Promptflow] Support image in SingleNodeRun #811

Merged
merged 19 commits into from
Oct 23, 2023
Merged
1 change: 1 addition & 0 deletions src/promptflow/promptflow/_internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,4 @@
NotFoundException,
SqliteClient,
)
from promptflow.storage._run_storage import DefaultRunStorage
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def node_test(
dependency_nodes_outputs=dependency_nodes_outputs,
connections=connections,
working_dir=self.flow.code,
output_sub_dir=".promptflow/intermediate",
)
return result

Expand Down
15 changes: 10 additions & 5 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from promptflow._core.tools_manager import ToolsManager
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.logger_utils import logger
from promptflow._utils.multimedia_utils import load_multimedia_data
from promptflow._utils.multimedia_utils import load_multimedia_data, load_multimedia_data_recursively
from promptflow._utils.utils import transpose
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputAssignment, InputValueType, Node
from promptflow.contracts.run_info import FlowRunInfo, Status
Expand All @@ -48,7 +48,7 @@
from promptflow.executor._tool_resolver import ToolResolver
from promptflow.executor.flow_validator import FlowValidator
from promptflow.storage import AbstractRunStorage
from promptflow.storage._run_storage import DefaultRunStorage, DummyRunStorage
from promptflow.storage._run_storage import DefaultRunStorage

LINE_NUMBER_KEY = "line_number" # Using the same key with portal.
LINE_TIMEOUT_SEC = 600
Expand Down Expand Up @@ -242,6 +242,7 @@ def load_and_exec_node(
flow_file: Path,
node_name: str,
*,
output_sub_dir: Optional[str] = None,
flow_inputs: Optional[Mapping[str, Any]] = None,
dependency_nodes_outputs: Optional[Mapping[str, Any]] = None,
connections: Optional[dict] = None,
Expand Down Expand Up @@ -294,8 +295,10 @@ def load_and_exec_node(
flow_file=flow_file,
)

flow_inputs = FlowExecutor._apply_default_value_for_input(flow.inputs, flow_inputs)
converted_flow_inputs_for_node = FlowValidator.convert_flow_inputs_for_node(flow, node, flow_inputs)
inputs_with_default_value = FlowExecutor._apply_default_value_for_input(flow.inputs, flow_inputs)
inputs = load_multimedia_data(flow.inputs, inputs_with_default_value, working_dir)
dependency_nodes_outputs = load_multimedia_data_recursively(dependency_nodes_outputs)
converted_flow_inputs_for_node = FlowValidator.convert_flow_inputs_for_node(flow, node, inputs)
package_tool_keys = [node.source.tool] if node.source and node.source.tool else []
tool_resolver = ToolResolver(working_dir, connections, package_tool_keys)
resolved_node = tool_resolver.resolve_tool_by_node(node)
Expand All @@ -320,7 +323,9 @@ def load_and_exec_node(
resolved_inputs = {k: v for k, v in resolved_inputs.items() if k not in resolved_node.init_args}

# TODO: Simplify the logic here
run_tracker = RunTracker(DummyRunStorage())
sub_dir = "." if output_sub_dir is None else output_sub_dir
storage = DefaultRunStorage(base_dir=working_dir, sub_dir=Path(sub_dir))
run_tracker = RunTracker(storage)
with run_tracker.node_log_manager:
ToolInvoker.activate(DefaultToolInvoker())

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import uuid
import os
from types import GeneratorType
from pathlib import Path

import pytest

Expand All @@ -11,6 +13,7 @@
from promptflow.executor import FlowExecutor
from promptflow.executor._errors import ConnectionNotFound, InputTypeError, ResolveToolError
from promptflow.executor.flow_executor import BulkResult, LineResult
from promptflow.storage._run_storage import DefaultRunStorage
from promptflow.storage import AbstractRunStorage

from ..utils import (
Expand All @@ -19,6 +22,7 @@
get_flow_expected_status_summary,
get_flow_sample_inputs,
get_yaml_file,
get_yaml_working_dir
)

SAMPLE_FLOW = "web_classification_no_variants"
Expand All @@ -27,6 +31,11 @@
SAMPLE_FLOW_WITH_LANGCHAIN_TRACES = "flow_with_langchain_traces"


def assert_contains_substrings(s, substrings):
for substring in substrings:
assert substring in s


class MemoryRunStorage(AbstractRunStorage):
def __init__(self):
self._node_runs = {}
Expand Down Expand Up @@ -221,6 +230,38 @@ def test_executor_exec_line(self, flow_folder, dev_connections):
assert node_run_info.node == node
assert isinstance(node_run_info.api_calls, list) # api calls is set

@pytest.mark.parametrize(
"flow_folder",
[
"python_tool_with_multiple_image_nodes"
],
)
def test_executor_exec_line_with_image(self, flow_folder, dev_connections):
self.skip_serp(flow_folder, dev_connections)
working_dir = get_yaml_working_dir(flow_folder)
os.chdir(working_dir)
storage = DefaultRunStorage(base_dir=working_dir, sub_dir=Path("./temp"))
executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections, storage=storage)
flow_result = executor.exec_line({})
assert not executor._run_tracker._flow_runs, "Flow runs in run tracker should be empty."
assert not executor._run_tracker._node_runs, "Node runs in run tracker should be empty."
assert isinstance(flow_result.output, dict)
assert flow_result.run_info.status == Status.Completed
node_count = len(executor._flow.nodes)
assert isinstance(flow_result.run_info.api_calls, list) and len(flow_result.run_info.api_calls) == node_count
substrings = ["data:image/jpg;path", ".jpg"]
for i in range(node_count):
assert_contains_substrings(str(flow_result.run_info.api_calls[i]), substrings)
assert len(flow_result.node_run_infos) == node_count
for node, node_run_info in flow_result.node_run_infos.items():
assert node_run_info.status == Status.Completed
assert node_run_info.node == node
assert isinstance(node_run_info.api_calls, list) # api calls is set
assert_contains_substrings(str(node_run_info.inputs), substrings)
assert_contains_substrings(str(node_run_info.output), substrings)
assert_contains_substrings(str(node_run_info.result), substrings)
assert_contains_substrings(str(node_run_info.api_calls[0]), substrings)

@pytest.mark.parametrize(
"flow_folder, node_name, flow_inputs, dependency_nodes_outputs",
[
Expand Down Expand Up @@ -252,6 +293,41 @@ def test_executor_exec_node(self, flow_folder, node_name, flow_inputs, dependenc
assert run_info.node == node_name
assert run_info.system_metrics["duration"] >= 0

@pytest.mark.parametrize(
"flow_folder, node_name, flow_inputs, dependency_nodes_outputs",
[
("python_tool_with_multiple_image_nodes", "python_node_2", {"logo_content": "Microsoft and four squares"},
{"python_node": {"image": {"data:image/jpg;path": "logo.jpg"}, "image_name": "Microsoft's logo",
"image_list": [{"data:image/jpg;path": "logo.jpg"}]}}),
("python_tool_with_multiple_image_nodes", "python_node", {
"image": "logo.jpg", "image_name": "Microsoft's logo"}, {},)
],
)
def test_executor_exec_node_with_image(self, flow_folder, node_name, flow_inputs, dependency_nodes_outputs,
dev_connections):
self.skip_serp(flow_folder, dev_connections)
yaml_file = get_yaml_file(flow_folder)
working_dir = get_yaml_working_dir(flow_folder)
os.chdir(working_dir)
run_info = FlowExecutor.load_and_exec_node(
yaml_file,
node_name,
flow_inputs=flow_inputs,
dependency_nodes_outputs=dependency_nodes_outputs,
connections=dev_connections,
output_sub_dir=("./temp"),
raise_ex=True,
)
substrings = ["data:image/jpg;path", "temp", ".jpg"]
assert_contains_substrings(str(run_info.inputs), substrings)
assert_contains_substrings(str(run_info.output), substrings)
assert_contains_substrings(str(run_info.result), substrings)
assert_contains_substrings(str(run_info.api_calls[0]), substrings)
assert run_info.status == Status.Completed
assert isinstance(run_info.api_calls, list)
assert run_info.node == node_name
assert run_info.system_metrics["duration"] >= 0

def test_executor_node_overrides(self, dev_connections):
inputs = self.get_line_inputs()
executor = FlowExecutor.create(
Expand Down
5 changes: 5 additions & 0 deletions src/promptflow/tests/executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def get_yaml_file(folder_name, root: str = FLOW_ROOT, file_name: str = "flow.dag
return yaml_file


def get_yaml_working_dir(folder_name, root: str = FLOW_ROOT):
flow_folder_path = Path(root) / folder_name
return flow_folder_path


def get_flow_inputs(folder_name, root: str = FLOW_ROOT):
flow_folder_path = Path(root) / folder_name
inputs = load_json(flow_folder_path / "inputs.json")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ inputs:
image:
type: image
default: logo.jpg
image_name:
type: string
default: Microsoft's logo
logo_content:
type: string
default: Microsoft and four squares
outputs:
output:
type: image
Expand All @@ -14,10 +20,12 @@ nodes:
path: python_with_image.py
inputs:
image: ${inputs.image}
image_name: ${inputs.image_name}
- name: python_node_2
type: python
source:
type: code
path: python_with_image.py
path: python_node_2.py
inputs:
image: ${python_node.output}
image_dict: ${python_node.output}
logo_content: ${inputs.logo_content}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from promptflow.contracts.multimedia import Image
from promptflow import tool


@tool
def python_with_image(image_dict: dict, logo_content: str) -> Image:
image_dict["image_list2"] = [image_dict["image"], image_dict["image"]]
image_dict["logo_content"] = logo_content
return image_dict
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@


@tool
def python_with_image(image: Image) -> Image:
return image
def python_with_image(image: Image, image_name: str) -> Image:
return {"image": image, "image_name": image_name, "image_list": [image, image]}
Loading