Skip to content

Commit

Permalink
merge main and resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang committed Jan 11, 2024
2 parents 8fac781 + cac13ed commit a223af2
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 250 deletions.
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ DIANNA comes with simple datasets. Their main goal is to provide intuitive insig
| [Coffee dataset](https://timeseriesclassification.com/description.php?Dataset=Coffee) <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162"> | Food spectographs time series dataset for a two class problem to distinguish between Robusta and Arabica coffee beans. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/763002c5-40ad-48cc-9de0-ea43d7fa8a75)"> | [data source](https://github.com/QIBChemometrics/Benchtop-NMR-Coffee-Survey) |
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d)"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |

### Tabular

| Dataset | Description | Examples | Generation |
| :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------ |
| [Pengiun dataset](https://www.kaggle.com/code/parulpandey/penguin-dataset-the-new-iris) <img width="75" alt="Penguins Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/c7716ad3-f992-4557-80d9-1d8178c7ed57)"> | Palmer Archipelago (Antarctica) penguin dataset is a great intro dataset for data exploration & visualization similar to the famous Iris dataset. | <img width="500" alt="example image" src="https://github.com/allisonhorst/palmerpenguins/blob/main/man/figures/README-mass-flipper-1.png"> | [data source](https://github.com/allisonhorst/palmerpenguins) |
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d)"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |

## ONNX models

