Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 2, 2025
1 parent ca8e10a commit 3aa4372
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
splitobs(n::Int; at) -> Tuple
Compute the indices for two or more disjoint subsets of
the range `1:n` with splits given by `at`.
the range `1:n` with split sizes determined by `at`.
# Examples
Expand All @@ -18,16 +18,12 @@ splitobs(n::Int; at) = _splitobs(n, at)

_splitobs(n::Int, at::Integer) = _splitobs(n::Int, at / n)
_splitobs(n::Int, at::NTuple{N, <:Integer}) where {N} = _splitobs(n::Int, at ./ n)

_splitobs(n::Int, at::Tuple{}) = (1:n,)


function _splitobs(n::Int, at::AbstractFloat)
0 <= at <= 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)"))
n1 = floor(Int, n * at)
delta = n*at - n1
# TODO add random rounding
(1:n1, n1+1:n)
n1 = round(Int, n * at)
return (1:n1, n1+1:n)
end

function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
Expand All @@ -40,8 +36,9 @@ function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
return (a, rest...)
end


"""
splitobs([rng], data; at, shuffle=false, stratified=nothing) -> Tuple
splitobs([rng,] data; at, shuffle=false, stratified=nothing) -> Tuple
Partition the `data` into two or more subsets.
Expand Down
10 changes: 10 additions & 0 deletions test/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,13 @@ end
p2, _ = splitobs(rng, data, at=3, shuffle=true)
@test p1 == p2
end

@testset "stratified" begin
data = (a=zeros(Float32, 2, 10), b=[0,0,0,0,1,1,1,1,1,1])
d1, d2 = splitobs(data, at=0.5, stratified=data.b)
@test d1.b == [0,0,1,1,1]
@test d2.b == [0,0,1,1,1]
d1, d2 = splitobs(data, at=0.25, stratified=data.b)
@test d1.b == [0,1,1]
@test d2.b == [0,0,0,1,1,1,1]
end

0 comments on commit 3aa4372

Please sign in to comment.