-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
40 lines (37 loc) · 1.23 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import csv
class Dataset:
def __init__(self, dir_path='tenenbaum_data', data=None, one_file=False):
if data is not None:
self.data = data
return
file_list = os.listdir(dir_path)
self.data = []
for file in file_list:
l = file.split('.')[0]
l = eval(l)
if one_file:
if l != 16:
continue
if not isinstance(l, list):
l = [l]
print(f"The given list of file {file}: {l}")
file_path = os.path.join(dir_path, file)
# load csv data
csv_reader = csv.reader(open(file_path, 'r'))
for row in csv_reader:
num = int(row[0])
rate = float(row[1])
self.data.append((
l,
num,
rate
))
print(f"Loaded {len(self.data)} data points")
def get_length(self):
return len(self.data)
def get_data(self, idx):
return self.data[idx]
def split(self, ratio=0.9):
split_idx = int(len(self.data) * ratio)
return Dataset(data=self.data[:split_idx]), Dataset(data=self.data[split_idx:])