|
7 | 7 |
|
8 | 8 | # %% [markdown]
|
9 | 9 | # # Sample grouping
|
10 |
| -# We are going to linger into the concept of sample groups. As in the previous |
11 |
| -# section, we will give an example to highlight some surprising results. This |
12 |
| -# time, we will use the handwritten digits dataset. |
| 10 | +# In this notebook we present the concept of **sample groups**. We use the |
| 11 | +# handwritten digits dataset to highlight some surprising results. |
13 | 12 |
|
14 | 13 | # %%
|
15 | 14 | from sklearn.datasets import load_digits
|
|
18 | 17 | data, target = digits.data, digits.target
|
19 | 18 |
|
20 | 19 | # %% [markdown]
|
21 |
| -# We will recreate the same model used in the previous notebook: a logistic |
22 |
| -# regression classifier with a preprocessor to scale the data. |
| 20 | +# We create a model consisting of a logistic regression classifier with a |
| 21 | +# preprocessor to scale the data. |
| 22 | +# |
| 23 | +# ```{note} |
| 24 | +# Here we use a `MinMaxScaler` as we know that each pixel's gray-scale is |
| 25 | +# strictly bounded between 0 (white) and 16 (black). This makes `MinMaxScaler` |
| 26 | +# more suited in this case than `StandardScaler`, as some pixels consistently |
| 27 | +# have low variance (pixels at the borders might almost always be zero if most |
| 28 | +# digits are centered in the image). Then, using `StandardScaler` can result in |
| 29 | +# a very high scaled value due to division by a small number. |
| 30 | +# ``` |
23 | 31 |
|
24 | 32 | # %%
|
25 | 33 | from sklearn.preprocessing import MinMaxScaler
|
|
29 | 37 | model = make_pipeline(MinMaxScaler(), LogisticRegression(max_iter=1_000))
|
30 | 38 |
|
31 | 39 | # %% [markdown]
|
32 |
| -# We will use the same baseline model. We will use a `KFold` cross-validation |
33 |
| -# without shuffling the data at first. |
| 40 | +# The idea is to compare the estimated generalization performance using |
| 41 | +# different cross-validation techniques and see how such estimations are |
| 42 | +# impacted by underlying data structures. We first use a `KFold` |
| 43 | +# cross-validation without shuffling the data. |
34 | 44 |
|
35 | 45 | # %%
|
36 | 46 | from sklearn.model_selection import cross_val_score, KFold
|
|
59 | 69 | )
|
60 | 70 |
|
61 | 71 | # %% [markdown]
|
62 |
| -# We observe that shuffling the data improves the mean accuracy. We could go a |
63 |
| -# little further and plot the distribution of the testing score. We can first |
64 |
| -# concatenate the test scores. |
| 72 | +# We observe that shuffling the data improves the mean accuracy. We can go a |
| 73 | +# little further and plot the distribution of the testing score. For such |
| 74 | +# purpose we concatenate the test scores. |
65 | 75 |
|
66 | 76 | # %%
|
67 | 77 | import pandas as pd
|
|
72 | 82 | ).T
|
73 | 83 |
|
74 | 84 | # %% [markdown]
|
75 |
| -# Let's plot the distribution now. |
| 85 | +# Let's now plot the score distributions. |
76 | 86 |
|
77 | 87 | # %%
|
78 | 88 | import matplotlib.pyplot as plt
|
79 | 89 |
|
80 |
| -all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7) |
| 90 | +all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7) |
81 | 91 | plt.xlim([0.8, 1.0])
|
82 | 92 | plt.xlabel("Accuracy score")
|
83 | 93 | plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left")
|
84 | 94 | _ = plt.title("Distribution of the test scores")
|
85 | 95 |
|
86 | 96 | # %% [markdown]
|
87 |
| -# The cross-validation testing error that uses the shuffling has less variance |
88 |
| -# than the one that does not impose any shuffling. It means that some specific |
89 |
| -# fold leads to a low score in this case. |
| 97 | +# Shuffling the data results in a higher cross-validated test accuracy with less |
| 98 | +# variance compared to when the data is not shuffled. It means that some |
| 99 | +# specific fold leads to a low score in this case. |
90 | 100 |
|
91 | 101 | # %%
|
92 | 102 | print(test_score_no_shuffling)
|
93 | 103 |
|
94 | 104 | # %% [markdown]
|
95 |
| -# Thus, there is an underlying structure in the data that shuffling will break |
96 |
| -# and get better results. To get a better understanding, we should read the |
97 |
| -# documentation shipped with the dataset. |
| 105 | +# Thus, shuffling the data breaks the underlying structure and thus makes the |
| 106 | +# classification task easier to our model. To get a better understanding, we can |
| 107 | +# read the dataset description in more detail: |
98 | 108 |
|
99 | 109 | # %%
|
100 | 110 | print(digits.DESCR)
|
|
165 | 175 | groups[lb:up] = group_id
|
166 | 176 |
|
167 | 177 | # %% [markdown]
|
168 |
| -# We can check the grouping by plotting the indices linked to writer ids. |
| 178 | +# We can check the grouping by plotting the indices linked to writers' ids. |
169 | 179 |
|
170 | 180 | # %%
|
171 | 181 | plt.plot(groups)
|
|
176 | 186 | _ = plt.title("Underlying writer groups existing in the target")
|
177 | 187 |
|
178 | 188 | # %% [markdown]
|
179 |
| -# Once we group the digits by writer, we can use cross-validation to take this |
180 |
| -# information into account: the class containing `Group` should be used. |
| 189 | +# Once we group the digits by writer, we can incorporate this information into |
| 190 | +# the cross-validation process by using group-aware variations of the strategies |
| 191 | +# we have explored in this course, for example, the `GroupKFold` strategy. |
181 | 192 |
|
182 | 193 | # %%
|
183 | 194 | from sklearn.model_selection import GroupKFold
|
|
191 | 202 | )
|
192 | 203 |
|
193 | 204 | # %% [markdown]
|
194 |
| -# We see that this strategy is less optimistic regarding the model |
195 |
| -# generalization performance. However, this is the most reliable if our goal is |
196 |
| -# to make handwritten digits recognition writers independent. Besides, we can as |
197 |
| -# well see that the standard deviation was reduced. |
| 205 | +# We see that this strategy leads to a lower generalization performance than the |
| 206 | +# other two techniques. However, this is the most reliable estimate if our goal |
| 207 | +# is to evaluate the capabilities of the model to generalize to new unseen |
| 208 | +# writers. In this sense, shuffling the dataset (or alternatively using the |
| 209 | +# writers' ids as a new feature) would lead the model to memorize the different |
| 210 | +# writer's particular handwriting. |
198 | 211 |
|
199 | 212 | # %%
|
200 | 213 | all_scores = pd.DataFrame(
|
|
207 | 220 | ).T
|
208 | 221 |
|
209 | 222 | # %%
|
210 |
| -all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7) |
| 223 | +all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7) |
211 | 224 | plt.xlim([0.8, 1.0])
|
212 | 225 | plt.xlabel("Accuracy score")
|
213 | 226 | plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left")
|
214 | 227 | _ = plt.title("Distribution of the test scores")
|
215 | 228 |
|
216 | 229 | # %% [markdown]
|
217 |
| -# As a conclusion, it is really important to take any sample grouping pattern |
218 |
| -# into account when evaluating a model. Otherwise, the results obtained will be |
219 |
| -# over-optimistic in regards with reality. |
| 230 | +# In conclusion, accounting for any sample grouping patterns is crucial when |
| 231 | +# assessing a model’s ability to generalize to new groups. Without this |
| 232 | +# consideration, the results may appear overly optimistic compared to the actual |
| 233 | +# performance. |
| 234 | +# |
| 235 | +# The interested reader can learn about other group-aware cross-validation |
| 236 | +# techniques in the [scikit-learn user |
| 237 | +# guide](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data). |
0 commit comments