-
Notifications
You must be signed in to change notification settings - Fork 0
/
word2vec.py
141 lines (119 loc) · 4.57 KB
/
word2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
import numpy as np
import tqdm
import itertools
def get_skips_and_contexts(tokens_list: list[str], context_size: int):
assert context_size > 0 and context_size % 2 == 0
context_size_half = context_size // 2
contexts, skips = [], []
for tokens in tqdm.tqdm(tokens_list, desc="Building word2vec dataset"):
for i, token in enumerate(tokens):
start = i - context_size_half
end = i + context_size_half
if start >= 0 and end < len(tokens):
context = (
tokens[start : start + context_size_half]
+ tokens[i + 1 : i + 1 + context_size_half]
)
contexts.append(context)
skips.append(token)
return skips, contexts
def build_cbow_dataset(
tokens_list: list[str],
context_size: int,
batch_size: int,
device: str = "cpu",
) -> torch.utils.data.DataLoader:
ys, xs = get_skips_and_contexts(tokens_list, context_size)
tensor_dataset = torch.utils.data.TensorDataset(
torch.tensor(xs, dtype=torch.int, device=device),
torch.tensor(ys, dtype=torch.long, device=device),
)
data_loader = torch.utils.data.DataLoader(
tensor_dataset, batch_size=batch_size, shuffle=True
)
return data_loader
def build_skip_gram_dataset(
tokens_list: list[str],
context_size: int,
batch_size: int,
device: str = "cpu",
) -> torch.utils.data.DataLoader:
skips, contexts = get_skips_and_contexts(tokens_list, context_size)
xs = [[skip] for skip in skips for _ in range(context_size)]
ys = list(itertools.chain(*contexts))
tensor_dataset = torch.utils.data.TensorDataset(
torch.tensor(xs, dtype=torch.int, device=device),
torch.tensor(ys, dtype=torch.long, device=device),
)
data_loader = torch.utils.data.DataLoader(
tensor_dataset, batch_size=batch_size, shuffle=True
)
return data_loader
class CBOWModel(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int, context_size: int):
super(CBOWModel, self).__init__()
self.emb = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, x):
y = self.emb(x)
y = y.sum(dim=1)
y = self.linear(y)
return y
class SkipGramModel(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super(SkipGramModel, self).__init__()
self.emb = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, x):
y = self.emb(x)
y = self.linear(y)
return y
class Word2VecTrainer:
def __init__(
self,
train_data_loader: torch.utils.data.DataLoader,
val_data_loader: torch.utils.data.DataLoader,
model_name: str,
vocab_size: int,
embedding_dim: int,
context_size: int,
device: str = "cpu",
):
self.train_data_loader, self.val_data_loader = (
train_data_loader,
val_data_loader,
)
self.device = device
self.valid_model_names = ["CBOW", "skip-gram"]
self.model_name = model_name
if model_name not in self.valid_model_names:
raise ValueError(
f"model_name must be one of {self.valid_model_names}, but got model_name='{model_name}'"
)
if model_name == "CBOW":
self.model = CBOWModel(vocab_size, embedding_dim, context_size).to(device)
elif model_name == "skip-gram":
self.model = SkipGramModel(vocab_size, embedding_dim).to(device)
self.model_name = model_name
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
self.criterion = nn.CrossEntropyLoss()
def train(self, max_epochs):
for epoch in range(max_epochs):
self.model.train()
loss_epoch = []
for inputs, labels in tqdm.tqdm(
self.train_data_loader, desc=f"Training {self.model_name} epoch #{epoch}"
):
xs = inputs.to(self.device)
ys = labels.to(self.device)
out = self.model.forward(xs)
loss = self.criterion(out.squeeze(1), ys)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
loss_epoch.append(loss.item())
loss_epoch_mean = np.mean(loss_epoch)
print(f"epoch = {epoch:02d}; loss = {loss_epoch_mean:.10f}")
print("Finished")