|
| 1 | +# --- |
| 2 | +# jupyter: |
| 3 | +# kernelspec: |
| 4 | +# display_name: Python 3 |
| 5 | +# name: python3 |
| 6 | +# --- |
| 7 | + |
| 8 | +# %% [markdown] |
| 9 | +# # The Adult census dataset |
| 10 | +# |
| 11 | +# [This dataset](http://www.openml.org/d/1590) is a collection of demographic |
| 12 | +# information for the adult population as of 1994 in the USA. The prediction |
| 13 | +# task is to predict whether a person is earning a high or low revenue in |
| 14 | +# USD/year. |
| 15 | +# |
| 16 | +# The column named **class** is the target variable (i.e., the variable which we |
| 17 | +# want to predict). The two possible classes are `" <=50K"` (low-revenue) and |
| 18 | +# `" >50K"` (high-revenue). |
| 19 | +# |
| 20 | +# Before drawing any conclusions based on its statistics or the predictions of |
| 21 | +# models trained on it, remember that this dataset is not only outdated, but is |
| 22 | +# also not representative of the US population. In fact, the original data |
| 23 | +# contains a feature named `fnlwgt` that encodes the number of units in the |
| 24 | +# target population that the responding unit represents. |
| 25 | +# |
| 26 | +# First we load the dataset. We keep only some columns of interest to ease the |
| 27 | +# plotting. |
| 28 | + |
| 29 | +# %% |
| 30 | +import pandas as pd |
| 31 | + |
| 32 | +adult_census = pd.read_csv("../datasets/adult-census.csv") |
| 33 | +columns_to_plot = [ |
| 34 | + "age", |
| 35 | + "education-num", |
| 36 | + "capital-loss", |
| 37 | + "capital-gain", |
| 38 | + "hours-per-week", |
| 39 | + "relationship", |
| 40 | + "class", |
| 41 | +] |
| 42 | +target_name = "class" |
| 43 | +target = adult_census[target_name] |
| 44 | + |
| 45 | +# %% [markdown] |
| 46 | +# We explore this dataset in the first module's notebook "First look at our |
| 47 | +# dataset", where we provide a first intuition on how the data is structured. |
| 48 | +# There, we use a seaborn pairplot to visualize pairwise relationships between |
| 49 | +# the numerical variables in the dataset. This tool aligns scatter plots for every pair |
| 50 | +# of variables and histograms for the plots in the |
| 51 | +# diagonal of the array. |
| 52 | +# |
| 53 | +# This approach is limited: |
| 54 | +# - Pair plots can only deal with numerical features and; |
| 55 | +# - by observing pairwise interactions we end up with a two-dimensional |
| 56 | +# projection of a multi-dimensional feature space, which can lead to a wrong |
| 57 | +# interpretation of the individual impact of a feature. |
| 58 | +# |
| 59 | +# Here we explore with some more detail the relation between features using |
| 60 | +# plotly `Parcoords`. |
| 61 | + |
| 62 | +# %% |
| 63 | +import plotly.graph_objects as go |
| 64 | +from sklearn.preprocessing import LabelEncoder |
| 65 | + |
| 66 | +le = LabelEncoder() |
| 67 | + |
| 68 | + |
| 69 | +def generate_dict(col): |
| 70 | + """Check if column is categorical and generate the appropriate dict""" |
| 71 | + if adult_census[col].dtype == "object": # Categorical column |
| 72 | + encoded = le.fit_transform(adult_census[col]) |
| 73 | + return { |
| 74 | + "tickvals": list(range(len(le.classes_))), |
| 75 | + "ticktext": list(le.classes_), |
| 76 | + "label": col, |
| 77 | + "values": encoded, |
| 78 | + } |
| 79 | + else: # Numerical column |
| 80 | + return {"label": col, "values": adult_census[col]} |
| 81 | + |
| 82 | + |
| 83 | +plot_list = [generate_dict(col) for col in columns_to_plot] |
| 84 | + |
| 85 | +fig = go.Figure( |
| 86 | + data=go.Parcoords( |
| 87 | + line=dict( |
| 88 | + color=le.fit_transform(target), |
| 89 | + colorscale="Viridis", |
| 90 | + ), |
| 91 | + dimensions=plot_list, |
| 92 | + ) |
| 93 | +) |
| 94 | +fig.show() |
| 95 | + |
| 96 | +# %% [markdown] |
| 97 | +# The `Parcoords` plot is quite similar to the parallel coordinates plot that we |
| 98 | +# present in the module on hyperparameters tuning in this mooc. It display the |
| 99 | +# values of the features on different columns while the target class is color |
| 100 | +# coded. Thus, we are able to quickly inspect if there is a range of values for |
| 101 | +# a certain feature which is leading to a particular result. |
| 102 | +# |
| 103 | +# As in the parallel coordinates plot, it is possible to select one or more |
| 104 | +# ranges of values by clicking and holding on any axis of the plot. You can then |
| 105 | +# slide (move) the range selection and cross two selections to see the |
| 106 | +# intersections. You can undo a selection by clicking once again on the same |
| 107 | +# axis. |
| 108 | +# |
| 109 | +# In particular for this dataset we observe that values of `"age"` lower to 20 |
| 110 | +# years are quite predictive of low-income, regardless of the value of other |
| 111 | +# features. Similarly, a `"capital-loss"` above `4000` seems to lead to |
| 112 | +# low-income. |
| 113 | +# |
| 114 | +# In this case we can additionaly observe that the variables `"age"` and |
| 115 | +# `"relationship"` are more correlated than the others: |
| 116 | + |
| 117 | +# %% |
| 118 | +import matplotlib.pyplot as plt |
| 119 | +import numpy as np |
| 120 | + |
| 121 | +from scipy.cluster import hierarchy |
| 122 | +from scipy.spatial.distance import squareform |
| 123 | +from scipy.stats import spearmanr |
| 124 | + |
| 125 | +X = adult_census[columns_to_plot].drop(columns="class") |
| 126 | +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) |
| 127 | +corr = spearmanr(X).correlation |
| 128 | + |
| 129 | +# Ensure the correlation matrix is symmetric |
| 130 | +corr = (corr + corr.T) / 2 |
| 131 | +np.fill_diagonal(corr, 1) |
| 132 | + |
| 133 | +# We convert the correlation matrix to a distance matrix before performing |
| 134 | +# hierarchical clustering using Ward's linkage. |
| 135 | +distance_matrix = 1 - np.abs(corr) |
| 136 | +dist_linkage = hierarchy.ward(squareform(distance_matrix)) |
| 137 | +dendro = hierarchy.dendrogram( |
| 138 | + dist_linkage, labels=X.columns.to_list(), ax=ax1, leaf_rotation=90 |
| 139 | +) |
| 140 | +dendro_idx = np.arange(0, len(dendro["ivl"])) |
| 141 | + |
| 142 | +ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]]) |
| 143 | +ax2.set_xticks(dendro_idx) |
| 144 | +ax2.set_yticks(dendro_idx) |
| 145 | +ax2.set_xticklabels(dendro["ivl"], rotation="vertical") |
| 146 | +ax2.set_yticklabels(dendro["ivl"]) |
| 147 | +_ = fig.tight_layout() |
0 commit comments