Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Improve wording in group-aware cross-validation notebook #776

Merged
merged 5 commits into from
May 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 47 additions & 29 deletions python_scripts/cross_validation_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

# %% [markdown]
# # Sample grouping
# We are going to linger into the concept of sample groups. As in the previous
# section, we will give an example to highlight some surprising results. This
# time, we will use the handwritten digits dataset.
# In this notebook we present the concept of **sample groups**. We use the
# handwritten digits dataset to highlight some surprising results.

# %%
from sklearn.datasets import load_digits
Expand All @@ -18,8 +17,17 @@
data, target = digits.data, digits.target

# %% [markdown]
# We will recreate the same model used in the previous notebook: a logistic
# regression classifier with a preprocessor to scale the data.
# We create a model consisting of a logistic regression classifier with a
# preprocessor to scale the data.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
#
# ```{note}
# Here we use a `MinMaxScaler` as we know that each pixel's gray-scale is
# strictly bounded between 0 (white) and 16 (black). This makes `MinMaxScaler`
# more suited in this case than `StandardScaler`, as some pixels consistently
# have low variance (pixels at the borders might almost always be zero if most
# digits are centered in the image). Then, using `StandardScaler` can result in
# a very high scaled value due to division by a small number.
# ```

# %%
from sklearn.preprocessing import MinMaxScaler
Expand All @@ -29,8 +37,10 @@
model = make_pipeline(MinMaxScaler(), LogisticRegression(max_iter=1_000))

# %% [markdown]
# We will use the same baseline model. We will use a `KFold` cross-validation
# without shuffling the data at first.
# The idea is to compare the estimated generalization performance using
# different cross-validation techniques and see how such estimations are
# impacted by underlying data structures. We first use a `KFold`
# cross-validation without shuffling the data.

# %%
from sklearn.model_selection import cross_val_score, KFold
Expand Down Expand Up @@ -59,9 +69,9 @@
)

# %% [markdown]
# We observe that shuffling the data improves the mean accuracy. We could go a
# little further and plot the distribution of the testing score. We can first
# concatenate the test scores.
# We observe that shuffling the data improves the mean accuracy. We can go a
# little further and plot the distribution of the testing score. For such
# purpose we concatenate the test scores.

# %%
import pandas as pd
Expand All @@ -72,29 +82,29 @@
).T

# %% [markdown]
# Let's plot the distribution now.
# Let's now plot the score distributions.

# %%
import matplotlib.pyplot as plt

all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7)
all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7)
plt.xlim([0.8, 1.0])
plt.xlabel("Accuracy score")
plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left")
_ = plt.title("Distribution of the test scores")

# %% [markdown]
# The cross-validation testing error that uses the shuffling has less variance
# than the one that does not impose any shuffling. It means that some specific
# fold leads to a low score in this case.
# Shuffling the data results in a higher cross-validated test accuracy with less
# variance compared to when the data is not shuffled. It means that some
# specific fold leads to a low score in this case.

# %%
print(test_score_no_shuffling)

# %% [markdown]
# Thus, there is an underlying structure in the data that shuffling will break
# and get better results. To get a better understanding, we should read the
# documentation shipped with the dataset.
# Thus, shuffling the data breaks the underlying structure and thus makes the
# classification task easier to our model. To get a better understanding, we can
# read the dataset description in more detail:

# %%
print(digits.DESCR)
Expand Down Expand Up @@ -165,7 +175,7 @@
groups[lb:up] = group_id

# %% [markdown]
# We can check the grouping by plotting the indices linked to writer ids.
# We can check the grouping by plotting the indices linked to writers' ids.

# %%
plt.plot(groups)
Expand All @@ -176,8 +186,9 @@
_ = plt.title("Underlying writer groups existing in the target")

# %% [markdown]
# Once we group the digits by writer, we can use cross-validation to take this
# information into account: the class containing `Group` should be used.
# Once we group the digits by writer, we can incorporate this information into
# the cross-validation process by using group-aware variations of the strategies
# we have explored in this course, for example, the `GroupKFold` strategy.

# %%
from sklearn.model_selection import GroupKFold
Expand All @@ -191,10 +202,12 @@
)

# %% [markdown]
# We see that this strategy is less optimistic regarding the model
# generalization performance. However, this is the most reliable if our goal is
# to make handwritten digits recognition writers independent. Besides, we can as
# well see that the standard deviation was reduced.
# We see that this strategy leads to a lower generalization performance than the
# other two techniques. However, this is the most reliable estimate if our goal
# is to evaluate the capabilities of the model to generalize to new unseen
# writers. In this sense, shuffling the dataset (or alternatively using the
# writers' ids as a new feature) would lead the model to memorize the different
# writer's particular handwriting.

# %%
all_scores = pd.DataFrame(
Expand All @@ -207,13 +220,18 @@
).T

# %%
all_scores.plot.hist(bins=10, edgecolor="black", alpha=0.7)
all_scores.plot.hist(bins=16, edgecolor="black", alpha=0.7)
plt.xlim([0.8, 1.0])
plt.xlabel("Accuracy score")
plt.legend(bbox_to_anchor=(1.05, 0.8), loc="upper left")
_ = plt.title("Distribution of the test scores")

# %% [markdown]
# As a conclusion, it is really important to take any sample grouping pattern
# into account when evaluating a model. Otherwise, the results obtained will be
# over-optimistic in regards with reality.
# In conclusion, accounting for any sample grouping patterns is crucial when
# assessing a model’s ability to generalize to new groups. Without this
# consideration, the results may appear overly optimistic compared to the actual
# performance.
#
# The interested reader can learn about other group-aware cross-validation
# techniques in the [scikit-learn user
# guide](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data).
Loading