-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgen_feats_israeli.m
59 lines (50 loc) · 1.76 KB
/
gen_feats_israeli.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
function gen_feats_israeli(db_ind)
if nargin<1
db_ind = 2;
end
[db_attr, ~, dbname] = get_db_attrs('israeli', db_ind);
load(fullfile('datasets', 'israeli', 'preprocessed_data.mat'), ...
'trace_ims', 'trace_treadids')
trace_H = size(trace_ims, 1); trace_W = size(trace_ims, 2);
% zero-center the data
trace_ims = single(trace_ims);
mean_im = mean(trace_ims, 4);
mean_im_pix = mean(mean(mean_im, 1), 2);
trace_ims = bsxfun(@minus, trace_ims, mean_im_pix);
% load and modify network
net = dagnn.DagNN();
if db_ind==0
net.addLayer('identity', dagnn.Conv('size', [1 1 3 1], ...
'stride', 1, ...
'pad', 0, ...
'hasBias', false), ...
{'data'}, {'raw'}, {'I'});
net.params(1).value = reshape(single([1 0 0]), 1, 1, 3, 1);
else
flatnn = load(fullfile('models', db_attr{3}));
net = net.loadobj(flatnn);
ind = net.getLayerIndex(db_attr{2});
net.layers(ind:end) = []; net.rebuild();
end
% generate database
data = {net.vars(1).name, trace_ims};
all_db_feats = generate_db_CNNfeats(net, data);
% generate labels for db
all_db_labels = reshape(trace_treadids, 1, 1, 1, []);
feat_idx = numel(net.vars);
feat_dims = size(net.vars(end).value);
rfs = net.getVarReceptiveFields(1);
rfsIm = rfs(end);
mkdir(fullfile('feats', dbname))
db_feats = all_db_feats(:,:,:, 1);
db_labels = all_db_labels(:,:,:, 1);
save(fullfile('feats', dbname, 'israeli_001.mat'), ...
'db_feats', 'db_labels', 'feat_dims', ...
'rfsIm', 'trace_H', 'trace_W', '-v7.3')
for i=2:size(all_db_feats, 4)
db_feats = all_db_feats(:,:,:, i);
db_labels = all_db_labels(:,:,:, i);
save(fullfile('feats', dbname, sprintf('israeli_%03d.mat', i)), ...
'db_feats', 'db_labels', '-v7.3');
end
end