Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR welcome on stratified K-folds? #60

Open
SimonEnsemble opened this issue Feb 22, 2022 · 4 comments
Open

PR welcome on stratified K-folds? #60

SimonEnsemble opened this issue Feb 22, 2022 · 4 comments
Labels
enhancement New feature or request

Comments

@SimonEnsemble
Copy link

I can make a first-round PR, then willing to make whatever changes necessary, if anyone is willing to coach me through it/ look at my PR. thx.

@darsnack
Copy link
Member

Yes, this would be a nice addition. You can probably adapt the MLDataPattern.jl code as a starting point. I would try and start with the simplest implementation first though, since part of the goal with this package is to cut down on the complexity of the MLDataPattern.jl codebase.

@darsnack darsnack added the enhancement New feature or request label Feb 22, 2022
@SimonEnsemble
Copy link
Author

feel like we should start with the split observations, stratified version, first.

this is the simplest/clearest implementation I could come up with that works for (i) target vector y with any number of classes and (ii) an at that could be a tuple. is there a cleaner way to do this? also, these are not views like in the rest of the package...

will need some coaching... can write a few tests for this though.

function splitobs_stratified(;at, y::Array, shuffle::Bool=true)
	n_splits = length(at) + 1
	the_splits = [Int[] for s = 1:n_splits]
	for label in unique(y)
		ids_this_label = filter(i -> y[i] == label, 1:length(y))
		if shuffle
			ids_this_label = shuffleobs(ids_this_label)
		end
		split_this_label = splitobs(ids_this_label, at=at)
		for s = 1:n_splits
			the_splits[s] = vcat(the_splits[s], split_this_label[s])
		end
	end
	return the_splits
end
targets = vcat([-1 for i = 1:10], [1 for i = 1:100])
splits = splitobs_stratified(at=(0.2, 0.5), y=targets, shuffle=true)
for s in splits
       println(sum(targets[s] .== 1) / sum(targets[s] .== -1)) # 10.0 woo!
end

@CarloLucibello
Copy link
Member

@SimonEnsemble your approach seems fine, you should open a PR, we can discuss there the details

@CarloLucibello
Copy link
Member

With #195 we have stratified splitobs. It should be easy to extend kfold in a similar way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants