Skip to content

Commit 900c6f7

Browse files
author
ArturoAmorQ
committed
FIX Make cross-validation figures less inexact
1 parent 8124c5b commit 900c6f7

3 files changed

+13
-6
lines changed
21.9 KB
Loading
36.5 KB
Loading

figures/plot_parameter_tuning_cv.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,12 @@ def plot_cv_indices(cv, X, y, ax, lw=50):
5858
Patch(color=cmap_cv(0.5)),
5959
Patch(color=cmap_cv(0.02)),
6060
],
61-
["Testing samples", "Training samples", "Validation samples"],
62-
loc=(1.02, 0.7),
61+
[
62+
"Testing samples\n(reserved for\nfinal evaluation)",
63+
"Training samples",
64+
"Validation samples",
65+
],
66+
loc=(1.02, 0.5),
6367
)
6468
return ax
6569

@@ -82,7 +86,6 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
8286

8387
# Generate the training/testing visualizations for each CV split
8488
for ii, (train_outer, test_outer) in enumerate(splits_outer):
85-
8689
splits_inner = list(cv_inner.split(train_outer))
8790
n_splits_inner = len(splits_inner)
8891

@@ -116,7 +119,7 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
116119
)
117120
yticklabels = list(range(n_splits_outer))
118121
ax.set(
119-
yticks=n_splits_inner*np.arange(n_splits_outer) + 0.5,
122+
yticks=n_splits_inner * np.arange(n_splits_outer) + 0.5,
120123
yticklabels=yticklabels,
121124
xlabel="Sample index",
122125
ylabel="CV outer iteration",
@@ -129,8 +132,12 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
129132
Patch(color=cmap_cv(0.5)),
130133
Patch(color=cmap_cv(0.02)),
131134
],
132-
["Testing samples", "Training samples", "Validation samples"],
133-
loc=(1.06, .93),
135+
[
136+
"Testing samples\n(reserved for\nouter evaluation)",
137+
"Training samples",
138+
"Validation samples",
139+
],
140+
loc=(1.06, 0.85),
134141
)
135142
return ax
136143

0 commit comments

Comments
 (0)