-
Notifications
You must be signed in to change notification settings - Fork 3
/
pp_liteseg.py
284 lines (239 loc) · 11 KB
/
pp_liteseg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# pytorch implement of pp_liteseg
import torch
import torch.nn as nn
from torch.nn import init
import math
import torch.nn.functional as F
from stdc import STDCNet813
from UAFM import UAFM_SpAtten, ConvBNReLU
class PPLiteSeg(nn.Module):
"""
The PP_LiteSeg implementation based on PaddlePaddle.
The original article refers to "Juncai Peng, Yi Liu, Shiyu Tang, Yuying Hao, Lutao Chu,
Guowei Chen, Zewu Wu, Zeyu Chen, Zhiliang Yu, Yuning Du, Qingqing Dang,Baohua Lai,
Qiwen Liu, Xiaoguang Hu, Dianhai Yu, Yanjun Ma. PP-LiteSeg: A Superior Real-Time Semantic
Segmentation Model. https://arxiv.org/abs/2204.02681".
Args:
num_classes (int): The number of target classes.
backbone(nn.Layer): Backbone network, such as stdc1net and resnet18. The backbone must
has feat_channels, of which the length is 5.
backbone_indices (List(int), optional): The values indicate the indices of output of backbone.
Default: [2, 3, 4].
arm_type (str, optional): The type of attention refinement module. Default: ARM_Add_SpAttenAdd3.
cm_bin_sizes (List(int), optional): The bin size of context module. Default: [1,2,4].
cm_out_ch (int, optional): The output channel of the last context module. Default: 128.
arm_out_chs (List(int), optional): The out channels of each arm module. Default: [64, 96, 128].
seg_head_inter_chs (List(int), optional): The intermediate channels of segmentation head.
Default: [64, 64, 64].
resize_mode (str, optional): The resize mode for the upsampling operation in decoder.
Default: bilinear.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""
def __init__(self,
num_classes,
backbone_indices=[2, 3, 4],
arm_type='UAFM_SpAtten',
cm_bin_sizes=[1, 2, 4],
cm_out_ch=128,
arm_out_chs=[32, 64, 128], # [32, 64, 128] for STDCNet813; [64, 96, 128] for STDCNet1446
seg_head_inter_chs=[64, 64, 64],
resize_mode='bilinear',
pretrain_model=None,
is_training=True):
super().__init__()
self.training = is_training
backbone = STDCNet813()
self.backbone = backbone
assert hasattr(backbone, 'feat_channels'), \
"The backbone should has feat_channels."
assert len(backbone.feat_channels) >= len(backbone_indices), \
f"The length of input backbone_indices ({len(backbone_indices)}) should not be" \
f"greater than the length of feat_channels ({len(backbone.feat_channels)})."
assert len(backbone.feat_channels) > max(backbone_indices), \
f"The max value ({max(backbone_indices)}) of backbone_indices should be " \
f"less than the length of feat_channels ({len(backbone.feat_channels)})."
self.backbone = backbone
assert len(backbone_indices) > 1, "The lenght of backbone_indices " \
"should be greater than 1"
self.backbone_indices = backbone_indices # [..., x16_id, x32_id]
backbone_out_chs = [backbone.feat_channels[i] for i in backbone_indices]
# head
if len(arm_out_chs) == 1:
arm_out_chs = arm_out_chs * len(backbone_indices)
assert len(arm_out_chs) == len(backbone_indices), "The length of " \
"arm_out_chs and backbone_indices should be equal"
self.ppseg_head = PPLiteSegHead(backbone_out_chs, arm_out_chs,
cm_bin_sizes, cm_out_ch, arm_type,
resize_mode)
if len(seg_head_inter_chs) == 1:
seg_head_inter_chs = seg_head_inter_chs * len(backbone_indices)
assert len(seg_head_inter_chs) == len(backbone_indices), "The length of " \
"seg_head_inter_chs and backbone_indices should be equal"
self.seg_heads = nn.ModuleList() # [..., head_16, head32]
for in_ch, mid_ch in zip(arm_out_chs, seg_head_inter_chs):
self.seg_heads.append(SegHead(in_ch, mid_ch, num_classes))
# pretrained
if pretrain_model:
print('use pretrain model {}'.format(pretrain_model))
self.init_weight(pretrain_model)
else:
self.init_params()
def forward(self, x):
x_hw = x.shape[2:]
feats_backbone = self.backbone(x) # [x2, x4, x8, x16, x32]
assert len(feats_backbone) >= len(self.backbone_indices), \
f"The nums of backbone feats ({len(feats_backbone)}) should be greater or " \
f"equal than the nums of backbone_indices ({len(self.backbone_indices)})"
feats_selected = [feats_backbone[i] for i in self.backbone_indices]
feats_head = self.ppseg_head(feats_selected) # [..., x8, x16, x32]
if self.training:
logit_list = []
for x, seg_head in zip(feats_head, self.seg_heads):
x = seg_head(x)
logit_list.append(x)
logit_list = [
F.interpolate(
x, x_hw, mode='bilinear', align_corners=False)
for x in logit_list
]
else:
x = self.seg_heads[0](feats_head[0])
x = F.interpolate(x, x_hw, mode='bilinear', align_corners=False)
logit_list = [x]
return logit_list
def init_weight(self, pretrain_model):
state_dict = torch.load(pretrain_model)["state_dict"]
self_state_dict = self.state_dict()
for k, v in state_dict.items():
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict)
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
class PPLiteSegHead(nn.Module):
"""
The head of PPLiteSeg.
Args:
backbone_out_chs (List(Tensor)): The channels of output tensors in the backbone.
arm_out_chs (List(int)): The out channels of each arm module.
cm_bin_sizes (List(int)): The bin size of context module.
cm_out_ch (int): The output channel of the last context module.
arm_type (str): The type of attention refinement module.
resize_mode (str): The resize mode for the upsampling operation in decoder.
"""
def __init__(self, backbone_out_chs, arm_out_chs, cm_bin_sizes, cm_out_ch,
arm_type, resize_mode):
super().__init__()
self.cm = PPContextModule(backbone_out_chs[-1], cm_out_ch, cm_out_ch,
cm_bin_sizes)
arm_class = eval(arm_type)
self.arm_list = nn.ModuleList() # [..., arm8, arm16, arm32]
for i in range(len(backbone_out_chs)):
low_chs = backbone_out_chs[i]
high_ch = cm_out_ch if i == len(
backbone_out_chs) - 1 else arm_out_chs[i + 1]
out_ch = arm_out_chs[i]
arm = arm_class(
low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode)
self.arm_list.append(arm)
def forward(self, in_feat_list):
"""
Args:
in_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32].
x2, x4 and x8 are optional.
Returns:
out_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32].
x2, x4 and x8 are optional.
The length of in_feat_list and out_feat_list are the same.
"""
high_feat = self.cm(in_feat_list[-1])
out_feat_list = []
for i in reversed(range(len(in_feat_list))):
low_feat = in_feat_list[i]
arm = self.arm_list[i]
high_feat = arm(low_feat, high_feat)
out_feat_list.insert(0, high_feat)
return out_feat_list
class PPContextModule(nn.Module):
"""
Simple Context module.
Args:
in_channels (int): The number of input channels to pyramid pooling module.
inter_channels (int): The number of inter channels to pyramid pooling module.
out_channels (int): The number of output channels after pyramid pooling module.
bin_sizes (tuple, optional): The out size of pooled feature maps. Default: (1, 3).
align_corners (bool): An argument of F.interpolate. It should be set to False
when the output size of feature is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
"""
def __init__(self,
in_channels,
inter_channels,
out_channels,
bin_sizes,
align_corners=False):
super().__init__()
self.stages = nn.ModuleList([
self._make_stage(in_channels, inter_channels, size)
for size in bin_sizes
])
self.conv_out = ConvBNReLU(
in_planes=inter_channels,
out_planes=out_channels,
kernel=3)
self.align_corners = align_corners
def _make_stage(self, in_channels, out_channels, size):
prior = nn.AdaptiveAvgPool2d(output_size=size)
conv = ConvBNReLU(
in_planes=in_channels, out_planes=out_channels, kernel=1)
return nn.Sequential(prior, conv)
def forward(self, input):
out = None
input_shape = input.shape[2:]
for stage in self.stages:
x = stage(input)
x = F.interpolate(
x,
input_shape,
mode='bilinear',
align_corners=self.align_corners)
if out is None:
out = x
else:
out += x
out = self.conv_out(out)
return out
class SegHead(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes):
super().__init__()
self.conv = ConvBNReLU(
in_chan,
mid_chan,
kernel=3,
stride=1
)
self.conv_out = nn.Conv2d(
mid_chan, n_classes, kernel_size=1, bias=False)
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
if __name__ == "__main__":
model = PPLiteSeg(num_classes=6, is_training=True)
model.eval()
x = torch.randn(16, 3, 512, 512)
y = model(x)
# torch.save(model.state_dict(), 'cat.pth')
# print(model)
print(len(y))
for i in range(len(y)):
print(y[i].shape)