diff --git a/src/promptflow/promptflow/_cli/_pf/_flow.py b/src/promptflow/promptflow/_cli/_pf/_flow.py index 43bccc8bb86..bd6c75373de 100644 --- a/src/promptflow/promptflow/_cli/_pf/_flow.py +++ b/src/promptflow/promptflow/_cli/_pf/_flow.py @@ -6,6 +6,7 @@ import importlib import json import os +import shutil import subprocess import sys import tempfile @@ -40,6 +41,7 @@ from promptflow._constants import LANGUAGE_KEY, FlowLanguage from promptflow._sdk._constants import PROMPT_FLOW_DIR_NAME, ConnectionProvider from promptflow._sdk._pf_client import PFClient +from promptflow._sdk.operations._flow_operations import FlowOperations from promptflow._utils.logger_utils import get_cli_sdk_logger DEFAULT_CONNECTION = "open_ai_connection" @@ -467,6 +469,23 @@ def serve_flow_csharp(args, source): pass +def _resolve_python_flow_additional_includes(source) -> Path: + # Resolve flow additional includes + from promptflow import load_flow + + flow = load_flow(source) + with FlowOperations._resolve_additional_includes(flow.path) as resolved_flow_path: + if resolved_flow_path == flow.path: + return source + # Copy resolved flow to temp folder if additional includes exists + # Note: DO NOT use resolved flow path directly, as when inner logic raise exception, + # temp dir will fail due to file occupied by other process. + temp_flow_path = Path(tempfile.TemporaryDirectory().name) + shutil.copytree(src=resolved_flow_path.parent, dst=temp_flow_path, dirs_exist_ok=True) + + return temp_flow_path + + def serve_flow_python(args, source): from promptflow._sdk._serving.app import create_app @@ -474,7 +493,8 @@ def serve_flow_python(args, source): if static_folder: static_folder = Path(static_folder).absolute().as_posix() config = list_of_dict_to_dict(args.config) - # Change working directory to model dir + source = _resolve_python_flow_additional_includes(source) + os.environ["PROMPTFLOW_PROJECT_PATH"] = source.absolute().as_posix() logger.info(f"Change working directory to model dir {source}") os.chdir(source) app = create_app( diff --git a/src/promptflow/promptflow/_sdk/operations/_flow_operations.py b/src/promptflow/promptflow/_sdk/operations/_flow_operations.py index 17acff8cf63..358e0422192 100644 --- a/src/promptflow/promptflow/_sdk/operations/_flow_operations.py +++ b/src/promptflow/promptflow/_sdk/operations/_flow_operations.py @@ -611,12 +611,18 @@ def build( env_var_names=env_var_names, ) + @classmethod @contextlib.contextmanager def _resolve_additional_includes(cls, flow_dag_path: Path) -> Iterable[Path]: # TODO: confirm if we need to import this from promptflow._sdk._submitter import remove_additional_includes - if _get_additional_includes(flow_dag_path): + # Eager flow may not contain a yaml file, skip resolving additional includes + def is_yaml_file(file_path): + _, file_extension = os.path.splitext(file_path) + return file_extension.lower() in (".yaml", ".yml") + + if is_yaml_file(flow_dag_path) and _get_additional_includes(flow_dag_path): # Merge the flow folder and additional includes to temp folder. # TODO: support a flow_dag_path with a name different from flow.dag.yaml with _merge_local_code_and_additional_includes(code_path=flow_dag_path.parent) as temp_dir: diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_flow_serve.py b/src/promptflow/tests/sdk_cli_test/unittests/test_flow_serve.py new file mode 100644 index 00000000000..e7b8ad5f479 --- /dev/null +++ b/src/promptflow/tests/sdk_cli_test/unittests/test_flow_serve.py @@ -0,0 +1,22 @@ +from pathlib import Path + +import pytest +from sdk_cli_test.conftest import MODEL_ROOT + +from promptflow._cli._pf._flow import _resolve_python_flow_additional_includes + + +@pytest.mark.unittest +def test_flow_serve_resolve_additional_includes(): + # Assert flow path not changed if no additional includes + flow_path = (Path(MODEL_ROOT) / "web_classification").resolve().absolute().as_posix() + resolved_flow_path = _resolve_python_flow_additional_includes(flow_path) + assert flow_path == resolved_flow_path + + # Assert additional includes are resolved correctly + flow_path = (Path(MODEL_ROOT) / "web_classification_with_additional_include").resolve().absolute().as_posix() + resolved_flow_path = _resolve_python_flow_additional_includes(flow_path) + + assert (Path(resolved_flow_path) / "convert_to_dict.py").exists() + assert (Path(resolved_flow_path) / "fetch_text_content_from_url.py").exists() + assert (Path(resolved_flow_path) / "summarize_text_content.jinja2").exists()