Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 2, 2025
1 parent d0241cb commit ca8e10a
Showing 1 changed file with 0 additions and 95 deletions.
95 changes: 0 additions & 95 deletions src/resample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,98 +211,3 @@ function undersample(rng::AbstractRNG, data::Tuple; kws...)
d, c = undersample(rng, data[1:end-1], data[end]; kws...)
return (d..., c)
end


"""
stratifiedobs([rng], data, p; [shuffle = true]) -> Tuple
Partition the dataset `data` into multiple disjoint subsets
with size proportional to the value(s) of `p`.
The observations are assignmed to a data subset using stratified sampling without replacement.
If `p` is a float between 0 and 1, then the return value
will be a tuple with two subsests in which the
first element contains the fraction of observations specified by
`p` and the second element contains the rest. In the following
code the first subset `train` will contain around 70% of the
observations and the second subset `test` the rest. The key
difference to [`splitobs`](@ref) is that the class distribution
in `y` will actively be preserved in `train` and `test`.
```julia
train_data, test_data = stratifiedobs(data, p = 0.7)
```
If `p` is a tuple of floats between 0 and 1, then additional subsets will be
created. In this example `train` will contain about 50% of the
observations, `val` will contain around 30%, and `test` the
remaining 20%.
```julia
train_data, val_data, test_data = stratifiedobs(y, p = (0.5, 0.3))
```
It is also possible to call `stratifiedobs` with multiple data
arguments as tuple, which all must have the same number of total
observations. Note that if `data` is a tuple, then it will be
assumed that the last element of the tuple contains the targets.
```julia
(X_train, y_train), (X_test, y_test) = stratifiedobs((X, y), p = 0.7)
```
The optional parameter `shuffle` determines if the resulting data
subsets should be shuffled. If `false`, then the observations in
the subsets will be grouped together according to their labels.
```julia
julia> y = ["a", "b", "b", "b", "b", "a"] # 2 imbalanced classes
6-element Array{String,1}:
"a"
"b"
"b"
"b"
"b"
"a"
julia> train, test = stratifiedobs(y, p = 0.5, shuffle = false)
(String["b","b","a"],String["b","b","a"])
```
The optional argument `rng` allows one to specify the
random number generator used for shuffling.
For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref).
See also [`undersample`](@ref), [`oversample`](@ref), and [`splitobs`](@ref).
"""
function stratifiedobs(data; p = 0.7, shuffle = true, obsdim = default_obsdim(data), rng = Random.GLOBAL_RNG)
stratifiedobs(identity, data, p, shuffle, convert(ObsDimension, obsdim), rng)
end

function stratifiedobs(f, data; p = 0.7, shuffle = true, obsdim = default_obsdim(data), rng = Random.GLOBAL_RNG)
stratifiedobs(f, data, p, shuffle, convert(ObsDimension, obsdim), rng)
end

function stratifiedobs(data, p::AbstractFloat, args...)
stratifiedobs(identity, data, p, args...)
end

function stratifiedobs(data, p::NTuple{N,AbstractFloat}, args...) where N
stratifiedobs(identity, data, p, args...)
end

function stratifiedobs(rng, data, p::Union{NTuple,AbstractFloat}, stratified::AbstractVector)
# The given data is always shuffled to qualify as performing
# stratified sampling without replacement.
idxs_groups = group_indices(stratified)
idxs_splits = ntuple(i -> Int[], length(p)+1)
for (lbl, idxs) in idxs_groups
new_idxs_splits = splitobs(rng, idxs, at=p)
for i in 1:length(idxs_splits)
append!(idxs_splits[i], new_idxs_splits[i])
end
end
return map(idx -> obsview(data, idx), idxs_splits)
end

0 comments on commit ca8e10a

Please sign in to comment.