@@ -58,8 +58,12 @@ def plot_cv_indices(cv, X, y, ax, lw=50):
58
58
Patch (color = cmap_cv (0.5 )),
59
59
Patch (color = cmap_cv (0.02 )),
60
60
],
61
- ["Testing samples" , "Training samples" , "Validation samples" ],
62
- loc = (1.02 , 0.7 ),
61
+ [
62
+ "Testing samples\n (reserved for\n final evaluation)" ,
63
+ "Training samples" ,
64
+ "Validation samples" ,
65
+ ],
66
+ loc = (1.02 , 0.5 ),
63
67
)
64
68
return ax
65
69
@@ -82,7 +86,6 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
82
86
83
87
# Generate the training/testing visualizations for each CV split
84
88
for ii , (train_outer , test_outer ) in enumerate (splits_outer ):
85
-
86
89
splits_inner = list (cv_inner .split (train_outer ))
87
90
n_splits_inner = len (splits_inner )
88
91
@@ -116,7 +119,7 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
116
119
)
117
120
yticklabels = list (range (n_splits_outer ))
118
121
ax .set (
119
- yticks = n_splits_inner * np .arange (n_splits_outer ) + 0.5 ,
122
+ yticks = n_splits_inner * np .arange (n_splits_outer ) + 0.5 ,
120
123
yticklabels = yticklabels ,
121
124
xlabel = "Sample index" ,
122
125
ylabel = "CV outer iteration" ,
@@ -129,8 +132,12 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
129
132
Patch (color = cmap_cv (0.5 )),
130
133
Patch (color = cmap_cv (0.02 )),
131
134
],
132
- ["Testing samples" , "Training samples" , "Validation samples" ],
133
- loc = (1.06 , .93 ),
135
+ [
136
+ "Testing samples\n (reserved for\n outer evaluation)" ,
137
+ "Training samples" ,
138
+ "Validation samples" ,
139
+ ],
140
+ loc = (1.06 , 0.85 ),
134
141
)
135
142
return ax
136
143
0 commit comments