Skip to content

Commit

Permalink
fixed LRP implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jul 3, 2023
1 parent 4bd4aca commit ad40d5d
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions visualize_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from kernel_eval.models import vgg11, vgg13, vgg16, vgg19, resnet34
from kernel_eval.datasets import SingleFileDataset
from kernel_eval.datasets import SingleFileDatasetLoadingOptions
from kernel_eval.utils import load_model, augment_images, normalize_spectral_data
from kernel_eval.utils import load_model

DATA_PATHS: Final[List[str]] = ["/prodi/hpcmem/spots_ftir/LC704/",
"/prodi/hpcmem/spots_ftir/BC051111/",
Expand Down Expand Up @@ -69,7 +69,7 @@ def main() -> None:
in_channels = tmp_data.shape[1]

# ---------------- Load and Train Models ---------------
model_files = glob(MODEL_OUTPUT_PATH+"*") # load all models saved under the path
model_files: List[str] = glob(MODEL_OUTPUT_PATH+"*") # load all models saved under the path

for model_file in model_files:
model_file = model_file.split("/")[-1]
Expand All @@ -87,20 +87,32 @@ def main() -> None:

case "resnet34": model = resnet34(in_channels=in_channels,
depthwise=depthwise, num_classes=1)
case _: raise ValueError(f"Model {model} not supported")
case _: raise ValueError(f"Model {model_type} not supported")

model = model.to(device)

model = load_model(MODEL_OUTPUT_PATH, model_type, depthwise,
batch_size, learning_rate, epochs, model)
model.eval()
model.zero_grad()

# ---------------- Evaluate Models ----------------
lrp_model = LRP(model)
img, label = next(iter(test_loader))
attribution = lrp_model.attribute(img, target=label).cpu().detach().numpy()[0, 0:3, :, :]
img, _ = next(iter(test_loader))
img.requires_grad = True

print(f"[ analyze model: {model_file} ]")
attribution = lrp_model.attribute(img, target=None).cpu().detach().numpy()[0]
img = img.cpu().detach().numpy()[0]
img = np.mean(img, axis=0)
img = np.expand_dims(img, axis=0)

# add the attributions of every channel dimension up
attribution = np.sum(attribution, axis=0)
attribution = np.expand_dims(attribution, axis=0)

# move the channel dimension to the last dimension for numpy (C,H,W) -> (H,W,C)
img = img.cpu().detach().numpy()[0, 0:3, :, :]
# shape is (442, 244, 244) -> (244, 244, 442)
img = np.moveaxis(img, 0, -1)
attribution = np.moveaxis(attribution, 0, -1)

Expand All @@ -111,6 +123,11 @@ def main() -> None:

fig, _ = viz.visualize_image_attr_multiple(attribution, img, vis_types, vis_signs,
["Attribution", "Image"], show_colorbar = True)

if not os.path.exists("plots/"):
os.mkdir("plots/")

plt.title(f"LRP: {model_file}")
fig.savefig(f"./plots/{model_file}_lrp.png")
plt.close()

Expand Down

0 comments on commit ad40d5d

Please sign in to comment.