Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Connection] Support flexflow with connection config #3012

Merged
merged 10 commits into from
Apr 26, 2024
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import os

import mock
import pytest

Expand All @@ -20,6 +22,8 @@
class TestPFClient:
# Test pf client when connection provider is azureml.
# This tests suites need azure dependencies.
# Mock os.environ to avoid this test affecting other tests
@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.skipif(condition=not pytest.is_live, reason="This test requires an actual PFClient")
def test_connection_provider(self, subscription_id: str, resource_group_name: str, workspace_name: str):
target = "promptflow._sdk._pf_client.Configuration"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class ConnectionProvider(ABC):
"""The connection provider interface to list/get connections in the current environment."""

PROVIDER_CONFIG_KEY = "CONNECTION_PROVIDER_CONFIG"
PROVIDER_CONFIG_KEY = "PF_CONNECTION_PROVIDER"
_instance = None

@abstractmethod
Expand All @@ -28,12 +28,12 @@ def list(self, **kwargs) -> List[_Connection]:
raise NotImplementedError("Method 'list' is not implemented.")

@classmethod
def get_instance(cls) -> "ConnectionProvider":
def get_instance(cls, **kwargs) -> "ConnectionProvider":
"""Get the connection provider instance in the current environment.
It will return different implementations based on the current environment.
"""
if not cls._instance:
cls._instance = cls._init_from_env()
cls._instance = cls._init_from_env(**kwargs)
return cls._instance

@classmethod
Expand Down Expand Up @@ -63,12 +63,12 @@ def init_from_provider_config(cls, provider_config: str, credential=None):
)

@classmethod
def _init_from_env(cls) -> "ConnectionProvider":
def _init_from_env(cls, **kwargs) -> "ConnectionProvider":
"""Initialize the connection provider from environment variables."""
from ._http_connection_provider import HttpConnectionProvider

endpoint = os.getenv(HttpConnectionProvider.ENDPOINT_KEY)
if endpoint:
return HttpConnectionProvider(endpoint)
provider_config = os.getenv(cls.PROVIDER_CONFIG_KEY, "")
return ConnectionProvider.init_from_provider_config(provider_config)
return ConnectionProvider.init_from_provider_config(provider_config, credential=kwargs.get("credential"))
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from promptflow.contracts.tool import ConnectionType
from promptflow.contracts.types import Secret

from .._errors import ConnectionNotFound
from ._connection_provider import ConnectionProvider


Expand Down Expand Up @@ -98,8 +99,14 @@ def list(self):
return [c for c in self._connections.values()]

def get(self, name: str) -> Any:
if isinstance(name, str):
return self._connections.get(name)
elif ConnectionType.is_connection_value(name):
if ConnectionType.is_connection_value(name):
return name
return None
connection = None
if isinstance(name, str):
connection = self._connections.get(name)
if not connection:
raise ConnectionNotFound(
f"Connection {name!r} not found in dict connection provider. "
f"Available keys are {list(self._connections.keys())}."
)
return connection
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def is_github_codespaces():
return os.environ.get("CODESPACES", None) == "true"


def interactive_credential_disabled():
"""Check if interactive login is disabled."""
return os.environ.get(PF_NO_INTERACTIVE_LOGIN, "false").lower() == "true"
def interactive_credential_enabled():
"""Check if interactive login is enabled."""
return os.environ.get(PF_NO_INTERACTIVE_LOGIN, "true").lower() == "false"


def is_from_cli():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ..._utils.credential_utils import get_default_azure_credential
from ._connection_provider import ConnectionProvider
from ._utils import interactive_credential_disabled, is_from_cli, is_github_codespaces
from ._utils import interactive_credential_enabled, is_from_cli, is_github_codespaces

