Skip to content

Commit

Permalink
improved plot quality
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jul 10, 2023
1 parent ca1ccfb commit c8b0ee4
Showing 1 changed file with 49 additions and 44 deletions.
93 changes: 49 additions & 44 deletions visualize_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket
import datetime
import os
import argparse
from glob import glob
from typing import Final, List

Expand All @@ -30,13 +31,13 @@
MODEL_OUTPUT_PATH: Final[str] = "./models/"


def main() -> None:
def main(has_cancer: bool) -> None:
"""
Load all pre-trained models under {MODEL_OUTPUT_PATH} and evaluate them on the test set to
visualize the Layer Relevance Propagation (LRP).
Parameters:
None
has_cancer: bool - type of cancer example to visualize (positive or negative)
Returns:
None
Expand Down Expand Up @@ -127,7 +128,14 @@ def main() -> None:
# ---------------- Evaluate Models ----------------
lrp_model_normal = LRP(model_normal)
lrp_model_depth = LRP(model_depth)
img, _ = next(iter(test_loader))
img: torch.Tensor = None
# search for an example of the given type
for _, (curr_img, curr_label) in enumerate(test_loader):
if bool(curr_label) == has_cancer:
img = curr_img
break

assert img is not None, f"Could not find an example with cancer={has_cancer}"
img.requires_grad = True

print(f"[ analyze model: {model_file} ]")
Expand Down Expand Up @@ -173,54 +181,51 @@ def main() -> None:
"RdWhGn", ["red", "white", "green"]
)
norm = mpl.colors.Normalize(vmin=-1, vmax=1)
mpl.rcParams.update({"font.size": 20})
mpl.rcParams.update({"axes.titlesize": 20})
mpl.rcParams.update({"axes.labelsize": 15})

# create and save the normal LRP analysis
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
fig.suptitle(f"LRP Analysis - {model_type}")
fig, axes = plt.subplots(2, 3, figsize=(20, 10))
fig.suptitle(f"LRP Analysis - Model: {model_type} - Cancer: {has_cancer}")
plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
orientation="vertical", ax=axes,
label="Attribution Values")
axes[0].grid()
axes[0].imshow(img)
axes[0].set_title("Input Image")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[1].grid()
axes[1].imshow(attribution_normal, cmap=cmap, vmin=-1, vmax=1)
axes[1].set_title("Attribution Normal")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[2].grid()
axes[2].imshow(attribution_depth, cmap=cmap, vmin=-1, vmax=1)
axes[2].set_title("Attribution Depthwise")
axes[2].set_xticks([])
axes[2].set_yticks([])

plt.savefig(f"./plots/{model_file}_lrp.png")
plt.close()

# create and save the LRP analysis for the specral channels
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
fig.suptitle(f"Spectral LRP Analysis - {model_type}")
axes[0].grid()
axes[0].imshow(img)
axes[0].set_title("Input Image")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[1].grid()
axes[1].plot(input_spectral)
axes[1].set_title("Spectral Input")
axes[2].grid()
axes[2].plot(attribution_spectral_normal)
axes[2].set_title("Spectral Attribution Normal")
axes[3].grid()
axes[3].plot(attribution_spectral_depth)
axes[3].set_title("Spectral Attribution Depthwise")

plt.savefig(f"./plots/{model_file}_spectral_lrp.png")
plt.rc("font", size=10) # controls default text sizes
axes[0][0].grid()
axes[0][0].imshow(img)
axes[0][0].set_title("Input Image")
axes[0][0].set_yticks([])
axes[0][0].set_xticks([])
axes[0][1].grid()
axes[0][1].imshow(attribution_normal, cmap=cmap, vmin=-1, vmax=1)
axes[0][1].set_title("Attribution Normal")
axes[0][1].set_xticks([])
axes[0][1].set_yticks([])
axes[0][2].grid()
axes[0][2].imshow(attribution_depth, cmap=cmap, vmin=-1, vmax=1)
axes[0][2].set_title("Attribution Depthwise")
axes[0][2].set_xticks([])
axes[0][2].set_yticks([])

axes[1][0].grid()
axes[1][0].plot(input_spectral)
axes[1][0].set_title("Spectral Input")
axes[1][1].grid()
axes[1][1].plot(attribution_spectral_normal)
axes[1][1].set_title("Spectral Attribution Normal")
axes[1][2].grid()
axes[1][2].plot(attribution_spectral_depth)
axes[1][2].set_title("Spectral Attribution Depthwise")

plt.savefig(f"./plots/{model_file}_lrp_C_{has_cancer}.png")
plt.close()

print("[ finished LRP analysis ]")

if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--has_cancer", "-hc", type=bool, default=True,
help="choose between negative/positive cancer examples")
args = parser.parse_args()
main(**vars(args))

0 comments on commit c8b0ee4

Please sign in to comment.