-
Notifications
You must be signed in to change notification settings - Fork 3
/
assembleGoogLeNet.m
67 lines (55 loc) · 2.03 KB
/
assembleGoogLeNet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
function net = assembleGoogLeNet()
% assembleGoogLeNet Assemble GoogLeNet network
%
% net = assembleGoogLeNet creates a GoogLeNet network with weights
% trained on ImageNet. You can load the same GoogLeNet network by
% installing the Deep Learning Toolbox Model for GoogLeNet Network
% support package from the Add-On Explorer and then using the googlenet
% function.
% Copyright 2019 The MathWorks, Inc.
% Download the network parameters. If these have already been downloaded,
% this step will be skipped.
%
% The files will be downloaded to a file "googlenetParams.mat", in a
% directory "GoogLeNet" located in the system's temporary directory.
dataDir = fullfile(tempdir, "GoogLeNet");
paramFile = fullfile(dataDir, "googlenetParams.mat");
downloadUrl = "http://www.mathworks.com/supportfiles/nnet/data/networks/googlenetParams.mat";
if ~exist(dataDir, "dir")
mkdir(dataDir);
end
if ~exist(paramFile, "file")
disp("Downloading pretrained parameters file (26 MB).")
disp("This may take several minutes...");
websave(paramFile, downloadUrl);
disp("Download finished.");
else
disp("Skipping download, parameter file already exists.");
end
% Load the network parameters from the file googlenetParams.mat.
s = load(paramFile);
params = s.params;
% Create a layer graph with the network architecture of GoogLeNet.
lgraph = googlenetLayers;
% Create a cell array containing the layer names.
layerNames = {lgraph.Layers(:).Name}';
% Loop over layers and add parameters.
for i = 1:numel(layerNames)
name = layerNames{i};
idx = strcmp(layerNames,name);
layer = lgraph.Layers(idx);
% Assign layer parameters.
pname = replace(name,'-','_');
layerParams = params.(pname);
if ~isempty(layerParams)
paramNames = fields(layerParams);
for j = 1:numel(paramNames)
layer.(paramNames{j}) = layerParams.(paramNames{j});
end
% Add layer into layer graph.
lgraph = replaceLayer(lgraph,name,layer);
end
end
% Assemble the network.
net = assembleNetwork(lgraph);
end