From 1169dafcca0431fec58257624b44f087034e10e0 Mon Sep 17 00:00:00 2001 From: Lina Tang Date: Thu, 25 Apr 2024 15:29:58 +0800 Subject: [PATCH] Support passing connection provider to script executor --- .../promptflow/executor/_script_executor.py | 12 +++++---- .../promptflow/executor/flow_executor.py | 8 ++++-- src/promptflow-core/tests/conftest.py | 6 +++++ .../tests/core/e2etests/test_eager_flow.py | 25 ++++++++++++++++++- .../flow.flex.yaml | 1 + .../simple_callable_with_connection.py | 10 ++++++++ 6 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml create mode 100644 src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py diff --git a/src/promptflow-core/promptflow/executor/_script_executor.py b/src/promptflow-core/promptflow/executor/_script_executor.py index fbf1aecce86..f2716f7df45 100644 --- a/src/promptflow-core/promptflow/executor/_script_executor.py +++ b/src/promptflow-core/promptflow/executor/_script_executor.py @@ -25,6 +25,7 @@ from promptflow.contracts.flow import Flow from promptflow.contracts.tool import ConnectionType from promptflow.core import log_metric +from promptflow.core._connection_provider._dict_connection_provider import DictConnectionProvider from promptflow.core._model_configuration import ( MODEL_CONFIG_NAME_2_CLASS, AzureOpenAIModelConfiguration, @@ -47,7 +48,7 @@ class ScriptExecutor(FlowExecutor): def __init__( self, flow_file: Union[Path, str, Callable], - connections: Optional[dict] = None, + connections: Optional[Union[dict, ConnectionProvider]] = None, working_dir: Optional[Path] = None, *, storage: Optional[AbstractRunStorage] = None, @@ -57,6 +58,9 @@ def __init__( logger.debug(f"Init params for script executor: {init_kwargs}") self._flow_file = flow_file + if connections and isinstance(connections, dict): + connections = DictConnectionProvider(connections) + self._connections = connections entry = flow_file # Entry could be both a path or a callable self._entry = entry self._init_kwargs = init_kwargs or {} @@ -65,7 +69,6 @@ def __init__( else: self._working_dir = working_dir or Path.cwd() self._initialize_function() - self._connections = connections self._storage = storage or DefaultRunStorage() self._flow_id = "default_flow_id" self._log_interval = 60 @@ -308,9 +311,8 @@ def _resolve_init_kwargs(self, c: type, init_kwargs: dict): return resolved_init_kwargs - @classmethod - def _resolve_connection_params(cls, connection_params: list, init_kwargs: dict, resolved_init_kwargs: dict): - provider = ConnectionProvider.get_instance() + def _resolve_connection_params(self, connection_params: list, init_kwargs: dict, resolved_init_kwargs: dict): + provider = self._connections or ConnectionProvider.get_instance() # parse connection logger.debug(f"Resolving connection params: {connection_params}") for key in connection_params: diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index 1b84ca2a10a..88827e6e2bc 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -205,14 +205,18 @@ def create( if hasattr(flow_file, "__call__") or inspect.isfunction(flow_file): from ._script_executor import ScriptExecutor - return ScriptExecutor(flow_file, storage=storage) + return ScriptExecutor(flow_file, connections=connections, storage=storage) if not isinstance(flow_file, (Path, str)): raise NotImplementedError("Only support Path or str for flow_file.") if is_flex_flow(flow_path=flow_file, working_dir=working_dir): from ._script_executor import ScriptExecutor return ScriptExecutor( - flow_file=Path(flow_file), working_dir=working_dir, storage=storage, init_kwargs=init_kwargs + flow_file=Path(flow_file), + connections=connections, + working_dir=working_dir, + storage=storage, + init_kwargs=init_kwargs, ) elif is_prompty_flow(file_path=flow_file): from ._prompty_executor import PromptyExecutor diff --git a/src/promptflow-core/tests/conftest.py b/src/promptflow-core/tests/conftest.py index e692b29f45b..12dc15a8210 100644 --- a/src/promptflow-core/tests/conftest.py +++ b/src/promptflow-core/tests/conftest.py @@ -156,6 +156,12 @@ def setup_connection_provider(): yield +@pytest.fixture +def dev_connections() -> dict: + with open(CONNECTION_FILE, "r") as f: + return json.load(f) + + # ==================== serving fixtures ==================== diff --git a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py index 588a29a6169..a9d4e1a07c7 100644 --- a/src/promptflow-core/tests/core/e2etests/test_eager_flow.py +++ b/src/promptflow-core/tests/core/e2etests/test_eager_flow.py @@ -6,6 +6,7 @@ from promptflow._core.tool_meta_generator import PythonLoadError from promptflow.contracts.run_info import Status from promptflow.core import AzureOpenAIModelConfiguration, OpenAIModelConfiguration +from promptflow.core._connection_provider._dict_connection_provider import DictConnectionProvider from promptflow.executor._errors import FlowEntryInitializationError, InvalidFlexFlowEntry from promptflow.executor._result import LineResult from promptflow.executor._script_executor import ScriptExecutor @@ -39,7 +40,7 @@ async def func_entry_async(input_str: str) -> str: ] -@pytest.mark.usefixtures("recording_injection", "setup_connection_provider") +@pytest.mark.usefixtures("recording_injection", "setup_connection_provider", "dev_connections") @pytest.mark.e2etest class TestEagerFlow: @pytest.mark.parametrize( @@ -121,6 +122,28 @@ def test_flow_run_with_openai_chat(self): token_names = ["prompt_tokens", "completion_tokens", "total_tokens"] for token_name in token_names: assert token_name in line_result.run_info.api_calls[0]["children"][0]["system_metrics"] + assert line_result.run_info.api_calls[0]["children"][0]["system_metrics"][token_name] > 0 + + def test_flow_run_with_connection(self, dev_connections): + flow_file = get_yaml_file( + "dummy_callable_class_with_connection", root=EAGER_FLOW_ROOT, file_name="flow.flex.yaml" + ) + + # Test submitting eager flow to script executor with connection dictionary + executor = ScriptExecutor( + flow_file=flow_file, connections=dev_connections, init_kwargs={"connection": "azure_open_ai_connection"} + ) + line_result = executor.exec_line(inputs={}, index=0) + assert line_result.run_info.status == Status.Completed, line_result.run_info.error + + # Test submitting eager flow to script executor with connection provider + executor = ScriptExecutor( + flow_file=flow_file, + connections=DictConnectionProvider(dev_connections), + init_kwargs={"connection": "azure_open_ai_connection"}, + ) + line_result = executor.exec_line(inputs={}, index=0) + assert line_result.run_info.status == Status.Completed, line_result.run_info.error @pytest.mark.parametrize("entry, inputs, expected_output", function_entries) def test_flow_run_with_function_entry(self, entry, inputs, expected_output): diff --git a/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml new file mode 100644 index 00000000000..e977acc37eb --- /dev/null +++ b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/flow.flex.yaml @@ -0,0 +1 @@ +entry: simple_callable_with_connection:MyFlow \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py new file mode 100644 index 00000000000..b5122ca5d21 --- /dev/null +++ b/src/promptflow/tests/test_configs/eager_flows/dummy_callable_class_with_connection/simple_callable_with_connection.py @@ -0,0 +1,10 @@ +from promptflow.connections import AzureOpenAIConnection + + +class MyFlow: + def __init__(self, connection: AzureOpenAIConnection): + self._connection = connection + + def __call__(self): + assert isinstance(self._connection, AzureOpenAIConnection) + return "Dummy output"