-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
57 lines (47 loc) · 2.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import opt
import torch
from torch.optim import Adam
import torch.nn.functional as F
from sklearn.cluster import KMeans
from utils import adjust_learning_rate
from utils import eva, target_distribution
acc_reuslt = []
nmi_result = []
ari_result = []
f1_result = []
use_adjust_lr = ['usps', 'hhar', 'reut', 'acm', 'dblp', 'cite']
def Train(epoch, model, data, adj, label, lr, pre_model_save_path, final_model_save_path, n_clusters,
original_acc, gamma_value, lambda_value, device):
optimizer = Adam(model.parameters(), lr=lr)
model.load_state_dict(torch.load(pre_model_save_path, map_location='cpu'))
with torch.no_grad():
x_hat, z_hat, adj_hat, z_ae, z_igae, _, _, _, z_tilde = model(data, adj)
kmeans = KMeans(n_clusters=n_clusters, n_init=20)
cluster_id = kmeans.fit_predict(z_tilde.data.cpu().numpy())
model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device)
eva(label, cluster_id, 'Initialization')
for epoch in range(epoch):
# if opt.args.name in use_adjust_lr:
# adjust_learning_rate(optimizer, epoch)
x_hat, z_hat, adj_hat, z_ae, z_igae, q, q1, q2, z_tilde = model(data, adj)
tmp_q = q.data
p = target_distribution(tmp_q)
loss_ae = F.mse_loss(x_hat, data)
loss_w = F.mse_loss(z_hat, torch.spmm(adj, data))
loss_a = F.mse_loss(adj_hat, adj.to_dense())
loss_igae = loss_w + gamma_value * loss_a
loss_kl = F.kl_div((q.log() + q1.log() + q2.log()) / 3, p, reduction='batchmean')
loss = loss_ae + loss_igae + lambda_value * loss_kl
print('{} loss: {}'.format(epoch, loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
kmeans = KMeans(n_clusters=n_clusters, n_init=20).fit(z_tilde.data.cpu().numpy())
acc, nmi, ari, f1 = eva(label, kmeans.labels_, epoch)
acc_reuslt.append(acc)
nmi_result.append(nmi)
ari_result.append(ari)
f1_result.append(f1)
if acc > original_acc:
original_acc = acc
torch.save(model.state_dict(), final_model_save_path)