forked from tysam-code/hlb-CIFAR10
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
174 lines (140 loc) · 8.97 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import functools
from functools import partial
import platform
import logging
import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import utils
config = {
'data': {
'cutout_size': 0,
'cache_filename': 'data_cache.pt',
}
}
def get_dataset(download_path, cache_path, device, dtype, pad_amount):
data = None
cache_path = utils.full_path(cache_path, create=True)
cache_filepath = os.path.join(cache_path, config['data']['cache_filename'])
if os.path.exists(cache_filepath):
logging.info("Loading cached dataset from %s", cache_filepath)
# This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)
# So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above
# config dictionary, then we should be good. :)
data = torch.load(cache_filepath)
if data['train']['images'].dtype != dtype:
logging.info(f"Cached dataset is different dtype {data['train']['images'].dtype} than requested {dtype}, so regenerating cache")
data = None
else:
logging.info("No cached dataset found at %s", cache_filepath)
if data is None:
logging.info("Generating dataset cache...")
cifar10_mean, cifar10_std = [
torch.tensor([0.4913997551666284, 0.48215855929893703,
0.4465309133731618], device=device),
torch.tensor([0.24703225141799082, 0.24348516474564,
0.26158783926049628], device=device)
]
transform = transforms.Compose([
transforms.ToTensor()])
download_path = utils.full_path(download_path, create=True)
cifar10 = torchvision.datasets.CIFAR10(
download_path, download=True, train=True, transform=transform)
cifar10_eval = torchvision.datasets.CIFAR10(
download_path, download=False, train=False, transform=transform)
# use the dataloader to get a single batch of all of the dataset items at once.
train_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10, batch_size=len(cifar10), drop_last=True,
shuffle=True)
eval_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10_eval, batch_size=len(cifar10_eval), drop_last=True,
shuffle=False)
train_dataset_gpu = {}
eval_dataset_gpu = {}
train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(
device=device, non_blocking=True) for item in next(iter(train_dataset_gpu_loader))]
eval_dataset_gpu['images'], eval_dataset_gpu['targets'] = [item.to(
device=device, non_blocking=True) for item in next(iter(eval_dataset_gpu_loader))]
def batch_normalize_images(input_images, mean, std):
return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)
# preload with our mean and std
batch_normalize_images = partial(
batch_normalize_images, mean=cifar10_mean, std=cifar10_std)
# Batch normalize datasets, now. Wowie. We did it! We should take a break and make some tea now.
train_dataset_gpu['images'] = batch_normalize_images(
train_dataset_gpu['images'])
eval_dataset_gpu['images'] = batch_normalize_images(
eval_dataset_gpu['images'])
data = {
'train': train_dataset_gpu,
'eval': eval_dataset_gpu
}
# Convert dataset to FP16 now for the rest of the process....
data['train']['images'] = data['train']['images'].to(dtype=dtype)
data['eval']['images'] = data['eval']['images'].to(dtype=dtype)
torch.save(data, cache_filepath)
# As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way
# of measuring other things. That said, measuring the preprocessing (outside of the padding) is still important to us.
# Pad the GPU training dataset
if pad_amount > 0:
# Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary
data['train']['images'] = F.pad(
data['train']['images'], (pad_amount,)*4, 'reflect')
return data
## This is actually (I believe) a pretty clean implementation of how to do something like this, since shifted-square masks unique to each depth-channel can actually be rather
## tricky in practice. That said, if there's a better way, please do feel free to submit it! This can be one of the harder parts of the code to understand (though I personally get
## stuck on the fold/unfold process for the lower-level convolution calculations.
def make_random_square_masks(inputs, mask_size):
##### TODO: Double check that this properly covers the whole range of values. :'( :')
if mask_size == 0:
return None # no need to cutout or do anything like that since the patch_size is set to 0
is_even = int(mask_size % 2 == 0)
in_shape = inputs.shape
# seed centers of squares to cutout boxes from, in one dimension each
mask_center_y = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-2]-mask_size//2-is_even)
mask_center_x = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-1]-mask_size//2-is_even)
# measure distance, using the center as a reference point
to_mask_y_dists = torch.arange(in_shape[-2], device=inputs.device).view(1, 1, in_shape[-2], 1) - mask_center_y.view(-1, 1, 1, 1)
to_mask_x_dists = torch.arange(in_shape[-1], device=inputs.device).view(1, 1, 1, in_shape[-1]) - mask_center_x.view(-1, 1, 1, 1)
to_mask_y = (to_mask_y_dists >= (-(mask_size // 2) + is_even)) * (to_mask_y_dists <= mask_size // 2)
to_mask_x = (to_mask_x_dists >= (-(mask_size // 2) + is_even)) * (to_mask_x_dists <= mask_size // 2)
final_mask = to_mask_y * to_mask_x ## Turn (y by 1) and (x by 1) boolean masks into (y by x) masks through multiplication. Their intersection is square, hurray! :D
return final_mask
def batch_cutout(inputs, patch_size):
with torch.no_grad():
cutout_batch_mask = make_random_square_masks(inputs, patch_size)
if cutout_batch_mask is None:
return inputs # if the mask is None, then that's because the patch size was set to 0 and we will not be using cutout today.
# TODO: Could be fused with the crop operation for sheer speeeeeds. :D <3 :))))
cutout_batch = torch.where(cutout_batch_mask, torch.zeros_like(inputs), inputs)
return cutout_batch
def batch_crop(inputs, crop_size):
with torch.no_grad():
crop_mask_batch = make_random_square_masks(inputs, crop_size)
cropped_batch = torch.masked_select(inputs, crop_mask_batch).view(inputs.shape[0], inputs.shape[1], crop_size, crop_size)
return cropped_batch
def batch_flip_lr(batch_images, flip_chance=.5):
with torch.no_grad():
# TODO: Is there a more elegant way to do this? :') :'((((
return torch.where(torch.rand_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)) < flip_chance, torch.flip(batch_images, (-1,)), batch_images)
# TODO: Could we jit this in the (more distant) future? :)
@torch.no_grad()
def get_batches(data_dict, key, batchsize, memory_format, device, cutout_size=config['data']['cutout_size']):
num_epoch_examples = len(data_dict[key]['images'])
shuffled = torch.randperm(num_epoch_examples, device=device)
crop_size = 32
## Here, we prep the dataset by applying all data augmentations in batches ahead of time before each epoch, then we return an iterator below
## that iterates in chunks over with a random derangement (i.e. shuffled indices) of the individual examples. So we get perfectly-shuffled
## batches (which skip the last batch if it's not a full batch), but everything seems to be (and hopefully is! :D) properly shuffled. :)
if key == 'train':
images = batch_crop(data_dict[key]['images'], crop_size) # TODO: hardcoded image size for now?
images = batch_flip_lr(images)
images = batch_cutout(images, patch_size=cutout_size)
else:
images = data_dict[key]['images']
# Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training
images = images.to(memory_format=memory_format)
for idx in range(num_epoch_examples // batchsize):
if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch
yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \
data_dict[key]['targets'].index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]) ## Each item is only used/accessed by the network once per epoch. :D