Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pygem/plot/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ def plot_mcmc_chain(
continue

# stack predictions first (shape: n_steps x ... x ...) - may end up being 2d or 3d
pred_primes = torch.stack(pred_primes[key]).numpy()
pred_chain = torch.stack(pred_chain[key]).numpy()
pred_primes_key = torch.stack(pred_primes[key]).numpy()
pred_chain_key = torch.stack(pred_chain[key]).numpy()

# flatten all axes except the first (n_steps) -> 2D array (n_steps, M)
pred_primes_flat = pred_primes.reshape(pred_primes.shape[0], -1)
pred_chain_flat = pred_chain.reshape(pred_chain.shape[0], -1)
pred_primes_flat = pred_primes_key.reshape(pred_primes_key.shape[0], -1)
pred_chain_flat = pred_chain_key.reshape(pred_chain_key.shape[0], -1)

# make obs array broadcastable (flatten if needed)
obs_vals_flat = np.ravel(np.array(obs[key][0]))
Expand All @@ -291,6 +291,7 @@ def plot_mcmc_chain(
axes[nparams].plot(mean_resid_primes, '.', ms=ms, c='tab:blue')
axes[nparams].plot(mean_resid_chain, '.', ms=ms, c='tab:orange')

axes[nparams].text(0.02, 0.02, key, transform=axes[nparams].transAxes, fontsize=fontsize, va='bottom', ha='left')
if key == 'elev_change_1d':
axes[nparams].set_ylabel(r'$\overline{\hat{dh} - dh}$', fontsize=fontsize)
else:
Expand Down
Loading