-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
36 lines (29 loc) · 1.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Training script for GAN."""
import tensorflow as tf
from gan import GanBuilder
from serialization import register_defaults
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'gan_id',
help='id of GAN spec defined in gan_params')
parser.add_argument(
'-s', '--max_steps', type=float, default=1e7,
help='maximum number of steps to train until')
args = parser.parse_args()
register_defaults()
builder = GanBuilder(args.gan_id)
def train_pose_gan(gan_id, config=None, **train_args):
"""
Train the specified model.
`train_args` should have `steps` or `max_steps`.
"""
def input_fn():
features = builder.get_random_generator_input()
labels = builder.get_real_sample()
return features, labels
gan = builder.gan_estimator()
return gan.train(input_fn, **train_args)
tf.logging.set_verbosity(tf.logging.INFO)
train_pose_gan(args.gan_id, max_steps=int(args.max_steps))