Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trace][refactor] Move tracing related functions to a better place #2990

Merged
merged 15 commits into from
Apr 25, 2024
17 changes: 5 additions & 12 deletions src/promptflow-azure/promptflow/azure/_storage/blob/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import threading
import traceback
from typing import Optional, Tuple
from typing import Callable, Tuple

from azure.ai.ml import MLClient
from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata
Expand All @@ -25,17 +25,10 @@ def get_datastore_container_client(
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object] = None,
get_credential: Callable,
) -> Tuple[ContainerClient, str]:
try:
if credential is None:
# in cloud scenario, runtime will pass in credential
# so this is local to cloud only code, happens in prompt flow service
# which should rely on Azure CLI credential only
from azure.identity import AzureCliCredential

credential = AzureCliCredential()

credential = get_credential()
datastore_definition, datastore_credential = _get_default_datastore(
subscription_id, resource_group_name, workspace_name, credential
)
Expand Down Expand Up @@ -68,7 +61,7 @@ def get_datastore_container_client(


def _get_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
subscription_id: str, resource_group_name: str, workspace_name: str, credential
) -> Tuple[Datastore, str]:

datastore_key = _get_datastore_client_key(subscription_id, resource_group_name, workspace_name)
Expand Down Expand Up @@ -103,7 +96,7 @@ def _get_datastore_client_key(subscription_id: str, resource_group_name: str, wo


def _get_aml_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
subscription_id: str, resource_group_name: str, workspace_name: str, credential
) -> Tuple[Datastore, str]:

ml_client = MLClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ast
import datetime
import threading
from typing import Optional
from typing import Callable

client_map = {}
_thread_lock = threading.Lock()
Expand All @@ -18,7 +18,7 @@ def get_client(
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object] = None,
get_credential: Callable,
):
client_key = _get_db_client_key(container_name, subscription_id, resource_group_name, workspace_name)
container_client = _get_client_from_map(client_key)
Expand All @@ -28,13 +28,7 @@ def get_client(
with container_lock:
container_client = _get_client_from_map(client_key)
if container_client is None:
if credential is None:
# in cloud scenario, runtime will pass in credential
# so this is local to cloud only code, happens in prompt flow service
# which should rely on Azure CLI credential only
from azure.identity import AzureCliCredential

credential = AzureCliCredential()
credential = get_credential()
token = _get_resource_token(
container_name, subscription_id, resource_group_name, workspace_name, credential
)
Expand Down Expand Up @@ -77,7 +71,7 @@ def _get_resource_token(
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object],
credential,
) -> object:
from promptflow.azure import PFClient

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sdk_cli_azure_test.conftest import DATAS_DIR, EAGER_FLOWS_DIR, FLOWS_DIR

from promptflow._sdk._errors import RunOperationParameterError, UploadUserError, UserAuthenticationError
from promptflow._sdk._utils import parse_otel_span_status_code
from promptflow._sdk._utils.tracing import _parse_otel_span_status_code
from promptflow._sdk.entities import Run
from promptflow._sdk.operations._run_operations import RunOperations
from promptflow._utils.async_utils import async_run_allowing_running_loop
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_flex_flow_with_imported_func(self, pf: PFClient):
# TODO(3017093): won't support this for now
with pytest.raises(UserErrorException) as e:
pf.run(
flow=parse_otel_span_status_code,
flow=_parse_otel_span_status_code,
data=f"{DATAS_DIR}/simple_eager_flow_data.jsonl",
# set code folder to avoid snapshot too big
code=f"{EAGER_FLOWS_DIR}/multiple_entries",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def trace_collector(
get_created_by_info_with_cache=get_created_by_info_with_cache,
logger=logger,
cloud_trace_only=cloud_trace_only,
credential=credential,
)
return "Traces received", 200

Expand Down
60 changes: 39 additions & 21 deletions src/promptflow-devkit/promptflow/_sdk/_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,8 @@
is_port_in_use,
is_run_from_built_binary,
)
from promptflow._sdk._tracing_utils import get_workspace_kind
from promptflow._sdk._utils import (
add_executable_script_to_env_path,
extract_workspace_triad_from_trace_provider,
parse_kv_from_pb_attribute,
)
from promptflow._sdk._utils import add_executable_script_to_env_path, extract_workspace_triad_from_trace_provider
from promptflow._sdk._utils.tracing import get_workspace_kind, parse_kv_from_pb_attribute, parse_protobuf_span
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow._utils.thread_utils import ThreadWithContextVars
from promptflow.tracing._integrations._openai_injector import inject_openai_api
Expand Down Expand Up @@ -560,7 +556,7 @@ def process_otlp_trace_request(
get_created_by_info_with_cache: typing.Callable,
logger: logging.Logger,
cloud_trace_only: bool = False,
credential: typing.Optional[object] = None,
get_credential: typing.Optional[typing.Callable] = None,
):
"""Process ExportTraceServiceRequest and write data to local/remote storage.

Expand All @@ -574,11 +570,10 @@ def process_otlp_trace_request(
:type logger: logging.Logger
:param cloud_trace_only: If True, only write trace to cosmosdb and skip local trace. Default is False.
:type cloud_trace_only: bool
:param credential: The credential object used to authenticate with cosmosdb. Default is None.
:type credential: Optional[object]
:param get_credential: A function that gets credential for Cosmos DB operation. Default is None.
:type get_credential: Optional[Callable]
"""
from promptflow._sdk.entities._trace import Span
from promptflow._sdk.operations._trace_operations import TraceOperations

