Skip to content

Commit

Permalink
Improve error message for aggregation input's validation. (#563)
Browse files Browse the repository at this point in the history
# Description

Improve error message for aggregation input's validation.
And move these validation code to flow_validator.py
All of these errors are system error.

# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes]**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [X] Pull request includes test coverage for the included changes.

Co-authored-by: Robben Wang <[email protected]>
  • Loading branch information
huaiyan and Robben Wang authored Sep 22, 2023
1 parent 0edcd7d commit be4c4bb
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 72 deletions.
32 changes: 1 addition & 31 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from promptflow.executor import _input_assignment_parser
from promptflow.executor._errors import (
InputMappingError,
InvalidAggregationInput,
NodeOutputNotFound,
OutputReferenceBypassed,
OutputReferenceNotExist,
Expand Down Expand Up @@ -554,43 +553,14 @@ def exec_aggregation(
self._node_concurrency = node_concurrency
aggregated_flow_inputs = dict(inputs or {})
aggregation_inputs = dict(aggregation_inputs or {})
self._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs)
FlowValidator._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs)
aggregated_flow_inputs = self._apply_default_value_for_aggregation_input(
self._flow.inputs, aggregated_flow_inputs, aggregation_inputs
)

with self._run_tracker.node_log_manager:
return self._exec_aggregation(aggregated_flow_inputs, aggregation_inputs, run_id)

@staticmethod
def _validate_aggregation_inputs(aggregated_flow_inputs: Mapping[str, Any], aggregation_inputs: Mapping[str, Any]):
"""Validate the aggregation inputs according to the flow inputs."""
for key, value in aggregated_flow_inputs.items():
if key in aggregation_inputs:
raise InvalidAggregationInput(
message_format="Input '{input_key}' appear in both flow aggregation input and aggregation input.",
input_key=key,
)
if not isinstance(value, list):
raise InvalidAggregationInput(
message_format="Flow aggregation input {input_key} should be one list.", input_key=key
)

for key, value in aggregation_inputs.items():
if not isinstance(value, list):
raise InvalidAggregationInput(
message_format="Aggregation input {input_key} should be one list.", input_key=key
)

inputs_len = {key: len(value) for key, value in aggregated_flow_inputs.items()}
inputs_len.update({key: len(value) for key, value in aggregation_inputs.items()})
if len(set(inputs_len.values())) > 1:
raise InvalidAggregationInput(
message_format="Whole aggregation inputs should have the same length. "
"Current key length mapping are: {key_len}",
key_len=inputs_len,
)

