-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet3d.py
97 lines (71 loc) · 3.01 KB
/
unet3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
def conv_block(in_dim, out_dim, act_fn):
model = nn.Sequential(
nn.Conv3d(in_dim,out_dim, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm3d(out_dim),
act_fn,
)
return model
def up_conv(in_dim, out_dim, act_fn):
model = nn.Sequential(
nn.ConvTranspose3d(in_dim,out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm3d(out_dim),
act_fn,
)
return model
def double_conv_block(in_dim, out_dim, act_fn):
model = nn.Sequential(
conv_block(in_dim, out_dim, act_fn),
conv_block(out_dim, out_dim, act_fn),
)
return model
def out_block(in_dim,out_dim):
model = nn.Sequential(
nn.Conv3d(in_dim,out_dim, kernel_size=1, stride=1, padding=0),
)
return model
class Unet3d(nn.Module):
def __init__(self, in_dim, out_dim, num_filter):
super(Unet3d, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_filter = num_filter
act_fn = nn.ReLU(inplace=True)
self.down_1 = double_conv_block(self.in_dim, self.num_filter, act_fn)
self.pool_1 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
self.down_2 = double_conv_block(self.num_filter, self.num_filter * 2, act_fn)
self.pool_2 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
self.down_3 = double_conv_block(self.num_filter * 2, self.num_filter * 4, act_fn)
self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
self.bridge = double_conv_block(self.num_filter * 4, self.num_filter * 8, act_fn)
self.trans_1 = up_conv(self.num_filter * 8, self.num_filter * 8, act_fn)
self.up_1 = double_conv_block(self.num_filter * 12, self.num_filter * 4, act_fn)
self.trans_2 = up_conv(self.num_filter * 4, self.num_filter * 4, act_fn)
self.up_2 = double_conv_block(self.num_filter * 6, self.num_filter * 2, act_fn)
self.trans_3 = up_conv(self.num_filter * 2, self.num_filter * 2, act_fn)
self.up_3 = double_conv_block(self.num_filter * 3, self.num_filter, act_fn)
self.out = out_block(self.num_filter, out_dim)
def forward(self, x):
down_1 = self.down_1(x)
pool_1 = self.pool_1(down_1)
down_2 = self.down_2(pool_1)
pool_2 = self.pool_2(down_2)
down_3 = self.down_3(pool_2)
pool_3 = self.pool_3(down_3)
bridge = self.bridge(pool_3)
trans_1 = self.trans_1(bridge)
concat_1 = torch.cat([trans_1, down_3], dim=1)
up_1 = self.up_1(concat_1)
trans_2 = self.trans_2(up_1)
concat_2 = torch.cat([trans_2, down_2], dim=1)
up_2 = self.up_2(concat_2)
trans_3 = self.trans_3(up_2)
concat_3 = torch.cat([trans_3, down_1], dim=1)
up_3 = self.up_3(concat_3)
out = self.out(up_3)
return out
# net = Unet3d(in_dim=5, out_dim=2, num_filter=16).cuda()
# input = torch.ones(1, 5, 128, 192, 128).cuda()
# out = net(input)
# print(out.size())