Skip to content

Commit

Permalink
[Internal] Refine connection usage (#2802)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.
Related: #2723

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Signed-off-by: Brynn Yin <[email protected]>
  • Loading branch information
brynn-code authored Apr 15, 2024
1 parent bd8e41c commit 9f51888
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
_ScopeDependentOperations,
)

from promptflow._sdk._errors import ConnectionClassNotFoundError
from promptflow._sdk.entities._connection import CustomConnection, _Connection
from promptflow._sdk.entities._connection import _Connection
from promptflow.azure._restclient.flow_service_caller import FlowServiceCaller
from promptflow.core._connection_provider._workspace_connection_provider import WorkspaceConnectionProvider
from promptflow.core._errors import OpenURLFailedUserError
Expand Down Expand Up @@ -45,29 +44,8 @@ def __init__(
self._credential,
)

@classmethod
def _convert_core_connection_to_sdk_connection(cls, core_conn):
# TODO: Refine this and connection operation ones to (devkit) _Connection._from_core_object
sdk_conn_mapping = _Connection.SUPPORTED_TYPES
sdk_conn_cls = sdk_conn_mapping.get(core_conn.type)
if sdk_conn_cls is None:
raise ConnectionClassNotFoundError(
f"Correspond sdk connection type not found for core connection type: {core_conn.type!r}, "
f"please re-install the 'promptflow' package."
)
common_args = {
"name": core_conn.name,
"module": core_conn.module,
"expiry_time": core_conn.expiry_time,
"created_date": core_conn.created_date,
"last_modified_date": core_conn.last_modified_date,
}
if sdk_conn_cls is CustomConnection:
return sdk_conn_cls(configs=core_conn.configs, secrets=core_conn.secrets, **common_args)
return sdk_conn_cls(**dict(core_conn), **common_args)

def get(self, name, **kwargs):
return self._convert_core_connection_to_sdk_connection(self._provider.get(name))
return _Connection._from_core_connection(self._provider.get(name))

