Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Executor] Support applying default value and ensuring type for flex flow #2923

Merged
merged 24 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/promptflow-core/promptflow/contracts/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ class FlexFlow(FlowBase):
:type message_format: str
"""

init: Dict[str, FlowInputDefinition] = None
program_language: str = FlowLanguage.Python
environment_variables: Dict[str, object] = None
# eager flow does not support multimedia contract currently, it is set to basic by default.
Expand All @@ -962,11 +963,13 @@ def deserialize(data: dict) -> "FlexFlow":

inputs = data.get("inputs") or {}
outputs = data.get("outputs") or {}
init = data.get("init") or {}
return FlexFlow(
id=data.get("id", "default_flow_id"),
name=data.get("name", "default_flow"),
inputs={name: FlowInputDefinition.deserialize(i) for name, i in inputs.items()},
outputs={name: FlowOutputDefinition.deserialize(o) for name, o in outputs.items()},
init={name: FlowInputDefinition.deserialize(i) for name, i in init.items()},
program_language=data.get(LANGUAGE_KEY, FlowLanguage.Python),
environment_variables=data.get("environment_variables") or {},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional

from promptflow._utils.logger_utils import logger
from promptflow.contracts.flow import PromptyFlow
from promptflow.contracts.tool import InputDefinition
from promptflow.core._flow import Prompty
from promptflow.storage import AbstractRunStorage
Expand Down Expand Up @@ -50,3 +51,10 @@ def _initialize_function(self):
self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()}
self._is_async = False
return self._func

def _init_input_sign(self):
lumoslnt marked this conversation as resolved.
Show resolved Hide resolved
configs, _ = Prompty._parse_prompty(self._working_dir / self._flow_file)
flow = PromptyFlow.deserialize(configs)
self._inputs_sign = flow.inputs
# The init signature only used for flex flow, so we set the _init_sign to empty dict for prompty flow.
self._init_sign = {}
24 changes: 23 additions & 1 deletion src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from promptflow._utils.async_utils import async_to_sync, sync_to_async
from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict
from promptflow._utils.exception_utils import ExceptionPresenter
from promptflow._utils.execution_utils import apply_default_value_for_input
from promptflow._utils.logger_utils import logger
from promptflow._utils.multimedia_utils import BasicMultimediaProcessor
from promptflow._utils.tool_utils import function_to_interface
from promptflow._utils.yaml_utils import load_yaml
from promptflow.connections import ConnectionProvider
from promptflow.contracts.flow import Flow
from promptflow.contracts.flow import FlexFlow, Flow
from promptflow.contracts.tool import ConnectionType
from promptflow.core import log_metric
from promptflow.core._model_configuration import (
Expand All @@ -40,6 +41,7 @@

from ._errors import FlowEntryInitializationError, InvalidAggregationFunction, ScriptExecutionError
from .flow_executor import FlowExecutor
from .flow_validator import FlowValidator


class ScriptExecutor(FlowExecutor):
Expand All @@ -63,6 +65,7 @@ def __init__(
self._working_dir = Flow._resolve_working_dir(entry, working_dir)
else:
self._working_dir = working_dir or Path.cwd()
self._init_input_sign()
self._initialize_function()
self._connections = connections
self._storage = storage or DefaultRunStorage()
Expand Down Expand Up @@ -100,6 +103,7 @@ def exec_line(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._inputs_sign, inputs)
with self._exec_line_context(run_id, index):
return self._exec_line(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand All @@ -125,6 +129,7 @@ def _exec_line_preprocess(
# Executor will add line_number to batch inputs if there is no line_number in the original inputs,
# which should be removed, so, we only preserve the inputs that are contained in self._inputs.
inputs = {k: inputs[k] for k in self._inputs if k in inputs}
FlowValidator._ensure_flow_inputs_type_inner(self._inputs_sign, inputs)
return run_info, inputs, run_tracker, None, []

def _exec_line(
Expand Down Expand Up @@ -263,6 +268,7 @@ async def exec_line_async(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._inputs_sign, inputs)
with self._exec_line_context(run_id, index):
return await self._exec_line_async(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand Down Expand Up @@ -321,6 +327,7 @@ def get_inputs_definition(self):
def _resolve_init_kwargs(self, c: type, init_kwargs: dict):
"""Resolve init kwargs, the connection names will be resolved to connection objects."""
logger.debug(f"Resolving init kwargs: {init_kwargs.keys()}.")
init_kwargs = apply_default_value_for_input(self._init_sign, init_kwargs)
sig = inspect.signature(c.__init__)
connection_params = []
model_config_param_name_2_cls = {}
Expand Down Expand Up @@ -495,3 +502,18 @@ def _parse_flow_file(self):
target=ErrorTarget.EXECUTOR,
) from e
return module_name, func_name

def _init_input_sign(self):
if not self.is_function_entry:
with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin:
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed
flow_dag = load_yaml(fin)
flow = FlexFlow.deserialize(flow_dag)
# In the yaml file, user can define the inputs and init signature for the flow, also SDK may create
# the signature and add them to the yaml file. We need to get the signature from the yaml file and
# used for applying default value and ensuring input type.
self._inputs_sign = flow.inputs
self._init_sign = flow.init
else:
# Since there is no yaml file for function entry, we set the inputs and init signature to empty dict.
self._inputs_sign = {}
lumoslnt marked this conversation as resolved.
Show resolved Hide resolved
self._init_sign = {}
20 changes: 16 additions & 4 deletions src/promptflow-core/promptflow/executor/flow_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, List, Mapping, Optional

from promptflow._utils.logger_utils import logger
from promptflow.contracts.flow import Flow, InputValueType, Node
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputValueType, Node
from promptflow.contracts.tool import ValueType
from promptflow.executor._errors import (
DuplicateNodeName,
Expand Down Expand Up @@ -197,8 +197,14 @@ def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optiona
in the `flow` object.
:rtype: Mapping[str, Any]
"""
return FlowValidator._resolve_flow_inputs_type_inner(flow.inputs, inputs, idx)

