Skip to content

Commit

Permalink
[Bugfix] Add validation for wrong connection type for LLM node. (#1685)
Browse files Browse the repository at this point in the history
# Description

Add validation for wrong connection type for LLM node.
This pull request includes changes related to testing and validation of
connections in the `promptflow` codebase. The most important changes
include adding a new test case to validate the behavior when a tool load
fails due to an invalid connection type, and adding a check to validate
the connection type of a node in the `_resolve_llm_node` method.

Testing and validation changes:

* `src/promptflow/tests/executor/e2etests/test_executor_validation.py`:
Added a new test case to validate the behavior when a tool load fails
due to an invalid connection type. (test_executor_validation.py)
* `src/promptflow/promptflow/executor/_tool_resolver.py`: Added a check
to validate the connection type of a node in the `_resolve_llm_node`
method. (src/promptflow/promptflow/executor/_tool_resolver.py)

Additional changes:

*
`src/promptflow/tests/test_configs/wrong_flows/flow_llm_with_wrong_conn/flow.dag.yaml`:
Added a new YAML file that defines a node with a custom connection.
(src/promptflow/tests/test_configs/wrong_flows/flow_llm_with_wrong_conn/flow.dag.yaml)
*
`src/promptflow/tests/test_configs/wrong_flows/flow_llm_with_wrong_conn/wrong_llm.jinja2`:
Added a new Jinja2 template file containing a system message and a user
message that uses a Jinja2 variable.
(src/promptflow/tests/test_configs/wrong_flows/flow_llm_with_wrong_conn/wrong_llm.jinja2)

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

Co-authored-by: Heyi Tang <[email protected]>
  • Loading branch information
thy09 and Heyi Tang authored Jan 9, 2024
1 parent 11a5ec9 commit 3e91ce0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,13 @@ def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTo
if not connection_type_to_api_mapping:
raise EmptyLLMApiMapping()
# If provider is not specified, try to resolve it from connection type
node.provider = connection_type_to_api_mapping.get(type(connection).__name__)
connection_type = type(connection).__name__
if connection_type not in connection_type_to_api_mapping:
raise InvalidConnectionType(
message_format="Connection type {conn_type} is not supported for LLM.",
conn_type=connection_type,
)
node.provider = connection_type_to_api_mapping[connection_type]
tool: Tool = self._tool_loader.load_tool_for_llm_node(node)
key, connection = self._resolve_llm_connection_to_inputs(node, tool)
updated_node = copy.deepcopy(node)
Expand Down
11 changes: 11 additions & 0 deletions src/promptflow/tests/executor/e2etests/test_executor_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
InputNotFound,
InputReferenceNotFound,
InputTypeError,
InvalidConnectionType,
InvalidSource,
NodeCircularDependency,
NodeInputValidationError,
Expand All @@ -37,6 +38,16 @@ class TestValidation:
@pytest.mark.parametrize(
"flow_folder, yml_file, error_class, inner_class, error_msg",
[
(
"flow_llm_with_wrong_conn",
"flow.dag.yaml",
ResolveToolError,
InvalidConnectionType,
(
"Tool load failed in 'wrong_llm': "
"(InvalidConnectionType) Connection type CustomConnection is not supported for LLM."
)
),
(
"nodes_names_duplicated",
"flow.dag.yaml",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
inputs: {}
outputs: {}
nodes:
- name: wrong_llm
type: llm
source:
type: code
path: wrong_llm.jinja2
inputs: {}
connection: custom_connection
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
system:
You are a helpful assistant.

user:
{{question}}

0 comments on commit 3e91ce0

Please sign in to comment.