Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
Signed-off-by: Brynn Yin <[email protected]>
  • Loading branch information
brynn-code committed Apr 25, 2024
1 parent 475578c commit 00df20e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get(self, name: str) -> Any:
connection = self._connections.get(name)
if not connection:
raise ConnectionNotFound(
f"Connection {name!r} not found in dict connection provider."
f"Connection {name!r} not found in dict connection provider. "
f"Available keys are {list(self._connections.keys())}."
)
return connection
15 changes: 14 additions & 1 deletion src/promptflow-core/promptflow/executor/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,20 @@ def __init__(


class GetConnectionError(InvalidRequest):
pass
def __init__(
self,
connection: str,
node_name: str,
error: Exception,
**kwargs,
):
super().__init__(
message_format="Get connection '{connection}' for node '{node_name}' error: {error}",
connection=connection,
node_name=node_name,
error=str(error),
target=ErrorTarget.EXECUTOR,
)


class InvalidBulkTestRequest(ValidationException):
Expand Down
6 changes: 3 additions & 3 deletions src/promptflow-core/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node_name: st
connection_value = self._connection_provider.get(v.value)
except Exception as e: # Cache all exception as different provider raises different exceptions
# Raise new error with node details
raise GetConnectionError(f"Connection {v.value} for node {node_name!r} input {k!r} error: {str(e)}.") from e
raise GetConnectionError(v.value, node_name, e) from e
# Check if type matched
if not any(type(connection_value).__name__ == typ for typ in conn_types):
msg = (
Expand All @@ -114,7 +114,7 @@ def _convert_to_custom_strong_type_connection_value(
connection_value = self._connection_provider.get(v.value)
except Exception as e: # Cache all exception as different provider raises different exceptions
# Raise new error with node details
raise GetConnectionError(f"Connection {v.value} for node {node_name!r} input {k!r} error: {str(e)}.") from e
raise GetConnectionError(v.value, node_name, e) from e

custom_defined_connection_class_name = conn_types[0]
source_type = getattr(source, "type", None)
Expand Down Expand Up @@ -486,7 +486,7 @@ def _get_llm_node_connection(self, node: Node):
connection = self._connection_provider.get(node.connection)
except Exception as e: # Cache all exception as different provider raises different exceptions
# Raise new error with node details
raise GetConnectionError(f"Connection {node.connection} for node {node.name!r} error: {str(e)}.") from e
raise GetConnectionError(node.connection, node.name, e) from e
return connection

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def test_executor_node_overrides(self, dev_connections):
raise_ex=True,
)
assert isinstance(e.value.inner_exception, GetConnectionError)
assert "Connection 'dummy_connection' of LLM node 'classify_with_llm' is not found." in str(e.value)
assert (
"Get connection 'dummy_connection' for node 'classify_with_llm' "
"error: Connection 'dummy_connection' not found" in str(e.value)
)

@pytest.mark.parametrize(
"flow_folder",
Expand Down

0 comments on commit 00df20e

Please sign in to comment.