-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_utils.py
98 lines (83 loc) · 2.98 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch, dgl
import numpy as np
class MyDataset(torch.utils.data.Dataset):
def __init__(self, dataset, setType='train', neg_data=None):
super(MyDataset, self).__init__()
if setType not in ['train', 'val', 'test']:
raise ValueError('Invalid setType {}'.format(setType))
self.n_user = dataset['userCount']
self.n_item = dataset['itemCount']
self.n_category = dataset['categoryCount']
self.training = setType == 'train'
if self.training:
uids, iids = dataset['train'].nonzero()
self.data = np.stack((uids, iids), axis=1).astype(np.int64)
self.actSet = set((uid, iid) for uid, iid in zip(uids, iids))
else:
uids, iids = dataset[setType].nonzero()
data = []
for uid, pos in zip(uids, iids):
data.append((uid, pos))
for neg in neg_data[uid]:
data.append((uid, neg))
self.data = np.array(data, dtype=np.int64)
def neg_sample(self):
assert self.training
self.neg_data = np.random.randint(low=0, high=self.n_item, size=len(self.data), dtype=np.int64)
for i in range(len(self.data)):
uid = self.data[i][0]
iid = self.neg_data[i]
while (uid, iid) in self.actSet:
iid = np.random.randint(low=0, high=self.n_item)
self.neg_data[i] = iid
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
user, item = self.data[idx]
if self.training:
return user, item, self.neg_data[idx]
else:
return user, item
def prepare_dgl_graph(args, dataset):
"""
indexed from 0
[base_u, base_u + n_user)
[base_i, base_i + n_item)
[base_c, base_c + n_category)
base_u starts from 0
etype:
0 u-u
1 i-u
2 u-i
3 c-i
4 i-c
"""
src, dst, etype = [], [], []
bu = 0
bi = bu + dataset['userCount']
bc = bi + dataset['itemCount']
num_nodes = bc + dataset['categoryCount']
""" social network """
uids, fids = dataset['trust'].nonzero()
src += (bu + uids).tolist()
dst += (bu + fids).tolist()
etype += [0] * dataset['trust'].nnz
""" user-item interactions """
uids, iids = dataset['train'].nonzero()
src += (bi + iids).tolist()
dst += (bu + uids).tolist()
etype += [1] * dataset['train'].nnz
src += (bu + uids).tolist()
dst += (bi + iids).tolist()
etype += [2] * dataset['train'].nnz
""" item-categories relations """
iids, cids = dataset['category'].nonzero()
src += (bc + cids).tolist()
dst += (bi + iids).tolist()
etype += [3] * dataset['category'].nnz
src += (bi + iids).tolist()
dst += (bc + cids).tolist()
etype += [4] * dataset['category'].nnz
graph = dgl.graph((src, dst), num_nodes=num_nodes)
graph.edata['type'] = torch.LongTensor(etype)
return graph