5
5
6
6
from hyperbox .utils .calc_model_size import flops_size_counter
7
7
8
- if __name__ == '__main__' :
8
+ def test_groups (cin , cout , group_list ):
9
+ print (f"test_groups: cin={ cin } , cout={ cout } , group_list={ group_list } \n " )
10
+ groups = ValueSpace (candidates = group_list )
11
+ x = torch .rand (2 , cin , 64 , 64 )
12
+ conv = ops .Conv2d (cin , cout , 3 , 1 , 1 , groups = groups )
13
+ m = RandomMutator (conv )
14
+ for i in range (10 ):
15
+ m .reset ()
16
+ print (f'\n *******step{ i } ********\n ' , conv )
17
+ y = conv (x )
18
+
19
+ def test_cin_groups (cin_list , cout , group_list ):
20
+ print (f"test_cin_groups: cin_list={ cin_list } , cout={ cout } , group_list={ group_list } \n " )
21
+ cin = ValueSpace (candidates = cin_list )
22
+ groups = ValueSpace (candidates = group_list )
23
+ x = torch .rand (2 , 3 , 64 , 64 )
24
+ conv = torch .nn .Sequential (
25
+ ops .Conv2d (3 , cin , 3 , 1 , 1 ),
26
+ ops .Conv2d (cin , cout , 3 , 1 , 1 , groups = cin )
27
+ )
28
+ m = RandomMutator (conv )
29
+ for i in range (10 ):
30
+ m .reset ()
31
+ print (f'\n *******step{ i } ********\n ' , conv )
32
+ y = conv (x )
33
+
34
+ def test_cout_groups (cin , cout_list , group_list ):
35
+ print (f"test_cout_groups: cin={ cin } , cout_list={ cout_list } , group_list={ group_list } \n " )
36
+ cout = ValueSpace (candidates = cout_list )
37
+ groups = ValueSpace (candidates = group_list )
38
+ x = torch .rand (2 , cin , 64 , 64 )
39
+ conv = ops .Conv2d (cin , cout , 3 , 1 , 1 , groups = groups )
40
+ m = RandomMutator (conv )
41
+ for i in range (10 ):
42
+ m .reset ()
43
+ print (f'\n *******step{ i } ********\n ' , conv )
44
+ y = conv (x )
45
+
46
+ def test_cin_cout_groups (cin_list , cout_list , group_list ):
47
+ print (f"test_cin_cout_groups: cin_list={ cin_list } , cout_list={ cout_list } , group_list={ group_list } \n " )
48
+ cin = ValueSpace (candidates = cin_list )
49
+ cout = ValueSpace (candidates = cout_list )
50
+ groups = ValueSpace (candidates = group_list )
51
+ x = torch .rand (2 , 3 , 64 , 64 )
52
+ conv = torch .nn .Sequential (
53
+ ops .Conv2d (3 , cin , 3 , 1 , 1 ),
54
+ ops .Conv2d (cin , cout , 3 , 1 , 1 , groups = cin )
55
+ )
56
+ m = RandomMutator (conv )
57
+ for i in range (10 ):
58
+ m .reset ()
59
+ print (f'\n *******step{ i } ********\n ' , conv )
60
+ y = conv (x )
61
+
62
+ def test_conv ():
63
+ print ('testing conv flops and sizes' )
9
64
x = torch .rand (1 ,3 ,64 ,64 )
10
65
vs1 = ValueSpace ([10 ,2 ])
11
66
vs2 = ValueSpace ([3 ,5 ,7 ])
17
72
)
18
73
m = RandomMutator (op )
19
74
m .reset ()
20
- print (op )
21
- print (m ._cache )
22
- print (conv .weight .shape )
23
- print (conv (x ).shape )
24
- r = flops_size_counter (op , (1 ,3 ,8 ,8 ), True , True )
25
- print (r )
75
+ # print(op)
76
+ # print(m._cache)
77
+ # print(conv.weight.shape)
78
+ # print(conv(x).shape)
79
+ r = flops_size_counter (op , (1 ,3 ,8 ,8 ), False , False )
80
+ # print(r)
26
81
op = torch .nn .Sequential (
27
82
ops .Conv2d (3 ,8 ,3 ,1 ),
28
83
ops .BatchNorm2d (8 )
29
84
)
30
- r = flops_size_counter (op , (1 ,3 ,8 ,8 ), False , True )
31
- print (conv (x ).shape )
85
+ r = flops_size_counter (op , (1 ,3 ,8 ,8 ), False , False )
86
+ # print(conv(x).shape)
87
+
88
+ if __name__ == '__main__' :
89
+ test_conv ()
90
+ test_groups (32 , 64 , group_list = [1 , 2 , 4 , 8 , 16 , 32 ])
91
+ test_cin_groups ([32 , 64 ], 128 , group_list = [32 , 64 ])
92
+ test_cout_groups (8 , [16 , 32 ], group_list = [1 , 2 , 4 , 8 ])
93
+ test_cin_cout_groups ([8 , 16 ], [32 , 64 ], group_list = [1 , 8 , 16 ])
0 commit comments