Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions pygem/bin/run/run_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,21 +2219,12 @@ def rho_constraints(**kwargs):
f'Chain {n_chain}: failed to produce an unstuck result after {attempts_per_chain} initial guesses.'
)

# concatenate mass balance
m_chain = torch.cat(
(m_chain, torch.tensor(pred_chain['glacierwide_mb_mwea']).reshape(-1, 1)), dim=1
)
m_primes = torch.cat(
(m_primes, torch.tensor(pred_primes['glacierwide_mb_mwea']).reshape(-1, 1)),
dim=1,
)

if debug:
print(
'mb_mwea_mean:',
np.round(torch.mean(m_chain[:, -1]).item(), 3),
np.round(torch.mean(torch.stack(pred_chain['glacierwide_mb_mwea'])).item(), 3),
'mb_mwea_std:',
np.round(torch.std(m_chain[:, -1]).item(), 3),
np.round(torch.std(torch.stack(pred_chain['glacierwide_mb_mwea'])).item(), 3),
'\nmb_obs_mean:',
np.round(mb_obs_mwea, 3),
'mb_obs_std:',
Expand All @@ -2257,6 +2248,8 @@ def rho_constraints(**kwargs):
graphics.plot_mcmc_chain(
m_primes,
m_chain,
pred_primes,
pred_chain,
obs,
ar,
glacier_str,
Expand Down
72 changes: 55 additions & 17 deletions pygem/plot/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import binned_statistic

from pygem.utils.stats import effective_n
Expand Down Expand Up @@ -113,10 +114,15 @@ def plot_modeloutput_section(model=None, ax=None, title='', **kwargs):
ax.set_title(title, loc='left')


def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=False, fpath=None):
def plot_mcmc_chain(
m_primes, m_chain, pred_primes, pred_chain, obs, ar, title, ms=1, fontsize=8, show=False, fpath=None
):
# Plot the trace of the parameters
n = m_primes.shape[1]
fig, axes = plt.subplots(n + 1, 1, figsize=(6, n * 1.5), sharex=True)
nparams = m_primes.shape[1]
npreds = len(pred_chain.keys())
N = nparams + npreds + 1
fig, axes = plt.subplots(N, 1, figsize=(6, N * 1), sharex=True)
# convert torch objects to numpy
m_chain = m_chain.detach().numpy()
m_primes = m_primes.detach().numpy()

Expand All @@ -125,6 +131,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
# instantiate list to hold legend objs
legs = []

# axes[0] will always be tbias
axes[0].plot(
[],
[],
Expand All @@ -139,6 +146,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
# axes[0].add_artist(leg)
axes[0].set_ylabel(r'$T_{bias}$', fontsize=fontsize)

# axes[1] will always be kp
axes[1].plot(m_primes[:, 1], '.', ms=ms, c='tab:blue')
axes[1].plot(m_chain[:, 1], '.', ms=ms, c='tab:orange')
axes[1].plot(
Expand All @@ -150,6 +158,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
legs.append(l1)
axes[1].set_ylabel(r'$K_p$', fontsize=fontsize)

# axes[2] will always be ddfsnow
axes[2].plot(m_primes[:, 2], '.', ms=ms, c='tab:blue')
axes[2].plot(m_chain[:, 2], '.', ms=ms, c='tab:orange')
axes[2].plot(
Expand All @@ -161,7 +170,8 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
legs.append(l2)
axes[2].set_ylabel(r'$fsnow$', fontsize=fontsize)

if n > 4:
if nparams > 3:
# axes[3] will be rho_ablation if more than 3 model params
m_chain[:, 3] = m_chain[:, 3]
m_primes[:, 3] = m_primes[:, 3]
axes[3].plot(m_primes[:, 3], '.', ms=ms, c='tab:blue')
Expand All @@ -175,6 +185,7 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
legs.append(l3)
axes[3].set_ylabel(r'$\rho_{abl}$', fontsize=fontsize)

# axes[4] will be rho_accumulation if more than 3 model params
m_chain[:, 4] = m_chain[:, 4]
m_primes[:, 4] = m_primes[:, 4]
axes[4].plot(m_primes[:, 4], '.', ms=ms, c='tab:blue')
Expand All @@ -188,33 +199,60 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
legs.append(l4)
axes[4].set_ylabel(r'$\rho_{acc}$', fontsize=fontsize)

if 'glacierwide_mb_mwea' in obs.keys():
# plot predictions
if 'glacierwide_mb_mwea' in pred_primes.keys():
mb_obs = obs['glacierwide_mb_mwea']
axes[-2].fill_between(
axes[nparams].fill_between(
np.arange(len(ar)),
mb_obs[0] - (2 * mb_obs[1]),
mb_obs[0] + (2 * mb_obs[1]),
color='grey',
alpha=0.3,
)
axes[-2].fill_between(
axes[nparams].fill_between(
np.arange(len(ar)),
mb_obs[0] - mb_obs[1],
mb_obs[0] + mb_obs[1],
color='grey',
alpha=0.3,
)
axes[-2].plot(m_primes[:, -1], '.', ms=ms, c='tab:blue')
axes[-2].plot(m_chain[:, -1], '.', ms=ms, c='tab:orange')
axes[-2].plot(

mb_primes = torch.stack(pred_primes['glacierwide_mb_mwea']).numpy()
mb_chain = torch.stack(pred_chain['glacierwide_mb_mwea']).numpy()
axes[nparams].plot(mb_primes, '.', ms=ms, c='tab:blue')
axes[nparams].plot(mb_chain, '.', ms=ms, c='tab:orange')
axes[nparams].plot(
[],
[],
label=f'median={np.median(m_chain[:, -1]):.3f}\niqr={np.subtract(*np.percentile(m_chain[:, -1], [75, 25])):.3f}',
label=f'median={np.median(mb_chain):.3f}\niqr={np.subtract(*np.percentile(mb_chain, [75, 25])):.3f}',
)
ln2 = axes[-2].legend(loc='upper right', handlelength=0, borderaxespad=0, fontsize=fontsize)
ln2 = axes[nparams].legend(loc='upper right', handlelength=0, borderaxespad=0, fontsize=fontsize)
legs.append(ln2)
axes[-2].set_ylabel(r'$\dot{{b}}$', fontsize=fontsize)
axes[nparams].set_ylabel(r'$\dot{{b}}$', fontsize=fontsize)
nparams += 1

# plot MAE for all other prediction keys
for key in pred_primes.keys():
if key == 'glacierwide_mb_mwea':
continue
pred_primes = torch.stack(pred_primes[key]).numpy()
pred_chain = torch.stack(pred_chain[key]).numpy()
obs_vals = np.array(obs[key][0])

mae_primes = np.mean(pred_primes - obs_vals, axis=(1, 2))
mae_chain = np.mean(pred_chain - obs_vals, axis=(1, 2))

axes[nparams].plot(mae_primes, '.', ms=ms, c='tab:blue')
axes[nparams].plot(mae_chain, '.', ms=ms, c='tab:orange')

if key == 'elev_change_1d':
axes[nparams].set_ylabel(r'$\overline{\hat{dh} - dh}$', fontsize=fontsize)
else:
axes[nparams].set_ylabel(r'$\overline{\mathrm{pred} - \mathrm{obs}}$', fontsize=fontsize)
legs.append(None)
nparams += 1

# axes[-1] will always be acceptance rate
axes[-1].plot(ar, 'tab:orange', lw=1)
axes[-1].plot(
np.convolve(ar, np.ones(100) / 100, mode='valid'),
Expand All @@ -230,7 +268,8 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.tick_params(axis='both', direction='inout')
if i == n:
ax.yaxis.set_label_coords(-0.1, 0.5)
if i > m_primes.shape[1] - 1:
continue
ax.plot([], [], label=f'n_eff={neff[i]}')
hands, ls = ax.get_legend_handles_labels()
Expand All @@ -252,9 +291,8 @@ def plot_mcmc_chain(m_primes, m_chain, obs, ar, title, ms=1, fontsize=8, show=Fa
handlelength=0,
fontsize=fontsize,
)

for i, ax in enumerate(axes):
ax.add_artist(legs[i])
if legs[i] is not None:
ax.add_artist(legs[i])

axes[0].set_xlim([0, m_chain.shape[0]])
axes[0].set_title(title, fontsize=fontsize)
Expand Down
Loading