-
Notifications
You must be signed in to change notification settings - Fork 10
/
model.py
39 lines (34 loc) · 1.72 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
import torch.nn.functional as F
from rnn import CustomRNN
from gru import CustomGRUCell
class POSTagger(nn.Module):
def __init__(self, rnn_class, embedding_dim, hidden_dim, vocab_size, target_size, use_gpu=True):
super(POSTagger, self).__init__()
self.rnn_class = rnn_class
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=1)
self.use_gpu = use_gpu
if use_gpu:
self.word_embeddings.cuda()
self.num_layers = 1
# The LSTM takes word embeddings as inputs, and outputs hidden states
# with dimensionality hidden_dim.
if self.rnn_class == 'lstm':
self.rnn = CustomRNN(nn.LSTMCell, embedding_dim, hidden_dim, batch_first=False)
elif self.rnn_class == 'gru':
self.rnn = CustomRNN(nn.GRUCell, embedding_dim, hidden_dim, batch_first=False)
elif self.rnn_class == 'rnn':
self.rnn = CustomRNN(nn.RNNCell, embedding_dim, hidden_dim, batch_first=False)
else:
self.rnn = CustomRNN(CustomGRUCell, embedding_dim, hidden_dim, batch_first=False)
# The linear layer that maps from hidden state space to tag space
self.hidden2tag = nn.Linear(hidden_dim, target_size)
def forward(self, sentences, ranges, lengths):
embeds = self.word_embeddings(sentences)
lstm_out, _ = self.rnn(embeds, ranges, lengths)
tag_space = self.hidden2tag(lstm_out)
tag_scores = F.log_softmax(tag_space)
# do this if want 3D tensor (ref: https://github.com/pytorch/pytorch/issues/1020)
# tag_scores = F.log_softmax(tag_space.transpose(0, 2)).transpose(0, 2)
return tag_scores