forked from VeritasXu/Ternary-Federated
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Ternary_Fed.py
95 lines (80 loc) · 2.9 KB
/
Ternary_Fed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import sys
import copy
import torch
import numpy as np
from utils.config import Args
from utils.Evaluate import evaluate
import utils.data_utils as data_utils
from tools.Fed_Operator import ServerUpdate, LocalUpdate
if Args.model == 'MLP':
from model.MLP import MLP as Fed_Model
elif Args.model == 'CNN':
from model.CNN import CNN as Fed_Model
elif Args.model == 'ResNet':
from model.resnet import ResNet18 as Fed_Model
def choose_model(f_dict, ter_dict):
# create models based on both full and ternary weights
tmp_net1 = Fed_Model()
tmp_net2 = Fed_Model()
tmp_net1.load_state_dict(f_dict)
tmp_net2.load_state_dict(ter_dict)
# evaluate networks on test set
_, acc_1, _ = evaluate(tmp_net1, G_loss_fun, test_iter, Args)
_, acc_2, _ = evaluate(tmp_net2, G_loss_fun, test_iter, Args)
print('F: %.3f' % acc_1, 'TF: %.3f' % acc_2)
# If the ter model loses more than 3 percent accuracy, sent full model instead
flag = False
if np.abs(acc_1 - acc_2) < 0.03:
flag = True
return ter_dict, flag
else:
return f_dict, flag
if __name__ == '__main__':
print(Args)
torch.manual_seed(Args.seed)
C_iter, train_iter, test_iter, stats = data_utils.get_dataset(args=Args)
# build global network
G_net = Fed_Model()
print(G_net)
G_net.train()
G_loss_fun = torch.nn.CrossEntropyLoss()
# copy weights
w_glob = G_net.state_dict()
m = max(int(Args.frac * Args.num_C), 1)
# Global loss
gv_acc = []
net_best = None
val_acc_list, net_list = [], []
num_s2 = 0
# training
c_lists = [[] for i in range(Args.num_C)]
for rounds in range(Args.rounds):
w_locals = []
client_id = np.random.choice(range(Args.num_C), m, replace=False)
print('Round {:d} start'.format(rounds))
num_samp = []
for idx in client_id:
print("Client ID: ", idx)
local = LocalUpdate(client_name=idx, c_round=rounds, train_iter=C_iter[idx], test_iter=test_iter,
wp_lists=c_lists[idx], args=Args)
w, wp_lists = local.TFed_train(net=copy.deepcopy(G_net).to(Args.device))
c_lists[idx] = wp_lists
w_locals.append(copy.deepcopy(w))
num_samp.append(len(C_iter[idx].dataset))
# update global weights
w_glob, ter_glob = ServerUpdate(w_locals, num_samp)
w_glob, tmp_flag = choose_model(w_glob, ter_glob)
if tmp_flag:
num_s2 += 1
print('S2')
else:
print('S1')
# reload global network weights
G_net.load_state_dict(w_glob)
# verify accuracy on test set
g_loss, g_acc, g_acc5 = evaluate(G_net, G_loss_fun, test_iter, Args)
gv_acc.append(g_acc)
print('Round {:3d}, Global loss {:.3f}, Global Acc {:.3f}'.format(rounds, g_loss, g_acc))