-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
95 lines (64 loc) · 3.19 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# TODO add torchscript jit optimiziations
from typing import List, Any, Tuple
import torch
from torch import nn, jit
from lstm.cell import LayerNormLSTMCell
class LayerNormLSTM(torch.jit.ScriptModule):
def __init__(self, input_size: int, hidden_sizes: List[int]) -> None:
super(LayerNormLSTM, self).__init__()
self.input_size = input_size
self.num_layers = len(hidden_sizes)
self.hidden_sizes = hidden_sizes
self.cells = nn.ModuleList()
for l in range(self.num_layers):
hidden_size_before = input_size if l == 0 else hidden_sizes[l - 1]
self.cells.append(LayerNormLSTMCell(hidden_size_before, hidden_sizes[l]))
# TODO maybe add a version where a input state can be passed into and are returned
@torch.jit.script_method
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, input_size = input.size()
h_prev = [torch.zeros(batch_size, hidden_size, device=input.device) for hidden_size in self.hidden_sizes]
c_prev = [torch.zeros(batch_size, hidden_size, device=input.device) for hidden_size in self.hidden_sizes]
hidden_states = []
cell_states = []
for t in range(seq_len):
h = input[:, t]
for l, cell in enumerate(self.cells):
state = (h_prev[l], c_prev[l])
h, c = cell(h, state)
h_prev[l] = h
c_prev[l] = c
hidden_states += [h]
cell_states += [c]
return torch.stack(hidden_states, dim=1), torch.stack(cell_states, dim=1)
class LayerLSTM(nn.Module):
def __init__(self, input_size: int, hidden_sizes: List[int]) -> None:
super(LayerLSTM, self).__init__()
self.input_size = input_size
self.num_layers = len(hidden_sizes)
self.hidden_sizes = hidden_sizes
self.cells = nn.ModuleList()
for l in range(self.num_layers):
hidden_size_before = input_size if l == 0 else hidden_sizes[l - 1]
self.cells.append(nn.LSTMCell(hidden_size_before, hidden_sizes[l]))
# TODO test state parameter usage
def forward(self, input: torch.Tensor, state: Tuple[List[torch.Tensor], List[torch.Tensor]] = None) -> Tuple[
torch.Tensor, torch.Tensor]:
batch_size, seq_len, input_size = input.size()
if state is None:
h_prev = [torch.zeros(batch_size, hidden_size, device=input.device) for hidden_size in self.hidden_sizes]
c_prev = [torch.zeros(batch_size, hidden_size, device=input.device) for hidden_size in self.hidden_sizes]
else:
h_prev, c_prev = state
hidden_states = torch.zeros(batch_size, seq_len, self.hidden_sizes[-1], device=input.device)
cell_states = torch.zeros(batch_size, seq_len, self.hidden_sizes[-1], device=input.device)
for t in range(seq_len):
h = input[:, t]
for l, cell in enumerate(self.cells):
s = (h_prev[l], c_prev[l])
h, c = cell(h, s)
h_prev[l] = h
c_prev[l] = c
hidden_states[:, t] = h
cell_states[:, t] = c
return hidden_states, cell_states