Skip to content

Commit

Permalink
Merge branch 'main' into yigao/check_cli
Browse files Browse the repository at this point in the history
  • Loading branch information
crazygao committed Mar 29, 2024
2 parents 0408819 + af44a55 commit 9fe7599
Show file tree
Hide file tree
Showing 87 changed files with 2,892 additions and 1,234 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/promptflow-import-linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ jobs:
- name: Install all packages
run: |
cd ${{ env.WORKING_DIRECTORY }}/src/promptflow-tracing
touch promptflow/__init__.py
poetry install --with dev
cd ${{ env.WORKING_DIRECTORY }}/src/promptflow-core
touch promptflow/__init__.py
poetry install --with dev
cd ${{ env.WORKING_DIRECTORY }}/src/promptflow-devkit
touch promptflow/__init__.py
poetry install --with dev
cd ${{ env.WORKING_DIRECTORY }}/src/promptflow-azure
touch promptflow/__init__.py
poetry install --with dev
working-directory: ${{ env.WORKING_DIRECTORY }}
- name: import lint
Expand Down
16 changes: 2 additions & 14 deletions .github/workflows/promptflow-sdk-cli-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
pip install ${{ github.workspace }}/src/promptflow-tracing
pip install ${{ github.workspace }}/src/promptflow-core
pip install ${{ github.workspace }}/src/promptflow-devkit[pyarrow]
gci ./promptflow -Recurse | % {if ($_.Name.Contains('.whl')) {python -m pip install "$($_.FullName)"}}
pip install ${{ github.workspace }}/src/promptflow
gci ./promptflow-tools -Recurse | % {if ($_.Name.Contains('.whl')) {python -m pip install $_.FullName}}
pip freeze
- name: install recording
Expand Down Expand Up @@ -124,9 +124,7 @@ jobs:
working-directory: artifacts
run: |
Set-PSDebug -Trace 1
pip uninstall -y promptflow promptflow-sdk promptflow-tools
gci ./promptflow -Recurse | % {if ($_.Name.Contains('.whl')) {python -m pip install "$($_.FullName)[executable]"}}
gci ./promptflow-tools -Recurse | % {if ($_.Name.Contains('.whl')) {python -m pip install $_.FullName}}
pip install ${{ github.workspace }}/src/promptflow-devkit[pyarrow,executable]
pip freeze
- name: Run SDK CLI Executable Test
shell: pwsh
Expand All @@ -138,16 +136,6 @@ jobs:
-l eastus `
-m "unittest or e2etest" `
-o "${{ env.testWorkingDirectory }}/test-results-sdk-cli-executable.xml"
- name: Install pfs
shell: pwsh
working-directory: artifacts
run: |
Set-PSDebug -Trace 1
pip uninstall -y promptflow promptflow-sdk promptflow-tools
pip install ${{ github.workspace }}/src/promptflow-devkit
gci ./promptflow -Recurse | % {if ($_.Name.Contains('.whl')) {python -m pip install "$($_.FullName)[azure]"}}
gci ./promptflow-tools -Recurse | % {if ($_.Name.Contains('.whl')) {python -m pip install $_.FullName}}
pip freeze
- name: Run PFS Test
shell: pwsh
working-directory: ${{ env.testWorkingDirectory }}
Expand Down
12 changes: 12 additions & 0 deletions docs/reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
- [promptflow](https://pypi.org/project/promptflow):
[![PyPI version](https://badge.fury.io/py/promptflow.svg)](https://badge.fury.io/py/promptflow)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/promptflow)](https://pypi.org/project/promptflow/)
- [promptflow-tracing](https://pypi.org/project/promptflow-tracing):
[![PyPI version](https://badge.fury.io/py/promptflow-tracing.svg)](https://badge.fury.io/py/promptflow-tracing)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/promptflow-tracing)](https://pypi.org/project/promptflow-tracing/)
- [promptflow-core](https://pypi.org/project/promptflow-core):
[![PyPI version](https://badge.fury.io/py/promptflow-core.svg)](https://badge.fury.io/py/promptflow-core)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/promptflow-core)](https://pypi.org/project/promptflow-core/)
- [promptflow-devkit](https://pypi.org/project/promptflow-devkit):
[![PyPI version](https://badge.fury.io/py/promptflow-devkit.svg)](https://badge.fury.io/py/promptflow-devkit)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/promptflow-devkit)](https://pypi.org/project/promptflow-devkit/)
- [promptflow-azure](https://pypi.org/project/promptflow-azure):
[![PyPI version](https://badge.fury.io/py/promptflow-azure.svg)](https://badge.fury.io/py/promptflow-azure)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/promptflow-azure)](https://pypi.org/project/promptflow-azure/)
- [promptflow-tools](https://pypi.org/project/promptflow-tools/):
[![PyPI version](https://badge.fury.io/py/promptflow-tools.svg)](https://badge.fury.io/py/promptflow-tools)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/promptflow-tools)](https://pypi.org/project/promptflow-tools/)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from streamlit_quill import st_quill # noqa: F401

from promptflow._sdk._utils import print_yellow_warning
from promptflow._utils.multimedia_utils import MIME_PATTERN, is_multimedia_dict
from promptflow._utils.multimedia_utils import MIME_PATTERN, BasicMultimediaProcessor
from promptflow.core._serving.flow_invoker import FlowInvoker

invoker = None
Expand Down Expand Up @@ -83,7 +83,7 @@ def list_iter_render_message(message_items):
st.markdown(f"`{json_dumps(message_items)},`")

def dict_iter_render_message(message_items):
if is_multimedia_dict(message_items):
if BasicMultimediaProcessor.is_multimedia_dict(message_items):
key = list(message_items.keys())[0]
value = message_items[key]
show_image(value, key)
Expand Down
1 change: 1 addition & 0 deletions scripts/docs/doc_generation.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ if (-not $SkipInstall){
pip install myst-parser==0.18.1
pip install matplotlib==3.4.3
pip install jinja2==3.0.1
pip install sqlalchemy>=2.0.0
Write-Host "===============Finished install requirements==============="
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
113 changes: 113 additions & 0 deletions src/promptflow-azure/promptflow/azure/_storage/blob/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import datetime
import logging
import threading
import traceback
from typing import Optional, Tuple

from azure.ai.ml import MLClient
from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata
from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType
from azure.ai.ml.constants._common import LONG_URI_FORMAT, STORAGE_ACCOUNT_URLS
from azure.ai.ml.entities._datastore.datastore import Datastore
from azure.storage.blob import ContainerClient

from promptflow.exceptions import UserErrorException

_datastore_cache = {}
_thread_lock = threading.Lock()
_cache_timeout = 60 * 4 # Align the cache ttl with cosmosdb client.


def get_datastore_container_client(
logger: logging.Logger,
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object] = None,
) -> Tuple[ContainerClient, str]:
try:
# To write data to blob, user should have "Storage Blob Data Contributor" to the storage account.
if credential is None:
from azure.identity import DefaultAzureCredential

credential = DefaultAzureCredential()

default_datastore = get_default_datastore(subscription_id, resource_group_name, workspace_name, credential)

storage_endpoint = _get_storage_endpoint_from_metadata()
account_url = STORAGE_ACCOUNT_URLS[DatastoreType.AZURE_BLOB].format(
default_datastore.account_name, storage_endpoint
)

# Datastore is a notion of AzureML, it is not a notion of Blob Storage.
# So, we cannot get datastore name by blob client.
# To generate the azureml uri has datastore name, we need to generate the uri here and pass in to db client.
container_client = ContainerClient(
account_url=account_url, container_name=default_datastore.container_name, credential=credential
)
blob_base_uri = LONG_URI_FORMAT.format(
subscription_id, resource_group_name, workspace_name, default_datastore.name, ""
)
if not blob_base_uri.endswith("/"):
blob_base_uri += "/"

logger.info(f"Get blob base url for {blob_base_uri}")

return container_client, blob_base_uri

except Exception as e:
stack_trace = traceback.format_exc()
logger.error(f"Failed to get blob client: {e}, stack trace is {stack_trace}")
raise


def get_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
) -> Datastore:

datastore_key = _get_datastore_client_key(subscription_id, resource_group_name, workspace_name)
datastore = _get_datastore_from_cache(datastore_key=datastore_key)
if datastore is None:
with _thread_lock:
datastore = _get_datastore_from_cache(datastore_key=datastore_key)
if datastore is None:
datastore = _get_default_datastore(subscription_id, resource_group_name, workspace_name, credential)
_datastore_cache[datastore_key] = {
"expire_at": datetime.datetime.now() + datetime.timedelta(seconds=_cache_timeout),
"datastore": datastore,
}
return datastore


def _get_datastore_from_cache(datastore_key: str):
datastore = _datastore_cache.get(datastore_key)

if datastore and datastore["expire_at"] > datetime.datetime.now():
return datastore["datastore"]

return None


def _get_datastore_client_key(subscription_id: str, resource_group_name: str, workspace_name: str) -> str:
# Azure name allow hyphens and underscores. User @ to avoid possible conflict.
return f"{subscription_id}@{resource_group_name}@{workspace_name}"


def _get_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
) -> Datastore:

ml_client = MLClient(
credential=credential,
subscription_id=subscription_id,
resource_group_name=resource_group_name,
workspace_name=workspace_name,
)

default_datastore = ml_client.datastores.get_default()
if default_datastore.type != DatastoreType.AZURE_BLOB:
raise UserErrorException(
message=f"Default datastore {default_datastore.name} is {default_datastore.type}, not AzureBlob."
)

return default_datastore
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,5 @@ def _init_container_client(endpoint: str, database_name: str, container_name: st


def _get_db_client_key(container_name: str, subscription_id: str, resource_group_name: str, workspace_name: str) -> str:
return f"{subscription_id}_{resource_group_name}_{workspace_name}_{container_name}"
# Azure name allow hyphens and underscores. User @ to avoid possible conflict.
return f"{subscription_id}@{resource_group_name}@{workspace_name}@{container_name}"
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from azure.cosmos import ContainerProxy

from promptflow._constants import SpanAttributeFieldName, SpanFieldName, SpanResourceAttributesFieldName
from promptflow._sdk._constants import CreatedByFieldName
from promptflow._constants import SpanAttributeFieldName, SpanResourceAttributesFieldName, SpanResourceFieldName
from promptflow._sdk._constants import TRACE_DEFAULT_COLLECTION, CreatedByFieldName
from promptflow._sdk.entities._trace import Span
from promptflow.azure._storage.cosmosdb.cosmosdb_utils import safe_create_cosmosdb_item

Expand Down Expand Up @@ -39,17 +39,19 @@ class CollectionCosmosDB:
def __init__(self, span: Span, is_cloud_trace: bool, created_by: Dict[str, Any]):
self.span = span
self.created_by = created_by
self.collection_name = span.session_id
self.location = LocationType.CLOUD if is_cloud_trace else LocationType.LOCAL
resource_attributes = span._content.get(SpanFieldName.RESOURCE, None)
resource_attributes = span.resource.get(SpanResourceFieldName.ATTRIBUTES, {})
self.collection_name = resource_attributes.get(
SpanResourceAttributesFieldName.COLLECTION, TRACE_DEFAULT_COLLECTION
)
self.collection_id = (
resource_attributes[SpanResourceAttributesFieldName.COLLECTION_ID]
if is_cloud_trace
else generate_collection_id_by_name_and_created_by(self.collection_name, created_by)
)

def create_collection_if_not_exist(self, client: ContainerProxy):
span_attributes = self.span._content[SpanFieldName.ATTRIBUTES]
span_attributes = self.span.attributes
# For batch run, ignore collection operation
if SpanAttributeFieldName.BATCH_RUN_ID in span_attributes:
return
Expand All @@ -71,7 +73,7 @@ def create_collection_if_not_exist(self, client: ContainerProxy):
)

def update_collection_updated_at_info(self, client: ContainerProxy):
span_attributes = self.span._content[SpanFieldName.ATTRIBUTES]
span_attributes = self.span.attributes
# For batch run, ignore collection operation
if SpanAttributeFieldName.BATCH_RUN_ID in span_attributes:
return
Expand Down
55 changes: 39 additions & 16 deletions src/promptflow-azure/promptflow/azure/_storage/cosmosdb/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json
from typing import Any, Dict

from promptflow._constants import SpanFieldName
from azure.cosmos.container import ContainerProxy
from azure.storage.blob import ContainerClient

from promptflow._constants import SpanContextFieldName, SpanEventFieldName, SpanFieldName
from promptflow._sdk.entities._trace import Span as SpanEntity


class Span:

name: str = None
context: dict = None
kind: str = None
Expand All @@ -25,38 +28,58 @@ class Span:
partition_key: str = None
collection_id: str = None
created_by: dict = None
external_event_data_uris: list = None

def __init__(self, span: SpanEntity, collection_id: str, created_by: dict) -> None:
self.name = span.name
self.context = span._content[SpanFieldName.CONTEXT]
self.kind = span._content[SpanFieldName.KIND]
self.parent_id = span.parent_span_id
self.start_time = span._content[SpanFieldName.START_TIME]
self.end_time = span._content[SpanFieldName.END_TIME]
self.status = span._content[SpanFieldName.STATUS]
self.attributes = span._content[SpanFieldName.ATTRIBUTES]
self.events = span._content[SpanFieldName.EVENTS]
self.links = span._content[SpanFieldName.LINKS]
self.resource = span._content[SpanFieldName.RESOURCE]
self.partition_key = span.session_id
self.context = span.context
self.kind = span.kind
self.parent_id = span.parent_id
self.start_time = span.start_time.isoformat()
self.end_time = span.end_time.isoformat()
self.status = span.status
self.attributes = span.attributes
self.events = span.events
self.links = span.links
self.resource = span.resource
self.partition_key = collection_id
self.collection_id = collection_id
self.id = span.span_id
self.created_by = created_by
self.external_event_data_uris = []

def persist(self, client):
def persist(self, cosmos_client: ContainerProxy, blob_container_client: ContainerClient, blob_base_uri: str):
if self.id is None or self.partition_key is None or self.resource is None:
return

resource_attributes = self.resource.get(SpanFieldName.ATTRIBUTES, None)
if resource_attributes is None:
return

if self.events and blob_container_client is not None and blob_base_uri is not None:
self._persist_events(blob_container_client, blob_base_uri)

from azure.cosmos.exceptions import CosmosResourceExistsError

try:
return client.create_item(body=self.to_dict())
return cosmos_client.create_item(body=self.to_dict())
except CosmosResourceExistsError:
return None
return

def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if v}

def _persist_events(self, blob_container_client: ContainerClient, blob_base_uri: str):
for idx, event in enumerate(self.events):
event_data = json.dumps(event)
blob_client = blob_container_client.get_blob_client(self._event_path(idx))
blob_client.upload_blob(event_data)

event[SpanEventFieldName.ATTRIBUTES] = {}
self.external_event_data_uris.append(f"{blob_base_uri}{self._event_path(idx)}")

EVENT_PATH_PREFIX = ".promptflow/.trace"

def _event_path(self, idx: int) -> str:
trace_id = self.context[SpanContextFieldName.TRACE_ID]
return f"{self.EVENT_PATH_PREFIX}/{self.collection_id}/{trace_id}/{self.id}/{idx}"
Loading

0 comments on commit 9fe7599

Please sign in to comment.