From 5901354212a1fcaffef2e7e704c35f73f0e4f549 Mon Sep 17 00:00:00 2001 From: Robben Wang <350053002@qq.com> Date: Wed, 17 Apr 2024 18:39:48 +0800 Subject: [PATCH] Use same structure for main flow and evaluation value in LineSummary (#2845) # Description Keep same behavior with local trace. Use same structure for main flow and evaluation value in LineSummary. Maybe could free us from adding fields one by one. # All Promptflow Contribution checklist: - [X] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [X] Title of the pull request is clear and informative. - [X] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [X] Pull request includes test coverage for the included changes. Co-authored-by: robbenwang --- .../azure/_storage/cosmosdb/summary.py | 55 ++--- .../unittests/test_summary.py | 216 ++++++++---------- 2 files changed, 105 insertions(+), 166 deletions(-) diff --git a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/summary.py b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/summary.py index 63c6e4ab9e6..41cd398bb9b 100644 --- a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/summary.py +++ b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/summary.py @@ -27,7 +27,7 @@ @dataclass class SummaryLine: """ - This class represents an Item in Summary container + This class represents an Item in LineSummary container and each value for evaluations dict. """ id: str @@ -54,27 +54,6 @@ class SummaryLine: line_run_id: str = None -@dataclass -class LineEvaluation: - """ - This class represents an evaluation value in Summary container item. - - """ - - outputs: typing.Dict - trace_id: str - root_span_id: str - name: str - created_by: typing.Dict - collection_id: str - flow_id: str = None - # Only for batch run - batch_run_id: str = None - line_number: str = None - # Only for line run - line_run_id: str = None - - class Summary: def __init__(self, span: Span, collection_id: str, created_by: typing.Dict, logger: logging.Logger) -> None: self.span = span @@ -93,7 +72,7 @@ def persist(self, client: ContainerProxy): # For non root span, write a placeholder item to LineSummary table. self._persist_running_item(client) return - self._parse_inputs_outputs_from_events() + self._prepare_db_item() # Persist root span as a line run. self._persist_line_run(client) @@ -191,9 +170,8 @@ def _process_value(value): else: return _process_value(content) - def _persist_line_run(self, client: ContainerProxy): - attributes: dict = self.span.attributes - + def _prepare_db_item(self): + self._parse_inputs_outputs_from_events() session_id = self.session_id start_time = self.span.start_time.isoformat() end_time = self.span.end_time.isoformat() @@ -202,6 +180,7 @@ def _persist_line_run(self, client: ContainerProxy): # Convert ISO 8601 formatted strings to datetime objects latency = (self.span.end_time - self.span.start_time).total_seconds() # calculate `cumulative_token_count` + attributes: dict = self.span.attributes completion_token_count = int(attributes.get(SpanAttributeFieldName.COMPLETION_TOKEN_COUNT, 0)) prompt_token_count = int(attributes.get(SpanAttributeFieldName.PROMPT_TOKEN_COUNT, 0)) total_token_count = int(attributes.get(SpanAttributeFieldName.TOTAL_TOKEN_COUNT, 0)) @@ -237,10 +216,13 @@ def _persist_line_run(self, client: ContainerProxy): elif SpanAttributeFieldName.BATCH_RUN_ID in attributes and SpanAttributeFieldName.LINE_NUMBER in attributes: item.batch_run_id = attributes[SpanAttributeFieldName.BATCH_RUN_ID] item.line_number = attributes[SpanAttributeFieldName.LINE_NUMBER] + self.item = item - self.logger.info(f"Persist main run for LineSummary id: {item.id}") + def _persist_line_run(self, client: ContainerProxy): + + self.logger.info(f"Persist main run for LineSummary id: {self.item.id}") # Use upsert because we may create running item in advance. - return client.upsert_item(body=asdict(item)) + return client.upsert_item(body=asdict(self.item)) def _insert_evaluation_with_retry(self, client: ContainerProxy): for attempt in range(3): @@ -258,15 +240,6 @@ def _insert_evaluation_with_retry(self, client: ContainerProxy): def _insert_evaluation(self, client: ContainerProxy): attributes: dict = self.span.attributes - item = LineEvaluation( - trace_id=self.span.trace_id, - root_span_id=self.span.span_id, - collection_id=self.collection_id, - outputs=self.outputs, - name=self.span.name, - created_by=self.created_by, - ) - # None is the default value for the field. referenced_line_run_id = attributes.get(SpanAttributeFieldName.REFERENCED_LINE_RUN_ID, None) referenced_batch_run_id = attributes.get(SpanAttributeFieldName.REFERENCED_BATCH_RUN_ID, None) @@ -296,18 +269,18 @@ def _insert_evaluation(self, client: ContainerProxy): raise InsertEvaluationsRetriableException(f"Cannot find main run by parameter {parameters}.") if SpanAttributeFieldName.LINE_RUN_ID in attributes: - item.line_run_id = attributes[SpanAttributeFieldName.LINE_RUN_ID] key = self.span.name else: batch_run_id = attributes[SpanAttributeFieldName.BATCH_RUN_ID] - item.batch_run_id = batch_run_id - item.line_number = line_number # Use the batch run id, instead of the name, as the key in the evaluations dictionary. # Customers may execute the same evaluation flow multiple times for a batch run. # We should be able to save all evaluations, as customers use batch runs in a critical manner. key = batch_run_id - patch_operations = [{"op": "add", "path": f"/evaluations/{key}", "value": asdict(item)}] + item_dict = asdict(self.item) + # Remove unnecessary fields from the item + del item_dict["evaluations"] + patch_operations = [{"op": "add", "path": f"/evaluations/{key}", "value": item_dict}] self.logger.info(f"Insert evaluation for LineSummary main_id: {main_id}") return client.patch_item(item=main_id, partition_key=main_partition_key, patch_operations=patch_operations) diff --git a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_summary.py b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_summary.py index 0e1e63ebfa5..c56684d7aa8 100644 --- a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_summary.py +++ b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_summary.py @@ -4,14 +4,9 @@ import pytest -from promptflow._constants import OK_LINE_RUN_STATUS, SpanAttributeFieldName +from promptflow._constants import OK_LINE_RUN_STATUS, SpanAttributeFieldName, SpanStatusFieldName from promptflow._sdk.entities._trace import Span -from promptflow.azure._storage.cosmosdb.summary import ( - InsertEvaluationsRetriableException, - LineEvaluation, - Summary, - SummaryLine, -) +from promptflow.azure._storage.cosmosdb.summary import InsertEvaluationsRetriableException, Summary, SummaryLine @pytest.mark.unittest @@ -48,6 +43,16 @@ def setup_data(self): }, ] self.summary = Summary(test_span, self.FAKE_COLLECTION_ID, self.FAKE_CREATED_BY, self.FAKE_LOGGER) + # Just for assert purpose + self.summary.item = SummaryLine( + id="test_trace_id", + partition_key=self.FAKE_COLLECTION_ID, + collection_id=self.FAKE_COLLECTION_ID, + session_id="test_session_id", + line_run_id="line_run_id", + trace_id=self.summary.span.trace_id, + root_span_id=self.summary.span.span_id, + ) def test_aggregate_node_span_does_not_persist(self): mock_client = mock.Mock() @@ -56,30 +61,26 @@ def test_aggregate_node_span_does_not_persist(self): with mock.patch.multiple( self.summary, _persist_running_item=mock.DEFAULT, - _parse_inputs_outputs_from_events=mock.DEFAULT, _persist_line_run=mock.DEFAULT, _insert_evaluation_with_retry=mock.DEFAULT, ) as values: self.summary.persist(mock_client) values["_persist_running_item"].assert_not_called() - values["_parse_inputs_outputs_from_events"].assert_not_called() values["_persist_line_run"].assert_not_called() values["_insert_evaluation_with_retry"].assert_not_called() - def test_non_root_span_does_not_persist(self): + def test_non_root_span_persist_running_node(self): mock_client = mock.Mock() self.summary.span.parent_id = "parent_span_id" with mock.patch.multiple( self.summary, _persist_running_item=mock.DEFAULT, - _parse_inputs_outputs_from_events=mock.DEFAULT, _persist_line_run=mock.DEFAULT, _insert_evaluation_with_retry=mock.DEFAULT, ) as values: self.summary.persist(mock_client) values["_persist_running_item"].assert_called_once() - values["_parse_inputs_outputs_from_events"].assert_not_called() values["_persist_line_run"].assert_not_called() values["_insert_evaluation_with_retry"].assert_not_called() @@ -92,17 +93,15 @@ def test_root_span_persist_main_line(self): with mock.patch.multiple( self.summary, _persist_running_item=mock.DEFAULT, - _parse_inputs_outputs_from_events=mock.DEFAULT, _persist_line_run=mock.DEFAULT, _insert_evaluation_with_retry=mock.DEFAULT, ) as values: self.summary.persist(mock_client) values["_persist_running_item"].assert_not_called() - values["_parse_inputs_outputs_from_events"].assert_called_once() values["_persist_line_run"].assert_called_once() values["_insert_evaluation_with_retry"].assert_not_called() - def test_root_evaluation_span_insert(self): + def test_root_eval_span_persist_eval(self): mock_client = mock.Mock() self.summary.span.parent_id = None self.summary.span.attributes[SpanAttributeFieldName.LINE_RUN_ID] = "line_run_id" @@ -110,43 +109,97 @@ def test_root_evaluation_span_insert(self): with mock.patch.multiple( self.summary, _persist_running_item=mock.DEFAULT, - _parse_inputs_outputs_from_events=mock.DEFAULT, _persist_line_run=mock.DEFAULT, _insert_evaluation_with_retry=mock.DEFAULT, ) as values: self.summary.persist(mock_client) values["_persist_running_item"].assert_not_called() - values["_parse_inputs_outputs_from_events"].assert_called_once() values["_persist_line_run"].assert_called_once() values["_insert_evaluation_with_retry"].assert_called_once() - def test_insert_evaluation_not_found(self): - client = mock.Mock() + @pytest.mark.parametrize( + "run_id_dict, expected_line_run_id, expected_batch_run_id, expected_line_number", + [ + [{}, None, None, None], + [ + { + SpanAttributeFieldName.LINE_NUMBER: "1", + }, + None, + None, + None, + ], + [{SpanAttributeFieldName.BATCH_RUN_ID: "batch_run_id"}, None, None, None], + [ + { + SpanAttributeFieldName.BATCH_RUN_ID: "batch_run_id", + SpanAttributeFieldName.LINE_NUMBER: "1", + }, + None, + "batch_run_id", + "1", + ], + [{SpanAttributeFieldName.LINE_RUN_ID: "line_run_id"}, "line_run_id", None, None], + ], + ) + def test_prepare_db_item(self, run_id_dict, expected_line_run_id, expected_batch_run_id, expected_line_number): + self.summary.span.start_time = datetime.datetime.fromisoformat("2022-01-01T00:00:00") + self.summary.span.end_time = datetime.datetime.fromisoformat("2022-01-01T00:01:00") self.summary.span.attributes = { - SpanAttributeFieldName.REFERENCED_LINE_RUN_ID: "referenced_line_run_id", - SpanAttributeFieldName.LINE_RUN_ID: "line_run_id", + SpanAttributeFieldName.COMPLETION_TOKEN_COUNT: 10, + SpanAttributeFieldName.PROMPT_TOKEN_COUNT: 5, + SpanAttributeFieldName.TOTAL_TOKEN_COUNT: 15, + SpanAttributeFieldName.SPAN_TYPE: "span_type", } + self.summary.span.attributes.update(run_id_dict) + + self.summary._prepare_db_item() + + assert self.summary.item.id == self.summary.span.trace_id + assert self.summary.item.partition_key == self.summary.collection_id + assert self.summary.item.session_id == self.summary.session_id + assert self.summary.item.trace_id == self.summary.span.trace_id + assert self.summary.item.collection_id == self.summary.collection_id + assert self.summary.item.root_span_id == self.summary.span.span_id + assert self.summary.item.inputs == self.summary.inputs + assert self.summary.item.outputs == self.summary.outputs + assert self.summary.item.start_time == "2022-01-01T00:00:00" + assert self.summary.item.end_time == "2022-01-01T00:01:00" + assert self.summary.item.status == self.summary.span.status[SpanStatusFieldName.STATUS_CODE] + assert self.summary.item.latency == 60.0 + assert self.summary.item.name == self.summary.span.name + assert self.summary.item.kind == "span_type" + assert self.summary.item.cumulative_token_count == { + "completion": 10, + "prompt": 5, + "total": 15, + } + assert self.summary.item.created_by == self.summary.created_by + assert self.summary.item.line_run_id == expected_line_run_id + assert self.summary.item.batch_run_id == expected_batch_run_id + assert self.summary.item.line_number == expected_line_number - client.query_items.return_value = [] - with pytest.raises(InsertEvaluationsRetriableException): - self.summary._insert_evaluation(client) - client.query_items.assert_called_once() - client.patch_item.assert_not_called() - - def test_insert_evaluation_not_finished(self): + @pytest.mark.parametrize( + "return_value", + [ + [], # No item found + [{"id": "main_id"}], # Not finished + ], + ) + def test_insert_evaluation_no_action(self, return_value): client = mock.Mock() self.summary.span.attributes = { SpanAttributeFieldName.REFERENCED_LINE_RUN_ID: "referenced_line_run_id", SpanAttributeFieldName.LINE_RUN_ID: "line_run_id", } - client.query_items.return_value = [{"id": "main_id"}] + client.query_items.return_value = [] with pytest.raises(InsertEvaluationsRetriableException): self.summary._insert_evaluation(client) client.query_items.assert_called_once() client.patch_item.assert_not_called() - def test_insert_evaluation_query_line(self): + def test_insert_evaluation_query_line_run(self): client = mock.Mock() self.summary.span.attributes = { SpanAttributeFieldName.REFERENCED_LINE_RUN_ID: "referenced_line_run_id", @@ -168,18 +221,11 @@ def test_insert_evaluation_query_line(self): ], enable_cross_partition_query=True, ) + item_dict = asdict(self.summary.item) + del item_dict["evaluations"] - expected_item = LineEvaluation( - line_run_id="line_run_id", - collection_id=self.FAKE_COLLECTION_ID, - trace_id=self.summary.span.trace_id, - root_span_id=self.summary.span.span_id, - outputs=None, - name=self.summary.span.name, - created_by=self.FAKE_CREATED_BY, - ) expected_patch_operations = [ - {"op": "add", "path": f"/evaluations/{self.summary.span.name}", "value": asdict(expected_item)} + {"op": "add", "path": f"/evaluations/{self.summary.span.name}", "value": item_dict} ] client.patch_item.assert_called_once_with( item="main_id", @@ -212,17 +258,10 @@ def test_insert_evaluation_query_batch_run(self): enable_cross_partition_query=True, ) - expected_item = LineEvaluation( - batch_run_id="batch_run_id", - collection_id=self.FAKE_COLLECTION_ID, - line_number=1, - trace_id=self.summary.span.trace_id, - root_span_id=self.summary.span.span_id, - outputs=None, - name=self.summary.span.name, - created_by=self.FAKE_CREATED_BY, - ) - expected_patch_operations = [{"op": "add", "path": "/evaluations/batch_run_id", "value": asdict(expected_item)}] + item_dict = asdict(self.summary.item) + del item_dict["evaluations"] + + expected_patch_operations = [{"op": "add", "path": "/evaluations/batch_run_id", "value": item_dict}] client.patch_item.assert_called_once_with( item="main_id", partition_key="test_main_partition_key", @@ -231,81 +270,8 @@ def test_insert_evaluation_query_batch_run(self): def test_persist_line_run(self): client = mock.Mock() - self.summary.span.attributes.update( - { - SpanAttributeFieldName.LINE_RUN_ID: "line_run_id", - SpanAttributeFieldName.SPAN_TYPE: "promptflow.TraceType.Flow", - SpanAttributeFieldName.COMPLETION_TOKEN_COUNT: 10, - SpanAttributeFieldName.PROMPT_TOKEN_COUNT: 5, - SpanAttributeFieldName.TOTAL_TOKEN_COUNT: 15, - } - ) - expected_item = SummaryLine( - id="test_trace_id", - partition_key=self.FAKE_COLLECTION_ID, - collection_id=self.FAKE_COLLECTION_ID, - session_id="test_session_id", - line_run_id="line_run_id", - trace_id=self.summary.span.trace_id, - root_span_id=self.summary.span.span_id, - inputs=None, - outputs=None, - start_time="2022-01-01T00:00:00", - end_time="2022-01-01T00:01:00", - status=OK_LINE_RUN_STATUS, - latency=60.0, - name=self.summary.span.name, - kind="promptflow.TraceType.Flow", - created_by=self.FAKE_CREATED_BY, - cumulative_token_count={ - "completion": 10, - "prompt": 5, - "total": 15, - }, - ) - - self.summary._persist_line_run(client) - client.upsert_item.assert_called_once_with(body=asdict(expected_item)) - - def test_persist_batch_run(self): - client = mock.Mock() - self.summary.span.attributes.update( - { - SpanAttributeFieldName.BATCH_RUN_ID: "batch_run_id", - SpanAttributeFieldName.LINE_NUMBER: "1", - SpanAttributeFieldName.SPAN_TYPE: "promptflow.TraceType.Flow", - SpanAttributeFieldName.COMPLETION_TOKEN_COUNT: 10, - SpanAttributeFieldName.PROMPT_TOKEN_COUNT: 5, - SpanAttributeFieldName.TOTAL_TOKEN_COUNT: 15, - }, - ) - expected_item = SummaryLine( - id="test_trace_id", - partition_key=self.FAKE_COLLECTION_ID, - session_id="test_session_id", - collection_id=self.FAKE_COLLECTION_ID, - batch_run_id="batch_run_id", - line_number="1", - trace_id=self.summary.span.trace_id, - root_span_id=self.summary.span.span_id, - inputs=None, - outputs=None, - start_time="2022-01-01T00:00:00", - end_time="2022-01-01T00:01:00", - status=OK_LINE_RUN_STATUS, - latency=60.0, - name=self.summary.span.name, - created_by=self.FAKE_CREATED_BY, - kind="promptflow.TraceType.Flow", - cumulative_token_count={ - "completion": 10, - "prompt": 5, - "total": 15, - }, - ) - self.summary._persist_line_run(client) - client.upsert_item.assert_called_once_with(body=asdict(expected_item)) + client.upsert_item.assert_called_once_with(body=asdict(self.summary.item)) def test_insert_evaluation_with_retry_success(self): client = mock.Mock()