Skip to content

Commit 682cbcb

Browse files
author
marsggbo
committed
refactor mutables.ops.conv when searching groups & add test
1 parent dd0dd10 commit 682cbcb

File tree

2 files changed

+82
-20
lines changed

2 files changed

+82
-20
lines changed

hyperbox/mutables/ops/conv.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,16 @@ def __init__(
5959
_stride = stride.max_value if isinstance(stride, ValueSpace) else stride
6060
_padding = padding.max_value if isinstance(padding, ValueSpace) else padding
6161
_dilation = dilation.min_value if isinstance(dilation, ValueSpace) else dilation
62-
_groups = groups.min_value if isinstance(groups, ValueSpace) else groups
62+
63+
_groups = groups
64+
if isinstance(groups, ValueSpace):
65+
if isinstance(in_channels, ValueSpace):
66+
if groups is not in_channels:
67+
print('groups must be the same as in_channels when in_channels is ValueSpace')
68+
groups = in_channels
69+
_groups = groups.max_value
70+
else:
71+
_groups = groups.min_value
6372

6473
self.format_args(_kernel_size, _stride, _padding, _dilation)
6574
conv_kwargs = kwargs
@@ -199,7 +208,7 @@ def forward_conv(self, x):
199208
filters = filters[:out_channels, ...]
200209
if self.search_kernel_size:
201210
filters = self.transform_kernel_size(filters)
202-
if self.search_groups:
211+
if self.search_groups and not self.search_in_channel:
203212
filters = self.get_filters_by_groups(filters, in_channels, groups).contiguous()
204213
if self.auto_padding:
205214
kernel_size = filters.shape[2:]
@@ -222,15 +231,6 @@ def get_filters_by_groups(self, filters, in_channels, groups):
222231
start = part_id * sub_in_channels
223232
filter_crops.append(sub_filter[:, start:start + sub_in_channels, :, :])
224233
filters = torch.cat(filter_crops, dim=0)
225-
226-
# indices = []
227-
# for i in range(groups):
228-
# part_id = i % sub_ratio
229-
# start = part_id * sub_in_channels
230-
# indices.extend(list(range(start, start + sub_in_channels)))
231-
# print(f"groups={groups}, in_channels={in_channels}, indices={indices}")
232-
# filters = filters[:, indices, :, :]
233-
print(f"groups={groups}, in_channels={in_channels}, filters.shape={filters.shape}")
234234
return filters
235235

236236
def transform_kernel_size(self, filters):

tests/mutables/test_op_conv.py

+71-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,62 @@
55

66
from hyperbox.utils.calc_model_size import flops_size_counter
77

8-
if __name__ == '__main__':
8+
def test_groups(cin, cout, group_list):
9+
print(f"test_groups: cin={cin}, cout={cout}, group_list={group_list}\n")
10+
groups = ValueSpace(candidates=group_list)
11+
x = torch.rand(2, cin, 64, 64)
12+
conv = ops.Conv2d(cin, cout, 3, 1, 1, groups=groups)
13+
m = RandomMutator(conv)
14+
for i in range(10):
15+
m.reset()
16+
print(f'\n*******step{i}********\n', conv)
17+
y = conv(x)
18+
19+
def test_cin_groups(cin_list, cout, group_list):
20+
print(f"test_cin_groups: cin_list={cin_list}, cout={cout}, group_list={group_list}\n")
21+
cin = ValueSpace(candidates=cin_list)
22+
groups = ValueSpace(candidates=group_list)
23+
x = torch.rand(2, 3, 64, 64)
24+
conv = torch.nn.Sequential(
25+
ops.Conv2d(3, cin, 3, 1, 1),
26+
ops.Conv2d(cin, cout, 3, 1, 1, groups=cin)
27+
)
28+
m = RandomMutator(conv)
29+
for i in range(10):
30+
m.reset()
31+
print(f'\n*******step{i}********\n', conv)
32+
y = conv(x)
33+
34+
def test_cout_groups(cin, cout_list, group_list):
35+
print(f"test_cout_groups: cin={cin}, cout_list={cout_list}, group_list={group_list}\n")
36+
cout = ValueSpace(candidates=cout_list)
37+
groups = ValueSpace(candidates=group_list)
38+
x = torch.rand(2, cin, 64, 64)
39+
conv = ops.Conv2d(cin, cout, 3, 1, 1, groups=groups)
40+
m = RandomMutator(conv)
41+
for i in range(10):
42+
m.reset()
43+
print(f'\n*******step{i}********\n', conv)
44+
y = conv(x)
45+
46+
def test_cin_cout_groups(cin_list, cout_list, group_list):
47+
print(f"test_cin_cout_groups: cin_list={cin_list}, cout_list={cout_list}, group_list={group_list}\n")
48+
cin = ValueSpace(candidates=cin_list)
49+
cout = ValueSpace(candidates=cout_list)
50+
groups = ValueSpace(candidates=group_list)
51+
x = torch.rand(2, 3, 64, 64)
52+
conv = torch.nn.Sequential(
53+
ops.Conv2d(3, cin, 3, 1, 1),
54+
ops.Conv2d(cin, cout, 3, 1, 1, groups=cin)
55+
)
56+
m = RandomMutator(conv)
57+
for i in range(10):
58+
m.reset()
59+
print(f'\n*******step{i}********\n', conv)
60+
y = conv(x)
61+
62+
def test_conv():
63+
print('testing conv flops and sizes')
964
x = torch.rand(1,3,64,64)
1065
vs1 = ValueSpace([10,2])
1166
vs2 = ValueSpace([3,5,7])
@@ -17,15 +72,22 @@
1772
)
1873
m = RandomMutator(op)
1974
m.reset()
20-
print(op)
21-
print(m._cache)
22-
print(conv.weight.shape)
23-
print(conv(x).shape)
24-
r = flops_size_counter(op, (1,3,8,8), True, True)
25-
print(r)
75+
# print(op)
76+
# print(m._cache)
77+
# print(conv.weight.shape)
78+
# print(conv(x).shape)
79+
r = flops_size_counter(op, (1,3,8,8), False, False)
80+
# print(r)
2681
op = torch.nn.Sequential(
2782
ops.Conv2d(3,8,3,1),
2883
ops.BatchNorm2d(8)
2984
)
30-
r = flops_size_counter(op, (1,3,8,8), False, True)
31-
print(conv(x).shape)
85+
r = flops_size_counter(op, (1,3,8,8), False, False)
86+
# print(conv(x).shape)
87+
88+
if __name__ == '__main__':
89+
test_conv()
90+
test_groups(32, 64, group_list=[1, 2, 4, 8, 16, 32])
91+
test_cin_groups([32, 64], 128, group_list=[32, 64])
92+
test_cout_groups(8, [16, 32], group_list=[1, 2, 4, 8])
93+
test_cin_cout_groups([8, 16], [32, 64], group_list=[1, 8, 16])

0 commit comments

Comments
 (0)