-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput_pipeline.py
78 lines (65 loc) · 2.54 KB
/
input_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
CROP_PADDING = 32
def preprocess_for_train(image, seed, seed2, image_size, dtype):
begin, size, _ = tf.image.stateless_sample_distorted_bounding_box(
tf.shape(image),
tf.zeros([0, 0, 4], tf.float32),
area_range=(0.05, 1.0),
min_object_covered=0,
use_image_if_no_bounding_boxes=True,
seed=(seed, seed2))
image = tf.slice(image, begin, size)
image.set_shape([None, None, 3])
image = tf.image.resize(image, [image_size, image_size])
image = tf.cast(image, dtype)
return image
def preprocess_for_eval(image, image_size, dtype):
shape = tf.shape(image)
height, width = shape[0], shape[1]
ratio = image_size / (image_size + CROP_PADDING)
crop_size = tf.cast(
(ratio * tf.cast(tf.minimum(height, width), tf.float32)), tf.int32)
y, x = (height - crop_size) // 2, (width - crop_size) // 2
image = tf.image.crop_to_bounding_box(image, y, x, crop_size, crop_size)
image.set_shape([None, None, 3])
image = tf.image.resize(image, [image_size, image_size])
image = tf.cast(image, dtype)
return image
def create_split(dataset_builder, batch_size, image_size, train, dtype, seed, cache=False):
if train:
tf.random.set_seed(seed)
train_examples = dataset_builder.info.splits['train'].num_examples
split_size = train_examples // jax.process_count()
start = jax.process_index() * split_size
split = 'train[{}:{}]'.format(start, start + split_size)
else:
validate_examples = dataset_builder.info.splits['validation'].num_examples
split_size = validate_examples // jax.process_count()
start = jax.process_index() * split_size
split = 'validation[{}:{}]'.format(start, start + split_size)
def preprocess_example(example, seed2=None):
image = example['image']
if train:
image = preprocess_for_train(image, seed, seed2, image_size, dtype)
else:
image = preprocess_for_eval(image, image_size, dtype)
image = image / 127.5 - 1
return {'image': image, 'label': example['label']}
ds = dataset_builder.as_dataset(split=split)
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)
if cache:
ds = ds.cache()
if train:
ds = ds.repeat()
ds = tf.data.Dataset.zip((ds, tf.data.Dataset.random(seed)))
ds = ds.shuffle(16 * batch_size, seed=seed)
ds = ds.map(preprocess_example, num_parallel_calls=-1)
ds = ds.batch(batch_size, drop_remainder=True)
if not train:
ds = ds.repeat()
ds = ds.prefetch(10)
return ds