2
2
splitobs(n::Int; at) -> Tuple
3
3
4
4
Compute the indices for two or more disjoint subsets of
5
- the range `1:n` with splits given by `at`.
5
+ the range `1:n` with split sizes determined by `at`.
6
6
7
7
# Examples
8
8
@@ -18,13 +18,12 @@ splitobs(n::Int; at) = _splitobs(n, at)
18
18
19
19
_splitobs (n:: Int , at:: Integer ) = _splitobs (n:: Int , at / n)
20
20
_splitobs (n:: Int , at:: NTuple{N, <:Integer} ) where {N} = _splitobs (n:: Int , at ./ n)
21
-
22
21
_splitobs (n:: Int , at:: Tuple{} ) = (1 : n,)
23
22
24
23
function _splitobs (n:: Int , at:: AbstractFloat )
25
24
0 <= at <= 1 || throw (ArgumentError (" the parameter \" at\" must be in interval (0, 1)" ))
26
- n1 = clamp ( round (Int, at * n), 0 , n )
27
- (1 : n1, n1+ 1 : n)
25
+ n1 = round (Int, n * at )
26
+ return (1 : n1, n1+ 1 : n)
28
27
end
29
28
30
29
function _splitobs (n:: Int , at:: NTuple{N,<:AbstractFloat} ) where N
@@ -37,22 +36,24 @@ function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
37
36
return (a, rest... )
38
37
end
39
38
39
+
40
40
"""
41
- splitobs([rng], data; at, shuffle=false) -> Tuple
41
+ splitobs([rng,] data; at, shuffle=false, stratified=nothing ) -> Tuple
42
42
43
43
Partition the `data` into two or more subsets.
44
44
45
- When `at` is a number between 0 and 1, this specifies the proportion in the first subset.
46
-
47
- When `at` is an integer, it specifies the number of observations in the first subset.
48
-
49
- When `at` is a tuple, entries specifies the number or proportion in each subset, except
45
+ The argument `at` specifies how to split the data:
46
+ - When `at` is a number between 0 and 1, this specifies the proportion in the first subset.
47
+ - When `at` is an integer, it specifies the number of observations in the first subset.
48
+ - When `at` is a tuple, entries specifies the number or proportion in each subset, except
50
49
for the last which will contain the remaning observations.
51
50
The number of returned subsets is `length(at)+1`.
52
51
53
52
If `shuffle=true`, randomly permute the observations before splitting.
54
53
A random number generator `rng` can be optionally passed as the first argument.
55
54
55
+ If `stratified` is not `nothing`, it should be an array of labels with the same length as the data.
56
+ The observations will be split in such a way that the proportion of each label is preserved in each subset.
56
57
57
58
Supports any datatype implementing [`numobs`](@ref).
58
59
@@ -74,14 +75,41 @@ julia> train, test = splitobs((reshape(1.0:100.0, 1, :), 101:200), at=0.7, shuff
74
75
75
76
julia> vec(test[1]) .+ 100 == test[2]
76
77
true
78
+
79
+ julia> splitobs(1:10, at=0.5, stratified=[0,0,0,0,1,1,1,1,1,1]) # 2 zeros and 3 ones in each subset
80
+ ([1, 2, 5, 6, 7], [3, 4, 8, 9, 10])
77
81
```
78
82
"""
79
83
splitobs (data; kws... ) = splitobs (Random. default_rng (), data; kws... )
80
84
81
- function splitobs (rng:: AbstractRNG , data; at, shuffle:: Bool = false )
85
+ function splitobs (rng:: AbstractRNG , data; at,
86
+ shuffle:: Bool = false ,
87
+ stratified:: Union{Nothing,AbstractVector} = nothing )
88
+ n = numobs (data)
89
+ at = _normalize_at (n, at)
82
90
if shuffle
83
- data = shuffleobs (rng, data)
91
+ perm = randperm (rng, n)
92
+ data = obsview (data, perm) # same as shuffleobs(rng, data), but make it explicit to keep perm
84
93
end
85
- n = numobs (data)
86
- return map (idx -> obsview (data, idx), splitobs (n; at))
94
+ if stratified != = nothing
95
+ @assert length (stratified) == n
96
+ if shuffle
97
+ stratified = stratified[perm]
98
+ end
99
+ idxs_groups = group_indices (stratified)
100
+ idxs_splits = ntuple (i -> Int[], length (at)+ 1 )
101
+ for (lbl, idxs) in idxs_groups
102
+ new_idxs_splits = splitobs (idxs; at, shuffle= false )
103
+ for i in 1 : length (idxs_splits)
104
+ append! (idxs_splits[i], new_idxs_splits[i])
105
+ end
106
+ end
107
+ else
108
+ idxs_splits = splitobs (n; at)
109
+ end
110
+ return map (idxs -> obsview (data, idxs), idxs_splits)
87
111
end
112
+
113
+ _normalize_at (n, at:: Integer ) = at / n
114
+ _normalize_at (n, at:: NTuple{N, <:Integer} ) where N = at ./ n
115
+ _normalize_at (n, at) = at
0 commit comments