Skip to content

Commit 0a416f8

Browse files
glemaitreArturoAmorQ
and
ArturoAmorQ
committed
[ci skip] ENH Improve wording in group-aware cross-validation notebook (#776)
Co-authored-by: ArturoAmorQ <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]> 6c73552
1 parent 1531294 commit 0a416f8

File tree

95 files changed

+864
-829
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+864
-829
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

_sources/python_scripts/cross_validation_grouping.py

+47-29
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77

88
# %% [markdown]
99
# # Sample grouping
10-
# We are going to linger into the concept of sample groups. As in the previous
11-
# section, we will give an example to highlight some surprising results. This
12-
# time, we will use the handwritten digits dataset.
10+
# In this notebook we present the concept of **sample groups**. We use the
11+
# handwritten digits dataset to highlight some surprising results.
1312

1413
# %%
1514
from sklearn.datasets import load_digits
@@ -18,8 +17,17 @@
1817
data, target = digits.data, digits.target
1918

2019
# %% [markdown]
21-
# We will recreate the same model used in the previous notebook: a logistic
22-
# regression classifier with a preprocessor to scale the data.
20+
# We create a model consisting of a logistic regression classifier with a
21+
# preprocessor to scale the data.
22+
#
23+
# ```{note}
24+
# Here we use a `MinMaxScaler` as we know that each pixel's gray-scale is
25+
# strictly bounded between 0 (white) and 16 (black). This makes `MinMaxScaler`
26+
# more suited in this case than `StandardScaler`, as some pixels consistently
27+
# have low variance (pixels at the borders might almost always be zero if most
28+
# digits are centered in the image). Then, using `StandardScaler` can result in
29+
# a very high scaled value due to division by a small number.
30+
# ```
2331

2432
# %%
2533
from sklearn.preprocessing import MinMaxScaler
@@ -29,8 +37,10 @@
2937
model = make_pipeline(MinMaxScaler(), LogisticRegression(max_iter=1_000))
3038

3139
# %% [markdown]
32-
# We will use the same baseline model. We will use a `KFold` cross-validation
33-
# without shuffling the data at first.
40+
# The idea is to compare the estimated generalization performance using
41+
# different cross-validation techniques and see how such estimations are
42+
# impacted by underlying data structures. We first use a `KFold`
43+
# cross-validation without shuffling the data.
3444

3545
# %%
3646
from sklearn.model_selection import cross_val_score, KFold
@@ -59,9 +69,9 @@
5969
)
6070

6171
# %% [markdown]
62-
# We observe that shuffling the data improves the mean accuracy. We could go a
63-
# little further and plot the distribution of the testing score. We can first
64-
# concatenate the test scores.
72+
# We observe that shuffling the data improves the mean accuracy. We can go a
73+
# little further and plot the distribution of the testing score. For such
74+
# purpose we concatenate the test scores.
6575

6676
# %%
6777
import pandas as pd
@@ -72,29 +82,29 @@
7282
).T
7383

7484
# %% [markdown]
75-
# Let's plot the distribution now.
85+
# Let's now plot the score distributions.
7686

7787
# %%
7888
import matplotlib.pyplot as plt
7989

80-
all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7)
90+
all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7)
8191
plt.xlim([0.8, 1.0])
8292
plt.xlabel("Accuracy score")
8393
plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left")
8494
_ = plt.title("Distribution of the test scores")
8595

8696
# %% [markdown]
87-
# The cross-validation testing error that uses the shuffling has less variance
88-
# than the one that does not impose any shuffling. It means that some specific
89-
# fold leads to a low score in this case.
97+
# Shuffling the data results in a higher cross-validated test accuracy with less
98+
# variance compared to when the data is not shuffled. It means that some
99+
# specific fold leads to a low score in this case.
90100

91101
# %%
92102
print(test_score_no_shuffling)
93103

94104
# %% [markdown]
95-
# Thus, there is an underlying structure in the data that shuffling will break
96-
# and get better results. To get a better understanding, we should read the
97-
# documentation shipped with the dataset.
105+
# Thus, shuffling the data breaks the underlying structure and thus makes the
106+
# classification task easier to our model. To get a better understanding, we can
107+
# read the dataset description in more detail:
98108

99109
# %%
100110
print(digits.DESCR)
@@ -165,7 +175,7 @@
165175
groups[lb:up] = group_id
166176

167177
# %% [markdown]
168-
# We can check the grouping by plotting the indices linked to writer ids.
178+
# We can check the grouping by plotting the indices linked to writers' ids.
169179

170180
# %%
171181
plt.plot(groups)
@@ -176,8 +186,9 @@
176186
_ = plt.title("Underlying writer groups existing in the target")
177187

178188
# %% [markdown]
179-
# Once we group the digits by writer, we can use cross-validation to take this
180-
# information into account: the class containing `Group` should be used.
189+
# Once we group the digits by writer, we can incorporate this information into
190+
# the cross-validation process by using group-aware variations of the strategies
191+
# we have explored in this course, for example, the `GroupKFold` strategy.
181192

182193
# %%
183194
from sklearn.model_selection import GroupKFold
@@ -191,10 +202,12 @@
191202
)
192203

193204
# %% [markdown]
194-
# We see that this strategy is less optimistic regarding the model
195-
# generalization performance. However, this is the most reliable if our goal is
196-
# to make handwritten digits recognition writers independent. Besides, we can as
197-
# well see that the standard deviation was reduced.
205+
# We see that this strategy leads to a lower generalization performance than the
206+
# other two techniques. However, this is the most reliable estimate if our goal
207+
# is to evaluate the capabilities of the model to generalize to new unseen
208+
# writers. In this sense, shuffling the dataset (or alternatively using the
209+
# writers' ids as a new feature) would lead the model to memorize the different
210+
# writer's particular handwriting.
198211

199212
# %%
200213
all_scores = pd.DataFrame(
@@ -207,13 +220,18 @@
207220
).T
208221

209222
# %%
210-
all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7)
223+
all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7)
211224
plt.xlim([0.8, 1.0])
212225
plt.xlabel("Accuracy score")
213226
plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left")
214227
_ = plt.title("Distribution of the test scores")
215228

216229
# %% [markdown]
217-
# As a conclusion, it is really important to take any sample grouping pattern
218-
# into account when evaluating a model. Otherwise, the results obtained will be
219-
# over-optimistic in regards with reality.
230+
# In conclusion, accounting for any sample grouping patterns is crucial when
231+
# assessing a model’s ability to generalize to new groups. Without this
232+
# consideration, the results may appear overly optimistic compared to the actual
233+
# performance.
234+
#
235+
# The interested reader can learn about other group-aware cross-validation
236+
# techniques in the [scikit-learn user
237+
# guide](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data).

0 commit comments

Comments
 (0)