forked from microsoftarchive/BatchAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_mnist.py
124 lines (101 loc) · 4.64 KB
/
train_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python
from __future__ import print_function
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from mpi4py import MPI
import chainermn
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__(
# the size of the inputs to each layer will be inferred
l1=L.Linear(784, n_units), # n_in -> n_units
l2=L.Linear(n_units, n_units), # n_units -> n_units
l3=L.Linear(n_units, n_out), # n_units -> n_out
)
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='ChainerMN example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--communicator', type=str,
default='hierarchical', help='Type of communicator')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', action='store_true',
help='Use GPU')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
# Prepare ChainerMN communicator.
if args.gpu:
if args.communicator == 'naive':
print("Error: 'naive' communicator does not support GPU.\n")
exit(-1)
comm = chainermn.create_communicator(args.communicator)
device = comm.intra_rank
else:
if args.communicator != 'naive':
print('Warning: using naive communicator '
'because only naive supports CPU-only execution')
comm = chainermn.create_communicator('naive')
device = -1
if comm.mpi_comm.rank == 0:
print('==========================================')
print('Num process (COMM_WORLD): {}'.format(MPI.COMM_WORLD.Get_size()))
if args.gpu:
print('Using GPUs')
print('Using {} communicator'.format(args.communicator))
print('Num unit: {}'.format(args.unit))
print('Num Minibatch-size: {}'.format(args.batchsize))
print('Num epoch: {}'.format(args.epoch))
print('==========================================')
model = L.Classifier(MLP(args.unit, 10))
if device >= 0:
chainer.cuda.get_device(device).use()
model.to_gpu()
# Create a multi node optimizer from a standard Chainer optimizer.
optimizer = chainermn.create_multi_node_optimizer(
chainer.optimizers.Adam(), comm)
optimizer.setup(model)
# Split and distribute the dataset. Only worker 0 loads the whole dataset.
# Datasets of worker 0 are evenly split and distributed to all workers.
if comm.rank == 0:
train, test = chainer.datasets.get_mnist()
else:
train, test = None, None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# Create a multi node evaluator from a standard Chainer evaluator.
evaluator = extensions.Evaluator(test_iter, model, device=device)
evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
trainer.extend(evaluator)
# Some display and output extensions are necessary only for one worker.
# (Otherwise, there would just be repeated outputs.)
if comm.rank == 0:
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(extensions.ProgressBar())
if args.resume:
chainer.serializers.load_npz(args.resume, trainer)
trainer.run()
if __name__ == '__main__':
main()