-
Notifications
You must be signed in to change notification settings - Fork 18
/
plot_utils.py
18 lines (16 loc) · 875 Bytes
/
plot_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np
from matplotlib import pyplot as plt
def make_squares(images, nr_images_per_side):
images_to_plot = np.concatenate(
[np.concatenate([images[j*nr_images_per_side+i].reshape((28,28)) for i in range(0,nr_images_per_side)],
axis=1)
for j in range(0,nr_images_per_side)],
axis=0)
return images_to_plot
def plot_squares(originals, reconstructs, nr_images_per_side):
originals_square = make_squares(originals, nr_images_per_side)
plt.imsave('./results/original.png', originals_square, cmap='viridis')
reconstructs_square = make_squares(reconstructs, nr_images_per_side)
plt.imsave('./results/recons.png', reconstructs_square, cmap='viridis')
combined = np.concatenate([originals_square, reconstructs_square], axis=1)
plt.imsave('./results/combined.png', combined, cmap='viridis')