@staticmethod
def _resolve_flow_inputs_type_inner(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
updated_inputs = {k: v for k, v in inputs.items()}
for k, v in flow.inputs.items():
for k, v in flow_inputs.items():
if k in inputs:
updated_inputs[k] = FlowValidator._parse_input_value(k, inputs[k], v.type, idx)
return updated_inputs
Expand All @@ -219,7 +225,13 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional
type specified in the `flow` object.
:rtype: Mapping[str, Any]
"""
for k, v in flow.inputs.items():
return FlowValidator._ensure_flow_inputs_type_inner(flow.inputs, inputs, idx)

@staticmethod
def _ensure_flow_inputs_type_inner(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
for k, _ in flow_inputs.items():
if k not in inputs:
line_info = "in input data" if idx is None else f"in line {idx} of input data"
msg_format = (
Expand All @@ -228,7 +240,7 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional
"if it's no longer needed."
)
raise InputNotFound(message_format=msg_format, input_name=k, line_info=line_info)
return FlowValidator.resolve_flow_inputs_type(flow, inputs, idx)
return FlowValidator._resolve_flow_inputs_type_inner(flow_inputs, inputs, idx)

@staticmethod
def convert_flow_inputs_for_node(flow: Flow, node: Node, inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Expand Down
28 changes: 27 additions & 1 deletion src/promptflow-core/tests/core/e2etests/test_eager_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from promptflow._core.tool_meta_generator import PythonLoadError
from promptflow.contracts.run_info import Status
from promptflow.core import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
from promptflow.executor._errors import FlowEntryInitializationError, InvalidFlexFlowEntry
from promptflow.executor._errors import (
FlowEntryInitializationError,
InputNotFound,
InputTypeError,
InvalidFlexFlowEntry,
)
from promptflow.executor._result import LineResult
from promptflow.executor._script_executor import ScriptExecutor
from promptflow.executor.flow_executor import FlowExecutor
Expand Down Expand Up @@ -93,6 +98,12 @@ class TestEagerFlow:
"open_ai_model_config": OpenAIModelConfiguration(model="my_model", base_url="fake_base_url"),
},
),
(
"flow_with_signature",
{"input_1": "input_1"},
lambda x: x["output"] == "input_2",
None,
),
],
)
def test_flow_run(self, flow_folder, inputs, ensure_output, init_kwargs):
Expand Down Expand Up @@ -150,6 +161,21 @@ async def test_flow_run_with_function_entry_async(self, entry, inputs, expected_
msg = f"The two tasks should run concurrently, but got {delta_desc}"
assert 0 <= delta_sec < 0.1, msg

def test_flow_run_with_invalid_inputs(self):
# Case 1: input not found
flow_file = get_yaml_file("flow_with_signature", root=EAGER_FLOW_ROOT)
executor = FlowExecutor.create(flow_file=flow_file, connections={}, init_kwargs=None)
with pytest.raises(InputNotFound) as e:
executor.exec_line(inputs={}, index=0)
assert "The input for flow is incorrect." in str(e.value)

# Case 2: input type mismatch
flow_file = get_yaml_file("flow_with_wrong_type", root=EAGER_FLOW_ROOT)
executor = FlowExecutor.create(flow_file=flow_file, connections={}, init_kwargs=None)
with pytest.raises(InputTypeError) as e:
executor.exec_line(inputs={"input_1": 1}, index=0)
assert "does not match the expected type" in str(e.value)

def test_flow_run_with_invalid_case(self):
flow_folder = "dummy_flow_with_exception"
flow_file = get_yaml_file(flow_folder, root=EAGER_FLOW_ROOT)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
entry: my_flow:MyClass
init:
input_init:
type: string
default: input_init
inputs:
input_1:
type: string
input_2:
type: string
default: input_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"input_1": "input_1"}
{"input_1": "input_1"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class MyClass:
def __init__(self, input_init: str = "default_input_init"):
pass

def __call__(self, input_1, input_2: str = "default_input_2"):
return {"output": input_2}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
entry: my_flow:MyClass
init:
input_init:
type: string
default: input_init
inputs:
input_1:
type: string
input_2:
type: int
default: input_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class MyClass:
def __init__(self, input_init: str = "default_input_init"):
pass

def __call__(self, input_1, input_2: str = "default_input_2"):
return {"output": input_2}
Loading