-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
getCNN_alex.m
76 lines (60 loc) · 2.32 KB
/
getCNN_alex.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function layers=getCNN_alex(param)
%% input layer
inputlayer = imageInputLayer(param.szIn,...
'DataAugmentation',param.DataAugmentation,...
'Normalization',param.Normalization,...
'Name','inputlayer');
%% alexnet (from 2:16)
net = alexnet;
layersTransfer = net.Layers(2:16);
clear net;
%% conv1 (modified)
conv1=layersTransfer(1);
Weights=zeros([size(conv1.Weights,1),size(conv1.Weights,2),...
size(conv1.Weights,3)+1,size(conv1.Weights,4)]);
Bias=conv1.Bias;
Weights(:,:,1:3,:)=conv1.Weights(:,:,1:3,:);
Weights(:,:,4,:)=1/11;%randn(size(Weights(:,:,4,:)));
conv1 = convolution2dLayer(size(Weights,1),size(Weights,4),'Stride',4,...
'Padding',0,'WeightL2Factor',param.WeightL2Factor,'Name','conv1',...
'WeightLearnRateFactor',param.WeightLearnRateFactor/20,...
'BiasLearnRateFactor',param.BiasLearnRateFactor/20);
conv1.Weights = Weights;
conv1.Bias = ones(size(Bias)); %=Bias;
layersTransfer=layersTransfer(2:end);
%% fully connected layers
% fc6
fc6= fullyConnectedLayer(4096,...
'WeightL2Factor',param.WeightL2Factor,...
'WeightLearnRateFactor',param.WeightLearnRateFactor,...
'BiasLearnRateFactor',param.BiasLearnRateFactor,...
'Name','fc6');
relu6 = reluLayer('Name','relu6');
drop6=dropoutLayer(0.5,'Name','drop6');
% fc7
fc7= fullyConnectedLayer(2048,...
'WeightL2Factor',param.WeightL2Factor,...
'WeightLearnRateFactor',param.WeightLearnRateFactor,...
'BiasLearnRateFactor',param.BiasLearnRateFactor,...
'Name','fc7');
relu7 = reluLayer('Name','relu7');
drop7=dropoutLayer(0.5,'Name','drop7');
% fc8
fc8= fullyConnectedLayer(1024,...
'WeightL2Factor',param.WeightL2Factor,...
'WeightLearnRateFactor',param.WeightLearnRateFactor,...
'BiasLearnRateFactor',param.BiasLearnRateFactor,...
'Name','fc8');
relu8 = reluLayer('Name','relu8');
drop8=dropoutLayer(0.5,'Name','drop8');
% fc9
outSize=param.szOut(1)*param.szOut(2);
fc9= fullyConnectedLayer(outSize,...
'WeightL2Factor',param.WeightL2Factor,...
'WeightLearnRateFactor',param.WeightLearnRateFactor,...
'BiasLearnRateFactor',param.BiasLearnRateFactor,...
'Name','fc9');
fc=[fc6;relu6;drop6;fc7;relu7;drop7;fc8;relu8;drop8;fc9];
%% regression layer
reg=regressionLayer('Name','reg');
layers = [inputlayer;conv1;layersTransfer;fc;reg]