diff --git a/src/promptflow/promptflow/_sdk/_tracing.py b/src/promptflow/promptflow/_sdk/_tracing.py index 186d5ebcf65..c34a0935bb3 100644 --- a/src/promptflow/promptflow/_sdk/_tracing.py +++ b/src/promptflow/promptflow/_sdk/_tracing.py @@ -177,7 +177,7 @@ def start_trace_with_devkit( ref_line_run_id = env_attrs.get(ContextAttributeKey.REFERENCED_LINE_RUN_ID, None) op_ctx = OperationContext.get_instance() # remove `referenced.line_run_id` from context to avoid stale value set by previous node - if ref_line_run_id: + if ref_line_run_id is None: op_ctx._remove_otel_attributes(SpanAttributeFieldName.REFERENCED_LINE_RUN_ID) else: op_ctx._add_otel_attributes(SpanAttributeFieldName.REFERENCED_LINE_RUN_ID, ref_line_run_id) diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py index 1cdfab7b211..6af218f14da 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py @@ -204,7 +204,6 @@ def test_cancel_experiment(self): exp = client._experiments.get(exp.name) assert exp.status == ExperimentStatus.TERMINATED - @pytest.mark.skip("This test is not working currently") @pytest.mark.usefixtures("use_secrets_config_file", "recording_injection", "setup_local_connection") def test_flow_test_with_experiment(self, monkeypatch): # set queue size to 1 to make collection faster @@ -253,7 +252,7 @@ def _assert_result(result): time.sleep(10) # TODO fix this line_runs = client._traces.list_line_runs(session_id=session) if len(line_runs) > 0: - assert len(line_runs) > 1 + assert len(line_runs) == 1 line_run = line_runs[0] assert len(line_run.evaluations) == 1, "line run evaluation not exists!" assert "eval_classification_accuracy" == list(line_run.evaluations.values())[0].display_name diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_trace.py b/src/promptflow/tests/sdk_cli_test/unittests/test_trace.py index 58d9fe73d79..d1ba717ccea 100644 --- a/src/promptflow/tests/sdk_cli_test/unittests/test_trace.py +++ b/src/promptflow/tests/sdk_cli_test/unittests/test_trace.py @@ -3,18 +3,28 @@ # --------------------------------------------------------- import base64 +import json import os import uuid from typing import Dict from unittest.mock import patch import pytest +from mock import mock from opentelemetry import trace from opentelemetry.proto.trace.v1.trace_pb2 import Span as PBSpan from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_ENDPOINT from opentelemetry.sdk.trace import TracerProvider -from promptflow._constants import SpanResourceAttributesFieldName, SpanResourceFieldName, TraceEnvironmentVariableName +from promptflow._constants import ( + SpanAttributeFieldName, + SpanResourceAttributesFieldName, + SpanResourceFieldName, + TraceEnvironmentVariableName, +) +from promptflow._core.operation_context import OperationContext +from promptflow._sdk._constants import PF_TRACE_CONTEXT, PF_TRACE_CONTEXT_ATTR, ContextAttributeKey +from promptflow._sdk._tracing import start_trace_with_devkit from promptflow._sdk.entities._trace import Span from promptflow.tracing._start_trace import _is_tracer_provider_set, setup_exporter_from_environ @@ -40,6 +50,14 @@ def mock_resource() -> Dict: } +@pytest.fixture +def mock_promptflow_service_invocation(): + """Mock `_invoke_pf_svc` as we don't expect to invoke PFS during unit test.""" + with mock.patch("promptflow._sdk._tracing._invoke_pf_svc") as mock_func: + mock_func.return_value = "23333" + yield + + @pytest.mark.sdk_test @pytest.mark.unittest class TestStartTrace: @@ -103,3 +121,28 @@ def test_trace_without_attributes_collection(self, mock_resource: Dict) -> None: attributes = span._content["attributes"] assert isinstance(attributes, dict) assert len(attributes) == 0 + + def test_experiment_test_lineage(self, monkeypatch: pytest.MonkeyPatch, mock_promptflow_service_invocation) -> None: + # experiment orchestrator will help set this context in environment + referenced_line_run_id = str(uuid.uuid4()) + ctx = {PF_TRACE_CONTEXT_ATTR: {ContextAttributeKey.REFERENCED_LINE_RUN_ID: referenced_line_run_id}} + with monkeypatch.context() as m: + m.setenv(PF_TRACE_CONTEXT, json.dumps(ctx)) + start_trace_with_devkit(session_id=None) + # lineage is stored in context + op_ctx = OperationContext.get_instance() + otel_attrs = op_ctx._get_otel_attributes() + assert otel_attrs[SpanAttributeFieldName.REFERENCED_LINE_RUN_ID] == referenced_line_run_id + + def test_experiment_test_lineage_cleanup( + self, monkeypatch: pytest.MonkeyPatch, mock_promptflow_service_invocation + ) -> None: + # in previous code, context may be set with lineage + op_ctx = OperationContext.get_instance() + op_ctx._add_otel_attributes(SpanAttributeFieldName.REFERENCED_LINE_RUN_ID, str(uuid.uuid4())) + with monkeypatch.context() as m: + m.setenv(PF_TRACE_CONTEXT, json.dumps({PF_TRACE_CONTEXT_ATTR: dict()})) + start_trace_with_devkit(session_id=None) + # lineage will be reset + otel_attrs = op_ctx._get_otel_attributes() + assert SpanAttributeFieldName.REFERENCED_LINE_RUN_ID not in otel_attrs