all_spans = []
for resource_span in trace_request.resource_spans:
Expand All @@ -596,7 +591,7 @@ def process_otlp_trace_request(
for scope_span in resource_span.scope_spans:
for span in scope_span.spans:
# TODO: persist with batch
span: Span = TraceOperations._parse_protobuf_span(span, resource=resource, logger=logger)
span: Span = parse_protobuf_span(span, resource=resource, logger=logger)
if not cloud_trace_only:
all_spans.append(copy.deepcopy(span))
span._persist()
Expand All @@ -606,12 +601,14 @@ def process_otlp_trace_request(

if cloud_trace_only:
# If we only trace to cloud, we should make sure the data writing is success before return.
_try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache, logger, credential, is_cloud_trace=True)
_try_write_trace_to_cosmosdb(
all_spans, get_created_by_info_with_cache, logger, get_credential, is_cloud_trace=True
)
else:
# Create a new thread to write trace to cosmosdb to avoid blocking the main thread
ThreadWithContextVars(
target=_try_write_trace_to_cosmosdb,
args=(all_spans, get_created_by_info_with_cache, logger, credential, False),
args=(all_spans, get_created_by_info_with_cache, logger, get_credential, False),
).start()

return
Expand All @@ -621,7 +618,7 @@ def _try_write_trace_to_cosmosdb(
all_spans: typing.List,
get_created_by_info_with_cache: typing.Callable,
logger: logging.Logger,
credential: typing.Optional[object] = None,
get_credential: typing.Optional[typing.Callable] = None,
is_cloud_trace: bool = False,
):
if not all_spans:
Expand All @@ -640,6 +637,15 @@ def _try_write_trace_to_cosmosdb(
logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.")
start_time = datetime.now()

# use Azure CLI credential init as default get credential function
# outside this function, or say before this line, we might be in an environment with only promptflow-devkit
# we cannot import `AzureCliCredential` from `azure-identity`, so it might be None
# for other usages like runtime, or PRS, they should pass the function to get credential
if get_credential is None:
from azure.identity import AzureCliCredential

get_credential = AzureCliCredential
huaiyan marked this conversation as resolved.
Show resolved Hide resolved

from promptflow.azure._storage.cosmosdb.client import get_client
from promptflow.azure._storage.cosmosdb.collection import CollectionCosmosDB
from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB
Expand All @@ -649,19 +655,31 @@ def _try_write_trace_to_cosmosdb(
# So, we load clients in parallel for warm up.
span_client_thread = ThreadWithContextVars(
target=get_client,
args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, credential),
args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, get_credential),
)
span_client_thread.start()

collection_client_thread = ThreadWithContextVars(
target=get_client,
args=(CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, credential),
args=(
CosmosDBContainerName.COLLECTION,
subscription_id,
resource_group_name,
workspace_name,
get_credential,
),
)
collection_client_thread.start()

line_summary_client_thread = ThreadWithContextVars(
target=get_client,
args=(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name, credential),
args=(
CosmosDBContainerName.LINE_SUMMARY,
subscription_id,
resource_group_name,
workspace_name,
get_credential,
),
)
line_summary_client_thread.start()

Expand All @@ -677,7 +695,7 @@ def _try_write_trace_to_cosmosdb(
subscription_id=subscription_id,
resource_group_name=resource_group_name,
workspace_name=workspace_name,
credential=credential,
get_credential=get_credential,
)

span_client_thread.join()
Expand All @@ -687,7 +705,7 @@ def _try_write_trace_to_cosmosdb(

created_by = get_created_by_info_with_cache()
collection_client = get_client(
CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, credential
CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, get_credential
)

collection_db = CollectionCosmosDB(first_span, is_cloud_trace, created_by)
Expand All @@ -701,7 +719,7 @@ def _try_write_trace_to_cosmosdb(
for span in all_spans:
try:
span_client = get_client(
CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, credential
CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, get_credential
)
result = SpanCosmosDB(span, collection_id, created_by).persist(
span_client, blob_container_client, blob_base_uri
Expand All @@ -713,7 +731,7 @@ def _try_write_trace_to_cosmosdb(
subscription_id,
resource_group_name,
workspace_name,
credential,
get_credential,
)
Summary(span, collection_id, created_by, logger).persist(line_summary_client)
except Exception as e:
Expand Down
Loading
Loading