-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
85 changed files
with
2,723 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
image/ | ||
train/ | ||
*.sublime* | ||
*.*~ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
before_script: | ||
- sudo apt-add-repository ppa:octave/stable --yes | ||
- sudo apt-get update -y | ||
- sudo apt-get install octave -y | ||
- sudo apt-get install liboctave-dev -y | ||
script: | ||
- sh -c "octave tests/runalltests.m" | ||
|
||
notifications: | ||
email: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
function cae = caeapplygrads(cae) | ||
cae.sv = 0; | ||
for j = 1 : numel(cae.a) | ||
for i = 1 : numel(cae.i) | ||
% cae.vik{i}{j} = cae.momentum * cae.vik{i}{j} + cae.alpha ./ (cae.sigma + cae.ddik{i}{j}) .* cae.dik{i}{j}; | ||
% cae.vok{i}{j} = cae.momentum * cae.vok{i}{j} + cae.alpha ./ (cae.sigma + cae.ddok{i}{j}) .* cae.dok{i}{j}; | ||
cae.vik{i}{j} = cae.alpha * cae.dik{i}{j}; | ||
cae.vok{i}{j} = cae.alpha * cae.dok{i}{j}; | ||
cae.sv = cae.sv + sum(cae.vik{i}{j}(:) .^ 2); | ||
cae.sv = cae.sv + sum(cae.vok{i}{j}(:) .^ 2); | ||
|
||
cae.ik{i}{j} = cae.ik{i}{j} - cae.vik{i}{j}; | ||
cae.ok{i}{j} = cae.ok{i}{j} - cae.vok{i}{j}; | ||
end | ||
% cae.vb{j} = cae.momentum * cae.vb{j} + cae.alpha / (cae.sigma + cae.ddb{j}) * cae.db{j}; | ||
cae.vb{j} = cae.alpha * cae.db{j}; | ||
cae.sv = cae.sv + sum(cae.vb{j} .^ 2); | ||
|
||
cae.b{j} = cae.b{j} - cae.vb{j}; | ||
end | ||
|
||
for i = 1 : numel(cae.o) | ||
% cae.vc{i} = cae.momentum * cae.vc{i} + cae.alpha / (cae.sigma + cae.ddc{i}) * cae.dc{i}; | ||
cae.vc{i} = cae.alpha * cae.dc{i}; | ||
cae.sv = cae.sv + sum(cae.vc{i} .^ 2); | ||
|
||
cae.c{i} = cae.c{i} - cae.vc{i}; | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
function cae = caebbp(cae) | ||
|
||
%% backprop deltas | ||
for i = 1 : numel(cae.o) | ||
% output delta delta | ||
cae.odd{i} = (cae.o{i} .* (1 - cae.o{i}) .* cae.edgemask) .^ 2; | ||
% delta delta c | ||
cae.ddc{i} = sum(cae.odd{i}(:)) / size(cae.odd{i}, 1); | ||
end | ||
|
||
for j = 1 : numel(cae.a) % calc activation delta deltas | ||
z = 0; | ||
for i = 1 : numel(cae.o) | ||
z = z + convn(cae.odd{i}, flipall(cae.ok{i}{j} .^ 2), 'full'); | ||
end | ||
cae.add{j} = (cae.a{j} .* (1 - cae.a{j})) .^ 2 .* z; | ||
end | ||
|
||
%% calc params delta deltas | ||
ns = size(cae.odd{1}, 1); | ||
for j = 1 : numel(cae.a) | ||
cae.ddb{j} = sum(cae.add{j}(:)) / ns; | ||
for i = 1 : numel(cae.o) | ||
cae.ddok{i}{j} = convn(flipall(cae.a{j} .^ 2), cae.odd{i}, 'valid') / ns; | ||
cae.ddik{i}{j} = convn(cae.add{j}, flipall(cae.i{i} .^ 2), 'valid') / ns; | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
function cae = caebp(cae, y) | ||
|
||
%% backprop deltas | ||
cae.L = 0; | ||
for i = 1 : numel(cae.o) | ||
% error | ||
cae.e{i} = (cae.o{i} - y{i}) .* cae.edgemask; | ||
% loss function | ||
cae.L = cae.L + 1/2 * sum(cae.e{i}(:) .^2 ) / size(cae.e{i}, 1); | ||
% output delta | ||
cae.od{i} = cae.e{i} .* (cae.o{i} .* (1 - cae.o{i})); | ||
|
||
cae.dc{i} = sum(cae.od{i}(:)) / size(cae.e{i}, 1); | ||
end | ||
|
||
for j = 1 : numel(cae.a) % calc activation deltas | ||
z = 0; | ||
for i = 1 : numel(cae.o) | ||
z = z + convn(cae.od{i}, flipall(cae.ok{i}{j}), 'full'); | ||
end | ||
cae.ad{j} = cae.a{j} .* (1 - cae.a{j}) .* z; | ||
end | ||
|
||
%% calc gradients | ||
ns = size(cae.e{1}, 1); | ||
for j = 1 : numel(cae.a) | ||
cae.db{j} = sum(cae.ad{j}(:)) / ns; | ||
for i = 1 : numel(cae.o) | ||
cae.dok{i}{j} = convn(flipall(cae.a{j}), cae.od{i}, 'valid') / ns; | ||
cae.dik{i}{j} = convn(cae.ad{j}, flipall(cae.i{i}), 'valid') / ns; | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function cae = caedown(cae) | ||
pa = cae.a; | ||
pok = cae.ok; | ||
|
||
for i = 1 : numel(cae.o) | ||
z = 0; | ||
for j = 1 : numel(cae.a) | ||
z = z + convn(pa{j}, pok{i}{j}, 'valid'); | ||
end | ||
cae.o{i} = sigm(z + cae.c{i}); | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
%% mnist data | ||
clear all; close all; clc; | ||
load mnist_uint8; | ||
x = cell(100, 1); | ||
N = 600; | ||
for i = 1 : 100 | ||
x{i}{1} = reshape(train_x(((i - 1) * N + 1) : (i) * N, :), N, 28, 28) * 255; | ||
end | ||
%% ex 1 | ||
scae = { | ||
struct('outputmaps', 10, 'inputkernel', [1 5 5], 'outputkernel', [1 5 5], 'scale', [1 2 2], 'sigma', 0.1, 'momentum', 0.9, 'noise', 0) | ||
}; | ||
|
||
opts.rounds = 1000; | ||
opts.batchsize = 1; | ||
opts.alpha = 0.01; | ||
opts.ddinterval = 10; | ||
opts.ddhist = 0.5; | ||
scae = scaesetup(scae, x, opts); | ||
scae = scaetrain(scae, x, opts); | ||
cae = scae{1}; | ||
|
||
%Visualize the average reconstruction error | ||
plot(cae.rL); | ||
|
||
%Visualize the output kernels | ||
ff=[]; | ||
for i=1:numel(cae.ok{1}); | ||
mm = cae.ok{1}{i}(1,:,:); | ||
ff(i,:) = mm(:); | ||
end; | ||
figure;visualize(ff') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
function cae = caenumgradcheck(cae, x, y) | ||
epsilon = 1e-4; | ||
er = 1e-6; | ||
disp('performing numerical gradient checking...') | ||
for i = 1 : numel(cae.o) | ||
p_cae = cae; p_cae.c{i} = p_cae.c{i} + epsilon; | ||
m_cae = cae; m_cae.c{i} = m_cae.c{i} - epsilon; | ||
|
||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
|
||
e = abs(d - cae.dc{i}); | ||
if e > er | ||
disp('OUTPUT BIAS numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.dc{i}); | ||
keyboard | ||
end | ||
end | ||
|
||
for a = 1 : numel(cae.a) | ||
|
||
p_cae = cae; p_cae.b{a} = p_cae.b{a} + epsilon; | ||
m_cae = cae; m_cae.b{a} = m_cae.b{a} - epsilon; | ||
|
||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% cae.dok{i}{a}(u) = d; | ||
e = abs(d - cae.db{a}); | ||
if e > er | ||
disp('BIAS numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.db{a}); | ||
keyboard | ||
end | ||
|
||
for i = 1 : numel(cae.o) | ||
for u = 1 : numel(cae.ok{i}{a}) | ||
p_cae = cae; p_cae.ok{i}{a}(u) = p_cae.ok{i}{a}(u) + epsilon; | ||
m_cae = cae; m_cae.ok{i}{a}(u) = m_cae.ok{i}{a}(u) - epsilon; | ||
|
||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% cae.dok{i}{a}(u) = d; | ||
e = abs(d - cae.dok{i}{a}(u)); | ||
if e > er | ||
disp('OUTPUT KERNEL numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.dok{i}{a}(u)); | ||
% keyboard | ||
end | ||
end | ||
end | ||
|
||
for i = 1 : numel(cae.i) | ||
for u = 1 : numel(cae.ik{i}{a}) | ||
p_cae = cae; | ||
m_cae = cae; | ||
p_cae.ik{i}{a}(u) = p_cae.ik{i}{a}(u) + epsilon; | ||
m_cae.ik{i}{a}(u) = m_cae.ik{i}{a}(u) - epsilon; | ||
[m_cae, p_cae] = caerun(m_cae, p_cae, x, y); | ||
d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% cae.dik{i}{a}(u) = d; | ||
e = abs(d - cae.dik{i}{a}(u)); | ||
if e > er | ||
disp('INPUT KERNEL numerical gradient checking failed'); | ||
disp(e); | ||
disp(d / cae.dik{i}{a}(u)); | ||
end | ||
end | ||
end | ||
end | ||
|
||
disp('done') | ||
|
||
end | ||
|
||
function [m_cae, p_cae] = caerun(m_cae, p_cae, x, y) | ||
m_cae = caeup(m_cae, x); m_cae = caedown(m_cae); m_cae = caebp(m_cae, y); | ||
p_cae = caeup(p_cae, x); p_cae = caedown(p_cae); p_cae = caebp(p_cae, y); | ||
end | ||
|
||
%function checknumgrad(cae,what,x,y) | ||
% epsilon = 1e-4; | ||
% er = 1e-9; | ||
% | ||
% for i = 1 : numel(eval(what)) | ||
% if iscell(eval(['cae.' what])) | ||
% checknumgrad(cae,[what '{' num2str(i) '}'], x, y) | ||
% else | ||
% p_cae = cae; | ||
% m_cae = cae; | ||
% eval(['p_cae.' what '(' num2str(i) ')']) = eval([what '(' num2str(i) ')']) + epsilon; | ||
% eval(['m_cae.' what '(' num2str(i) ')']) = eval([what '(' num2str(i) ')']) - epsilon; | ||
% | ||
% m_cae = caeff(m_cae, x); m_cae = caedown(m_cae); m_cae = caebp(m_cae, y); | ||
% p_cae = caeff(p_cae, x); p_cae = caedown(p_cae); p_cae = caebp(p_cae, y); | ||
% | ||
% d = (p_cae.L - m_cae.L) / (2 * epsilon); | ||
% e = abs(d - eval(['cae.d' what '(' num2str(i) ')'])); | ||
% if e > er | ||
% error('numerical gradient checking failed'); | ||
% end | ||
% end | ||
% end | ||
% | ||
% end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
function cae = caesdlm(cae, opts, m) | ||
%stochastic diagonal levenberg-marquardt | ||
|
||
%first round | ||
if isfield(cae,'ddok') == 0 | ||
cae = caebbp(cae); | ||
end | ||
|
||
%recalculate double grads every opts.ddinterval | ||
if mod(m, opts.ddinterval) == 0 | ||
cae_n = caebbp(cae); | ||
|
||
for ii = 1 : numel(cae.o) | ||
cae.ddc{ii} = opts.ddhist * cae.ddc{ii} + (1 - opts.ddhist) * cae_n.ddc{ii}; | ||
end | ||
|
||
for jj = 1 : numel(cae.a) | ||
cae.ddb{jj} = opts.ddhist * cae.ddb{jj} + (1 - opts.ddhist) * cae_n.ddb{jj}; | ||
for ii = 1 : numel(cae.o) | ||
cae.ddok{ii}{jj} = opts.ddhist * cae.ddok{ii}{jj} + (1 - opts.ddhist) * cae_n.ddok{ii}{jj}; | ||
cae.ddik{ii}{jj} = opts.ddhist * cae.ddik{ii}{jj} + (1 - opts.ddhist) * cae_n.ddik{ii}{jj}; | ||
end | ||
end | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
function cae = caetrain(cae, x, opts) | ||
n = cae.inputkernel(1); | ||
cae.rL = []; | ||
for m = 1 : opts.rounds | ||
tic; | ||
disp([num2str(m) '/' num2str(opts.rounds) ' rounds']); | ||
i1 = randi(numel(x)); | ||
l = randi(size(x{i1}{1},1) - opts.batchsize - n + 1); | ||
x1{1} = double(x{i1}{1}(l : l + opts.batchsize - 1, :, :)) / 255; | ||
|
||
if n == 1 %Auto Encoder | ||
x2{1} = x1{1}; | ||
else %Predictive Encoder | ||
x2{1} = double(x{i1}{1}(l + n : l + n + opts.batchsize - 1, :, :)) / 255; | ||
end | ||
% Add noise to input, for denoising stacked autoenoder | ||
x1{1} = x1{1} .* (rand(size(x1{1})) > cae.noise); | ||
|
||
cae = caeup(cae, x1); | ||
cae = caedown(cae); | ||
cae = caebp(cae, x2); | ||
cae = caesdlm(cae, opts, m); | ||
% caenumgradcheck(cae,x1,x2); | ||
cae = caeapplygrads(cae); | ||
|
||
if m == 1 | ||
cae.rL(1) = cae.L; | ||
end | ||
% cae.rL(m + 1) = 0.99 * cae.rL(m) + 0.01 * cae.L; | ||
cae.rL(m + 1) = cae.L; | ||
% if cae.sv < 1e-10 | ||
% disp('Converged'); | ||
% break; | ||
% end | ||
toc; | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
function cae = caeup(cae, x) | ||
cae.i = x; | ||
|
||
%init temp vars for parrallel processing | ||
pa = cell(size(cae.a)); | ||
pi = cae.i; | ||
pik = cae.ik; | ||
pb = cae.b; | ||
|
||
for j = 1 : numel(cae.a) | ||
z = 0; | ||
for i = 1 : numel(pi) | ||
z = z + convn(pi{i}, pik{i}{j}, 'full'); | ||
end | ||
pa{j} = sigm(z + pb{j}); | ||
|
||
% Max pool. | ||
if ~isequal(cae.scale, [1 1 1]) | ||
pa{j} = max3d(pa{j}, cae.M); | ||
end | ||
|
||
end | ||
cae.a = pa; | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function X = max3d(X, M) | ||
ll = size(X); | ||
B=X(M); | ||
B=B+rand(size(B))*1e-12; | ||
B=(B.*(B==repmat(max(B,[],2),[1 size(B,2) 1]))); | ||
X(M) = B; | ||
reshape(X,ll); | ||
end |
Oops, something went wrong.