-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchar_decoder.py
120 lines (102 loc) · 6.26 KB
/
char_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 5
"""
import torch
import torch.nn as nn
class CharDecoder(nn.Module):
def __init__(self, hidden_size, char_embedding_size=50, target_vocab=None):
""" Init Character Decoder.
@param hidden_size (int): Hidden size of the decoder LSTM
@param char_embedding_size (int): dimensionality of character embeddings
@param target_vocab (VocabEntry): vocabulary for the target language. See vocab.py for documentation.
"""
### YOUR CODE HERE for part 2a
### TODO - Initialize as an nn.Module.
### - Initialize the following variables:
### self.charDecoder: LSTM. Please use nn.LSTM() to construct this.
### self.char_output_projection: Linear layer, called W_{dec} and b_{dec} in the PDF
### self.decoderCharEmb: Embedding matrix of character embeddings
### self.target_vocab: vocabulary for the target language
###
### Hint: - Use target_vocab.char2id to access the character vocabulary for the target language.
### - Set the padding_idx argument of the embedding matrix.
### - Create a new Embedding layer. Do not reuse embeddings created in Part 1 of this assignment.
super(CharDecoder, self).__init__()
self.target_vocab = target_vocab
char_vocab_size = len(self.target_vocab.char2id)
padding_idx = self.target_vocab.char2id['<pad>']
self.decoderCharEmb = nn.Embedding(char_vocab_size, char_embedding_size, padding_idx=padding_idx)
self.charDecoder = nn.LSTM(input_size=char_embedding_size, hidden_size=hidden_size)
self.char_output_projection = nn.Linear(hidden_size, char_vocab_size)
self.loss = nn.CrossEntropyLoss(ignore_index=padding_idx, reduction="sum")
### END YOUR CODE
def forward(self, input, dec_hidden=None):
""" Forward pass of character decoder.
@param input: tensor of integers, shape (length, batch)
@param dec_hidden: internal state of the LSTM before reading the input characters. A tuple of two tensors of shape (1, batch, hidden_size)
@returns scores: called s_t in the PDF, shape (length, batch, self.vocab_size)
@returns dec_hidden: internal state of the LSTM after reading the input characters. A tuple of two tensors of shape (1, batch, hidden_size)
"""
### YOUR CODE HERE for part 2b
### TODO - Implement the forward pass of the character decoder.
input_embed = self.decoderCharEmb(input) # (length, batch, hidden_size)
h_t , dec_hidden = self.charDecoder(input_embed, dec_hidden) # h_t.shape = (length, batch_size, hidden_size)
scores = self.char_output_projection(h_t.permute(1, 0, 2)).permute(1, 0, 2)
return scores, dec_hidden
### END YOUR CODE
def train_forward(self, char_sequence, dec_hidden=None):
""" Forward computation during training.
@param char_sequence: tensor of integers, shape (length, batch). Note that "length" here and in forward() need not be the same.
@param dec_hidden: initial internal state of the LSTM, obtained from the output of the word-level decoder. A tuple of two tensors of shape (1, batch, hidden_size)
@returns The cross-entropy loss, computed as the *sum* of cross-entropy losses of all the words in the batch.
"""
### YOUR CODE HERE for part 2c
### TODO - Implement training forward pass.
###
### Hint: - Make sure padding characters do not contribute to the cross-entropy loss.
### - char_sequence corresponds to the sequence x_1 ... x_{n+1} from the handout (e.g., <START>,m,u,s,i,c,<END>).
input_seq = char_sequence[:-1]
target_seq = char_sequence[1:]
scores, _ = self.forward(input_seq, dec_hidden)
return self.loss(scores.permute(1, 2, 0), target_seq.permute(1, 0))
### END YOUR CODE
def decode_greedy(self, initialStates, device, max_length=21):
""" Greedy decoding
@param initialStates: initial internal state of the LSTM, a tuple of two tensors of size (1, batch, hidden_size)
@param device: torch.device (indicates whether the model is on CPU or GPU)
@param max_length: maximum length of words to decode
@returns decodedWords: a list (of length batch) of strings, each of which has length <= max_length.
The decoded strings should NOT contain the start-of-word and end-of-word characters.
"""
### YOUR CODE HERE for part 2d
### TODO - Implement greedy decoding.
### Hints:
### - Use target_vocab.char2id and target_vocab.id2char to convert between integers and characters
### - Use torch.tensor(..., device=device) to turn a list of character indices into a tensor.
### - We use curly brackets as start-of-word and end-of-word characters. That is, use the character '{' for <START> and '}' for <END>.
### Their indices are self.target_vocab.start_of_word and self.target_vocab.end_of_word, respectively.
hidden_state_init, cell_state_init = initialStates
next_state = initialStates
_, batch_size, hidden_size = hidden_state_init.shape
input = torch.Tensor([[self.target_vocab.start_of_word] * batch_size]).long().to(device)
tmp = []
for i in range(max_length):
scores, next_state = self.forward(input, next_state)
input = torch.argmax(scores, dim=2)
tmp.append(input.squeeze(dim=0).cpu().numpy())
tmp = torch.Tensor(tmp).permute(1, 0).numpy()
decodedWords = []
for sample in range(batch_size):
current_sentence = []
for i in range(max_length):
if tmp[sample, i] == self.target_vocab.end_of_word:
current_sentence = tmp[sample, :i].astype(int).tolist()
current_sentence = "".join(map(lambda x: self.target_vocab.id2char[x], current_sentence))
break
if len(current_sentence) == 0 and tmp[sample, 0] != self.target_vocab.end_of_word:
current_sentence = tmp[sample]
decodedWords.append(current_sentence)
return decodedWords
### END YOUR CODE