Skip to content

Unable to create a small sample of 1000 train and 100 using MultilabelStratifiedShuffleSplit #15

@meltedhead

Description

@meltedhead

Hi trent-b:

Thanks for this repository, hope you can help with my issue. I have a large json data set that i want to use MultilabelStratifiedShuffleSplit to create a smaller sample set.

def mlb_train_test_split(labels, test_size, train_size, random_state=0):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        msss = MultilabelStratifiedShuffleSplit(
            test_size=test_size, train_size=train_size, random_state=random_state
        )
    train_idx, test_idx = next(msss.split(np.ones_like(labels), labels))
    return train_idx, test_idx

i then call the function as :

train_idx, test_idx = mlb_train_test_split(labels, test_size=1000 train_size=200, random_state=0)

When i look at the numbers I'm seeing way more than 200 rows. Is there a limitation? The labels length is approximately 500,000 in the dataset.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions