Skip to content

Commit

Permalink
fixed image shaoes
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jun 27, 2023
1 parent e32e9bb commit d8f2391
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions visualize_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,20 @@ def main() -> None:
# ---------------- Evaluate Models ----------------
lrp_model = LRP(model)
img, label = next(iter(test_loader))
attribution = lrp_model.attribute(img, target=label).cpu().detach().numpy()
attribution = lrp_model.attribute(img, target=label).cpu().detach().numpy()[0, 0:3, :, :]
# move the channel dimension to the last dimension for numpy (C,H,W) -> (H,W,C)
img = np.moveaxis(img.cpu().detach().numpy(), 1, -1)
attribution = np.moveaxis(attribution, -1, 0)
img = img.cpu().detach().numpy()[0, 0:3, :, :]
img = np.moveaxis(img, 0, -1)
attribution = np.moveaxis(attribution, 0, -1)

vis_types = ["heat_map", "original_image"]
vis_signs = ["all", "all"] # "positive", "negative", or "all" to show both
# positive attribution indicates that the presence of the area increases the pred. score
# negative attribution indicates distractor areas whose absence increases the pred. score

_ = viz.visualize_image_attr_multiple(attribution, img, vis_types, vis_signs,
["Attribution", "Image"], show_colorbar = True)
plt.savefig(f"./plots/{model_file}_lrp.png")
fig, _ = viz.visualize_image_attr_multiple(attribution, img, vis_types, vis_signs,
["Attribution", "Image"], show_colorbar = True)
fig.savefig(f"./plots/{model_file}_lrp.png")
plt.close()


Expand Down

0 comments on commit d8f2391

Please sign in to comment.