Skip to content

Commit

Permalink
Support default value for flex flow
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Apr 22, 2024
1 parent 6d97a0d commit d5cc9c0
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from promptflow._core.tool_meta_generator import PythonLoadError
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict
from promptflow._utils.execution_utils import apply_default_value_for_input
from promptflow._utils.logger_utils import logger
from promptflow._utils.multimedia_utils import BasicMultimediaProcessor
from promptflow._utils.tool_utils import function_to_interface
from promptflow._utils.yaml_utils import load_yaml
from promptflow.connections import ConnectionProvider
from promptflow.contracts.flow import Flow
from promptflow.contracts.flow import Flow, FlowInputDefinition
from promptflow.contracts.tool import ConnectionType
from promptflow.core import log_metric
from promptflow.core._model_configuration import (
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
else:
self._working_dir = working_dir or Path.cwd()
self._initialize_function()
self._initialize_signature()
self._connections = connections
self._storage = storage or DefaultRunStorage()
self._flow_id = "default_flow_id"
Expand All @@ -90,6 +92,7 @@ def exec_line(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._yaml_inputs, inputs)
with self._exec_line_context(run_id, index):
return self._exec_line(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand Down Expand Up @@ -212,6 +215,7 @@ async def exec_line_async(
**kwargs,
) -> LineResult:
run_id = run_id or str(uuid.uuid4())
inputs = apply_default_value_for_input(self._yaml_inputs, inputs)
with self._exec_line_context(run_id, index):
return await self._exec_line_async(inputs, index, run_id, allow_generator_output=allow_generator_output)

Expand Down Expand Up @@ -274,6 +278,7 @@ def get_inputs_definition(self):
def _resolve_init_kwargs(self, c: type, init_kwargs: dict):
"""Resolve init kwargs, the connection names will be resolved to connection objects."""
logger.debug(f"Resolving init kwargs: {init_kwargs.keys()}.")
init_kwargs = apply_default_value_for_input(self._yaml_init, init_kwargs)
sig = inspect.signature(c.__init__)
connection_params = []
model_config_param_name_2_cls = {}
Expand Down Expand Up @@ -436,3 +441,9 @@ def _parse_flow_file(self):
target=ErrorTarget.EXECUTOR,
) from e
return module_name, func_name

def _initialize_signature(self):
with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin:

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
flow_dag = load_yaml(fin)
self._yaml_inputs = {k: FlowInputDefinition.deserialize(v) for k, v in flow_dag.get("inputs", {}).items()}
self._yaml_init = {k: FlowInputDefinition.deserialize(v) for k, v in flow_dag.get("init", {}).items()}

0 comments on commit d5cc9c0

Please sign in to comment.