Skip to content

Commit bc70882

Browse files
Merge pull request #18 from vloncar:pruning
PiperOrigin-RevId: 297473595 Change-Id: I300a5d241523d0ea09f4d2004e64f94dde6748d3
2 parents 8f17e33 + 3fe9a31 commit bc70882

File tree

6 files changed

+269
-6
lines changed

6 files changed

+269
-6
lines changed

examples/example_mnist_prune.py

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Example of mnist model with pruning.
17+
Adapted from TF model optimization example."""
18+
19+
import tempfile
20+
import numpy as np
21+
22+
import tensorflow.keras.backend as K
23+
from tensorflow.keras.datasets import mnist
24+
from tensorflow.keras.layers import Activation
25+
from tensorflow.keras.layers import Flatten
26+
from tensorflow.keras.layers import Input
27+
from tensorflow.keras.models import Model
28+
from tensorflow.keras.models import Sequential
29+
from tensorflow.keras.models import save_model
30+
from tensorflow.keras.utils import to_categorical
31+
32+
from qkeras import QActivation
33+
from qkeras import QDense
34+
from qkeras import QConv2D
35+
from qkeras import quantized_bits
36+
from qkeras.utils import load_qmodel
37+
from qkeras.utils import print_model_sparsity
38+
39+
from tensorflow_model_optimization.python.core.sparsity.keras import prune
40+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
41+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
42+
43+
44+
batch_size = 128
45+
num_classes = 10
46+
epochs = 12
47+
48+
prune_whole_model = True # Prune whole model or just specified layers
49+
50+
51+
def build_model(input_shape):
52+
x = x_in = Input(shape=input_shape, name="input")
53+
x = QConv2D(
54+
32, (2, 2), strides=(2,2),
55+
kernel_quantizer=quantized_bits(4,0,1),
56+
bias_quantizer=quantized_bits(4,0,1),
57+
name="conv2d_0_m")(x)
58+
x = QActivation("quantized_relu(4,0)", name="act0_m")(x)
59+
x = QConv2D(
60+
64, (3, 3), strides=(2,2),
61+
kernel_quantizer=quantized_bits(4,0,1),
62+
bias_quantizer=quantized_bits(4,0,1),
63+
name="conv2d_1_m")(x)
64+
x = QActivation("quantized_relu(4,0)", name="act1_m")(x)
65+
x = QConv2D(
66+
64, (2, 2), strides=(2,2),
67+
kernel_quantizer=quantized_bits(4,0,1),
68+
bias_quantizer=quantized_bits(4,0,1),
69+
name="conv2d_2_m")(x)
70+
x = QActivation("quantized_relu(4,0)", name="act2_m")(x)
71+
x = Flatten()(x)
72+
x = QDense(num_classes, kernel_quantizer=quantized_bits(4,0,1),
73+
bias_quantizer=quantized_bits(4,0,1),
74+
name="dense")(x)
75+
x = Activation("softmax", name="softmax")(x)
76+
77+
model = Model(inputs=[x_in], outputs=[x])
78+
return model
79+
80+
81+
def build_layerwise_model(input_shape, **pruning_params):
82+
return Sequential([
83+
prune.prune_low_magnitude(
84+
QConv2D(
85+
32, (2, 2), strides=(2,2),
86+
kernel_quantizer=quantized_bits(4,0,1),
87+
bias_quantizer=quantized_bits(4,0,1),
88+
name="conv2d_0_m"),
89+
input_shape=input_shape,
90+
**pruning_params),
91+
QActivation("quantized_relu(4,0)", name="act0_m"),
92+
prune.prune_low_magnitude(
93+
QConv2D(
94+
64, (3, 3), strides=(2,2),
95+
kernel_quantizer=quantized_bits(4,0,1),
96+
bias_quantizer=quantized_bits(4,0,1),
97+
name="conv2d_1_m"),
98+
**pruning_params),
99+
QActivation("quantized_relu(4,0)", name="act1_m"),
100+
prune.prune_low_magnitude(
101+
QConv2D(
102+
64, (2, 2), strides=(2,2),
103+
kernel_quantizer=quantized_bits(4,0,1),
104+
bias_quantizer=quantized_bits(4,0,1),
105+
name="conv2d_2_m"),
106+
**pruning_params),
107+
QActivation("quantized_relu(4,0)", name="act2_m"),
108+
Flatten(),
109+
prune.prune_low_magnitude(
110+
QDense(
111+
num_classes, kernel_quantizer=quantized_bits(4,0,1),
112+
bias_quantizer=quantized_bits(4,0,1),
113+
name="dense"),
114+
**pruning_params),
115+
Activation("softmax", name="softmax")
116+
])
117+
118+
119+
def train_and_save(model, x_train, y_train, x_test, y_test):
120+
model.compile(
121+
loss="categorical_crossentropy",
122+
optimizer="adam",
123+
metrics=["accuracy"])
124+
125+
# Print the model summary.
126+
model.summary()
127+
128+
# Add a pruning step callback to peg the pruning step to the optimizer's
129+
# step. Also add a callback to add pruning summaries to tensorboard
130+
callbacks = [
131+
pruning_callbacks.UpdatePruningStep(),
132+
#pruning_callbacks.PruningSummaries(log_dir=tempfile.mkdtemp())
133+
pruning_callbacks.PruningSummaries(log_dir="/tmp/mnist_prune")
134+
]
135+
136+
model.fit(
137+
x_train,
138+
y_train,
139+
batch_size=batch_size,
140+
epochs=epochs,
141+
verbose=1,
142+
callbacks=callbacks,
143+
validation_data=(x_test, y_test))
144+
score = model.evaluate(x_test, y_test, verbose=0)
145+
print("Test loss:", score[0])
146+
print("Test accuracy:", score[1])
147+
148+
print_model_sparsity(model)
149+
150+
# Export and import the model. Check that accuracy persists.
151+
_, keras_file = tempfile.mkstemp(".h5")
152+
print("Saving model to: ", keras_file)
153+
save_model(model, keras_file)
154+
155+
print("Reloading model")
156+
with prune.prune_scope():
157+
loaded_model = load_qmodel(keras_file)
158+
score = loaded_model.evaluate(x_test, y_test, verbose=0)
159+
print("Test loss:", score[0])
160+
print("Test accuracy:", score[1])
161+
162+
163+
def main():
164+
# input image dimensions
165+
img_rows, img_cols = 28, 28
166+
167+
# the data, shuffled and split between train and test sets
168+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
169+
170+
if K.image_data_format() == "channels_first":
171+
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
172+
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
173+
input_shape = (1, img_rows, img_cols)
174+
else:
175+
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
176+
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
177+
input_shape = (img_rows, img_cols, 1)
178+
179+
x_train = x_train.astype("float32")
180+
x_test = x_test.astype("float32")
181+
x_train /= 255
182+
x_test /= 255
183+
print("x_train shape:", x_train.shape)
184+
print(x_train.shape[0], "train samples")
185+
print(x_test.shape[0], "test samples")
186+
187+
# convert class vectors to binary class matrices
188+
y_train = to_categorical(y_train, num_classes)
189+
y_test = to_categorical(y_test, num_classes)
190+
191+
pruning_params = {
192+
"pruning_schedule":
193+
pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100)
194+
}
195+
196+
if prune_whole_model:
197+
model = build_model(input_shape)
198+
model = prune.prune_low_magnitude(model, **pruning_params)
199+
else:
200+
model = build_layerwise_model(input_shape, **pruning_params)
201+
202+
train_and_save(model, x_train, y_train, x_test, y_test)
203+
204+
205+
if __name__ == "__main__":
206+
main()

qkeras/qconvolutional.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@
2828
from tensorflow.keras.layers import Dropout
2929
from tensorflow.keras.layers import InputSpec
3030
from tensorflow.keras.layers import Layer
31+
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer
3132

3233
from .qlayers import Clip
3334
from .qlayers import QActivation
3435
from .quantizers import get_quantizer
3536
from .quantizers import get_quantized_initializer
3637

3738

38-
class QConv1D(Conv1D):
39+
class QConv1D(Conv1D, PrunableLayer):
3940
"""1D convolution layer (e.g. spatial convolution over images)."""
4041

4142
# most of these parameters follow the implementation of Conv1D in Keras,
@@ -155,8 +156,11 @@ def get_config(self):
155156
def get_quantizers(self):
156157
return self.quantizers
157158

159+
def get_prunable_weights(self):
160+
return [self.kernel]
158161

159-
class QConv2D(Conv2D):
162+
163+
class QConv2D(Conv2D, PrunableLayer):
160164
"""2D convolution layer (e.g. spatial convolution over images)."""
161165

162166
# most of these parameters follow the implementation of Conv2D in Keras,
@@ -284,8 +288,11 @@ def get_config(self):
284288
def get_quantizers(self):
285289
return self.quantizers
286290

291+
def get_prunable_weights(self):
292+
return [self.kernel]
293+
287294

288-
class QDepthwiseConv2D(DepthwiseConv2D):
295+
class QDepthwiseConv2D(DepthwiseConv2D, PrunableLayer):
289296
"""Creates quantized depthwise conv2d. Copied from mobilenet."""
290297

291298
# most of these parameters follow the implementation of DepthwiseConv2D
@@ -457,6 +464,9 @@ def get_config(self):
457464
def get_quantizers(self):
458465
return self.quantizers
459466

467+
def get_prunable_weights(self):
468+
return []
469+
460470

461471
def QSeparableConv2D(filters, # pylint: disable=invalid-name
462472
kernel_size,

qkeras/qlayers.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from tensorflow.keras.constraints import Constraint
4646
from tensorflow.keras.layers import Dense
4747
from tensorflow.keras.layers import Layer
48+
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer
4849

4950

5051
import numpy as np
@@ -60,7 +61,7 @@
6061
#
6162

6263

63-
class QActivation(Layer):
64+
class QActivation(Layer, PrunableLayer):
6465
"""Implements quantized activation layers."""
6566

6667
def __init__(self, activation, **kwargs):
@@ -97,6 +98,9 @@ def get_config(self):
9798
def compute_output_shape(self, input_shape):
9899
return input_shape
99100

101+
def get_prunable_weights(self):
102+
return []
103+
100104

101105
#
102106
# Constraint class to clip weights and bias between -1 and 1 so that:
@@ -149,7 +153,7 @@ def get_config(self):
149153
#
150154

151155

152-
class QDense(Dense):
156+
class QDense(Dense, PrunableLayer):
153157
"""Implements a quantized Dense layer."""
154158

155159
# most of these parameters follow the implementation of Dense in
@@ -284,3 +288,7 @@ def get_config(self):
284288

285289
def get_quantizers(self):
286290
return self.quantizers
291+
292+
def get_prunable_weights(self):
293+
return [self.kernel]
294+

qkeras/qnormalization.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorflow.python.ops import array_ops
3232
from tensorflow.python.ops import math_ops
3333
from tensorflow.python.ops import nn
34+
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer
3435

3536
import numpy as np
3637
import six
@@ -40,7 +41,7 @@
4041
from .safe_eval import safe_eval
4142

4243

43-
class QBatchNormalization(BatchNormalization):
44+
class QBatchNormalization(BatchNormalization, PrunableLayer):
4445
"""Quantized Batch Normalization layer.
4546
For training, mean and variance are not quantized.
4647
For inference, the quantized moving mean and moving variance are used.
@@ -302,3 +303,7 @@ def compute_output_shape(self, input_shape):
302303

