-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathdataset.py
49 lines (34 loc) · 1.39 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import division
from future.utils import implements_iterator
import numpy as np
class Dataset(object):
def __init__(self, features):
self.features = features
def split(self, proportion):
assert 0 < proportion < 1, "Proportion should be between 0 and 1."
limit = int(np.floor(len(self.features) * proportion))
return Dataset(self.features[:limit, :]), Dataset(self.features[limit:, :])
def batch_iterator(self, batch_size, shuffle=True):
if shuffle:
indices = np.random.permutation(len(self.features))
else:
indices = np.arange(len(self.features))
return DatasetIterator(self.features, indices, batch_size)
@implements_iterator
class DatasetIterator(object):
def __init__(self, features, indices, batch_size):
self.features = features
self.indices = indices
self.batch_size = batch_size
self.batch_index = 0
self.num_batches = int(np.ceil(len(features) / batch_size))
def __iter__(self):
return self
def __next__(self):
if self.batch_index >= self.num_batches:
raise StopIteration
else:
batch_start = self.batch_index * self.batch_size
batch_end = (self.batch_index + 1) * self.batch_size
self.batch_index += 1
return self.features[self.indices[batch_start:batch_end]]