Skip to content

Commit 2101332

Browse files
author
ArturoAmorQ
committed
ENH Add Adult census dataset description
1 parent f39b464 commit 2101332

File tree

3 files changed

+148
-12
lines changed

3 files changed

+148
-12
lines changed

jupyter-book/_toc.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ parts:
213213
- file: appendix/datasets_intro
214214
sections:
215215
- file: python_scripts/trees_dataset
216-
- file: appendix/adult_census_description
216+
- file: python_scripts/datasets_adult_census
217217
- file: python_scripts/datasets_california_housing
218218
- file: python_scripts/datasets_ames_housing
219219
- file: python_scripts/datasets_blood_transfusion

jupyter-book/appendix/adult_census_description.md

-11
This file was deleted.
+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# ---
2+
# jupyter:
3+
# kernelspec:
4+
# display_name: Python 3
5+
# name: python3
6+
# ---
7+
8+
# %% [markdown]
9+
# # The Adult census dataset
10+
#
11+
# [This dataset](http://www.openml.org/d/1590) is a collection of demographic
12+
# information for the adult population as of 1994 in the USA. The prediction
13+
# task is to predict whether a person is earning a high or low revenue in
14+
# USD/year.
15+
#
16+
# The column named **class** is the target variable (i.e., the variable which we
17+
# want to predict). The two possible classes are `" <=50K"` (low-revenue) and
18+
# `" >50K"` (high-revenue).
19+
#
20+
# Before drawing any conclusions based on its statistics or the predictions of
21+
# models trained on it, remember that this dataset is not only outdated, but is
22+
# also not representative of the US population. In fact, the original data
23+
# contains a feature named `fnlwgt` that encodes the number of units in the
24+
# target population that the responding unit represents.
25+
#
26+
# First we load the dataset. We keep only some columns of interest to ease the
27+
# plotting.
28+
29+
# %%
30+
import pandas as pd
31+
32+
adult_census = pd.read_csv("../datasets/adult-census.csv")
33+
columns_to_plot = [
34+
"age",
35+
"education-num",
36+
"capital-loss",
37+
"capital-gain",
38+
"hours-per-week",
39+
"relationship",
40+
"class",
41+
]
42+
target_name = "class"
43+
target = adult_census[target_name]
44+
45+
# %% [markdown]
46+
# We explore this dataset in the first module's notebook "First look at our
47+
# dataset", where we provide a first intuition on how the data is structured.
48+
# There, we use a seaborn pairplot to visualize pairwise relationships between
49+
# the numerical variables in the dataset. This tool aligns scatter plots for every pair
50+
# of variables and histograms for the plots in the
51+
# diagonal of the array.
52+
#
53+
# This approach is limited:
54+
# - Pair plots can only deal with numerical features and;
55+
# - by observing pairwise interactions we end up with a two-dimensional
56+
# projection of a multi-dimensional feature space, which can lead to a wrong
57+
# interpretation of the individual impact of a feature.
58+
#
59+
# Here we explore with some more detail the relation between features using
60+
# plotly `Parcoords`.
61+
62+
# %%
63+
import plotly.graph_objects as go
64+
from sklearn.preprocessing import LabelEncoder
65+
66+
le = LabelEncoder()
67+
68+
69+
def generate_dict(col):
70+
"""Check if column is categorical and generate the appropriate dict"""
71+
if adult_census[col].dtype == "object": # Categorical column
72+
encoded = le.fit_transform(adult_census[col])
73+
return {
74+
"tickvals": list(range(len(le.classes_))),
75+
"ticktext": list(le.classes_),
76+
"label": col,
77+
"values": encoded,
78+
}
79+
else: # Numerical column
80+
return {"label": col, "values": adult_census[col]}
81+
82+
83+
plot_list = [generate_dict(col) for col in columns_to_plot]
84+
85+
fig = go.Figure(
86+
data=go.Parcoords(
87+
line=dict(
88+
color=le.fit_transform(target),
89+
colorscale="Viridis",
90+
),
91+
dimensions=plot_list,
92+
)
93+
)
94+
fig.show()
95+
96+
# %% [markdown]
97+
# The `Parcoords` plot is quite similar to the parallel coordinates plot that we
98+
# present in the module on hyperparameters tuning in this mooc. It display the
99+
# values of the features on different columns while the target class is color
100+
# coded. Thus, we are able to quickly inspect if there is a range of values for
101+
# a certain feature which is leading to a particular result.
102+
#
103+
# As in the parallel coordinates plot, it is possible to select one or more
104+
# ranges of values by clicking and holding on any axis of the plot. You can then
105+
# slide (move) the range selection and cross two selections to see the
106+
# intersections. You can undo a selection by clicking once again on the same
107+
# axis.
108+
#
109+
# In particular for this dataset we observe that values of `"age"` lower to 20
110+
# years are quite predictive of low-income, regardless of the value of other
111+
# features. Similarly, a `"capital-loss"` above `4000` seems to lead to
112+
# low-income.
113+
#
114+
# In this case we can additionaly observe that the variables `"age"` and
115+
# `"relationship"` are more correlated than the others:
116+
117+
# %%
118+
import matplotlib.pyplot as plt
119+
import numpy as np
120+
121+
from scipy.cluster import hierarchy
122+
from scipy.spatial.distance import squareform
123+
from scipy.stats import spearmanr
124+
125+
X = adult_census[columns_to_plot].drop(columns="class")
126+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
127+
corr = spearmanr(X).correlation
128+
129+
# Ensure the correlation matrix is symmetric
130+
corr = (corr + corr.T) / 2
131+
np.fill_diagonal(corr, 1)
132+
133+
# We convert the correlation matrix to a distance matrix before performing
134+
# hierarchical clustering using Ward's linkage.
135+
distance_matrix = 1 - np.abs(corr)
136+
dist_linkage = hierarchy.ward(squareform(distance_matrix))
137+
dendro = hierarchy.dendrogram(
138+
dist_linkage, labels=X.columns.to_list(), ax=ax1, leaf_rotation=90
139+
)
140+
dendro_idx = np.arange(0, len(dendro["ivl"]))
141+
142+
ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]])
143+
ax2.set_xticks(dendro_idx)
144+
ax2.set_yticks(dendro_idx)
145+
ax2.set_xticklabels(dendro["ivl"], rotation="vertical")
146+
ax2.set_yticklabels(dendro["ivl"])
147+
_ = fig.tight_layout()

0 commit comments

Comments
 (0)