Skip to content

Commit da07fd6

Browse files
add stratified options to splitobs (#195)
* stratobs * cleanup * cleanup * statified
1 parent 0031c03 commit da07fd6

File tree

4 files changed

+62
-21
lines changed

4 files changed

+62
-21
lines changed

src/obstransform.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,7 @@ accomplish that, which means that the return value is likely of a
213213
different type than `data`.
214214
215215
Optionally, a random number generator `rng` can be passed as the
216-
first argument.
217-
218-
The optional parameter `rng` allows one to specify the
219-
random number generator used for shuffling. This is useful when
220-
reproducible results are desired.
216+
first argument.
221217
222218
For this function to work, the type of `data` must implement
223219
[`numobs`](@ref) and [`getobs`](@ref).

src/splitobs.jl

+42-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
splitobs(n::Int; at) -> Tuple
33
44
Compute the indices for two or more disjoint subsets of
5-
the range `1:n` with splits given by `at`.
5+
the range `1:n` with split sizes determined by `at`.
66
77
# Examples
88
@@ -18,13 +18,12 @@ splitobs(n::Int; at) = _splitobs(n, at)
1818

1919
_splitobs(n::Int, at::Integer) = _splitobs(n::Int, at / n)
2020
_splitobs(n::Int, at::NTuple{N, <:Integer}) where {N} = _splitobs(n::Int, at ./ n)
21-
2221
_splitobs(n::Int, at::Tuple{}) = (1:n,)
2322

2423
function _splitobs(n::Int, at::AbstractFloat)
2524
0 <= at <= 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)"))
26-
n1 = clamp(round(Int, at*n), 0, n)
27-
(1:n1, n1+1:n)
25+
n1 = round(Int, n * at)
26+
return (1:n1, n1+1:n)
2827
end
2928

3029
function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
@@ -37,22 +36,24 @@ function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
3736
return (a, rest...)
3837
end
3938

39+
4040
"""
41-
splitobs([rng], data; at, shuffle=false) -> Tuple
41+
splitobs([rng,] data; at, shuffle=false, stratified=nothing) -> Tuple
4242
4343
Partition the `data` into two or more subsets.
4444
45-
When `at` is a number between 0 and 1, this specifies the proportion in the first subset.
46-
47-
When `at` is an integer, it specifies the number of observations in the first subset.
48-
49-
When `at` is a tuple, entries specifies the number or proportion in each subset, except
45+
The argument `at` specifies how to split the data:
46+
- When `at` is a number between 0 and 1, this specifies the proportion in the first subset.
47+
- When `at` is an integer, it specifies the number of observations in the first subset.
48+
- When `at` is a tuple, entries specifies the number or proportion in each subset, except
5049
for the last which will contain the remaning observations.
5150
The number of returned subsets is `length(at)+1`.
5251
5352
If `shuffle=true`, randomly permute the observations before splitting.
5453
A random number generator `rng` can be optionally passed as the first argument.
5554
55+
If `stratified` is not `nothing`, it should be an array of labels with the same length as the data.
56+
The observations will be split in such a way that the proportion of each label is preserved in each subset.
5657
5758
Supports any datatype implementing [`numobs`](@ref).
5859
@@ -74,14 +75,41 @@ julia> train, test = splitobs((reshape(1.0:100.0, 1, :), 101:200), at=0.7, shuff
7475
7576
julia> vec(test[1]) .+ 100 == test[2]
7677
true
78+
79+
julia> splitobs(1:10, at=0.5, stratified=[0,0,0,0,1,1,1,1,1,1]) # 2 zeros and 3 ones in each subset
80+
([1, 2, 5, 6, 7], [3, 4, 8, 9, 10])
7781
```
7882
"""
7983
splitobs(data; kws...) = splitobs(Random.default_rng(), data; kws...)
8084

81-
function splitobs(rng::AbstractRNG, data; at, shuffle::Bool=false)
85+
function splitobs(rng::AbstractRNG, data; at,
86+
shuffle::Bool=false,
87+
stratified::Union{Nothing,AbstractVector}=nothing)
88+
n = numobs(data)
89+
at = _normalize_at(n, at)
8290
if shuffle
83-
data = shuffleobs(rng, data)
91+
perm = randperm(rng, n)
92+
data = obsview(data, perm) # same as shuffleobs(rng, data), but make it explicit to keep perm
8493
end
85-
n = numobs(data)
86-
return map(idx -> obsview(data, idx), splitobs(n; at))
94+
if stratified !== nothing
95+
@assert length(stratified) == n
96+
if shuffle
97+
stratified = stratified[perm]
98+
end
99+
idxs_groups = group_indices(stratified)
100+
idxs_splits = ntuple(i -> Int[], length(at)+1)
101+
for (lbl, idxs) in idxs_groups
102+
new_idxs_splits = splitobs(idxs; at, shuffle=false)
103+
for i in 1:length(idxs_splits)
104+
append!(idxs_splits[i], new_idxs_splits[i])
105+
end
106+
end
107+
else
108+
idxs_splits = splitobs(n; at)
109+
end
110+
return map(idxs -> obsview(data, idxs), idxs_splits)
87111
end
112+
113+
_normalize_at(n, at::Integer) = at / n
114+
_normalize_at(n, at::NTuple{N, <:Integer}) where N = at ./ n
115+
_normalize_at(n, at) = at

src/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,15 @@ function batch(xs::Vector{<:NamedTuple})
380380
all_keys = [sort(collect(keys(x))) for x in xs]
381381
ks = all_keys[1]
382382
@assert all(==(ks), all_keys) "Cannot batch named tuples with different keys"
383-
NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
383+
return NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
384384
end
385385

386386
function batch(xs::Vector{<:Dict})
387387
@assert length(xs) > 0 "Input should be non-empty"
388388
all_keys = [sort(collect(keys(x))) for x in xs]
389389
ks = all_keys[1]
390390
@assert all(==(ks), all_keys) "cannot batch dicts with different keys"
391-
Dict(k => batch([x[k] for x in xs]) for k in ks)
391+
return Dict(k => batch([x[k] for x in xs]) for k in ks)
392392
end
393393

394394
"""

test/splitobs.jl

+17
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,20 @@ end
9090
p2, _ = splitobs(rng, data, at=3, shuffle=true)
9191
@test p1 == p2
9292
end
93+
94+
@testset "stratified" begin
95+
data = (a=zeros(Float32, 2, 10), b=[0,0,0,0,1,1,1,1,1,1])
96+
d1, d2 = splitobs(data, at=0.5, stratified=data.b)
97+
@test d1.b == [0,0,1,1,1]
98+
@test d2.b == [0,0,1,1,1]
99+
d1, d2 = splitobs(data, at=0.25, stratified=data.b)
100+
@test d1.b == [0,1,1]
101+
@test d2.b == [0,0,0,1,1,1,1]
102+
103+
d1, d2 = splitobs(data, at=0., stratified=data.b)
104+
@test d1.b == []
105+
@test d2.b == [0,0,0,0,1,1,1,1,1,1]
106+
d1, d2 = splitobs(data, at=1., stratified=data.b)
107+
@test d1.b == [0,0,0,0,1,1,1,1,1,1]
108+
@test d2.b == []
109+
end

0 commit comments

Comments
 (0)