-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PR welcome on stratified K-folds? #60
Comments
Yes, this would be a nice addition. You can probably adapt the MLDataPattern.jl code as a starting point. I would try and start with the simplest implementation first though, since part of the goal with this package is to cut down on the complexity of the MLDataPattern.jl codebase. |
feel like we should start with the split observations, stratified version, first. this is the simplest/clearest implementation I could come up with that works for (i) target vector will need some coaching... can write a few tests for this though. function splitobs_stratified(;at, y::Array, shuffle::Bool=true)
n_splits = length(at) + 1
the_splits = [Int[] for s = 1:n_splits]
for label in unique(y)
ids_this_label = filter(i -> y[i] == label, 1:length(y))
if shuffle
ids_this_label = shuffleobs(ids_this_label)
end
split_this_label = splitobs(ids_this_label, at=at)
for s = 1:n_splits
the_splits[s] = vcat(the_splits[s], split_this_label[s])
end
end
return the_splits
end targets = vcat([-1 for i = 1:10], [1 for i = 1:100])
splits = splitobs_stratified(at=(0.2, 0.5), y=targets, shuffle=true)
for s in splits
println(sum(targets[s] .== 1) / sum(targets[s] .== -1)) # 10.0 woo!
end |
@SimonEnsemble your approach seems fine, you should open a PR, we can discuss there the details |
With #195 we have stratified |
I can make a first-round PR, then willing to make whatever changes necessary, if anyone is willing to coach me through it/ look at my PR. thx.
The text was updated successfully, but these errors were encountered: