From 871720b250e892d1177ffd012eb2a7a27cf673d9 Mon Sep 17 00:00:00 2001 From: Lina Tang Date: Thu, 25 Apr 2024 18:45:00 +0800 Subject: [PATCH] Apply suggestions from CR --- .../promptflow/executor/_script_executor.py | 7 ++++-- .../promptflow/executor/flow_executor.py | 8 ++----- .../promptflow/executor/flow_validator.py | 22 +++++++++++++------ .../promptflow/batch/_batch_engine.py | 3 +-- .../unittests/executor/test_flow_validator.py | 2 +- 5 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index acbfb4d540f..aecbbd82ff6 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -120,7 +120,7 @@ def _exec_line_preprocess( # Executor will add line_number to batch inputs if there is no line_number in the original inputs, # which should be removed, so, we only preserve the inputs that are contained in self._inputs. inputs = {k: inputs[k] for k in self._inputs if k in inputs} - FlowValidator.ensure_flow_inputs_type(self._inputs_sign, inputs) + FlowValidator._ensure_flow_inputs_type_inner(self._inputs_sign, inputs) return run_info, inputs, run_tracker, None, [] def _exec_line( @@ -458,9 +458,12 @@ def _init_input_sign(self): with open(self._working_dir / self._flow_file, "r", encoding="utf-8") as fin: flow_dag = load_yaml(fin) flow = FlexFlow.deserialize(flow_dag) + # In the yaml file, user can define the inputs and init signature for the flow, also SDK may create + # the signature and add them to the yaml file. We need to get the signature from the yaml file and + # used for applying default value and ensuring input type. self._inputs_sign = flow.inputs self._init_sign = flow.init else: - # For function entry there is no yaml file to get the inputs and init signature. + # Since there is no yaml file for function entry, we set the inputs and init signature to empty dict. self._inputs_sign = {} self._init_sign = {} diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index 7d329a57bd7..b0fd30fd52f 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -969,9 +969,7 @@ def _exec( aggregation_inputs = {} try: if validate_inputs: - inputs = FlowValidator.ensure_flow_inputs_type( - flow_inputs=self._flow.inputs, inputs=inputs, idx=run_info.index - ) + inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=run_info.index) inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs) # Inputs are assigned after validation and multimedia data loading, instead of at the start of the flow run. # This way, if validation or multimedia data loading fails, we avoid persisting invalid inputs. @@ -1056,9 +1054,7 @@ async def _exec_async( aggregation_inputs = {} try: if validate_inputs: - inputs = FlowValidator.ensure_flow_inputs_type( - flow_inputs=self._flow.inputs, inputs=inputs, idx=run_info.index - ) + inputs = FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=inputs, idx=run_info.index) # TODO: Consider async implementation for load_multimedia_data inputs = self._multimedia_processor.load_multimedia_data(self._flow.inputs, inputs) # Inputs are assigned after validation and multimedia data loading, instead of at the start of the flow run. diff --git a/src/promptflow-core/promptflow/executor/flow_validator.py b/src/promptflow-core/promptflow/executor/flow_validator.py index 5bbe7bf1145..7b827590e73 100644 --- a/src/promptflow-core/promptflow/executor/flow_validator.py +++ b/src/promptflow-core/promptflow/executor/flow_validator.py @@ -182,9 +182,7 @@ def resolve_aggregated_flow_inputs_type(flow: Flow, inputs: Mapping[str, List[An return updated_inputs @staticmethod - def resolve_flow_inputs_type( - flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None - ) -> Mapping[str, Any]: + def resolve_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]: """Resolve inputs by type if existing. Ignore missing inputs. :param flow: The `flow` parameter is of type `Flow` and represents a flow object @@ -199,6 +197,12 @@ def resolve_flow_inputs_type( in the `flow` object. :rtype: Mapping[str, Any] """ + return FlowValidator._resolve_flow_inputs_type_inner(flow.inputs, inputs, idx) + + @staticmethod + def _resolve_flow_inputs_type_inner( + flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None + ) -> Mapping[str, Any]: updated_inputs = {k: v for k, v in inputs.items()} for k, v in flow_inputs.items(): if k in inputs: @@ -206,9 +210,7 @@ def resolve_flow_inputs_type( return updated_inputs @staticmethod - def ensure_flow_inputs_type( - flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None - ) -> Mapping[str, Any]: + def ensure_flow_inputs_type(flow: Flow, inputs: Mapping[str, Any], idx: Optional[int] = None) -> Mapping[str, Any]: """Make sure the inputs are completed and in the correct type. Raise Exception if not valid. :param flow: The `flow` parameter is of type `Flow` and represents a flow object @@ -223,6 +225,12 @@ def ensure_flow_inputs_type( type specified in the `flow` object. :rtype: Mapping[str, Any] """ + return FlowValidator._ensure_flow_inputs_type_inner(flow.inputs, inputs, idx) + + @staticmethod + def _ensure_flow_inputs_type_inner( + flow_inputs: FlowInputDefinition, inputs: Mapping[str, Any], idx: Optional[int] = None + ) -> Mapping[str, Any]: for k, _ in flow_inputs.items(): if k not in inputs: line_info = "in input data" if idx is None else f"in line {idx} of input data" @@ -232,7 +240,7 @@ def ensure_flow_inputs_type( "if it's no longer needed." ) raise InputNotFound(message_format=msg_format, input_name=k, line_info=line_info) - return FlowValidator.resolve_flow_inputs_type(flow_inputs, inputs, idx) + return FlowValidator._resolve_flow_inputs_type_inner(flow_inputs, inputs, idx) @staticmethod def convert_flow_inputs_for_node(flow: Flow, node: Node, inputs: Mapping[str, Any]) -> Mapping[str, Any]: diff --git a/src/promptflow-devkit/promptflow/batch/_batch_engine.py b/src/promptflow-devkit/promptflow/batch/_batch_engine.py index 5f49530ab65..65bdc35dcdc 100644 --- a/src/promptflow-devkit/promptflow/batch/_batch_engine.py +++ b/src/promptflow-devkit/promptflow/batch/_batch_engine.py @@ -575,8 +575,7 @@ def _get_aggregation_inputs(self, batch_inputs, line_results: List[LineResult]): succeeded_batch_inputs = [batch_inputs[i] for i in succeeded] resolved_succeeded_batch_inputs = [ - FlowValidator.ensure_flow_inputs_type(flow_inputs=self._flow.inputs, inputs=input) - for input in succeeded_batch_inputs + FlowValidator.ensure_flow_inputs_type(flow=self._flow, inputs=input) for input in succeeded_batch_inputs ] succeeded_inputs = transpose(resolved_succeeded_batch_inputs, keys=list(self._flow.inputs.keys())) aggregation_inputs = transpose( diff --git a/src/promptflow/tests/executor/unittests/executor/test_flow_validator.py b/src/promptflow/tests/executor/unittests/executor/test_flow_validator.py index 7a632b7b4a2..a864a626cf3 100644 --- a/src/promptflow/tests/executor/unittests/executor/test_flow_validator.py +++ b/src/promptflow/tests/executor/unittests/executor/test_flow_validator.py @@ -168,7 +168,7 @@ def test_resolve_flow_inputs_type_json_error_for_list_type( ): flow = get_flow_from_folder(flow_folder) with pytest.raises(error_type) as exe_info: - FlowValidator.resolve_flow_inputs_type(flow.inputs, inputs, idx=index) + FlowValidator.resolve_flow_inputs_type(flow, inputs, idx=index) assert error_message == exe_info.value.message @pytest.mark.parametrize(