-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
completes exercises1 and 2 and starts working on 3
- Loading branch information
0 parents
commit bb79ee4
Showing
10 changed files
with
1,199 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
function res = edgecnn(x, w, b, dzdy) | ||
% EDGECNN A very simple CNN to detect edges in an image | ||
|
||
pad1 = ([size(w,1) size(w,1) size(w,2) size(w,2)] - 1) / 2 ; | ||
rho3 = 5 ; | ||
pad3 = (rho3 - 1) / 2 ; | ||
|
||
res.x1 = x ; | ||
res.x2 = vl_nnconv(res.x1, w, b, 'pad', pad1) ; | ||
res.x3 = abs(res.x2) ; | ||
res.x4 = vl_nnpool(res.x3, rho3, 'pad', pad3) ; | ||
|
||
if nargin > 3 | ||
res.dzdx4 = dzdy ; | ||
res.dzdx3 = vl_nnpool(res.x3, rho3, res.dzdx4, 'pad', pad3) ; | ||
res.dzdx2 = res.dzdx3 .* sign(res.x2) ; | ||
[res.dzdx1, res.dzdw, res.dzdb] = ... | ||
vl_nnconv(res.x1, w, b, res.dzdx2, 'pad', pad1) ; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
setup ; | ||
|
||
% ------------------------------------------------------------------------- | ||
% Part 1.1: Linear convolution | ||
% ------------------------------------------------------------------------- | ||
|
||
% Read an example image | ||
x = imread('peppers.png') ; | ||
|
||
% Convert to single format | ||
x = im2single(x) ; | ||
|
||
% Visualize the input x | ||
figure(1) ; clf ; imagesc(x) ; | ||
|
||
% Create a bank of linear filters | ||
w = randn(5,5,3,10,'single') ; | ||
|
||
% Apply the convolutional operator | ||
y = vl_nnconv(x, w, []) ; | ||
|
||
% Visualize the output y | ||
figure(2) ; clf ; vl_imarraysc(y) ; colormap gray ; | ||
|
||
% Try again, downsampling the output | ||
y_ds = vl_nnconv(x, w, [], 'stride', 16) ; | ||
figure(3) ; clf ; vl_imarraysc(y_ds) ; colormap gray ; | ||
|
||
% Try padding | ||
y_pad = vl_nnconv(x, w, [], 'pad', 4) ; | ||
figure(4) ; clf ; vl_imarraysc(y_pad) ; colormap gray ; | ||
|
||
% Manually design a filter | ||
w = [0 1 0 ; | ||
1 -4 1 ; | ||
0 1 0 ] ; | ||
w = single(repmat(w, [1, 1, 3])) ; | ||
y_lap = vl_nnconv(x, w, []) ; | ||
figure(5) ; clf ; colormap gray ; | ||
subplot(1,2,1) ; imagesc(y_lap) ; title('filter output') ; | ||
subplot(1,2,2) ; imagesc(-abs(y_lap)) ; title('- abs(filter output)') ; | ||
|
||
|
||
% ------------------------------------------------------------------------- | ||
% Part 1.2: Non-linear gating (ReLU) | ||
% ------------------------------------------------------------------------- | ||
|
||
w = single(repmat([1 0 -1], [1, 1, 3])) ; | ||
w = cat(4, w, -w) ; | ||
y = vl_nnconv(x, w, []) ; | ||
z = vl_nnrelu(y) ; | ||
|
||
figure(6) ; clf ; colormap gray ; | ||
subplot(1,2,1) ; vl_imarraysc(y) ; | ||
subplot(1,2,2) ; vl_imarraysc(z) ; | ||
|
||
% ------------------------------------------------------------------------- | ||
% Part 1.2: Pooling | ||
% ------------------------------------------------------------------------- | ||
|
||
y = vl_nnpool(x, 15) ; | ||
figure(7) ; clf ; imagesc(y) ; | ||
|
||
% ------------------------------------------------------------------------- | ||
% Part 1.3: Normalization | ||
% ------------------------------------------------------------------------- | ||
|
||
rho = 5 ; | ||
kappa = 0 ; | ||
alpha = 1 ; | ||
beta = 0.5 ; | ||
y_nrm = vl_nnnormalize(x, [rho kappa alpha beta]) ; | ||
figure(8) ; clf ; imagesc(y_nrm) ; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
setup ; | ||
|
||
% ------------------------------------------------------------------------- | ||
% Part 2.1: Linear convolution derivatives | ||
% ------------------------------------------------------------------------- | ||
|
||
% Read an example image | ||
x = im2single(imread('peppers.png')) ; | ||
|
||
% Create a bank of linear filters and apply them to the image | ||
w = randn(5,5,3,10,'single') ; | ||
y = vl_nnconv(x, w, []) ; | ||
|
||
% Create the derivative dz/dy | ||
dzdy = randn(size(y), 'single') ; | ||
|
||
% Back-propagation | ||
[dzdx, dzdw] = vl_nnconv(x, w, [], dzdy) ; | ||
|
||
% Check the derivative numerically | ||
ex = randn(size(x), 'single') ; | ||
eta = 0.0001 ; | ||
xp = x + eta * ex ; | ||
yp = vl_nnconv(xp, w, []) ; | ||
|
||
dzdx_empirical = sum(dzdy(:) .* (yp(:) - y(:)) / eta) ; | ||
dzdx_computed = sum(dzdx(:) .* ex(:)) ; | ||
|
||
fprintf(... | ||
'der: empirical: %f, computed: %f, error: %.2f %%\n', ... | ||
dzdx_empirical, dzdx_computed, ... | ||
abs(1 - dzdx_empirical/dzdx_computed)*100) ; | ||
|
||
% ------------------------------------------------------------------------- | ||
% Part 2.2: Back-propagation | ||
% ------------------------------------------------------------------------- | ||
|
||
% Parameters of the CNN | ||
w1 = randn(5,5,3,10,'single') ; | ||
rho2 = 10 ; | ||
|
||
% Run the CNN forward | ||
x1 = im2single(imread('peppers.png')) ; | ||
x2 = vl_nnconv(x1, w1, []) ; | ||
x3 = vl_nnpool(x2, rho2) ; | ||
|
||
% Create the derivative dz/dx3 | ||
dzdx3 = randn(size(x3), 'single') ; | ||
|
||
% Run the CNN backward | ||
dzdx2 = vl_nnpool(x2, rho2, dzdx3) ; | ||
[dzdx1, dzdw1] = vl_nnconv(x1, w1, [], dzdx2) ; | ||
|
||
% Check the derivative numerically | ||
ew1 = randn(size(w1), 'single') ; | ||
eta = 0.0001 ; | ||
w1p = w1 + eta * ew1 ; | ||
|
||
x1p = x1 ; | ||
x2p = vl_nnconv(x1p, w1p, []) ; | ||
x3p = vl_nnpool(x2p, rho2) ; | ||
|
||
dzdw1_empirical = sum(dzdx3(:) .* (x3p(:) - x3(:)) / eta) ; | ||
dzdw1_computed = sum(dzdw1(:) .* ew1(:)) ; | ||
|
||
fprintf(... | ||
'der: empirical: %f, computed: %f, error: %.2f %%\n', ... | ||
dzdw1_empirical, dzdw1_computed, ... | ||
abs(1 - dzdw1_empirical/dzdw1_computed)*100) ; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
% ------------------------------------------------------------------------- | ||
% Part 3: Learning a simple CNN | ||
% ------------------------------------------------------------------------- | ||
|
||
setup ; | ||
|
||
% Load an image and compute its edge map | ||
im = im2single(imread('peppers.png')) ; | ||
[edges,fill] = extractImageEdges(im) ; | ||
figure(1) ; clf ; colormap gray ; | ||
subplot(1,3,1) ; imagesc(im) ; axis equal ; | ||
subplot(1,3,2) ; imagesc(edges) ; axis equal ; | ||
subplot(1,3,3) ; imagesc(fill) ; axis equal ; | ||
|
||
% Initialize the network | ||
w = randn(3, 3, 3, 'single') ; | ||
b = single(1) ; | ||
dzdw_momentum = zeros('like', w) ; | ||
dzdb_momentum = zeros('like', b) ; | ||
|
||
% SGD parameters | ||
T = 500 ; | ||
eta = 0.005 ; | ||
momentum = 0.9 ; | ||
energy = zeros(1, T) ; | ||
y = zeros(size(edges),'single') ; | ||
y(edges) = +1 ; | ||
y(fill) = -1 ; | ||
|
||
% SGD with momentum | ||
for t = 1:T | ||
% Forward pass | ||
res = edgecnn(im, w, b) ; | ||
|
||
% Loss | ||
z = y .* res.x4 ; | ||
E(t) = mean(max(0, 1 - z(:))) ; | ||
dzdx4 = - y .* (z < 1) / numel(z) ; | ||
|
||
% Backward pass | ||
res = edgecnn(im, w, b, dzdx4) ; | ||
|
||
% Gradient step | ||
dzdw_momentum = momentum * dzdw_momentum + res.dzdw ; | ||
dzdb_momentum = momentum * dzdb_momentum + res.dzdb ; | ||
w = w - eta * dzdw_momentum ; | ||
b = b - eta * dzdb_momentum ; | ||
|
||
% Plots | ||
if mod(t-1, 20) == 0 | ||
figure(2) ; clf ; colormap gray ; | ||
subplot(2,2,1) ; imagesc(res.x4) ; axis equal ; title('network output') ; | ||
subplot(2,2,2) ; plot(1:t, E(1:t)) ; grid on ; title('objective') ; | ||
subplot(2,2,3) ; vl_imarraysc(w) ; title('filter slices') ; axis equal ; | ||
subplot(2,2,4) ; imagesc(res.x2) ; title('first layer output') ; axis equal ; | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
name = practical-cnn-2015a | ||
DST = [email protected]:WWW/share | ||
DSTDOC = [email protected]:WWW/practicals/cnn | ||
|
||
.PHONY: prepack, pack, pack-data, pack-code, post, clean, distclean | ||
|
||
pack-all: pack-data pack-code pack | ||
|
||
code=\ | ||
exercise1.m \ | ||
exercise2.m \ | ||
exercise3.m | ||
README.md \ | ||
vlfeat \ | ||
matconvnet | ||
|
||
doc=\ | ||
doc/images \ | ||
doc/instructions.html | ||
|
||
data=\ | ||
data/mnist.mat | ||
|
||
code:=$(addprefix $(CURDIR)/,$(code)) | ||
data:=$(addprefix $(CURDIR)/,$(data)) | ||
deps:=$(shell find $(code) $(doc) $(data) -type f | sed "s/ /\\\\ /g") | ||
|
||
pack: data/$(name).tar.gz | ||
pack-data: data/$(name)-data-only.tar.gz | ||
pack-code: data/$(name)-code-only.tar.gz | ||
|
||
data/$(name).tar.gz: $(deps) | ||
rm -rf data/$(name) | ||
mkdir -p data/$(name)/{doc,data} | ||
ln -sf $(data) data/$(name)/data/ | ||
ln -sf $(doc) data/$(name)/doc/ | ||
ln -sf $(code) data/$(name)/ | ||
tar -C data -czvhf data/$(name).tar.gz $(name)/ | ||
|
||
data/$(name)-data-only.tar.gz: $(deps) | ||
rm -rf data/$(name) | ||
mkdir -p data/$(name)/{doc,data} | ||
ln -sf $(data) data/$(name)/ | ||
tar -C data -czvhf data/$(name)-data-only.tar.gz $(name)/ | ||
|
||
data/$(name)-code-only.tar.gz: $(deps) | ||
rm -rf data/$(name) | ||
mkdir -p data/$(name)/{doc,data} | ||
ln -sf $(doc) data/$(name)/doc/ | ||
ln -sf $(code) data/$(name)/ | ||
tar -C data -czvhf data/$(name)-code-only.tar.gz $(name)/ | ||
|
||
post-doc: | ||
rsync -rvt doc/images $(DSTDOC)/ | ||
rsync -vt doc/instructions.html $(DSTDOC)/index.html | ||
|
||
post: pack-all | ||
rsync -vt data/$(name).tar.gz $(DST)/ | ||
rsync -vt data/$(name)-data-only.tar.gz $(DST)/ | ||
rsync -vt data/$(name)-code-only.tar.gz $(DST)/ | ||
|
||
clean: | ||
find . -name '*~' -delete | ||
|
||
distclean: clean | ||
rm -f data/$(name)*.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
% PREPARELABDATA | ||
|
||
% -------------------------------------------------------------------- | ||
% Download VLFeat | ||
% -------------------------------------------------------------------- | ||
|
||
if ~exist('vlfeat', 'dir') | ||
from = 'http://www.vlfeat.org/download/vlfeat-0.9.18-bin.tar.gz' ; | ||
fprintf('Downloading vlfeat from %s\n', from) ; | ||
untar(from, 'data') ; | ||
movefile('data/vlfeat-0.9.18', 'vlfeat') ; | ||
end | ||
|
||
setup ; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
function [edges, fill] = extractImageEdges(im) | ||
edges = edge(rgb2gray(im),'canny') ; | ||
fill = ~imdilate(edges, strel('disk', 5, 0)) ; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
|
||
run matconvnet/matlab/vl_setupnn ; | ||
try | ||
vl_nnconv(single(1),single(1),[]) ; | ||
catch | ||
warning('VL_NNCONV() does not seem to be compiled. Trying to compile now.') ; | ||
vl_compilenn2('enableGpu', false, 'verbose', true) ; | ||
end | ||
|
||
run vlfeat/toolbox/vl_setup ; |