Skip to content

Commit

Permalink
added final plotting implementation for LRP
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jul 3, 2023
1 parent ad40d5d commit 994c855
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 35 deletions.
59 changes: 58 additions & 1 deletion kernel_eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
"""utility library for various functions"""
import os
import random
from enum import Enum
from glob import glob
from datetime import datetime
from typing import List, Tuple
from typing import List, Tuple, Optional, Union
import torch
from torch import nn
import numpy as np
from matplotlib import pyplot as plt


class VisualizeSign(Enum):
POSITIVE = 1
ABSOLUTE = 2
NEGATIVE = 3
ALL = 4


def save_model(model_path: str, model_name: str, depthwise: bool,
batch_size: int, lr: float, epochs: int, model: nn.Module) -> None:
"""
Expand Down Expand Up @@ -209,3 +217,52 @@ def normalize_spectral_data(img: torch.Tensor) -> torch.Tensor:
img -= mean

return img


def normalize_attribute(
attr: np.ndarray,
sign: str,
outlier_perc: Union[int, float] = 2,
reduction_axis: Optional[int] = None,
):
attr_combined = attr
if reduction_axis is not None:
attr_combined = np.sum(attr, axis=reduction_axis)

# Choose appropriate signed values and rescale, removing given outlier percentage.
if VisualizeSign[sign] == VisualizeSign.ALL:
threshold = cumulative_sum_threshold(
np.abs(attr_combined), 100 - outlier_perc)
elif VisualizeSign[sign] == VisualizeSign.POSITIVE:
attr_combined = (attr_combined > 0) * attr_combined
threshold = cumulative_sum_threshold(
attr_combined, 100 - outlier_perc)
elif VisualizeSign[sign] == VisualizeSign.NEGATIVE:
attr_combined = (attr_combined < 0) * attr_combined
threshold = -1 * cumulative_sum_threshold(
np.abs(attr_combined), 100 - outlier_perc
)
elif VisualizeSign[sign] == VisualizeSign.ABSOLUTE:
attr_combined = np.abs(attr_combined)
threshold = cumulative_sum_threshold(
attr_combined, 100 - outlier_perc)
else:
raise AssertionError("Visualize Sign type is not valid.")
return normalize_scale(attr_combined, threshold)


def normalize_scale(attr: np.ndarray, scale_factor: float):
assert scale_factor != 0, "Cannot normalize by scale factor = 0"
attr_norm = attr / scale_factor
return np.clip(attr_norm, -1, 1)


def cumulative_sum_threshold(values: np.ndarray, percentile: Union[int, float]):
# given values should be non-negative
assert percentile in range(0, 101), (
"Percentile for thresholding must be " "between 0 and 100 inclusive."
)
sorted_vals = np.sort(values.flatten())
cum_sums = np.cumsum(sorted_vals)
threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
return sorted_vals[threshold_id]
121 changes: 87 additions & 34 deletions visualize_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import torch
from torch.utils.data import DataLoader
from captum.attr import LRP
from captum.attr import visualization as viz
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from kernel_eval.models import vgg11, vgg13, vgg16, vgg19, resnet34
from kernel_eval.datasets import SingleFileDataset
from kernel_eval.datasets import SingleFileDatasetLoadingOptions
from kernel_eval.utils import load_model
from kernel_eval.utils import load_model, normalize_attribute

DATA_PATHS: Final[List[str]] = ["/prodi/hpcmem/spots_ftir/LC704/",
"/prodi/hpcmem/spots_ftir/BC051111/",
Expand All @@ -43,6 +44,8 @@ def main() -> None:
None
"""
device = "cpu"
if not os.path.exists("plots/"):
os.mkdir("plots/")

print("\n\n\n"+"#"*60)
print("## " + str(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")))
Expand All @@ -60,7 +63,7 @@ def main() -> None:
augment=True,
normalize=True)

test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=True, num_workers=1)

# load a single image to get the input shape
# train data has the shape (batch_size, channels, width, height) -> (BATCH_SIZE, 442, 400, 400)
Expand All @@ -70,67 +73,117 @@ def main() -> None:

# ---------------- Load and Train Models ---------------
model_files: List[str] = glob(MODEL_OUTPUT_PATH+"*") # load all models saved under the path
# filter deothwsie models out since they will be compared separately
model_files = list(filter(lambda key: "depthwise" not in key, model_files))

for model_file in model_files:
# obtain the model type and hyperparameters from the model file name
model_file = model_file.split("/")[-1]
depthwise = "depthwise" in model_file
model_type = model_file.split("_")[0]
batch_size = int(model_file.split("_")[1][:-2])
learning_rate = float(model_file.split("_")[2][:-2])
epochs = int(model_file.split("_")[3][:-2])

