diff --git a/python/tests/api/logger/test_segments.py b/python/tests/api/logger/test_segments.py index 3b18eb1a5d..5f83ea17d9 100644 --- a/python/tests/api/logger/test_segments.py +++ b/python/tests/api/logger/test_segments.py @@ -210,6 +210,7 @@ def test_segment_write_roundtrip_versions(tmp_path: Any, v0) -> None: input_rows = 10 segment_column = "col3" number_of_segments = 2 + trace_id = "123-456" values_per_segment = input_rows / number_of_segments d = { "col1": [i for i in range(input_rows)], @@ -220,7 +221,7 @@ def test_segment_write_roundtrip_versions(tmp_path: Any, v0) -> None: df = pd.DataFrame(data=d) test_segments = segment_on_column(segment_column) - results: SegmentedResultSet = why.log(df, schema=DatasetSchema(segments=test_segments)) + results: SegmentedResultSet = why.log(df, trace_id=trace_id, schema=DatasetSchema(segments=test_segments)) assert results.count == number_of_segments partitions = results.partitions assert len(partitions) == 1 @@ -255,6 +256,11 @@ def test_segment_write_roundtrip_versions(tmp_path: Any, v0) -> None: post_deserialization_first_view = roundtrip_profiles[0] assert post_deserialization_first_view is not None assert isinstance(post_deserialization_first_view, DatasetProfileView) + + # check that trace_id is preserved round trip in metadata + assert post_deserialization_first_view.metadata + assert "whylabs.traceId" in post_deserialization_first_view.metadata + assert trace_id == post_deserialization_first_view.metadata["whylabs.traceId"] pre_serialization_first_view = first_segment_profile.view() pre_columns = pre_serialization_first_view.get_columns() post_columns = post_deserialization_first_view.get_columns() diff --git a/python/whylogs/api/writer/whylabs.py b/python/whylogs/api/writer/whylabs.py index c76fc44255..287d235fa4 100644 --- a/python/whylogs/api/writer/whylabs.py +++ b/python/whylogs/api/writer/whylabs.py @@ -525,7 +525,10 @@ def _write_segmented_reference_result_set(self, file: SegmentedResultSet, **kwar upload_statuses = list() for view, url in zip(files, upload_urls): with tempfile.NamedTemporaryFile() as tmp_file: - view.write(file=tmp_file) + if kwargs.get("use_v0") is None or kwargs.get("use_v0"): + view.write(file=tmp_file, use_v0=True) + else: + view.write(file=tmp_file) tmp_file.flush() tmp_file.seek(0) @@ -583,7 +586,13 @@ def write(self, file: Writable, **kwargs: Any) -> Tuple[bool, str]: self._dataset_id = kwargs.get("dataset_id") with tempfile.NamedTemporaryFile() as tmp_file: - view.write(file=tmp_file) + # currently whylabs is not ingesting the v1 format of segmented profiles as segmented + # so we default to sending them as v0 profiles if the override `use_v0` is not defined, + # if `use_v0` is defined then pass that through to control the serialization format. + if has_segments and (kwargs.get("use_v0") is None or kwargs.get("use_v0")): + view.write(file=tmp_file, use_v0=True) + else: + view.write(file=tmp_file) tmp_file.flush() tmp_file.seek(0) utc_now = datetime.datetime.now(datetime.timezone.utc) diff --git a/python/whylogs/core/view/segmented_dataset_profile_view.py b/python/whylogs/core/view/segmented_dataset_profile_view.py index 25f49315d7..d30a6048cb 100644 --- a/python/whylogs/core/view/segmented_dataset_profile_view.py +++ b/python/whylogs/core/view/segmented_dataset_profile_view.py @@ -215,8 +215,11 @@ def _write_v1(self, path: Optional[str] = None, **kwargs: Any) -> Tuple[bool, st return True, path def write(self, path: Optional[str] = None, **kwargs: Any) -> Tuple[bool, str]: - if kwargs.get("use_v0"): - logger.warning("writing segmented profile as v0 format, some info may be converted") + if kwargs.get("use_v0") or self.profile_view.model_performance_metrics: + if self.profile_view.model_performance_metrics: + logger.info("Converting segmented profile with performance metrics to v0 format before writing.") + else: + logger.info("writing segmented profile as v0 format.") return self._write_as_v0_message(path, **kwargs) else: return self._write_v1(path, **kwargs)