-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataloader.py
37 lines (23 loc) · 1.07 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
from torch.utils import data
# cat_embeds.append(Cat_emb) # cat_paths.append(Cat_path) # cat_labels.append(Class_Label)
class Dev_Embd_Dataset(data.Dataset):
def __init__(self, cat_embeds, cat_labels):
self.cat_embeds = cat_embeds
self.cat_labels = cat_labels
def __len__(self): # ``__len__``, 数据集的大小,cat_embeds训练集有10W个,测试集1W个
return len(self.cat_embeds)
def __getitem__(self, index): # ``__getitem__``, 支持范围从 0 到 len(self) 的整数索引。
X = self.cat_embeds[index]
X = X.astype(np.float32)
y = self.cat_labels[index] # 标签 0-1210
return X, y
class Val_Embd_Dataset(data.Dataset):
def __init__(self, cat_embeds):
self.cat_embeds = cat_embeds
def __len__(self): # ``__len__``, 数据集的大小,cat_embeds训练集有10W个,测试集1W个
return len(self.cat_embeds)
def __getitem__(self, index): # ``__getitem__``, 支持范围从 0 到 len(self) 的整数索引。
X = self.cat_embeds[index]
X = X.astype(np.float32)
return X