Skip to content

Commit

Permalink
[pfs] support connection provider (#1388)
Browse files Browse the repository at this point in the history
# Description
1. Add working_directory to get/get_with_secret/list connection to
support connection provider in pfs.
2. Add timestamp to pfs log

![image](https://github.com/microsoft/promptflow/assets/17938940/90b01f73-eeff-41c8-b7d3-750369755671)
3. Refine error handler


Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.



Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored Dec 7, 2023
1 parent ec8512a commit 8414d1c
Show file tree
Hide file tree
Showing 12 changed files with 284 additions and 108 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/promptflow-sdk-pfs-e2e-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@ env:
IS_IN_CI_PIPELINE: "true"

jobs:
authorize:
environment:
# forked prs from pull_request_target will be run in external environment, domain prs will be run in internal environment
${{ github.event_name == 'pull_request_target' &&
github.event.pull_request.head.repo.full_name != github.repository &&
'external' || 'internal' }}
runs-on: ubuntu-latest
steps:
- run: true
sdk_pfs_e2e_test:
needs: authorize
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -55,6 +65,11 @@ jobs:
uses: SimenB/github-actions-cpu-cores@v1
id: cpu-cores

- name: Azure Login
uses: azure/login@v1
with:
creds: ${{ secrets.AZURE_CREDENTIALS }}

- name: Run Test
shell: pwsh
working-directory: ${{ env.testWorkingDirectory }}
Expand Down
8 changes: 4 additions & 4 deletions src/promptflow/promptflow/_sdk/_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,18 @@ def _get_workspace_from_config(
)
return RESOURCE_ID_FORMAT.format(subscription_id, resource_group, AZUREML_RESOURCE_PROVIDER, workspace_name)

def get_connection_provider(self) -> Optional[str]:
def get_connection_provider(self, path=None) -> Optional[str]:
"""Get the current connection provider. Default to local if not configured."""
provider = self.get_config(key=self.CONNECTION_PROVIDER)
return self.resolve_connection_provider(provider)
return self.resolve_connection_provider(provider, path=path)

@classmethod
def resolve_connection_provider(cls, provider) -> Optional[str]:
def resolve_connection_provider(cls, provider, path=None) -> Optional[str]:
if provider is None:
return ConnectionProvider.LOCAL
if provider == ConnectionProvider.AZUREML.value:
# Note: The below function has azure-ai-ml dependency.
return "azureml:" + cls._get_workspace_from_config()
return "azureml:" + cls._get_workspace_from_config(path=path)
# If provider not None and not Azure, return it directly.
# It can be the full path of a workspace.
return provider
Expand Down
14 changes: 3 additions & 11 deletions src/promptflow/promptflow/_sdk/_pf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

from .._utils.logger_utils import LoggerFactory
from ._configuration import Configuration
from ._constants import LOGGER_NAME, MAX_SHOW_DETAILS_RESULTS, ConnectionProvider
from ._constants import LOGGER_NAME, MAX_SHOW_DETAILS_RESULTS
from ._user_agent import USER_AGENT
from ._utils import setup_user_agent_to_operation_context
from ._utils import get_connection_operation, setup_user_agent_to_operation_context
from .entities import Run
from .operations import RunOperations
from .operations._connection_operations import ConnectionOperations
from .operations._flow_operations import FlowOperations
from .operations._local_azure_connection_operations import LocalAzureConnectionOperations
from .operations._tool_operations import ToolOperations

logger = LoggerFactory.get_logger(name=LOGGER_NAME, verbosity=logging.WARNING)
Expand Down Expand Up @@ -195,14 +194,7 @@ def connections(self) -> ConnectionOperations:
"""Connection operations that can manage connections."""
if not self._connections:
self._ensure_connection_provider()
if self._connection_provider == ConnectionProvider.LOCAL.value:
logger.debug("PFClient using local connection operations.")
self._connections = ConnectionOperations()
elif self._connection_provider.startswith(ConnectionProvider.AZUREML.value):
logger.debug("PFClient using local azure connection operations.")
self._connections = LocalAzureConnectionOperations(self._connection_provider)
else:
raise ValueError(f"Unsupported connection provider: {self._connection_provider}")
self._connections = get_connection_operation(self._connection_provider)
return self._connections

@property
Expand Down
40 changes: 24 additions & 16 deletions src/promptflow/promptflow/_sdk/_service/apis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,29 @@

from flask import jsonify, request

from promptflow._sdk._errors import ConnectionNotFoundError
import promptflow._sdk.schemas._connection as connection
from promptflow._sdk._configuration import Configuration
from promptflow._sdk._service import Namespace, Resource, fields
from promptflow._sdk._service.utils.utils import local_user_only
from promptflow._sdk.entities._connection import _Connection
from promptflow._sdk.operations._connection_operations import ConnectionOperations
import promptflow._sdk.schemas._connection as connection


api = Namespace("Connections", description="Connections Management")

# azure connection
working_directory_parser = api.parser()
working_directory_parser.add_argument("working_directory", type=str, location="form", required=False)

# Response model of list connections
list_connection_field = api.model(
"Connection",
{
"name": fields.String,
"type": fields.String,
"module": fields.String,
"expiry_time": fields.DateTime(),
"created_date": fields.DateTime(),
"last_modified_date": fields.DateTime(),
"expiry_time": fields.String,
"created_date": fields.String,
"last_modified_date": fields.String,
},
)
# Response model of connection operation
Expand All @@ -48,19 +51,22 @@
)


@api.errorhandler(ConnectionNotFoundError)
def handle_connection_not_found_exception(error):
api.logger.warning(f"Raise ConnectionNotFoundError, {error.message}")
return {"error_message": error.message}, 404
def _get_connection_operation(working_directory=None):
from promptflow._sdk._utils import get_connection_operation

connection_provider = Configuration().get_connection_provider(path=working_directory)
connection_operation = get_connection_operation(connection_provider)
return connection_operation


@api.route("/")
class ConnectionList(Resource):
@api.doc(description="List all connection")
@api.doc(parser=working_directory_parser, description="List all connection")
@api.marshal_with(list_connection_field, skip_none=True, as_list=True)
@local_user_only
def get(self):
connection_op = ConnectionOperations()
args = working_directory_parser.parse_args()
connection_op = _get_connection_operation(args.working_directory)
# parse query parameters
max_results = request.args.get("max_results", default=50, type=int)
all_results = request.args.get("all_results", default=False, type=bool)
Expand All @@ -73,11 +79,12 @@ def get(self):
@api.route("/<string:name>")
@api.param("name", "The connection name.")
class Connection(Resource):
@api.doc(description="Get connection")
@api.doc(parser=working_directory_parser, description="Get connection")
@api.response(code=200, description="Connection details", model=dict_field)
@local_user_only
def get(self, name: str):
connection_op = ConnectionOperations()
args = working_directory_parser.parse_args()
connection_op = _get_connection_operation(args.working_directory)
connection = connection_op.get(name=name, raise_error=True)
connection_dict = connection._to_dict()
return jsonify(connection_dict)
Expand Down Expand Up @@ -115,11 +122,12 @@ def delete(self, name: str):

@api.route("/<string:name>/listsecrets")
class ConnectionWithSecret(Resource):
@api.doc(description="Get connection with secret")
@api.doc(parser=working_directory_parser, description="Get connection with secret")
@api.response(code=200, description="Connection details with secret", model=dict_field)
@local_user_only
def get(self, name: str):
connection_op = ConnectionOperations()
args = working_directory_parser.parse_args()
connection_op = _get_connection_operation(args.working_directory)
connection = connection_op.get(name=name, with_secrets=True, raise_error=True)
connection_dict = connection._to_dict()
return jsonify(connection_dict)
Expand Down
20 changes: 7 additions & 13 deletions src/promptflow/promptflow/_sdk/_service/apis/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from dataclasses import asdict
from pathlib import Path

import yaml
from flask import Response, jsonify, make_response, request

from promptflow._sdk._constants import FlowRunProperties, get_list_view_type
from promptflow._sdk._errors import RunNotFoundError
from promptflow._sdk._service import Namespace, Resource, fields
from promptflow._sdk.entities import Run as RunEntity
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
Expand All @@ -34,12 +34,6 @@
list_field = api.schema_model("RunList", {"type": "array", "items": {"$ref": "#/definitions/RunDict"}})


@api.errorhandler(RunNotFoundError)
def handle_run_not_found_exception(error):
api.logger.warning(f"Raise RunNotFoundError, {error.message}")
return {"error_message": error.message}, 404


@api.route("/")
class RunList(Resource):
@api.response(code=200, description="Runs", model=list_field)
Expand Down Expand Up @@ -73,18 +67,18 @@ def post(self):
run_name = run._generate_run_name()
run_dict["name"] = run_name
with tempfile.TemporaryDirectory() as temp_dir:
run_file = Path(temp_dir) / "batch_run.json"
with open(run_file, "w") as f:
json.dump(run_dict, f)
run_file = Path(temp_dir) / "batch_run.yaml"
with open(run_file, "w", encoding="utf-8") as f:
yaml.safe_dump(run_dict, f)
cmd = f"pf run create --file {run_file}"
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
_, stderr = process.communicate()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout, _ = process.communicate()
if process.returncode == 0:
run_op = RunOperations()
run = run_op.get(name=run_name)
return jsonify(run._to_dict())
else:
raise Exception(f"Create batch run failed: {stderr}")
raise Exception(f"Create batch run failed: {stdout.decode('utf-8')}")


@api.route("/<string:name>")
Expand Down
26 changes: 16 additions & 10 deletions src/promptflow/promptflow/_sdk/_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import logging
from logging.handlers import RotatingFileHandler

from flask import Blueprint, Flask, jsonify
from werkzeug.exceptions import HTTPException
Expand All @@ -10,12 +11,12 @@
from promptflow._sdk._service import Api
from promptflow._sdk._service.apis.connection import api as connection_api
from promptflow._sdk._service.apis.run import api as run_api
from promptflow._sdk._service.utils.utils import FormattedException
from promptflow._sdk._utils import get_promptflow_sdk_version, read_write_by_user
from promptflow.exceptions import UserErrorException


def heartbeat():
response = {"sdk_version": get_promptflow_sdk_version()}
response = {"promptflow": get_promptflow_sdk_version()}
return jsonify(response)


Expand All @@ -38,19 +39,24 @@ def create_app():
app.logger.setLevel(logging.INFO)
log_file = HOME_PROMPT_FLOW_DIR / PF_SERVICE_LOG_FILE
log_file.touch(mode=read_write_by_user(), exist_ok=True)
handler = logging.FileHandler(filename=log_file)
# Create a rotating file handler with a max size of 1 MB and keeping up to 1 backup files
handler = RotatingFileHandler(filename=log_file, maxBytes=1_000_000, backupCount=1)
formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] - %(message)s")
handler.setFormatter(formatter)
app.logger.addHandler(handler)

# Basic error handler
@app.errorhandler(Exception)
@api.errorhandler(Exception)
def handle_exception(e):
from dataclasses import asdict

if isinstance(e, HTTPException):
return e
return asdict(FormattedException(e), dict_factory=lambda x: {k: v for (k, v) in x if v}), e.code
app.logger.error(e, exc_info=True, stack_info=True)
if isinstance(e, UserErrorException):
error_info = e.message
else:
error_info = str(e)
return jsonify({"error_message": f"Internal Server Error, {error_info}"}), 500
formatted_exception = FormattedException(e)
return (
asdict(formatted_exception, dict_factory=lambda x: {k: v for (k, v) in x if v}),
formatted_exception.status_code,
)

return app, api
Loading

0 comments on commit 8414d1c

Please sign in to comment.