Skip to content

Commit

Permalink
Support default value and type ensure for flex flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Apr 22, 2024
1 parent 6d97a0d commit 319351e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/promptflow-core/promptflow/contracts/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ class FlexFlow(FlowBase):
:type message_format: str
"""

init: Dict[str, FlowInputDefinition] = None
program_language: str = FlowLanguage.Python
environment_variables: Dict[str, object] = None
# eager flow does not support multimedia contract currently, it is set to basic by default.
Expand All @@ -962,11 +963,13 @@ def deserialize(data: dict) -> "FlexFlow":

inputs = data.get("inputs") or {}
outputs = data.get("outputs") or {}
init = data.get("init") or {}
return FlexFlow(
id=data.get("id", "default_flow_id"),
name=data.get("name", "default_flow"),
inputs={name: FlowInputDefinition.deserialize(i) for name, i in inputs.items()},
outputs={name: FlowOutputDefinition.deserialize(o) for name, o in outputs.items()},
init={name: FlowInputDefinition.deserialize(i) for name, i in init.items()},
program_language=data.get(LANGUAGE_KEY, FlowLanguage.Python),
environment_variables=data.get("environment_variables") or {},
)
Expand Down
14 changes: 13 additions & 1 deletion src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from promptflow._core.tool_meta_generator import PythonLoadError
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict
from promptflow._utils.execution_utils import apply_default_value_for_input
from promptflow._utils.logger_utils import logger
from promptflow._utils.multimedia_utils import BasicMultimediaProcessor
from promptflow._utils.tool_utils import function_to_interface
from promptflow._utils.yaml_utils import load_yaml
from promptflow.connections import ConnectionProvider
from promptflow.contracts.flow import Flow
from promptflow.contracts.flow import FlexFlow, Flow
from promptflow.contracts.tool import ConnectionType
from promptflow.core import log_metric
from promptflow.core._model_configuration import (
Expand All @@ -40,6 +41,7 @@

from ._errors import FlowEntryInitializationError, InvalidAggregationFunction, ScriptExecutionError
from .flow_executor import FlowExecutor
from .flow_validator import FlowValidator


class ScriptExecutor(FlowExecutor):
Expand All @@ -64,6 +66,7 @@ def __init__(
else:
self._working_dir = working_dir or Path.cwd()
self._initialize_function()
self._initialize_signature()
self._connections = connections
self._storage = storage or DefaultRunStorage()
self._flow_id = "default_flow_id"
Expand All @@ -90,6 +93,7 @@ def exec_line(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._flow.inputs, inputs)
with self._exec_line_context(run_id, index):
return self._exec_line(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand All @@ -115,6 +119,7 @@ def _exec_line_preprocess(
# Executor will add line_number to batch inputs if there is no line_number in the original inputs,
# which should be removed, so, we only preserve the inputs that are contained in self._inputs.
inputs = {k: inputs[k] for k in self._inputs if k in inputs}
FlowValidator.ensure_flow_inputs_type(self._flow, inputs)
return run_info, inputs, run_tracker, None, []

def _exec_line(
Expand Down Expand Up @@ -212,6 +217,7 @@ async def exec_line_async(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._flow.inputs, inputs)
with self._exec_line_context(run_id, index):
return await self._exec_line_async(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand Down Expand Up @@ -274,6 +280,7 @@ def get_inputs_definition(self):
def _resolve_init_kwargs(self, c: type, init_kwargs: dict):
"""Resolve init kwargs, the connection names will be resolved to connection objects."""
logger.debug(f"Resolving init kwargs: {init_kwargs.keys()}.")
init_kwargs = apply_default_value_for_input(self._flow.init, init_kwargs)
sig = inspect.signature(c.__init__)
connection_params = []
model_config_param_name_2_cls = {}
Expand Down Expand Up @@ -436,3 +443,8 @@ def _parse_flow_file(self):
target=ErrorTarget.EXECUTOR,
) from e
return module_name, func_name

def _initialize_flow(self):
with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin:
flow_dag = load_yaml(fin)
self._flow = FlexFlow.deserialize(flow_dag)

0 comments on commit 319351e

Please sign in to comment.