Skip to content

Commit

Permalink
[cherry-pick] add loss info and skd distillation (#1612)
Browse files Browse the repository at this point in the history
* add skd distillation. (#1587)

* add skd distillation.

* update skd's test.

* [ACT] add loss info (#1597)

* add loss info on ACT training.

* Add flops info.
  • Loading branch information
zzjjay authored Dec 28, 2022
1 parent d521460 commit f68ec4b
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 22 deletions.
17 changes: 15 additions & 2 deletions paddleslim/auto_compression/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,19 @@ def analysis_prune(eval_function,
params_filename,
analysis_file,
pruned_ratios,
target_loss=None):
target_loss=None,
criterion='l1_norm'):
'''
Args:
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``.
model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None.
params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None.
analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned.
criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm.
'''

devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places)
Expand All @@ -47,7 +59,8 @@ def analysis_prune(eval_function,
eval_function,
sensitivities_file=analysis_file,
eval_args=[exe, feed_target_names, fetch_targets],
pruned_ratios=pruned_ratios)
pruned_ratios=pruned_ratios,
criterion=criterion)

with open(analysis_file, 'rb') as f:
if sys.version_info < (3, 0):
Expand Down
18 changes: 12 additions & 6 deletions paddleslim/auto_compression/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,13 +783,17 @@ def _start_train(self, train_program_info, test_program_info, strategy,
total_epochs = train_config.epochs if train_config.epochs else 100
total_train_iter = 0
stop_training = False

loss_vars = [var for var in train_program_info.loss_dict.values()]
loss_names = [name for name in train_program_info.loss_dict.keys()]

for epoch_id in range(total_epochs):
if stop_training:
break
for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \
loss = self._exe.run(train_program_info.program, \
feed=data, \
fetch_list=train_program_info.fetch_targets)
fetch_list=train_program_info.fetch_targets+loss_vars)
if not isinstance(train_program_info.learning_rate, float):
train_program_info.learning_rate.step()
if 'unstructure' in strategy:
Expand All @@ -800,10 +804,12 @@ def _start_train(self, train_program_info, test_program_info, strategy,
else:
logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0:
_logger.info(
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
total_train_iter, epoch_id, batch_id,
np_probs_float))
print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
total_train_iter, epoch_id, batch_id, loss[0])
for idx, loss_value in enumerate(loss[1:]):
print_info += '{}: {} '.format(loss_names[idx],
loss_value)
_logger.info(print_info)
total_train_iter += 1
if total_train_iter % int(
train_config.eval_iter) == 0 and total_train_iter != 0:
Expand Down
30 changes: 20 additions & 10 deletions paddleslim/auto_compression/create_compressed_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..common import get_logger
from .strategy_config import ProgramInfo
from ..common.load_model import load_inference_model
from ..analysis import flops

_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
Expand Down Expand Up @@ -118,7 +119,7 @@ def _parse_distill_loss(distill_node_pair,
distill_lambda=1.0):
"""parse distill loss config"""
loss_dist = 0.0
losses = []
losses = {}
if isinstance(distill_node_pair[0], str):
assert isinstance(distill_loss, str)
assert isinstance(distill_lambda, float)
Expand All @@ -128,16 +129,17 @@ def _parse_distill_loss(distill_node_pair,

assert len(distill_node_pair) == len(distill_loss)
assert len(distill_node_pair) == len(distill_lambda)
for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda):
tmp_loss = 0.0
_logger.info("train config.distill_node_pair: {}".format(node, loss,
lam))
for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
distill_lambda):
tmp_loss = losses.get(loss_clas, 0.0)
_logger.info("train config.distill_node_pair: {}".format(
node, loss_clas, lam))
assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2):
tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1])
loss_dist += lam * tmp_loss
losses.append(tmp_loss)
tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
loss_dist += tmp_loss
losses[loss_clas] = tmp_loss

return loss_dist, losses

Expand Down Expand Up @@ -313,7 +315,7 @@ def build_distill_program(executor,
use_dynamic_loss_scaling=True,
**train_config['amp_config'])

distill_loss, losses = _parse_distill_loss(
distill_loss, loss_dict = _parse_distill_loss(
distill_node_pair,
config.get('loss') or 'l2', ### default loss is l2
config.get('alpha') or 1.0) ### default alpha is 1.0
Expand All @@ -334,7 +336,7 @@ def build_distill_program(executor,

train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, train_fetch_list,
optimizer, learning_rate)
optimizer, learning_rate, loss_dict)
test_program_info = ProgramInfo(startup_program, test_program,
feed_target_names, fetch_targets)
return train_program_info, test_program_info
Expand Down Expand Up @@ -469,6 +471,8 @@ def build_prune_program(executor,
params.append(param.name)
original_shapes[param.name] = param.shape

origin_flops = flops(train_program_info.program)

pruned_program, _, _ = pruner.prune(
train_program_info.program,
paddle.static.global_scope(),
Expand All @@ -485,6 +489,12 @@ def build_prune_program(executor,
param.name, original_shapes[param.name], param.shape))
_logger.info(
"####################channel pruning end##########################")

final_flops = flops(pruned_program)
pruned_flops = abs(origin_flops - final_flops) / origin_flops
_logger.info("FLOPs before pruning: {}".format(origin_flops))
_logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
final_flops, round(pruned_flops * 100, 2)))
train_program_info.program = pruned_program