@classmethod
def _direct_get(cls, name, subscription_id, resource_group_name, workspace_name, credential):
Expand All @@ -76,7 +54,7 @@ def _direct_get(cls, name, subscription_id, resource_group_name, workspace_name,
permission(workspace/list secrets). As create azure pf_client requires workspace read permission.
"""
provider = WorkspaceConnectionProvider(subscription_id, resource_group_name, workspace_name, credential)
return provider.get(name=name)
return _Connection._from_core_connection(provider.get(name=name))

# Keep this as promptflow tools is using this method
_build_connection_dict = WorkspaceConnectionProvider._build_connection_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@ def test_build_azure_openai_connection_from_rest_object(self):
}
build_from_data_and_assert(data, expected)

def test_build_legacy_openai_connection_from_rest_object(self):
# Legacy OpenAI connection with type in metadata
# Test this not convert to CustomConnection
data = {
"id": "mock_id",
"name": "legacy_open_ai",
"type": "Microsoft.MachineLearningServices/workspaces/connections",
"properties": {
"authType": "CustomKeys",
"credentials": {"keys": {"api_key": "***"}},
"category": "CustomKeys",
"target": "<api-base>",
"metadata": {
"azureml.flow.connection_type": "OpenAI",
"azureml.flow.module": "promptflow.connections",
"organization": "mock",
},
},
}
expected = {
"type": "OpenAIConnection",
"module": "promptflow.connections",
"name": "legacy_open_ai",
"value": {"api_key": "***", "organization": "mock"},
}
build_from_data_and_assert(data, expected)

def test_build_strong_type_openai_connection_from_rest_object(self):
data = {
"id": "mock_id",
Expand Down
25 changes: 24 additions & 1 deletion src/promptflow-devkit/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SCRUBBED_VALUE_USER_INPUT,
ConfigValueType,
)
from promptflow._sdk._errors import SDKError, UnsecureConnectionError
from promptflow._sdk._errors import ConnectionClassNotFoundError, SDKError, UnsecureConnectionError
from promptflow._sdk._orm.connection import Connection as ORMConnection
from promptflow._sdk._utils import (
decrypt_secret_value,
Expand Down Expand Up @@ -143,6 +143,29 @@ def _from_mt_rest_object(cls, mt_rest_obj) -> "_Connection":
obj = type_cls._from_mt_rest_object(mt_rest_obj)
return obj

@classmethod
def _from_core_connection(cls, core_conn) -> "_Connection":
if isinstance(core_conn, _Connection):
# Already a sdk connection, return.
return core_conn
sdk_conn_mapping = _Connection.SUPPORTED_TYPES
sdk_conn_cls = sdk_conn_mapping.get(core_conn.type)
if sdk_conn_cls is None:
raise ConnectionClassNotFoundError(
f"Correspond sdk connection type not found for core connection type: {core_conn.type!r}, "
f"please re-install the 'promptflow' package."
)
common_args = {
"name": core_conn.name,
"module": core_conn.module,
"expiry_time": core_conn.expiry_time,
"created_date": core_conn.created_date,
"last_modified_date": core_conn.last_modified_date,
}
if sdk_conn_cls is CustomConnection:
return sdk_conn_cls(configs=core_conn.configs, secrets=core_conn.secrets, **common_args)
return sdk_conn_cls(**dict(core_conn), **common_args)

@classmethod
def _from_orm_object_with_secrets(cls, orm_object: ORMConnection):
# !!! Attention !!!: Do not use this function to user facing api, use _from_orm_object to remove secrets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import List, Type, TypeVar

from promptflow._sdk._constants import MAX_LIST_CLI_RESULTS
from promptflow._sdk._errors import ConnectionClassNotFoundError, ConnectionNameNotSetError
from promptflow._sdk._errors import ConnectionNameNotSetError
from promptflow._sdk._orm import Connection as ORMConnection
from promptflow._sdk._telemetry import ActivityType, TelemetryMixin, monitor_operation
from promptflow._sdk._utils import safe_parse_object_list
from promptflow._sdk.entities._connection import CustomConnection, _Connection
from promptflow._sdk.entities._connection import _Connection
from promptflow.connections import _Connection as _CoreConnection

T = TypeVar("T", bound="_Connection")
Expand Down Expand Up @@ -73,26 +73,6 @@ def delete(self, name: str) -> None:
"""
ORMConnection.delete(name)

@classmethod
def _convert_core_connection_to_sdk_connection(cls, core_conn):
sdk_conn_mapping = _Connection.SUPPORTED_TYPES
sdk_conn_cls = sdk_conn_mapping.get(core_conn.type)
if sdk_conn_cls is None:
raise ConnectionClassNotFoundError(
f"Correspond sdk connection type not found for core connection type: {core_conn.type!r}, "
f"please re-install the 'promptflow' package."
)
common_args = {
"name": core_conn.name,
"module": core_conn.module,
"expiry_time": core_conn.expiry_time,
"created_date": core_conn.created_date,
"last_modified_date": core_conn.last_modified_date,
}
if sdk_conn_cls is CustomConnection:
return sdk_conn_cls(configs=core_conn.configs, secrets=core_conn.secrets, **common_args)
return sdk_conn_cls(**dict(core_conn), **common_args)

@monitor_operation(activity_name="pf.connections.create_or_update", activity_type=ActivityType.PUBLICAPI)
def create_or_update(self, connection: Type[_Connection], **kwargs):
"""Create or update a connection.
Expand All @@ -103,7 +83,7 @@ def create_or_update(self, connection: Type[_Connection], **kwargs):
if not connection.name:
raise ConnectionNameNotSetError("Name is required to create or update connection.")
if isinstance(connection, _CoreConnection) and not isinstance(connection, _Connection):
connection = self._convert_core_connection_to_sdk_connection(connection)
connection = _Connection._from_core_connection(connection)
orm_object = connection._to_orm_object()
now = datetime.now().isoformat()
if orm_object.createdDate is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
WeaviateConnection,
_Connection,
)
from promptflow._sdk.operations._connection_operations import ConnectionOperations
from promptflow._utils.yaml_utils import load_yaml
from promptflow.core._connection import RequiredEnvironmentVariablesNotSetError
from promptflow.exceptions import UserErrorException
Expand Down Expand Up @@ -484,7 +483,7 @@ def test_convert_core_connection_to_sdk_connection(self):
"api_version": "2023-07-01-preview",
}
connection = CoreAzureOpenAIConnection(**connection_args)
sdk_connection = ConnectionOperations._convert_core_connection_to_sdk_connection(connection)
sdk_connection = _Connection._from_core_connection(connection)
assert isinstance(sdk_connection, AzureOpenAIConnection)
assert sdk_connection._to_dict() == {
"module": "promptflow.connections",
Expand All @@ -501,12 +500,12 @@ def test_convert_core_connection_to_sdk_connection(self):
"secrets": {"b": "2"},
}
connection = CoreCustomConnection(**connection_args)
sdk_connection = ConnectionOperations._convert_core_connection_to_sdk_connection(connection)
sdk_connection = _Connection._from_core_connection(connection)
assert isinstance(sdk_connection, CustomConnection)
assert sdk_connection._to_dict() == {"module": "promptflow.connections", "type": "custom", **connection_args}

# Bad case
connection = CoreCustomConnection(**connection_args)
connection.type = "unknown"
with pytest.raises(ConnectionClassNotFoundError):
ConnectionOperations._convert_core_connection_to_sdk_connection(connection)
_Connection._from_core_connection(connection)

0 comments on commit 9f51888

Please sign in to comment.