-
Notifications
You must be signed in to change notification settings - Fork 21
/
findLayersToReplace.m
51 lines (38 loc) · 1.67 KB
/
findLayersToReplace.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
% findLayersToReplace(lgraph) finds the single classification layer and the
% preceding learnable (fully connected or convolutional) layer of the layer
% graph lgraph.
function [learnableLayer,classLayer] = findLayersToReplace(lgraph)
if ~isa(lgraph,'nnet.cnn.LayerGraph')
error('Argument must be a LayerGraph object.')
end
% Get source, destination, and layer names.
src = string(lgraph.Connections.Source);
dst = string(lgraph.Connections.Destination);
layerNames = string({lgraph.Layers.Name}');
% Find the classification layer. The layer graph must have a single
% classification layer.
isClassificationLayer = arrayfun(@(l) ...
(isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
lgraph.Layers);
if sum(isClassificationLayer) ~= 1
error('Layer graph must have a single classification layer.')
end
classLayer = lgraph.Layers(isClassificationLayer);
% Traverse the layer graph in reverse starting from the classification
% layer. If the network branches, throw an error.
currentLayerIdx = find(isClassificationLayer);
while true
if numel(currentLayerIdx) ~= 1
error('Layer graph must have a single learnable layer preceding the classification layer.')
end
currentLayerType = class(lgraph.Layers(currentLayerIdx));
isLearnableLayer = ismember(currentLayerType, ...
['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']);
if isLearnableLayer
learnableLayer = lgraph.Layers(currentLayerIdx);
return
end
currentDstIdx = find(layerNames(currentLayerIdx) == dst);
currentLayerIdx = find(src(currentDstIdx) == layerNames);
end
end