Skip to content

Commit

Permalink
added spectral LRP analysis implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jul 10, 2023
1 parent 994c855 commit 6bff4d7
Showing 1 changed file with 42 additions and 7 deletions.
49 changes: 42 additions & 7 deletions visualize_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
"/prodi/hpcmem/spots_ftir/CO1801a/",
"/prodi/hpcmem/spots_ftir/CO722/"]

DATA_OUT: Final[str] = "/prodi/hpcmem/spots_ftir/data_out/"

MODEL_OUTPUT_PATH: Final[str] = "./models/"


Expand Down Expand Up @@ -67,7 +65,7 @@ def main() -> None:

# load a single image to get the input shape
# train data has the shape (batch_size, channels, width, height) -> (BATCH_SIZE, 442, 400, 400)
print("[ creating model ]")
print("[ creating models ]")
tmp_data, _ = next(iter(test_loader))
in_channels = tmp_data.shape[1]

Expand Down Expand Up @@ -137,8 +135,27 @@ def main() -> None:
attribution_depth = lrp_model_depth.attribute(img, target=None).cpu().detach().numpy()[0]

img = img.cpu().detach().numpy()[0]
input_spectral = np.copy(img)
img = np.mean(img, axis=0)

# specral data
input_spectral = input_spectral.mean(axis=(1, 2))
attribution_spectral_normal = attribution_normal.mean(axis=(1, 2))
attribution_spectral_depth = attribution_depth.mean(axis=(1, 2))

# normalize the spectra data
peak_interval = input_spectral[339:379]
peak_point = np.max(peak_interval)
input_spectral = input_spectral / peak_point

peak_interval = attribution_spectral_normal[339:379]
peak_point = np.max(peak_interval)
attribution_spectral_normal = attribution_spectral_normal / peak_point

peak_interval = attribution_spectral_depth[339:379]
peak_point = np.max(peak_interval)
attribution_spectral_depth = attribution_spectral_depth / peak_point

# add the attributions of every channel dimension up
attribution_normal = np.sum(attribution_normal, axis=0)
attribution_normal = np.expand_dims(attribution_normal, axis=0)
Expand All @@ -155,15 +172,13 @@ def main() -> None:
cmap = LinearSegmentedColormap.from_list(
"RdWhGn", ["red", "white", "green"]
)
norm = mpl.colors.Normalize(vmin=-1, vmax=1)

# create and save the normal LRP analysis
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
#axis_separator = make_axes_locatable(axes[2])
#colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1)
norm = mpl.colors.Normalize(vmin=-1, vmax=1)
plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
orientation="vertical", ax=axes,
label="Attribution Values")

axes[0].grid()
axes[0].imshow(img)
axes[0].set_title("Input Image")
Expand All @@ -183,6 +198,26 @@ def main() -> None:
fig.savefig(f"./plots/{model_file}_lrp.png")
plt.close()

# create and save the LRP analysis for the specral channels
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
axes[0].grid()
axes[0].imshow(img)
axes[0].set_title("Input Image")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[1].grid()
axes[1].plot(input_spectral)
axes[1].set_title("Spectral Input")
axes[2].grid()
axes[2].plot(attribution_spectral_normal)
axes[2].set_title("Spectral Attribution Normal")
axes[3].grid()
axes[3].plot(attribution_spectral_depth)
axes[3].set_title("Spectral Attribution Depthwise")

fig.savefig(f"./plots/{model_file}_spectral_lrp.png")
plt.close()

print("[ finished LRP analysis ]")

if __name__ == "__main__":
Expand Down

0 comments on commit 6bff4d7

Please sign in to comment.