From 3e395e7d0a2015be77bc3552f43f44dc81d33eef Mon Sep 17 00:00:00 2001 From: Artem Maevskiy Date: Mon, 23 Nov 2020 22:05:56 +0000 Subject: [PATCH] saving images as PDFs; fixing mask plot --- metrics/__init__.py | 100 ++++++++++++++++++++++++++++++++++---------- metrics/trends.py | 3 +- 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/metrics/__init__.py b/metrics/__init__.py index b64643d..6d9771d 100644 --- a/metrics/__init__.py +++ b/metrics/__init__.py @@ -12,7 +12,7 @@ from .trends import make_trend_plot -def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy=False): +def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy=False, pdffile=None): l = min(data_real.min(), data_gen.min()) r = max(data_real.max(), data_gen.max()) bins = np.linspace(l, r, n_bins + 1) @@ -27,6 +27,7 @@ def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy buf = io.BytesIO() fig.savefig(buf, format='png') + if pdffile is not None: fig.savefig(pdffile, format='pdf') plt.close(fig) buf.seek(0) @@ -34,8 +35,9 @@ def make_histograms(data_real, data_gen, title, figsize=(8, 8), n_bins=100, logy return np.array(img.getdata(), dtype=np.uint8).reshape(1, img.size[1], img.size[0], -1) -def make_metric_plots(images_real, images_gen, features=None, calc_chi2=False): +def make_metric_plots(images_real, images_gen, features=None, calc_chi2=False, make_pdfs=False): plots = {} + if make_pdfs: pdf_plots = {} if calc_chi2: chi2 = 0 @@ -43,29 +45,42 @@ def make_metric_plots(images_real, images_gen, features=None, calc_chi2=False): metric_real = get_val_metric_v(images_real) metric_gen = get_val_metric_v(images_gen ) - plots.update({name : make_histograms(real, gen, name) - for name, real, gen in zip(_METRIC_NAMES, metric_real.T, metric_gen.T)}) + for name, real, gen in zip(_METRIC_NAMES, metric_real.T, metric_gen.T): + pdffile = None + if make_pdfs: + pdffile = io.BytesIO() + pdf_plots[name] = pdffile + plots[name] = make_histograms(real, gen, name, pdffile=pdffile) + if features is not None: for feature_name, (feature_real, feature_gen) in features.items(): for metric_name, real, gen in zip(_METRIC_NAMES, metric_real.T, metric_gen.T): name = f'{metric_name} vs {feature_name}' + pdffile = None + if make_pdfs: + pdffile = io.BytesIO() + pdf_plots[name] = pdffile if calc_chi2 and (metric_name != "Sum"): plots[name], chi2_i = make_trend_plot(feature_real, real, feature_gen, gen, - name, calc_chi2=True) + name, calc_chi2=True, + pdffile=pdffile) chi2 += chi2_i else: plots[name] = make_trend_plot(feature_real, real, - feature_gen, gen, name) + feature_gen, gen, name, pdffile=pdffile) except AssertionError as e: print(f"WARNING! Assertion error ({e})") + result = {'plots' : plots} if calc_chi2: - return plots, chi2 + result['chi2'] = chi2 + if make_pdfs: + result['pdf_plots'] = pdf_plots - return plots + return result def make_images_for_model(model, @@ -73,10 +88,15 @@ def make_images_for_model(model, return_raw_data=False, calc_chi2=False, gen_more=None, - batch_size=128): + batch_size=128, + pdf_outputs=None): X, Y = sample assert X.ndim == 2 assert X.shape[1] == 4 + make_pdfs = (pdf_outputs is not None) + if make_pdfs: + assert isinstance(pdf_outputs, list) + assert len(pdf_outputs) == 0 if gen_more is None: gen_features = X @@ -102,16 +122,36 @@ def make_images_for_model(model, 'pad_coord_fraction' : (X[:, 3] % 1, gen_features[:,3] % 1) } - images = make_metric_plots(real, gen, features=features, calc_chi2=calc_chi2) + metric_plot_results = make_metric_plots(real, gen, features=features, + calc_chi2=calc_chi2, make_pdfs=make_pdfs) + images = metric_plot_results['plots'] if calc_chi2: - images, chi2 = images - - images1 = make_metric_plots(real, gen1, features=features) - - img_amplitude = make_histograms(Y.flatten(), gen_scaled.flatten(), 'log10(amplitude + 1)', logy=True) - - images['examples'] = plot_individual_images(Y, gen_scaled) - images['examples_mask'] = plot_images_mask(Y, gen_scaled) + chi2 = metric_plot_results['chi2'] + if make_pdfs: + images_pdf = metric_plot_results['pdf_plots'] + pdf_outputs.append(images_pdf) + + metric_plot_results1 = make_metric_plots(real, gen1, features=features, make_pdfs=make_pdfs) + images1 = metric_plot_results1['plots'] + if make_pdfs: + pdf_outputs.append(metric_plot_results1['pdf_plots']) + + pdffile = None + if make_pdfs: + pdffile = io.BytesIO() + pdf_outputs.append(pdffile) + img_amplitude = make_histograms(Y.flatten(), gen_scaled.flatten(), 'log10(amplitude + 1)', logy=True, + pdffile=pdffile) + + pdffile_examples = None + pdffile_examples_mask = None + if make_pdfs: + pdffile_examples = io.BytesIO() + pdffile_examples_mask = io.BytesIO() + images_pdf['examples'] = pdffile_examples + images_pdf['examples_mask'] = pdffile_examples_mask + images['examples'] = plot_individual_images(Y, gen_scaled, pdffile=pdffile_examples) + images['examples_mask'] = plot_images_mask(Y, gen_scaled, pdffile=pdffile_examples_mask) result = [images, images1, img_amplitude] @@ -126,11 +166,13 @@ def make_images_for_model(model, def evaluate_model(model, path, sample, gen_sample_name=None): path.mkdir() + pdf_outputs = [] ( images, images1, img_amplitude, gen_dataset, chi2 ) = make_images_for_model(model, sample=sample, - calc_chi2=True, return_raw_data=True, gen_more=10) + calc_chi2=True, return_raw_data=True, gen_more=10, pdf_outputs=pdf_outputs) + images_pdf, images1_pdf, img_amplitude_pdf = pdf_outputs array_to_img = lambda arr: PIL.Image.fromarray(arr.reshape(arr.shape[1:])) @@ -140,6 +182,16 @@ def evaluate_model(model, path, sample, gen_sample_name=None): array_to_img(img).save(str(path / f"{k}_amp_gt_1.png")) array_to_img(img_amplitude).save(str(path / "log10_amp_p_1.png")) + def buf_to_file(buf, filename): + with open(filename, 'wb') as f: + f.write(buf.getbuffer()) + + for k, img in images_pdf.items(): + buf_to_file(img, str(path / f"{k}.pdf")) + for k, img in images1_pdf.items(): + buf_to_file(img, str(path / f"{k}_amp_gt_1.pdf")) + buf_to_file(img_amplitude_pdf, str(path / "log10_amp_p_1.pdf")) + if gen_sample_name is not None: with open(str(path / gen_sample_name), 'w') as f: for event_X, event_Y in zip(*gen_dataset): @@ -155,7 +207,7 @@ def evaluate_model(model, path, sample, gen_sample_name=None): f.write(f"{chi2:.2f}\n") -def plot_individual_images(real, gen, n=10): +def plot_individual_images(real, gen, n=10, pdffile=None): assert real.ndim == 3 == gen.ndim assert real.shape[1:] == gen.shape[1:] N_max = min(len(real), len(gen)) @@ -182,6 +234,7 @@ def plot_individual_images(real, gen, n=10): buf = io.BytesIO() fig.savefig(buf, format='png') + if pdffile is not None: fig.savefig(pdffile, format='pdf') plt.close(fig) buf.seek(0) @@ -189,7 +242,7 @@ def plot_individual_images(real, gen, n=10): return np.array(img.getdata(), dtype=np.uint8).reshape(1, img.size[1], img.size[0], -1) -def plot_images_mask(real, gen): +def plot_images_mask(real, gen, pdffile=None): assert real.ndim == 3 == gen.ndim assert real.shape[1:] == gen.shape[1:] @@ -197,13 +250,14 @@ def plot_images_mask(real, gen): size_y = size_x / real.shape[2] * real.shape[1] * 2.4 fig, [ax0, ax1] = plt.subplots(2, 1, figsize=(size_x, size_y)) - ax0.imshow(real.any(axis=0), aspect='auto') + ax0.imshow((real >= 1.).any(axis=0), aspect='auto') ax0.set_title("real") - ax1.imshow(gen.any(axis=0), aspect='auto') + ax1.imshow((gen >= 1.).any(axis=0), aspect='auto') ax1.set_title("generated") buf = io.BytesIO() fig.savefig(buf, format='png') + if pdffile is not None: fig.savefig(pdffile, format='pdf') plt.close(fig) buf.seek(0) diff --git a/metrics/trends.py b/metrics/trends.py index de6ef42..2ea7d5c 100644 --- a/metrics/trends.py +++ b/metrics/trends.py @@ -51,7 +51,7 @@ def stats(arr): return (mean, std), (mean_err, std_err) -def make_trend_plot(feature_real, real, feature_gen, gen, name, calc_chi2=False, figsize=(8, 8)): +def make_trend_plot(feature_real, real, feature_gen, gen, name, calc_chi2=False, figsize=(8, 8), pdffile=None): feature_real = feature_real.squeeze() feature_gen = feature_gen.squeeze() real = real.squeeze() @@ -71,6 +71,7 @@ def make_trend_plot(feature_real, real, feature_gen, gen, name, calc_chi2=False, buf = io.BytesIO() fig.savefig(buf, format='png') + if pdffile is not None: fig.savefig(pdffile, format='pdf') plt.close(fig) buf.seek(0)