Skip to content

Commit

Permalink
[Executor] Support applying default value and ensuring type for flex …
Browse files Browse the repository at this point in the history
…flow (#2923)

# Description

In this PR, we support applying default value and ensuring type provided
for flex flow. If the input is not provided then we will apply the
default value defined in the yaml, also we will check whether the input
type is consistent with the type defined in yaml. An example yaml file:
```
entry: my_flow:MyClass
init:
  input_init:
    type: string
    default: input_init
inputs:
  input_1:
    type: string
  input_2:
    type: string
    default: input_2
```

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Lina Tang <[email protected]>
  • Loading branch information
2 people authored and crazygao committed May 6, 2024
1 parent c976236 commit 3a0ee93
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/promptflow-core/promptflow/contracts/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ class FlexFlow(FlowBase):
:type message_format: str
"""

init: Dict[str, FlowInputDefinition] = None
program_language: str = FlowLanguage.Python
environment_variables: Dict[str, object] = None
# eager flow does not support multimedia contract currently, it is set to basic by default.
Expand All @@ -962,11 +963,13 @@ def deserialize(data: dict) -> "FlexFlow":

inputs = data.get("inputs") or {}
outputs = data.get("outputs") or {}
init = data.get("init") or {}
return FlexFlow(
id=data.get("id", "default_flow_id"),
name=data.get("name", "default_flow"),
inputs={name: FlowInputDefinition.deserialize(i) for name, i in inputs.items()},
outputs={name: FlowOutputDefinition.deserialize(o) for name, o in outputs.items()},
init={name: FlowInputDefinition.deserialize(i) for name, i in init.items()},
program_language=data.get(LANGUAGE_KEY, FlowLanguage.Python),
environment_variables=data.get("environment_variables") or {},
)
Expand Down
8 changes: 8 additions & 0 deletions src/promptflow-core/promptflow/executor/_prompty_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional

from promptflow._utils.logger_utils import logger
from promptflow.contracts.flow import PromptyFlow
from promptflow.contracts.tool import InputDefinition
from promptflow.core._flow import Prompty
from promptflow.storage import AbstractRunStorage
Expand Down Expand Up @@ -50,3 +51,10 @@ def _initialize_function(self):
self._inputs = {k: v.to_flow_input_definition() for k, v in inputs.items()}
self._is_async = False
return self._func

def _init_input_sign(self):
configs, _ = Prompty._parse_prompty(self._working_dir / self._flow_file)
flow = PromptyFlow.deserialize(configs)
self._inputs_sign = flow.inputs
# The init signature only used for flex flow, so we set the _init_sign to empty dict for prompty flow.
self._init_sign = {}
24 changes: 23 additions & 1 deletion src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from promptflow._utils.async_utils import async_to_sync, sync_to_async
from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict
from promptflow._utils.exception_utils import ExceptionPresenter
from promptflow._utils.execution_utils import apply_default_value_for_input
from promptflow._utils.logger_utils import logger
from promptflow._utils.multimedia_utils import BasicMultimediaProcessor
from promptflow._utils.tool_utils import function_to_interface
from promptflow._utils.yaml_utils import load_yaml
from promptflow.connections import ConnectionProvider
from promptflow.contracts.flow import Flow
from promptflow.contracts.flow import FlexFlow, Flow
from promptflow.contracts.tool import ConnectionType
from promptflow.core import log_metric
from promptflow.core._model_configuration import (
Expand All @@ -40,6 +41,7 @@

from ._errors import FlowEntryInitializationError, InvalidAggregationFunction, ScriptExecutionError
from .flow_executor import FlowExecutor
from .flow_validator import FlowValidator


class ScriptExecutor(FlowExecutor):
Expand All @@ -63,6 +65,7 @@ def __init__(
self._working_dir = Flow._resolve_working_dir(entry, working_dir)
else:
self._working_dir = working_dir or Path.cwd()
self._init_input_sign()
self._initialize_function()
self._connections = connections
self._storage = storage or DefaultRunStorage()
Expand Down Expand Up @@ -100,6 +103,7 @@ def exec_line(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._inputs_sign, inputs)
with self._exec_line_context(run_id, index):
return self._exec_line(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand All @@ -125,6 +129,7 @@ def _exec_line_preprocess(
# Executor will add line_number to batch inputs if there is no line_number in the original inputs,
# which should be removed, so, we only preserve the inputs that are contained in self._inputs.
inputs = {k: inputs[k] for k in self._inputs if k in inputs}
FlowValidator._ensure_flow_inputs_type_inner(self._inputs_sign, inputs)
return run_info, inputs, run_tracker, None, []

def _exec_line(
Expand Down Expand Up @@ -263,6 +268,7 @@ async def exec_line_async(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._inputs_sign, inputs)
with self._exec_line_context(run_id, index):
return await self._exec_line_async(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand Down Expand Up @@ -321,6 +327,7 @@ def get_inputs_definition(self):
def _resolve_init_kwargs(self, c: type, init_kwargs: dict):
"""Resolve init kwargs, the connection names will be resolved to connection objects."""
logger.debug(f"Resolving init kwargs: {init_kwargs.keys()}.")
init_kwargs = apply_default_value_for_input(self._init_sign, init_kwargs)
sig = inspect.signature(c.__init__)
connection_params = []
model_config_param_name_2_cls = {}
Expand Down Expand Up @@ -495,3 +502,18 @@ def _parse_flow_file(self):
target=ErrorTarget.EXECUTOR,
) from e
return module_name, func_name

def _init_input_sign(self):
if not self.is_function_entry:
with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin:
flow_dag = load_yaml(fin)
flow = FlexFlow.deserialize(flow_dag)
# In the yaml file, user can define the inputs and init signature for the flow, also SDK may create
# the signature and add them to the yaml file. We need to get the signature from the yaml file and
# used for applying default value and ensuring input type.
self._inputs_sign = flow.inputs
self._init_sign = flow.init
else:
# Since there is no yaml file for function entry, we set the inputs and init signature to empty dict.
self._inputs_sign = {}
self._init_sign = {}
20 changes: 16 additions & 4 deletions src/promptflow-core/promptflow/executor/flow_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, List, Mapping, Optional

from promptflow._utils.logger_utils import logger
from promptflow.contracts.flow import Flow, InputValueType, Node
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputValueType, Node
from promptflow.contracts.tool import ValueType
from promptflow.executor._errors import (
DuplicateNodeName,
Expand Down Expand Up @@ -197,8 +197,14 @@ def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optiona
in the `flow` object.
:rtype: Mapping[str, Any]
"""
return FlowValidator._resolve_flow_inputs_type_inner(flow.inputs, inputs, idx)

@staticmethod
def _resolve_flow_inputs_type_inner(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
updated_inputs = {k: v for k, v in inputs.items()}
for k, v in flow.inputs.items():
for k, v in flow_inputs.items():
if k in inputs:
updated_inputs[k] = FlowValidator._parse_input_value(k, inputs[k], v.type, idx)
return updated_inputs
Expand All @@ -219,7 +225,13 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional
type specified in the `flow` object.
:rtype: Mapping[str, Any]
"""
for k, v in flow.inputs.items():
return FlowValidator._ensure_flow_inputs_type_inner(flow.inputs, inputs, idx)

@staticmethod
def _ensure_flow_inputs_type_inner(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
for k, _ in flow_inputs.items():
if k not in inputs:
line_info = "in input data" if idx is None else f"in line {idx} of input data"
msg_format = (
Expand All @@ -228,7 +240,7 @@ def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional
"if it's no longer needed."
)
raise InputNotFound(message_format=msg_format, input_name=k, line_info=line_info)
return FlowValidator.resolve_flow_inputs_type(flow, inputs, idx)
return FlowValidator._resolve_flow_inputs_type_inner(flow_inputs, inputs, idx)

@staticmethod
def convert_flow_inputs_for_node(flow: Flow, node: Node, inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Expand Down
28 changes: 27 additions & 1 deletion src/promptflow-core/tests/core/e2etests/test_eager_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from promptflow._core.tool_meta_generator import PythonLoadError
from promptflow.contracts.run_info import Status
from promptflow.core import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
from promptflow.executor._errors import FlowEntryInitializationError, InvalidFlexFlowEntry
from promptflow.executor._errors import (
FlowEntryInitializationError,
InputNotFound,
InputTypeError,
InvalidFlexFlowEntry,
)
from promptflow.executor._result import LineResult
from promptflow.executor._script_executor import ScriptExecutor
from promptflow.executor.flow_executor import FlowExecutor
Expand Down Expand Up @@ -93,6 +98,12 @@ class TestEagerFlow:
"open_ai_model_config": OpenAIModelConfiguration(model="my_model", base_url="fake_base_url"),
},
),
(
"flow_with_signature",
{"input_1": "input_1"},
lambda x: x["output"] == "input_2",
None,
),
],
)
def test_flow_run(self, flow_folder, inputs, ensure_output, init_kwargs):
Expand Down Expand Up @@ -150,6 +161,21 @@ async def test_flow_run_with_function_entry_async(self, entry, inputs, expected_
msg = f"The two tasks should run concurrently, but got {delta_desc}"
assert 0 <= delta_sec < 0.1, msg

def test_flow_run_with_invalid_inputs(self):
# Case 1: input not found
flow_file = get_yaml_file("flow_with_signature", root=EAGER_FLOW_ROOT)
executor = FlowExecutor.create(flow_file=flow_file, connections={}, init_kwargs=None)
with pytest.raises(InputNotFound) as e:
executor.exec_line(inputs={}, index=0)
assert "The input for flow is incorrect." in str(e.value)

# Case 2: input type mismatch
flow_file = get_yaml_file("flow_with_wrong_type", root=EAGER_FLOW_ROOT)
executor = FlowExecutor.create(flow_file=flow_file, connections={}, init_kwargs=None)
with pytest.raises(InputTypeError) as e:
executor.exec_line(inputs={"input_1": 1}, index=0)
assert "does not match the expected type" in str(e.value)

def test_flow_run_with_invalid_case(self):
flow_folder = "dummy_flow_with_exception"
flow_file = get_yaml_file(flow_folder, root=EAGER_FLOW_ROOT)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
entry: my_flow:MyClass
init:
input_init:
type: string
default: input_init
inputs:
input_1:
type: string
input_2:
type: string
default: input_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"input_1": "input_1"}
{"input_1": "input_1"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class MyClass:
def __init__(self, input_init: str = "default_input_init"):
pass

def __call__(self, input_1, input_2: str = "default_input_2"):
return {"output": input_2}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
entry: my_flow:MyClass
init:
input_init:
type: string
default: input_init
inputs:
input_1:
type: string
input_2:
type: int
default: input_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class MyClass:
def __init__(self, input_init: str = "default_input_init"):
pass

def __call__(self, input_1, input_2: str = "default_input_2"):
return {"output": input_2}

0 comments on commit 3a0ee93

Please sign in to comment.