From 8907e964832ce9714877077625d063404fc0b1fb Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Thu, 25 Apr 2024 17:14:05 +0800 Subject: [PATCH 1/7] Support flexflow with global config Signed-off-by: Brynn Yin --- .../unittests/test_pf_client.py | 4 + .../_connection_provider.py | 10 +- .../_dict_connection_provider.py | 6 +- .../core/_connection_provider/_utils.py | 6 +- .../_workspace_connection_provider.py | 8 +- .../promptflow/core/_errors.py | 4 + .../promptflow/core/_serving/flow_invoker.py | 8 +- .../promptflow/executor/_errors.py | 2 +- .../promptflow/executor/_tool_resolver.py | 33 +++--- .../promptflow/_sdk/_pf_client.py | 6 +- .../_local_azure_connection_operations.py | 10 +- .../e2etests/test_global_config.py | 19 +++- .../e2etests/test_csharp_executor_proxy.py | 4 +- .../e2etests/test_executor_happypath.py | 4 +- .../e2etests/test_executor_validation.py | 4 +- .../batch/test_base_executor_proxy.py | 4 +- .../unittests/batch/test_batch_engine.py | 6 +- .../unittests/executor/test_tool_resolver.py | 105 ++++++++++++------ 18 files changed, 161 insertions(+), 82 deletions(-) diff --git a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_pf_client.py b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_pf_client.py index 1a8e0fd3925..191f5861592 100644 --- a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_pf_client.py +++ b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_pf_client.py @@ -1,6 +1,8 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import os + import mock import pytest @@ -20,6 +22,8 @@ class TestPFClient: # Test pf client when connection provider is azureml. # This tests suites need azure dependencies. + # Mock os.environ to avoid this test affecting other tests + @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.skipif(condition=not pytest.is_live, reason="This test requires an actual PFClient") def test_connection_provider(self, subscription_id: str, resource_group_name: str, workspace_name: str): target = "promptflow._sdk._pf_client.Configuration" diff --git a/src/promptflow-core/promptflow/core/_connection_provider/_connection_provider.py b/src/promptflow-core/promptflow/core/_connection_provider/_connection_provider.py index 05a8a624ad5..81eb79a8d5d 100644 --- a/src/promptflow-core/promptflow/core/_connection_provider/_connection_provider.py +++ b/src/promptflow-core/promptflow/core/_connection_provider/_connection_provider.py @@ -14,7 +14,7 @@ class ConnectionProvider(ABC): """The connection provider interface to list/get connections in the current environment.""" - PROVIDER_CONFIG_KEY = "CONNECTION_PROVIDER_CONFIG" + PROVIDER_CONFIG_KEY = "PF_CONNECTION_PROVIDER" _instance = None @abstractmethod @@ -28,12 +28,12 @@ def list(self, **kwargs) -> List[_Connection]: raise NotImplementedError("Method 'list' is not implemented.") @classmethod - def get_instance(cls) -> "ConnectionProvider": + def get_instance(cls, **kwargs) -> "ConnectionProvider": """Get the connection provider instance in the current environment. It will return different implementations based on the current environment. """ if not cls._instance: - cls._instance = cls._init_from_env() + cls._instance = cls._init_from_env(**kwargs) return cls._instance @classmethod @@ -63,7 +63,7 @@ def init_from_provider_config(cls, provider_config: str, credential=None): ) @classmethod - def _init_from_env(cls) -> "ConnectionProvider": + def _init_from_env(cls, **kwargs) -> "ConnectionProvider": """Initialize the connection provider from environment variables.""" from ._http_connection_provider import HttpConnectionProvider @@ -71,4 +71,4 @@ def _init_from_env(cls) -> "ConnectionProvider": if endpoint: return HttpConnectionProvider(endpoint) provider_config = os.getenv(cls.PROVIDER_CONFIG_KEY, "") - return ConnectionProvider.init_from_provider_config(provider_config) + return ConnectionProvider.init_from_provider_config(provider_config, credential=kwargs.get("credential")) diff --git a/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py b/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py index fb05ffee1bc..0ef360321a4 100644 --- a/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py +++ b/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py @@ -9,6 +9,7 @@ from promptflow.contracts.tool import ConnectionType from promptflow.contracts.types import Secret +from .._errors import ConnectionNotFound from ._connection_provider import ConnectionProvider @@ -102,4 +103,7 @@ def get(self, name: str) -> Any: return self._connections.get(name) elif ConnectionType.is_connection_value(name): return name - return None + raise ConnectionNotFound( + f"Connection {name!r} not found in dict connection provider." + f"Available keys are {list(self._connections.keys())}." + ) diff --git a/src/promptflow-core/promptflow/core/_connection_provider/_utils.py b/src/promptflow-core/promptflow/core/_connection_provider/_utils.py index 97dd00029b6..3fa7a401f1e 100644 --- a/src/promptflow-core/promptflow/core/_connection_provider/_utils.py +++ b/src/promptflow-core/promptflow/core/_connection_provider/_utils.py @@ -59,9 +59,9 @@ def is_github_codespaces(): return os.environ.get("CODESPACES", None) == "true" -def interactive_credential_disabled(): - """Check if interactive login is disabled.""" - return os.environ.get(PF_NO_INTERACTIVE_LOGIN, "false").lower() == "true" +def interactive_credential_enabled(): + """Check if interactive login is enabled.""" + return os.environ.get(PF_NO_INTERACTIVE_LOGIN, "true").lower() == "false" def is_from_cli(): diff --git a/src/promptflow-core/promptflow/core/_connection_provider/_workspace_connection_provider.py b/src/promptflow-core/promptflow/core/_connection_provider/_workspace_connection_provider.py index 4e038f93819..b97f048616a 100644 --- a/src/promptflow-core/promptflow/core/_connection_provider/_workspace_connection_provider.py +++ b/src/promptflow-core/promptflow/core/_connection_provider/_workspace_connection_provider.py @@ -26,7 +26,7 @@ from ..._utils.credential_utils import get_default_azure_credential from ._connection_provider import ConnectionProvider -from ._utils import interactive_credential_disabled, is_from_cli, is_github_codespaces +from ._utils import interactive_credential_enabled, is_from_cli, is_github_codespaces GET_CONNECTION_URL = ( "/subscriptions/{sub}/resourcegroups/{rg}/providers/Microsoft.MachineLearningServices" @@ -111,14 +111,14 @@ def _get_credential(cls): get_arm_token(credential=credential) except Exception: raise AccountNotSetUp() - if interactive_credential_disabled(): - return DefaultAzureCredential(exclude_interactive_browser_credential=True) + if interactive_credential_enabled(): + return DefaultAzureCredential(exclude_interactive_browser_credential=False) if is_github_codespaces(): # For code spaces, append device code credential as the fallback option. credential = DefaultAzureCredential() credential.credentials = (*credential.credentials, DeviceCodeCredential()) return credential - return DefaultAzureCredential(exclude_interactive_browser_credential=False) + return DefaultAzureCredential(exclude_interactive_browser_credential=True) @classmethod def open_url(cls, token, url, action, host="management.azure.com", method="GET", model=None) -> Union[Any, dict]: diff --git a/src/promptflow-core/promptflow/core/_errors.py b/src/promptflow-core/promptflow/core/_errors.py index c43e0f07c55..a611049fa98 100644 --- a/src/promptflow-core/promptflow/core/_errors.py +++ b/src/promptflow-core/promptflow/core/_errors.py @@ -102,6 +102,10 @@ class InvalidSampleError(CoreError): pass +class ConnectionNotFound(CoreError): + pass + + class OpenURLUserAuthenticationError(UserAuthenticationError): def __init__(self, **kwargs): super().__init__(target=ErrorTarget.CORE, **kwargs) diff --git a/src/promptflow-core/promptflow/core/_serving/flow_invoker.py b/src/promptflow-core/promptflow/core/_serving/flow_invoker.py index d3241495733..f7221ea026b 100644 --- a/src/promptflow-core/promptflow/core/_serving/flow_invoker.py +++ b/src/promptflow-core/promptflow/core/_serving/flow_invoker.py @@ -116,6 +116,12 @@ def _init_connections(self, connection_provider): connections_to_ignore.extend(self.connections_name_overrides.keys()) self.logger.debug(f"Flow invoker connections name overrides: {self.connections_name_overrides.keys()}") self.logger.debug(f"Ignoring connections: {connections_to_ignore}") + if not connection_provider: + connection_provider = ConnectionProvider.get_instance(credential=self._credential) + else: + connection_provider = ConnectionProvider.init_from_provider_config( + connection_provider, credential=self._credential + ) # Note: The connection here could be local or workspace, depends on the connection.provider in pf.yaml. connections = self.resolve_connections( # use os.environ to override flow definition's connection since @@ -123,7 +129,7 @@ def _init_connections(self, connection_provider): connection_names=self.flow.get_connection_names( environment_variables_overrides=os.environ, ), - provider=ConnectionProvider.init_from_provider_config(connection_provider, credential=self._credential), + provider=connection_provider, connections_to_ignore=connections_to_ignore, # fetch connections with name override connections_to_add=list(self.connections_name_overrides.values()), diff --git a/src/promptflow-core/promptflow/executor/_errors.py b/src/promptflow-core/promptflow/executor/_errors.py index 83bcd6d822f..2143511d9e6 100644 --- a/src/promptflow-core/promptflow/executor/_errors.py +++ b/src/promptflow-core/promptflow/executor/_errors.py @@ -54,7 +54,7 @@ def __init__( ) -class ConnectionNotFound(InvalidRequest): +class GetConnectionError(InvalidRequest): pass diff --git a/src/promptflow-core/promptflow/executor/_tool_resolver.py b/src/promptflow-core/promptflow/executor/_tool_resolver.py index 612db104061..5fd39aa7f73 100644 --- a/src/promptflow-core/promptflow/executor/_tool_resolver.py +++ b/src/promptflow-core/promptflow/executor/_tool_resolver.py @@ -13,7 +13,7 @@ from promptflow._constants import MessageFormatType from promptflow._core._errors import InvalidSource -from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR, INPUTS_TO_ESCAPE_PARAM_KEY, TOOL_TYPE_TO_ESCAPE +from promptflow._core.tool import INPUTS_TO_ESCAPE_PARAM_KEY, STREAMING_OPTION_PARAMETER_ATTR, TOOL_TYPE_TO_ESCAPE from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping from promptflow._utils.multimedia_utils import MultimediaProcessor from promptflow._utils.tool_utils import ( @@ -36,9 +36,9 @@ ) from promptflow.executor._docstring_parser import DocstringParser from promptflow.executor._errors import ( - ConnectionNotFound, EmptyLLMApiMapping, FailedToGenerateToolDefinition, + GetConnectionError, InvalidAssistantTool, InvalidConnectionType, InvalidCustomLLMTool, @@ -83,9 +83,11 @@ def start_resolver( return resolver def _convert_to_connection_value(self, k: str, v: InputAssignment, node_name: str, conn_types: List[ValueType]): - connection_value = self._connection_provider.get(v.value) - if not connection_value: - raise ConnectionNotFound(f"Connection {v.value} not found for node {node_name!r} input {k!r}.") + try: + connection_value = self._connection_provider.get(v.value) + except Exception as e: # Cache all exception as different provider raises different exceptions + # Raise new error with node details + raise GetConnectionError(f"Connection {v.value} for node {node_name!r} input {k!r} error: {str(e)}.") from e # Check if type matched if not any(type(connection_value).__name__ == typ for typ in conn_types): msg = ( @@ -108,9 +110,11 @@ def _convert_to_custom_strong_type_connection_value( if not conn_types: msg = f"Input '{k}' for node '{node_name}' has invalid types: {conn_types}." raise NodeInputValidationError(message=msg) - connection_value = self._connection_provider.get(v.value) - if not connection_value: - raise ConnectionNotFound(f"Connection {v.value} not found for node {node_name!r} input {k!r}.") + try: + connection_value = self._connection_provider.get(v.value) + except Exception as e: # Cache all exception as different provider raises different exceptions + # Raise new error with node details + raise GetConnectionError(f"Connection {v.value} for node {node_name!r} input {k!r} error: {str(e)}.") from e custom_defined_connection_class_name = conn_types[0] source_type = getattr(source, "type", None) @@ -478,14 +482,11 @@ def _remove_init_args(node_inputs: dict, init_args: dict): del node_inputs[k] def _get_llm_node_connection(self, node: Node): - connection = self._connection_provider.get(node.connection) - if connection is None: - raise ConnectionNotFound( - message_format="Connection '{connection}' of LLM node '{node_name}' is not found.", - connection=node.connection, - node_name=node.name, - target=ErrorTarget.EXECUTOR, - ) + try: + connection = self._connection_provider.get(node.connection) + except Exception as e: # Cache all exception as different provider raises different exceptions + # Raise new error with node details + raise GetConnectionError(f"Connection {node.connection} for node {node.name!r} error: {str(e)}.") from e return connection @staticmethod diff --git a/src/promptflow-devkit/promptflow/_sdk/_pf_client.py b/src/promptflow-devkit/promptflow/_sdk/_pf_client.py index a5c7c56ec09..99d9675e7a4 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_pf_client.py +++ b/src/promptflow-devkit/promptflow/_sdk/_pf_client.py @@ -366,7 +366,11 @@ def _ensure_connection_provider(self) -> str: if not self._connection_provider: # Get a copy with config override instead of the config instance self._connection_provider = Configuration(overrides=self._config).get_connection_provider() - logger.debug("PFClient connection provider: %s", self._connection_provider) + logger.debug("PFClient connection provider: %s, setting to env.", self._connection_provider) + from promptflow.core._connection_provider._connection_provider import ConnectionProvider + + # Set to os.environ for connection provider to use + os.environ[ConnectionProvider.PROVIDER_CONFIG_KEY] = self._connection_provider return self._connection_provider @property diff --git a/src/promptflow-devkit/promptflow/_sdk/operations/_local_azure_connection_operations.py b/src/promptflow-devkit/promptflow/_sdk/operations/_local_azure_connection_operations.py index d676d2a9661..431719ac538 100644 --- a/src/promptflow-devkit/promptflow/_sdk/operations/_local_azure_connection_operations.py +++ b/src/promptflow-devkit/promptflow/_sdk/operations/_local_azure_connection_operations.py @@ -12,7 +12,7 @@ from promptflow._utils.credential_utils import get_default_azure_credential from promptflow._utils.logger_utils import get_cli_sdk_logger from promptflow.core._connection_provider._utils import ( - interactive_credential_disabled, + interactive_credential_enabled, is_from_cli, is_github_codespaces, ) @@ -73,14 +73,14 @@ def _get_credential(cls): "See https://docs.microsoft.com/cli/azure/authenticate-azure-cli for more details." ) sys.exit(1) - if interactive_credential_disabled(): - return DefaultAzureCredential(exclude_interactive_browser_credential=True) + if interactive_credential_enabled(): + return DefaultAzureCredential(exclude_interactive_browser_credential=False) if is_github_codespaces(): # For code spaces, append device code credential as the fallback option. - credential = DefaultAzureCredential() + credential = DefaultAzureCredential(exclude_interactive_browser_credential=True) credential.credentials = (*credential.credentials, DeviceCodeCredential()) return credential - return DefaultAzureCredential(exclude_interactive_browser_credential=False) + return DefaultAzureCredential(exclude_interactive_browser_credential=True) @monitor_operation(activity_name="pf.connections.azure.list", activity_type=ActivityType.PUBLICAPI) def list( diff --git a/src/promptflow-devkit/tests/sdk_cli_global_config_test/e2etests/test_global_config.py b/src/promptflow-devkit/tests/sdk_cli_global_config_test/e2etests/test_global_config.py index b7553962944..752668aafcf 100644 --- a/src/promptflow-devkit/tests/sdk_cli_global_config_test/e2etests/test_global_config.py +++ b/src/promptflow-devkit/tests/sdk_cli_global_config_test/e2etests/test_global_config.py @@ -4,10 +4,14 @@ from promptflow._sdk._load_functions import load_flow from promptflow._sdk.entities._flows._flow_context_resolver import FlowContextResolver +from promptflow.contracts.run_info import Status from promptflow.core._connection_provider._workspace_connection_provider import WorkspaceConnectionProvider +from promptflow.executor._script_executor import ScriptExecutor -FLOWS_DIR = PROMPTFLOW_ROOT / "tests" / "test_configs" / "flows" -DATAS_DIR = PROMPTFLOW_ROOT / "tests" / "test_configs" / "datas" +TEST_CONFIG_DIR = PROMPTFLOW_ROOT / "tests" / "test_configs" +FLOWS_DIR = TEST_CONFIG_DIR / "flows" +DATAS_DIR = TEST_CONFIG_DIR / "datas" +EAGER_FLOW_ROOT = TEST_CONFIG_DIR / "eager_flows" @pytest.mark.usefixtures("global_config") @@ -45,3 +49,14 @@ def assert_client(mock_self, provider, **kwargs): flow = load_flow(source=f"{FLOWS_DIR}/web_classification") with mock.patch("promptflow.core._serving.flow_invoker.FlowInvoker.resolve_connections", assert_client): FlowContextResolver.resolve(flow=flow) + + def test_flex_flow_run_with_openai_chat(self, pf): + # Test flex flow run successfully with global config ws connection + flow_file = EAGER_FLOW_ROOT / "callable_class_with_openai" / "flow.flex.yaml" + pf._ensure_connection_provider() + executor = ScriptExecutor(flow_file=flow_file, init_kwargs={"connection": "azure_open_ai_connection"}) + line_result = executor.exec_line(inputs={"question": "Hello", "stream": False}, index=0) + assert line_result.run_info.status == Status.Completed, line_result.run_info.error + token_names = ["prompt_tokens", "completion_tokens", "total_tokens"] + for token_name in token_names: + assert token_name in line_result.run_info.api_calls[0]["children"][0]["system_metrics"] diff --git a/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py b/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py index 61ebc547bf9..b1b6d03d5d9 100644 --- a/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py +++ b/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py @@ -15,7 +15,7 @@ from promptflow.batch._result import BatchResult from promptflow.contracts.run_info import Status from promptflow.exceptions import ErrorTarget, ValidationException -from promptflow.executor._errors import ConnectionNotFound +from promptflow.executor._errors import GetConnectionError from promptflow.storage._run_storage import AbstractRunStorage from ..mock_execution_server import run_executor_server @@ -46,7 +46,7 @@ def test_batch_execution_error(self): def test_batch_validation_error(self): # prepare the init error file to mock the validation error error_message = "'test_connection' not found." - test_exception = ConnectionNotFound(message=error_message) + test_exception = GetConnectionError(message=error_message) error_dict = ExceptionPresenter.create(test_exception).to_dict() init_error_file = Path(mkdtemp()) / "init_error.json" with open(init_error_file, "w") as file: diff --git a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py index 87a411bae98..59fd93a8500 100644 --- a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py +++ b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py @@ -12,7 +12,7 @@ from promptflow.contracts.run_info import Status from promptflow.exceptions import UserErrorException from promptflow.executor import FlowExecutor -from promptflow.executor._errors import ConnectionNotFound, InputTypeError, ResolveToolError +from promptflow.executor._errors import GetConnectionError, InputTypeError, ResolveToolError from promptflow.executor.flow_executor import execute_flow from promptflow.storage._run_storage import DefaultRunStorage @@ -190,7 +190,7 @@ def test_executor_node_overrides(self, dev_connections): node_override={"classify_with_llm.connection": "dummy_connection"}, raise_ex=True, ) - assert isinstance(e.value.inner_exception, ConnectionNotFound) + assert isinstance(e.value.inner_exception, GetConnectionError) assert "Connection 'dummy_connection' of LLM node 'classify_with_llm' is not found." in str(e.value) @pytest.mark.parametrize( diff --git a/src/promptflow/tests/executor/e2etests/test_executor_validation.py b/src/promptflow/tests/executor/e2etests/test_executor_validation.py index d77ec93e400..321bfee671c 100644 --- a/src/promptflow/tests/executor/e2etests/test_executor_validation.py +++ b/src/promptflow/tests/executor/e2etests/test_executor_validation.py @@ -12,9 +12,9 @@ from promptflow.contracts._errors import FailedToImportModule from promptflow.executor import FlowExecutor from promptflow.executor._errors import ( - ConnectionNotFound, DuplicateNodeName, EmptyOutputReference, + GetConnectionError, InputNotFound, InputReferenceNotFound, InputTypeError, @@ -177,7 +177,7 @@ def test_node_topology_in_order(self, ordered_flow_folder, unordered_flow_folder @pytest.mark.parametrize( "flow_folder, error_class, inner_class", [ - ("invalid_connection", ResolveToolError, ConnectionNotFound), + ("invalid_connection", ResolveToolError, GetConnectionError), ("tool_type_missing", ResolveToolError, NotImplementedError), ("wrong_module", FailedToImportModule, None), ("wrong_api", ResolveToolError, APINotFound), diff --git a/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py b/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py index 5d4d2a06e3a..24a2b6f2195 100644 --- a/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py +++ b/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py @@ -12,7 +12,7 @@ from promptflow._utils.exception_utils import ExceptionPresenter from promptflow.contracts.run_info import Status from promptflow.exceptions import ErrorTarget, ValidationException -from promptflow.executor._errors import ConnectionNotFound +from promptflow.executor._errors import GetConnectionError from promptflow.storage._run_storage import AbstractRunStorage from ...mock_execution_server import _get_aggr_result_dict, _get_line_result_dict @@ -89,7 +89,7 @@ async def test_ensure_executor_startup_when_existing_validation_error(self): # prepare the error file error_file = Path(mkdtemp()) / "error.json" error_message = "Connection 'aoai_conn' not found" - error_dict = ExceptionPresenter.create(ConnectionNotFound(message=error_message)).to_dict() + error_dict = ExceptionPresenter.create(GetConnectionError(message=error_message)).to_dict() with open(error_file, "w") as file: json.dump(error_dict, file, indent=4) diff --git a/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py b/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py index 6ae3820e03c..7690497bcd7 100644 --- a/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py +++ b/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py @@ -12,7 +12,7 @@ from promptflow.batch import BatchEngine from promptflow.contracts.run_info import Status from promptflow.exceptions import ErrorTarget -from promptflow.executor._errors import ConnectionNotFound +from promptflow.executor._errors import GetConnectionError from promptflow.executor._result import AggregationResult from ...utils import MemoryRunStorage, get_yaml_file, load_jsonl @@ -32,8 +32,8 @@ class TestBatchEngine: "Unexpected error occurred while executing the batch run. Error: (Exception) test error.", ), ( - ConnectionNotFound(message="Connection 'aoai_conn' not found"), - ConnectionNotFound, + GetConnectionError(message="Connection 'aoai_conn' not found"), + GetConnectionError, ErrorTarget.EXECUTOR, ["UserError", "ValidationError", "InvalidRequest", "ConnectionNotFound"], "Connection 'aoai_conn' not found", diff --git a/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py b/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py index df5a09a1fe0..8856c6e89b8 100644 --- a/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py +++ b/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py @@ -8,8 +8,8 @@ from jinja2 import TemplateSyntaxError from promptflow._core._errors import InvalidSource -from promptflow._core.tools_manager import ToolLoader from promptflow._core.tool import INPUTS_TO_ESCAPE_PARAM_KEY +from promptflow._core.tools_manager import ToolLoader from promptflow._internal import tool from promptflow.connections import AzureOpenAIConnection, CustomConnection, CustomStrongTypeConnection from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSource, ToolSourceType @@ -19,7 +19,7 @@ from promptflow.exceptions import UserErrorException from promptflow.executor._assistant_tool_invoker import ResolvedAssistantTool from promptflow.executor._errors import ( - ConnectionNotFound, + GetConnectionError, InvalidConnectionType, NodeInputValidationError, ResolveToolError, @@ -189,7 +189,7 @@ def test_convert_node_literal_input_types_with_invalid_case(self): tool=tool, inputs={"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL)}, ) - with pytest.raises(ConnectionNotFound): + with pytest.raises(GetConnectionError): tool_resolver = ToolResolver(working_dir=None, connection_provider=DictConnectionProvider({})) tool_resolver._convert_node_literal_input_types(node, tool) @@ -269,7 +269,7 @@ def test_resolve_llm_connection_to_inputs(self): tool=tool, inputs={"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL)}, ) - with pytest.raises(ConnectionNotFound): + with pytest.raises(GetConnectionError): tool_resolver = ToolResolver(working_dir=None, connection_provider=connection_provider) tool_resolver._resolve_llm_connection_to_inputs(node, tool) @@ -281,7 +281,7 @@ def test_resolve_llm_connection_to_inputs(self): inputs={"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL)}, connection="conn_name1", ) - with pytest.raises(ConnectionNotFound): + with pytest.raises(GetConnectionError): tool_resolver = ToolResolver(working_dir=None, connection_provider=DictConnectionProvider({})) tool_resolver._resolve_llm_connection_to_inputs(node, tool) @@ -792,33 +792,74 @@ def test_invalid_assistant_definition_path(self, path): "value 'assistant_definition_non_existing.yaml' is not a valid path." ) - @pytest.mark.parametrize("tool_type, node_inputs, expected_inputs", [ - (ToolType.PYTHON, {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL)}, - {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL)}), - (ToolType.PYTHON, {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT)}, - {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT)}), - (ToolType.PROMPT, {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT)}, - {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), - INPUTS_TO_ESCAPE_PARAM_KEY: InputAssignment(value=["text"], value_type=InputValueType.LITERAL)}), - (ToolType.LLM, {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT)}, - {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), - INPUTS_TO_ESCAPE_PARAM_KEY: InputAssignment(value=["text"], value_type=InputValueType.LITERAL)}), - (ToolType.CUSTOM_LLM, {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT)}, - {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), - INPUTS_TO_ESCAPE_PARAM_KEY: InputAssignment(value=["text"], value_type=InputValueType.LITERAL)}), - (ToolType.LLM, {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.LITERAL)}, - {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), - "text": InputAssignment(value="Hello World!", value_type=InputValueType.LITERAL)}), - ]) + @pytest.mark.parametrize( + "tool_type, node_inputs, expected_inputs", + [ + ( + ToolType.PYTHON, + {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL)}, + {"conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL)}, + ), + ( + ToolType.PYTHON, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + }, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + }, + ), + ( + ToolType.PROMPT, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + }, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + INPUTS_TO_ESCAPE_PARAM_KEY: InputAssignment(value=["text"], value_type=InputValueType.LITERAL), + }, + ), + ( + ToolType.LLM, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + }, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + INPUTS_TO_ESCAPE_PARAM_KEY: InputAssignment(value=["text"], value_type=InputValueType.LITERAL), + }, + ), + ( + ToolType.CUSTOM_LLM, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + }, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.FLOW_INPUT), + INPUTS_TO_ESCAPE_PARAM_KEY: InputAssignment(value=["text"], value_type=InputValueType.LITERAL), + }, + ), + ( + ToolType.LLM, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.LITERAL), + }, + { + "conn": InputAssignment(value="conn_name", value_type=InputValueType.LITERAL), + "text": InputAssignment(value="Hello World!", value_type=InputValueType.LITERAL), + }, + ), + ], + ) def test_update_inputs_to_escape(self, tool_type, node_inputs, expected_inputs): node = Node( name="mock", From 475578c1d59ae19595f3b3346df9b22f707c1dcd Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Thu, 25 Apr 2024 17:52:20 +0800 Subject: [PATCH 2/7] Add changelog Signed-off-by: Brynn Yin --- src/promptflow-devkit/CHANGELOG.md | 5 +++++ src/promptflow/CHANGELOG.md | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/promptflow-devkit/CHANGELOG.md b/src/promptflow-devkit/CHANGELOG.md index 465076bc662..d789473f9e4 100644 --- a/src/promptflow-devkit/CHANGELOG.md +++ b/src/promptflow-devkit/CHANGELOG.md @@ -1,5 +1,10 @@ # promptflow-devkit package +## v1.11.0 (Upcoming) + +### Improvements +- Interactive browser credential is excluded by default when using Azure AI connections, user could set `PF_NO_INTERACTIVE_LOGIN=False` to enable it. + ## v1.10.0 (Upcoming) ### Features Added diff --git a/src/promptflow/CHANGELOG.md b/src/promptflow/CHANGELOG.md index 293a0f50504..9123d85589e 100644 --- a/src/promptflow/CHANGELOG.md +++ b/src/promptflow/CHANGELOG.md @@ -1,5 +1,10 @@ # Release History +## v1.11.0 (Upcoming) + +### Improvements +- [promptflow-devkit]: Interactive browser credential is excluded by default when using Azure AI connections, user could set `PF_NO_INTERACTIVE_LOGIN=False` to enable it. + ## v1.10.0 (Upcoming) ### Features Added - [promptflow-devkit]: Expose --ui to trigger a chat window, reach [here](https://microsoft.github.io/promptflow/reference/pf-command-reference.html#pf-flow-test) for more details. From 00df20eda3b7c5951f92f3d847d275eaf09f154f Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Thu, 25 Apr 2024 19:20:42 +0800 Subject: [PATCH 3/7] Fix test Signed-off-by: Brynn Yin --- .../_dict_connection_provider.py | 2 +- .../promptflow/executor/_errors.py | 15 ++++++++++++++- .../promptflow/executor/_tool_resolver.py | 6 +++--- .../executor/e2etests/test_executor_happypath.py | 5 ++++- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py b/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py index 407c80547b1..34877a35a76 100644 --- a/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py +++ b/src/promptflow-core/promptflow/core/_connection_provider/_dict_connection_provider.py @@ -106,7 +106,7 @@ def get(self, name: str) -> Any: connection = self._connections.get(name) if not connection: raise ConnectionNotFound( - f"Connection {name!r} not found in dict connection provider." + f"Connection {name!r} not found in dict connection provider. " f"Available keys are {list(self._connections.keys())}." ) return connection diff --git a/src/promptflow-core/promptflow/executor/_errors.py b/src/promptflow-core/promptflow/executor/_errors.py index 2143511d9e6..ffd29f4389b 100644 --- a/src/promptflow-core/promptflow/executor/_errors.py +++ b/src/promptflow-core/promptflow/executor/_errors.py @@ -55,7 +55,20 @@ def __init__( class GetConnectionError(InvalidRequest): - pass + def __init__( + self, + connection: str, + node_name: str, + error: Exception, + **kwargs, + ): + super().__init__( + message_format="Get connection '{connection}' for node '{node_name}' error: {error}", + connection=connection, + node_name=node_name, + error=str(error), + target=ErrorTarget.EXECUTOR, + ) class InvalidBulkTestRequest(ValidationException): diff --git a/src/promptflow-core/promptflow/executor/_tool_resolver.py b/src/promptflow-core/promptflow/executor/_tool_resolver.py index 5fd39aa7f73..fec9d90f3dd 100644 --- a/src/promptflow-core/promptflow/executor/_tool_resolver.py +++ b/src/promptflow-core/promptflow/executor/_tool_resolver.py @@ -87,7 +87,7 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node_name: st connection_value = self._connection_provider.get(v.value) except Exception as e: # Cache all exception as different provider raises different exceptions # Raise new error with node details - raise GetConnectionError(f"Connection {v.value} for node {node_name!r} input {k!r} error: {str(e)}.") from e + raise GetConnectionError(v.value, node_name, e) from e # Check if type matched if not any(type(connection_value).__name__ == typ for typ in conn_types): msg = ( @@ -114,7 +114,7 @@ def _convert_to_custom_strong_type_connection_value( connection_value = self._connection_provider.get(v.value) except Exception as e: # Cache all exception as different provider raises different exceptions # Raise new error with node details - raise GetConnectionError(f"Connection {v.value} for node {node_name!r} input {k!r} error: {str(e)}.") from e + raise GetConnectionError(v.value, node_name, e) from e custom_defined_connection_class_name = conn_types[0] source_type = getattr(source, "type", None) @@ -486,7 +486,7 @@ def _get_llm_node_connection(self, node: Node): connection = self._connection_provider.get(node.connection) except Exception as e: # Cache all exception as different provider raises different exceptions # Raise new error with node details - raise GetConnectionError(f"Connection {node.connection} for node {node.name!r} error: {str(e)}.") from e + raise GetConnectionError(node.connection, node.name, e) from e return connection @staticmethod diff --git a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py index 59fd93a8500..481856bad31 100644 --- a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py +++ b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py @@ -191,7 +191,10 @@ def test_executor_node_overrides(self, dev_connections): raise_ex=True, ) assert isinstance(e.value.inner_exception, GetConnectionError) - assert "Connection 'dummy_connection' of LLM node 'classify_with_llm' is not found." in str(e.value) + assert ( + "Get connection 'dummy_connection' for node 'classify_with_llm' " + "error: Connection 'dummy_connection' not found" in str(e.value) + ) @pytest.mark.parametrize( "flow_folder", From f70ddabf63347d25e1ce7b8d52516815208e85b8 Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Thu, 25 Apr 2024 19:22:04 +0800 Subject: [PATCH 4/7] Add comments Signed-off-by: Brynn Yin --- src/promptflow-core/promptflow/core/_serving/flow_invoker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/promptflow-core/promptflow/core/_serving/flow_invoker.py b/src/promptflow-core/promptflow/core/_serving/flow_invoker.py index f7221ea026b..ce4b7dd6d0b 100644 --- a/src/promptflow-core/promptflow/core/_serving/flow_invoker.py +++ b/src/promptflow-core/promptflow/core/_serving/flow_invoker.py @@ -117,8 +117,10 @@ def _init_connections(self, connection_provider): self.logger.debug(f"Flow invoker connections name overrides: {self.connections_name_overrides.keys()}") self.logger.debug(f"Ignoring connections: {connections_to_ignore}") if not connection_provider: + # If user not pass in connection provider string, get from environment variable. connection_provider = ConnectionProvider.get_instance(credential=self._credential) else: + # Else, init from the string to parse the provider config. connection_provider = ConnectionProvider.init_from_provider_config( connection_provider, credential=self._credential ) From 63a859c96ebaef179d785f4fe670df4efbde147b Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 26 Apr 2024 11:51:05 +0800 Subject: [PATCH 5/7] Fix test Signed-off-by: Brynn Yin --- .../executor/unittests/batch/test_base_executor_proxy.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py b/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py index 24a2b6f2195..d9e0e6f8189 100644 --- a/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py +++ b/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py @@ -88,8 +88,9 @@ async def test_ensure_executor_startup_when_not_healthy(self): async def test_ensure_executor_startup_when_existing_validation_error(self): # prepare the error file error_file = Path(mkdtemp()) / "error.json" - error_message = "Connection 'aoai_conn' not found" - error_dict = ExceptionPresenter.create(GetConnectionError(message=error_message)).to_dict() + error_dict = ExceptionPresenter.create( + GetConnectionError(connection="aoai_conn", node_name="mock", error=Exception("mock")) + ).to_dict() with open(error_file, "w") as file: json.dump(error_dict, file, indent=4) @@ -98,7 +99,7 @@ async def test_ensure_executor_startup_when_existing_validation_error(self): mock.side_effect = ExecutorServiceUnhealthy("executor unhealthy") with pytest.raises(ValidationException) as ex: await mock_executor_proxy.ensure_executor_startup(error_file) - assert ex.value.message == error_message + assert "Get connection '{aoai_conn}' for node '{mock}' error" in ex.value.message assert ex.value.target == ErrorTarget.BATCH @pytest.mark.asyncio From 20bee0a7617875220af7f356761714c144929a5c Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 26 Apr 2024 12:09:29 +0800 Subject: [PATCH 6/7] Update Signed-off-by: Brynn Yin --- .../tests/executor/unittests/batch/test_base_executor_proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py b/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py index d9e0e6f8189..14b25444feb 100644 --- a/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py +++ b/src/promptflow/tests/executor/unittests/batch/test_base_executor_proxy.py @@ -99,7 +99,7 @@ async def test_ensure_executor_startup_when_existing_validation_error(self): mock.side_effect = ExecutorServiceUnhealthy("executor unhealthy") with pytest.raises(ValidationException) as ex: await mock_executor_proxy.ensure_executor_startup(error_file) - assert "Get connection '{aoai_conn}' for node '{mock}' error" in ex.value.message + assert "Get connection 'aoai_conn' for node 'mock' error: mock" in ex.value.message assert ex.value.target == ErrorTarget.BATCH @pytest.mark.asyncio From e856bcb85bf884e036e0409793b9c0343be642c2 Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 26 Apr 2024 13:52:49 +0800 Subject: [PATCH 7/7] Fix test Signed-off-by: Brynn Yin --- .../tests/executor/e2etests/test_csharp_executor_proxy.py | 5 ++--- .../tests/executor/unittests/batch/test_batch_engine.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py b/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py index b1b6d03d5d9..40a61f16ab3 100644 --- a/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py +++ b/src/promptflow/tests/executor/e2etests/test_csharp_executor_proxy.py @@ -45,8 +45,7 @@ def test_batch_execution_error(self): def test_batch_validation_error(self): # prepare the init error file to mock the validation error - error_message = "'test_connection' not found." - test_exception = GetConnectionError(message=error_message) + test_exception = GetConnectionError(connection="test_connection", node_name="mock", error=Exception("mock")) error_dict = ExceptionPresenter.create(test_exception).to_dict() init_error_file = Path(mkdtemp()) / "init_error.json" with open(init_error_file, "w") as file: @@ -54,7 +53,7 @@ def test_batch_validation_error(self): # submit a batch run with pytest.raises(ValidationException) as e: self._submit_batch_run(init_error_file=init_error_file) - assert error_message in e.value.message + assert "Get connection 'test_connection' for node 'mock' error: mock" in e.value.message assert e.value.error_codes == ["UserError", "ValidationError"] assert e.value.target == ErrorTarget.BATCH diff --git a/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py b/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py index 097dfa92089..dbcc57c1d29 100644 --- a/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py +++ b/src/promptflow/tests/executor/unittests/batch/test_batch_engine.py @@ -32,11 +32,11 @@ class TestBatchEngine: "Unexpected error occurred while executing the batch run. Error: (Exception) test error.", ), ( - GetConnectionError(message="Connection 'aoai_conn' not found"), + GetConnectionError(connection="aoai_conn", node_name="mock", error=Exception("mock")), GetConnectionError, ErrorTarget.EXECUTOR, ["UserError", "ValidationError", "InvalidRequest", "GetConnectionError"], - "Connection 'aoai_conn' not found", + "Get connection 'aoai_conn' for node 'mock' error: mock", ), ], )