Skip to content

Commit

Permalink
Added example script to train sketch-based image retrieval and deep s…
Browse files Browse the repository at this point in the history
…hape matching network presented in ECCV18 paper: Deep Shape Matching.
  • Loading branch information
filipradenovic committed Oct 9, 2018
1 parent db11c57 commit 0db13cf
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
4 changes: 2 additions & 2 deletions cnntrain/train_network.m
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
end

for epoch=start+1:opts.numEpochs

% Set the random seed based on the epoch and opts.randomSeed.
% This is important for reproducibility, including when training
% is restarted from a checkpoint.
Expand Down Expand Up @@ -137,7 +137,7 @@

% With multiple GPUs, return one copy
if isa(net, 'Composite'), net = net{1} ; end

end

% -------------------------------------------------------------------------
Expand Down
1 change: 0 additions & 1 deletion examples/test_cnnsketch2imageretrieval.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
% Choose ECCV18 fine-tuned CNN network
network_file = fullfile(data_root, 'networks', 'retrieval-SfM-30k', 'retrievalSfM30k-edgemac-vgg.mat');

%% COMING SOON
% After running the training script train_cnnsketch2imageretrieval.m you can evaluate fine-tuned network
% network_file = fullfile(data_root, 'networks', 'exp', 'vgg_edgefilter_mac_test', 'net-epoch-20');

Expand Down
60 changes: 60 additions & 0 deletions examples/train_cnnsketch2imageretrieval.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
% TRAIN_CNNSKETCH2IMAGERETRIEVAL Code to train the methods presented in the paper:
% F. Radenovic, G. Tolias, O. Chum, Deep Shape Matching, ECCV 2018
%
% Note: The method has been re-coded since our ECCV 2018 paper and minor differences in performance might appear.
%
% Authors: F. Radenovic, G. Tolias, O. Chum. 2018.

clear;

%-------------------------------------------------------------------------------
% Set data folder
%-------------------------------------------------------------------------------

% Set data folder, change if you have downloaded the data somewhere else
data_root = fullfile(get_root_cnnimageretrieval(), 'data');
% Check, and, if necessary, download train data (db with edgemaps), and pre-trained imagenet networks
download_train_sketch(data_root);


%-------------------------------------------------------------------------------
% Reproduce training from ECCV18 paper: Deep Shape Matching ...
%-------------------------------------------------------------------------------

% Set architecture and initialization parameters
opts.init.model = 'VGG'; % (ALEX | VGG | GOOGLENET | RESNET101)
opts.init.modelDir = fullfile(data_root, 'networks', 'imagenet');
opts.init.method = 'edgefilter_mac';
opts.init.objectiveType = {'contrastiveloss', 0.7};
opts.init.errorType = {'batchmap'};
opts.init.averageImageScale = 0;
opts.init.imageChannels = 1;

% Set train parameters
opts.train.dbPath = fullfile(data_root, 'train', 'dbs', 'retrieval-SfM-30k-edgemap.mat');
opts.train.batchSize = 20;
opts.train.numSubBatches = 4;
opts.train.numEpochs = 20;
opts.train.learningRate = 0.001 .* exp(-(0:99)*0.1);
opts.train.numNegative = 5;
opts.train.numRemine = 3;
opts.train.gpus = [1];

opts.train.augment.jitterFlip = true;
opts.train.augment.jitterQueryBinarize = true;

% Trial name (to name a save directory)
trialName = 'test';

% Export directory expDir named after model, method and trialName
opts.init.method = [opts.init.method, '_', trialName];
opts.train.expDir = fullfile(data_root, 'networks', 'exp', [lower(opts.init.model) '_' lower(opts.init.method)]);
if ~exist(opts.train.expDir); mkdir(opts.train.expDir); end % create folder if its not there

% Load opts by respecting added opts
opts = load_opts_train(opts);

% Initialize and train the network
fprintf('>> Experiment folder is set to %s\n', opts.train.expDir);
net = init_network(opts.init);
[net, state, stats] = train_network(net, @(o,i,n,b,s,m,e) get_batch(o,i,n,b,s,m,e), opts.train);

0 comments on commit 0db13cf

Please sign in to comment.