-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_light_gcn.py
122 lines (83 loc) · 4.04 KB
/
main_light_gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
# set gpu id
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
from torch_mgdcf.losses import compute_bpr_loss, compute_l2_loss
from torch_mgdcf.utils import create_tensor_dataloader
from torch_mgdcf.datasets import load_dataset
from torch_mgdcf.layers.light_gcn import LightGCN
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import time
from torch_mgdcf.evaluation.ranking import evaluate_mean_global_metrics
np.set_printoptions(precision=4)
parser = argparse.ArgumentParser(description='Argument parser for the program.')
parser.add_argument('--dataset', type=str, default='light_gcn_yelp', help='Dataset Name')
parser.add_argument('--embedding_size', type=int, default=64, help='Embedding size')
parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate')
parser.add_argument('--l2_coef', type=float, default=1e-4, help='L2 regularization coefficient')
parser.add_argument('--lr_decay', type=float, default=0.995, help='Learning rate decay')
parser.add_argument('--k', type=int, default=4, help='Number of layers')
parser.add_argument('--edge_drop_rate', type=float, default=0.15, help='Edge drop rate')
parser.add_argument('--batch_size', type=int, default=8000, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=3000, help='Number of epochs')
args = parser.parse_args()
print(args)
dataset_name = args.dataset
embedding_size = args.embedding_size
lr = args.lr
l2_coef = args.l2_coef
lr_decay = args.lr_decay
k = args.k
edge_drop_rate = args.edge_drop_rate
batch_size = args.batch_size
num_epochs = args.num_epochs
device = "cuda"
num_users, num_items, user_item_edges, train_index, test_index, train_user_items_dict, test_user_items_dict = load_dataset(dataset_name)
train_user_item_edges = user_item_edges[train_index]
g = LightGCN.build_homo_graph(train_user_item_edges, num_users=num_users, num_items=num_items).to(device)
num_nodes = g.num_nodes()
embeddings = np.random.randn(num_nodes, embedding_size) / np.sqrt(embedding_size)
embeddings = torch.tensor(embeddings, dtype=torch.float32, requires_grad=True, device=device)
model = LightGCN(k=k, edge_drop_rate=edge_drop_rate).to(device)
def forward():
virtual_h = model(g, embeddings)
user_h = virtual_h[:num_users]
item_h = virtual_h[num_users:]
return user_h, item_h
def evaluate():
model.eval()
user_h, item_h = forward()
user_h = user_h.detach().cpu().numpy()
item_h = item_h.detach().cpu().numpy()
mean_results_dict = evaluate_mean_global_metrics(test_user_items_dict, train_user_items_dict,
user_h, item_h, k_list=[10, 20], metrics=["precision","recall", "ndcg"])
return mean_results_dict
def update_learning_rate(optimizer, lr_decay):
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * lr_decay
train_edges_data_loader = create_tensor_dataloader(torch.tensor(train_user_item_edges), batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam([embeddings], lr=lr)
for epoch in range(num_epochs):
if epoch % 20 == 0:
print("\nEvaluation before epoch {} ......".format(epoch))
mean_results_dict = evaluate()
print(mean_results_dict)
start_time = time.time()
for step, batch_edges in enumerate(train_edges_data_loader):
model.train()
user_h, item_h = forward()
# using BPR as ranking loss
mf_losses = compute_bpr_loss(user_h, item_h, batch_edges, reduction="none")
# LightGCN applies L2 regularization on input embeddings
l2_loss = compute_l2_loss([embeddings])
loss = mf_losses.sum() + l2_loss * l2_coef
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_learning_rate(optimizer, lr_decay)
end_time = time.time()
print("epoch = {}\tloss = {:.4f}\tmf_loss = {:.4f}\tl2_loss = {:.4f}\tupdated_lr = {:.4f}\tepoch_time = {:.4f}s"
.format(epoch, loss.item(), mf_losses.mean().item(), l2_loss.item(), optimizer.param_groups[0]['lr'], end_time-start_time))