-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathMIRT.py
39 lines (29 loc) · 1 KB
/
MIRT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# coding: utf-8
# 2021/3/23 @ tongshiwei
import logging
from EduCDM import MIRT
import torch
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
train_data = pd.read_csv("../../data/a0910/train.csv")
valid_data = pd.read_csv("../../data/a0910/valid.csv")
test_data = pd.read_csv("../../data/a0910/test.csv")
batch_size = 256
def transform(x, y, z, batch_size, **params):
dataset = TensorDataset(
torch.tensor(x, dtype=torch.int64),
torch.tensor(y, dtype=torch.int64),
torch.tensor(z, dtype=torch.float32)
)
return DataLoader(dataset, batch_size=batch_size, **params)
train, valid, test = [
transform(data["user_id"], data["item_id"], data["score"], batch_size)
for data in [train_data, valid_data, test_data]
]
logging.getLogger().setLevel(logging.INFO)
cdm = MIRT(4164, 17747, 123)
cdm.train(train, valid, epoch=2)
cdm.save("mirt.params")
cdm.load("mirt.params")
auc, accuracy = cdm.eval(test)
print("auc: %.6f, accuracy: %.6f" % (auc, accuracy))