GET_CONNECTION_URL = (
"/subscriptions/{sub}/resourcegroups/{rg}/providers/Microsoft.MachineLearningServices"
Expand Down Expand Up @@ -111,14 +111,14 @@ def _get_credential(cls):
get_arm_token(credential=credential)
except Exception:
raise AccountNotSetUp()
if interactive_credential_disabled():
return DefaultAzureCredential(exclude_interactive_browser_credential=True)
if interactive_credential_enabled():
return DefaultAzureCredential(exclude_interactive_browser_credential=False)
if is_github_codespaces():
# For code spaces, append device code credential as the fallback option.
credential = DefaultAzureCredential()
credential.credentials = (*credential.credentials, DeviceCodeCredential())
return credential
return DefaultAzureCredential(exclude_interactive_browser_credential=False)
return DefaultAzureCredential(exclude_interactive_browser_credential=True)

@classmethod
def open_url(cls, token, url, action, host="management.azure.com", method="GET", model=None) -> Union[Any, dict]:
Expand Down
4 changes: 4 additions & 0 deletions src/promptflow-core/promptflow/core/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class InvalidSampleError(CoreError):
pass


class ConnectionNotFound(CoreError):
pass


class OpenURLUserAuthenticationError(UserAuthenticationError):
def __init__(self, **kwargs):
super().__init__(target=ErrorTarget.CORE, **kwargs)
Expand Down
10 changes: 9 additions & 1 deletion src/promptflow-core/promptflow/core/_serving/flow_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,22 @@ def _init_connections(self, connection_provider):
connections_to_ignore.extend(self.connections_name_overrides.keys())
self.logger.debug(f"Flow invoker connections name overrides: {self.connections_name_overrides.keys()}")
self.logger.debug(f"Ignoring connections: {connections_to_ignore}")
if not connection_provider:
# If user not pass in connection provider string, get from environment variable.
connection_provider = ConnectionProvider.get_instance(credential=self._credential)
else:
# Else, init from the string to parse the provider config.
connection_provider = ConnectionProvider.init_from_provider_config(
brynn-code marked this conversation as resolved.
Show resolved Hide resolved
connection_provider, credential=self._credential
)
# Note: The connection here could be local or workspace, depends on the connection.provider in pf.yaml.
connections = self.resolve_connections(
# use os.environ to override flow definition's connection since
# os.environ is resolved to user's setting now
connection_names=self.flow.get_connection_names(
environment_variables_overrides=os.environ,
),
provider=ConnectionProvider.init_from_provider_config(connection_provider, credential=self._credential),
provider=connection_provider,
connections_to_ignore=connections_to_ignore,
# fetch connections with name override
connections_to_add=list(self.connections_name_overrides.values()),
Expand Down
17 changes: 15 additions & 2 deletions src/promptflow-core/promptflow/executor/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,21 @@ def __init__(
)


class ConnectionNotFound(InvalidRequest):
pass
class GetConnectionError(InvalidRequest):
def __init__(
self,
connection: str,
node_name: str,
error: Exception,
**kwargs,
):
super().__init__(
message_format="Get connection '{connection}' for node '{node_name}' error: {error}",
connection=connection,
node_name=node_name,
error=str(error),
target=ErrorTarget.EXECUTOR,
)


class InvalidBulkTestRequest(ValidationException):
Expand Down
33 changes: 17 additions & 16 deletions src/promptflow-core/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from promptflow._constants import MessageFormatType
from promptflow._core._errors import InvalidSource
from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR, INPUTS_TO_ESCAPE_PARAM_KEY, TOOL_TYPE_TO_ESCAPE
from promptflow._core.tool import INPUTS_TO_ESCAPE_PARAM_KEY, STREAMING_OPTION_PARAMETER_ATTR, TOOL_TYPE_TO_ESCAPE
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._utils.multimedia_utils import MultimediaProcessor
from promptflow._utils.tool_utils import (
Expand All @@ -36,9 +36,9 @@
)
from promptflow.executor._docstring_parser import DocstringParser
from promptflow.executor._errors import (
ConnectionNotFound,
EmptyLLMApiMapping,
FailedToGenerateToolDefinition,
GetConnectionError,
InvalidAssistantTool,
InvalidConnectionType,
InvalidCustomLLMTool,
Expand Down Expand Up @@ -83,9 +83,11 @@ def start_resolver(
return resolver

def _convert_to_connection_value(self, k: str, v: InputAssignment, node_name: str, conn_types: List[ValueType]):
connection_value = self._connection_provider.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node_name!r} input {k!r}.")
try:
connection_value = self._connection_provider.get(v.value)
except Exception as e: # Cache all exception as different provider raises different exceptions
# Raise new error with node details
raise GetConnectionError(v.value, node_name, e) from e
# Check if type matched
if not any(type(connection_value).__name__ == typ for typ in conn_types):
msg = (
Expand All @@ -108,9 +110,11 @@ def _convert_to_custom_strong_type_connection_value(
if not conn_types:
msg = f"Input '{k}' for node '{node_name}' has invalid types: {conn_types}."
raise NodeInputValidationError(message=msg)
connection_value = self._connection_provider.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node_name!r} input {k!r}.")
try:
connection_value = self._connection_provider.get(v.value)
except Exception as e: # Cache all exception as different provider raises different exceptions
# Raise new error with node details
raise GetConnectionError(v.value, node_name, e) from e

custom_defined_connection_class_name = conn_types[0]
source_type = getattr(source, "type", None)
Expand Down Expand Up @@ -478,14 +482,11 @@ def _remove_init_args(node_inputs: dict, init_args: dict):
del node_inputs[k]

def _get_llm_node_connection(self, node: Node):
connection = self._connection_provider.get(node.connection)
if connection is None:
raise ConnectionNotFound(
message_format="Connection '{connection}' of LLM node '{node_name}' is not found.",
connection=node.connection,
node_name=node.name,
target=ErrorTarget.EXECUTOR,
)
try:
connection = self._connection_provider.get(node.connection)
except Exception as e: # Cache all exception as different provider raises different exceptions
# Raise new error with node details
raise GetConnectionError(node.connection, node.name, e) from e
return connection

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions src/promptflow-devkit/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# promptflow-devkit package

## v1.11.0 (Upcoming)

### Improvements
- Interactive browser credential is excluded by default when using Azure AI connections, user could set `PF_NO_INTERACTIVE_LOGIN=False` to enable it.

## v1.10.0 (Upcoming)

### Features Added
Expand Down
6 changes: 5 additions & 1 deletion src/promptflow-devkit/promptflow/_sdk/_pf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,11 @@ def _ensure_connection_provider(self) -> str:
if not self._connection_provider:
# Get a copy with config override instead of the config instance
self._connection_provider = Configuration(overrides=self._config).get_connection_provider()
logger.debug("PFClient connection provider: %s", self._connection_provider)
logger.debug("PFClient connection provider: %s, setting to env.", self._connection_provider)
from promptflow.core._connection_provider._connection_provider import ConnectionProvider

# Set to os.environ for connection provider to use
os.environ[ConnectionProvider.PROVIDER_CONFIG_KEY] = self._connection_provider
return self._connection_provider

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from promptflow._utils.credential_utils import get_default_azure_credential
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.core._connection_provider._utils import (
interactive_credential_disabled,
interactive_credential_enabled,
is_from_cli,
is_github_codespaces,
)
Expand Down Expand Up @@ -73,14 +73,14 @@ def _get_credential(cls):
"See https://docs.microsoft.com/cli/azure/authenticate-azure-cli for more details."
)
sys.exit(1)
if interactive_credential_disabled():
return DefaultAzureCredential(exclude_interactive_browser_credential=True)
if interactive_credential_enabled():
return DefaultAzureCredential(exclude_interactive_browser_credential=False)
if is_github_codespaces():
# For code spaces, append device code credential as the fallback option.
credential = DefaultAzureCredential()
credential = DefaultAzureCredential(exclude_interactive_browser_credential=True)
credential.credentials = (*credential.credentials, DeviceCodeCredential())
return credential
return DefaultAzureCredential(exclude_interactive_browser_credential=False)
return DefaultAzureCredential(exclude_interactive_browser_credential=True)

@monitor_operation(activity_name="pf.connections.azure.list", activity_type=ActivityType.PUBLICAPI)
def list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from promptflow._sdk._load_functions import load_flow
from promptflow._sdk.entities._flows._flow_context_resolver import FlowContextResolver
from promptflow.contracts.run_info import Status
from promptflow.core import Prompty
from promptflow.core._connection_provider._workspace_connection_provider import WorkspaceConnectionProvider
from promptflow.executor._script_executor import ScriptExecutor

TEST_CONFIG_DIR = PROMPTFLOW_ROOT / "tests" / "test_configs"
FLOWS_DIR = TEST_CONFIG_DIR / "flows"
DATAS_DIR = TEST_CONFIG_DIR / "datas"
PROMPTY_DIR = TEST_CONFIG_DIR / "prompty"
EAGER_FLOW_ROOT = TEST_CONFIG_DIR / "eager_flows"


@pytest.mark.usefixtures("global_config")
Expand Down Expand Up @@ -54,3 +57,14 @@ def test_prompty_callable(self, pf):
prompty = Prompty.load(source=f"{PROMPTY_DIR}/prompty_example.prompty")
result = prompty(question="what is the result of 1+1?")
assert "2" in result

def test_flex_flow_run_with_openai_chat(self, pf):
# Test flex flow run successfully with global config ws connection
flow_file = EAGER_FLOW_ROOT / "callable_class_with_openai" / "flow.flex.yaml"
pf._ensure_connection_provider()
executor = ScriptExecutor(flow_file=flow_file, init_kwargs={"connection": "azure_open_ai_connection"})
line_result = executor.exec_line(inputs={"question": "Hello", "stream": False}, index=0)
assert line_result.run_info.status == Status.Completed, line_result.run_info.error
token_names = ["prompt_tokens", "completion_tokens", "total_tokens"]
for token_name in token_names:
assert token_name in line_result.run_info.api_calls[0]["children"][0]["system_metrics"]
5 changes: 5 additions & 0 deletions src/promptflow/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Release History

## v1.11.0 (Upcoming)

### Improvements
- [promptflow-devkit]: Interactive browser credential is excluded by default when using Azure AI connections, user could set `PF_NO_INTERACTIVE_LOGIN=False` to enable it.

## v1.10.0 (Upcoming)
### Features Added
- [promptflow-devkit]: Expose --ui to trigger a chat window, reach [here](https://microsoft.github.io/promptflow/reference/pf-command-reference.html#pf-flow-test) for more details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from promptflow.batch._result import BatchResult
from promptflow.contracts.run_info import Status
from promptflow.exceptions import ErrorTarget, ValidationException
from promptflow.executor._errors import ConnectionNotFound
from promptflow.executor._errors import GetConnectionError
from promptflow.storage._run_storage import AbstractRunStorage

from ..mock_execution_server import run_executor_server
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_batch_execution_error(self):
def test_batch_validation_error(self):
# prepare the init error file to mock the validation error
error_message = "'test_connection' not found."
test_exception = ConnectionNotFound(message=error_message)
test_exception = GetConnectionError(message=error_message)
error_dict = ExceptionPresenter.create(test_exception).to_dict()
init_error_file = Path(mkdtemp()) / "init_error.json"
with open(init_error_file, "w") as file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from promptflow.contracts.run_info import Status
from promptflow.exceptions import UserErrorException
from promptflow.executor import FlowExecutor
from promptflow.executor._errors import ConnectionNotFound, InputTypeError, ResolveToolError
from promptflow.executor._errors import GetConnectionError, InputTypeError, ResolveToolError
from promptflow.executor.flow_executor import execute_flow
from promptflow.storage._run_storage import DefaultRunStorage

Expand Down Expand Up @@ -190,8 +190,11 @@ def test_executor_node_overrides(self, dev_connections):
node_override={"classify_with_llm.connection": "dummy_connection"},
raise_ex=True,
)
assert isinstance(e.value.inner_exception, ConnectionNotFound)
assert "Connection 'dummy_connection' of LLM node 'classify_with_llm' is not found." in str(e.value)
assert isinstance(e.value.inner_exception, GetConnectionError)
assert (
"Get connection 'dummy_connection' for node 'classify_with_llm' "
"error: Connection 'dummy_connection' not found" in str(e.value)
)

@pytest.mark.parametrize(
"flow_folder",
Expand Down
Loading
Loading