forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
269 lines (229 loc) · 9.84 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""Utilities for data processing."""
from typing import Dict, Iterable, Optional, Tuple
import jax
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler
from torch.utils.data import Sampler
from algoperf import spec
def shard_and_maybe_pad_np(
batch: Dict[str, spec.Tensor],
padding_value: int = 0,
global_batch_size: Optional[int] = None) -> Dict[str, spec.Tensor]:
"""Prepare tf data for JAX or PyTorch DDP.
Convert an input batch from tf Tensors to numpy arrays, pad it with
padding_value if the batch size is not divisible by the number of devices,
create the corresponding mask, and reshape it to be sharded across devices.
"""
local_device_count = max(torch.cuda.device_count(), jax.local_device_count())
inputs = batch['inputs']
current_batch_size = inputs[0].shape[0] if isinstance(
inputs, tuple) else inputs.shape[0]
if global_batch_size is not None:
assert global_batch_size >= current_batch_size, \
'global_batch_size must be larger than or equal to current_batch_size.'
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
pad_to_global_batch_size = False
remainder_size = current_batch_size % local_device_count
if remainder_size != 0 or pad_to_global_batch_size:
if global_batch_size is not None:
pad_size = global_batch_size - current_batch_size
else:
pad_size = local_device_count - remainder_size
targets = batch['targets']
targets_shape = tuple(
targets[0].shape if isinstance(targets, tuple) else targets.shape)
# We need a 2d mask for WMT.
mask_shape = targets_shape if len(targets_shape) < 3 else targets_shape[0]
# Get weights from batch if there are any.
weights = batch.get('weights')
# The weights will also be padded.
batch['weights'] = np.ones(mask_shape) if weights is None else weights
def _prepare(x):
# Use _numpy() for zero-copy conversion between TF and NumPy.
if not isinstance(x, np.ndarray):
x = x._numpy() # pylint: disable=protected-access
# Pad if remainder_size != 0 (should only be possible during evaluation).
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)
# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))
return jax.tree_map(_prepare, batch)
def pad(tensor: np.ndarray,
pad_size: int,
padding_value: int = 0) -> np.ndarray:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
return padded_tensor
def mixup_pytorch(batch: Tuple[spec.Tensor, spec.Tensor],
alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]:
inputs, targets = batch
# Transform to one-hot targets.
targets = F.one_hot(targets, num_classes=1000)
# Compute weight for convex combination by sampling from Beta distribution.
beta_dist = torch.distributions.beta.Beta(alpha, alpha)
weight = beta_dist.sample()
# Return convex combination of original and shifted inputs and targets.
inputs = weight * inputs + (1.0 - weight) * torch.roll(inputs, 1, dims=0)
targets = weight * targets + (1.0 - weight) * torch.roll(targets, 1, dims=0)
return (inputs, targets)
# github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py
class DistributedEvalSampler(Sampler):
r"""DistributedEvalSampler is different from DistributedSampler.
It does NOT add extra samples to make it evenly divisible.
DistributedEvalSampler should NOT be used for training. The distributed
processes could hang forever.
See this issue for details: https://github.com/pytorch/pytorch/issues/22584
shuffle is disabled by default
DistributedEvalSampler is for evaluation purpose where synchronization does
not happen every epoch.
Synchronization should be done outside the dataloader loop.
Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
process can pass a :class`~DistributedEvalSampler` instance as
a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`rank` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within
:attr:`num_replicas`. By default, :attr:`rank` is retrieved from the
current distributed group.
shuffle (bool, optional): If ``True``, sampler will shuffle the
indices. Default: ``False``
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
.. warning::
In distributed mode, calling the :meth`set_epoch(epoch) <set_epoch>`
method at the beginning of each epoch **before** creating the
:class:`DataLoader` iterator is necessary to make shuffling work
properly across multiple epochs. Otherwise, the same ordering will be
always used.
Example::
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(self,
dataset: torch.utils.data.Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
seed: int = 0) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError('Requires distributed package to be available.')
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError('Requires distributed package to be available.')
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
# true value without extra samples
self.total_size = len(self.dataset)
indices = list(range(self.total_size))
indices = indices[self.rank:self.total_size:self.num_replicas]
# true value without extra samples
self.num_samples = len(indices)
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterable[int]:
if self.shuffle:
# Deterministically shuffle based on epoch and seed.
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Subsample.
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this
ensures all replicas use a different random ordering for each epoch.
Otherwise, the next iteration of this sampler will yield the same
ordering.
Args:
epoch: An int indicating epoch number.
"""
self.epoch = epoch
# Modified from github.com/pytorch/pytorch/issues/23900#issuecomment-518858050.
def cycle(iterable: Iterable,
keys: Tuple[str, ...] = ('inputs', 'targets'),
custom_sampler: bool = False,
use_mixup: bool = False,
mixup_alpha: float = 0.2) -> Iterable:
iterator = iter(iterable)
epoch = 0
while True:
try:
batch = next(iterator)
if use_mixup:
assert keys == ('inputs', 'targets')
batch = mixup_pytorch(batch, alpha=mixup_alpha)
assert len(keys) == len(batch)
yield dict(zip(keys, batch))
except StopIteration:
if custom_sampler and isinstance(iterable, DataLoader):
epoch += 1
iterable.sampler.set_epoch(epoch)
iterator = iter(iterable)
# Inspired by
# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/
# ConvNets/image_classification/dataloaders.py
class PrefetchedWrapper:
def __init__(self,
dataloader: DataLoader,
device: torch.device,
start_epoch: int = 0) -> None:
self.dataloader = dataloader
self.epoch = start_epoch
self.device = device
def __len__(self) -> int:
return len(self.dataloader)
def __iter__(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(self.epoch)
self.epoch += 1
return self.prefetched_loader()
def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]:
stream = torch.cuda.Stream()
first = True
for next_inputs, next_targets in self.dataloader:
with torch.cuda.stream(stream):
next_inputs = next_inputs.to(
self.device, dtype=torch.float, non_blocking=True)
next_targets = next_targets.to(self.device, non_blocking=True)
if not first:
yield inputs, targets
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
inputs = next_inputs
targets = next_targets
yield inputs, targets