Skip to content

Commit

Permalink
Fold splits.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcamino committed Aug 17, 2018
1 parent 7bfe704 commit 2ae2a83
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 35 deletions.
72 changes: 72 additions & 0 deletions multi_categorical_gans/datasets/fold_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import print_function

import argparse
import os

from multi_categorical_gans.datasets.formats import data_formats, loaders, savers
from sklearn.model_selection import KFold


def main():
options_parser = argparse.ArgumentParser(description="Split features file into train and test files.")

options_parser.add_argument("features", type=str, help="Input features file.")
options_parser.add_argument("folds", type=int, help="Number of folds.")
options_parser.add_argument("output_directory", type=str, help="Output directory path for the folds.")
options_parser.add_argument("features_template", type=str, help="Output features file name template. ")

options_parser.add_argument(
"--features_format",
type=str,
default="sparse",
choices=data_formats,
help="Either a dense numpy array or a sparse csr matrix."
)

options_parser.add_argument("--labels", type=str, help="Input labels file.")
options_parser.add_argument("--labels_template", type=str, help="Output labels file name template. ")

options_parser.add_argument(
"--labels_format",
type=str,
default="sparse",
choices=data_formats,
help="Either a dense numpy array or a sparse csr matrix."
)

options_parser.add_argument("--shuffle", default=False, action="store_true",
help="Shuffle the dataset before the split.")

options = options_parser.parse_args()

features_loader = loaders[options.features_format]
features_saver = savers[options.features_format]
features = features_loader(options.features, transform=False)

if options.labels is not None:
labels_loader = loaders[options.labels_format]
labels_saver = savers[options.labels_format]
labels = labels_loader(options.labels, transform=False)

k_fold = KFold(n_splits=options.folds, shuffle=options.shuffle)
for fold_number, (train_index, test_index) in enumerate(k_fold.split(features)):
train_features, test_features = features[train_index, :], features[test_index, :]
template = os.path.join(options.output_directory, options.features_template)

features_saver(template.format(name="train", number=fold_number, total=options.folds), train_features)
features_saver(template.format(name="test", number=fold_number, total=options.folds), test_features)

print("Train features:", train_features.shape, "Test features:", test_features.shape)

if options.labels is not None:
train_labels, test_labels = labels[train_index], labels[test_index]
template = os.path.join(options.output_directory, options.labels_template)

labels_saver(template.format(name="train", number=fold_number, total=options.folds), train_labels)
labels_saver(template.format(name="test", number=fold_number, total=options.folds), test_labels)

print("Train labels:", train_labels.shape, "Test labels:", test_labels.shape)


if __name__ == "__main__":
main()
74 changes: 39 additions & 35 deletions multi_categorical_gans/datasets/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,71 @@

import argparse

import numpy as np

from multi_categorical_gans.datasets.formats import data_formats, loaders, savers


def train_test_split(features, train_size, percent=False, shuffle=False):
num_samples = features.shape[0]
if percent:
assert 0 < train_size < 100, "Invalid percent value."
limit = int(num_samples * (train_size / 100.0))
else:
assert 0 < train_size < num_samples, "Invalid size."
limit = train_size

if shuffle:
index = np.arange(num_samples)
np.random.shuffle(index)
features = features[index, :]

return features[:limit, :], features[limit:, :]
from sklearn.model_selection import train_test_split


def main():
options_parser = argparse.ArgumentParser(description="Split features file into train and test files.")

options_parser.add_argument("features", type=str, help="Input features file.")
options_parser.add_argument("train_size", type=int, help="Number of samples for the train part.")
options_parser.add_argument("train_size", type=float, help="Number or proportion of samples for the train part.")
options_parser.add_argument("train_features", type=str, help="Output train features file.")
options_parser.add_argument("test_features", type=str, help="Output test features file.")

options_parser.add_argument(
"--data_format",
"--features_format",
type=str,
default="sparse",
choices=data_formats,
help="Either a dense numpy array or a sparse csr matrix."
)

options_parser.add_argument("--percent", default=False, action="store_true",
help="Interpret the train size as a percentage.")
options_parser.add_argument("--labels", type=str, help="Input labels file.")
options_parser.add_argument("--train_labels", type=str, help="Output train labels file.")
options_parser.add_argument("--test_labels", type=str, help="Output test labels file.")

options_parser.add_argument(
"--labels_format",
type=str,
default="sparse",
choices=data_formats,
help="Either a dense numpy array or a sparse csr matrix."
)

options_parser.add_argument("--shuffle", default=False, action="store_true",
help="Shuffle the dataset before the split.")

options = options_parser.parse_args()

loader = loaders[options.data_format]
saver = savers[options.data_format]
features_loader = loaders[options.features_format]
features_saver = savers[options.features_format]
features = features_loader(options.features, transform=False)

train_features, test_features = train_test_split(
loader(options.features, transform=False),
options.train_size,
percent=options.percent,
shuffle=options.shuffle
)
if options.labels is None:
train_features, test_features = train_test_split(features,
train_size=options.train_size,
shuffle=options.shuffle)
else:
labels_loader = loaders[options.labels_format]
labels_saver = savers[options.labels_format]
labels = labels_loader(options.labels, transform=False)

train_features, test_features, train_labels, test_labels = train_test_split(features,
labels,
train_size=options.train_size,
shuffle=options.shuffle)

features_saver(options.train_features, train_features)
features_saver(options.test_features, test_features)

print("Train features:", train_features.shape, "Test features:", test_features.shape)

saver(options.train_features, train_features)
saver(options.test_features, test_features)
if options.labels is not None:
labels_saver(options.train_labels, train_labels)
labels_saver(options.test_labels, test_labels)

print("Train features:", train_features.shape)
print("Test features:", test_features.shape)
print("Train labels:", train_labels.shape, "Test labels:", test_labels.shape)


if __name__ == "__main__":
Expand Down

0 comments on commit 2ae2a83

Please sign in to comment.