-
Notifications
You must be signed in to change notification settings - Fork 9
/
unet.py
executable file
·64 lines (49 loc) · 1.88 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register_model
@register_model("unet")
class UNet(nn.Module):
"""UNet as defined in https://arxiv.org/abs/1805.07709"""
def __init__(self, bias, residual_connection = False):
super(UNet, self).__init__()
self.conv1 = nn.Conv2d(1,32,5,padding = 2, bias = bias)
self.conv2 = nn.Conv2d(32,32,3,padding = 1, bias = bias)
self.conv3 = nn.Conv2d(32,64,3,stride=2, padding = 1, bias = bias)
self.conv4 = nn.Conv2d(64,64,3,padding = 1, bias=bias)
self.conv5 = nn.Conv2d(64,64,3,dilation=2, padding = 2, bias = bias)
self.conv6 = nn.Conv2d(64,64,3,dilation = 4,padding = 4, bias = bias)
self.conv7 = nn.ConvTranspose2d(64,64, 4,stride = 2, padding = 1, bias = bias)
self.conv8 = nn.Conv2d(96,32,3,padding=1, bias = bias)
self.conv9 = nn.Conv2d(32,1,5,padding = 2, bias = False)
self.residual_connection = residual_connection;
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument("--bias", action='store_true', help="use residual bias")
parser.add_argument("--residual", action='store_true', help="use residual connection")
@classmethod
def build_model(cls, args):
return cls(args.bias, args.residual)
def forward(self, x):
pad_right = x.shape[-2]%2
pad_bottom = x.shape[-1]%2
padding = nn.ZeroPad2d((0, pad_bottom, 0, pad_right))
x = padding(x)
out = F.relu(self.conv1(x))
out_saved = F.relu(self.conv2(out))
out = F.relu(self.conv3(out_saved))
out = F.relu(self.conv4(out))
out = F.relu(self.conv5(out))
out = F.relu(self.conv6(out))
out = F.relu(self.conv7(out))
out = torch.cat([out,out_saved],dim = 1)
out = F.relu(self.conv8(out))
out = self.conv9(out)
if self.residual_connection:
out = x - out;
if pad_bottom:
out = out[:, :, :, :-1]
if pad_right:
out = out[:, :, :-1, :]
return out