1
+ # Copyright 2019 Google LLC
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Example of mnist model with pruning.
17
+ Adapted from TF model optimization example."""
18
+
19
+ import tempfile
20
+ import numpy as np
21
+
22
+ import tensorflow .keras .backend as K
23
+ from tensorflow .keras .datasets import mnist
24
+ from tensorflow .keras .layers import Activation
25
+ from tensorflow .keras .layers import Flatten
26
+ from tensorflow .keras .layers import Input
27
+ from tensorflow .keras .models import Model
28
+ from tensorflow .keras .models import Sequential
29
+ from tensorflow .keras .models import save_model
30
+ from tensorflow .keras .utils import to_categorical
31
+
32
+ from qkeras import QActivation
33
+ from qkeras import QDense
34
+ from qkeras import QConv2D
35
+ from qkeras import quantized_bits
36
+ from qkeras .utils import load_qmodel
37
+ from qkeras .utils import print_model_sparsity
38
+
39
+ from tensorflow_model_optimization .python .core .sparsity .keras import prune
40
+ from tensorflow_model_optimization .python .core .sparsity .keras import pruning_callbacks
41
+ from tensorflow_model_optimization .python .core .sparsity .keras import pruning_schedule
42
+
43
+
44
+ batch_size = 128
45
+ num_classes = 10
46
+ epochs = 12
47
+
48
+ prune_whole_model = True # Prune whole model or just specified layers
49
+
50
+
51
+ def build_model (input_shape ):
52
+ x = x_in = Input (shape = input_shape , name = "input" )
53
+ x = QConv2D (
54
+ 32 , (2 , 2 ), strides = (2 ,2 ),
55
+ kernel_quantizer = quantized_bits (4 ,0 ,1 ),
56
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
57
+ name = "conv2d_0_m" )(x )
58
+ x = QActivation ("quantized_relu(4,0)" , name = "act0_m" )(x )
59
+ x = QConv2D (
60
+ 64 , (3 , 3 ), strides = (2 ,2 ),
61
+ kernel_quantizer = quantized_bits (4 ,0 ,1 ),
62
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
63
+ name = "conv2d_1_m" )(x )
64
+ x = QActivation ("quantized_relu(4,0)" , name = "act1_m" )(x )
65
+ x = QConv2D (
66
+ 64 , (2 , 2 ), strides = (2 ,2 ),
67
+ kernel_quantizer = quantized_bits (4 ,0 ,1 ),
68
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
69
+ name = "conv2d_2_m" )(x )
70
+ x = QActivation ("quantized_relu(4,0)" , name = "act2_m" )(x )
71
+ x = Flatten ()(x )
72
+ x = QDense (num_classes , kernel_quantizer = quantized_bits (4 ,0 ,1 ),
73
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
74
+ name = "dense" )(x )
75
+ x = Activation ("softmax" , name = "softmax" )(x )
76
+
77
+ model = Model (inputs = [x_in ], outputs = [x ])
78
+ return model
79
+
80
+
81
+ def build_layerwise_model (input_shape , ** pruning_params ):
82
+ return Sequential ([
83
+ prune .prune_low_magnitude (
84
+ QConv2D (
85
+ 32 , (2 , 2 ), strides = (2 ,2 ),
86
+ kernel_quantizer = quantized_bits (4 ,0 ,1 ),
87
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
88
+ name = "conv2d_0_m" ),
89
+ input_shape = input_shape ,
90
+ ** pruning_params ),
91
+ QActivation ("quantized_relu(4,0)" , name = "act0_m" ),
92
+ prune .prune_low_magnitude (
93
+ QConv2D (
94
+ 64 , (3 , 3 ), strides = (2 ,2 ),
95
+ kernel_quantizer = quantized_bits (4 ,0 ,1 ),
96
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
97
+ name = "conv2d_1_m" ),
98
+ ** pruning_params ),
99
+ QActivation ("quantized_relu(4,0)" , name = "act1_m" ),
100
+ prune .prune_low_magnitude (
101
+ QConv2D (
102
+ 64 , (2 , 2 ), strides = (2 ,2 ),
103
+ kernel_quantizer = quantized_bits (4 ,0 ,1 ),
104
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
105
+ name = "conv2d_2_m" ),
106
+ ** pruning_params ),
107
+ QActivation ("quantized_relu(4,0)" , name = "act2_m" ),
108
+ Flatten (),
109
+ prune .prune_low_magnitude (
110
+ QDense (
111
+ num_classes , kernel_quantizer = quantized_bits (4 ,0 ,1 ),
112
+ bias_quantizer = quantized_bits (4 ,0 ,1 ),
113
+ name = "dense" ),
114
+ ** pruning_params ),
115
+ Activation ("softmax" , name = "softmax" )
116
+ ])
117
+
118
+
119
+ def train_and_save (model , x_train , y_train , x_test , y_test ):
120
+ model .compile (
121
+ loss = "categorical_crossentropy" ,
122
+ optimizer = "adam" ,
123
+ metrics = ["accuracy" ])
124
+
125
+ # Print the model summary.
126
+ model .summary ()
127
+
128
+ # Add a pruning step callback to peg the pruning step to the optimizer's
129
+ # step. Also add a callback to add pruning summaries to tensorboard
130
+ callbacks = [
131
+ pruning_callbacks .UpdatePruningStep (),
132
+ #pruning_callbacks.PruningSummaries(log_dir=tempfile.mkdtemp())
133
+ pruning_callbacks .PruningSummaries (log_dir = "/tmp/mnist_prune" )
134
+ ]
135
+
136
+ model .fit (
137
+ x_train ,
138
+ y_train ,
139
+ batch_size = batch_size ,
140
+ epochs = epochs ,
141
+ verbose = 1 ,
142
+ callbacks = callbacks ,
143
+ validation_data = (x_test , y_test ))
144
+ score = model .evaluate (x_test , y_test , verbose = 0 )
145
+ print ("Test loss:" , score [0 ])
146
+ print ("Test accuracy:" , score [1 ])
147
+
148
+ print_model_sparsity (model )
149
+
150
+ # Export and import the model. Check that accuracy persists.
151
+ _ , keras_file = tempfile .mkstemp (".h5" )
152
+ print ("Saving model to: " , keras_file )
153
+ save_model (model , keras_file )
154
+
155
+ print ("Reloading model" )
156
+ with prune .prune_scope ():
157
+ loaded_model = load_qmodel (keras_file )
158
+ score = loaded_model .evaluate (x_test , y_test , verbose = 0 )
159
+ print ("Test loss:" , score [0 ])
160
+ print ("Test accuracy:" , score [1 ])
161
+
162
+
163
+ def main ():
164
+ # input image dimensions
165
+ img_rows , img_cols = 28 , 28
166
+
167
+ # the data, shuffled and split between train and test sets
168
+ (x_train , y_train ), (x_test , y_test ) = mnist .load_data ()
169
+
170
+ if K .image_data_format () == "channels_first" :
171
+ x_train = x_train .reshape (x_train .shape [0 ], 1 , img_rows , img_cols )
172
+ x_test = x_test .reshape (x_test .shape [0 ], 1 , img_rows , img_cols )
173
+ input_shape = (1 , img_rows , img_cols )
174
+ else :
175
+ x_train = x_train .reshape (x_train .shape [0 ], img_rows , img_cols , 1 )
176
+ x_test = x_test .reshape (x_test .shape [0 ], img_rows , img_cols , 1 )
177
+ input_shape = (img_rows , img_cols , 1 )
178
+
179
+ x_train = x_train .astype ("float32" )
180
+ x_test = x_test .astype ("float32" )
181
+ x_train /= 255
182
+ x_test /= 255
183
+ print ("x_train shape:" , x_train .shape )
184
+ print (x_train .shape [0 ], "train samples" )
185
+ print (x_test .shape [0 ], "test samples" )
186
+
187
+ # convert class vectors to binary class matrices
188
+ y_train = to_categorical (y_train , num_classes )
189
+ y_test = to_categorical (y_test , num_classes )
190
+
191
+ pruning_params = {
192
+ "pruning_schedule" :
193
+ pruning_schedule .ConstantSparsity (0.75 , begin_step = 2000 , frequency = 100 )
194
+ }
195
+
196
+ if prune_whole_model :
197
+ model = build_model (input_shape )
198
+ model = prune .prune_low_magnitude (model , ** pruning_params )
199
+ else :
200
+ model = build_layerwise_model (input_shape , ** pruning_params )
201
+
202
+ train_and_save (model , x_train , y_train , x_test , y_test )
203
+
204
+
205
+ if __name__ == "__main__" :
206
+ main ()
0 commit comments