Skip to content

Commit

Permalink
[trace][refactor] Move tracing related functions to a better place (#…
Browse files Browse the repository at this point in the history
…2990)

# Description

**Move tracing related functions**

- `promptflow._sdk._tracing`: tracing related function import place,
with unit tests covered and guarded
- for `promptflow-tracing`: `start_trace_with_devkit`,
`setup_exporter_to_pfs`
- for OTLP collector, runtime and others: `process_otlp_trace_request` -
parse span from Protocol Buffer
- `promptflow._sdk._utils.tracing`: utilities for tracing

Remove previous tracing utilities file `_tracing_utils.py`.

**Pass function that gets credential**

For `process_otlp_trace_request` usage, user should pass the function
that how to get credential, instead of the credential itself. However,
as the environment may not have Azure extension, so we cannot directly
pass `AzureCliCredential` in outside; so add a default logic inside the
function.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.
  • Loading branch information
zhengfeiwang authored Apr 25, 2024
1 parent 4c00cdb commit 9f2bc05
Show file tree
Hide file tree
Showing 11 changed files with 363 additions and 350 deletions.
17 changes: 5 additions & 12 deletions src/promptflow-azure/promptflow/azure/_storage/blob/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import threading
import traceback
from typing import Optional, Tuple
from typing import Callable, Tuple

from azure.ai.ml import MLClient
from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata
Expand All @@ -25,17 +25,10 @@ def get_datastore_container_client(
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object] = None,
get_credential: Callable,
) -> Tuple[ContainerClient, str]:
try:
if credential is None:
# in cloud scenario, runtime will pass in credential
# so this is local to cloud only code, happens in prompt flow service
# which should rely on Azure CLI credential only
from azure.identity import AzureCliCredential

credential = AzureCliCredential()

credential = get_credential()
datastore_definition, datastore_credential = _get_default_datastore(
subscription_id, resource_group_name, workspace_name, credential
)
Expand Down Expand Up @@ -68,7 +61,7 @@ def get_datastore_container_client(


def _get_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
subscription_id: str, resource_group_name: str, workspace_name: str, credential
) -> Tuple[Datastore, str]:

datastore_key = _get_datastore_client_key(subscription_id, resource_group_name, workspace_name)
Expand Down Expand Up @@ -103,7 +96,7 @@ def _get_datastore_client_key(subscription_id: str, resource_group_name: str, wo


def _get_aml_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
subscription_id: str, resource_group_name: str, workspace_name: str, credential
) -> Tuple[Datastore, str]:

ml_client = MLClient(
Expand Down
14 changes: 4 additions & 10 deletions src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ast
import datetime
import threading
from typing import Optional
from typing import Callable

client_map = {}
_thread_lock = threading.Lock()
Expand All @@ -18,7 +18,7 @@ def get_client(
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object] = None,
get_credential: Callable,
):
client_key = _get_db_client_key(container_name, subscription_id, resource_group_name, workspace_name)
container_client = _get_client_from_map(client_key)
Expand All @@ -28,13 +28,7 @@ def get_client(
with container_lock:
container_client = _get_client_from_map(client_key)
if container_client is None:
if credential is None:
# in cloud scenario, runtime will pass in credential
# so this is local to cloud only code, happens in prompt flow service
# which should rely on Azure CLI credential only
from azure.identity import AzureCliCredential

credential = AzureCliCredential()
credential = get_credential()
token = _get_resource_token(
container_name, subscription_id, resource_group_name, workspace_name, credential
)
Expand Down Expand Up @@ -77,7 +71,7 @@ def _get_resource_token(
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object],
credential,
) -> object:
from promptflow.azure import PFClient

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sdk_cli_azure_test.conftest import DATAS_DIR, EAGER_FLOWS_DIR, FLOWS_DIR

