Skip to content

Commit

Permalink
black formatting notebooks.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jun 30, 2022
1 parent e628240 commit 7f16f1d
Show file tree
Hide file tree
Showing 6 changed files with 1,308 additions and 606 deletions.
123 changes: 77 additions & 46 deletions notebooks/MNLE-Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,19 @@
"import sbibm\n",
"\n",
"# Plotting settings\n",
"plt.style.use('plotting_settings.mplstyle')\n",
"plt.style.use(\"plotting_settings.mplstyle\")\n",
"# Colorblind color palette\n",
"colors = ['#377eb8', '#ff7f00', '#4daf4a',\n",
" '#f781bf', '#a65628', '#984ea3',\n",
" '#999999', '#e41a1c', '#dede00']"
"colors = [\n",
" \"#377eb8\",\n",
" \"#ff7f00\",\n",
" \"#4daf4a\",\n",
" \"#f781bf\",\n",
" \"#a65628\",\n",
" \"#984ea3\",\n",
" \"#999999\",\n",
" \"#e41a1c\",\n",
" \"#dede00\",\n",
"]"
]
},
{
Expand Down Expand Up @@ -88,7 +96,7 @@
" # decode and put them in separate columns.\n",
" x = torch.zeros(x1d.shape[0], 2)\n",
" x[:, 0] = abs(x1d.squeeze())\n",
" x[x1d.squeeze()>0, 1] = 1\n",
" x[x1d.squeeze() > 0, 1] = 1\n",
"\n",
"theta = theta[:N]\n",
"x = x[:N]"
Expand Down Expand Up @@ -215,7 +223,9 @@
"# To be able to plot the data in 1D we again encode choices as sign.\n",
"x_syn_1d = x_syn[:, :1]\n",
"# Set reaction times negative for zero-choices.\n",
"x_syn_1d[x_syn[:, 1:]==0, ] *= -1"
"x_syn_1d[\n",
" x_syn[:, 1:] == 0,\n",
"] *= -1"
]
},
{
Expand Down Expand Up @@ -270,12 +280,26 @@
"# Plot the real simulated data, the synthetic data, and the learned synthetic likelihoods.\n",
"plt.figure(figsize=(12, 6))\n",
"density = True\n",
"_, bins, _ = plt.hist(x_syn_1d.numpy(), bins=\"auto\", color=\"C0\", \n",
" label=\"synthetic data via MNLE\", density=density);\n",
"_, bins, _ = plt.hist(\n",
" x_syn_1d.numpy(),\n",
" bins=\"auto\",\n",
" color=\"C0\",\n",
" label=\"synthetic data via MNLE\",\n",
" density=density,\n",
")\n",
"# Plot real data on top,\n",
"plt.hist(xos.numpy(), bins=bins, alpha=0.5, color=\"C1\", label=\"real data via DDM\", density=density)\n",
"plt.hist(\n",
" xos.numpy(),\n",
" bins=bins,\n",
" alpha=0.5,\n",
" color=\"C1\",\n",
" label=\"real data via DDM\",\n",
" density=density,\n",
")\n",
"# and the MNLE likelihoods\n",
"plt.plot(test_data, mnle_likelihoods.exp(), label=\"MNLE likelihood\", ls=\"-\", c=colors[2]);\n",
"plt.plot(\n",
" test_data, mnle_likelihoods.exp(), label=\"MNLE likelihood\", ls=\"-\", c=colors[2]\n",
")\n",
"plt.legend();"
]
},
Expand Down Expand Up @@ -322,10 +346,10 @@
"outputs": [],
"source": [
"# Using the trainer object we can obtain the posterior object with the desired MCMC options for inference.\n",
"posterior = trainer.build_posterior(mcmc_method=\"slice_np_vectorized\", \n",
" mcmc_parameters=dict(init_strategy=\"sir\", \n",
" num_chains=10, \n",
" warmup_steps=100))"
"posterior = trainer.build_posterior(\n",
" mcmc_method=\"slice_np_vectorized\",\n",
" mcmc_parameters=dict(init_strategy=\"sir\", num_chains=10, warmup_steps=100),\n",
")"
]
},
{
Expand Down Expand Up @@ -402,25 +426,27 @@
}
],
"source": [
"limits = [[-2, 2], [0.5, 2.0], [.3, .7], [.2, 1.8]]\n",
"limits = [[-2, 2], [0.5, 2.0], [0.3, 0.7], [0.2, 1.8]]\n",
"\n",
"fig, ax1 = pairplot([reference_posterior_samples, posterior_samples],\n",
" points=sbibm.get_task(\"ddm\").get_true_parameters(obs), \n",
" limits=limits, \n",
" ticks=limits, \n",
" samples_colors=colors[:2], \n",
" diag=\"kde\",\n",
" upper=\"contour\",\n",
" kde_offdiag=dict(bw_method=\"scott\", bins=50),\n",
" contour_offdiag=dict(levels=[0.01], percentile=False),\n",
" points_offdiag=dict(marker=\"+\", markersize=10), \n",
" points_colors=[\"k\"], \n",
" labels=[r\"$v$\", r\"$a$\", r\"$w$\", r\"$\\tau$\"])\n",
"fig, ax1 = pairplot(\n",
" [reference_posterior_samples, posterior_samples],\n",
" points=sbibm.get_task(\"ddm\").get_true_parameters(obs),\n",
" limits=limits,\n",
" ticks=limits,\n",
" samples_colors=colors[:2],\n",
" diag=\"kde\",\n",
" upper=\"contour\",\n",
" kde_offdiag=dict(bw_method=\"scott\", bins=50),\n",
" contour_offdiag=dict(levels=[0.01], percentile=False),\n",
" points_offdiag=dict(marker=\"+\", markersize=10),\n",
" points_colors=[\"k\"],\n",
" labels=[r\"$v$\", r\"$a$\", r\"$w$\", r\"$\\tau$\"],\n",
")\n",
"\n",
"plt.sca(ax1[0, 0])\n",
"plt.legend([\"Reference\", \"MNLE\", r\"Ground truth $\\theta$\"], \n",
" bbox_to_anchor=(-.1, -2.9), \n",
" loc=2);"
"plt.legend(\n",
" [\"Reference\", \"MNLE\", r\"Ground truth $\\theta$\"], bbox_to_anchor=(-0.1, -2.9), loc=2\n",
");"
]
},
{
Expand Down Expand Up @@ -484,7 +510,10 @@
],
"source": [
"# Infer with MCMC\n",
"posterior_samples = posterior.sample((1000,), x=x_o, )"
"posterior_samples = posterior.sample(\n",
" (1000,),\n",
" x=x_o,\n",
")"
]
},
{
Expand All @@ -509,23 +538,25 @@
"source": [
"reference_posterior_samples = task.get_reference_posterior_samples(obs)[:1000]\n",
"\n",
"fig, ax1 = pairplot([reference_posterior_samples, posterior_samples],\n",
" points=sbibm.get_task(\"ddm\").get_true_parameters(obs), \n",
" limits=limits, \n",
" ticks=limits, \n",
" samples_colors=colors[:2], \n",
" diag=\"kde\",\n",
" upper=\"contour\",\n",
" kde_offdiag=dict(bw_method=\"scott\", bins=50),\n",
" contour_offdiag=dict(levels=[0.01], percentile=False),\n",
" points_offdiag=dict(marker=\"+\", markersize=10), \n",
" points_colors=[\"k\"], \n",
" labels=[r\"$v$\", r\"$a$\", r\"$w$\", r\"$\\tau$\"])\n",
"fig, ax1 = pairplot(\n",
" [reference_posterior_samples, posterior_samples],\n",
" points=sbibm.get_task(\"ddm\").get_true_parameters(obs),\n",
" limits=limits,\n",
" ticks=limits,\n",
" samples_colors=colors[:2],\n",
" diag=\"kde\",\n",
" upper=\"contour\",\n",
" kde_offdiag=dict(bw_method=\"scott\", bins=50),\n",
" contour_offdiag=dict(levels=[0.01], percentile=False),\n",
" points_offdiag=dict(marker=\"+\", markersize=10),\n",
" points_colors=[\"k\"],\n",
" labels=[r\"$v$\", r\"$a$\", r\"$w$\", r\"$\\tau$\"],\n",
")\n",
"\n",
"plt.sca(ax1[0, 0])\n",
"plt.legend([\"Reference\", \"MNLE\", r\"Ground truth $\\theta$\"], \n",
" bbox_to_anchor=(-.1, -2.9), \n",
" loc=2);"
"plt.legend(\n",
" [\"Reference\", \"MNLE\", r\"Ground truth $\\theta$\"], bbox_to_anchor=(-0.1, -2.9), loc=2\n",
");"
]
},
{
Expand Down
Loading

0 comments on commit 7f16f1d

Please sign in to comment.