Skip to content

Commit 8cf24f4

Browse files
authored
Split iris_data.py from estimator examples. (tensorflow#2954)
* Split iris data * move shared code to iris_data.py * add minimal csv example * remove unused pandas import * Use sparse softmax loss to avoid warning
1 parent 5a5d330 commit 8cf24f4

File tree

4 files changed

+121
-127
lines changed

4 files changed

+121
-127
lines changed

samples/core/get_started/custom_estimator.py

Lines changed: 12 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -17,67 +17,15 @@
1717
from __future__ import print_function
1818

1919
import argparse
20-
import pandas as pd
2120
import tensorflow as tf
2221

22+
import iris_data
23+
2324
parser = argparse.ArgumentParser()
2425
parser.add_argument('--batch_size', default=100, type=int, help='batch size')
2526
parser.add_argument('--train_steps', default=1000, type=int,
2627
help='number of training steps')
2728

28-
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
29-
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
30-
31-
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
32-
'PetalLength', 'PetalWidth', 'Species']
33-
SPECIES = ['Sentosa', 'Versicolor', 'Virginica']
34-
35-
36-
def load_data(y_name='Species'):
37-
"""Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
38-
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
39-
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
40-
train_x, train_y = train, train.pop(y_name)
41-
42-
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
43-
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
44-
test_x, test_y = test, test.pop(y_name)
45-
46-
return (train_x, train_y), (test_x, test_y)
47-
48-
49-
50-
def train_input_fn(features, labels, batch_size):
51-
"""An input function for training"""
52-
# Convert the inputs to a Dataset.
53-
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
54-
55-
# Shuffle, repeat, and batch the examples.
56-
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
57-
58-
# Return the read end of the pipeline.
59-
return dataset.make_one_shot_iterator().get_next()
60-
61-
62-
def eval_input_fn(features, labels=None, batch_size=None):
63-
"""An input function for evaluation or prediction"""
64-
if labels is None:
65-
# No labels, use only features.
66-
inputs = features
67-
else:
68-
inputs = (features, labels)
69-
70-
# Convert the inputs to a Dataset.
71-
dataset = tf.data.Dataset.from_tensor_slices(inputs)
72-
73-
# Batch the examples
74-
assert batch_size is not None, "batch_size must not be None"
75-
dataset = dataset.batch(batch_size)
76-
77-
# Return the read end of the pipeline.
78-
return dataset.make_one_shot_iterator().get_next()
79-
80-
8129
def my_model(features, labels, mode, params):
8230
"""DNN with three hidden layers, and dropout of 0.1 probability."""
8331
# Create three fully connected layers each layer having a dropout
@@ -99,12 +47,8 @@ def my_model(features, labels, mode, params):
9947
}
10048
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
10149

102-
# Convert the labels to a one-hot tensor of shape (length of features, 3)
103-
# and with a on-value of 1 for each one-hot vector of length 3.
104-
onehot_labels = tf.one_hot(labels, 3, 1, 0)
10550
# Compute loss.
106-
loss = tf.losses.softmax_cross_entropy(
107-
onehot_labels=onehot_labels, logits=logits)
51+
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
10852

10953
# Compute evaluation metrics.
11054
accuracy = tf.metrics.accuracy(labels=labels,
@@ -129,9 +73,7 @@ def main(argv):
12973
args = parser.parse_args(argv[1:])
13074

13175
# Fetch the data
132-
(train_x, train_y), (test_x, test_y) = load_data()
133-
train_x = dict(train_x)
134-
test_x = dict(test_x)
76+
(train_x, train_y), (test_x, test_y) = iris_data.load_data()
13577

13678
# Feature columns describe how to use the input.
13779
my_feature_columns = []
@@ -151,12 +93,12 @@ def main(argv):
15193

15294
# Train the Model.
15395
classifier.train(
154-
input_fn=lambda:train_input_fn(train_x, train_y, args.batch_size),
96+
input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
15597
steps=args.train_steps)
15698

15799
# Evaluate the model.
158100
eval_result = classifier.evaluate(
159-
input_fn=lambda:eval_input_fn(test_x, test_y, args.batch_size))
101+
input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
160102

161103
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
162104

@@ -170,14 +112,18 @@ def main(argv):
170112
}
171113

172114
predictions = classifier.predict(
173-
input_fn=lambda:eval_input_fn(predict_x, batch_size=args.batch_size))
115+
input_fn=lambda:iris_data.eval_input_fn(predict_x,
116+
labels=None,
117+
batch_size=args.batch_size))
174118

175119
for pred_dict, expec in zip(predictions, expected):
176120
template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
177121

178122
class_id = pred_dict['class_ids'][0]
179123
probability = pred_dict['probabilities'][class_id]
180-
print(template.format(SPECIES[class_id], 100 * probability, expec))
124+
125+
print(template.format(iris_data.SPECIES[class_id],
126+
100 * probability, expec))
181127

182128

183129
if __name__ == '__main__':

samples/core/get_started/estimator_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from six.moves import StringIO
2525

26+
import iris_data
2627
import custom_estimator
2728
import premade_estimator
2829

@@ -35,7 +36,7 @@
3536
def four_lines_data():
3637
text = StringIO(FOUR_LINES)
3738

38-
df = pd.read_csv(text, names=premade_estimator.CSV_COLUMN_NAMES)
39+
df = pd.read_csv(text, names=iris_data.CSV_COLUMN_NAMES)
3940

4041
xy = (df, df.pop("Species"))
4142
return xy, xy
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import pandas as pd
2+
import tensorflow as tf
3+
4+
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
5+
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
6+
7+
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
8+
'PetalLength', 'PetalWidth', 'Species']
9+
SPECIES = ['Sentosa', 'Versicolor', 'Virginica']
10+
11+
def maybe_download():
12+
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
13+
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
14+
15+
return train_path, test_path
16+
17+
def load_data(y_name='Species'):
18+
"""Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
19+
train_path, test_path = maybe_download()
20+
21+
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
22+
train_x, train_y = train, train.pop(y_name)
23+
24+
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
25+
test_x, test_y = test, test.pop(y_name)
26+
27+
return (train_x, train_y), (test_x, test_y)
28+
29+
30+
def train_input_fn(features, labels, batch_size):
31+
"""An input function for training"""
32+
# Convert the inputs to a Dataset.
33+
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
34+
35+
# Shuffle, repeat, and batch the examples.
36+
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
37+
38+
# Return the read end of the pipeline.
39+
return dataset.make_one_shot_iterator().get_next()
40+
41+
42+
def eval_input_fn(features, labels, batch_size):
43+
"""An input function for evaluation or prediction"""
44+
features=dict(features)
45+
if labels is None:
46+
# No labels, use only features.
47+
inputs = features
48+
else:
49+
inputs = (features, labels)
50+
51+
# Convert the inputs to a Dataset.
52+
dataset = tf.data.Dataset.from_tensor_slices(inputs)
53+
54+
# Batch the examples
55+
assert batch_size is not None, "batch_size must not be None"
56+
dataset = dataset.batch(batch_size)
57+
58+
# Return the read end of the pipeline.
59+
return dataset.make_one_shot_iterator().get_next()
60+
61+
62+
# The remainder of this file contains a simple example of a csv parser,
63+
# implemented using a the `Dataset` class.
64+
65+
# `tf.parse_csv` sets the types of the outputs to match the examples given in
66+
# the `record_defaults` argument.
67+
CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]
68+
69+
def _parse_line(line):
70+
# Decode the line into its fields
71+
fields = tf.decode_csv(line, record_defaults=CSV_TYPES)
72+
73+
# Pack the result into a dictionary
74+
features = dict(zip(CSV_COLUMN_NAMES, fields))
75+
76+
# Separate the label from the features
77+
label = features.pop('Species')
78+
79+
return features, label
80+
81+
82+
def csv_input_fn(csv_path, batch_size):
83+
# Create a dataset containing the text lines.
84+
dataset = tf.data.TextLineDataset(csv_path).skip(1)
85+
86+
# Parse each line.
87+
dataset = dataset.map(_parse_line)
88+
89+
# Shuffle, repeat, and batch the examples.
90+
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
91+
92+
# Return the read end of the pipeline.
93+
return dataset.make_one_shot_iterator().get_next()

