forked from blt2114/CDE_with_BNF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mog_network_test.py
50 lines (47 loc) · 2.1 KB
/
mog_network_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import network, mog_network
import argparse
def main():
### Establish command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("-lr", dest="lr", action="store",
help="The learning rate", type=float, default=0.05)
parser.add_argument("-epochs", dest="epochs", action="store",
help="Number of epochs to run", type=int, default=1000)
parser.add_argument("-batch_size", dest="batch_size", action="store",
help="Number of epochs to run", type=int, default=10**10)
parser.add_argument("-log_base_dir", dest="log_base_dir", action="store",type=str,
default='logs/', help="directory for tensorboard logs")
parser.add_argument("-log_fn", dest="log_fn", action="store",type=str,
help="root filename for tb logs", required=True)
parser.add_argument("-n_comps", dest="n_comps", action="store",type=int,
help="number of mixing components", required=True)
parser.add_argument("-display_freq", dest="display_freq", action="store",type=int,
help="log state every this many epochs",default=100)
parser.add_argument("-input_independent", dest="input_independent", action="store_false",
default=False, help="if parameters of the flow are to be input independent")
parser.add_argument("-r_mag_W", dest="r_mag_W", action="store",
type=float,default=0., help="magnitude of L2 regularizer on W")
### Parse args
try:
args = parser.parse_args()
print(args)
except IOError as e:
parser.error(e)
net = mog_network.mixture_density_network(
summary_fn=args.log_fn,
n_hidden_units=[50],
n_epochs=args.epochs,
n_components=args.n_comps,
input_dependent=not args.input_independent,
lr=args.lr,
dataset="toy",
n_pts=5000,
log_base_dir=args.log_base_dir,
r_mag_W=args.r_mag_W,
display_freq=args.display_freq,
batch_size=args.batch_size,
)
net.split(0)
net.train()
if __name__ == "__main__":
main()