elif strategy.startswith('asp'):
Expand Down
5 changes: 4 additions & 1 deletion paddleslim/auto_compression/strategy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def __init__(self,
feed_target_names,
fetch_targets,
optimizer=None,
learning_rate=None):
learning_rate=None,
loss_dict=None):
"""
ProgramInfo Config.
Args:
Expand All @@ -441,10 +442,12 @@ def __init__(self,
fetch_targets(list(Variable)): The fetch variable in the program.
optimizer(Optimizer, optional): Optimizer in training. Default: None.
learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None.
loss_dict(dict): The components of losses.
"""
self.startup_program = startup_program
self.program = program
self.feed_target_names = feed_target_names
self.fetch_targets = fetch_targets
self.optimizer = optimizer
self.learning_rate = learning_rate
self.loss_dict = loss_dict
2 changes: 1 addition & 1 deletion paddleslim/dist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .single_distiller import merge, fsp, l2, soft_label, loss, dkd
from .single_distiller import merge, fsp, l2, soft_label, loss, dkd, skd
from .dml import DML
58 changes: 56 additions & 2 deletions paddleslim/dist/single_distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import paddle
from paddleslim.core import GraphWrapper
import paddle.nn.functional as F


def merge(teacher_program,
Expand Down Expand Up @@ -203,8 +204,11 @@ def soft_label(teacher_var_name,
teacher_var = paddle.nn.functional.softmax(teacher_var /
teacher_temperature)
soft_label_loss = paddle.mean(
paddle.fluid.layers.cross_entropy(
student_var, teacher_var, soft_label=True))
paddle.nn.functional.cross_entropy(
input=student_var,
label=teacher_var,
soft_label=True,
use_softmax=False))
return soft_label_loss


Expand Down Expand Up @@ -305,3 +309,53 @@ def dkd(teacher_var_name,
temperature=temperature,
alpha=alpha,
beta=beta)


def skd(teacher_var_name, student_var_name, program=None, multiplier=None):
"""Combine variables from student model and teacher model
by Spherical Knowledge Distillation loss (aka. skd-loss).
Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program. If not specified,
the default program will be used. Default: None
multiplier(float): The multiplier to recover its norm to the original
level. When it's None, the appropriate multiplier can be computed by
teacher's logits with paddle.std(output_t, axis=1). Default: None.
Returns:
Variable: skd distiller loss.
"""
if program == None:
program = paddle.static.default_main_program()

student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
teacher_var.stop_gradient = True

if multiplier is None:
multiplier = paddle.std(teacher_var, axis=1, keepdim=True)

logits_student = F.layer_norm(
student_var,
student_var.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * multiplier
logits_teacher = F.layer_norm(
teacher_var,
teacher_var.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * multiplier

student_out = F.softmax(logits_student, axis=1)
teacher_out = F.softmax(logits_teacher, axis=1)
skd_loss = paddle.mean(
F.cross_entropy(
input=student_out,
label=teacher_out,
soft_label=True,
use_softmax=False))
return skd_loss
81 changes: 81 additions & 0 deletions tests/test_skd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle
from paddleslim.dist import merge, skd
from layers import conv_bn_layer
from static_case import StaticCase


class TestSKDLoss(StaticCase):
def test_skd_loss(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)

student_program = paddle.static.Program()
student_startup = paddle.static.Program()
with paddle.static.program_guard(student_program, student_startup):
with paddle.utils.unique_name.guard():
input = paddle.static.data(
name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
student_predict = conv1 + conv2

teacher_program = paddle.static.Program()
teacher_startup = paddle.static.Program()
with paddle.static.program_guard(teacher_program, teacher_startup):
with paddle.utils.unique_name.guard():
input = paddle.static.data(
name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
teacher_predict = conv_bn_layer(conv5, 8, 3, "conv6")

exe.run(teacher_startup)
exe.run(student_startup)

data_name_map = {'image': 'image'}
merge(teacher_program, student_program, data_name_map, place)
merged_ops = []
for block in student_program.blocks:
for op in block.ops:
merged_ops.append(op.type)
with paddle.static.program_guard(student_program, student_startup):
distill_loss = skd('teacher_' + teacher_predict.name,
student_predict.name,
program=None,
multiplier=None)

loss_ops = []
for block in student_program.blocks:
for op in block.ops:
loss_ops.append(op.type)
print(f"ret: {set(loss_ops).difference(set(merged_ops))}")
self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set())

self.assertTrue({
'softmax_with_cross_entropy', 'softmax', 'reduce_mean', 'layer_norm'
}.issubset(set(loss_ops).difference(set(merged_ops))))


if __name__ == '__main__':
unittest.main()

0 comments on commit f68ec4b

Please sign in to comment.