samples/core/get_started/premade_estimator.py

Lines changed: 14 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,73 +17,21 @@
1717
from __future__ import print_function
1818

1919
import argparse
20-
import pandas as pd
2120
import tensorflow as tf
2221

22+
import iris_data
23+
24+
2325
parser = argparse.ArgumentParser()
2426
parser.add_argument('--batch_size', default=100, type=int, help='batch size')
2527
parser.add_argument('--train_steps', default=1000, type=int,
2628
help='number of training steps')
2729

28-
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
29-
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
30-
31-
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
32-
'PetalLength', 'PetalWidth', 'Species']
33-
SPECIES = ['Sentosa', 'Versicolor', 'Virginica']
34-
35-
36-
def load_data(y_name='Species'):
37-
"""Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
38-
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
39-
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
40-
train_x, train_y = train, train.pop(y_name)
41-
42-
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
43-
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
44-
test_x, test_y = test, test.pop(y_name)
45-
46-
return (train_x, train_y), (test_x, test_y)
47-
48-
49-
def train_input_fn(features, labels, batch_size):
50-
"""An input function for training"""
51-
# Convert the inputs to a Dataset.
52-
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
53-
54-
# Shuffle, repeat, and batch the examples.
55-
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
56-
57-
# Return the read end of the pipeline.
58-
return dataset.make_one_shot_iterator().get_next()
59-
60-
61-
def eval_input_fn(features, labels=None, batch_size=None):
62-
"""An input function for evaluation or prediction"""
63-
if labels is None:
64-
# No labels, use only features.
65-
inputs = features
66-
else:
67-
inputs = (features, labels)
68-
69-
# Convert the inputs to a Dataset.
70-
dataset = tf.data.Dataset.from_tensor_slices(inputs)
71-
72-
# Batch the examples
73-
assert batch_size is not None, "batch_size must not be None"
74-
dataset = dataset.batch(batch_size)
75-
76-
# Return the read end of the pipeline.
77-
return dataset.make_one_shot_iterator().get_next()
78-
79-
8030
def main(argv):
8131
args = parser.parse_args(argv[1:])
8232

8333
# Fetch the data
84-
(train_x, train_y), (test_x, test_y) = load_data()
85-
train_x = dict(train_x)
86-
test_x = dict(test_x)
34+
(train_x, train_y), (test_x, test_y) = iris_data.load_data()
8735

8836
# Feature columns describe how to use the input.
8937
my_feature_columns = []
@@ -100,12 +48,14 @@ def main(argv):
10048

10149
# Train the Model.
10250
classifier.train(
103-
input_fn=lambda:train_input_fn(train_x, train_y, args.batch_size),
51+
input_fn=lambda:iris_data.train_input_fn(train_x, train_y,
52+
args.batch_size),
10453
steps=args.train_steps)
10554

10655
# Evaluate the model.
10756
eval_result = classifier.evaluate(
108-
input_fn=lambda:eval_input_fn(test_x, test_y, args.batch_size))
57+
input_fn=lambda:iris_data.eval_input_fn(test_x, test_y,
58+
args.batch_size))
10959

11060
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
11161

@@ -119,14 +69,18 @@ def main(argv):
11969
}
12070

12171
predictions = classifier.predict(
122-
input_fn=lambda:eval_input_fn(predict_x, batch_size=args.batch_size))
72+
input_fn=lambda:iris_data.eval_input_fn(predict_x,
73+
labels=None,
74+
batch_size=args.batch_size))
12375

12476
for pred_dict, expec in zip(predictions, expected):
12577
template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
12678

12779
class_id = pred_dict['class_ids'][0]
12880
probability = pred_dict['probabilities'][class_id]
129-
print(template.format(SPECIES[class_id], 100 * probability, expec))
81+
82+
print(template.format(iris_data.SPECIES[class_id],
83+
100 * probability, expec))
13084

13185

13286
if __name__ == '__main__':

0 commit comments

Comments
 (0)