Skip to content

Commit

Permalink
saving images as PDFs; fixing mask plot
Browse files Browse the repository at this point in the history
  • Loading branch information
SiLiKhon committed Nov 23, 2020
1 parent efb8fd1 commit 3e395e7
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 24 deletions.
100 changes: 77 additions & 23 deletions metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .trends import make_trend_plot


def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy=False):
def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy=False, pdffile=None):
l = min(data_real.min(), data_gen.min())
r = max(data_real.max(), data_gen.max())
bins = np.linspace(l, r, n_bins + 1)
Expand All @@ -27,56 +27,76 @@ def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy

buf = io.BytesIO()
fig.savefig(buf, format='png')
if pdffile is not None: fig.savefig(pdffile, format='pdf')
plt.close(fig)
buf.seek(0)

img = PIL.Image.open(buf)
return np.array(img.getdata(), dtype=np.uint8).reshape(1, img.size[1], img.size[0], -1)


def make_metric_plots(images_real, images_gen, features=None, calc_chi2=False):
def make_metric_plots(images_real, images_gen, features=None, calc_chi2=False, make_pdfs=False):
plots = {}
if make_pdfs: pdf_plots = {}
if calc_chi2:
chi2 = 0

try:
metric_real = get_val_metric_v(images_real)
metric_gen = get_val_metric_v(images_gen )

plots.update({name : make_histograms(real, gen, name)
for name, real, gen in zip(_METRIC_NAMES, metric_real.T, metric_gen.T)})
for name, real, gen in zip(_METRIC_NAMES, metric_real.T, metric_gen.T):
pdffile = None
if make_pdfs:
pdffile = io.BytesIO()
pdf_plots[name] = pdffile
plots[name] = make_histograms(real, gen, name, pdffile=pdffile)


if features is not None:
for feature_name, (feature_real, feature_gen) in features.items():
for metric_name, real, gen in zip(_METRIC_NAMES, metric_real.T, metric_gen.T):
name = f'{metric_name} vs {feature_name}'
pdffile = None
if make_pdfs:
pdffile = io.BytesIO()
pdf_plots[name] = pdffile
if calc_chi2 and (metric_name != "Sum"):
plots[name], chi2_i = make_trend_plot(feature_real, real,
feature_gen, gen,
name, calc_chi2=True)
name, calc_chi2=True,
pdffile=pdffile)
chi2 += chi2_i
else:
plots[name] = make_trend_plot(feature_real, real,
feature_gen, gen, name)
feature_gen, gen, name, pdffile=pdffile)

except AssertionError as e:
print(f"WARNING! Assertion error ({e})")

result = {'plots' : plots}
if calc_chi2:
return plots, chi2
result['chi2'] = chi2
if make_pdfs:
result['pdf_plots'] = pdf_plots

return plots
return result


def make_images_for_model(model,
sample,
return_raw_data=False,
calc_chi2=False,
gen_more=None,
batch_size=128):
batch_size=128,
pdf_outputs=None):
X, Y = sample
assert X.ndim == 2
assert X.shape[1] == 4
make_pdfs = (pdf_outputs is not None)
if make_pdfs:
assert isinstance(pdf_outputs, list)
assert len(pdf_outputs) == 0

if gen_more is None:
gen_features = X
Expand All @@ -102,16 +122,36 @@ def make_images_for_model(model,
'pad_coord_fraction' : (X[:, 3] % 1, gen_features[:,3] % 1)
}

images = make_metric_plots(real, gen, features=features, calc_chi2=calc_chi2)
metric_plot_results = make_metric_plots(real, gen, features=features,
calc_chi2=calc_chi2, make_pdfs=make_pdfs)
images = metric_plot_results['plots']
if calc_chi2:
images, chi2 = images

images1 = make_metric_plots(real, gen1, features=features)

img_amplitude = make_histograms(Y.flatten(), gen_scaled.flatten(), 'log10(amplitude + 1)', logy=True)

images['examples'] = plot_individual_images(Y, gen_scaled)
images['examples_mask'] = plot_images_mask(Y, gen_scaled)
chi2 = metric_plot_results['chi2']
if make_pdfs:
images_pdf = metric_plot_results['pdf_plots']
pdf_outputs.append(images_pdf)

metric_plot_results1 = make_metric_plots(real, gen1, features=features, make_pdfs=make_pdfs)
images1 = metric_plot_results1['plots']
if make_pdfs:
pdf_outputs.append(metric_plot_results1['pdf_plots'])

pdffile = None
if make_pdfs:
pdffile = io.BytesIO()
pdf_outputs.append(pdffile)
img_amplitude = make_histograms(Y.flatten(), gen_scaled.flatten(), 'log10(amplitude + 1)', logy=True,
pdffile=pdffile)

