-
Notifications
You must be signed in to change notification settings - Fork 0
/
tinyimagenet.py
83 lines (69 loc) · 2.85 KB
/
tinyimagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import glob
import os
from zipfile import ZipFile
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.io.image import ImageReadMode, decode_image, read_image
id_dict = {}
for i, line in enumerate(open('data/TinyImageNet/wnids.txt', 'r')):
id_dict[line.replace('\n', '')] = i
class TrainTinyImageNetDataset(Dataset):
def __init__(self, transform=None):
self.filenames = glob.glob("data/TinyImageNet/train/*/images/*.JPEG")
self.transform = transform
self.id_dict = id_dict
self.targets = [
self.id_dict[img_path.split('/')[3]] for img_path in self.filenames
]
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path, ImageReadMode.RGB)
label = self.id_dict[img_path.split('/')[3]]
if self.transform:
image = self.transform(image.type(torch.FloatTensor))
return image, label
class TestTinyImageNetDataset(Dataset):
def __init__(self, transform=None):
self.filenames = glob.glob(
"data/TinyImageNet/val/images/*.JPEG")
self.transform = transform
self.id_dict = id_dict
self.cls_dic = {}
for i, line in enumerate(open('data/TinyImageNet/val/val_annotations.txt', 'r')):
a = line.split('\t')
img, cls_id = a[0], a[1]
self.cls_dic[img] = self.id_dict[cls_id]
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path, ImageReadMode.RGB)
label = self.cls_dic[img_path.split('/')[-1]]
if self.transform:
image = self.transform(image.type(torch.FloatTensor))
return image, label
class CorruptTinyImageNetDataset(Dataset):
def __init__(self, intensity, transform=None, corruption_name=None):
if corruption_name is None:
self.zipfiles = glob.glob(f'data/TinyImageNet/corrupt/*/{intensity}/*.zip')
else:
self.zipfiles = glob.glob(f'data/TinyImageNet/corrupt/{corruption_name}/{intensity}/*.zip')
self.transform = transform
self.id_dict = id_dict
def __len__(self):
return len(self.zipfiles) * 50
def __getitem__(self, idx):
zipfile = self.zipfiles[idx // 50]
with ZipFile(zipfile, 'r') as myzip:
img = myzip.namelist()[idx % 50]
img = myzip.read(img)
image = decode_image(torch.from_numpy(np.frombuffer(bytearray(img), dtype=np.uint8)), ImageReadMode.RGB)
label = self.id_dict[zipfile.split('/')[-1][:-4]]
if self.transform:
image = self.transform(image.type(torch.FloatTensor))
return image, label