forked from ananyahjha93/multi-level-vae
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathalternate_data_loader.py
31 lines (24 loc) · 1.02 KB
/
alternate_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import random
import pdb
from torchvision import datasets
from utils import transform_config
from torch.utils.data import Dataset
class MNIST_Paired(Dataset):
def __init__(self, root='mnist', download=True, train=True, transform=transform_config):
self.mnist = datasets.MNIST(root=root, download=download, train=train, transform=transform)
self.data_dict = {}
for i in range(self.__len__()):
image, label = self.mnist.__getitem__(i)
try:
self.data_dict[label]
except KeyError:
self.data_dict[label] = []
self.data_dict[label].append(image)
def __len__(self):
return self.mnist.__len__()
def __getitem__(self, index):
image, label = self.mnist.__getitem__(index)
# return another image of the same class randomly selected from the data dictionary
# this is done to simulate pair-wise labeling of data
return image, random.SystemRandom().choice(self.data_dict[label]), label