pdffile_examples = None
pdffile_examples_mask = None
if make_pdfs:
pdffile_examples = io.BytesIO()
pdffile_examples_mask = io.BytesIO()
images_pdf['examples'] = pdffile_examples
images_pdf['examples_mask'] = pdffile_examples_mask
images['examples'] = plot_individual_images(Y, gen_scaled, pdffile=pdffile_examples)
images['examples_mask'] = plot_images_mask(Y, gen_scaled, pdffile=pdffile_examples_mask)

result = [images, images1, img_amplitude]

Expand All @@ -126,11 +166,13 @@ def make_images_for_model(model,

def evaluate_model(model, path, sample, gen_sample_name=None):
path.mkdir()
pdf_outputs = []
(
images, images1, img_amplitude,
gen_dataset, chi2
) = make_images_for_model(model, sample=sample,
calc_chi2=True, return_raw_data=True, gen_more=10)
calc_chi2=True, return_raw_data=True, gen_more=10, pdf_outputs=pdf_outputs)
images_pdf, images1_pdf, img_amplitude_pdf = pdf_outputs

array_to_img = lambda arr: PIL.Image.fromarray(arr.reshape(arr.shape[1:]))

Expand All @@ -140,6 +182,16 @@ def evaluate_model(model, path, sample, gen_sample_name=None):
array_to_img(img).save(str(path / f"{k}_amp_gt_1.png"))
array_to_img(img_amplitude).save(str(path / "log10_amp_p_1.png"))

def buf_to_file(buf, filename):
with open(filename, 'wb') as f:
f.write(buf.getbuffer())

for k, img in images_pdf.items():
buf_to_file(img, str(path / f"{k}.pdf"))
for k, img in images1_pdf.items():
buf_to_file(img, str(path / f"{k}_amp_gt_1.pdf"))
buf_to_file(img_amplitude_pdf, str(path / "log10_amp_p_1.pdf"))

if gen_sample_name is not None:
with open(str(path / gen_sample_name), 'w') as f:
for event_X, event_Y in zip(*gen_dataset):
Expand All @@ -155,7 +207,7 @@ def evaluate_model(model, path, sample, gen_sample_name=None):
f.write(f"{chi2:.2f}\n")


def plot_individual_images(real, gen, n=10):
def plot_individual_images(real, gen, n=10, pdffile=None):
assert real.ndim == 3 == gen.ndim
assert real.shape[1:] == gen.shape[1:]
N_max = min(len(real), len(gen))
Expand All @@ -182,28 +234,30 @@ def plot_individual_images(real, gen, n=10):

buf = io.BytesIO()
fig.savefig(buf, format='png')
if pdffile is not None: fig.savefig(pdffile, format='pdf')
plt.close(fig)
buf.seek(0)

img = PIL.Image.open(buf)
return np.array(img.getdata(), dtype=np.uint8).reshape(1, img.size[1], img.size[0], -1)


def plot_images_mask(real, gen):
def plot_images_mask(real, gen, pdffile=None):
assert real.ndim == 3 == gen.ndim
assert real.shape[1:] == gen.shape[1:]

size_x = 6
size_y = size_x / real.shape[2] * real.shape[1] * 2.4

fig, [ax0, ax1] = plt.subplots(2, 1, figsize=(size_x, size_y))
ax0.imshow(real.any(axis=0), aspect='auto')
ax0.imshow((real >= 1.).any(axis=0), aspect='auto')
ax0.set_title("real")
ax1.imshow(gen.any(axis=0), aspect='auto')
ax1.imshow((gen >= 1.).any(axis=0), aspect='auto')
ax1.set_title("generated")

buf = io.BytesIO()
fig.savefig(buf, format='png')
if pdffile is not None: fig.savefig(pdffile, format='pdf')
plt.close(fig)
buf.seek(0)

Expand Down
3 changes: 2 additions & 1 deletion metrics/trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def stats(arr):
return (mean, std), (mean_err, std_err)


def make_trend_plot(feature_real, real, feature_gen, gen, name, calc_chi2=False, figsize=(8, 8)):
def make_trend_plot(feature_real, real, feature_gen, gen, name, calc_chi2=False, figsize=(8, 8), pdffile=None):
feature_real = feature_real.squeeze()
feature_gen = feature_gen.squeeze()
real = real.squeeze()
Expand All @@ -71,6 +71,7 @@ def make_trend_plot(feature_real, real, feature_gen, gen, name, calc_chi2=False,

buf = io.BytesIO()
fig.savefig(buf, format='png')
if pdffile is not None: fig.savefig(pdffile, format='pdf')
plt.close(fig)
buf.seek(0)

Expand Down

0 comments on commit 3e395e7

Please sign in to comment.