Skip to content

Commit

Permalink
fixed some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jun 27, 2023
1 parent 7fb8eec commit e32e9bb
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions visualize_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader
from captum.attr import LRP
from captum.attr import visualization as viz
# import numpy as np
import numpy as np
from matplotlib import pyplot as plt

from kernel_eval.models import vgg11, vgg13, vgg16, vgg19, resnet34
Expand Down Expand Up @@ -66,13 +66,13 @@ def main() -> None:
# train data has the shape (batch_size, channels, width, height) -> (BATCH_SIZE, 442, 400, 400)
print("[ creating model ]")
tmp_data, _ = next(iter(test_loader))
tmp_data = augment_images(tmp_data, size=224)
in_channels = tmp_data.shape[1]

# ---------------- Load and Train Models ---------------
model_files = glob(MODEL_OUTPUT_PATH)
model_files = glob(MODEL_OUTPUT_PATH+"*") # load all models saved under the path

for model_file in model_files:
model_file = model_file.split("/")[-1]
depthwise = "depthwise" in model_file
model_type = model_file.split("_")[0]
batch_size = int(model_file.split("_")[1][:-2])
Expand All @@ -98,17 +98,18 @@ def main() -> None:
# ---------------- Evaluate Models ----------------
lrp_model = LRP(model)
img, label = next(iter(test_loader))
img = augment_images(img, size=224)
img = normalize_spectral_data(img)
attribution = lrp_model.attribute(img, target=label).cpu().detach().numpy()
# move the channel dimension to the last dimension for numpy (C,H,W) -> (H,W,C)
img = np.moveaxis(img.cpu().detach().numpy(), 1, -1)
attribution = np.moveaxis(attribution, -1, 0)

vis_types = ["heat_map", "original_image"]
vis_signs = ["all", "all"] # "positive", "negative", or "all" to show both
# positive attribution indicates that the presence of the area increases the pred. score
# negative attribution indicates distractor areas whose absence increases the pred. score

_ = viz.visualize_image_attr_multiple(attribution, img, vis_types, vis_signs,
["Attribution", "Image"], show_colorbar = True)
["Attribution", "Image"], show_colorbar = True)
plt.savefig(f"./plots/{model_file}_lrp.png")
plt.close()

Expand Down

0 comments on commit e32e9bb

Please sign in to comment.