-
Notifications
You must be signed in to change notification settings - Fork 0
/
rnn.py
51 lines (47 loc) · 1.37 KB
/
rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch.nn as nn
import torch.nn.functional as F
class RNNModel(nn.Module):
def __init__(
self,
embedding_model: nn.Module,
embedding_dim: int,
rnn_hidden_size: int,
rnn_num_layers: int,
rnn_dropout: float,
linear_sizes: list[int],
attention_heads: int = None,
):
super(RNNModel, self).__init__()
self.emb = embedding_model
self.rnn = nn.GRU(
embedding_dim,
rnn_hidden_size,
rnn_num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.linears = nn.ModuleList(
[
nn.Linear(in_size, out_size)
for in_size, out_size in zip(
[rnn_hidden_size] + linear_sizes, (linear_sizes + [1])
)
]
)
self.attention = (
nn.MultiheadAttention(embedding_dim, attention_heads, batch_first=True)
if attention_heads
else None
)
def forward(self, x):
y = self.emb(x)
y = F.tanh(y)
if self.attention:
y, _ = self.attention(y, y, y, need_weights=False)
y, _ = self.rnn(y)
y = y[:, -1, :]
for linear in self.linears[:-1]:
y = F.gelu(linear(y))
y = self.linears[-1](y)
y = F.sigmoid(y)
return y