-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
27 lines (25 loc) · 1.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn.functional as func
def train_unit(config, model, device, federated_train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(federated_train_loader):
# send model to federated points
model.send(data.location)
data, target = data.to(device), target.to(device)
# gradient zeroed
optimizer.zero_grad()
output = model(data)
# loss calculation
loss = func.nll_loss(output, target)
# gradient calculation
loss.backward()
optimizer.step()
# update model from sever
model.get()
if batch_idx % config.log_interval == 0:
loss = loss.get()
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * config.batch_size, len(federated_train_loader) * config.batch_size,
100. * batch_idx / len(federated_train_loader), loss.item()))
def train(config, model, device, federated_train_loader, optimizer):
for epoch in range(1, config.epochs + 1):
train_unit(config, model, device, federated_train_loader, optimizer, epoch)