Skip to content

Commit

Permalink
Apply suggestions from CR
Browse files Browse the repository at this point in the history
  • Loading branch information
Lina Tang committed Apr 25, 2024
1 parent fb8915f commit 35e20ab
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 18 deletions.
7 changes: 5 additions & 2 deletions src/promptflow-core/promptflow/executor/_script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _exec_line_preprocess(
# Executor will add line_number to batch inputs if there is no line_number in the original inputs,
# which should be removed, so, we only preserve the inputs that are contained in self._inputs.
inputs = {k: inputs[k] for k in self._inputs if k in inputs}
FlowValidator.ensure_flow_inputs_type(self._inputs_sign, inputs)
FlowValidator._ensure_flow_inputs_type_inner(self._inputs_sign, inputs)
return run_info, inputs, run_tracker, None, []

def _exec_line(
Expand Down Expand Up @@ -458,9 +458,12 @@ def _init_input_sign(self):
with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin:
flow_dag = load_yaml(fin)
flow = FlexFlow.deserialize(flow_dag)
# In the yaml file, user can define the inputs and init signature for the flow, and SDK may create
# the signature and add them to the yaml file. We need to get the signature from the yaml file and
# used for applying default value and ensuring input type.
self._inputs_sign = flow.inputs
self._init_sign = flow.init
else:
# For function entry there is no yaml file to get the inputs and init signature.
# Since there is no yaml file for function entry, we set the inputs and init signature to empty dict.
self._inputs_sign = {}
self._init_sign = {}
8 changes: 2 additions & 6 deletions src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,7 @@ def _exec(
aggregation_inputs = {}
try:
if validate_inputs:
inputs = FlowValidator.ensure_flow_inputs_type(
flow_inputs=self._flow.inputs, inputs=inputs, idx=run_info.index
)
inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=run_info.index)
inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs)
# Inputs are assigned after validation and multimedia data loading, instead of at the start of the flow run.
# This way, if validation or multimedia data loading fails, we avoid persisting invalid inputs.
Expand Down Expand Up @@ -1056,9 +1054,7 @@ async def _exec_async(
aggregation_inputs = {}
try:
if validate_inputs:
inputs = FlowValidator.ensure_flow_inputs_type(
flow_inputs=self._flow.inputs, inputs=inputs, idx=run_info.index
)
inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=run_info.index)
# TODO: Consider async implementation for load_multimedia_data
inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs)
# Inputs are assigned after validation and multimedia data loading, instead of at the start of the flow run.
Expand Down
22 changes: 15 additions & 7 deletions src/promptflow-core/promptflow/executor/flow_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ def resolve_aggregated_flow_inputs_type(flow: Flow, inputs: Mapping[str, List[An
return updated_inputs

@staticmethod
def resolve_flow_inputs_type(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]:
"""Resolve inputs by type if existing. Ignore missing inputs.
:param flow: The `flow` parameter is of type `Flow` and represents a flow object
Expand All @@ -199,16 +197,20 @@ def resolve_flow_inputs_type(
in the `flow` object.
:rtype: Mapping[str, Any]
"""
return FlowValidator._resolve_flow_inputs_type_inner(flow.inputs, inputs, idx)

@staticmethod
def _resolve_flow_inputs_type_inner(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
updated_inputs = {k: v for k, v in inputs.items()}
for k, v in flow_inputs.items():
if k in inputs:
updated_inputs[k] = FlowValidator._parse_input_value(k, inputs[k], v.type, idx)
return updated_inputs

@staticmethod
def ensure_flow_inputs_type(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]:
"""Make sure the inputs are completed and in the correct type. Raise Exception if not valid.
:param flow: The `flow` parameter is of type `Flow` and represents a flow object
Expand All @@ -223,6 +225,12 @@ def ensure_flow_inputs_type(
type specified in the `flow` object.
:rtype: Mapping[str, Any]
"""
return FlowValidator._ensure_flow_inputs_type_inner(flow.inputs, inputs, idx)

@staticmethod
def _ensure_flow_inputs_type_inner(
flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None
) -> Mapping[str, Any]:
for k, _ in flow_inputs.items():
if k not in inputs:
line_info = "in input data" if idx is None else f"in line {idx} of input data"
Expand All @@ -232,7 +240,7 @@ def ensure_flow_inputs_type(
"if it's no longer needed."
)
raise InputNotFound(message_format=msg_format, input_name=k, line_info=line_info)
return FlowValidator.resolve_flow_inputs_type(flow_inputs, inputs, idx)
return FlowValidator._resolve_flow_inputs_type_inner(flow_inputs, inputs, idx)

@staticmethod
def convert_flow_inputs_for_node(flow: Flow, node: Node, inputs: Mapping[str, Any]) -> Mapping[str, Any]:
Expand Down
3 changes: 1 addition & 2 deletions src/promptflow-devkit/promptflow/batch/_batch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,7 @@ def _get_aggregation_inputs(self, batch_inputs, line_results: List[LineResult]):

succeeded_batch_inputs = [batch_inputs[i] for i in succeeded]
resolved_succeeded_batch_inputs = [
FlowValidator.ensure_flow_inputs_type(flow_inputs=self._flow.inputs, inputs=input)
for input in succeeded_batch_inputs
FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=input) for input in succeeded_batch_inputs
]
succeeded_inputs = transpose(resolved_succeeded_batch_inputs, keys=list(self._flow.inputs.keys()))
aggregation_inputs = transpose(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_resolve_flow_inputs_type_json_error_for_list_type(
):
flow = get_flow_from_folder(flow_folder)
with pytest.raises(error_type) as exe_info:
FlowValidator.resolve_flow_inputs_type(flow.inputs, inputs, idx=index)
FlowValidator.resolve_flow_inputs_type(flow, inputs, idx=index)
assert error_message == exe_info.value.message

@pytest.mark.parametrize(
Expand Down

0 comments on commit 35e20ab

Please sign in to comment.