Skip to content

Commit

Permalink
formulae clarified with small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
PoorvaGarg committed Aug 29, 2024
1 parent ff0739e commit ca71be5
Showing 1 changed file with 73 additions and 48 deletions.
121 changes: 73 additions & 48 deletions docs/source/explainable_sir.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -129,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -169,7 +169,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -238,7 +238,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -279,7 +279,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -302,7 +302,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -360,7 +360,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have our full-fledged model of SIR dynamics along with interventions, we have a complete list of random variables in question. In our explanation we will abbreviate them as follows. `S` - susceptible, `I` - infected, `R` - recovered, `ld` - lockdown, `m` - masking, `le` - lockdown efficiency, `me` - mask efficiency, `je` - joint efficiency, `os` - overshoot, and `oth` - overshoot is too high. We use these notations in the rest of the notebook to describe the probabilities we are computing."
"Now that we have our full-fledged model of SIR dynamics along with interventions, we have a complete list of random variables in question. In our explanation we will abbreviate them as follows. `S` - susceptible, `I` - infected, `R` - recovered, `l` - the effect of intervention, `beta`, `gamma` - the parameters of the SIR dynamics model, `ld` - lockdown, `m` - masking, `le` - lockdown efficiency, `me` - mask efficiency, `je` - joint efficiency, `os` - overshoot, and `oth` - overshoot is too high. We use these notations in the rest of the notebook to describe the probabilities we are computing."
]
},
{
Expand Down Expand Up @@ -389,9 +389,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 12,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variables in the model: dict_keys(['lockdown', 'mask', 'beta', 'gamma', 'lockdown_efficiency', 'mask_efficiency', 'joint_efficiency', 'S', 'I', 'R', 'l', 'overshoot', 'os_too_high'])\n"
]
}
],
"source": [
"# conditioning (as opposed to intervening) is sufficient for\n",
"# propagating the changes, as the decisions are upstream from ds\n",
Expand Down Expand Up @@ -429,7 +437,16 @@
"lockdown_samples = lockdown_predictive()\n",
"\n",
"predictive = Predictive(policy_model, num_samples=num_samples, parallel=True)\n",
"samples = predictive()"
"samples = predictive()\n",
"\n",
"print(\"Variables in the model:\", samples.keys())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the above list of variables match our list of variables earlier when we constructed the full-fledged SIR model."
]
},
{
Expand Down Expand Up @@ -581,6 +598,17 @@
"## Causal Explanations using SearchForExplanation\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we dive into the code below, let us first define some notation. We use small case abbreviations to refer to the value of the variables under consideration. For example, $\\mathit{ld}$ refers to `lockdown=1` and $\\mathit{ld}'$ refers to `lockdown=0`. We place interventions in the subscripts, i.e. $\\mathit{os}_{\\mathit{ld}}$ refers to the `overshoot` under the intervention that `lockdown=1`. Later on in the notebook, we also employ contexts that are kept fixed in the intervened worlds. We place these contexts in the superscript. For example, $\\mathit{os}_{\\mathit{ld}}^{\\mathit{me}}$ refers to the variable `overshoot` when `lockdown` was intervened to be 1 and `mask_efficiency` was kept fixed at its factual value. \n",
"\n",
"We use $P(.)$ to denote the distribution described by the model (`policy_model` in this notebook). We also induce a distribution over the sets of interventions and the sets of context nodes kept fixed. We denote these distributions by $P_a(.)$ and $P_w(.)$ respectively. As an example, $P_a(\\{ld\\})$ refers to the probability that the set of interventions under consideration is $\\{ld\\}$. These distributions are determined using the parameters `antecedent_bias` and `witness_bias` given to the handler `SearchForExplanation`. For more details, please refer to the [documentation](https://basisresearch.github.io/chirho/explainable.html#chirho.explainable.handlers.explanation.SearchForExplanation)\n",
"\n",
"Now let's dive into the code and we use this notation to describe the quantities we are computing. "
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -679,7 +707,7 @@
" witness_bias=0.2,\n",
")(policy_model)\n",
"\n",
"logp, importance_tr, mwc_imp, log_weights = importance_infer(num_samples=10000)(query)()\n",
"logp, importance_tr, mwc_imp, log_weights = importance_infer(num_samples=num_samples)(query)()\n",
"print(torch.exp(logp))"
]
},
Expand All @@ -698,7 +726,7 @@
"metadata": {},
"outputs": [],
"source": [
"def compute_prob(trace, log_weights, mask, mwc):\n",
"def compute_prob(trace, log_weights, mask):\n",
" mask_intervened = torch.ones(\n",
" trace.nodes[\"__cause____antecedent_lockdown\"][\"value\"].shape\n",
" ).bool()\n",
Expand All @@ -718,12 +746,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We specifically compute the following four probabilities. In each of the computations, we condition on lockdown and masking actually being implemented in the factual workd. Then we take an interventional setting and compute the probability that this setting has a causal power over the outcome. For instance, in 1., we assume lockdown (`ld`) and masking (`m`) have been implemented, and we ask about the joint prbability that both (a) removing both interventions, i.e. intervening for both `ld` and `m` to not happen - which we mark by the apostrophe - would lead to `oth` not happening, $\\mathit{oth}'_{\\mathit{ld}', m'}$, and (b) intervening for both to happend would lead to `oth`, $\\mathit{oth}_{\\mathit{ld}, m}$ (which, given the stochasticity between these interventions and the outcome, might be non-trivial).\n",
"We specifically compute the following four probabilities. In each of the computations, we condition on lockdown and masking actually being implemented in the factual world. Then we take an interventional setting and compute the probability that this setting has a causal power over the outcome. For instance, in 1., we assume lockdown (`ld`) and masking (`m`) have been implemented, and we ask about the joint prbability that both (a) removing both interventions, i.e. intervening for both `ld` and `m` to not happen - which we mark by the apostrophe - would lead to `oth` not happening, $\\mathit{oth}'_{\\mathit{ld}', m'}$, and (b) intervening for both to happend would lead to `oth`, $\\mathit{oth}_{\\mathit{ld}, m}$ (which, given the stochasticity between these interventions and the outcome, is non-trivial). Note that in computing these probabilities, we also marginalize over all the possible contexts to be kept fixed, i.e. all possible subsets of $W = \\{\\mathit{le}, \\mathit{me}\\}$\n",
"\n",
"1. $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{oth}^w_{\\mathit{ld}, m}, \\mathit{oth}'^w_{\\mathit{ld}', m'} | \\mathit{ld}, m)$\n",
"\n",
"2. $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{oth}^w_{\\mathit{ld}}, \\mathit{oth}'^w_{\\mathit{ld}'} | \\mathit{ld}, m)$\n",
"\n",
"3. $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{oth}^w_{m}, \\mathit{oth}'^w_{m'} | \\mathit{ld}, m)$\n",
"\n",
"1. $P(\\mathit{oth}_{\\mathit{ld}, m}, \\mathit{oth}'_{\\mathit{ld}', m'} | \\mathit{ld}, m)$\n",
"2. $P(\\mathit{oth}_{\\mathit{ld}}, \\mathit{oth}'_{\\mathit{ld}'} | \\mathit{ld}, m)$\n",
"3. $P(\\mathit{oth}_{m}, \\mathit{oth}'_{m'} | \\mathit{ld}, m)$\n",
"4. $P(\\mathit{oth}, \\mathit{oth}' | \\mathit{ld}, m)$"
"4. $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{oth}^w, \\mathit{oth}'^w | \\mathit{ld}, m)$"
]
},
{
Expand All @@ -748,31 +779,27 @@
" importance_tr,\n",
" log_weights,\n",
" {\"__cause____antecedent_lockdown\": 0, \"__cause____antecedent_mask\": 0, \"mask\": 1, \"lockdown\": 1},\n",
" mwc_imp\n",
")\n",
"\n",
"# only lockdown executed, masking preempted\n",
"compute_prob(\n",
" importance_tr,\n",
" log_weights,\n",
" {\"__cause____antecedent_lockdown\": 0, \"__cause____antecedent_mask\": 1, \"mask\": 1, \"lockdown\": 1},\n",
" mwc_imp\n",
")\n",
"\n",
"# only masking executed, lockdown preempted\n",
"compute_prob(\n",
" importance_tr,\n",
" log_weights,\n",
" {\"__cause____antecedent_lockdown\": 1, \"__cause____antecedent_mask\": 0, \"mask\": 1, \"lockdown\": 1},\n",
" mwc_imp\n",
")\n",
"\n",
"# no interventions executed\n",
"compute_prob(\n",
" importance_tr,\n",
" log_weights,\n",
" {\"__cause____antecedent_lockdown\": 1, \"__cause____antecedent_mask\": 1, \"mask\": 1, \"lockdown\": 1},\n",
" mwc_imp\n",
")"
]
},
Expand All @@ -796,9 +823,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also compute degree of responsibilities assigned to both lockdown and mask as follows. To compute degree of responsibility of lockdown and mask, we specifically compute the probability that these factors were a part of the cause of the outcome. Mathematically, we compute the following:\n",
"1. Degree of responsibility of lockdown: $\\sum_{\\mathit{ld} \\in C} P(\\mathit{oth}_{C}, \\mathit{oth}'_{C'} | \\mathit{ld}, m)$\n",
"2. Degree of responsibility of mask: $\\sum_{\\mathit{m} \\in C} P(\\mathit{oth}_{C}, \\mathit{oth}'_{C'} | \\mathit{ld}, m)$"
"We can also compute degree of responsibilities assigned to both lockdown and mask as follows. To compute degree of responsibility of lockdown and mask, we specifically compute the probability that these factors were a part of the cause of the outcome. Mathematically, we compute the following where $W = \\{\\mathit{le}, \\mathit{me}\\}$ and $C = \\{\\mathit{ld}, m\\}$:\n",
"\n",
"1. Degree of responsibility of lockdown: $\\sum_{w \\subseteq W} \\sum_{\\mathit{ld} \\in C} P_w(w) P_a(C | \\mathit{ld} \\in C) \\cdot P(\\mathit{oth}^w_{C}, \\mathit{oth}'^w_{C'} | \\mathit{ld}, m)$\n",
"\n",
"2. Degree of responsibility of mask: $\\sum_{w \\subseteq W} \\sum_{\\mathit{m} \\in C} P_w(w) P_a(C | \\mathit{m} \\in C) \\cdot P(\\mathit{oth}^w_{C}, \\mathit{oth}'^w_{C'} | \\mathit{ld}, m)$"
]
},
{
Expand All @@ -820,11 +849,11 @@
],
"source": [
"print(\"Degree of responsibility for lockdown: \")\n",
"compute_prob(importance_tr, log_weights, {\"__cause____antecedent_lockdown\": 0, \"mask\": 1, \"lockdown\": 1}, mwc_imp)\n",
"compute_prob(importance_tr, log_weights, {\"__cause____antecedent_lockdown\": 0, \"mask\": 1, \"lockdown\": 1})\n",
"print()\n",
"\n",
"print(\"Degree of responsibility for mask: \")\n",
"compute_prob(importance_tr, log_weights, {\"__cause____antecedent_mask\": 0, \"mask\": 1, \"lockdown\": 1}, mwc_imp)"
"compute_prob(importance_tr, log_weights, {\"__cause____antecedent_mask\": 0, \"mask\": 1, \"lockdown\": 1})"
]
},
{
Expand Down Expand Up @@ -1010,7 +1039,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The above histogram also takes into account the context that is being kept fixed. If `lockdown` is being intervened on, keeping `lockdown_efficiency` fixed would hinder the effect of intervention. Thus to obtain the relevant samples, we also filter for the appropriate context. Once we have filtered for the context, we take the samples and plot them as density above. The histogram above plots three quantities. It plots $P(\\mathit{os} | \\mathit{ld}, m)$ as the factual distribution of overshoot, $P(\\mathit{os}_{\\mathit{ld}'} | \\mathit{ld}, m)$ as `counterfactual_lockdown` and $P(\\mathit{os}_{\\mathit{m}'} | \\mathit{ld}, m)$ as `counterfactual_mask`. These distributions help in comparing how necessity interventions for the two antecedents affect the overshoot."
"The above histogram also takes into account the context that is being kept fixed. If `lockdown` is being intervened on, keeping `lockdown_efficiency` fixed would hinder the effect of intervention. Thus to obtain the relevant samples, we also filter for the appropriate context. Once we have filtered for the context, we take the samples and plot them as density above. The histogram above plots three quantities. It plots $P(\\mathit{os} | \\mathit{ld}, m)$ as the factual distribution of overshoot, $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{os}^w_{\\mathit{ld}'} | \\mathit{ld}, m)$ as `counterfactual_lockdown` where $W = \\{\\mathit{me}\\}$ and $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{os}^w_{\\mathit{m}'} | \\mathit{ld}, m)$ as `counterfactual_mask` where $W = \\{\\mathit{le}\\}$. These distributions help in comparing how necessity interventions for the two antecedents affect the overshoot."
]
},
{
Expand Down Expand Up @@ -1137,7 +1166,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The histogram plots three quantities. It plots $P(\\mathit{os} | \\mathit{ld}, m)$ as the factual distribution of overshoot, $P(\\mathit{os}_{\\mathit{ld}} | \\mathit{ld}, m)$ as `counterfactual_lockdown` and $P(\\mathit{os}_{\\mathit{m}} | \\mathit{ld}, m)$ as `counterfactual_mask`. Again, these distributions help in comparing how sufficiency interventions for the two antecedents affect the overshoot."
"The histogram plots three quantities. It plots $P(\\mathit{os} | \\mathit{ld}, m)$ as the factual distribution of overshoot, $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{os}^w_{\\mathit{ld}} | \\mathit{ld}, m)$ as `counterfactual_lockdown` where $W = \\{\\mathit{me}\\}$ and $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{os}^w_{\\mathit{m}} | \\mathit{ld}, m)$ as `counterfactual_mask` where $W = \\{\\mathit{le}\\}$. Again, these distributions help in comparing how sufficiency interventions for the two antecedents affect the overshoot."
]
},
{
Expand All @@ -1161,11 +1190,12 @@
}
],
"source": [
"# Collecting samples for joint distribution of overshoot under necessity and sufficiency interventions on lockdown\n",
"masks = {\n",
" \"__cause____antecedent_mask\": 1,\n",
" \"__cause____antecedent_lockdown\": 0,\n",
" \"__cause____witness_lockdown_efficiency\": 0,\n",
" \"lockdown\": 1, \"mask\": 1\n",
" \"__cause____antecedent_lockdown\": 0, # Intervening only on lockdown\n",
" \"__cause____witness_lockdown_efficiency\": 0, # Excluding lockdown efficiency fron the context candidates\n",
" \"lockdown\": 1, \"mask\": 1 # Conditioning on lockdown and masking being imposed in factual world\n",
" }\n",
"with mwc_imp:\n",
" data_nec = gather(\n",
Expand All @@ -1188,16 +1218,17 @@
" data_suff = data_suff.squeeze()[torch.nonzero(mask_tensor.squeeze())]\n",
" data_nec = data_nec.squeeze()[torch.nonzero(mask_tensor.squeeze())]\n",
"\n",
"\n",
"a = torch.transpose(torch.vstack((data_nec.squeeze(), data_suff.squeeze())), 0, 1)\n",
"a = torch.transpose(torch.vstack((data_nec.squeeze(), data_suff.squeeze())), 0, 1) # Joint distribution\n",
"hist_lockdown_2d, rough = torch.histogramdd(a, bins=[36, 36], density=True, range=[0.0, 45.0, 0.0, 45.0])\n",
"pr_lockdown = (hist_lockdown_2d[:16, 16:].sum()/hist_lockdown_2d.sum())\n",
"\n",
"\n",
"# Collecting samples for joint distribution of overshoot under necessity and sufficiency interventions on mask\n",
"masks = {\n",
" \"__cause____antecedent_mask\": 0,\n",
" \"__cause____antecedent_lockdown\": 1,\n",
" \"__cause____witness_mask_efficiency\": 0,\n",
" \"lockdown\": 1, \"mask\": 1\n",
" \"__cause____antecedent_lockdown\": 1, # Intervening only on mask\n",
" \"__cause____witness_mask_efficiency\": 0, # Excluding mask efficiency fron the context candidates\n",
" \"lockdown\": 1, \"mask\": 1 # Conditioning on lockdown and masking being imposed in factual world\n",
" }\n",
"with mwc_imp:\n",
" data_nec = gather(\n",
Expand All @@ -1219,15 +1250,7 @@
" data_suff = data_suff.squeeze()[torch.nonzero(mask_tensor.squeeze())]\n",
" data_nec = data_nec.squeeze()[torch.nonzero(mask_tensor.squeeze())]\n",
"\n",
" data_nec = data_nec.squeeze()\n",
" data_suff = data_suff.squeeze()\n",
"\n",
" sum = 0\n",
" for i in range(len(data_nec)):\n",
" if (data_nec[i] < overshoot_threshold) & (data_suff[i] > overshoot_threshold):\n",
" sum += 1\n",
"\n",
"a = torch.transpose(torch.vstack((data_nec.squeeze(), data_suff.squeeze())), 0, 1)\n",
"a = torch.transpose(torch.vstack((data_nec.squeeze(), data_suff.squeeze())), 0, 1) # Joint distribution\n",
"hist_mask_2d, _ = torch.histogramdd(a, bins = [36, 36], density=True, range=[0.0, 45.0, 0.0, 45.0])\n",
"pr_mask = (hist_mask_2d[:16, 16:].sum()/hist_mask_2d.sum())"
]
Expand Down Expand Up @@ -1273,7 +1296,7 @@
"ax.axhline(y=(os_lockdown_nec) * 36 / 45, color=\"white\", linestyle=\"--\")\n",
"\n",
"ax.legend(loc=\"upper left\")\n",
"ax.text(13, 2, 'pr(lockdown caused overshoot): %.4f' % pr_lockdown.item(), color=\"white\")\n",
"ax.text(13, 2, 'pr(lockdown has causal role over high overshoot): %.4f' % pr_lockdown.item(), color=\"white\")\n",
"\n",
"ax = axs[1]\n",
"hist_mask = hist_mask_nec.unsqueeze(1) * hist_mask_suff.unsqueeze(0)\n",
Expand All @@ -1295,7 +1318,7 @@
" label=\"Mean Overshoot\",\n",
")\n",
"ax.axhline(y=(os_mask_nec) * 36 / 45, color=\"white\", linestyle=\"--\")\n",
"ax.text(13, 2, 'pr(masking is a cause ): %.4f' % pr_mask.item(), color=\"white\")\n",
"ax.text(13, 2, 'pr(masking has causal role over high overshoot): %.4f' % pr_mask.item(), color=\"white\")\n",
"\n",
"ax.legend(loc=\"upper left\")\n",
"\n",
Expand All @@ -1306,7 +1329,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The above heatmaps plot the distributions $P(\\mathit{os}_{\\mathit{ld}}, \\mathit{os}_{\\mathit{ld}'}|\\mathit{ld, m})$ and $P(\\mathit{os}_{\\mathit{m}}, \\mathit{os}_{\\mathit{m}'}|\\mathit{ld, m})$ respectively. It is evident from the plot above that counterfactual for lockdown has more probability mass in the top right quadrant (low overshoot in the necessity world and high overshoot in the sufficient world). This gives us a more clear picture into why lockdown has more causal role in overshoot being too high as compared to masking."
"The above heatmaps plot the joint distributions arising from necessity and sufficient interventions, particularly $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{os}^w_{\\mathit{ld}}, \\mathit{os}^w_{\\mathit{ld}'}|\\mathit{ld, m})$ where $W = \\{\\mathit{me}\\}$ and $\\sum_{w \\subseteq W} P_w(w) \\cdot P(\\mathit{os}^w_{\\mathit{m}}, \\mathit{os}^w_{\\mathit{m}'}|\\mathit{ld, m})$ where $W = \\{\\mathit{le}\\}$.\n",
"\n",
"It is evident from the plot above that counterfactual for lockdown has more probability mass in the top right quadrant (low overshoot in the necessity world and high overshoot in the sufficient world). This gives us a more clear picture into why lockdown has more causal role in overshoot being too high as compared to masking."
]
},
{
Expand Down

0 comments on commit ca71be5

Please sign in to comment.