<!-- TODO: Add all links, see issue https://github.com/dianna-ai/dianna/issues/135 -->
Expand Down Expand Up @@ -267,6 +274,14 @@ And here are links to notebooks showing how we created our models on the benchma
| Coffee model | [Coffee model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/coffee/generate_model.ipynb) |
| [Season prediction model](https://zenodo.org/record/7543883) | [Season prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/season_prediction/generate_model.ipynb) |

### Tabular

| Models | Generation |
| :-------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Penguin model (classification) | [Penguin model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/penguin_species/generate_model.ipynb) |
| Sunshine hours prediction model (regression) | [Sunshine hours prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/sunshine_prediction/generate_model.ipynb) |


**_We envision the birth of the ONNX Scientific models zoo soon..._**

## Tutorials
Expand All @@ -279,9 +294,9 @@ DIANNA supports different data modalities and XAI methods. The table contains li
| :--------- | :------------------------------------------------ | :----------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------- |
| Images ||||
| Text ||| |
| Timeseries ||| |
| Embedding | planned | planned | planned |
| Tabular | planned | planned | planned |
| Timeseries ||| |
| Tabular | planned | | planned |
| Embedding | planned | planned | planned
| Graphs* | work in progress | work in progress | work in progress |

[LRP](https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0130140&type=printable) and [PatternAttribution](https://arxiv.org/pdf/1705.05598.pdf) also feature in the top 5 of our thoroughly evaluated XAI methods using objective criteria (details in coming blog-post). **Contributing by adding these and more (new) post-hoc explainability methods on ONNX models is very welcome!**
Expand Down
1 change: 1 addition & 0 deletions dianna/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa: F401
"""Tools for visualization of model explanations."""
from .image import plot_image
from .tabular import plot_tabular
from .text import highlight_text
from .timeseries import plot_timeseries
52 changes: 52 additions & 0 deletions dianna/visualization/tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Visualization module for tabular data."""
from typing import List
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np


def plot_tabular(
x: np.ndarray,
y: List[str],
x_label: str = "Importance score",
y_label: str = "Features",
num_features: Optional[int] = None,
show_plot: Optional[bool] = True,
output_filename: Optional[str] = None,
) -> plt.Figure:
"""Plot feature importance with segments highlighted.
Args:
x (np.ndarray): Array of feature importance scores
y (List[str]): List of feature names
x_label (str): Label for the x-axis
y_label (str): Label or list of labels for the y-axis
num_features (Optional[int]): Number of most salient features to display
show_plot (bool, optional): Shows plot if true (for testing or writing
plots to disk instead).
output_filename (str, optional): Name of the file to save
the plot to (optional).
Returns:
plt.Figure
"""
if not num_features:
num_features = len(x)
fig, ax = plt.subplots()
abs_values = [abs(i) for i in x]
top_values = [x for _, x in sorted(zip(abs_values, x), reverse=True)][:num_features]
top_features = [x for _, x in sorted(zip(abs_values, y), reverse=True)][
:num_features
]

colors = ["r" if x >= 0 else "b" for x in top_values]
ax.barh(top_features, top_values, color=colors)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)

if show_plot:
plt.show()
if output_filename:
plt.savefig(output_filename)

return fig
2 changes: 1 addition & 1 deletion dianna/visualization/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def plot_timeseries(
x_label: str = 't',
y_label: Union[str, Iterable[str]] = None,
cmap: Optional[str] = None,
show_plot: bool = False,
show_plot: Optional[bool] = False,
output_filename: Optional[str] = None,
) -> plt.Figure:
"""Plot timeseries with segments highlighted.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Unit tests for KernelSHAP image."""
from unittest import TestCase
import numpy as np
from dianna.methods.kernelshap_image import KERNELSHAPImage


class ShapOnImages(TestCase):
"""Suite of Kernelshap tests for the image case."""

def test_shap_segment_image(self):
"""Test if the segmentation of images are correct given some data."""
input_data = np.random.random((28, 28, 1))
Expand Down Expand Up @@ -41,7 +43,10 @@ def test_shap_mask_image(self):
sigma,
)
masked_image = explainer._mask_image(
np.zeros((1, n_segments)), segments_slic, input_data, background,
np.zeros((1, n_segments)),
segments_slic,
input_data,
background,
)
# check if all points are masked
assert np.array_equal(masked_image[0], np.zeros(input_data.shape))
Expand Down
67 changes: 67 additions & 0 deletions tests/methods/test_lime_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Unit tests for LIME image."""
from unittest import TestCase
import numpy as np
import dianna
from dianna.methods.lime_image import LIMEImage
from tests.methods.test_onnx_runner import generate_data
from tests.utils import run_model


class LimeOnImages(TestCase):
"""Suite of Lime tests for the image case."""

@staticmethod
def test_lime_function():
"""Test if lime runs and outputs are correct given some data and a model function."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_filename():
"""Test if lime runs and outputs are correct given some data and a model file."""
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1)[0].astype(np.float32)
axis_labels = ('channels', 'y', 'x')
labels = [1]

heatmap = dianna.explain_image(model_filename,
input_data,
method='LIME',
labels=labels,
random_state=42,
axis_labels=axis_labels)

heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
assert heatmap[0].shape == input_data[0].shape
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_values():
"""Test if get_explanation_values function works correctly."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_values.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
return_masks=False,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

def setUp(self) -> None:
"""Set seed."""
np.random.seed(42)
86 changes: 13 additions & 73 deletions tests/test_lime.py → tests/methods/test_lime_text.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,10 @@
"""Unit tests for LIME text."""
from unittest import TestCase
import numpy as np
import pytest
import dianna
import dianna.visualization
from dianna.methods.lime_image import LIMEImage
from tests.test_onnx_runner import generate_data
from tests.utils import assert_explanation_satisfies_expectations
from tests.utils import load_movie_review_model
from tests.utils import run_model


class LimeOnImages(TestCase):
"""Suite of Lime tests for the image case."""

@staticmethod
def test_lime_function():
"""Test if lime runs and outputs are correct given some data and a model function."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_filename():
"""Test if lime runs and outputs are correct given some data and a model file."""
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1)[0].astype(np.float32)
axis_labels = ('channels', 'y', 'x')
labels = [1]

heatmap = dianna.explain_image(model_filename,
input_data,
method='LIME',
labels=labels,
random_state=42,
axis_labels=axis_labels)

heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
assert heatmap[0].shape == input_data[0].shape
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_values():
"""Test if get_explanation_values function works correctly."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_values.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
return_masks=False,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

def setUp(self) -> None:
"""Set seed."""
np.random.seed(42)


class LimeOnText(TestCase):
Expand Down Expand Up @@ -165,21 +102,24 @@ def tokenizer():
('UNKWORDZ a bad UNKWORDZ UNKWORDZ!?\'"', 9),
('such UNKWORDZ UNKWORDZ movie "UNKWORDZUNKWORDZ\'UNKWORDZ', 9),
('such a bad UNKWORDZ UNKWORDZ!UNKWORDZ\'UNKWORDZ', 9),
pytest.param('its own self-UNKWORDZ universe.', 7,
pytest.param('its own self-UNKWORDZ universe.',
7,
marks=pytest.mark.xfail(reason='poor handling of -')),
pytest.param('its own UNKWORDZ-contained universe.', 7,
pytest.param('its own UNKWORDZ-contained universe.',
7,
marks=pytest.mark.xfail(reason='poor handling of -')),
pytest.param('Backslashes are UNKWORDZ/cool.', 6,
pytest.param('Backslashes are UNKWORDZ/cool.',
6,
marks=pytest.mark.xfail(reason='/ poor handling of /')),
pytest.param('Backslashes are fun/UNKWORDZ.', 6,
pytest.param('Backslashes are fun/UNKWORDZ.',
6,
marks=pytest.mark.xfail(reason='poor handling of /')),
pytest.param(' ', 0,
marks=pytest.mark.xfail(reason='Repeated whitespaces')),
pytest.param('I like whitespaces.', 4,
pytest.param(
' ', 0, marks=pytest.mark.xfail(reason='Repeated whitespaces')),
pytest.param('I like whitespaces.',
4,
marks=pytest.mark.xfail(reason='Repeated whitespaces')),
])


def test_spacytokenizer_length(text, length, tokenizer):
"""Test that tokenizer returns strings of the correct length."""
tokens = tokenizer.tokenize(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

def generate_data(batch_size):
"""Generate a batch of random data."""
return np.random.randint(0, 256, size=(batch_size, 1, 28, 28)) # MNIST shape
return np.random.randint(0, 256,
size=(batch_size, 1, 28, 28)) # MNIST shape


def test_onnx_runner():
Expand Down
Loading

0 comments on commit a223af2

Please sign in to comment.