From 0bc314c323462281685a96295b0a23bd2ceda143 Mon Sep 17 00:00:00 2001 From: Lina Tang Date: Tue, 30 Apr 2024 14:44:17 +0800 Subject: [PATCH] [Executor] Support passing connection provider to script executor (#3004) # Description Currently we use `ConnectionProvider.get_instance` to get connections in script executor, and it doesn't support passing connections to it directly, while the flow executor can accept connections, which is inconsistent. For flow serving, we want the behaviors to be consistent to refine the conenction resolving logic, so in this PR, we support directly passing connections to script executor. # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [x] Pull request includes test coverage for the included changes. Co-authored-by: Lina Tang --- .../promptflow/executor/_script_executor.py | 12 +++++---- .../promptflow/executor/flow_executor.py | 8 ++++-- src/promptflow-core/tests/conftest.py | 6 +++++ .../tests/core/e2etests/test_eager_flow.py | 25 ++++++++++++++++++- .../flow.flex.yaml | 1 + .../simple_callable_with_connection.py | 10 ++++++++ 6 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml create mode 100644 src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index ac8a2d639cc..36011040829 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -25,6 +25,7 @@ from promptflow.contracts.flow import FlexFlow, Flow from promptflow.contracts.tool import ConnectionType from promptflow.core import log_metric +from promptflow.core._connection_provider._dict_connection_provider import DictConnectionProvider from promptflow.core._model_configuration import ( MODEL_CONFIG_NAME_2_CLASS, AzureOpenAIModelConfiguration, @@ -48,7 +49,7 @@ class ScriptExecutor(FlowExecutor): def __init__( self, flow_file: Union[Path, str, Callable], - connections: Optional[dict] = None, + connections: Optional[Union[dict, ConnectionProvider]] = None, working_dir: Optional[Path] = None, *, storage: Optional[AbstractRunStorage] = None, @@ -58,6 +59,9 @@ def __init__( logger.debug(f"Init params for script executor: {init_kwargs}") self._flow_file = flow_file + if connections and isinstance(connections, dict): + connections = DictConnectionProvider(connections) + self._connections = connections entry = flow_file # Entry could be both a path or a callable self._entry = entry self._init_kwargs = init_kwargs or {} @@ -67,7 +71,6 @@ def __init__( self._working_dir = working_dir or Path.cwd() self._init_input_sign() self._initialize_function() - self._connections = connections self._storage = storage or DefaultRunStorage() self._flow_id = "default_flow_id" self._log_interval = 60 @@ -353,9 +356,8 @@ def _resolve_init_kwargs(self, c: type, init_kwargs: dict): return resolved_init_kwargs - @classmethod - def _resolve_connection_params(cls, connection_params: list, init_kwargs: dict, resolved_init_kwargs: dict): - provider = ConnectionProvider.get_instance() + def _resolve_connection_params(self, connection_params: list, init_kwargs: dict, resolved_init_kwargs: dict): + provider = self._connections or ConnectionProvider.get_instance() # parse connection logger.debug(f"Resolving connection params: {connection_params}") for key in connection_params: diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index fbb71cf7094..a5a1374b081 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -205,14 +205,18 @@ def create( if hasattr(flow_file, "__call__") or inspect.isfunction(flow_file): from ._script_executor import ScriptExecutor - return ScriptExecutor(flow_file, storage=storage) + return ScriptExecutor(flow_file, connections=connections, storage=storage) if not isinstance(flow_file, (Path, str)): raise NotImplementedError("Only support Path or str for flow_file.") if is_flex_flow(flow_path=flow_file, working_dir=working_dir): from ._script_executor import ScriptExecutor return ScriptExecutor( - flow_file=Path(flow_file), working_dir=working_dir, storage=storage, init_kwargs=init_kwargs + flow_file=Path(flow_file), + connections=connections, + working_dir=working_dir, + storage=storage, + init_kwargs=init_kwargs, ) elif is_prompty_flow(file_path=flow_file): from ._prompty_executor import PromptyExecutor diff --git a/src/promptflow-core/tests/conftest.py b/src/promptflow-core/tests/conftest.py index e692b29f45b..12dc15a8210 100644 --- a/src/promptflow-core/tests/conftest.py +++ b/src/promptflow-core/tests/conftest.py @@ -156,6 +156,12 @@ def setup_connection_provider(): yield +@pytest.fixture +def dev_connections() -> dict: + with open(CONNECTION_FILE, "r") as f: + return json.load(f) + + # ==================== serving fixtures ==================== diff --git a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py index 79a4b9169f6..ca1a83e7a1b 100644 --- a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py +++ b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py @@ -6,6 +6,7 @@ from promptflow._core.tool_meta_generator import PythonLoadError from promptflow.contracts.run_info import Status from promptflow.core import AzureOpenAIModelConfiguration, OpenAIModelConfiguration +from promptflow.core._connection_provider._dict_connection_provider import DictConnectionProvider from promptflow.executor._errors import ( FlowEntryInitializationError, InputNotFound, @@ -44,7 +45,7 @@ async def func_entry_async(input_str: str) -> str: ] -@pytest.mark.usefixtures("recording_injection", "setup_connection_provider") +@pytest.mark.usefixtures("recording_injection", "setup_connection_provider", "dev_connections") @pytest.mark.e2etest class TestEagerFlow: @pytest.mark.parametrize( @@ -138,6 +139,28 @@ def test_flow_run_with_openai_chat(self): token_names = ["prompt_tokens", "completion_tokens", "total_tokens"] for token_name in token_names: assert token_name in line_result.run_info.api_calls[0]["children"][0]["system_metrics"] + assert line_result.run_info.api_calls[0]["children"][0]["system_metrics"][token_name] > 0 + + def test_flow_run_with_connection(self, dev_connections): + flow_file = get_yaml_file( + "dummy_callable_class_with_connection", root=EAGER_FLOW_ROOT, file_name="flow.flex.yaml" + ) + + # Test submitting eager flow to script executor with connection dictionary + executor = ScriptExecutor( + flow_file=flow_file, connections=dev_connections, init_kwargs={"connection": "azure_open_ai_connection"} + ) + line_result = executor.exec_line(inputs={}, index=0) + assert line_result.run_info.status == Status.Completed, line_result.run_info.error + + # Test submitting eager flow to script executor with connection provider + executor = ScriptExecutor( + flow_file=flow_file, + connections=DictConnectionProvider(dev_connections), + init_kwargs={"connection": "azure_open_ai_connection"}, + ) + line_result = executor.exec_line(inputs={}, index=0) + assert line_result.run_info.status == Status.Completed, line_result.run_info.error @pytest.mark.parametrize("entry, inputs, expected_output", function_entries) def test_flow_run_with_function_entry(self, entry, inputs, expected_output): diff --git a/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml new file mode 100644 index 00000000000..e977acc37eb --- /dev/null +++ b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml @@ -0,0 +1 @@ +entry: simple_callable_with_connection:MyFlow \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py new file mode 100644 index 00000000000..b5122ca5d21 --- /dev/null +++ b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py @@ -0,0 +1,10 @@ +from promptflow.connections import AzureOpenAIConnection + + +class MyFlow: + def __init__(self, connection: AzureOpenAIConnection): + self._connection = connection + + def __call__(self): + assert isinstance(self._connection, AzureOpenAIConnection) + return "Dummy output"