Skip to content

Commit

Permalink
Merge pull request #31 from fehiepsi/networkx
Browse files Browse the repository at this point in the history
Use networkx instead of causalgraphicalmodels
  • Loading branch information
fehiepsi authored Feb 16, 2024
2 parents fd76466 + 7de912d commit be837c0
Show file tree
Hide file tree
Showing 23 changed files with 344 additions and 299 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,3 @@ jobs:
- name: Test with nbval
run: |
find notebooks -maxdepth 1 -name "[01][456789]*.ipynb" | sort -n | xargs pytest -vx --nbval-lax --durations=0
26 changes: 26 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
on:
push:
branches: [ master ]

jobs:
nikola_build:

name: 'Deploy Nikola to GitHub Pages'
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Build and Deploy Nikola
run: |
python -m pip install --upgrade pip
pip install "Nikola[extras]"
git fetch origin gh-pages
cd site
nikola github_deploy
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ output
.doit.db*
site/pages/*.ipynb
site/listings/*.py

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ I am a fan of the book [*Statistical Rethinking*](https://xcelab.net/rm/statisti

## Installation

The following tools are used for some analysis and visualizations: [arviz](https://arviz-devs.github.io/arviz/) for [posteriors](https://en.wikipedia.org/wiki/Posterior_probability), [causalgraphicalmodels](https://github.com/ijmbarr/causalgraphicalmodels) and [daft](https://docs.daft-pgm.org/en/latest/) for [causal graphs](https://en.wikipedia.org/wiki/Causal_graph), and (optional) [ete3](http://etetoolkit.org/) for [phylogenetic trees](https://en.wikipedia.org/wiki/Phylogenetic_tree).
The following tools are used for some analysis and visualizations: [arviz](https://arviz-devs.github.io/arviz/) for [posteriors](https://en.wikipedia.org/wiki/Posterior_probability), [networkx](https://networkx.org/) and [daft](https://docs.daft-pgm.org/en/latest/) for [causal graphs](https://en.wikipedia.org/wiki/Causal_graph), and (optional) [ete3](http://etetoolkit.org/) for [phylogenetic trees](https://en.wikipedia.org/wiki/Phylogenetic_tree).

```sh
pip install numpyro arviz causalgraphicalmodels daft
pip install numpyro arviz daft networkx
```

## Excercises
Expand Down
8 changes: 4 additions & 4 deletions notebooks/00_preface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q numpyro arviz causalgraphicalmodels daft"
"!pip install -q numpyro arviz"
]
},
{
Expand Down Expand Up @@ -231,14 +231,14 @@
},
"source": [
"```sh\n",
"pip install numpyro arviz causalgraphicalmodels daft\n",
"pip install numpyro arviz daft networkx\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -252,7 +252,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.11.6"
},
"varInspector": {
"cols": {
Expand Down
4 changes: 2 additions & 2 deletions notebooks/01_the_golem_of_prague.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -31,7 +31,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.11.6"
},
"varInspector": {
"cols": {
Expand Down
8 changes: 4 additions & 4 deletions notebooks/02_small_worlds_and_large_worlds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q numpyro arviz causalgraphicalmodels daft"
"!pip install -q numpyro arviz"
]
},
{
Expand Down Expand Up @@ -224,7 +224,7 @@
"params = svi_result.params\n",
"\n",
"# display summary of quadratic approximation\n",
"samples = guide.sample_posterior(random.PRNGKey(1), params, (1000,))\n",
"samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))\n",
"numpyro.diagnostics.print_summary(samples, prob=0.89, group_by_chain=False)"
]
},
Expand Down Expand Up @@ -323,7 +323,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -337,7 +337,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions notebooks/03_sampling_the_imaginary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q numpyro arviz causalgraphicalmodels daft"
"!pip install -q numpyro arviz"
]
},
{
Expand Down Expand Up @@ -816,7 +816,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
"version": "3.11.6"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
36 changes: 20 additions & 16 deletions notebooks/04_geocentric_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q numpyro arviz causalgraphicalmodels daft"
"!pip install -q numpyro arviz"
]
},
{
Expand Down Expand Up @@ -904,7 +904,7 @@
}
],
"source": [
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (1000,))\n",
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(1000,))\n",
"print_summary(samples, 0.89, False)"
]
},
Expand Down Expand Up @@ -978,7 +978,7 @@
"svi = SVI(model, m4_2, optim.Adam(1), Trace_ELBO(), height=d2.height.values)\n",
"svi_result = svi.run(random.PRNGKey(0), 2000)\n",
"p4_2 = svi_result.params\n",
"samples = m4_2.sample_posterior(random.PRNGKey(1), p4_2, (1000,))\n",
"samples = m4_2.sample_posterior(random.PRNGKey(1), p4_2, sample_shape=(1000,))\n",
"print_summary(samples, 0.89, False)"
]
},
Expand Down Expand Up @@ -1007,7 +1007,7 @@
}
],
"source": [
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (1000,))\n",
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(1000,))\n",
"vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))\n",
"vcov"
]
Expand Down Expand Up @@ -1064,7 +1064,7 @@
}
],
"source": [
"post = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (int(1e4),))\n",
"post = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(int(1e4),))\n",
"{latent: list(post[latent][:6]) for latent in post}"
]
},
Expand Down Expand Up @@ -1369,7 +1369,7 @@
}
],
"source": [
"samples = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
"samples = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
"samples.pop(\"mu\")\n",
"print_summary(samples, 0.89, False)"
]
Expand Down Expand Up @@ -1429,7 +1429,7 @@
],
"source": [
"az.plot_pair(d2[[\"weight\", \"height\"]].to_dict(orient=\"list\"))\n",
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
"a_map = jnp.mean(post[\"a\"])\n",
"b_map = jnp.mean(post[\"b\"])\n",
"x = jnp.linspace(d2.weight.min(), d2.weight.max(), 101)\n",
Expand Down Expand Up @@ -1464,7 +1464,7 @@
}
],
"source": [
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
"{latent: list(post[latent].reshape(-1)[:5]) for latent in post}"
]
},
Expand Down Expand Up @@ -1539,7 +1539,7 @@
],
"source": [
"# extract 20 samples from the posterior\n",
"post = mN.sample_posterior(random.PRNGKey(1), pN, (20,))\n",
"post = mN.sample_posterior(random.PRNGKey(1), pN, sample_shape=(20,))\n",
"\n",
"# display raw data and sample size\n",
"ax = az.plot_pair(dN[[\"weight\", \"height\"]].to_dict(orient=\"list\"))\n",
Expand Down Expand Up @@ -1568,7 +1568,8 @@
"metadata": {},
"outputs": [],
"source": [
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
"post.pop(\"mu\")\n",
"mu_at_50 = post[\"a\"] + post[\"b\"] * (50 - xbar)"
]
},
Expand Down Expand Up @@ -1797,7 +1798,7 @@
"metadata": {},
"outputs": [],
"source": [
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
"mu_link = lambda weight: post[\"a\"] + post[\"b\"] * (weight - xbar)\n",
"weight_seq = jnp.arange(start=25, stop=71, step=1)\n",
"mu = vmap(mu_link)(weight_seq).T\n",
Expand Down Expand Up @@ -1829,6 +1830,7 @@
}
],
"source": [
"post.pop(\"mu\")\n",
"sim_height = Predictive(m4_3.model, post, return_sites=[\"height\"])(\n",
" random.PRNGKey(2), weight_seq, None\n",
")[\"height\"]\n",
Expand Down Expand Up @@ -1905,6 +1907,7 @@
"outputs": [],
"source": [
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(int(1e4),))\n",
"post.pop(\"mu\")\n",
"sim_height = Predictive(m4_3.model, post, return_sites=[\"height\"])(\n",
" random.PRNGKey(2), weight_seq, None\n",
")[\"height\"]\n",
Expand All @@ -1924,7 +1927,7 @@
"metadata": {},
"outputs": [],
"source": [
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
"weight_seq = jnp.arange(25, 71)\n",
"sim_height = vmap(\n",
" lambda i, weight: dist.Normal(\n",
Expand Down Expand Up @@ -2126,7 +2129,7 @@
}
],
"source": [
"samples = m4_5.sample_posterior(random.PRNGKey(1), p4_5, (1000,))\n",
"samples = m4_5.sample_posterior(random.PRNGKey(1), p4_5, sample_shape=(1000,))\n",
"print_summary({k: v for k, v in samples.items() if k != \"mu\"}, 0.89, False)"
]
},
Expand All @@ -2145,7 +2148,8 @@
"source": [
"weight_seq = jnp.linspace(start=-2.2, stop=2, num=30)\n",
"pred_dat = {\"weight_s\": weight_seq, \"weight_s2\": weight_seq**2}\n",
"post = m4_5.sample_posterior(random.PRNGKey(1), p4_5, (1000,))\n",
"post = m4_5.sample_posterior(random.PRNGKey(1), p4_5, sample_shape=(1000,))\n",
"post.pop(\"mu\")\n",
"predictive = Predictive(m4_5.model, post)\n",
"mu = predictive(random.PRNGKey(2), **pred_dat)[\"mu\"]\n",
"mu_mean = jnp.mean(mu, 0)\n",
Expand Down Expand Up @@ -2479,7 +2483,7 @@
}
],
"source": [
"post = m4_7.sample_posterior(random.PRNGKey(1), p4_7, (1000,))\n",
"post = m4_7.sample_posterior(random.PRNGKey(1), p4_7, sample_shape=(1000,))\n",
"w = jnp.mean(post[\"w\"], 0)\n",
"plt.subplot(\n",
" xlim=(d2.year.min(), d2.year.max()),\n",
Expand Down Expand Up @@ -2578,7 +2582,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
"version": "3.11.6"
},
"toc": {
"base_numbering": 1,
Expand Down
182 changes: 71 additions & 111 deletions notebooks/05_the_many_variables_and_the_spurious_waffles.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit be837c0

Please sign in to comment.