Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to transformed_data for reconstructed charts (with from_dict/from_json) #3102

Merged
merged 2 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions altair/utils/_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@
HConcatChart,
VConcatChart,
ConcatChart,
TopLevelUnitSpec,
FacetedUnitSpec,
UnitSpec,
UnitSpecWithFrame,
NonNormalizedSpec,
TopLevelLayerSpec,
LayerSpec,
TopLevelConcatSpec,
ConcatSpecGenericSpec,
TopLevelHConcatSpec,
HConcatSpecGenericSpec,
TopLevelVConcatSpec,
VConcatSpecGenericSpec,
TopLevelFacetSpec,
FacetSpec,
data_transformers,
)
from altair.utils._vegafusion_data import get_inline_tables
Expand All @@ -17,6 +32,25 @@
FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]]


# For the transformed_data functionality, the chart classes in the values
# can be considered equivalent to the chart class in the key.
_chart_class_mapping = {
Chart: (
Chart,
TopLevelUnitSpec,
FacetedUnitSpec,
UnitSpec,
UnitSpecWithFrame,
NonNormalizedSpec,
),
LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec),
ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec),
HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec),
VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec),
FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec),
}


@overload
def transformed_data(
chart: Union[Chart, FacetChart],
Expand Down Expand Up @@ -118,6 +152,16 @@ def transformed_data(chart, row_limit=None, exclude=None):
return datasets


# The equivalent classes from _chart_class_mapping should also be added
# to the type hints below for `chart` as the function would also work for them.
# However, this was not possible so far as mypy then complains about
# "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]"
# This might be due to the complex type hierarchy of the chart classes.
# See also https://github.com/python/mypy/issues/5119
# and https://github.com/python/mypy/issues/4020 which show that mypy might not have
# a very consistent behavior for overloaded functions.
# The same error appeared when trying it with Protocols for the concat and layer charts.
# This function is only used internally and so we accept this inconsistency for now.
def name_views(
chart: Union[
Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart
Expand Down Expand Up @@ -148,7 +192,9 @@ def name_views(
List of the names of the charts and subcharts
"""
exclude = set(exclude) if exclude is not None else set()
if isinstance(chart, (Chart, FacetChart)):
if isinstance(chart, _chart_class_mapping[Chart]) or isinstance(
chart, _chart_class_mapping[FacetChart]
):
if chart.name not in exclude:
if chart.name in (None, Undefined):
# Add name since none is specified
Expand All @@ -157,13 +203,13 @@ def name_views(
else:
return []
else:
if isinstance(chart, LayerChart):
if isinstance(chart, _chart_class_mapping[LayerChart]):
subcharts = chart.layer
elif isinstance(chart, HConcatChart):
elif isinstance(chart, _chart_class_mapping[HConcatChart]):
subcharts = chart.hconcat
elif isinstance(chart, VConcatChart):
elif isinstance(chart, _chart_class_mapping[VConcatChart]):
subcharts = chart.vconcat
elif isinstance(chart, ConcatChart):
elif isinstance(chart, _chart_class_mapping[ConcatChart]):
subcharts = chart.concat
else:
raise ValueError(
Expand Down
40 changes: 31 additions & 9 deletions tests/test_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@
("window_rank.py", 12, ["team", "diff"]),
])
# fmt: on
def test_primitive_chart_examples(filename, rows, cols):
@pytest.mark.parametrize("to_reconstruct", [True, False])
def test_primitive_chart_examples(filename, rows, cols, to_reconstruct):
source = pkgutil.get_data(examples_methods_syntax.__name__, filename)
chart = eval_block(source)
if to_reconstruct:
# When reconstructing a Chart, Altair uses different classes
# then what might have been originally used. See
# https://github.com/hex-inc/vegafusion/issues/354 for more info.
chart = alt.Chart.from_dict(chart.to_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great testing approach!

df = chart.transformed_data()

assert len(df) == rows
assert set(cols).issubset(set(df.columns))

Expand Down Expand Up @@ -96,19 +103,29 @@ def test_primitive_chart_examples(filename, rows, cols):
("us_population_pyramid_over_time.py", [19, 38, 19], [["gender"], ["year"], ["gender"]]),
])
# fmt: on
def test_compound_chart_examples(filename, all_rows, all_cols):
@pytest.mark.parametrize("to_reconstruct", [True, False])
def test_compound_chart_examples(filename, all_rows, all_cols, to_reconstruct):
source = pkgutil.get_data(examples_methods_syntax.__name__, filename)
chart = eval_block(source)
print(chart)

if to_reconstruct:
# When reconstructing a Chart, Altair uses different classes
# then what might have been originally used. See
# https://github.com/hex-inc/vegafusion/issues/354 for more info.
chart = alt.Chart.from_dict(chart.to_dict())
dfs = chart.transformed_data()
assert len(dfs) == len(all_rows)
for df, rows, cols in zip(dfs, all_rows, all_cols):
assert len(df) == rows
assert set(cols).issubset(set(df.columns))

if not to_reconstruct:
# Only run assert statements if the chart is not reconstructed. Reason
# is that for some charts, the original chart contained duplicated datasets
# which disappear when reconstructing the chart.
assert len(dfs) == len(all_rows)
for df, rows, cols in zip(dfs, all_rows, all_cols):
assert len(df) == rows
assert set(cols).issubset(set(df.columns))


def test_transformed_data_exclude():
@pytest.mark.parametrize("to_reconstruct", [True, False])
def test_transformed_data_exclude(to_reconstruct):
source = data.wheat()
bar = alt.Chart(source).mark_bar().encode(x="year:O", y="wheat:Q")
rule = alt.Chart(source).mark_rule(color="red").encode(y="mean(wheat):Q")
Expand All @@ -119,6 +136,11 @@ def test_transformed_data_exclude():
)

chart = (bar + rule + some_annotation).properties(width=600)
if to_reconstruct:
# When reconstructing a Chart, Altair uses different classes
# then what might have been originally used. See
# https://github.com/hex-inc/vegafusion/issues/354 for more info.
chart = alt.Chart.from_dict(chart.to_dict())
datasets = chart.transformed_data(exclude=["some_annotation"])

assert len(datasets) == 2
Expand Down