# create the normal model
match model_type:
case "vgg11": model = vgg11(in_channels=in_channels, depthwise=depthwise, num_classes=1)
case "vgg13": model = vgg13(in_channels=in_channels, depthwise=depthwise, num_classes=1)
case "vgg16": model = vgg16(in_channels=in_channels, depthwise=depthwise, num_classes=1)
case "vgg19": model = vgg19(in_channels=in_channels, depthwise=depthwise, num_classes=1)
case "vgg11": model_normal = vgg11(in_channels=in_channels,
depthwise=False, num_classes=1)
case "vgg13": model_normal = vgg13(in_channels=in_channels,
depthwise=False, num_classes=1)
case "vgg16": model_normal = vgg16(in_channels=in_channels,
depthwise=False, num_classes=1)
case "vgg19": model_normal = vgg19(in_channels=in_channels,
depthwise=False, num_classes=1)
case "resnet34": model_normal = resnet34(in_channels=in_channels,
depthwise=False, num_classes=1)
case _: raise ValueError(f"Model {model_type} not supported")

case "resnet34": model = resnet34(in_channels=in_channels,
depthwise=depthwise, num_classes=1)
# create the depthwise counterpart
match model_type:
case "vgg11": model_depth = vgg11(in_channels=in_channels,
depthwise=True, num_classes=1)
case "vgg13": model_depth = vgg13(in_channels=in_channels,
depthwise=True, num_classes=1)
case "vgg16": model_depth = vgg16(in_channels=in_channels,
depthwise=True, num_classes=1)
case "vgg19": model_depth = vgg19(in_channels=in_channels,
depthwise=True, num_classes=1)
case "resnet34": model_depth = resnet34(in_channels=in_channels,
depthwise=True, num_classes=1)
case _: raise ValueError(f"Model {model_type} not supported")

model = model.to(device)
# load the models
model_normal = model_normal.to(device)
model_normal = load_model(MODEL_OUTPUT_PATH, model_type, False,
batch_size, learning_rate, epochs, model_normal)
model_normal.eval()
model_normal.zero_grad()

model = load_model(MODEL_OUTPUT_PATH, model_type, depthwise,
batch_size, learning_rate, epochs, model)
model.eval()
model.zero_grad()

model_depth = model_depth.to(device)
model_depth = load_model(MODEL_OUTPUT_PATH, model_type, True,
batch_size, learning_rate, epochs, model_depth)
model_depth.eval()
model_depth.zero_grad()

# ---------------- Evaluate Models ----------------
lrp_model = LRP(model)
lrp_model_normal = LRP(model_normal)
lrp_model_depth = LRP(model_depth)
img, _ = next(iter(test_loader))
img.requires_grad = True

print(f"[ analyze model: {model_file} ]")
attribution = lrp_model.attribute(img, target=None).cpu().detach().numpy()[0]
attribution_normal = lrp_model_normal.attribute(img, target=None).cpu().detach().numpy()[0]
attribution_depth = lrp_model_depth.attribute(img, target=None).cpu().detach().numpy()[0]

img = img.cpu().detach().numpy()[0]
img = np.mean(img, axis=0)
img = np.expand_dims(img, axis=0)

# add the attributions of every channel dimension up
attribution = np.sum(attribution, axis=0)
attribution = np.expand_dims(attribution, axis=0)
attribution_normal = np.sum(attribution_normal, axis=0)
attribution_normal = np.expand_dims(attribution_normal, axis=0)
attribution_depth = np.sum(attribution_depth, axis=0)
attribution_depth = np.expand_dims(attribution_depth, axis=0)

# move the channel dimension to the last dimension for numpy (C,H,W) -> (H,W,C)
# shape is (442, 244, 244) -> (244, 244, 442)
img = np.moveaxis(img, 0, -1)
attribution = np.moveaxis(attribution, 0, -1)

vis_types = ["heat_map", "original_image"]
vis_signs = ["all", "all"] # "positive", "negative", or "all" to show both
# positive attribution indicates that the presence of the area increases the pred. score
# negative attribution indicates distractor areas whose absence increases the pred. score

fig, _ = viz.visualize_image_attr_multiple(attribution, img, vis_types, vis_signs,
["Attribution", "Image"], show_colorbar = True)

if not os.path.exists("plots/"):
os.mkdir("plots/")
attribution_normal = np.moveaxis(attribution_normal, 0, -1)
attribution_depth = np.moveaxis(attribution_depth, 0, -1)
attribution_normal = normalize_attribute(attribution_normal, "ALL", 2, 2)
attribution_depth = normalize_attribute(attribution_depth, "ALL", 2, 2)

cmap = LinearSegmentedColormap.from_list(
"RdWhGn", ["red", "white", "green"]
)

fig, axes = plt.subplots(1, 3, figsize=(10, 3))
#axis_separator = make_axes_locatable(axes[2])
#colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1)
norm = mpl.colors.Normalize(vmin=-1, vmax=1)
plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
orientation="vertical", ax=axes,
label="Attribution Values")

axes[0].grid()
axes[0].imshow(img)
axes[0].set_title("Input Image")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[1].grid()
axes[1].imshow(attribution_normal, cmap=cmap, vmin=-1, vmax=1)
axes[1].set_title("Attribution Normal")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[2].grid()
axes[2].imshow(attribution_depth, cmap=cmap, vmin=-1, vmax=1)
axes[2].set_title("Attribution Depthwise")
axes[2].set_xticks([])
axes[2].set_yticks([])

plt.title(f"LRP: {model_file}")
fig.savefig(f"./plots/{model_file}_lrp.png")
plt.close()

print("[ finished LRP analysis ]")

if __name__ == "__main__":
main()

0 comments on commit 994c855

Please sign in to comment.