Skip to content

Commit 7f4c1a4

Browse files
statified
1 parent 3aa4372 commit 7f4c1a4

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

Diff for: src/splitobs.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ If `shuffle=true`, randomly permute the observations before splitting.
5353
A random number generator `rng` can be optionally passed as the first argument.
5454
5555
If `stratified` is not `nothing`, it should be an array of labels with the same length as the data.
56-
The observations will be split in a way that the proportion of each label is preserved in each subset.
56+
The observations will be split in such a way that the proportion of each label is preserved in each subset.
5757
5858
Supports any datatype implementing [`numobs`](@ref).
5959
@@ -75,6 +75,9 @@ julia> train, test = splitobs((reshape(1.0:100.0, 1, :), 101:200), at=0.7, shuff
7575
7676
julia> vec(test[1]) .+ 100 == test[2]
7777
true
78+
79+
julia> splitobs(1:10, at=0.5, stratified=[0,0,0,0,1,1,1,1,1,1]) # 2 zeros and 3 ones in each subset
80+
([1, 2, 5, 6, 7], [3, 4, 8, 9, 10])
7881
```
7982
"""
8083
splitobs(data; kws...) = splitobs(Random.default_rng(), data; kws...)

Diff for: test/splitobs.jl

+7
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,11 @@ end
9999
d1, d2 = splitobs(data, at=0.25, stratified=data.b)
100100
@test d1.b == [0,1,1]
101101
@test d2.b == [0,0,0,1,1,1,1]
102+
103+
d1, d2 = splitobs(data, at=0., stratified=data.b)
104+
@test d1.b == []
105+
@test d2.b == [0,0,0,0,1,1,1,1,1,1]
106+
d1, d2 = splitobs(data, at=1., stratified=data.b)
107+
@test d1.b == [0,0,0,0,1,1,1,1,1,1]
108+
@test d2.b == []
102109
end

0 commit comments

Comments
 (0)