diff --git a/pygem/bin/run/run_calibration.py b/pygem/bin/run/run_calibration.py index 8a3b2f32..e6587bdd 100755 --- a/pygem/bin/run/run_calibration.py +++ b/pygem/bin/run/run_calibration.py @@ -2219,21 +2219,12 @@ def rho_constraints(**kwargs): f'Chain {n_chain}: failed to produce an unstuck result after {attempts_per_chain} initial guesses.' ) - # concatenate mass balance - m_chain = torch.cat( - (m_chain, torch.tensor(pred_chain['glacierwide_mb_mwea']).reshape(-1, 1)), dim=1 - ) - m_primes = torch.cat( - (m_primes, torch.tensor(pred_primes['glacierwide_mb_mwea']).reshape(-1, 1)), - dim=1, - ) - if debug: print( 'mb_mwea_mean:', - np.round(torch.mean(m_chain[:, -1]).item(), 3), + np.round(torch.mean(torch.stack(pred_chain['glacierwide_mb_mwea'])).item(), 3), 'mb_mwea_std:', - np.round(torch.std(m_chain[:, -1]).item(), 3), + np.round(torch.std(torch.stack(pred_chain['glacierwide_mb_mwea'])).item(), 3), '\nmb_obs_mean:', np.round(mb_obs_mwea, 3), 'mb_obs_std:', @@ -2257,6 +2248,8 @@ def rho_constraints(**kwargs): graphics.plot_mcmc_chain( m_primes, m_chain, + pred_primes, + pred_chain, obs, ar, glacier_str, diff --git a/pygem/plot/graphics.py b/pygem/plot/graphics.py index 20b49951..b0d37bb0 100644 --- a/pygem/plot/graphics.py +++ b/pygem/plot/graphics.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt import numpy as np +import torch from scipy.stats import binned_statistic from pygem.utils.stats import effective_n @@ -113,10 +114,15 @@ def plot_modeloutput_section(model=None, ax=None, title='', **kwargs): ax.set_title(title, loc='left') -def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=False, fpath=None): +def plot_mcmc_chain( + m_primes, m_chain, pred_primes, pred_chain, obs, ar, title, ms=1, fontsize=8, show=False, fpath=None +): # Plot the trace of the parameters - n = m_primes.shape[1] - fig, axes = plt.subplots(n + 1, 1, figsize=(6, n * 1.5), sharex=True) + nparams = m_primes.shape[1] + npreds = len(pred_chain.keys()) + N = nparams + npreds + 1 + fig, axes = plt.subplots(N, 1, figsize=(6, N * 1), sharex=True) + # convert torch objects to numpy m_chain = m_chain.detach().numpy() m_primes = m_primes.detach().numpy() @@ -125,6 +131,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa # instantiate list to hold legend objs legs = [] + # axes[0] will always be tbias axes[0].plot( [], [], @@ -139,6 +146,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa # axes[0].add_artist(leg) axes[0].set_ylabel(r'$T_{bias}$', fontsize=fontsize) + # axes[1] will always be kp axes[1].plot(m_primes[:, 1], '.', ms=ms, c='tab:blue') axes[1].plot(m_chain[:, 1], '.', ms=ms, c='tab:orange') axes[1].plot( @@ -150,6 +158,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa legs.append(l1) axes[1].set_ylabel(r'$K_p$', fontsize=fontsize) + # axes[2] will always be ddfsnow axes[2].plot(m_primes[:, 2], '.', ms=ms, c='tab:blue') axes[2].plot(m_chain[:, 2], '.', ms=ms, c='tab:orange') axes[2].plot( @@ -161,7 +170,8 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa legs.append(l2) axes[2].set_ylabel(r'$fsnow$', fontsize=fontsize) - if n > 4: + if nparams > 3: + # axes[3] will be rho_ablation if more than 3 model params m_chain[:, 3] = m_chain[:, 3] m_primes[:, 3] = m_primes[:, 3] axes[3].plot(m_primes[:, 3], '.', ms=ms, c='tab:blue') @@ -175,6 +185,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa legs.append(l3) axes[3].set_ylabel(r'$\rho_{abl}$', fontsize=fontsize) + # axes[4] will be rho_accumulation if more than 3 model params m_chain[:, 4] = m_chain[:, 4] m_primes[:, 4] = m_primes[:, 4] axes[4].plot(m_primes[:, 4], '.', ms=ms, c='tab:blue') @@ -188,33 +199,60 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa legs.append(l4) axes[4].set_ylabel(r'$\rho_{acc}$', fontsize=fontsize) - if 'glacierwide_mb_mwea' in obs.keys(): + # plot predictions + if 'glacierwide_mb_mwea' in pred_primes.keys(): mb_obs = obs['glacierwide_mb_mwea'] - axes[-2].fill_between( + axes[nparams].fill_between( np.arange(len(ar)), mb_obs[0] - (2 * mb_obs[1]), mb_obs[0] + (2 * mb_obs[1]), color='grey', alpha=0.3, ) - axes[-2].fill_between( + axes[nparams].fill_between( np.arange(len(ar)), mb_obs[0] - mb_obs[1], mb_obs[0] + mb_obs[1], color='grey', alpha=0.3, ) - axes[-2].plot(m_primes[:, -1], '.', ms=ms, c='tab:blue') - axes[-2].plot(m_chain[:, -1], '.', ms=ms, c='tab:orange') - axes[-2].plot( + + mb_primes = torch.stack(pred_primes['glacierwide_mb_mwea']).numpy() + mb_chain = torch.stack(pred_chain['glacierwide_mb_mwea']).numpy() + axes[nparams].plot(mb_primes, '.', ms=ms, c='tab:blue') + axes[nparams].plot(mb_chain, '.', ms=ms, c='tab:orange') + axes[nparams].plot( [], [], - label=f'median={np.median(m_chain[:, -1]):.3f}\niqr={np.subtract(*np.percentile(m_chain[:, -1], [75, 25])):.3f}', + label=f'median={np.median(mb_chain):.3f}\niqr={np.subtract(*np.percentile(mb_chain, [75, 25])):.3f}', ) - ln2 = axes[-2].legend(loc='upper right', handlelength=0, borderaxespad=0, fontsize=fontsize) + ln2 = axes[nparams].legend(loc='upper right', handlelength=0, borderaxespad=0, fontsize=fontsize) legs.append(ln2) - axes[-2].set_ylabel(r'$\dot{{b}}$', fontsize=fontsize) + axes[nparams].set_ylabel(r'$\dot{{b}}$', fontsize=fontsize) + nparams += 1 + + # plot MAE for all other prediction keys + for key in pred_primes.keys(): + if key == 'glacierwide_mb_mwea': + continue + pred_primes = torch.stack(pred_primes[key]).numpy() + pred_chain = torch.stack(pred_chain[key]).numpy() + obs_vals = np.array(obs[key][0]) + + mae_primes = np.mean(pred_primes - obs_vals, axis=(1, 2)) + mae_chain = np.mean(pred_chain - obs_vals, axis=(1, 2)) + axes[nparams].plot(mae_primes, '.', ms=ms, c='tab:blue') + axes[nparams].plot(mae_chain, '.', ms=ms, c='tab:orange') + + if key == 'elev_change_1d': + axes[nparams].set_ylabel(r'$\overline{\hat{dh} - dh}$', fontsize=fontsize) + else: + axes[nparams].set_ylabel(r'$\overline{\mathrm{pred} - \mathrm{obs}}$', fontsize=fontsize) + legs.append(None) + nparams += 1 + + # axes[-1] will always be acceptance rate axes[-1].plot(ar, 'tab:orange', lw=1) axes[-1].plot( np.convolve(ar, np.ones(100) / 100, mode='valid'), @@ -230,7 +268,8 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') ax.tick_params(axis='both', direction='inout') - if i == n: + ax.yaxis.set_label_coords(-0.1, 0.5) + if i > m_primes.shape[1] - 1: continue ax.plot([], [], label=f'n_eff={neff[i]}') hands, ls = ax.get_legend_handles_labels() @@ -252,9 +291,8 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa handlelength=0, fontsize=fontsize, ) - - for i, ax in enumerate(axes): - ax.add_artist(legs[i]) + if legs[i] is not None: + ax.add_artist(legs[i]) axes[0].set_xlim([0, m_chain.shape[0]]) axes[0].set_title(title, fontsize=fontsize)