-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataInput.py
More file actions
66 lines (56 loc) · 2.42 KB
/
DataInput.py
File metadata and controls
66 lines (56 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
import random
class DataInput(object):
def __init__(self, dataset_path, train_labels_file, batch_size, num_examples, image_width, image_height, num_channels, seed, dataset):
self.dataset_path = dataset_path
self.train_labels_file = train_labels_file
self.num_examples = num_examples
self.image_height = image_height
self.image_width = image_width
self.num_channels = num_channels
self.seed = seed
self.dataset = dataset
self.check_data(train_labels_file)
# Create the File Name queue
self.filename_queue = tf.train.string_input_producer([self.dataset_path + self.train_labels_file], num_epochs=None)
print(self.filename_queue)
# Reading the file line by line
self.reader = tf.TextLineReader()
print(self.reader)
# Parse the line of CSV
self.key_temp, self.value_temp = self.reader.read(self.filename_queue)
print(self.key_temp, self.value_temp)
self.record_defaults = [[1], ['']]
self.col1, self.col2 = tf.decode_csv(self.value_temp, record_defaults=self.record_defaults)
print(self.col1, self.col2)
# Decode the data into JPEG
self.decode_jpeg()
# setup the input pipeline
self.input_pipeline(batch_size)
def input_pipeline(self, batch_size,num_epochs=None):
self.min_after_dequeue = 10000
self.capacity = self.min_after_dequeue + 3 * batch_size
self.example_batch, self.label_batch = tf.train.shuffle_batch (
[self.train_image, self.col1], batch_size=batch_size, capacity=self.capacity,
min_after_dequeue=self.min_after_dequeue, seed=self.seed)
return self.example_batch, self.label_batch
def decode_jpeg(self):
file_content = tf.read_file(self.col2)
self.train_image = tf.image.decode_png(file_content, channels=self.num_channels)
distorted_image = tf.image.random_flip_left_right(self.train_image)
self.train_image = tf.image.per_image_standardization(distorted_image)
self.train_image = tf.image.resize_images(self.train_image, [self.image_width, self.image_height])
def check_data(self, train_labels_file):
f = open(train_labels_file)
lines = f.readlines()
f.close()
labels = []
for line in lines:
label = int(line.split(",")[0])
if label not in labels:
labels.append(label)
print("Read data from file: "+str(train_labels_file))
print("The number of data is: " + str(len(lines)))
print("The labels are: " + str(sorted(labels)))