303304
def get_quantizers(self):
304305
return self.quantizers
306+
307+
def get_prunable_weights(self):
308+
return []
309+

qkeras/utils.py

+33
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818
import six
1919

2020
import tensorflow as tf
21+
import tensorflow.keras.backend as K
2122

2223
from tensorflow.keras import initializers
2324
from tensorflow.keras.models import model_from_json
2425

26+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
27+
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
28+
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
29+
2530
import numpy as np
2631

2732
from .qlayers import QActivation
@@ -472,3 +477,31 @@ def load_qmodel(filepath, custom_objects=None, compile=True):
472477
qmodel = tf.keras.models.load_model(filepath, custom_objects=custom_objects, compile=compile)
473478

474479
return qmodel
480+
481+
482+
def print_model_sparsity(model):
483+
"""Prints sparsity for the pruned layers in the model."""
484+
485+
def _get_sparsity(weights):
486+
return 1.0 - np.count_nonzero(weights) / float(weights.size)
487+
488+
print("Model Sparsity Summary ({})".format(model.name))
489+
print("--")
490+
for layer in model.layers:
491+
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
492+
prunable_weights = layer.layer.get_prunable_weights()
493+
elif isinstance(layer, prunable_layer.PrunableLayer):
494+
prunable_weights = layer.get_prunable_weights()
495+
elif prune_registry.PruneRegistry.supports(layer):
496+
weight_names = prune_registry.PruneRegistry._weight_names(layer)
497+
prunable_weights = [getattr(layer, weight) for weight in weight_names]
498+
else:
499+
prunable_weights = None
500+
if prunable_weights:
501+
print("{}: {}".format(
502+
layer.name, ", ".join([
503+
"({}, {})".format(weight.name,
504+
str(_get_sparsity(K.get_value(weight))))
505+
for weight in prunable_weights
506+
])))
507+
print("\n")

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pyasn1<0.5.0,>=0.4.6
88
requests<3,>=2.21.0
99
pyparsing
1010
pytest>=4.6.9
11+
tensorflow-model-optimization>=0.2.1

0 commit comments

Comments
 (0)