-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathresnext.py
83 lines (75 loc) · 2.62 KB
/
resnext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from torch import nn
class ResNeXtUnit(nn.Module):
def __init__(self, in_features, out_features, mid_features=None, stride=1, groups=32):
super(ResNeXtUnit, self).__init__()
if mid_features is None:
mid_features = int(out_features/2)
self.feas = nn.Sequential(
nn.Conv2d(in_features, mid_features, 1, stride=1),
nn.BatchNorm2d(mid_features),
nn.Conv2d(mid_features, mid_features, 3, stride=stride, padding=1, groups=groups),
nn.BatchNorm2d(mid_features),
nn.Conv2d(mid_features, out_features, 1, stride=1),
nn.BatchNorm2d(out_features)
)
if in_features == out_features: # when dim not change, in could be added diectly to out
self.shortcut = nn.Sequential()
else: # when dim not change, in should also change dim to be added to out
self.shortcut = nn.Sequential(
nn.Conv2d(in_features, out_features, 1, stride=stride),
nn.BatchNorm2d(out_features)
)
def forward(self, x):
fea = self.feas(x)
return fea + self.shortcut(x)
class ResNeXt(nn.Module):
def __init__(self, class_num):
super(ResNeXt, self).__init__()
self.basic_conv = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64)
) # 32x32
self.stage_1 = nn.Sequential(
ResNeXtUnit(64, 256, mid_features=128),
nn.ReLU(),
ResNeXtUnit(256, 256),
nn.ReLU(),
ResNeXtUnit(256, 256),
nn.ReLU()
) # 32x32
self.stage_2 = nn.Sequential(
ResNeXtUnit(256, 512, stride=2),
nn.ReLU(),
ResNeXtUnit(512, 512),
nn.ReLU(),
ResNeXtUnit(512, 512),
nn.ReLU()
) # 16x16
self.stage_3 = nn.Sequential(
ResNeXtUnit(512, 1024, stride=2),
nn.ReLU(),
ResNeXtUnit(1024, 1024),
nn.ReLU(),
ResNeXtUnit(1024, 1024),
nn.ReLU()
) # 8x8
self.pool = nn.AvgPool2d(8)
self.classifier = nn.Sequential(
nn.Linear(1024, class_num),
# nn.Softmax(dim=1)
)
def forward(self, x):
fea = self.basic_conv(x)
fea = self.stage_1(fea)
fea = self.stage_2(fea)
fea = self.stage_3(fea)
fea = self.pool(fea)
fea = torch.squeeze(fea)
fea = self.classifier(fea)
return fea
if __name__=='__main__':
x = torch.rand(8,3,32,32)
net = ResNeXt(10)
out = net(x)
print(out.shape)