Skip to content

Commit

Permalink
add tdnn_lstm model
Browse files Browse the repository at this point in the history
  • Loading branch information
YiwenShaoStephen committed Aug 20, 2020
1 parent 20c4fb2 commit 8684887
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions models/tdnn_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Yiwen Shao

# Apache 2.0

import torch
import torch.nn as nn
import torch.nn.functional as F


class tdnn_bn_relu(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride=1, dilation=1):
super(tdnn_bn_relu, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = dilation * (kernel_size - 1) // 2
self.dilation = dilation
self.tdnn = nn.Conv1d(in_dim, out_dim, kernel_size,
stride=stride, padding=self.padding, dilation=dilation)
self.bn = nn.BatchNorm1d(out_dim)
self.relu = nn.ReLU(inplace=True)

def output_lengths(self, in_lengths):
out_lengths = (
in_lengths + 2 * self.padding - self.dilation * (self.kernel_size - 1) +
self.stride - 1
) // self.stride
return out_lengths

def forward(self, x, x_lengths):
assert len(x.size()) == 3 # x is of size (N, F, T)
x = self.tdnn(x)
x = self.bn(x)
x = self.relu(x)
x_lengths = self.output_lengths(x_lengths)
return x, x_lengths


class TDNNLSTM(nn.Module):
def __init__(self, in_dim, out_dim, num_layers, hidden_dims, kernel_sizes, strides, dilations,
bidirectional=False, dropout=0):
super(TDNNLSTM, self).__init__()
self.num_tdnn_layers = len(hidden_dims)
self.num_lstm_layers = num_layers - len(hidden_dims)
assert len(kernel_sizes) == self.num_tdnn_layers
assert len(strides) == self.num_tdnn_layers
assert len(dilations) == self.num_tdnn_layers
self.in_dim = in_dim
self.out_dim = out_dim
self.dropout = dropout
# set lstm hidden_dim to the num_channels of the last cnn layer
self.lstm_dim = hidden_dims[-1]
self.tdnn = nn.ModuleList([
tdnn_bn_relu(
in_dim if layer == 0 else hidden_dims[layer - 1],
hidden_dims[layer], kernel_sizes[layer],
strides[layer], dilations[layer],
)
for layer in range(self.num_tdnn_layers)
])
self.num_directions = 2 if bidirectional else 1
self.lstm = nn.LSTM(self.lstm_dim, self.lstm_dim, self.num_lstm_layers,
batch_first=True, bidirectional=bidirectional,
dropout=dropout)
self.final_layer = nn.Linear(
self.lstm_dim * self.num_directions, out_dim)

def forward(self, x, x_lengths):
assert len(x.size()) == 3 # x is of size (B, T, D)
# turn x to (B, D, T) for tdnn/cnn input
x = x.transpose(1, 2).contiguous()
for i in range(len(self.tdnn)):
# apply Tdnn
x, x_lengths = self.tdnn[i](x, x_lengths)
x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(2, 1).contiguous() # turn it back to (B, T, D)
bsz = x.size(0)
state_size = self.num_directions * \
self.num_lstm_layers, bsz, self.lstm_dim
h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size)
x = torch.nn.utils.rnn.pack_padded_sequence(
x, x_lengths, batch_first=True) # (B, T, D)
x, _ = self.lstm(x, (h0, c0))
x, _ = torch.nn.utils.rnn.pad_packed_sequence(
x, batch_first=True) # (B, T, D)
x = self.final_layer(x)
return x, x_lengths


if __name__ == "__main__":
num_layers = 2
hidden_dim = 20
in_dim = 10
out_dim = 5
net = TDNNLSTM(in_dim, out_dim, num_layers, hidden_dim)
print(net)
input = torch.randn(2, 30, 10)
input_lengths = [30, 20]
output, output_lengths = net(input, input_lengths)
print(output.size())
print(output_lengths)

0 comments on commit 8684887

Please sign in to comment.