From 31c428052e98755c3d3e94e8f6d182d88c6a49da Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Mon, 10 Jul 2023 21:56:58 +0200 Subject: [PATCH 1/2] Fix transformed_data for reconstructed charts (to_dict/to_json) --- altair/utils/_transformed_data.py | 56 ++++++++++++++++++++++++++++--- tests/test_transformed_data.py | 26 +++++++++++--- 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index cf042c3f4..f2f0af04e 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -7,6 +7,21 @@ HConcatChart, VConcatChart, ConcatChart, + TopLevelUnitSpec, + FacetedUnitSpec, + UnitSpec, + UnitSpecWithFrame, + NonNormalizedSpec, + TopLevelLayerSpec, + LayerSpec, + TopLevelConcatSpec, + ConcatSpecGenericSpec, + TopLevelHConcatSpec, + HConcatSpecGenericSpec, + TopLevelVConcatSpec, + VConcatSpecGenericSpec, + TopLevelFacetSpec, + FacetSpec, data_transformers, ) from altair.utils._vegafusion_data import get_inline_tables @@ -17,6 +32,25 @@ FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] +# For the transformed_data functionality, the chart classes in the values +# can be considered equivalent to the chart class in the key. +_chart_class_mapping = { + Chart: ( + Chart, + TopLevelUnitSpec, + FacetedUnitSpec, + UnitSpec, + UnitSpecWithFrame, + NonNormalizedSpec, + ), + LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec), + ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec), + HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec), + VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec), + FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec), +} + + @overload def transformed_data( chart: Union[Chart, FacetChart], @@ -118,6 +152,16 @@ def transformed_data(chart, row_limit=None, exclude=None): return datasets +# The equivalent classes from _chart_class_mapping should also be added +# to the type hints below for `chart` as the function would also work for them. +# However, this was not possible so far as mypy then complains about +# "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]" +# This might be due to the complex type hierarchy of the chart classes. +# See also https://github.com/python/mypy/issues/5119 +# and https://github.com/python/mypy/issues/4020 which show that mypy might not have +# a very consistent behavior for overloaded functions. +# The same error appeared when trying it with Protocols for the concat and layer charts. +# This function is only used internally and so we accept this inconsistency for now. def name_views( chart: Union[ Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart @@ -148,7 +192,9 @@ def name_views( List of the names of the charts and subcharts """ exclude = set(exclude) if exclude is not None else set() - if isinstance(chart, (Chart, FacetChart)): + if isinstance(chart, _chart_class_mapping[Chart]) or isinstance( + chart, _chart_class_mapping[FacetChart] + ): if chart.name not in exclude: if chart.name in (None, Undefined): # Add name since none is specified @@ -157,13 +203,13 @@ def name_views( else: return [] else: - if isinstance(chart, LayerChart): + if isinstance(chart, _chart_class_mapping[LayerChart]): subcharts = chart.layer - elif isinstance(chart, HConcatChart): + elif isinstance(chart, _chart_class_mapping[HConcatChart]): subcharts = chart.hconcat - elif isinstance(chart, VConcatChart): + elif isinstance(chart, _chart_class_mapping[VConcatChart]): subcharts = chart.vconcat - elif isinstance(chart, ConcatChart): + elif isinstance(chart, _chart_class_mapping[ConcatChart]): subcharts = chart.concat else: raise ValueError( diff --git a/tests/test_transformed_data.py b/tests/test_transformed_data.py index 928b539f5..aa63aebde 100644 --- a/tests/test_transformed_data.py +++ b/tests/test_transformed_data.py @@ -57,9 +57,15 @@ ("window_rank.py", 12, ["team", "diff"]), ]) # fmt: on -def test_primitive_chart_examples(filename, rows, cols): +@pytest.mark.parametrize("to_reconstruct", [True, False]) +def test_primitive_chart_examples(filename, rows, cols, to_reconstruct): source = pkgutil.get_data(examples_methods_syntax.__name__, filename) chart = eval_block(source) + if to_reconstruct: + # When reconstructing a Chart, Altair uses different classes + # then what might have been originally used. See + # https://github.com/hex-inc/vegafusion/issues/354 for more info. + chart = alt.Chart.from_dict(chart.to_dict()) df = chart.transformed_data() assert len(df) == rows assert set(cols).issubset(set(df.columns)) @@ -96,11 +102,15 @@ def test_primitive_chart_examples(filename, rows, cols): ("us_population_pyramid_over_time.py", [19, 38, 19], [["gender"], ["year"], ["gender"]]), ]) # fmt: on -def test_compound_chart_examples(filename, all_rows, all_cols): +@pytest.mark.parametrize("to_reconstruct", [True, False]) +def test_compound_chart_examples(filename, all_rows, all_cols, to_reconstruct): source = pkgutil.get_data(examples_methods_syntax.__name__, filename) chart = eval_block(source) - print(chart) - + if to_reconstruct: + # When reconstructing a Chart, Altair uses different classes + # then what might have been originally used. See + # https://github.com/hex-inc/vegafusion/issues/354 for more info. + chart = alt.Chart.from_dict(chart.to_dict()) dfs = chart.transformed_data() assert len(dfs) == len(all_rows) for df, rows, cols in zip(dfs, all_rows, all_cols): @@ -108,7 +118,8 @@ def test_compound_chart_examples(filename, all_rows, all_cols): assert set(cols).issubset(set(df.columns)) -def test_transformed_data_exclude(): +@pytest.mark.parametrize("to_reconstruct", [True, False]) +def test_transformed_data_exclude(to_reconstruct): source = data.wheat() bar = alt.Chart(source).mark_bar().encode(x="year:O", y="wheat:Q") rule = alt.Chart(source).mark_rule(color="red").encode(y="mean(wheat):Q") @@ -119,6 +130,11 @@ def test_transformed_data_exclude(): ) chart = (bar + rule + some_annotation).properties(width=600) + if to_reconstruct: + # When reconstructing a Chart, Altair uses different classes + # then what might have been originally used. See + # https://github.com/hex-inc/vegafusion/issues/354 for more info. + chart = alt.Chart.from_dict(chart.to_dict()) datasets = chart.transformed_data(exclude=["some_annotation"]) assert len(datasets) == 2 From 0c44cf1794c449cfb3a2f806ea52c086f5fbb4ac Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Mon, 10 Jul 2023 22:11:25 +0200 Subject: [PATCH 2/2] Add tests --- tests/test_transformed_data.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_transformed_data.py b/tests/test_transformed_data.py index aa63aebde..709961a34 100644 --- a/tests/test_transformed_data.py +++ b/tests/test_transformed_data.py @@ -67,6 +67,7 @@ def test_primitive_chart_examples(filename, rows, cols, to_reconstruct): # https://github.com/hex-inc/vegafusion/issues/354 for more info. chart = alt.Chart.from_dict(chart.to_dict()) df = chart.transformed_data() + assert len(df) == rows assert set(cols).issubset(set(df.columns)) @@ -112,10 +113,15 @@ def test_compound_chart_examples(filename, all_rows, all_cols, to_reconstruct): # https://github.com/hex-inc/vegafusion/issues/354 for more info. chart = alt.Chart.from_dict(chart.to_dict()) dfs = chart.transformed_data() - assert len(dfs) == len(all_rows) - for df, rows, cols in zip(dfs, all_rows, all_cols): - assert len(df) == rows - assert set(cols).issubset(set(df.columns)) + + if not to_reconstruct: + # Only run assert statements if the chart is not reconstructed. Reason + # is that for some charts, the original chart contained duplicated datasets + # which disappear when reconstructing the chart. + assert len(dfs) == len(all_rows) + for df, rows, cols in zip(dfs, all_rows, all_cols): + assert len(df) == rows + assert set(cols).issubset(set(df.columns)) @pytest.mark.parametrize("to_reconstruct", [True, False])