Skip to content

Commit

Permalink
Merge pull request #680 from dianna-ai/visualization_tabular
Browse files Browse the repository at this point in the history
662 Visualization module tabular
  • Loading branch information
Yang authored Jan 10, 2024
2 parents 8b48a6e + 4ffd0cc commit cac13ed
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 68 deletions.
1 change: 1 addition & 0 deletions dianna/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa: F401
"""Tools for visualization of model explanations."""
from .image import plot_image
from .tabular import plot_tabular
from .text import highlight_text
from .timeseries import plot_timeseries
52 changes: 52 additions & 0 deletions dianna/visualization/tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Visualization module for tabular data."""
from typing import List
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np


def plot_tabular(
x: np.ndarray,
y: List[str],
x_label: str = "Importance score",
y_label: str = "Features",
num_features: Optional[int] = None,
show_plot: Optional[bool] = True,
output_filename: Optional[str] = None,
) -> plt.Figure:
"""Plot feature importance with segments highlighted.
Args:
x (np.ndarray): Array of feature importance scores
y (List[str]): List of feature names
x_label (str): Label for the x-axis
y_label (str): Label or list of labels for the y-axis
num_features (Optional[int]): Number of most salient features to display
show_plot (bool, optional): Shows plot if true (for testing or writing
plots to disk instead).
output_filename (str, optional): Name of the file to save
the plot to (optional).
Returns:
plt.Figure
"""
if not num_features:
num_features = len(x)
fig, ax = plt.subplots()
abs_values = [abs(i) for i in x]
top_values = [x for _, x in sorted(zip(abs_values, x), reverse=True)][:num_features]
top_features = [x for _, x in sorted(zip(abs_values, y), reverse=True)][
:num_features
]

colors = ["r" if x >= 0 else "b" for x in top_values]
ax.barh(top_features, top_values, color=colors)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)

if show_plot:
plt.show()
if output_filename:
plt.savefig(output_filename)

return fig
2 changes: 1 addition & 1 deletion dianna/visualization/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def plot_timeseries(
x_label: str = 't',
y_label: Union[str, Iterable[str]] = None,
cmap: Optional[str] = None,
show_plot: bool = False,
show_plot: Optional[bool] = False,
output_filename: Optional[str] = None,
) -> plt.Figure:
"""Plot timeseries with segments highlighted.
Expand Down
27 changes: 20 additions & 7 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
"""Unit tests for visualization modules."""
from pathlib import Path
import numpy as np
import pytest
from dianna.visualization import plot_tabular
from dianna.visualization import plot_timeseries


def test_plot_tabular(tmpdir):
"""Test plot tabular data."""
x = np.linspace(-5, 5, 3)
y = [f"Feature {i}" for i in range(len(x))]
output_path = Path(tmpdir) / "temp_visualization_test_tabular.png"

plot_tabular(x=x, y=y, show_plot=False, output_filename=output_path)

assert output_path.exists()


def test_plot_timeseries_univariate(tmpdir, random):
"""Test plot univariate time series."""
x = np.linspace(0, 10, 20)
y = np.sin(x)
segments = get_test_segments(data=np.expand_dims(y, 0))

output_path = Path(tmpdir) / 'temp_visualization_test_univariate.png'
output_path = Path(tmpdir) / "temp_visualization_test_univariate.png"

plot_timeseries(x=x,
y=y,
Expand All @@ -26,7 +39,7 @@ def test_plot_timeseries_multivariate(tmpdir, random):
x = np.linspace(start=0, stop=10, num=20)
ys = np.stack((np.sin(x), np.cos(x), np.tan(0.4 * x)))
segments = get_test_segments(data=ys)
output_path = Path(tmpdir) / 'temp_visualization_test_multivariate.png'
output_path = Path(tmpdir) / "temp_visualization_test_multivariate.png"

plot_timeseries(x=x,
y=ys.T,
Expand All @@ -48,13 +61,13 @@ def get_test_segments(data):
for i_segment in range(n_segments):
for i_channel in range(n_channels):
segment = {
'index': i_segment + i_channel * n_segments,
'start': i_segment,
'stop': i_segment + 1,
'weight': data[i_channel, factor * i_segment],
"index": i_segment + i_channel * n_segments,
"start": i_segment,
"stop": i_segment + 1,
"weight": data[i_channel, factor * i_segment],
}
if n_channels > 1:
segment['channel'] = i_channel
segment["channel"] = i_channel
segments.append(segment)

return segments
Expand Down
33 changes: 7 additions & 26 deletions tutorials/lime_tabular_penguin.ipynb

Large diffs are not rendered by default.

49 changes: 15 additions & 34 deletions tutorials/lime_tabular_weather.ipynb

Large diffs are not rendered by default.

0 comments on commit cac13ed

Please sign in to comment.