-
Notifications
You must be signed in to change notification settings - Fork 0
/
ResNet1D.py
168 lines (146 loc) · 6.59 KB
/
ResNet1D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#Imports
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
from torchinfo import summary
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset, ConcatDataset, TensorDataset
import random
import numpy as np
"""1D-Resnet defined by Deep neural network-estimated electrocardiographic age as a mortality predictor"""
def _padding(downsample, kernel_size):
"""Compute required padding"""
padding = max(0, int(np.floor((kernel_size - downsample + 1) / 2)))
return padding
def _downsample(n_samples_in, n_samples_out):
"""Compute downsample rate"""
downsample = int(n_samples_in // n_samples_out)
if downsample < 1:
raise ValueError("Number of samples should always decrease")
if n_samples_in % n_samples_out != 0:
raise ValueError("Number of samples for two consecutive blocks "
"should always decrease by an integer factor.")
return downsample
class ResBlock1d(nn.Module):
"""Residual network unit for unidimensional signals."""
def __init__(self, n_filters_in, n_filters_out, downsample, kernel_size, dropout_rate):
if kernel_size % 2 == 0:
raise ValueError("The current implementation only support odd values for `kernel_size`.")
super(ResBlock1d, self).__init__()
# Forward path
padding = _padding(1, kernel_size)
self.conv1 = nn.Conv1d(n_filters_in, n_filters_out, kernel_size, padding=padding, bias=False)
self.bn1 = nn.BatchNorm1d(n_filters_out)
self.relu = nn.ReLU()
self.dropout1 = nn.Dropout(dropout_rate)
padding = _padding(downsample, kernel_size)
self.conv2 = nn.Conv1d(n_filters_out, n_filters_out, kernel_size,
stride=downsample, padding=padding, bias=False)
self.bn2 = nn.BatchNorm1d(n_filters_out)
self.dropout2 = nn.Dropout(dropout_rate)
# Skip connection
skip_connection_layers = []
# Deal with downsampling
if downsample > 1:
maxpool = nn.MaxPool1d(downsample, stride=downsample)
skip_connection_layers += [maxpool]
# Deal with n_filters dimension increase
if n_filters_in != n_filters_out:
conv1x1 = nn.Conv1d(n_filters_in, n_filters_out, 1, bias=False)
skip_connection_layers += [conv1x1]
# Build skip conection layer
if skip_connection_layers:
self.skip_connection = nn.Sequential(*skip_connection_layers)
else:
self.skip_connection = None
def forward(self, x, y):
"""Residual unit."""
if self.skip_connection is not None:
y = self.skip_connection(y)
else:
y = y
# 1st layer
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.dropout1(x)
# 2nd layer
x = self.conv2(x)
x += y # Sum skip connection and main connection
y = x
x = self.bn2(x)
x = self.relu(x)
x = self.dropout2(x)
return x, y
class ResNet1d(nn.Module):
"""Residual network for unidimensional signals.
Parameters
----------
input_dim : tuple
Input dimensions. Tuple containing dimensions for the neural network
input tensor. Should be like: ``(n_filters, n_samples)``.
blocks_dim : list of tuples
Dimensions of residual blocks. The i-th tuple should contain the dimensions
of the output (i-1)-th residual block and the input to the i-th residual
block. Each tuple shoud be like: ``(n_filters, n_samples)``. `n_samples`
for two consecutive samples should always decrease by an integer factor.
dropout_rate: float [0, 1), optional
Dropout rate used in all Dropout layers. Default is 0.8
kernel_size: int, optional
Kernel size for convolutional layers. The current implementation
only supports odd kernel sizes. Default is 17.
References
----------
.. [1] K. He, X. Zhang, S. Ren, and J. Sun, "Identity Mappings in Deep Residual Networks,"
arXiv:1603.05027, Mar. 2016. https://arxiv.org/pdf/1603.05027.pdf.
.. [2] K. He, X. Zhang, S. Ren, and J. Sun, "Deep Residual Learning for Image Recognition," in 2016 IEEE Conference
on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778. https://arxiv.org/pdf/1512.03385.pdf
"""
def __init__(self, input_dim, blocks_dim, n_classes, kernel_size=17, dropout_rate=0.8):
super(ResNet1d, self).__init__()
# First layers
n_filters_in, n_filters_out = input_dim[0], blocks_dim[0][0]
n_samples_in, n_samples_out = input_dim[1], blocks_dim[0][1]
downsample = _downsample(n_samples_in, n_samples_out)
padding = _padding(downsample, kernel_size)
self.conv1 = nn.Conv1d(n_filters_in, n_filters_out, kernel_size, bias=False,
stride=downsample, padding=padding)
self.bn1 = nn.BatchNorm1d(n_filters_out)
# Residual block layers
self.res_blocks = []
for i, (n_filters, n_samples) in enumerate(blocks_dim):
n_filters_in, n_filters_out = n_filters_out, n_filters
n_samples_in, n_samples_out = n_samples_out, n_samples
downsample = _downsample(n_samples_in, n_samples_out)
resblk1d = ResBlock1d(n_filters_in, n_filters_out, downsample, kernel_size, dropout_rate)
self.add_module('resblock1d_{0}'.format(i), resblk1d)
self.res_blocks += [resblk1d]
# Linear layer
n_filters_last, n_samples_last = blocks_dim[-1]
last_layer_dim = n_filters_last * n_samples_last
self.lin = nn.Linear(last_layer_dim, n_classes)
self.n_blk = len(blocks_dim)
def forward(self, x):
"""Implement ResNet1d forward propagation"""
# First layers
x = self.conv1(x)
x = self.bn1(x)
# Residual blocks
y = x
for blk in self.res_blocks:
x, y = blk(x, y)
# Flatten array
x = x.view(x.size(0), -1)
# Fully conected layer
x = self.lin(x)
return x
res_model = ResNet1d(input_dim=(12, 1120),
blocks_dim=list(zip([64, 128, 196, 256, 320], [1120, 560, 280, 140, 70])),
n_classes=1,
kernel_size=17,
dropout_rate=0.8)
# test the newtork
random_data = torch.randn(20, 12, 1120)
print(res_model)