@staticmethod
def _apply_default_value_for_aggregation_input(
inputs: Dict[str, FlowInputDefinition],
Expand Down
49 changes: 49 additions & 0 deletions src/promptflow/promptflow/executor/flow_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
InputParseError,
InputReferenceNotFound,
InputTypeError,
InvalidAggregationInput,
NodeCircularDependency,
NodeReferenceNotFound,
OutputReferenceNotFound,
Expand Down Expand Up @@ -238,6 +239,54 @@ def convert_flow_inputs_for_node(flow: Flow, node: Node, inputs: Mapping[str, An
) from e
return updated_inputs

@staticmethod
def _validate_aggregation_inputs(aggregated_flow_inputs: Mapping[str, Any], aggregation_inputs: Mapping[str, Any]):
"""Validate the aggregation inputs according to the flow inputs."""
for key, value in aggregated_flow_inputs.items():
if key in aggregation_inputs:
raise InvalidAggregationInput(
message_format=(
"The input for aggregation is incorrect. The input '{input_key}' appears in both "
"aggregated flow input and aggregated reference input. "
"Please remove one of them and try the operation again."
),
input_key=key,
)
if not isinstance(value, list):
raise InvalidAggregationInput(
message_format=(
"The input for aggregation is incorrect. "
"The value for aggregated flow input '{input_key}' should be a list, "
"but received {value_type}. Please adjust the input value to match the expected format."
),
input_key=key,
value_type=type(value).__name__,
)

for key, value in aggregation_inputs.items():
if not isinstance(value, list):
raise InvalidAggregationInput(
message_format=(
"The input for aggregation is incorrect. "
"The value for aggregated reference input '{input_key}' should be a list, "
"but received {value_type}. Please adjust the input value to match the expected format."
),
input_key=key,
value_type=type(value).__name__,
)

inputs_len = {key: len(value) for key, value in aggregated_flow_inputs.items()}
inputs_len.update({key: len(value) for key, value in aggregation_inputs.items()})
if len(set(inputs_len.values())) > 1:
raise InvalidAggregationInput(
message_format=(
"The input for aggregation is incorrect. "
"The length of all aggregated inputs should be the same. Current input lengths are: "
"{key_len}. Please adjust the input value in your input data."
),
key_len=inputs_len,
)

@staticmethod
def _ensure_outputs_valid(flow: Flow):
updated_outputs = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from promptflow._core._errors import UnexpectedError
from promptflow.contracts.flow import Flow, FlowInputDefinition
from promptflow.contracts.tool import ValueType
from promptflow.executor._errors import InvalidAggregationInput
from promptflow.executor._line_execution_process_pool import get_available_max_worker_count
from promptflow.executor.flow_executor import (
FlowExecutor,
Expand Down Expand Up @@ -408,45 +407,6 @@ def test_apply_default_value_for_aggregation_input(
)
assert result == expected_inputs

@pytest.mark.parametrize(
"aggregated_flow_inputs, aggregation_inputs, error_message",
[
(
{},
{
"input1": "value1",
},
"Aggregation input input1 should be one list.",
),
(
{
"input1": "value1",
},
{},
"Flow aggregation input input1 should be one list.",
),
(
{"input1": ["value1_1", "value1_2"]},
{"input_2": ["value2_1"]},
"Whole aggregation inputs should have the same length. "
"Current key length mapping are: {'input1': 2, 'input_2': 1}",
),
(
{
"input1": "value1",
},
{
"input1": "value1",
},
"Input 'input1' appear in both flow aggregation input and aggregation input.",
),
],
)
def test_validate_aggregation_inputs_error(self, aggregated_flow_inputs, aggregation_inputs, error_message):
with pytest.raises(InvalidAggregationInput) as e:
FlowExecutor._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs)
assert str(e.value) == error_message


def func_with_stream_parameter(a: int, b: int, stream=False):
return a + b, stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import yaml

from promptflow.contracts.flow import Flow
from promptflow.executor._errors import InputParseError, InvalidFlowRequest
from promptflow.executor._errors import InputParseError, InvalidAggregationInput, InvalidFlowRequest
from promptflow.executor.flow_validator import FlowValidator

from ...utils import WRONG_FLOW_ROOT, get_yaml_file
Expand Down Expand Up @@ -71,6 +71,52 @@ def test_ensure_nodes_order_with_exception(self, flow_folder, error_message):
FlowValidator._ensure_nodes_order(flow)
assert str(e.value) == error_message, "Expected: {}, Actual: {}".format(error_message, str(e.value))

@pytest.mark.parametrize(
"aggregated_flow_inputs, aggregation_inputs, error_message",
[
(
{},
{
"input1": "value1",
},
"The input for aggregation is incorrect. "
"The value for aggregated reference input 'input1' should be a list, "
"but received str. Please adjust the input value to match the expected format.",
),
(
{
"input1": "value1",
},
{},
"The input for aggregation is incorrect. "
"The value for aggregated flow input 'input1' should be a list, "
"but received str. Please adjust the input value to match the expected format.",
),
(
{"input1": ["value1_1", "value1_2"]},
{"input_2": ["value2_1"]},
"The input for aggregation is incorrect. The length of all aggregated inputs should be the same. "
"Current input lengths are: {'input1': 2, 'input_2': 1}. "
"Please adjust the input value in your input data.",
),
(
{
"input1": "value1",
},
{
"input1": "value1",
},
"The input for aggregation is incorrect. "
"The input 'input1' appears in both aggregated flow input and aggregated reference input. "
"Please remove one of them and try the operation again.",
),
],
)
def test_validate_aggregation_inputs_error(self, aggregated_flow_inputs, aggregation_inputs, error_message):
with pytest.raises(InvalidAggregationInput) as e:
FlowValidator._validate_aggregation_inputs(aggregated_flow_inputs, aggregation_inputs)
assert str(e.value) == error_message

@pytest.mark.parametrize(
"flow_folder",
["simple_flow_with_python_tool_and_aggregate"],
Expand Down

0 comments on commit be4c4bb

Please sign in to comment.