from promptflow._sdk._errors import RunOperationParameterError, UploadUserError, UserAuthenticationError
from promptflow._sdk._utils import parse_otel_span_status_code
from promptflow._sdk._utils.tracing import _parse_otel_span_status_code
from promptflow._sdk.entities import Run
from promptflow._sdk.operations._run_operations import RunOperations
from promptflow._utils.async_utils import async_run_allowing_running_loop
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_flex_flow_with_imported_func(self, pf: PFClient):
# TODO(3017093): won't support this for now
with pytest.raises(UserErrorException) as e:
pf.run(
flow=parse_otel_span_status_code,
flow=_parse_otel_span_status_code,
data=f"{DATAS_DIR}/simple_eager_flow_data.jsonl",
# set code folder to avoid snapshot too big
code=f"{EAGER_FLOWS_DIR}/multiple_entries",
Expand Down
37 changes: 29 additions & 8 deletions src/promptflow-devkit/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from flask import request
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest

from promptflow._sdk._tracing import process_otlp_trace_request
from promptflow._sdk._errors import MissingAzurePackage
from promptflow._sdk._tracing import _is_azure_ext_installed, process_otlp_trace_request


def trace_collector(
Expand Down Expand Up @@ -41,13 +42,33 @@ def trace_collector(
if "application/x-protobuf" in content_type:
trace_request = ExportTraceServiceRequest()
trace_request.ParseFromString(request.data)
process_otlp_trace_request(
trace_request=trace_request,
get_created_by_info_with_cache=get_created_by_info_with_cache,
logger=logger,
cloud_trace_only=cloud_trace_only,
credential=credential,
)
# this function will be called in some old runtime versions
# where runtime will pass either credential object, or the function to get credential
# as we need to be compatible with this, need to handle both cases
if credential is not None:
# local prompt flow service will not pass credential, so this is runtime scenario
get_credential = credential if callable(credential) else lambda: credential # noqa: F841
process_otlp_trace_request(
trace_request=trace_request,
get_created_by_info_with_cache=get_created_by_info_with_cache,
logger=logger,
get_credential=get_credential,
cloud_trace_only=cloud_trace_only,
)
else:
# if `promptflow-azure` is not installed, pass an exception class to the function
get_credential = MissingAzurePackage
if _is_azure_ext_installed():
from azure.identity import AzureCliCredential

get_credential = AzureCliCredential
process_otlp_trace_request(
trace_request=trace_request,
get_created_by_info_with_cache=get_created_by_info_with_cache,
logger=logger,
get_credential=get_credential,
cloud_trace_only=cloud_trace_only,
)
return "Traces received", 200

# JSON protobuf encoding
Expand Down
51 changes: 30 additions & 21 deletions src/promptflow-devkit/promptflow/_sdk/_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,8 @@
is_port_in_use,
is_run_from_built_binary,
)
from promptflow._sdk._tracing_utils import get_workspace_kind
from promptflow._sdk._utils import (
add_executable_script_to_env_path,
extract_workspace_triad_from_trace_provider,
parse_kv_from_pb_attribute,
)
from promptflow._sdk._utils import add_executable_script_to_env_path, extract_workspace_triad_from_trace_provider
from promptflow._sdk._utils.tracing import get_workspace_kind, parse_kv_from_pb_attribute, parse_protobuf_span
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow._utils.thread_utils import ThreadWithContextVars
from promptflow.tracing._integrations._openai_injector import inject_openai_api
Expand Down Expand Up @@ -559,8 +555,8 @@ def process_otlp_trace_request(
trace_request: ExportTraceServiceRequest,
get_created_by_info_with_cache: typing.Callable,
logger: logging.Logger,
get_credential: typing.Callable,
cloud_trace_only: bool = False,
credential: typing.Optional[object] = None,
):
"""Process ExportTraceServiceRequest and write data to local/remote storage.
Expand All @@ -572,13 +568,12 @@ def process_otlp_trace_request(
:type get_created_by_info_with_cache: Callable
:param logger: The logger object used for logging.
:type logger: logging.Logger
:param get_credential: A function that gets credential for Cosmos DB operation.
:type get_credential: Callable
:param cloud_trace_only: If True, only write trace to cosmosdb and skip local trace. Default is False.
:type cloud_trace_only: bool
:param credential: The credential object used to authenticate with cosmosdb. Default is None.
:type credential: Optional[object]
"""
from promptflow._sdk.entities._trace import Span
from promptflow._sdk.operations._trace_operations import TraceOperations

all_spans = []
for resource_span in trace_request.resource_spans:
Expand All @@ -596,7 +591,7 @@ def process_otlp_trace_request(
for scope_span in resource_span.scope_spans:
for span in scope_span.spans:
# TODO: persist with batch
span: Span = TraceOperations._parse_protobuf_span(span, resource=resource, logger=logger)
span: Span = parse_protobuf_span(span, resource=resource, logger=logger)
if not cloud_trace_only:
all_spans.append(copy.deepcopy(span))
span._persist()
Expand All @@ -606,12 +601,14 @@ def process_otlp_trace_request(

if cloud_trace_only:
# If we only trace to cloud, we should make sure the data writing is success before return.
_try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache, logger, credential, is_cloud_trace=True)
_try_write_trace_to_cosmosdb(
all_spans, get_created_by_info_with_cache, logger, get_credential, is_cloud_trace=True
)
else:
# Create a new thread to write trace to cosmosdb to avoid blocking the main thread
ThreadWithContextVars(
target=_try_write_trace_to_cosmosdb,
args=(all_spans, get_created_by_info_with_cache, logger, credential, False),
args=(all_spans, get_created_by_info_with_cache, logger, get_credential, False),
).start()

return
Expand All @@ -621,7 +618,7 @@ def _try_write_trace_to_cosmosdb(
all_spans: typing.List,
get_created_by_info_with_cache: typing.Callable,
logger: logging.Logger,
credential: typing.Optional[object] = None,
get_credential: typing.Callable,
is_cloud_trace: bool = False,
):
if not all_spans:
Expand Down Expand Up @@ -649,19 +646,31 @@ def _try_write_trace_to_cosmosdb(
# So, we load clients in parallel for warm up.
span_client_thread = ThreadWithContextVars(
target=get_client,
args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, credential),
args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, get_credential),
)
span_client_thread.start()

collection_client_thread = ThreadWithContextVars(
target=get_client,
args=(CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, credential),
args=(
CosmosDBContainerName.COLLECTION,
subscription_id,
resource_group_name,
workspace_name,
get_credential,
),
)
collection_client_thread.start()

line_summary_client_thread = ThreadWithContextVars(
target=get_client,
args=(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name, credential),
args=(
CosmosDBContainerName.LINE_SUMMARY,
subscription_id,
resource_group_name,
workspace_name,
get_credential,
),
)
line_summary_client_thread.start()

Expand All @@ -677,7 +686,7 @@ def _try_write_trace_to_cosmosdb(
subscription_id=subscription_id,
resource_group_name=resource_group_name,
workspace_name=workspace_name,
credential=credential,
get_credential=get_credential,
)

span_client_thread.join()
Expand All @@ -687,7 +696,7 @@ def _try_write_trace_to_cosmosdb(

created_by = get_created_by_info_with_cache()
collection_client = get_client(
CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, credential
CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, get_credential
)

collection_db = CollectionCosmosDB(first_span, is_cloud_trace, created_by)
Expand All @@ -701,7 +710,7 @@ def _try_write_trace_to_cosmosdb(
for span in all_spans:
try:
span_client = get_client(
CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, credential
CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, get_credential
)
result = SpanCosmosDB(span, collection_id, created_by).persist(
span_client, blob_container_client, blob_base_uri
Expand All @@ -713,7 +722,7 @@ def _try_write_trace_to_cosmosdb(
subscription_id,
resource_group_name,
workspace_name,
credential,
get_credential,
)
Summary(span, collection_id, created_by, logger).persist(line_summary_client)
except Exception as e:
Expand Down
Loading

0 comments on commit 9f2bc05

Please sign in to comment.